mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-11-03 19:58:17 +00:00 
			
		
		
		
	Add protocol versions to pkg/util/wsstream
This commit is contained in:
		
				
					committed by
					
						
						Dr. Stefan Schimanski
					
				
			
			
				
	
			
			
			
						parent
						
							7b3c08d7d3
						
					
				
				
					commit
					ce7f003f57
				
			@@ -451,7 +451,7 @@ func write(statusCode int, gv unversioned.GroupVersion, s runtime.NegotiatedSeri
 | 
			
		||||
		defer out.Close()
 | 
			
		||||
 | 
			
		||||
		if wsstream.IsWebSocketRequest(req) {
 | 
			
		||||
			r := wsstream.NewReader(out, true)
 | 
			
		||||
			r := wsstream.NewReader(out, true, wsstream.NewDefaultReaderProtocols())
 | 
			
		||||
			if err := r.Copy(w, req); err != nil {
 | 
			
		||||
				utilruntime.HandleError(fmt.Errorf("error encountered while streaming results via websocket: %v", err))
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
@@ -27,6 +27,7 @@ import (
 | 
			
		||||
 | 
			
		||||
	"github.com/golang/glog"
 | 
			
		||||
	"golang.org/x/net/websocket"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/kubernetes/pkg/util/runtime"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -44,7 +45,7 @@ import (
 | 
			
		||||
//    READ  []byte{1, 10}                # receive "\n" on channel 1 (STDOUT)
 | 
			
		||||
//    CLOSE
 | 
			
		||||
//
 | 
			
		||||
const channelWebSocketProtocol = "channel.k8s.io"
 | 
			
		||||
const ChannelWebSocketProtocol = "channel.k8s.io"
 | 
			
		||||
 | 
			
		||||
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
 | 
			
		||||
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
 | 
			
		||||
@@ -60,7 +61,7 @@ const channelWebSocketProtocol = "channel.k8s.io"
 | 
			
		||||
//    READ  []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
 | 
			
		||||
//    CLOSE
 | 
			
		||||
//
 | 
			
		||||
const base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
 | 
			
		||||
const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
 | 
			
		||||
 | 
			
		||||
type codecType int
 | 
			
		||||
 | 
			
		||||
@@ -107,8 +108,9 @@ func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
 | 
			
		||||
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
 | 
			
		||||
	protocols := config.Protocol
 | 
			
		||||
	if len(protocols) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
		protocols = []string{""}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, protocol := range protocols {
 | 
			
		||||
		for _, allow := range allowed {
 | 
			
		||||
			if allow == protocol {
 | 
			
		||||
@@ -117,12 +119,31 @@ func handshake(config *websocket.Config, req *http.Request, allowed []string) er
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChannelProtocolConfig describes a websocket subprotocol with channels.
 | 
			
		||||
type ChannelProtocolConfig struct {
 | 
			
		||||
	Binary   bool
 | 
			
		||||
	Channels []ChannelType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDefaultChannelProtocols returns a channel protocol map with the
 | 
			
		||||
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
 | 
			
		||||
// channels.
 | 
			
		||||
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
 | 
			
		||||
	return map[string]ChannelProtocolConfig{
 | 
			
		||||
		"": {Binary: true, Channels: channels},
 | 
			
		||||
		ChannelWebSocketProtocol:       {Binary: true, Channels: channels},
 | 
			
		||||
		Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Conn supports sending multiple binary channels over a websocket connection.
 | 
			
		||||
// Supports only the "channel.k8s.io" subprotocol.
 | 
			
		||||
type Conn struct {
 | 
			
		||||
	protocols        map[string]ChannelProtocolConfig
 | 
			
		||||
	selectedProtocol string
 | 
			
		||||
	channels         []*websocketChannel
 | 
			
		||||
	codec            codecType
 | 
			
		||||
	ready            chan struct{}
 | 
			
		||||
@@ -134,24 +155,14 @@ type Conn struct {
 | 
			
		||||
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
 | 
			
		||||
// future use. The channel types for each channel are passed as an array, supporting the different
 | 
			
		||||
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
 | 
			
		||||
func NewConn(channels ...ChannelType) *Conn {
 | 
			
		||||
	conn := &Conn{
 | 
			
		||||
//
 | 
			
		||||
// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
 | 
			
		||||
// name is used if websocket.Config.Protocol is empty.
 | 
			
		||||
func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
 | 
			
		||||
	return &Conn{
 | 
			
		||||
		ready:     make(chan struct{}),
 | 
			
		||||
		channels: make([]*websocketChannel, len(channels)),
 | 
			
		||||
		protocols: protocols,
 | 
			
		||||
	}
 | 
			
		||||
	for i := range conn.channels {
 | 
			
		||||
		switch channels[i] {
 | 
			
		||||
		case ReadChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
 | 
			
		||||
		case WriteChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
 | 
			
		||||
		case ReadWriteChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
 | 
			
		||||
		case IgnoreChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
 | 
			
		||||
@@ -160,8 +171,9 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) {
 | 
			
		||||
	conn.timeout = duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Open the connection and create channels for reading and writing.
 | 
			
		||||
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWriteCloser, error) {
 | 
			
		||||
// Open the connection and create channels for reading and writing. It returns
 | 
			
		||||
// the selected subprotocol, a slice of channels and an error.
 | 
			
		||||
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer runtime.HandleCrash()
 | 
			
		||||
		defer conn.Close()
 | 
			
		||||
@@ -172,23 +184,42 @@ func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWrite
 | 
			
		||||
	for i := range conn.channels {
 | 
			
		||||
		rwc[i] = conn.channels[i]
 | 
			
		||||
	}
 | 
			
		||||
	return rwc, nil
 | 
			
		||||
	return conn.selectedProtocol, rwc, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (conn *Conn) initialize(ws *websocket.Conn) {
 | 
			
		||||
	protocols := ws.Config().Protocol
 | 
			
		||||
	switch {
 | 
			
		||||
	case len(protocols) == 0, protocols[0] == channelWebSocketProtocol:
 | 
			
		||||
	negotiated := ws.Config().Protocol
 | 
			
		||||
	conn.selectedProtocol = negotiated[0]
 | 
			
		||||
	p := conn.protocols[conn.selectedProtocol]
 | 
			
		||||
	if p.Binary {
 | 
			
		||||
		conn.codec = rawCodec
 | 
			
		||||
	case protocols[0] == base64ChannelWebSocketProtocol:
 | 
			
		||||
	} else {
 | 
			
		||||
		conn.codec = base64Codec
 | 
			
		||||
	}
 | 
			
		||||
	conn.ws = ws
 | 
			
		||||
	conn.channels = make([]*websocketChannel, len(p.Channels))
 | 
			
		||||
	for i, t := range p.Channels {
 | 
			
		||||
		switch t {
 | 
			
		||||
		case ReadChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
 | 
			
		||||
		case WriteChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
 | 
			
		||||
		case ReadWriteChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
 | 
			
		||||
		case IgnoreChannel:
 | 
			
		||||
			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	close(conn.ready)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
 | 
			
		||||
	return handshake(config, req, []string{channelWebSocketProtocol, base64ChannelWebSocketProtocol})
 | 
			
		||||
	supportedProtocols := make([]string, 0, len(conn.protocols))
 | 
			
		||||
	for p := range conn.protocols {
 | 
			
		||||
		supportedProtocols = append(supportedProtocols, p)
 | 
			
		||||
	}
 | 
			
		||||
	return handshake(config, req, supportedProtocols)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (conn *Conn) resetTimeout() {
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,7 @@ import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sync"
 | 
			
		||||
@@ -28,15 +29,19 @@ import (
 | 
			
		||||
	"golang.org/x/net/websocket"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func newServer(handler websocket.Handler) (*httptest.Server, string) {
 | 
			
		||||
func newServer(handler http.Handler) (*httptest.Server, string) {
 | 
			
		||||
	server := httptest.NewServer(handler)
 | 
			
		||||
	serverAddr := server.Listener.Addr().String()
 | 
			
		||||
	return server, serverAddr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRawConn(t *testing.T) {
 | 
			
		||||
	conn := NewConn(ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel)
 | 
			
		||||
	s, addr := newServer(conn.handle)
 | 
			
		||||
	channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
 | 
			
		||||
	conn := NewConn(NewDefaultChannelProtocols(channels))
 | 
			
		||||
 | 
			
		||||
	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		conn.Open(w, req)
 | 
			
		||||
	}))
 | 
			
		||||
	defer s.Close()
 | 
			
		||||
 | 
			
		||||
	client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
 | 
			
		||||
@@ -112,8 +117,10 @@ func TestRawConn(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestBase64Conn(t *testing.T) {
 | 
			
		||||
	conn := NewConn(ReadWriteChannel, ReadWriteChannel)
 | 
			
		||||
	s, addr := newServer(conn.handle)
 | 
			
		||||
	conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
 | 
			
		||||
	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		conn.Open(w, req)
 | 
			
		||||
	}))
 | 
			
		||||
	defer s.Close()
 | 
			
		||||
 | 
			
		||||
	config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
 | 
			
		||||
@@ -167,3 +174,99 @@ func TestBase64Conn(t *testing.T) {
 | 
			
		||||
	client.Close()
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type versionTest struct {
 | 
			
		||||
	supported map[string]bool // protocol -> binary
 | 
			
		||||
	requested []string
 | 
			
		||||
	error     bool
 | 
			
		||||
	expected  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func versionTests() []versionTest {
 | 
			
		||||
	const (
 | 
			
		||||
		binary = true
 | 
			
		||||
		base64 = false
 | 
			
		||||
	)
 | 
			
		||||
	return []versionTest{
 | 
			
		||||
		{
 | 
			
		||||
			supported: nil,
 | 
			
		||||
			requested: []string{"raw"},
 | 
			
		||||
			error:     true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
 | 
			
		||||
			requested: nil,
 | 
			
		||||
			expected:  "",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
 | 
			
		||||
			requested: []string{"v1.raw"},
 | 
			
		||||
			error:     true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
 | 
			
		||||
			requested: []string{"v1.raw", "v1.base64"},
 | 
			
		||||
			error:     true,
 | 
			
		||||
		}, {
 | 
			
		||||
			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
 | 
			
		||||
			requested: []string{"v1.raw", "raw"},
 | 
			
		||||
			expected:  "raw",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
 | 
			
		||||
			requested: []string{"v1.raw"},
 | 
			
		||||
			expected:  "v1.raw",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
 | 
			
		||||
			requested: []string{"v2.base64"},
 | 
			
		||||
			expected:  "v2.base64",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestVersionedConn(t *testing.T) {
 | 
			
		||||
	for i, test := range versionTests() {
 | 
			
		||||
		func() {
 | 
			
		||||
			supportedProtocols := map[string]ChannelProtocolConfig{}
 | 
			
		||||
			for p, binary := range test.supported {
 | 
			
		||||
				supportedProtocols[p] = ChannelProtocolConfig{
 | 
			
		||||
					Binary:   binary,
 | 
			
		||||
					Channels: []ChannelType{ReadWriteChannel},
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			conn := NewConn(supportedProtocols)
 | 
			
		||||
			// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
 | 
			
		||||
			// we use a channel.
 | 
			
		||||
			selectedProtocol := make(chan string, 0)
 | 
			
		||||
			s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
				p, _, _ := conn.Open(w, req)
 | 
			
		||||
				selectedProtocol <- p
 | 
			
		||||
			}))
 | 
			
		||||
			defer s.Close()
 | 
			
		||||
 | 
			
		||||
			config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
			config.Protocol = test.requested
 | 
			
		||||
			client, err := websocket.DialConfig(config)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if !test.error {
 | 
			
		||||
					t.Fatalf("test %d: didn't expect error: %v", i, err)
 | 
			
		||||
				} else {
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			defer client.Close()
 | 
			
		||||
			if test.error && err == nil {
 | 
			
		||||
				t.Fatalf("test %d: expected an error", i)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			<-conn.ready
 | 
			
		||||
			if got, expected := <-selectedProtocol, test.expected; got != expected {
 | 
			
		||||
				t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -23,6 +23,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/websocket"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/kubernetes/pkg/util/runtime"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -37,23 +38,46 @@ const binaryWebSocketProtocol = "binary.k8s.io"
 | 
			
		||||
// possible.
 | 
			
		||||
const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
 | 
			
		||||
 | 
			
		||||
// ReaderProtocolConfig describes a websocket subprotocol with one stream.
 | 
			
		||||
type ReaderProtocolConfig struct {
 | 
			
		||||
	Binary bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDefaultReaderProtocols returns a stream protocol map with the
 | 
			
		||||
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io".
 | 
			
		||||
func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig {
 | 
			
		||||
	return map[string]ReaderProtocolConfig{
 | 
			
		||||
		"": {Binary: true},
 | 
			
		||||
		binaryWebSocketProtocol:       {Binary: true},
 | 
			
		||||
		base64BinaryWebSocketProtocol: {Binary: false},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reader supports returning an arbitrary byte stream over a websocket channel.
 | 
			
		||||
// Supports the "binary.k8s.io" and "base64.binary.k8s.io" subprotocols.
 | 
			
		||||
type Reader struct {
 | 
			
		||||
	err              chan error
 | 
			
		||||
	r                io.Reader
 | 
			
		||||
	ping             bool
 | 
			
		||||
	timeout          time.Duration
 | 
			
		||||
	protocols        map[string]ReaderProtocolConfig
 | 
			
		||||
	selectedProtocol string
 | 
			
		||||
 | 
			
		||||
	handleCrash func() // overridable for testing
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewReader creates a WebSocket pipe that will copy the contents of r to a provided
 | 
			
		||||
// WebSocket connection. If ping is true, a zero length message will be sent to the client
 | 
			
		||||
// before the stream begins reading.
 | 
			
		||||
func NewReader(r io.Reader, ping bool) *Reader {
 | 
			
		||||
//
 | 
			
		||||
// The protocols parameter maps subprotocol names to StreamProtocols. The empty string
 | 
			
		||||
// subprotocol name is used if websocket.Config.Protocol is empty.
 | 
			
		||||
func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader {
 | 
			
		||||
	return &Reader{
 | 
			
		||||
		r:           r,
 | 
			
		||||
		err:         make(chan error),
 | 
			
		||||
		ping:        ping,
 | 
			
		||||
		protocols:   protocols,
 | 
			
		||||
		handleCrash: func() { runtime.HandleCrash() },
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -64,14 +88,18 @@ func (r *Reader) SetIdleTimeout(duration time.Duration) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
 | 
			
		||||
	return handshake(config, req, []string{binaryWebSocketProtocol, base64BinaryWebSocketProtocol})
 | 
			
		||||
	supportedProtocols := make([]string, 0, len(r.protocols))
 | 
			
		||||
	for p := range r.protocols {
 | 
			
		||||
		supportedProtocols = append(supportedProtocols, p)
 | 
			
		||||
	}
 | 
			
		||||
	return handshake(config, req, supportedProtocols)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Copy the reader to the response. The created WebSocket is closed after this
 | 
			
		||||
// method completes.
 | 
			
		||||
func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer runtime.HandleCrash()
 | 
			
		||||
		defer r.handleCrash()
 | 
			
		||||
		websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
 | 
			
		||||
	}()
 | 
			
		||||
	return <-r.err
 | 
			
		||||
@@ -79,11 +107,12 @@ func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
 | 
			
		||||
 | 
			
		||||
// handle implements a WebSocket handler.
 | 
			
		||||
func (r *Reader) handle(ws *websocket.Conn) {
 | 
			
		||||
	encode := len(ws.Config().Protocol) > 0 && ws.Config().Protocol[0] == base64BinaryWebSocketProtocol
 | 
			
		||||
	negotiated := ws.Config().Protocol
 | 
			
		||||
	r.selectedProtocol = negotiated[0]
 | 
			
		||||
	defer close(r.err)
 | 
			
		||||
	defer ws.Close()
 | 
			
		||||
	go IgnoreReceives(ws, r.timeout)
 | 
			
		||||
	r.err <- messageCopy(ws, r.r, encode, r.ping, r.timeout)
 | 
			
		||||
	r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func resetTimeout(ws *websocket.Conn, timeout time.Duration) {
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
@@ -32,7 +33,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func TestStream(t *testing.T) {
 | 
			
		||||
	input := "some random text"
 | 
			
		||||
	r := NewReader(bytes.NewBuffer([]byte(input)), true)
 | 
			
		||||
	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
 | 
			
		||||
	r.SetIdleTimeout(time.Second)
 | 
			
		||||
	data, err := readWebSocket(r, t, nil)
 | 
			
		||||
	if !reflect.DeepEqual(data, []byte(input)) {
 | 
			
		||||
@@ -45,7 +46,7 @@ func TestStream(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
func TestStreamPing(t *testing.T) {
 | 
			
		||||
	input := "some random text"
 | 
			
		||||
	r := NewReader(bytes.NewBuffer([]byte(input)), true)
 | 
			
		||||
	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
 | 
			
		||||
	r.SetIdleTimeout(time.Second)
 | 
			
		||||
	err := expectWebSocketFrames(r, t, nil, [][]byte{
 | 
			
		||||
		{},
 | 
			
		||||
@@ -59,8 +60,8 @@ func TestStreamPing(t *testing.T) {
 | 
			
		||||
func TestStreamBase64(t *testing.T) {
 | 
			
		||||
	input := "some random text"
 | 
			
		||||
	encoded := base64.StdEncoding.EncodeToString([]byte(input))
 | 
			
		||||
	r := NewReader(bytes.NewBuffer([]byte(input)), true)
 | 
			
		||||
	data, err := readWebSocket(r, t, nil, base64BinaryWebSocketProtocol)
 | 
			
		||||
	r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
 | 
			
		||||
	data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
 | 
			
		||||
	if !reflect.DeepEqual(data, []byte(encoded)) {
 | 
			
		||||
		t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
 | 
			
		||||
	}
 | 
			
		||||
@@ -69,6 +70,73 @@ func TestStreamBase64(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStreamVersionedBase64(t *testing.T) {
 | 
			
		||||
	input := "some random text"
 | 
			
		||||
	encoded := base64.StdEncoding.EncodeToString([]byte(input))
 | 
			
		||||
	r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
 | 
			
		||||
		"":                        {Binary: true},
 | 
			
		||||
		"binary.k8s.io":           {Binary: true},
 | 
			
		||||
		"base64.binary.k8s.io":    {Binary: false},
 | 
			
		||||
		"v1.binary.k8s.io":        {Binary: true},
 | 
			
		||||
		"v1.base64.binary.k8s.io": {Binary: false},
 | 
			
		||||
		"v2.binary.k8s.io":        {Binary: true},
 | 
			
		||||
		"v2.base64.binary.k8s.io": {Binary: false},
 | 
			
		||||
	})
 | 
			
		||||
	data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
 | 
			
		||||
	if !reflect.DeepEqual(data, []byte(encoded)) {
 | 
			
		||||
		t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStreamVersionedCopy(t *testing.T) {
 | 
			
		||||
	for i, test := range versionTests() {
 | 
			
		||||
		func() {
 | 
			
		||||
			supportedProtocols := map[string]ReaderProtocolConfig{}
 | 
			
		||||
			for p, binary := range test.supported {
 | 
			
		||||
				supportedProtocols[p] = ReaderProtocolConfig{
 | 
			
		||||
					Binary: binary,
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			input := "some random text"
 | 
			
		||||
			r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
 | 
			
		||||
			s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
				err := r.Copy(w, req)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					w.WriteHeader(503)
 | 
			
		||||
				}
 | 
			
		||||
			}))
 | 
			
		||||
			defer s.Close()
 | 
			
		||||
 | 
			
		||||
			config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Error(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			config.Protocol = test.requested
 | 
			
		||||
			client, err := websocket.DialConfig(config)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if !test.error {
 | 
			
		||||
					t.Errorf("test %d: didn't expect error: %v", i, err)
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			defer client.Close()
 | 
			
		||||
			if test.error && err == nil {
 | 
			
		||||
				t.Errorf("test %d: expected an error", i)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			<-r.err
 | 
			
		||||
			if got, expected := r.selectedProtocol, test.expected; got != expected {
 | 
			
		||||
				t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStreamError(t *testing.T) {
 | 
			
		||||
	input := "some random text"
 | 
			
		||||
	errs := &errorReader{
 | 
			
		||||
@@ -78,7 +146,7 @@ func TestStreamError(t *testing.T) {
 | 
			
		||||
		},
 | 
			
		||||
		err: fmt.Errorf("bad read"),
 | 
			
		||||
	}
 | 
			
		||||
	r := NewReader(errs, false)
 | 
			
		||||
	r := NewReader(errs, false, NewDefaultReaderProtocols())
 | 
			
		||||
 | 
			
		||||
	data, err := readWebSocket(r, t, nil)
 | 
			
		||||
	if !reflect.DeepEqual(data, []byte(input)) {
 | 
			
		||||
@@ -98,7 +166,10 @@ func TestStreamSurvivesPanic(t *testing.T) {
 | 
			
		||||
		},
 | 
			
		||||
		panicMessage: "bad read",
 | 
			
		||||
	}
 | 
			
		||||
	r := NewReader(errs, false)
 | 
			
		||||
	r := NewReader(errs, false, NewDefaultReaderProtocols())
 | 
			
		||||
 | 
			
		||||
	// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
 | 
			
		||||
	r.handleCrash = func() { recover() }
 | 
			
		||||
 | 
			
		||||
	data, err := readWebSocket(r, t, nil)
 | 
			
		||||
	if !reflect.DeepEqual(data, []byte(input)) {
 | 
			
		||||
@@ -121,7 +192,7 @@ func TestStreamClosedDuringRead(t *testing.T) {
 | 
			
		||||
			err:   fmt.Errorf("stuff"),
 | 
			
		||||
			pause: ch,
 | 
			
		||||
		}
 | 
			
		||||
		r := NewReader(errs, false)
 | 
			
		||||
		r := NewReader(errs, false, NewDefaultReaderProtocols())
 | 
			
		||||
 | 
			
		||||
		data, err := readWebSocket(r, t, func(c *websocket.Conn) {
 | 
			
		||||
			c.Close()
 | 
			
		||||
@@ -163,19 +234,13 @@ func (r *errorReader) Read(p []byte) (int, error) {
 | 
			
		||||
 | 
			
		||||
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
 | 
			
		||||
	errCh := make(chan error, 1)
 | 
			
		||||
	s, addr := newServer(func(ws *websocket.Conn) {
 | 
			
		||||
		cfg := ws.Config()
 | 
			
		||||
		cfg.Protocol = protocols
 | 
			
		||||
		go IgnoreReceives(ws, 0)
 | 
			
		||||
		go func() {
 | 
			
		||||
			err := <-r.err
 | 
			
		||||
			errCh <- err
 | 
			
		||||
		}()
 | 
			
		||||
		r.handle(ws)
 | 
			
		||||
	})
 | 
			
		||||
	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		errCh <- r.Copy(w, req)
 | 
			
		||||
	}))
 | 
			
		||||
	defer s.Close()
 | 
			
		||||
 | 
			
		||||
	config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
 | 
			
		||||
	config.Protocol = protocols
 | 
			
		||||
	client, err := websocket.DialConfig(config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -195,19 +260,13 @@ func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols
 | 
			
		||||
 | 
			
		||||
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
 | 
			
		||||
	errCh := make(chan error, 1)
 | 
			
		||||
	s, addr := newServer(func(ws *websocket.Conn) {
 | 
			
		||||
		cfg := ws.Config()
 | 
			
		||||
		cfg.Protocol = protocols
 | 
			
		||||
		go IgnoreReceives(ws, 0)
 | 
			
		||||
		go func() {
 | 
			
		||||
			err := <-r.err
 | 
			
		||||
			errCh <- err
 | 
			
		||||
		}()
 | 
			
		||||
		r.handle(ws)
 | 
			
		||||
	})
 | 
			
		||||
	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		errCh <- r.Copy(w, req)
 | 
			
		||||
	}))
 | 
			
		||||
	defer s.Close()
 | 
			
		||||
 | 
			
		||||
	config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
 | 
			
		||||
	config.Protocol = protocols
 | 
			
		||||
	ws, err := websocket.DialConfig(config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user