diff --git a/pkg/features/kube_features.go b/pkg/features/kube_features.go index 6912b4df743..ae1be776065 100644 --- a/pkg/features/kube_features.go +++ b/pkg/features/kube_features.go @@ -179,6 +179,14 @@ const ( // Enables kubelet to detect CSI volume condition and send the event of the abnormal volume to the corresponding pod that is using it. CSIVolumeHealth featuregate.Feature = "CSIVolumeHealth" + // owner: @seans3 + // kep: http://kep.k8s.io/4006 + // alpha: v1.29 + // + // Enables StreamTranslator proxy to handle WebSockets upgrade requests for the + // version of the RemoteCommand subprotocol that supports the "close" signal. + TranslateStreamCloseWebsocketRequests featuregate.Feature = "TranslateStreamCloseWebsocketRequests" + // owner: @nckturner // kep: http://kep.k8s.io/2699 // alpha: v1.27 @@ -925,6 +933,8 @@ var defaultKubernetesFeatureGates = map[featuregate.Feature]featuregate.FeatureS SkipReadOnlyValidationGCE: {Default: true, PreRelease: featuregate.Deprecated}, // remove in 1.31 + TranslateStreamCloseWebsocketRequests: {Default: false, PreRelease: featuregate.Alpha}, + CloudControllerManagerWebhook: {Default: false, PreRelease: featuregate.Alpha}, ContainerCheckpoint: {Default: false, PreRelease: featuregate.Alpha}, diff --git a/pkg/kubelet/server/server_test.go b/pkg/kubelet/server/server_test.go index 3ee71958dd0..8faeb231408 100644 --- a/pkg/kubelet/server/server_test.go +++ b/pkg/kubelet/server/server_test.go @@ -18,6 +18,7 @@ package server import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -959,7 +960,10 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) { url := fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?c=ls&c=-a&" + api.ExecStdinParam + "=1" - upgradeRoundTripper := spdy.NewRoundTripper(nil) + upgradeRoundTripper, err := spdy.NewRoundTripper(&tls.Config{}) + if err != nil { + t.Fatalf("Error creating SpdyRoundTripper: %v", err) + } c := &http.Client{Transport: upgradeRoundTripper} resp, err := c.Do(makeReq(t, "POST", url, "v4.channel.k8s.io")) @@ -1115,7 +1119,10 @@ func testExecAttach(t *testing.T, verb string) { upgradeRoundTripper httpstream.UpgradeRoundTripper c *http.Client ) - upgradeRoundTripper = spdy.NewRoundTripper(nil) + upgradeRoundTripper, err = spdy.NewRoundTripper(&tls.Config{}) + if err != nil { + t.Fatalf("Error creating SpdyRoundTripper: %v", err) + } c = &http.Client{Transport: upgradeRoundTripper} resp, err = c.Do(makeReq(t, "POST", url, "v4.channel.k8s.io")) @@ -1211,7 +1218,10 @@ func TestServePortForwardIdleTimeout(t *testing.T) { url := fw.testHTTPServer.URL + "/portForward/" + podNamespace + "/" + podName - upgradeRoundTripper := spdy.NewRoundTripper(nil) + upgradeRoundTripper, err := spdy.NewRoundTripper(&tls.Config{}) + if err != nil { + t.Fatalf("Error creating SpdyRoundTripper: %v", err) + } c := &http.Client{Transport: upgradeRoundTripper} req := makeReq(t, "POST", url, "portforward.k8s.io") @@ -1310,7 +1320,10 @@ func TestServePortForward(t *testing.T) { c *http.Client ) - upgradeRoundTripper = spdy.NewRoundTripper(nil) + upgradeRoundTripper, err = spdy.NewRoundTripper(&tls.Config{}) + if err != nil { + t.Fatalf("Error creating SpdyRoundTripper: %v", err) + } c = &http.Client{Transport: upgradeRoundTripper} req := makeReq(t, "POST", url, "portforward.k8s.io") diff --git a/pkg/registry/core/pod/rest/subresources.go b/pkg/registry/core/pod/rest/subresources.go index 76e0cdd4ffb..0e031412fbd 100644 --- a/pkg/registry/core/pod/rest/subresources.go +++ b/pkg/registry/core/pod/rest/subresources.go @@ -23,12 +23,16 @@ import ( "net/url" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/httpstream/wsstream" "k8s.io/apimachinery/pkg/util/net" "k8s.io/apimachinery/pkg/util/proxy" genericregistry "k8s.io/apiserver/pkg/registry/generic/registry" "k8s.io/apiserver/pkg/registry/rest" + utilfeature "k8s.io/apiserver/pkg/util/feature" + translator "k8s.io/apiserver/pkg/util/proxy" api "k8s.io/kubernetes/pkg/apis/core" "k8s.io/kubernetes/pkg/capabilities" + "k8s.io/kubernetes/pkg/features" "k8s.io/kubernetes/pkg/kubelet/client" "k8s.io/kubernetes/pkg/registry/core/pod" ) @@ -113,7 +117,21 @@ func (r *AttachREST) Connect(ctx context.Context, name string, opts runtime.Obje if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil + handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder) + if utilfeature.DefaultFeatureGate.Enabled(features.TranslateStreamCloseWebsocketRequests) { + // Wrap the upgrade aware handler to implement stream translation + // for WebSocket/V5 upgrade requests. + streamOptions := translator.Options{ + Stdin: attachOpts.Stdin, + Stdout: attachOpts.Stdout, + Stderr: attachOpts.Stderr, + Tty: attachOpts.TTY, + } + maxBytesPerSec := capabilities.Get().PerConnectionBandwidthLimitBytesPerSec + streamtranslator := translator.NewStreamTranslatorHandler(location, transport, maxBytesPerSec, streamOptions) + handler = translator.NewTranslatingHandler(handler, streamtranslator, wsstream.IsWebSocketRequestWithStreamCloseProtocol) + } + return handler, nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -156,7 +174,21 @@ func (r *ExecREST) Connect(ctx context.Context, name string, opts runtime.Object if err != nil { return nil, err } - return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil + handler := newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder) + if utilfeature.DefaultFeatureGate.Enabled(features.TranslateStreamCloseWebsocketRequests) { + // Wrap the upgrade aware handler to implement stream translation + // for WebSocket/V5 upgrade requests. + streamOptions := translator.Options{ + Stdin: execOpts.Stdin, + Stdout: execOpts.Stdout, + Stderr: execOpts.Stderr, + Tty: execOpts.TTY, + } + maxBytesPerSec := capabilities.Get().PerConnectionBandwidthLimitBytesPerSec + streamtranslator := translator.NewStreamTranslatorHandler(location, transport, maxBytesPerSec, streamOptions) + handler = translator.NewTranslatingHandler(handler, streamtranslator, wsstream.IsWebSocketRequestWithStreamCloseProtocol) + } + return handler, nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -213,7 +245,7 @@ func (r *PortForwardREST) Connect(ctx context.Context, name string, opts runtime return newThrottledUpgradeAwareProxyHandler(location, transport, false, true, responder), nil } -func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) *proxy.UpgradeAwareHandler { +func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder rest.Responder) http.Handler { handler := proxy.NewUpgradeAwareHandler(location, transport, wrapTransport, upgradeRequired, proxy.NewErrorResponder(responder)) handler.MaxBytesPerSec = capabilities.Get().PerConnectionBandwidthLimitBytesPerSec return handler diff --git a/staging/src/k8s.io/apiextensions-apiserver/go.mod b/staging/src/k8s.io/apiextensions-apiserver/go.mod index fc95bbe1a33..2c7debd6204 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/go.mod +++ b/staging/src/k8s.io/apiextensions-apiserver/go.mod @@ -73,9 +73,11 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/moby/spdystream v0.2.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.16.0 // indirect diff --git a/staging/src/k8s.io/apiextensions-apiserver/go.sum b/staging/src/k8s.io/apiextensions-apiserver/go.sum index 58297dddd7b..7e3a3f0f3ca 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/go.sum +++ b/staging/src/k8s.io/apiextensions-apiserver/go.sum @@ -163,6 +163,7 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= @@ -378,6 +379,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8= github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c= github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -388,6 +390,7 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go index 32f075782a9..a32fce5a0c1 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go @@ -17,6 +17,7 @@ limitations under the License. package httpstream import ( + "errors" "fmt" "io" "net/http" @@ -95,6 +96,26 @@ type Stream interface { Identifier() uint32 } +// UpgradeFailureError encapsulates the cause for why the streaming +// upgrade request failed. Implements error interface. +type UpgradeFailureError struct { + Cause error +} + +func (u *UpgradeFailureError) Error() string { + return fmt.Sprintf("unable to upgrade streaming request: %s", u.Cause) +} + +// IsUpgradeFailure returns true if the passed error is (or wrapped error contains) +// the UpgradeFailureError. +func IsUpgradeFailure(err error) bool { + if err == nil { + return false + } + var upgradeErr *UpgradeFailureError + return errors.As(err, &upgradeErr) +} + // IsUpgradeRequest returns true if the given request is a connection upgrade request func IsUpgradeRequest(req *http.Request) bool { for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] { diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream_test.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream_test.go index e988bce2b31..11fb928634e 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/httpstream_test.go @@ -17,6 +17,8 @@ limitations under the License. package httpstream import ( + "errors" + "fmt" "net/http" "reflect" "testing" @@ -129,3 +131,40 @@ func TestHandshake(t *testing.T) { } } } + +func TestIsUpgradeFailureError(t *testing.T) { + testCases := map[string]struct { + err error + expected bool + }{ + "nil error should return false": { + err: nil, + expected: false, + }, + "Non-upgrade error should return false": { + err: fmt.Errorf("this is not an upgrade error"), + expected: false, + }, + "UpgradeFailure error should return true": { + err: &UpgradeFailureError{}, + expected: true, + }, + "Wrapped Non-UpgradeFailure error should return false": { + err: fmt.Errorf("%s: %w", "first error", errors.New("Non-upgrade error")), + expected: false, + }, + "Wrapped UpgradeFailure error should return true": { + err: fmt.Errorf("%s: %w", "first error", &UpgradeFailureError{}), + expected: true, + }, + } + + for name, test := range testCases { + t.Run(name, func(t *testing.T) { + actual := IsUpgradeFailure(test.err) + if test.expected != actual { + t.Errorf("expected upgrade failure %t, got %t", test.expected, actual) + } + }) + } +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go index 7fe52ee568e..c78326fa3b5 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper.go @@ -38,6 +38,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/serializer" "k8s.io/apimachinery/pkg/util/httpstream" utilnet "k8s.io/apimachinery/pkg/util/net" + apiproxy "k8s.io/apimachinery/pkg/util/proxy" "k8s.io/apimachinery/third_party/forked/golang/netutil" ) @@ -68,6 +69,10 @@ type SpdyRoundTripper struct { // pingPeriod is a period for sending Ping frames over established // connections. pingPeriod time.Duration + + // upgradeTransport is an optional substitute for dialing if present. This field is + // mutually exclusive with the "tlsConfig", "Dialer", and "proxier". + upgradeTransport http.RoundTripper } var _ utilnet.TLSClientConfigHolder = &SpdyRoundTripper{} @@ -76,43 +81,61 @@ var _ utilnet.Dialer = &SpdyRoundTripper{} // NewRoundTripper creates a new SpdyRoundTripper that will use the specified // tlsConfig. -func NewRoundTripper(tlsConfig *tls.Config) *SpdyRoundTripper { +func NewRoundTripper(tlsConfig *tls.Config) (*SpdyRoundTripper, error) { return NewRoundTripperWithConfig(RoundTripperConfig{ - TLS: tlsConfig, + TLS: tlsConfig, + UpgradeTransport: nil, }) } // NewRoundTripperWithProxy creates a new SpdyRoundTripper that will use the // specified tlsConfig and proxy func. -func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) *SpdyRoundTripper { +func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) (*SpdyRoundTripper, error) { return NewRoundTripperWithConfig(RoundTripperConfig{ - TLS: tlsConfig, - Proxier: proxier, + TLS: tlsConfig, + Proxier: proxier, + UpgradeTransport: nil, }) } // NewRoundTripperWithConfig creates a new SpdyRoundTripper with the specified -// configuration. -func NewRoundTripperWithConfig(cfg RoundTripperConfig) *SpdyRoundTripper { +// configuration. Returns an error if the SpdyRoundTripper is misconfigured. +func NewRoundTripperWithConfig(cfg RoundTripperConfig) (*SpdyRoundTripper, error) { + // Process UpgradeTransport, which is mutually exclusive to TLSConfig and Proxier. + if cfg.UpgradeTransport != nil { + if cfg.TLS != nil || cfg.Proxier != nil { + return nil, fmt.Errorf("SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier") + } + tlsConfig, err := utilnet.TLSClientConfig(cfg.UpgradeTransport) + if err != nil { + return nil, fmt.Errorf("SpdyRoundTripper: Unable to retrieve TLSConfig from UpgradeTransport: %v", err) + } + cfg.TLS = tlsConfig + } if cfg.Proxier == nil { cfg.Proxier = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment) } return &SpdyRoundTripper{ - tlsConfig: cfg.TLS, - proxier: cfg.Proxier, - pingPeriod: cfg.PingPeriod, - } + tlsConfig: cfg.TLS, + proxier: cfg.Proxier, + pingPeriod: cfg.PingPeriod, + upgradeTransport: cfg.UpgradeTransport, + }, nil } // RoundTripperConfig is a set of options for an SpdyRoundTripper. type RoundTripperConfig struct { - // TLS configuration used by the round tripper. + // TLS configuration used by the round tripper if UpgradeTransport not present. TLS *tls.Config // Proxier is a proxy function invoked on each request. Optional. Proxier func(*http.Request) (*url.URL, error) // PingPeriod is a period for sending SPDY Pings on the connection. // Optional. PingPeriod time.Duration + // UpgradeTransport is a subtitute transport used for dialing. If set, + // this field will be used instead of "TLS" and "Proxier" for connection creation. + // Optional. + UpgradeTransport http.RoundTripper } // TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during @@ -123,7 +146,13 @@ func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config { // Dial implements k8s.io/apimachinery/pkg/util/net.Dialer. func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) { - conn, err := s.dial(req) + var conn net.Conn + var err error + if s.upgradeTransport != nil { + conn, err = apiproxy.DialURL(req.Context(), req.URL, s.upgradeTransport) + } else { + conn, err = s.dial(req) + } if err != nil { return nil, err } diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go index b2c2b88513a..de88f4e6071 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/spdy/roundtripper_test.go @@ -25,7 +25,9 @@ import ( "net/http" "net/http/httptest" "net/url" + "reflect" "strconv" + "strings" "testing" "github.com/armon/go-socks5" @@ -324,7 +326,10 @@ func TestRoundTripAndNewConnection(t *testing.T) { t.Fatalf("error creating request: %s", err) } - spdyTransport := NewRoundTripper(testCase.clientTLS) + spdyTransport, err := NewRoundTripper(testCase.clientTLS) + if err != nil { + t.Fatalf("error creating SpdyRoundTripper: %v", err) + } var proxierCalled bool var proxyCalledWithHost string @@ -428,6 +433,74 @@ func TestRoundTripAndNewConnection(t *testing.T) { } } +// Tests SpdyRoundTripper constructors +func TestRoundTripConstuctor(t *testing.T) { + testCases := map[string]struct { + tlsConfig *tls.Config + proxier func(req *http.Request) (*url.URL, error) + upgradeTransport http.RoundTripper + expectedTLSConfig *tls.Config + errMsg string + }{ + "Basic TLSConfig; no error": { + tlsConfig: &tls.Config{InsecureSkipVerify: true}, + expectedTLSConfig: &tls.Config{InsecureSkipVerify: true}, + upgradeTransport: nil, + }, + "Basic TLSConfig and Proxier: no error": { + tlsConfig: &tls.Config{InsecureSkipVerify: true}, + proxier: func(req *http.Request) (*url.URL, error) { return nil, nil }, + expectedTLSConfig: &tls.Config{InsecureSkipVerify: true}, + upgradeTransport: nil, + }, + "TLSConfig with UpgradeTransport: error": { + tlsConfig: &tls.Config{InsecureSkipVerify: true}, + upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + expectedTLSConfig: &tls.Config{InsecureSkipVerify: true}, + errMsg: "SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier", + }, + "Proxier with UpgradeTransport: error": { + proxier: func(req *http.Request) (*url.URL, error) { return nil, nil }, + upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + expectedTLSConfig: &tls.Config{InsecureSkipVerify: true}, + errMsg: "SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier", + }, + "Only UpgradeTransport: no error": { + upgradeTransport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + expectedTLSConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + spdyRoundTripper, err := NewRoundTripperWithConfig( + RoundTripperConfig{ + TLS: testCase.tlsConfig, + Proxier: testCase.proxier, + UpgradeTransport: testCase.upgradeTransport, + }, + ) + if testCase.errMsg != "" { + if err == nil { + t.Fatalf("expected error but received none") + } + if !strings.Contains(err.Error(), testCase.errMsg) { + t.Fatalf("expected error message (%s), got (%s)", err.Error(), testCase.errMsg) + } + } + if testCase.errMsg == "" { + if err != nil { + t.Fatalf("unexpected error received: %v", err) + } + actualTLSConfig := spdyRoundTripper.TLSClientConfig() + if !reflect.DeepEqual(testCase.expectedTLSConfig, actualTLSConfig) { + t.Errorf("expected TLSConfig (%v), got (%v)", + testCase.expectedTLSConfig, actualTLSConfig) + } + } + }) + } +} + type Interceptor struct { Authorization socks5.AuthContext proxyCalledWithHost *string @@ -544,7 +617,10 @@ func TestRoundTripSocks5AndNewConnection(t *testing.T) { t.Fatalf("error creating request: %s", err) } - spdyTransport := NewRoundTripper(testCase.clientTLS) + spdyTransport, err := NewRoundTripper(testCase.clientTLS) + if err != nil { + t.Fatalf("error creating SpdyRoundTripper: %v", err) + } var proxierCalled bool var proxyCalledWithHost string @@ -704,7 +780,10 @@ func TestRoundTripPassesContextToDialer(t *testing.T) { cancel() req, err := http.NewRequestWithContext(ctx, "GET", u, nil) require.NoError(t, err) - spdyTransport := NewRoundTripper(&tls.Config{}) + spdyTransport, err := NewRoundTripper(&tls.Config{}) + if err != nil { + t.Fatalf("error creating SpdyRoundTripper: %v", err) + } _, err = spdyTransport.Dial(req) assert.EqualError(t, err, "dial tcp 127.0.0.1:1233: operation was canceled") }) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go index d153070cedf..7cfdd063217 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go @@ -32,6 +32,8 @@ import ( "k8s.io/klog/v2" ) +const WebSocketProtocolHeader = "Sec-Websocket-Protocol" + // The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating // the channel number (zero indexed) the message was sent on. Messages in both directions should // prefix their messages with this channel byte. When used for remote execution, the channel numbers @@ -87,6 +89,23 @@ func IsWebSocketRequest(req *http.Request) bool { return httpstream.IsUpgradeRequest(req) } +// IsWebSocketRequestWithStreamCloseProtocol returns true if the request contains headers +// identifying that it is requesting a websocket upgrade with a remotecommand protocol +// version that supports the "CLOSE" signal; false otherwise. +func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool { + if !IsWebSocketRequest(req) { + return false + } + requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader)) + for _, requestedProtocol := range strings.Split(requestedProtocols, ",") { + if protocolSupportsStreamClose(strings.TrimSpace(requestedProtocol)) { + return true + } + } + + return false +} + // IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the // read and write deadlines are pushed every time a new message is received. func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) { @@ -168,15 +187,46 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) { conn.timeout = duration } +// SetWriteDeadline sets a timeout on writing to the websocket connection. The +// passed "duration" identifies how far into the future the write must complete +// by before the timeout fires. +func (conn *Conn) SetWriteDeadline(duration time.Duration) { + conn.ws.SetWriteDeadline(time.Now().Add(duration)) //nolint:errcheck +} + // Open the connection and create channels for reading and writing. It returns // the selected subprotocol, a slice of channels and an error. func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) { + // serveHTTPComplete is channel that is closed/selected when "websocket#ServeHTTP" finishes. + serveHTTPComplete := make(chan struct{}) + // Ensure panic in spawned goroutine is propagated into the parent goroutine. + panicChan := make(chan any, 1) go func() { - defer runtime.HandleCrash() - defer conn.Close() + // If websocket server returns, propagate panic if necessary. Otherwise, + // signal HTTPServe finished by closing "serveHTTPComplete". + defer func() { + if p := recover(); p != nil { + panicChan <- p + } else { + close(serveHTTPComplete) + } + }() websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req) }() - <-conn.ready + + // In normal circumstances, "websocket.Server#ServeHTTP" calls "initialize" which closes + // "conn.ready" and then blocks until serving is complete. + select { + case <-conn.ready: + klog.V(8).Infof("websocket server initialized--serving") + case <-serveHTTPComplete: + // websocket server returned before completing initialization; cleanup and return error. + conn.closeNonThreadSafe() //nolint:errcheck + return "", nil, fmt.Errorf("websocket server finished before becoming ready") + case p := <-panicChan: + panic(p) + } + rwc := make([]io.ReadWriteCloser, len(conn.channels)) for i := range conn.channels { rwc[i] = conn.channels[i] @@ -225,14 +275,23 @@ func (conn *Conn) resetTimeout() { } } -// Close is only valid after Open has been called -func (conn *Conn) Close() error { - <-conn.ready +// closeNonThreadSafe cleans up by closing streams and the websocket +// connection *without* waiting for the "ready" channel. +func (conn *Conn) closeNonThreadSafe() error { for _, s := range conn.channels { s.Close() } - conn.ws.Close() - return nil + var err error + if conn.ws != nil { + err = conn.ws.Close() + } + return err +} + +// Close is only valid after Open has been called +func (conn *Conn) Close() error { + <-conn.ready + return conn.closeNonThreadSafe() } // protocolSupportsStreamClose returns true if the passed protocol @@ -244,8 +303,8 @@ func protocolSupportsStreamClose(protocol string) bool { // handle implements a websocket handler. func (conn *Conn) handle(ws *websocket.Conn) { - defer conn.Close() conn.initialize(ws) + defer conn.Close() supportsStreamClose := protocolSupportsStreamClose(conn.selectedProtocol) for { diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn_test.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn_test.go index 8d9f5d5d417..e4a88a1a8cd 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn_test.go @@ -25,6 +25,8 @@ import ( "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/websocket" ) @@ -271,3 +273,146 @@ func TestVersionedConn(t *testing.T) { }() } } + +func TestIsWebSocketRequestWithStreamCloseProtocol(t *testing.T) { + tests := map[string]struct { + headers map[string]string + expected bool + }{ + "No headers returns false": { + headers: map[string]string{}, + expected: false, + }, + "Only connection upgrade header is false": { + headers: map[string]string{ + "Connection": "upgrade", + }, + expected: false, + }, + "Only websocket upgrade header is false": { + headers: map[string]string{ + "Upgrade": "websocket", + }, + expected: false, + }, + "Only websocket and connection upgrade headers is false": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + }, + expected: false, + }, + "Missing connection/upgrade header is false": { + headers: map[string]string{ + "Upgrade": "websocket", + WebSocketProtocolHeader: "v5.channel.k8s.io", + }, + expected: false, + }, + "Websocket connection upgrade headers with v5 protocol is true": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + WebSocketProtocolHeader: "v5.channel.k8s.io", + }, + expected: true, + }, + "Websocket connection upgrade headers with wrong case v5 protocol is false": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + WebSocketProtocolHeader: "v5.CHANNEL.k8s.io", // header value is case-sensitive + }, + expected: false, + }, + "Websocket connection upgrade headers with v4 protocol is false": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + WebSocketProtocolHeader: "v4.channel.k8s.io", + }, + expected: false, + }, + "Websocket connection upgrade headers with multiple protocols but missing v5 is false": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + WebSocketProtocolHeader: "v4.channel.k8s.io,v3.channel.k8s.io,v2.channel.k8s.io", + }, + expected: false, + }, + "Websocket connection upgrade headers with multiple protocols including v5 and spaces is true": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + WebSocketProtocolHeader: "v5.channel.k8s.io, v4.channel.k8s.io", + }, + expected: true, + }, + "Websocket connection upgrade headers with multiple protocols out of order including v5 and spaces is true": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + WebSocketProtocolHeader: "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io", + }, + expected: true, + }, + + "Websocket connection upgrade headers key is case-insensitive": { + headers: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "sec-websocket-protocol": "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io", + }, + expected: true, + }, + } + + for name, test := range tests { + req, err := http.NewRequest("GET", "http://www.example.com/", nil) + require.NoError(t, err) + for key, value := range test.headers { + req.Header.Add(key, value) + } + actual := IsWebSocketRequestWithStreamCloseProtocol(req) + assert.Equal(t, test.expected, actual, "%s: expected (%t), got (%t)", name, test.expected, actual) + } +} + +func TestProtocolSupportsStreamClose(t *testing.T) { + tests := map[string]struct { + protocol string + expected bool + }{ + "empty protocol returns false": { + protocol: "", + expected: false, + }, + "not binary protocol returns false": { + protocol: "base64.channel.k8s.io", + expected: false, + }, + "V1 protocol returns false": { + protocol: "channel.k8s.io", + expected: false, + }, + "V4 protocol returns false": { + protocol: "v4.channel.k8s.io", + expected: false, + }, + "V5 protocol returns true": { + protocol: "v5.channel.k8s.io", + expected: true, + }, + "V5 protocol wrong case returns false": { + protocol: "V5.channel.K8S.io", + expected: false, + }, + } + + for name, test := range tests { + actual := protocolSupportsStreamClose(test.protocol) + assert.Equal(t, test.expected, actual, + "%s: expected (%t), got (%t)", name, test.expected, actual) + } +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go b/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go index 4ceb2e06eab..e5196d1ee83 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go @@ -29,12 +29,12 @@ import ( "k8s.io/klog/v2" ) -// dialURL will dial the specified URL using the underlying dialer held by the passed +// DialURL will dial the specified URL using the underlying dialer held by the passed // RoundTripper. The primary use of this method is to support proxying upgradable connections. // For this reason this method will prefer to negotiate http/1.1 if the URL scheme is https. // If you wish to ensure ALPN negotiates http2 then set NextProto=[]string{"http2"} in the // TLSConfig of the http.Transport -func dialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) { +func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) { dialAddr := netutil.CanonicalAddr(url) dialer, err := utilnet.DialerFor(transport) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial_test.go b/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial_test.go index 32e951e61ca..488e878b723 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/proxy/dial_test.go @@ -143,7 +143,7 @@ func TestDialURL(t *testing.T) { u, _ := url.Parse(ts.URL) _, p, _ := net.SplitHostPort(u.Host) u.Host = net.JoinHostPort("127.0.0.1", p) - conn, err := dialURL(context.Background(), u, transport) + conn, err := DialURL(context.Background(), u, transport) // Make sure dialing doesn't mutate the transport's TLSConfig if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) { diff --git a/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go b/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go index ac2ada5472c..76acdfb4aca 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/proxy/upgradeaware.go @@ -492,7 +492,7 @@ func getResponse(r io.Reader) (*http.Response, []byte, error) { // dial dials the backend at req.URL and writes req to it. func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) { - conn, err := dialURL(req.Context(), req.URL, transport) + conn, err := DialURL(req.Context(), req.URL, transport) if err != nil { return nil, fmt.Errorf("error dialing backend: %v", err) } diff --git a/staging/src/k8s.io/apiserver/go.mod b/staging/src/k8s.io/apiserver/go.mod index d9cbb34d33b..2cb822a2fc5 100644 --- a/staging/src/k8s.io/apiserver/go.mod +++ b/staging/src/k8s.io/apiserver/go.mod @@ -18,6 +18,7 @@ require ( github.com/google/uuid v1.3.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 + github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.3 go.etcd.io/etcd/api/v3 v3.5.9 @@ -87,9 +88,9 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/moby/spdystream v0.2.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pquerna/cachecontrol v0.1.0 // indirect diff --git a/staging/src/k8s.io/apiserver/go.sum b/staging/src/k8s.io/apiserver/go.sum index 2677236207d..0d1d14cccaa 100644 --- a/staging/src/k8s.io/apiserver/go.sum +++ b/staging/src/k8s.io/apiserver/go.sum @@ -163,6 +163,7 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= @@ -376,6 +377,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8= github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c= github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtranslator.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtranslator.go new file mode 100644 index 00000000000..94ea13dff5b --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtranslator.go @@ -0,0 +1,167 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "fmt" + "net/http" + "net/url" + + "github.com/mxk/go-flowrate/flowrate" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + constants "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/client-go/tools/remotecommand" + "k8s.io/client-go/util/exec" +) + +// StreamTranslatorHandler is a handler which translates WebSocket stream data +// to SPDY to proxy to kubelet (and ContainerRuntime). +type StreamTranslatorHandler struct { + // Location is the location of the upstream proxy. It is used as the location to Dial on the upstream server + // for upgrade requests. + Location *url.URL + // Transport provides an optional round tripper to use to proxy. If nil, the default proxy transport is used + Transport http.RoundTripper + // MaxBytesPerSec throttles stream Reader/Writer if necessary + MaxBytesPerSec int64 + // Options define the requested streams (e.g. stdin, stdout). + Options Options +} + +// NewStreamTranslatorHandler creates a new proxy handler. Responder is required for returning +// errors to the caller. +func NewStreamTranslatorHandler(location *url.URL, transport http.RoundTripper, maxBytesPerSec int64, opts Options) *StreamTranslatorHandler { + return &StreamTranslatorHandler{ + Location: location, + Transport: transport, + MaxBytesPerSec: maxBytesPerSec, + Options: opts, + } +} + +func (h *StreamTranslatorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Create WebSocket server, including particular streams requested. If this websocket + // endpoint is not able to be upgraded, the websocket library will return errors + // to the client. + websocketStreams, err := webSocketServerStreams(req, w, h.Options) + if err != nil { + return + } + defer websocketStreams.conn.Close() + + // Creating SPDY executor, ensuring redirects are not followed. + spdyRoundTripper, err := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{UpgradeTransport: h.Transport}) + if err != nil { + websocketStreams.writeStatus(apierrors.NewInternalError(err)) //nolint:errcheck + return + } + spdyExecutor, err := remotecommand.NewSPDYExecutorRejectRedirects(spdyRoundTripper, spdyRoundTripper, "POST", h.Location) + if err != nil { + websocketStreams.writeStatus(apierrors.NewInternalError(err)) //nolint:errcheck + return + } + + // Wire the WebSocket server streams output to the SPDY client input. The stdin/stdout/stderr streams + // can be throttled if the transfer rate exceeds the "MaxBytesPerSec" (zero means unset). Throttling + // the streams instead of the underlying connection *may* not perform the same if two streams + // traveling the same direction (e.g. stdout, stderr) are being maxed out. + opts := remotecommand.StreamOptions{} + if h.Options.Stdin { + stdin := websocketStreams.stdinStream + if h.MaxBytesPerSec > 0 { + stdin = flowrate.NewReader(stdin, h.MaxBytesPerSec) + } + opts.Stdin = stdin + } + if h.Options.Stdout { + stdout := websocketStreams.stdoutStream + if h.MaxBytesPerSec > 0 { + stdout = flowrate.NewWriter(stdout, h.MaxBytesPerSec) + } + opts.Stdout = stdout + } + if h.Options.Stderr { + stderr := websocketStreams.stderrStream + if h.MaxBytesPerSec > 0 { + stderr = flowrate.NewWriter(stderr, h.MaxBytesPerSec) + } + opts.Stderr = stderr + } + if h.Options.Tty { + opts.Tty = true + opts.TerminalSizeQueue = &translatorSizeQueue{resizeChan: websocketStreams.resizeChan} + } + // Start the SPDY client with connected streams. Output from the WebSocket server + // streams will be forwarded into the SPDY client. Report SPDY execution errors + // through the websocket error stream. + err = spdyExecutor.StreamWithContext(req.Context(), opts) + if err != nil { + //nolint:errcheck // Ignore writeStatus returned error + if statusErr, ok := err.(*apierrors.StatusError); ok { + websocketStreams.writeStatus(statusErr) + } else if exitErr, ok := err.(exec.CodeExitError); ok && exitErr.Exited() { + websocketStreams.writeStatus(codeExitToStatusError(exitErr)) + } else { + websocketStreams.writeStatus(apierrors.NewInternalError(err)) + } + return + } + + // Write the success status back to the WebSocket client. + //nolint:errcheck + websocketStreams.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{ + Status: metav1.StatusSuccess, + }}) +} + +// translatorSizeQueue feeds the size events from the WebSocket +// resizeChan into the SPDY client input. Implements TerminalSizeQueue +// interface. +type translatorSizeQueue struct { + resizeChan chan remotecommand.TerminalSize +} + +func (t *translatorSizeQueue) Next() *remotecommand.TerminalSize { + size, ok := <-t.resizeChan + if !ok { + return nil + } + return &size +} + +// codeExitToStatusError converts a passed CodeExitError to the type necessary +// to send through an error stream using "writeStatus". +func codeExitToStatusError(exitErr exec.CodeExitError) *apierrors.StatusError { + rc := exitErr.ExitStatus() + return &apierrors.StatusError{ + ErrStatus: metav1.Status{ + Status: metav1.StatusFailure, + Reason: constants.NonZeroExitCodeReason, + Details: &metav1.StatusDetails{ + Causes: []metav1.StatusCause{ + { + Type: constants.ExitCodeCauseType, + Message: fmt.Sprintf("%d", rc), + }, + }, + }, + Message: fmt.Sprintf("command terminated with non-zero exit code: %v", exitErr), + }, + } +} diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtranslator_test.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtranslator_test.go new file mode 100644 index 00000000000..6246c35d49c --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/streamtranslator_test.go @@ -0,0 +1,872 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "math" + mrand "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + "time" + + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + rcconstants "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/remotecommand" + "k8s.io/client-go/transport" +) + +// TestStreamTranslator_LoopbackStdinToStdout returns random data sent on the client's +// STDIN channel back onto the client's STDOUT channel. There are two servers in this test: the +// upstream fake SPDY server, and the StreamTranslator server. The StreamTranslator proxys the +// data received from the websocket client upstream to the SPDY server (by translating the +// websocket data into spdy). The returned data read on the websocket client STDOUT is then +// compared the random data sent on STDIN to ensure they are the same. +func TestStreamTranslator_LoopbackStdinToStdout(t *testing.T) { + // Create upstream fake SPDY server which copies STDIN back onto STDOUT stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx, err := createSPDYServerStreams(w, req, Options{ + Stdin: true, + Stdout: true, + }) + if err != nil { + t.Errorf("error on createHTTPStreams: %v", err) + return + } + defer ctx.conn.Close() + // Loopback STDIN data onto STDOUT stream. + _, err = io.Copy(ctx.stdoutStream, ctx.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDOUT: %v", err) + } + + })) + defer spdyServer.Close() + // Create StreamTranslatorHandler, which points upstream to fake SPDY server with + // streams STDIN and STDOUT. Create test server from StreamTranslatorHandler. + spdyLocation, err := url.Parse(spdyServer.URL) + if err != nil { + t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL) + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streams := Options{Stdin: true, Stdout: true} + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &remotecommand.StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +// TestStreamTranslator_LoopbackStdinToStderr returns random data sent on the client's +// STDIN channel back onto the client's STDERR channel. There are two servers in this test: the +// upstream fake SPDY server, and the StreamTranslator server. The StreamTranslator proxys the +// data received from the websocket client upstream to the SPDY server (by translating the +// websocket data into spdy). The returned data read on the websocket client STDERR is then +// compared the random data sent on STDIN to ensure they are the same. +func TestStreamTranslator_LoopbackStdinToStderr(t *testing.T) { + // Create upstream fake SPDY server which copies STDIN back onto STDERR stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx, err := createSPDYServerStreams(w, req, Options{ + Stdin: true, + Stderr: true, + }) + if err != nil { + t.Errorf("error on createHTTPStreams: %v", err) + return + } + defer ctx.conn.Close() + // Loopback STDIN data onto STDERR stream. + _, err = io.Copy(ctx.stderrStream, ctx.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDERR: %v", err) + } + })) + defer spdyServer.Close() + // Create StreamTranslatorHandler, which points upstream to fake SPDY server with + // streams STDIN and STDERR. Create test server from StreamTranslatorHandler. + spdyLocation, err := url.Parse(spdyServer.URL) + if err != nil { + t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL) + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streams := Options{Stdin: true, Stderr: true} + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDERR buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stderr bytes.Buffer + options := &remotecommand.StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stderr: &stderr, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + data, err := io.ReadAll(bytes.NewReader(stderr.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDERR. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +// Returns a random exit code in the range(1-127). +func randomExitCode() int { + errorCode := mrand.Intn(127) // Range: (0 - 126) + errorCode += 1 // Range: (1 - 127) + return errorCode +} + +// TestStreamTranslator_ErrorStream tests the error stream by sending an error with a random +// exit code, then validating the error arrives on the error stream. +func TestStreamTranslator_ErrorStream(t *testing.T) { + expectedExitCode := randomExitCode() + // Create upstream fake SPDY server, returning a non-zero exit code + // on error stream within the structured error. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx, err := createSPDYServerStreams(w, req, Options{ + Stdout: true, + }) + if err != nil { + t.Errorf("error on createHTTPStreams: %v", err) + return + } + defer ctx.conn.Close() + // Read/discard STDIN data before returning error on error stream. + _, err = io.Copy(io.Discard, ctx.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to DISCARD: %v", err) + } + // Force an non-zero exit code error returned on the error stream. + err = ctx.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{ + Status: metav1.StatusFailure, + Reason: rcconstants.NonZeroExitCodeReason, + Details: &metav1.StatusDetails{ + Causes: []metav1.StatusCause{ + { + Type: rcconstants.ExitCodeCauseType, + Message: fmt.Sprintf("%d", expectedExitCode), + }, + }, + }, + }}) + if err != nil { + t.Fatalf("error writing status: %v", err) + } + })) + defer spdyServer.Close() + // Create StreamTranslatorHandler, which points upstream to fake SPDY server, and + // create a test server using the StreamTranslatorHandler. + spdyLocation, err := url.Parse(spdyServer.URL) + if err != nil { + t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL) + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streams := Options{Stdin: true} + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be discarded at + // upstream SDPY server. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + options := &remotecommand.StreamOptions{ + Stdin: bytes.NewReader(randomData), + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Expect exit code error on error stream. + if err == nil { + t.Errorf("expected error, but received none") + } + expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode) + // Compare expected error with exit code to actual error. + if expectedError != err.Error() { + t.Errorf("expected error (%s), got (%s)", expectedError, err) + } + } +} + +// TestStreamTranslator_MultipleReadChannels tests two streams (STDOUT, STDERR) reading from +// the connections at the same time. +func TestStreamTranslator_MultipleReadChannels(t *testing.T) { + // Create upstream fake SPDY server which copies STDIN back onto STDOUT and STDERR stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx, err := createSPDYServerStreams(w, req, Options{ + Stdin: true, + Stdout: true, + Stderr: true, + }) + if err != nil { + t.Errorf("error on createHTTPStreams: %v", err) + return + } + defer ctx.conn.Close() + // TeeReader copies data read on STDIN onto STDERR. + stdinReader := io.TeeReader(ctx.stdinStream, ctx.stderrStream) + // Also copy STDIN to STDOUT. + _, err = io.Copy(ctx.stdoutStream, stdinReader) + if err != nil { + t.Errorf("error copying STDIN to STDOUT: %v", err) + } + })) + defer spdyServer.Close() + // Create StreamTranslatorHandler, which points upstream to fake SPDY server with + // streams STDIN, STDOUT, and STDERR. Create test server from StreamTranslatorHandler. + spdyLocation, err := url.Parse(spdyServer.URL) + if err != nil { + t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL) + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streams := Options{Stdin: true, Stdout: true, Stderr: true} + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT and STDERR buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout, stderr bytes.Buffer + options := &remotecommand.StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + Stderr: &stderr, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(stdoutBytes, randomData) { + t.Errorf("unexpected data received: %d sent: %d", len(stdoutBytes), len(randomData)) + } + stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDERR. + if !bytes.Equal(stderrBytes, randomData) { + t.Errorf("unexpected data received: %d sent: %d", len(stderrBytes), len(randomData)) + } +} + +// TestStreamTranslator_ThrottleReadChannels tests two streams (STDOUT, STDERR) using rate limited streams. +func TestStreamTranslator_ThrottleReadChannels(t *testing.T) { + // Create upstream fake SPDY server which copies STDIN back onto STDOUT and STDERR stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx, err := createSPDYServerStreams(w, req, Options{ + Stdin: true, + Stdout: true, + Stderr: true, + }) + if err != nil { + t.Errorf("error on createHTTPStreams: %v", err) + return + } + defer ctx.conn.Close() + // TeeReader copies data read on STDIN onto STDERR. + stdinReader := io.TeeReader(ctx.stdinStream, ctx.stderrStream) + // Also copy STDIN to STDOUT. + _, err = io.Copy(ctx.stdoutStream, stdinReader) + if err != nil { + t.Errorf("error copying STDIN to STDOUT: %v", err) + } + })) + defer spdyServer.Close() + // Create StreamTranslatorHandler, which points upstream to fake SPDY server with + // streams STDIN, STDOUT, and STDERR. Create test server from StreamTranslatorHandler. + spdyLocation, err := url.Parse(spdyServer.URL) + if err != nil { + t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL) + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streams := Options{Stdin: true, Stdout: true, Stderr: true} + maxBytesPerSec := 900 * 1024 // slightly less than the 1MB that is being transferred to exercise throttling. + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, int64(maxBytesPerSec), streams) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT and STDERR buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout, stderr bytes.Buffer + options := &remotecommand.StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + Stderr: &stderr, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(stdoutBytes, randomData) { + t.Errorf("unexpected data received: %d sent: %d", len(stdoutBytes), len(randomData)) + } + stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDERR. + if !bytes.Equal(stderrBytes, randomData) { + t.Errorf("unexpected data received: %d sent: %d", len(stderrBytes), len(randomData)) + } +} + +// fakeTerminalSizeQueue implements TerminalSizeQueue, returning a random set of +// "maxSizes" number of TerminalSizes, storing the TerminalSizes in "sizes" slice. +type fakeTerminalSizeQueue struct { + maxSizes int + terminalSizes []remotecommand.TerminalSize +} + +// newTerminalSizeQueue returns a pointer to a fakeTerminalSizeQueue passing +// "max" number of random TerminalSizes created. +func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue { + return &fakeTerminalSizeQueue{ + maxSizes: max, + terminalSizes: make([]remotecommand.TerminalSize, 0, max), + } +} + +// Next returns a pointer to the next random TerminalSize, or nil if we have +// already returned "maxSizes" TerminalSizes already. Stores the randomly +// created TerminalSize in "terminalSizes" field for later validation. +func (f *fakeTerminalSizeQueue) Next() *remotecommand.TerminalSize { + if len(f.terminalSizes) >= f.maxSizes { + return nil + } + size := randomTerminalSize() + f.terminalSizes = append(f.terminalSizes, size) + return &size +} + +// randomTerminalSize returns a TerminalSize with random values in the +// range (0-65535) for the fields Width and Height. +func randomTerminalSize() remotecommand.TerminalSize { + randWidth := uint16(mrand.Intn(int(math.Pow(2, 16)))) + randHeight := uint16(mrand.Intn(int(math.Pow(2, 16)))) + return remotecommand.TerminalSize{ + Width: randWidth, + Height: randHeight, + } +} + +// TestStreamTranslator_MultipleWriteChannels +func TestStreamTranslator_TTYResizeChannel(t *testing.T) { + // Create the fake terminal size queue and the actualTerminalSizes which + // will be received at the opposite websocket endpoint. + numSizeQueue := 10000 + sizeQueue := newTerminalSizeQueue(numSizeQueue) + actualTerminalSizes := make([]remotecommand.TerminalSize, 0, numSizeQueue) + // Create upstream fake SPDY server which copies STDIN back onto STDERR stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx, err := createSPDYServerStreams(w, req, Options{ + Tty: true, + }) + if err != nil { + t.Errorf("error on createHTTPStreams: %v", err) + return + } + defer ctx.conn.Close() + // Read the terminal resize requests, storing them in actualTerminalSizes + for i := 0; i < numSizeQueue; i++ { + actualTerminalSize := <-ctx.resizeChan + actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize) + } + })) + defer spdyServer.Close() + // Create StreamTranslatorHandler, which points upstream to fake SPDY server with + // resize (TTY resize) stream. Create test server from StreamTranslatorHandler. + spdyLocation, err := url.Parse(spdyServer.URL) + if err != nil { + t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL) + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streams := Options{Tty: true} + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + options := &remotecommand.StreamOptions{ + Tty: true, + TerminalSizeQueue: sizeQueue, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + // Validate the random TerminalSizes sent on the resize stream are the same + // as the actual TerminalSizes received at the websocket server. + if len(actualTerminalSizes) != numSizeQueue { + t.Fatalf("expected to receive num terminal resizes (%d), got (%d)", + numSizeQueue, len(actualTerminalSizes)) + } + for i, actual := range actualTerminalSizes { + expected := sizeQueue.terminalSizes[i] + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected terminal resize window %v, got %v", expected, actual) + } + } +} + +// TestStreamTranslator_WebSocketServerErrors validates that when there is a problem creating +// the websocket server as the first step of the StreamTranslator an error is properly returned. +func TestStreamTranslator_WebSocketServerErrors(t *testing.T) { + spdyLocation, err := url.Parse("http://127.0.0.1") + if err != nil { + t.Fatalf("Unable to parse spdy server URL") + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, Options{}) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutorForProtocols( + &rest.Config{Host: streamTranslatorLocation.Host}, + "GET", + streamTranslatorServer.URL, + rcconstants.StreamProtocolV4Name, // RemoteCommand V4 protocol is unsupported + ) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. The WebSocket server within the + // StreamTranslator propagates an error here because the V4 protocol is not supported. + errorChan <- exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{}) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Must return "websocket unable to upgrade" (bad handshake) error. + if err == nil { + t.Fatalf("expected error, but received none") + } + if !strings.Contains(err.Error(), "unable to upgrade streaming request") { + t.Errorf("expected websocket bad handshake error, got (%s)", err) + } + } +} + +// TestStreamTranslator_BlockRedirects verifies that the StreamTranslator will *not* follow +// redirects; it will thrown an error instead. +func TestStreamTranslator_BlockRedirects(t *testing.T) { + for _, statusCode := range []int{ + http.StatusMovedPermanently, // 301 + http.StatusFound, // 302 + http.StatusSeeOther, // 303 + http.StatusTemporaryRedirect, // 307 + http.StatusPermanentRedirect, // 308 + } { + // Create upstream fake SPDY server which returns a redirect. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Location", "/") + w.WriteHeader(statusCode) + })) + defer spdyServer.Close() + spdyLocation, err := url.Parse(spdyServer.URL) + if err != nil { + t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL) + } + spdyTransport, err := fakeTransport() + if err != nil { + t.Fatalf("Unexpected error creating transport: %v", err) + } + streams := Options{Stdout: true} + streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams) + streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + streamTranslator.ServeHTTP(w, req) + })) + defer streamTranslatorServer.Close() + // Now create the websocket client (executor), and point it to the "streamTranslatorServer". + streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL) + if err != nil { + t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL) + } + exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL) + if err != nil { + t.Errorf("unexpected error creating websocket executor: %v", err) + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + // Should return "redirect not allowed" error. + errorChan <- exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{}) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Must return "redirect now allowed" error. + if err == nil { + t.Fatalf("expected error, but received none") + } + if !strings.Contains(err.Error(), "redirect not allowed") { + t.Errorf("expected redirect not allowed error, got (%s)", err) + } + } + } +} + +// streamContext encapsulates the structures necessary to communicate through +// a SPDY connection, including the Reader/Writer streams. +type streamContext struct { + conn io.Closer + stdinStream io.ReadCloser + stdoutStream io.WriteCloser + stderrStream io.WriteCloser + resizeStream io.ReadCloser + resizeChan chan remotecommand.TerminalSize + writeStatus func(status *apierrors.StatusError) error +} + +type streamAndReply struct { + httpstream.Stream + replySent <-chan struct{} +} + +// CreateSPDYServerStreams upgrades the passed HTTP request to a SPDY bi-directional streaming +// connection with remote command streams defined in passed options. Returns a streamContext +// structure containing the Reader/Writer streams to communicate through the SDPY connection. +// Returns an error if unable to upgrade the HTTP connection to a SPDY connection. +func createSPDYServerStreams(w http.ResponseWriter, req *http.Request, opts Options) (*streamContext, error) { + _, err := httpstream.Handshake(req, w, []string{rcconstants.StreamProtocolV4Name}) + if err != nil { + return nil, err + } + + upgrader := spdy.NewResponseUpgrader() + streamCh := make(chan streamAndReply) + conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error { + streamCh <- streamAndReply{Stream: stream, replySent: replySent} + return nil + }) + ctx := &streamContext{ + conn: conn, + } + + // wait for stream + replyChan := make(chan struct{}, 5) + defer close(replyChan) + receivedStreams := 0 + expectedStreams := 1 // expect at least the error stream + if opts.Stdout { + expectedStreams++ + } + if opts.Stdin { + expectedStreams++ + } + if opts.Stderr { + expectedStreams++ + } + if opts.Tty { + expectedStreams++ + } +WaitForStreams: + for { + select { + case stream := <-streamCh: + streamType := stream.Headers().Get(v1.StreamType) + switch streamType { + case v1.StreamTypeError: + replyChan <- struct{}{} + ctx.writeStatus = v4WriteStatusFunc(stream) + case v1.StreamTypeStdout: + replyChan <- struct{}{} + ctx.stdoutStream = stream + case v1.StreamTypeStdin: + replyChan <- struct{}{} + ctx.stdinStream = stream + case v1.StreamTypeStderr: + replyChan <- struct{}{} + ctx.stderrStream = stream + case v1.StreamTypeResize: + replyChan <- struct{}{} + ctx.resizeStream = stream + default: + // add other stream ... + return nil, errors.New("unimplemented stream type") + } + case <-replyChan: + receivedStreams++ + if receivedStreams == expectedStreams { + break WaitForStreams + } + } + } + + if ctx.resizeStream != nil { + ctx.resizeChan = make(chan remotecommand.TerminalSize) + go handleResizeEvents(req.Context(), ctx.resizeStream, ctx.resizeChan) + } + + return ctx, nil +} + +func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error { + return func(status *apierrors.StatusError) error { + bs, err := json.Marshal(status.Status()) + if err != nil { + return err + } + _, err = stream.Write(bs) + return err + } +} + +func fakeTransport() (*http.Transport, error) { + cfg := &transport.Config{ + TLS: transport.TLSConfig{ + Insecure: true, + CAFile: "", + }, + } + rt, err := transport.New(cfg) + if err != nil { + return nil, err + } + t, ok := rt.(*http.Transport) + if !ok { + return nil, fmt.Errorf("unknown transport type: %T", rt) + } + return t, nil +} diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/translatinghandler.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/translatinghandler.go new file mode 100644 index 00000000000..6f6c0088241 --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/translatinghandler.go @@ -0,0 +1,51 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "net/http" + + "k8s.io/klog/v2" +) + +// translatingHandler wraps the delegate handler, implementing the +// http.Handler interface. The delegate handles all requests unless +// the request satisfies the passed "shouldTranslate" function +// (currently only for WebSocket/V5 request), in which case the translator +// handles the request. +type translatingHandler struct { + delegate http.Handler + translator http.Handler + shouldTranslate func(*http.Request) bool +} + +func NewTranslatingHandler(delegate http.Handler, translator http.Handler, shouldTranslate func(*http.Request) bool) http.Handler { + return &translatingHandler{ + delegate: delegate, + translator: translator, + shouldTranslate: shouldTranslate, + } +} + +func (t *translatingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.shouldTranslate(req) { + klog.V(4).Infof("request handled by translator proxy") + t.translator.ServeHTTP(w, req) + return + } + t.delegate.ServeHTTP(w, req) +} diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/translatinghandler_test.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/translatinghandler_test.go new file mode 100644 index 00000000000..ee5a53ed88a --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/translatinghandler_test.go @@ -0,0 +1,121 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/httpstream/wsstream" +) + +// fakeHandler implements http.Handler interface +type fakeHandler struct { + served bool +} + +// ServeHTTP stores the fact that this fake handler was called. +func (fh *fakeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + fh.served = true +} + +func TestTranslatingHandler(t *testing.T) { + tests := map[string]struct { + upgrade string + version string + expectTranslator bool + }{ + "websocket/v5 upgrade, serves translator": { + upgrade: "websocket", + version: "v5.channel.k8s.io", + expectTranslator: true, + }, + "websocket/v5 upgrade with multiple other versions, serves translator": { + upgrade: "websocket", + version: "v5.channel.k8s.io, v4.channel.k8s.io, v3.channel.k8s.io", + expectTranslator: true, + }, + "websocket/v5 upgrade with multiple other versions out of order, serves translator": { + upgrade: "websocket", + version: "v4.channel.k8s.io, v3.channel.k8s.io, v5.channel.k8s.io", + expectTranslator: true, + }, + "no upgrade, serves delegate": { + upgrade: "", + version: "", + expectTranslator: false, + }, + "no upgrade with v5, serves delegate": { + upgrade: "", + version: "v5.channel.k8s.io", + expectTranslator: false, + }, + "websocket/v5 wrong case upgrade, serves delegage": { + upgrade: "websocket", + version: "v5.CHANNEL.k8s.io", + expectTranslator: false, + }, + "spdy/v5 upgrade, serves delegate": { + upgrade: "spdy", + version: "v5.channel.k8s.io", + expectTranslator: false, + }, + "spdy/v4 upgrade, serves delegate": { + upgrade: "spdy", + version: "v4.channel.k8s.io", + expectTranslator: false, + }, + "websocket/v4 upgrade, serves delegate": { + upgrade: "websocket", + version: "v4.channel.k8s.io", + expectTranslator: false, + }, + "websocket without version upgrade, serves delegate": { + upgrade: "websocket", + version: "", + expectTranslator: false, + }, + } + for name, test := range tests { + req, err := http.NewRequest("GET", "http://www.example.com/", nil) + require.NoError(t, err) + if test.upgrade != "" { + req.Header.Add("Connection", "Upgrade") + req.Header.Add("Upgrade", test.upgrade) + } + if len(test.version) > 0 { + req.Header.Add(wsstream.WebSocketProtocolHeader, test.version) + } + delegate := fakeHandler{} + translator := fakeHandler{} + translatingHandler := NewTranslatingHandler(&delegate, &translator, + wsstream.IsWebSocketRequestWithStreamCloseProtocol) + translatingHandler.ServeHTTP(nil, req) + if !delegate.served && !translator.served { + t.Errorf("unexpected neither translator nor delegate served") + continue + } + if test.expectTranslator { + if !translator.served { + t.Errorf("%s: expected translator served, got delegate served", name) + } + } else if !delegate.served { + t.Errorf("%s: expected delegate served, got translator served", name) + } + } +} diff --git a/staging/src/k8s.io/apiserver/pkg/util/proxy/websocket.go b/staging/src/k8s.io/apiserver/pkg/util/proxy/websocket.go new file mode 100644 index 00000000000..3b9746b3b2f --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/util/proxy/websocket.go @@ -0,0 +1,200 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/util/httpstream/wsstream" + constants "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/tools/remotecommand" +) + +const ( + // idleTimeout is the read/write deadline set for websocket server connection. Reading + // or writing the connection will return an i/o timeout if this deadline is exceeded. + // Currently, we use the same value as the kubelet websocket server. + defaultIdleConnectionTimeout = 4 * time.Hour + + // Deadline for writing errors to the websocket connection before io/timeout. + writeErrorDeadline = 10 * time.Second +) + +// Options contains details about which streams are required for +// remote command execution. +type Options struct { + Stdin bool + Stdout bool + Stderr bool + Tty bool +} + +// conns contains the connection and streams used when +// forwarding an attach or execute session into a container. +type conns struct { + conn io.Closer + stdinStream io.ReadCloser + stdoutStream io.WriteCloser + stderrStream io.WriteCloser + writeStatus func(status *apierrors.StatusError) error + resizeStream io.ReadCloser + resizeChan chan remotecommand.TerminalSize + tty bool +} + +// Create WebSocket server streams to respond to a WebSocket client. Creates the streams passed +// in the stream options. +func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts Options) (*conns, error) { + ctx, err := createWebSocketStreams(req, w, opts) + if err != nil { + return nil, err + } + + if ctx.resizeStream != nil { + ctx.resizeChan = make(chan remotecommand.TerminalSize) + go func() { + // Resize channel closes in panic case, and panic does not take down caller. + defer func() { + if p := recover(); p != nil { + // Standard panic logging. + for _, fn := range runtime.PanicHandlers { + fn(p) + } + } + }() + handleResizeEvents(req.Context(), ctx.resizeStream, ctx.resizeChan) + }() + } + + return ctx, nil +} + +// Read terminal resize events off of passed stream and queue into passed channel. +func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- remotecommand.TerminalSize) { + defer close(channel) + + decoder := json.NewDecoder(stream) + for { + size := remotecommand.TerminalSize{} + if err := decoder.Decode(&size); err != nil { + break + } + + select { + case channel <- size: + case <-ctx.Done(): + // To avoid leaking this routine, exit if the http request finishes. This path + // would generally be hit if starting the process fails and nothing is started to + // ingest these resize events. + return + } + } +} + +// createChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2) +// along with the approximate duplex value. It also creates the error (3) and resize (4) channels. +func createChannels(opts Options) []wsstream.ChannelType { + // open the requested channels, and always open the error channel + channels := make([]wsstream.ChannelType, 5) + channels[constants.StreamStdIn] = readChannel(opts.Stdin) + channels[constants.StreamStdOut] = writeChannel(opts.Stdout) + channels[constants.StreamStdErr] = writeChannel(opts.Stderr) + channels[constants.StreamErr] = wsstream.WriteChannel + channels[constants.StreamResize] = wsstream.ReadChannel + return channels +} + +// readChannel returns wsstream.ReadChannel if real is true, or wsstream.IgnoreChannel. +func readChannel(real bool) wsstream.ChannelType { + if real { + return wsstream.ReadChannel + } + return wsstream.IgnoreChannel +} + +// writeChannel returns wsstream.WriteChannel if real is true, or wsstream.IgnoreChannel. +func writeChannel(real bool) wsstream.ChannelType { + if real { + return wsstream.WriteChannel + } + return wsstream.IgnoreChannel +} + +// createWebSocketStreams returns a "conns" struct containing the websocket connection and +// streams needed to perform an exec or an attach. +func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts Options) (*conns, error) { + channels := createChannels(opts) + conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{ + // WebSocket server only supports remote command version 5. + constants.StreamProtocolV5Name: { + Binary: true, + Channels: channels, + }, + }) + conn.SetIdleTimeout(defaultIdleConnectionTimeout) + // Opening the connection responds to WebSocket client, negotiating + // the WebSocket upgrade connection and the subprotocol. + _, streams, err := conn.Open(w, req) + if err != nil { + return nil, err + } + + // Send an empty message to the lowest writable channel to notify the client the connection is established + switch { + case opts.Stdout: + _, err = streams[constants.StreamStdOut].Write([]byte{}) + case opts.Stderr: + _, err = streams[constants.StreamStdErr].Write([]byte{}) + default: + _, err = streams[constants.StreamErr].Write([]byte{}) + } + if err != nil { + conn.Close() + return nil, fmt.Errorf("write error during websocket server creation: %v", err) + } + + ctx := &conns{ + conn: conn, + stdinStream: streams[constants.StreamStdIn], + stdoutStream: streams[constants.StreamStdOut], + stderrStream: streams[constants.StreamStdErr], + tty: opts.Tty, + resizeStream: streams[constants.StreamResize], + } + + // writeStatus returns a WriteStatusFunc that marshals a given api Status + // as json in the error channel. + ctx.writeStatus = func(status *apierrors.StatusError) error { + bs, err := json.Marshal(status.Status()) + if err != nil { + return err + } + // Write status error to error stream with deadline. + conn.SetWriteDeadline(writeErrorDeadline) + _, err = streams[constants.StreamErr].Write(bs) + return err + } + + return ctx, nil +} diff --git a/staging/src/k8s.io/client-go/go.mod b/staging/src/k8s.io/client-go/go.mod index b83879f2681..a084e772608 100644 --- a/staging/src/k8s.io/client-go/go.mod +++ b/staging/src/k8s.io/client-go/go.mod @@ -49,6 +49,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/sys v0.13.0 // indirect diff --git a/staging/src/k8s.io/client-go/go.sum b/staging/src/k8s.io/client-go/go.sum index cadbb3297c4..820f7bd63af 100644 --- a/staging/src/k8s.io/client-go/go.sum +++ b/staging/src/k8s.io/client-go/go.sum @@ -75,6 +75,7 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/fallback.go b/staging/src/k8s.io/client-go/tools/remotecommand/fallback.go new file mode 100644 index 00000000000..4846cdb5509 --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/remotecommand/fallback.go @@ -0,0 +1,57 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" +) + +var _ Executor = &fallbackExecutor{} + +type fallbackExecutor struct { + primary Executor + secondary Executor + shouldFallback func(error) bool +} + +// NewFallbackExecutor creates an Executor that first attempts to use the +// WebSocketExecutor, falling back to the legacy SPDYExecutor if the initial +// websocket "StreamWithContext" call fails. +// func NewFallbackExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) { +func NewFallbackExecutor(primary, secondary Executor, shouldFallback func(error) bool) (Executor, error) { + return &fallbackExecutor{ + primary: primary, + secondary: secondary, + shouldFallback: shouldFallback, + }, nil +} + +// Stream is deprecated. Please use "StreamWithContext". +func (f *fallbackExecutor) Stream(options StreamOptions) error { + return f.StreamWithContext(context.Background(), options) +} + +// StreamWithContext initially attempts to call "StreamWithContext" using the +// primary executor, falling back to calling the secondary executor if the +// initial primary call to upgrade to a websocket connection fails. +func (f *fallbackExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + err := f.primary.StreamWithContext(ctx, options) + if f.shouldFallback(err) { + return f.secondary.StreamWithContext(ctx, options) + } + return err +} diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/fallback_test.go b/staging/src/k8s.io/client-go/tools/remotecommand/fallback_test.go new file mode 100644 index 00000000000..70049857050 --- /dev/null +++ b/staging/src/k8s.io/client-go/tools/remotecommand/fallback_test.go @@ -0,0 +1,227 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "bytes" + "context" + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/rest" +) + +func TestFallbackClient_WebSocketPrimarySucceeds(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + require.NoError(t, err) + })) + defer websocketServer.Close() + + // Now create the fallback client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation) + require.NoError(t, err) + // Never fallback, so always use the websocketExecutor, which succeeds against websocket server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return false }) + require.NoError(t, err) + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + // Start the streaming on the WebSocket "exec" client. + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + } + + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +func TestFallbackClient_SPDYSecondarySucceeds(t *testing.T) { + // Create fake SPDY server. Copy received STDIN data back onto STDOUT stream. + spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var stdin, stdout bytes.Buffer + ctx, err := createHTTPStreams(w, req, &StreamOptions{ + Stdin: &stdin, + Stdout: &stdout, + }) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer ctx.conn.Close() + _, err = io.Copy(ctx.stdoutStream, ctx.stdinStream) + if err != nil { + t.Fatalf("error copying STDIN to STDOUT: %v", err) + } + })) + defer spdyServer.Close() + + spdyLocation, err := url.Parse(spdyServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: spdyLocation.Host}, "GET", spdyServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: spdyLocation.Host}, "POST", spdyLocation) + require.NoError(t, err) + // Always fallback to spdyExecutor, and spdyExecutor succeeds against fake spdy server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true }) + require.NoError(t, err) + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + if err != nil { + t.Errorf("unexpected error") + } + } + + data, err := io.ReadAll(bytes.NewReader(stdout.Bytes())) + if err != nil { + t.Errorf("error reading the stream: %v", err) + return + } + // Check the random data sent on STDIN was the same returned on STDOUT. + if !bytes.Equal(randomData, data) { + t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData)) + } +} + +func TestFallbackClient_PrimaryAndSecondaryFail(t *testing.T) { + // Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err != nil { + w.WriteHeader(http.StatusForbidden) + return + } + defer conns.conn.Close() + // Loopback the STDIN stream onto the STDOUT stream. + _, err = io.Copy(conns.stdoutStream, conns.stdinStream) + require.NoError(t, err) + })) + defer websocketServer.Close() + + // Now create the fallback client (executor), and point it to the "websocketServer". + // Must add STDIN and STDOUT query params for the client request. + websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) + require.NoError(t, err) + spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation) + require.NoError(t, err) + // Always fallback to spdyExecutor, but spdyExecutor fails against websocket server. + exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true }) + require.NoError(t, err) + // Update the websocket executor to request remote command v4, which is unsupported. + fallbackExec, ok := exec.(*fallbackExecutor) + assert.True(t, ok, "error casting executor as fallbackExecutor") + websocketExec, ok := fallbackExec.primary.(*wsStreamExecutor) + assert.True(t, ok, "error casting executor as websocket executor") + // Set the attempted subprotocol version to V4; websocket server only accepts V5. + websocketExec.protocols = []string{remotecommand.StreamProtocolV4Name} + + // Generate random data, and set it up to stream on STDIN. The data will be + // returned on the STDOUT buffer. + randomSize := 1024 * 1024 + randomData := make([]byte, randomSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + var stdout bytes.Buffer + options := &StreamOptions{ + Stdin: bytes.NewReader(randomData), + Stdout: &stdout, + } + errorChan := make(chan error) + go func() { + errorChan <- exec.StreamWithContext(context.Background(), *options) + }() + + select { + case <-time.After(wait.ForeverTestTimeout): + t.Fatalf("expect stream to be closed after connection is closed.") + case err := <-errorChan: + // Ensure secondary executor returned an error. + require.Error(t, err) + } +} diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/spdy.go b/staging/src/k8s.io/client-go/tools/remotecommand/spdy.go index 76ea946b535..c2bfcf8a654 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/spdy.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/spdy.go @@ -34,9 +34,10 @@ type spdyStreamExecutor struct { upgrader spdy.Upgrader transport http.RoundTripper - method string - url *url.URL - protocols []string + method string + url *url.URL + protocols []string + rejectRedirects bool // if true, receiving redirect from upstream is an error } // NewSPDYExecutor connects to the provided server and upgrades the connection to @@ -49,6 +50,20 @@ func NewSPDYExecutor(config *restclient.Config, method string, url *url.URL) (Ex return NewSPDYExecutorForTransports(wrapper, upgradeRoundTripper, method, url) } +// NewSPDYExecutorRejectRedirects returns an Executor that will upgrade the future +// connection to a SPDY bi-directional streaming connection when calling "Stream" (deprecated) +// or "StreamWithContext" (preferred). Additionally, if the upstream server returns a redirect +// during the attempted upgrade in these "Stream" calls, an error is returned. +func NewSPDYExecutorRejectRedirects(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { + executor, err := NewSPDYExecutorForTransports(transport, upgrader, method, url) + if err != nil { + return nil, err + } + spdyExecutor := executor.(*spdyStreamExecutor) + spdyExecutor.rejectRedirects = true + return spdyExecutor, nil +} + // NewSPDYExecutorForTransports connects to the provided server using the given transport, // upgrades the response using the given upgrader to multiplexed bidirectional streams. func NewSPDYExecutorForTransports(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { @@ -88,9 +103,15 @@ func (e *spdyStreamExecutor) newConnectionAndStream(ctx context.Context, options return nil, nil, fmt.Errorf("error creating request: %v", err) } + client := http.Client{Transport: e.transport} + if e.rejectRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return fmt.Errorf("redirect not allowed") + } + } conn, protocol, err := spdy.Negotiate( e.upgrader, - &http.Client{Transport: e.transport}, + &client, req, e.protocols..., ) diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/spdy_test.go b/staging/src/k8s.io/client-go/tools/remotecommand/spdy_test.go index c11177a047f..1b1cf7491d2 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/spdy_test.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/spdy_test.go @@ -183,6 +183,7 @@ func TestSPDYExecutorStream(t *testing.T) { } func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { + //nolint:errcheck server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { ctx, err := createHTTPStreams(writer, request, options) if err != nil { @@ -381,7 +382,7 @@ func TestStreamRandomData(t *testing.T) { } defer ctx.conn.Close() - io.Copy(ctx.stdoutStream, ctx.stdinStream) + io.Copy(ctx.stdoutStream, ctx.stdinStream) //nolint:errcheck })) defer server.Close() diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go b/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go index 48e52092ee2..a60986decca 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go @@ -85,22 +85,26 @@ type wsStreamExecutor struct { heartbeatDeadline time.Duration } -// NewWebSocketExecutor allows to execute commands via a WebSocket connection. func NewWebSocketExecutor(config *restclient.Config, method, url string) (Executor, error) { + // Only supports V5 protocol for correct version skew functionality. + // Previous api servers will proxy upgrade requests to legacy websocket + // servers on container runtimes which support V1-V4. These legacy + // websocket servers will not handle the new CLOSE signal. + return NewWebSocketExecutorForProtocols(config, method, url, remotecommand.StreamProtocolV5Name) +} + +// NewWebSocketExecutorForProtocols allows to execute commands via a WebSocket connection. +func NewWebSocketExecutorForProtocols(config *restclient.Config, method, url string, protocols ...string) (Executor, error) { transport, upgrader, err := websocket.RoundTripperFor(config) if err != nil { return nil, fmt.Errorf("error creating websocket transports: %v", err) } return &wsStreamExecutor{ - transport: transport, - upgrader: upgrader, - method: method, - url: url, - // Only supports V5 protocol for correct version skew functionality. - // Previous api servers will proxy upgrade requests to legacy websocket - // servers on container runtimes which support V1-V4. These legacy - // websocket servers will not handle the new CLOSE signal. - protocols: []string{remotecommand.StreamProtocolV5Name}, + transport: transport, + upgrader: upgrader, + method: method, + url: url, + protocols: protocols, heartbeatPeriod: pingPeriod, heartbeatDeadline: pingReadDeadline, }, nil @@ -177,10 +181,12 @@ func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options Stream } type wsStreamCreator struct { - conn *gwebsocket.Conn + conn *gwebsocket.Conn + // Protects writing to websocket connection; reading is lock-free connWriteLock sync.Mutex - streams map[byte]*stream - streamsMu sync.Mutex + // map of stream id to stream; multiple streams read/write the connection + streams map[byte]*stream + streamsMu sync.Mutex } func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { @@ -226,7 +232,7 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, return s, nil } -// readDemuxLoop is the reading processor for this endpoint of the websocket +// readDemuxLoop is the lock-free reading processor for this endpoint of the websocket // connection. This loop reads the connection, and demultiplexes the data // into one of the individual stream pipes (by checking the stream id). This // loop can *not* be run concurrently, because there can only be one websocket diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go b/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go index 2895ba548cd..61df2b77a4c 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go @@ -74,7 +74,7 @@ func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -149,7 +149,7 @@ func TestWebSocketClient_DifferentBufferSizes(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -223,7 +223,7 @@ func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -304,7 +304,7 @@ func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -377,7 +377,7 @@ func TestWebSocketClient_MultipleReadChannels(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -479,7 +479,7 @@ func TestWebSocketClient_ErrorStream(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -637,7 +637,7 @@ func TestWebSocketClient_MultipleWriteChannels(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -723,7 +723,7 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -766,11 +766,14 @@ func TestWebSocketClient_ProtocolVersions(t *testing.T) { func TestWebSocketClient_BadHandshake(t *testing.T) { // Create fake WebSocket server (supports V5 subprotocol). websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) - if err != nil { - t.Fatalf("error on webSocketServerStreams: %v", err) + // Bad handshake means websocket server will not completely initialize. + _, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) + if err == nil { + t.Fatalf("expected error, but received none.") + } + if !strings.Contains(err.Error(), "websocket server finished before becoming ready") { + t.Errorf("expected websocket server error, but got: %v", err) } - defer conns.conn.Close() })) defer websocketServer.Close() @@ -779,7 +782,7 @@ func TestWebSocketClient_BadHandshake(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -831,7 +834,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -909,7 +912,7 @@ func TestWebSocketClient_TextMessageTypeError(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -970,7 +973,7 @@ func TestWebSocketClient_EmptyMessageHandled(t *testing.T) { if err != nil { t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL) } - exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketServer.URL) + exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL) if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -1009,14 +1012,14 @@ func TestWebSocketClient_ExecutorErrors(t *testing.T) { ExecProvider: &clientcmdapi.ExecConfig{}, AuthProvider: &clientcmdapi.AuthProviderConfig{}, } - _, err := NewWebSocketExecutor(&config, "POST", "http://localhost") + _, err := NewWebSocketExecutor(&config, "GET", "http://localhost") if err == nil { t.Errorf("expecting executor constructor error, but received none.") } else if !strings.Contains(err.Error(), "error creating websocket transports") { t.Errorf("expecting error creating transports, got (%s)", err.Error()) } // Verify that a nil context will cause an error in StreamWithContext - exec, err := NewWebSocketExecutor(&rest.Config{}, "POST", "http://localhost") + exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost") if err != nil { t.Errorf("unexpected error creating websocket executor: %v", err) } @@ -1316,7 +1319,16 @@ func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *opti resizeStream: streams[remotecommand.StreamResize], } - wsStreams.writeStatus = v4WriteStatusFunc(streams[remotecommand.StreamErr]) + wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error { + return func(status *apierrors.StatusError) error { + bs, err := json.Marshal(status.Status()) + if err != nil { + return err + } + _, err = stream.Write(bs) + return err + } + }(streams[remotecommand.StreamErr]) return wsStreams, nil } diff --git a/staging/src/k8s.io/client-go/transport/spdy/spdy.go b/staging/src/k8s.io/client-go/transport/spdy/spdy.go index f50b68e5ffb..9fddc6c5f23 100644 --- a/staging/src/k8s.io/client-go/transport/spdy/spdy.go +++ b/staging/src/k8s.io/client-go/transport/spdy/spdy.go @@ -43,11 +43,15 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, er if config.Proxy != nil { proxy = config.Proxy } - upgradeRoundTripper := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{ - TLS: tlsConfig, - Proxier: proxy, - PingPeriod: time.Second * 5, + upgradeRoundTripper, err := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{ + TLS: tlsConfig, + Proxier: proxy, + PingPeriod: time.Second * 5, + UpgradeTransport: nil, }) + if err != nil { + return nil, nil, err + } wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper) if err != nil { return nil, nil, err diff --git a/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go b/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go index e2a4a8abccf..010f916bc7b 100644 --- a/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go +++ b/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go @@ -108,10 +108,7 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response } wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header) if err != nil { - if err != gwebsocket.ErrBadHandshake { - return nil, err - } - return nil, fmt.Errorf("unable to upgrade connection: %v", err) + return nil, &httpstream.UpgradeFailureError{Cause: err} } rt.Conn = wsConn @@ -155,7 +152,7 @@ func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http. req.Header[httpstream.HeaderProtocolVersion] = protocols resp, err := rt.RoundTrip(req) if err != nil { - return nil, fmt.Errorf("error sending request: %v", err) + return nil, err } err = resp.Body.Close() if err != nil { diff --git a/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go b/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go index 168d5d5509b..16bfbf570ba 100644 --- a/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go +++ b/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go @@ -49,7 +49,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) @@ -67,18 +67,17 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) { // Create fake WebSocket server. websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - conns, err := webSocketServerStreams(req, w) - if err != nil { - t.Fatalf("error on webSocketServerStreams: %v", err) - } - defer conns.conn.Close() + // Bad handshake means websocket server will not completely initialize. + _, err := webSocketServerStreams(req, w) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "websocket server finished before becoming ready")) })) defer websocketServer.Close() // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) @@ -105,7 +104,7 @@ func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) { // Create the websocket roundtripper and call "Negotiate" to create websocket connection. websocketLocation, err := url.Parse(websocketServer.URL) require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "POST", websocketServer.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) require.NoError(t, err) rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) diff --git a/staging/src/k8s.io/kube-aggregator/go.mod b/staging/src/k8s.io/kube-aggregator/go.mod index 41cb5046be0..a1547f950d1 100644 --- a/staging/src/k8s.io/kube-aggregator/go.mod +++ b/staging/src/k8s.io/kube-aggregator/go.mod @@ -49,6 +49,7 @@ require ( github.com/google/cel-go v0.17.6 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/uuid v1.3.0 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect github.com/imdario/mergo v0.3.6 // indirect @@ -57,6 +58,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/moby/spdystream v0.2.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/staging/src/k8s.io/kube-aggregator/go.sum b/staging/src/k8s.io/kube-aggregator/go.sum index 2ae0229af73..2e2ce6f0208 100644 --- a/staging/src/k8s.io/kube-aggregator/go.sum +++ b/staging/src/k8s.io/kube-aggregator/go.sum @@ -163,6 +163,7 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= @@ -326,6 +327,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= @@ -369,6 +371,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8= github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c= github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go b/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go index 263e006c492..af25e072941 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go @@ -28,6 +28,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/cli-runtime/pkg/genericclioptions" "k8s.io/cli-runtime/pkg/genericiooptions" "k8s.io/cli-runtime/pkg/resource" @@ -125,7 +126,7 @@ func NewCmdAttach(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra. // RemoteAttach defines the interface accepted by the Attach command - provided for test stubbing type RemoteAttach interface { - Attach(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error + Attach(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error } // DefaultAttachFunc is the default AttachFunc used @@ -148,7 +149,7 @@ func DefaultAttachFunc(o *AttachOptions, containerToAttach *corev1.Container, ra TTY: raw, }, scheme.ParameterCodec) - return o.Attach.Attach("POST", req.URL(), o.Config, o.In, o.Out, o.ErrOut, raw, sizeQueue) + return o.Attach.Attach(req.URL(), o.Config, o.In, o.Out, o.ErrOut, raw, sizeQueue) } } @@ -156,11 +157,24 @@ func DefaultAttachFunc(o *AttachOptions, containerToAttach *corev1.Container, ra type DefaultRemoteAttach struct{} // Attach executes attach to a running container -func (*DefaultRemoteAttach) Attach(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { - exec, err := remotecommand.NewSPDYExecutor(config, method, url) +func (*DefaultRemoteAttach) Attach(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { + // Legacy SPDY executor is default. If feature gate enabled, fallback + // executor attempts websockets first--then SPDY. + exec, err := remotecommand.NewSPDYExecutor(config, "POST", url) if err != nil { return err } + if cmdutil.RemoteCommandWebsockets.IsEnabled() { + // WebSocketExecutor must be "GET" method as described in RFC 6455 Sec. 4.1 (page 17). + websocketExec, err := remotecommand.NewWebSocketExecutor(config, "GET", url.String()) + if err != nil { + return err + } + exec, err = remotecommand.NewFallbackExecutor(websocketExec, exec, httpstream.IsUpgradeFailure) + if err != nil { + return err + } + } return exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{ Stdin: stdin, Stdout: stdout, diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach_test.go b/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach_test.go index 24b6e71d2f1..6d491323ebb 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach_test.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/attach/attach_test.go @@ -43,13 +43,11 @@ import ( ) type fakeRemoteAttach struct { - method string - url *url.URL - err error + url *url.URL + err error } -func (f *fakeRemoteAttach) Attach(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { - f.method = method +func (f *fakeRemoteAttach) Attach(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { f.url = url return f.err } @@ -327,7 +325,7 @@ func TestAttach(t *testing.T) { return err } - return options.Attach.Attach("POST", u, nil, nil, nil, nil, raw, sizeQueue) + return options.Attach.Attach(u, nil, nil, nil, nil, raw, sizeQueue) } } @@ -347,9 +345,6 @@ func TestAttach(t *testing.T) { t.Errorf("%s: Did not get expected path for exec request: %q %q", test.name, test.attachPath, remoteAttach.url.Path) return } - if remoteAttach.method != "POST" { - t.Errorf("%s: Did not get method for attach request: %s", test.name, remoteAttach.method) - } if remoteAttach.url.Query().Get("container") != "bar" { t.Errorf("%s: Did not have query parameters: %s", test.name, remoteAttach.url.Query()) } @@ -428,7 +423,7 @@ func TestAttachWarnings(t *testing.T) { return err } - return options.Attach.Attach("POST", u, nil, nil, nil, nil, raw, sizeQueue) + return options.Attach.Attach(u, nil, nil, nil, nil, raw, sizeQueue) } } diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go b/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go index 2a29aecf8b8..36d43beceb9 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec.go @@ -27,6 +27,7 @@ import ( "github.com/spf13/cobra" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/cli-runtime/pkg/genericclioptions" "k8s.io/cli-runtime/pkg/genericiooptions" "k8s.io/cli-runtime/pkg/resource" @@ -113,17 +114,30 @@ func NewCmdExec(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Co // RemoteExecutor defines the interface accepted by the Exec command - provided for test stubbing type RemoteExecutor interface { - Execute(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error + Execute(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error } // DefaultRemoteExecutor is the standard implementation of remote command execution type DefaultRemoteExecutor struct{} -func (*DefaultRemoteExecutor) Execute(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { - exec, err := remotecommand.NewSPDYExecutor(config, method, url) +func (*DefaultRemoteExecutor) Execute(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { + // Legacy SPDY executor is default. If feature gate enabled, fallback + // executor attempts websockets first--then SPDY. + exec, err := remotecommand.NewSPDYExecutor(config, "POST", url) if err != nil { return err } + if cmdutil.RemoteCommandWebsockets.IsEnabled() { + // WebSocketExecutor must be "GET" method as described in RFC 6455 Sec. 4.1 (page 17). + websocketExec, err := remotecommand.NewWebSocketExecutor(config, "GET", url.String()) + if err != nil { + return err + } + exec, err = remotecommand.NewFallbackExecutor(websocketExec, exec, httpstream.IsUpgradeFailure) + if err != nil { + return err + } + } return exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{ Stdin: stdin, Stdout: stdout, @@ -371,7 +385,7 @@ func (p *ExecOptions) Run() error { TTY: t.Raw, }, scheme.ParameterCodec) - return p.Executor.Execute("POST", req.URL(), p.Config, p.In, p.Out, p.ErrOut, t.Raw, sizeQueue) + return p.Executor.Execute(req.URL(), p.Config, p.In, p.Out, p.ErrOut, t.Raw, sizeQueue) } if err := t.Safe(fn); err != nil { diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec_test.go b/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec_test.go index 82ffe85e75d..7305231f129 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec_test.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/exec/exec_test.go @@ -40,13 +40,11 @@ import ( ) type fakeRemoteExecutor struct { - method string url *url.URL execErr error } -func (f *fakeRemoteExecutor) Execute(method string, url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { - f.method = method +func (f *fakeRemoteExecutor) Execute(url *url.URL, config *restclient.Config, stdin io.Reader, stdout, stderr io.Writer, tty bool, terminalSizeQueue remotecommand.TerminalSizeQueue) error { f.url = url return f.execErr } @@ -264,9 +262,6 @@ func TestExec(t *testing.T) { t.Errorf("%s: Did not get expected container query param for exec request", test.name) return } - if ex.method != "POST" { - t.Errorf("%s: Did not get method for exec request: %s", test.name, ex.method) - } }) } } diff --git a/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go b/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go index bbb6e701bbe..03f3e7f0c7f 100644 --- a/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go +++ b/staging/src/k8s.io/kubectl/pkg/cmd/util/helpers.go @@ -425,8 +425,10 @@ func GetPodRunningTimeoutFlag(cmd *cobra.Command) (time.Duration, error) { type FeatureGate string const ( - ApplySet FeatureGate = "KUBECTL_APPLYSET" - CmdPluginAsSubcommand FeatureGate = "KUBECTL_ENABLE_CMD_SHADOW" + ApplySet FeatureGate = "KUBECTL_APPLYSET" + CmdPluginAsSubcommand FeatureGate = "KUBECTL_ENABLE_CMD_SHADOW" + InteractiveDelete FeatureGate = "KUBECTL_INTERACTIVE_DELETE" + RemoteCommandWebsockets FeatureGate = "KUBECTL_REMOTE_COMMAND_WEBSOCKETS" ) // IsEnabled returns true iff environment variable is set to true. diff --git a/staging/src/k8s.io/kubelet/go.mod b/staging/src/k8s.io/kubelet/go.mod index 50febbba119..736e1b6a05e 100644 --- a/staging/src/k8s.io/kubelet/go.mod +++ b/staging/src/k8s.io/kubelet/go.mod @@ -35,6 +35,7 @@ require ( github.com/moby/spdystream v0.2.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.16.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect diff --git a/staging/src/k8s.io/kubelet/go.sum b/staging/src/k8s.io/kubelet/go.sum index a1e396e4a72..358384dd8cf 100644 --- a/staging/src/k8s.io/kubelet/go.sum +++ b/staging/src/k8s.io/kubelet/go.sum @@ -111,6 +111,7 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= diff --git a/test/e2e/kubectl/kubectl.go b/test/e2e/kubectl/kubectl.go index b05ac1ceaa2..5f655454122 100644 --- a/test/e2e/kubectl/kubectl.go +++ b/test/e2e/kubectl/kubectl.go @@ -42,6 +42,8 @@ import ( "sigs.k8s.io/yaml" + utilkubectl "k8s.io/kubectl/pkg/cmd/util" + v1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" @@ -801,6 +803,66 @@ metadata: framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test-3", metav1.DeleteOptions{})) }) + ginkgo.It("should support inline execution and attach with websockets or fallback to spdy", func(ctx context.Context) { + waitForStdinContent := func(pod, content string) string { + var logOutput string + err := wait.PollUntilContextTimeout(ctx, 10*time.Second, 5*time.Minute, false, func(ctx context.Context) (bool, error) { + logOutput = e2ekubectl.RunKubectlOrDie(ns, "logs", pod) + return strings.Contains(logOutput, content), nil + }) + framework.ExpectNoError(err, "waiting for '%v' output", content) + return logOutput + } + + ginkgo.By("executing a command with run and attach with stdin") + // We wait for a non-empty line so we know kubectl has attached + e2ekubectl.NewKubectlCommand(ns, "run", "run-test", "--image="+busyboxImage, "--restart=OnFailure", podRunningTimeoutArg, "--attach=true", "--stdin", "--", "sh", "-c", "echo -n read: && cat && echo 'stdin closed'"). + WithStdinData("value\nabcd1234"). + AppendEnv([]string{string(utilkubectl.RemoteCommandWebsockets), "true"}). + ExecOrDie(ns) + + runOutput := waitForStdinContent("run-test", "stdin closed") + gomega.Expect(runOutput).To(gomega.ContainSubstring("read:value")) + gomega.Expect(runOutput).To(gomega.ContainSubstring("abcd1234")) + gomega.Expect(runOutput).To(gomega.ContainSubstring("stdin closed")) + + framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test", metav1.DeleteOptions{})) + + ginkgo.By("executing a command with run and attach without stdin") + // There is a race on this scenario described in #73099 + // It fails if we are not able to attach before the container prints + // "stdin closed", but hasn't exited yet. + // We wait 10 seconds before printing to give time to kubectl to attach + // to the container, this does not solve the race though. + e2ekubectl.NewKubectlCommand(ns, "run", "run-test-2", "--image="+busyboxImage, "--restart=OnFailure", podRunningTimeoutArg, "--attach=true", "--leave-stdin-open=true", "--", "sh", "-c", "cat && echo 'stdin closed'"). + WithStdinData("abcd1234"). + AppendEnv([]string{string(utilkubectl.RemoteCommandWebsockets), "true"}). + ExecOrDie(ns) + + runOutput = waitForStdinContent("run-test-2", "stdin closed") + gomega.Expect(runOutput).ToNot(gomega.ContainSubstring("abcd1234")) + gomega.Expect(runOutput).To(gomega.ContainSubstring("stdin closed")) + + framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test-2", metav1.DeleteOptions{})) + + ginkgo.By("executing a command with run and attach with stdin with open stdin should remain running") + e2ekubectl.NewKubectlCommand(ns, "run", "run-test-3", "--image="+busyboxImage, "--restart=OnFailure", podRunningTimeoutArg, "--attach=true", "--leave-stdin-open=true", "--stdin", "--", "sh", "-c", "cat && echo 'stdin closed'"). + WithStdinData("abcd1234\n"). + AppendEnv([]string{string(utilkubectl.RemoteCommandWebsockets), "true"}). + ExecOrDie(ns) + + runOutput = waitForStdinContent("run-test-3", "abcd1234") + gomega.Expect(runOutput).To(gomega.ContainSubstring("abcd1234")) + gomega.Expect(runOutput).ToNot(gomega.ContainSubstring("stdin closed")) + + g := func(pods []*v1.Pod) sort.Interface { return sort.Reverse(controller.ActivePods(pods)) } + runTestPod, _, err := polymorphichelpers.GetFirstPod(f.ClientSet.CoreV1(), ns, "run=run-test-3", 1*time.Minute, g) + framework.ExpectNoError(err) + framework.ExpectNoError(e2epod.WaitTimeoutForPodReadyInNamespace(ctx, c, runTestPod.Name, ns, time.Minute)) + + framework.ExpectNoError(c.CoreV1().Pods(ns).Delete(ctx, "run-test-3", metav1.DeleteOptions{})) + }) + ginkgo.It("should contain last line of the log", func(ctx context.Context) { podName := "run-log-test"