mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-11-04 04:08:16 +00:00 
			
		
		
		
	Merge pull request #123542 from liggitt/websocket-round-tripper-protocol
Use the websocket protocol header, verify selected protocol
This commit is contained in:
		@@ -18,6 +18,7 @@ package websocket
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
@@ -25,6 +26,7 @@ import (
 | 
			
		||||
	gwebsocket "github.com/gorilla/websocket"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/apimachinery/pkg/util/httpstream"
 | 
			
		||||
	"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
 | 
			
		||||
	utilnet "k8s.io/apimachinery/pkg/util/net"
 | 
			
		||||
	restclient "k8s.io/client-go/rest"
 | 
			
		||||
	"k8s.io/client-go/transport"
 | 
			
		||||
@@ -88,8 +90,8 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// set the protocol version directly on the dialer from the header
 | 
			
		||||
	protocolVersions := request.Header[httpstream.HeaderProtocolVersion]
 | 
			
		||||
	delete(request.Header, httpstream.HeaderProtocolVersion)
 | 
			
		||||
	protocolVersions := request.Header[wsstream.WebSocketProtocolHeader]
 | 
			
		||||
	delete(request.Header, wsstream.WebSocketProtocolHeader)
 | 
			
		||||
 | 
			
		||||
	dialer := gwebsocket.Dialer{
 | 
			
		||||
		Proxy:           rt.Proxier,
 | 
			
		||||
@@ -108,8 +110,24 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
 | 
			
		||||
	}
 | 
			
		||||
	wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if errors.Is(err, gwebsocket.ErrBadHandshake) {
 | 
			
		||||
			return nil, &httpstream.UpgradeFailureError{Cause: err}
 | 
			
		||||
		}
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Ensure we got back a protocol we understand
 | 
			
		||||
	foundProtocol := false
 | 
			
		||||
	for _, protocolVersion := range protocolVersions {
 | 
			
		||||
		if protocolVersion == wsConn.Subprotocol() {
 | 
			
		||||
			foundProtocol = true
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if !foundProtocol {
 | 
			
		||||
		wsConn.Close() // nolint:errcheck
 | 
			
		||||
		return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rt.Conn = wsConn
 | 
			
		||||
 | 
			
		||||
@@ -149,7 +167,8 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHo
 | 
			
		||||
// a WebSocket connection. Upon success, it returns the negotiated connection.
 | 
			
		||||
// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor.
 | 
			
		||||
func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) {
 | 
			
		||||
	req.Header[httpstream.HeaderProtocolVersion] = protocols
 | 
			
		||||
	// Plumb protocols to RoundTripper#RoundTrip
 | 
			
		||||
	req.Header[wsstream.WebSocketProtocolHeader] = protocols
 | 
			
		||||
	resp, err := rt.RoundTrip(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
 
 | 
			
		||||
@@ -54,7 +54,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
 | 
			
		||||
	rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	requestedProtocol := remotecommand.StreamProtocolV5Name
 | 
			
		||||
	req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
 | 
			
		||||
	req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
 | 
			
		||||
	_, err = rt.RoundTrip(req)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	// WebSocket Connection is stored in websocket RoundTripper.
 | 
			
		||||
@@ -83,11 +83,12 @@ func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	// Requested subprotocol version 1 is not supported by test websocket server.
 | 
			
		||||
	requestedProtocol := remotecommand.StreamProtocolV1Name
 | 
			
		||||
	req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
 | 
			
		||||
	req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
 | 
			
		||||
	_, err = rt.RoundTrip(req)
 | 
			
		||||
	// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
 | 
			
		||||
	require.Error(t, err)
 | 
			
		||||
	assert.True(t, strings.Contains(err.Error(), "bad handshake"))
 | 
			
		||||
	assert.True(t, httpstream.IsUpgradeFailure(err))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user