mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-10-31 02:08:13 +00:00 
			
		
		
		
	Merge pull request #33684 from fraenkel/port_forward_ws
Automatic merge from submit-queue Add websocket support for port forwarding #32880 **Release note**: ```release-note Port forwarding can forward over websockets or SPDY. ```
This commit is contained in:
		| @@ -3798,6 +3798,13 @@ | ||||
|       "name": "namespace", | ||||
|       "in": "path", | ||||
|       "required": true | ||||
|      }, | ||||
|      { | ||||
|       "uniqueItems": true, | ||||
|       "type": "integer", | ||||
|       "description": "List of ports to forward Required when using WebSockets", | ||||
|       "name": "ports", | ||||
|       "in": "query" | ||||
|      } | ||||
|     ] | ||||
|    }, | ||||
|   | ||||
| @@ -9095,6 +9095,14 @@ | ||||
|       "summary": "connect GET requests to portforward of Pod", | ||||
|       "nickname": "connectGetNamespacedPodPortforward", | ||||
|       "parameters": [ | ||||
|        { | ||||
|         "type": "integer", | ||||
|         "paramType": "query", | ||||
|         "name": "ports", | ||||
|         "description": "List of ports to forward Required when using WebSockets", | ||||
|         "required": false, | ||||
|         "allowMultiple": false | ||||
|        }, | ||||
|        { | ||||
|         "type": "string", | ||||
|         "paramType": "path", | ||||
| @@ -9125,6 +9133,14 @@ | ||||
|       "summary": "connect POST requests to portforward of Pod", | ||||
|       "nickname": "connectPostNamespacedPodPortforward", | ||||
|       "parameters": [ | ||||
|        { | ||||
|         "type": "integer", | ||||
|         "paramType": "query", | ||||
|         "name": "ports", | ||||
|         "description": "List of ports to forward Required when using WebSockets", | ||||
|         "required": false, | ||||
|         "allowMultiple": false | ||||
|        }, | ||||
|        { | ||||
|         "type": "string", | ||||
|         "paramType": "path", | ||||
|   | ||||
| @@ -27,7 +27,7 @@ spec: | ||||
|     command: | ||||
|     - /bin/sh | ||||
|     - -c | ||||
|     - "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.0 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;" | ||||
|     - "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.2 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;" | ||||
|     securityContext: | ||||
|       privileged: true | ||||
|     volumeMounts: | ||||
|   | ||||
| @@ -9047,6 +9047,14 @@ span.icon > [class^="icon-"], span.icon > [class*=" icon-"] { cursor: default; } | ||||
| </thead> | ||||
| <tbody> | ||||
| <tr> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">QueryParameter</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">ports</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">List of ports to forward Required when using WebSockets</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">false</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">integer (int32)</p></td> | ||||
| <td class="tableblock halign-left valign-top"></td> | ||||
| </tr> | ||||
| <tr> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">PathParameter</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">namespace</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">object name and auth scope, such as for teams and projects</p></td> | ||||
| @@ -9152,6 +9160,14 @@ span.icon > [class^="icon-"], span.icon > [class*=" icon-"] { cursor: default; } | ||||
| </thead> | ||||
| <tbody> | ||||
| <tr> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">QueryParameter</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">ports</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">List of ports to forward Required when using WebSockets</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">false</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">integer (int32)</p></td> | ||||
| <td class="tableblock halign-left valign-top"></td> | ||||
| </tr> | ||||
| <tr> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">PathParameter</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">namespace</p></td> | ||||
| <td class="tableblock halign-left valign-top"><p class="tableblock">object name and auth scope, such as for teams and projects</p></td> | ||||
| @@ -33308,7 +33324,7 @@ span.icon > [class^="icon-"], span.icon > [class*=" icon-"] { cursor: default; } | ||||
| </div> | ||||
| <div id="footer"> | ||||
| <div id="footer-text"> | ||||
| Last updated 2017-01-06 18:13:51 UTC | ||||
| Last updated 2017-02-01 12:44:12 UTC | ||||
| </div> | ||||
| </div> | ||||
| </body> | ||||
|   | ||||
| @@ -42,16 +42,16 @@ import ( | ||||
| type fakePortForwarder struct { | ||||
| 	lock sync.Mutex | ||||
| 	// stores data expected from the stream per port | ||||
| 	expected map[uint16]string | ||||
| 	expected map[int32]string | ||||
| 	// stores data received from the stream per port | ||||
| 	received map[uint16]string | ||||
| 	received map[int32]string | ||||
| 	// data to be sent to the stream per port | ||||
| 	send map[uint16]string | ||||
| 	send map[int32]string | ||||
| } | ||||
|  | ||||
| var _ portforward.PortForwarder = &fakePortForwarder{} | ||||
|  | ||||
| func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { | ||||
| func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { | ||||
| 	defer stream.Close() | ||||
|  | ||||
| 	// read from the client | ||||
| @@ -77,14 +77,14 @@ func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16 | ||||
|  | ||||
| // fakePortForwardServer creates an HTTP server that can handle port forwarding | ||||
| // requests. | ||||
| func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[uint16]string) http.HandlerFunc { | ||||
| func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[int32]string) http.HandlerFunc { | ||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||||
| 		pf := &fakePortForwarder{ | ||||
| 			expected: expectedFromClient, | ||||
| 			received: make(map[uint16]string), | ||||
| 			received: make(map[int32]string), | ||||
| 			send:     serverSends, | ||||
| 		} | ||||
| 		portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second) | ||||
| 		portforward.ServePortForward(w, req, pf, "pod", "uid", nil, 0, 10*time.Second, portforward.SupportedProtocols) | ||||
|  | ||||
| 		for port, expected := range expectedFromClient { | ||||
| 			actual, ok := pf.received[port] | ||||
| @@ -109,19 +109,19 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF | ||||
| func TestForwardPorts(t *testing.T) { | ||||
| 	tests := map[string]struct { | ||||
| 		ports       []string | ||||
| 		clientSends map[uint16]string | ||||
| 		serverSends map[uint16]string | ||||
| 		clientSends map[int32]string | ||||
| 		serverSends map[int32]string | ||||
| 	}{ | ||||
| 		"forward 1 port with no data either direction": { | ||||
| 			ports: []string{"5000"}, | ||||
| 		}, | ||||
| 		"forward 2 ports with bidirectional data": { | ||||
| 			ports: []string{"5001", "6000"}, | ||||
| 			clientSends: map[uint16]string{ | ||||
| 			clientSends: map[int32]string{ | ||||
| 				5001: "abcd", | ||||
| 				6000: "ghij", | ||||
| 			}, | ||||
| 			serverSends: map[uint16]string{ | ||||
| 			serverSends: map[int32]string{ | ||||
| 				5001: "1234", | ||||
| 				6000: "5678", | ||||
| 			}, | ||||
|   | ||||
| @@ -1047,10 +1047,14 @@ func typeToJSON(typeName string) string { | ||||
| 		return "string" | ||||
| 	case "byte", "*byte": | ||||
| 		return "string" | ||||
|  | ||||
| 	// TODO: Fix these when go-restful supports a way to specify an array query param: | ||||
| 	// https://github.com/emicklei/go-restful/issues/225 | ||||
| 	case "[]string", "[]*string": | ||||
| 		// TODO: Fix this when go-restful supports a way to specify an array query param: | ||||
| 		// https://github.com/emicklei/go-restful/issues/225 | ||||
| 		return "string" | ||||
| 	case "[]int32", "[]*int32": | ||||
| 		return "integer" | ||||
|  | ||||
| 	default: | ||||
| 		return typeName | ||||
| 	} | ||||
|   | ||||
| @@ -72,6 +72,7 @@ go_library( | ||||
|         "//pkg/kubelet/rkt:go_default_library", | ||||
|         "//pkg/kubelet/secret:go_default_library", | ||||
|         "//pkg/kubelet/server:go_default_library", | ||||
|         "//pkg/kubelet/server/portforward:go_default_library", | ||||
|         "//pkg/kubelet/server/remotecommand:go_default_library", | ||||
|         "//pkg/kubelet/server/stats:go_default_library", | ||||
|         "//pkg/kubelet/server/streaming:go_default_library", | ||||
| @@ -177,6 +178,7 @@ go_test( | ||||
|         "//pkg/kubelet/prober/results:go_default_library", | ||||
|         "//pkg/kubelet/prober/testing:go_default_library", | ||||
|         "//pkg/kubelet/secret:go_default_library", | ||||
|         "//pkg/kubelet/server/portforward:go_default_library", | ||||
|         "//pkg/kubelet/server/remotecommand:go_default_library", | ||||
|         "//pkg/kubelet/server/stats:go_default_library", | ||||
|         "//pkg/kubelet/status:go_default_library", | ||||
|   | ||||
| @@ -130,7 +130,7 @@ type DirectStreamingRuntime interface { | ||||
| 	// tty. | ||||
| 	ExecInContainer(containerID ContainerID, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan term.Size, timeout time.Duration) error | ||||
| 	// Forward the specified port from the specified pod to the stream. | ||||
| 	PortForward(pod *Pod, port uint16, stream io.ReadWriteCloser) error | ||||
| 	PortForward(pod *Pod, port int32, stream io.ReadWriteCloser) error | ||||
| 	// ContainerAttach encapsulates the attaching to containers for testability | ||||
| 	ContainerAttacher | ||||
| } | ||||
| @@ -141,7 +141,7 @@ type DirectStreamingRuntime interface { | ||||
| type IndirectStreamingRuntime interface { | ||||
| 	GetExec(id ContainerID, cmd []string, stdin, stdout, stderr, tty bool) (*url.URL, error) | ||||
| 	GetAttach(id ContainerID, stdin, stdout, stderr, tty bool) (*url.URL, error) | ||||
| 	GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) | ||||
| 	GetPortForward(podName, podNamespace string, podUID types.UID, ports []int32) (*url.URL, error) | ||||
| } | ||||
|  | ||||
| type ImageService interface { | ||||
|   | ||||
| @@ -73,7 +73,7 @@ type FakeDirectStreamingRuntime struct { | ||||
| 		TTY         bool | ||||
| 		// Port-forward args | ||||
| 		Pod    *Pod | ||||
| 		Port   uint16 | ||||
| 		Port   int32 | ||||
| 		Stream io.ReadWriteCloser | ||||
| 	} | ||||
| } | ||||
| @@ -394,7 +394,7 @@ func (f *FakeRuntime) RemoveImage(image ImageSpec) error { | ||||
| 	return f.Err | ||||
| } | ||||
|  | ||||
| func (f *FakeDirectStreamingRuntime) PortForward(pod *Pod, port uint16, stream io.ReadWriteCloser) error { | ||||
| func (f *FakeDirectStreamingRuntime) PortForward(pod *Pod, port int32, stream io.ReadWriteCloser) error { | ||||
| 	f.Lock() | ||||
| 	defer f.Unlock() | ||||
|  | ||||
| @@ -471,7 +471,7 @@ func (f *FakeIndirectStreamingRuntime) GetAttach(id ContainerID, stdin, stdout, | ||||
| 	return &url.URL{Host: FakeHost}, f.Err | ||||
| } | ||||
|  | ||||
| func (f *FakeIndirectStreamingRuntime) GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) { | ||||
| func (f *FakeIndirectStreamingRuntime) GetPortForward(podName, podNamespace string, podUID types.UID, ports []int32) (*url.URL, error) { | ||||
| 	f.Lock() | ||||
| 	defer f.Unlock() | ||||
|  | ||||
|   | ||||
| @@ -64,7 +64,7 @@ func (r *streamingRuntime) PortForward(podSandboxID string, port int32, stream i | ||||
| 	if port < 0 || port > math.MaxUint16 { | ||||
| 		return fmt.Errorf("invalid port %d", port) | ||||
| 	} | ||||
| 	return dockertools.PortForward(r.client, podSandboxID, uint16(port), stream) | ||||
| 	return dockertools.PortForward(r.client, podSandboxID, port, stream) | ||||
| } | ||||
|  | ||||
| // ExecSync executes a command in the container, and returns the stdout output. | ||||
|   | ||||
| @@ -1354,7 +1354,7 @@ func noPodInfraContainerError(podName, podNamespace string) error { | ||||
| //  - match cgroups of container | ||||
| //  - should we support nsenter + socat on the host? (current impl) | ||||
| //  - should we support nsenter + socat in a container, running with elevated privs and --pid=host? | ||||
| func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port uint16, stream io.ReadWriteCloser) error { | ||||
| func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port int32, stream io.ReadWriteCloser) error { | ||||
| 	podInfraContainer := pod.FindContainerByName(PodInfraContainerName) | ||||
| 	if podInfraContainer == nil { | ||||
| 		return noPodInfraContainerError(pod.Name, pod.Namespace) | ||||
| @@ -1370,7 +1370,7 @@ func (dm *DockerManager) UpdatePodCIDR(podCIDR string) error { | ||||
| } | ||||
|  | ||||
| // Temporarily export this function to share with dockershim. | ||||
| func PortForward(client DockerInterface, podInfraContainerID string, port uint16, stream io.ReadWriteCloser) error { | ||||
| func PortForward(client DockerInterface, podInfraContainerID string, port int32, stream io.ReadWriteCloser) error { | ||||
| 	container, err := client.InspectContainer(podInfraContainerID) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
|   | ||||
| @@ -2171,9 +2171,10 @@ func getStreamingConfig(kubeCfg *componentconfig.KubeletConfiguration, kubeDeps | ||||
| 		BaseURL: &url.URL{ | ||||
| 			Path: "/cri/", | ||||
| 		}, | ||||
| 		StreamIdleTimeout:     kubeCfg.StreamingConnectionIdleTimeout.Duration, | ||||
| 		StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout, | ||||
| 		SupportedProtocols:    streaming.DefaultConfig.SupportedProtocols, | ||||
| 		StreamIdleTimeout:               kubeCfg.StreamingConnectionIdleTimeout.Duration, | ||||
| 		StreamCreationTimeout:           streaming.DefaultConfig.StreamCreationTimeout, | ||||
| 		SupportedRemoteCommandProtocols: streaming.DefaultConfig.SupportedRemoteCommandProtocols, | ||||
| 		SupportedPortForwardProtocols:   streaming.DefaultConfig.SupportedPortForwardProtocols, | ||||
| 	} | ||||
| 	if kubeDeps.TLSOptions != nil { | ||||
| 		config.TLSConfig = kubeDeps.TLSOptions.Config | ||||
|   | ||||
| @@ -50,6 +50,7 @@ import ( | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/envvars" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/images" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/qos" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/portforward" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/remotecommand" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/status" | ||||
| 	kubetypes "k8s.io/kubernetes/pkg/kubelet/types" | ||||
| @@ -1394,7 +1395,7 @@ func (kl *Kubelet) AttachContainer(podFullName string, podUID types.UID, contain | ||||
|  | ||||
| // PortForward connects to the pod's port and copies data between the port | ||||
| // and the stream. | ||||
| func (kl *Kubelet) PortForward(podFullName string, podUID types.UID, port uint16, stream io.ReadWriteCloser) error { | ||||
| func (kl *Kubelet) PortForward(podFullName string, podUID types.UID, port int32, stream io.ReadWriteCloser) error { | ||||
| 	streamingRuntime, ok := kl.containerRuntime.(kubecontainer.DirectStreamingRuntime) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("streaming methods not supported by runtime") | ||||
| @@ -1467,7 +1468,7 @@ func (kl *Kubelet) GetAttach(podFullName string, podUID types.UID, containerName | ||||
| } | ||||
|  | ||||
| // GetPortForward gets the URL the port-forward will be served from, or nil if the Kubelet will serve it. | ||||
| func (kl *Kubelet) GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) { | ||||
| func (kl *Kubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) { | ||||
| 	switch streamingRuntime := kl.containerRuntime.(type) { | ||||
| 	case kubecontainer.DirectStreamingRuntime: | ||||
| 		// Kubelet will serve the attach directly. | ||||
| @@ -1484,7 +1485,7 @@ func (kl *Kubelet) GetPortForward(podName, podNamespace string, podUID types.UID | ||||
| 			return nil, fmt.Errorf("pod not found (%q)", podFullName) | ||||
| 		} | ||||
|  | ||||
| 		return streamingRuntime.GetPortForward(podName, podNamespace, podUID) | ||||
| 		return streamingRuntime.GetPortForward(podName, podNamespace, podUID, portForwardOpts.Ports) | ||||
| 	default: | ||||
| 		return nil, fmt.Errorf("container runtime does not support port-forward") | ||||
| 	} | ||||
|   | ||||
| @@ -37,6 +37,7 @@ import ( | ||||
| 	"k8s.io/kubernetes/pkg/api/v1" | ||||
| 	kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" | ||||
| 	containertest "k8s.io/kubernetes/pkg/kubelet/container/testing" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/portforward" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/remotecommand" | ||||
| ) | ||||
|  | ||||
| @@ -1607,7 +1608,7 @@ func TestPortForward(t *testing.T) { | ||||
| 		podName                = "podFoo" | ||||
| 		podNamespace           = "nsFoo" | ||||
| 		podUID       types.UID = "12345678" | ||||
| 		port         uint16    = 5000 | ||||
| 		port         int32     = 5000 | ||||
| 	) | ||||
| 	var ( | ||||
| 		stream = &fakeReadWriteCloser{} | ||||
| @@ -1646,7 +1647,7 @@ func TestPortForward(t *testing.T) { | ||||
| 		podFullName := kubecontainer.GetPodFullName(podWithUidNameNs(podUID, tc.podName, podNamespace)) | ||||
| 		{ // No streaming case | ||||
| 			description := "no streaming - " + tc.description | ||||
| 			redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID) | ||||
| 			redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{}) | ||||
| 			assert.Error(t, err, description) | ||||
| 			assert.Nil(t, redirect, description) | ||||
|  | ||||
| @@ -1658,7 +1659,7 @@ func TestPortForward(t *testing.T) { | ||||
| 			fakeRuntime := &containertest.FakeDirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} | ||||
| 			kubelet.containerRuntime = fakeRuntime | ||||
|  | ||||
| 			redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID) | ||||
| 			redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{}) | ||||
| 			assert.NoError(t, err, description) | ||||
| 			assert.Nil(t, redirect, description) | ||||
|  | ||||
| @@ -1677,7 +1678,7 @@ func TestPortForward(t *testing.T) { | ||||
| 			fakeRuntime := &containertest.FakeIndirectStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} | ||||
| 			kubelet.containerRuntime = fakeRuntime | ||||
|  | ||||
| 			redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID) | ||||
| 			redirect, err := kubelet.GetPortForward(tc.podName, podNamespace, podUID, portforward.V4Options{}) | ||||
| 			if tc.expectError { | ||||
| 				assert.Error(t, err, description) | ||||
| 			} else { | ||||
|   | ||||
| @@ -237,7 +237,7 @@ func (m *kubeGenericRuntimeManager) getSandboxIDByPodUID(podUID kubetypes.UID, s | ||||
| } | ||||
|  | ||||
| // GetPortForward gets the endpoint the runtime will serve the port-forward request from. | ||||
| func (m *kubeGenericRuntimeManager) GetPortForward(podName, podNamespace string, podUID kubetypes.UID) (*url.URL, error) { | ||||
| func (m *kubeGenericRuntimeManager) GetPortForward(podName, podNamespace string, podUID kubetypes.UID, ports []int32) (*url.URL, error) { | ||||
| 	sandboxIDs, err := m.getSandboxIDByPodUID(podUID, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to find sandboxID for pod %s: %v", format.PodDesc(podName, podNamespace, podUID), err) | ||||
| @@ -245,9 +245,9 @@ func (m *kubeGenericRuntimeManager) GetPortForward(podName, podNamespace string, | ||||
| 	if len(sandboxIDs) == 0 { | ||||
| 		return nil, fmt.Errorf("failed to find sandboxID for pod %s", format.PodDesc(podName, podNamespace, podUID)) | ||||
| 	} | ||||
| 	// TODO: Port is unused for now, but we may need it in the future. | ||||
| 	req := &runtimeapi.PortForwardRequest{ | ||||
| 		PodSandboxId: sandboxIDs[0], | ||||
| 		Port:         ports, | ||||
| 	} | ||||
| 	resp, err := m.runtimeService.PortForward(req) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -2107,7 +2107,7 @@ func (r *Runtime) ExecInContainer(containerID kubecontainer.ContainerID, cmd []s | ||||
| //  - should we support nsenter + socat in a container, running with elevated privs and --pid=host? | ||||
| // | ||||
| // TODO(yifan): Merge with the same function in dockertools. | ||||
| func (r *Runtime) PortForward(pod *kubecontainer.Pod, port uint16, stream io.ReadWriteCloser) error { | ||||
| func (r *Runtime) PortForward(pod *kubecontainer.Pod, port int32, stream io.ReadWriteCloser) error { | ||||
| 	glog.V(4).Infof("Rkt port forwarding in container.") | ||||
|  | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), r.requestTimeout) | ||||
|   | ||||
| @@ -55,6 +55,7 @@ go_test( | ||||
|     srcs = [ | ||||
|         "auth_test.go", | ||||
|         "server_test.go", | ||||
|         "server_websocket_test.go", | ||||
|     ], | ||||
|     library = ":go_default_library", | ||||
|     tags = ["automanaged"], | ||||
| @@ -64,6 +65,7 @@ go_test( | ||||
|         "//pkg/kubelet/cm:go_default_library", | ||||
|         "//pkg/kubelet/container:go_default_library", | ||||
|         "//pkg/kubelet/container/testing:go_default_library", | ||||
|         "//pkg/kubelet/server/portforward:go_default_library", | ||||
|         "//pkg/kubelet/server/remotecommand:go_default_library", | ||||
|         "//pkg/kubelet/server/stats:go_default_library", | ||||
|         "//pkg/util/term:go_default_library", | ||||
| @@ -72,6 +74,7 @@ go_test( | ||||
|         "//vendor:github.com/google/cadvisor/info/v2", | ||||
|         "//vendor:github.com/stretchr/testify/assert", | ||||
|         "//vendor:github.com/stretchr/testify/require", | ||||
|         "//vendor:golang.org/x/net/websocket", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/api/errors", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/apis/meta/v1", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/types", | ||||
|   | ||||
| @@ -12,7 +12,9 @@ go_library( | ||||
|     name = "go_default_library", | ||||
|     srcs = [ | ||||
|         "constants.go", | ||||
|         "httpstream.go", | ||||
|         "portforward.go", | ||||
|         "websocket.go", | ||||
|     ], | ||||
|     tags = ["automanaged"], | ||||
|     deps = [ | ||||
| @@ -22,12 +24,17 @@ go_library( | ||||
|         "//vendor:k8s.io/apimachinery/pkg/util/httpstream", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/util/httpstream/spdy", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/util/runtime", | ||||
|         "//vendor:k8s.io/apiserver/pkg/server/httplog", | ||||
|         "//vendor:k8s.io/apiserver/pkg/util/wsstream", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| go_test( | ||||
|     name = "go_default_test", | ||||
|     srcs = ["portforward_test.go"], | ||||
|     srcs = [ | ||||
|         "httpstream_test.go", | ||||
|         "websocket_test.go", | ||||
|     ], | ||||
|     library = ":go_default_library", | ||||
|     tags = ["automanaged"], | ||||
|     deps = [ | ||||
|   | ||||
| @@ -18,4 +18,6 @@ limitations under the License. | ||||
| package portforward | ||||
|  | ||||
| // The subprotocol "portforward.k8s.io" is used for port forwarding. | ||||
| const PortForwardProtocolV1Name = "portforward.k8s.io" | ||||
| const ProtocolV1Name = "portforward.k8s.io" | ||||
|  | ||||
| var SupportedProtocols = []string{ProtocolV1Name} | ||||
|   | ||||
							
								
								
									
										309
									
								
								pkg/kubelet/server/portforward/httpstream.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										309
									
								
								pkg/kubelet/server/portforward/httpstream.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,309 @@ | ||||
| /* | ||||
| Copyright 2016 The Kubernetes Authors. | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| */ | ||||
|  | ||||
| package portforward | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"k8s.io/apimachinery/pkg/types" | ||||
| 	"k8s.io/apimachinery/pkg/util/httpstream" | ||||
| 	"k8s.io/apimachinery/pkg/util/httpstream/spdy" | ||||
| 	utilruntime "k8s.io/apimachinery/pkg/util/runtime" | ||||
| 	"k8s.io/kubernetes/pkg/api" | ||||
|  | ||||
| 	"github.com/golang/glog" | ||||
| ) | ||||
|  | ||||
| func handleHttpStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error { | ||||
| 	_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) | ||||
| 	// negotiated protocol isn't currently used server side, but could be in the future | ||||
| 	if err != nil { | ||||
| 		// Handshake writes the error to the client | ||||
| 		return err | ||||
| 	} | ||||
| 	streamChan := make(chan httpstream.Stream, 1) | ||||
|  | ||||
| 	glog.V(5).Infof("Upgrading port forward response") | ||||
| 	upgrader := spdy.NewResponseUpgrader() | ||||
| 	conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan)) | ||||
| 	if conn == nil { | ||||
| 		return errors.New("Unable to upgrade websocket connection") | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
|  | ||||
| 	glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) | ||||
| 	conn.SetIdleTimeout(idleTimeout) | ||||
|  | ||||
| 	h := &httpStreamHandler{ | ||||
| 		conn:                  conn, | ||||
| 		streamChan:            streamChan, | ||||
| 		streamPairs:           make(map[string]*httpStreamPair), | ||||
| 		streamCreationTimeout: streamCreationTimeout, | ||||
| 		pod:       podName, | ||||
| 		uid:       uid, | ||||
| 		forwarder: portForwarder, | ||||
| 	} | ||||
| 	h.run() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // httpStreamReceived is the httpstream.NewStreamHandler for port | ||||
| // forward streams. It checks each stream's port and stream type headers, | ||||
| // rejecting any streams that with missing or invalid values. Each valid | ||||
| // stream is sent to the streams channel. | ||||
| func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { | ||||
| 	return func(stream httpstream.Stream, replySent <-chan struct{}) error { | ||||
| 		// make sure it has a valid port header | ||||
| 		portString := stream.Headers().Get(api.PortHeader) | ||||
| 		if len(portString) == 0 { | ||||
| 			return fmt.Errorf("%q header is required", api.PortHeader) | ||||
| 		} | ||||
| 		port, err := strconv.ParseUint(portString, 10, 16) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("unable to parse %q as a port: %v", portString, err) | ||||
| 		} | ||||
| 		if port < 1 { | ||||
| 			return fmt.Errorf("port %q must be > 0", portString) | ||||
| 		} | ||||
|  | ||||
| 		// make sure it has a valid stream type header | ||||
| 		streamType := stream.Headers().Get(api.StreamType) | ||||
| 		if len(streamType) == 0 { | ||||
| 			return fmt.Errorf("%q header is required", api.StreamType) | ||||
| 		} | ||||
| 		if streamType != api.StreamTypeError && streamType != api.StreamTypeData { | ||||
| 			return fmt.Errorf("invalid stream type %q", streamType) | ||||
| 		} | ||||
|  | ||||
| 		streams <- stream | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // httpStreamHandler is capable of processing multiple port forward | ||||
| // requests over a single httpstream.Connection. | ||||
| type httpStreamHandler struct { | ||||
| 	conn                  httpstream.Connection | ||||
| 	streamChan            chan httpstream.Stream | ||||
| 	streamPairsLock       sync.RWMutex | ||||
| 	streamPairs           map[string]*httpStreamPair | ||||
| 	streamCreationTimeout time.Duration | ||||
| 	pod                   string | ||||
| 	uid                   types.UID | ||||
| 	forwarder             PortForwarder | ||||
| } | ||||
|  | ||||
| // getStreamPair returns a httpStreamPair for requestID. This creates a | ||||
| // new pair if one does not yet exist for the requestID. The returned bool is | ||||
| // true if the pair was created. | ||||
| func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) { | ||||
| 	h.streamPairsLock.Lock() | ||||
| 	defer h.streamPairsLock.Unlock() | ||||
|  | ||||
| 	if p, ok := h.streamPairs[requestID]; ok { | ||||
| 		glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) | ||||
| 		return p, false | ||||
| 	} | ||||
|  | ||||
| 	glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) | ||||
|  | ||||
| 	p := newPortForwardPair(requestID) | ||||
| 	h.streamPairs[requestID] = p | ||||
|  | ||||
| 	return p, true | ||||
| } | ||||
|  | ||||
| // monitorStreamPair waits for the pair to receive both its error and data | ||||
| // streams, or for the timeout to expire (whichever happens first), and then | ||||
| // removes the pair. | ||||
| func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) { | ||||
| 	select { | ||||
| 	case <-timeout: | ||||
| 		err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID) | ||||
| 		utilruntime.HandleError(err) | ||||
| 		p.printError(err.Error()) | ||||
| 	case <-p.complete: | ||||
| 		glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID) | ||||
| 	} | ||||
| 	h.removeStreamPair(p.requestID) | ||||
| } | ||||
|  | ||||
| // hasStreamPair returns a bool indicating if a stream pair for requestID | ||||
| // exists. | ||||
| func (h *httpStreamHandler) hasStreamPair(requestID string) bool { | ||||
| 	h.streamPairsLock.RLock() | ||||
| 	defer h.streamPairsLock.RUnlock() | ||||
|  | ||||
| 	_, ok := h.streamPairs[requestID] | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| // removeStreamPair removes the stream pair identified by requestID from streamPairs. | ||||
| func (h *httpStreamHandler) removeStreamPair(requestID string) { | ||||
| 	h.streamPairsLock.Lock() | ||||
| 	defer h.streamPairsLock.Unlock() | ||||
|  | ||||
| 	delete(h.streamPairs, requestID) | ||||
| } | ||||
|  | ||||
| // requestID returns the request id for stream. | ||||
| func (h *httpStreamHandler) requestID(stream httpstream.Stream) string { | ||||
| 	requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) | ||||
| 	if len(requestID) == 0 { | ||||
| 		glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) | ||||
| 		// If we get here, it's because the connection came from an older client | ||||
| 		// that isn't generating the request id header | ||||
| 		// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) | ||||
| 		// | ||||
| 		// This is a best-effort attempt at supporting older clients. | ||||
| 		// | ||||
| 		// When there aren't concurrent new forwarded connections, each connection | ||||
| 		// will have a pair of streams (data, error), and the stream IDs will be | ||||
| 		// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert | ||||
| 		// the stream ID into a pseudo-request id by taking the stream type and | ||||
| 		// using id = stream.Identifier() when the stream type is error, | ||||
| 		// and id = stream.Identifier() - 2 when it's data. | ||||
| 		// | ||||
| 		// NOTE: this only works when there are not concurrent new streams from | ||||
| 		// multiple forwarded connections; it's a best-effort attempt at supporting | ||||
| 		// old clients that don't generate request ids.  If there are concurrent | ||||
| 		// new connections, it's possible that 1 connection gets streams whose IDs | ||||
| 		// are not consecutive (e.g. 5 and 9 instead of 5 and 7). | ||||
| 		streamType := stream.Headers().Get(api.StreamType) | ||||
| 		switch streamType { | ||||
| 		case api.StreamTypeError: | ||||
| 			requestID = strconv.Itoa(int(stream.Identifier())) | ||||
| 		case api.StreamTypeData: | ||||
| 			requestID = strconv.Itoa(int(stream.Identifier()) - 2) | ||||
| 		} | ||||
|  | ||||
| 		glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) | ||||
| 	} | ||||
| 	return requestID | ||||
| } | ||||
|  | ||||
| // run is the main loop for the httpStreamHandler. It processes new | ||||
| // streams, invoking portForward for each complete stream pair. The loop exits | ||||
| // when the httpstream.Connection is closed. | ||||
| func (h *httpStreamHandler) run() { | ||||
| 	glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) | ||||
| Loop: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-h.conn.CloseChan(): | ||||
| 			glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) | ||||
| 			break Loop | ||||
| 		case stream := <-h.streamChan: | ||||
| 			requestID := h.requestID(stream) | ||||
| 			streamType := stream.Headers().Get(api.StreamType) | ||||
| 			glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) | ||||
|  | ||||
| 			p, created := h.getStreamPair(requestID) | ||||
| 			if created { | ||||
| 				go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) | ||||
| 			} | ||||
| 			if complete, err := p.add(stream); err != nil { | ||||
| 				msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) | ||||
| 				utilruntime.HandleError(errors.New(msg)) | ||||
| 				p.printError(msg) | ||||
| 			} else if complete { | ||||
| 				go h.portForward(p) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // portForward invokes the httpStreamHandler's forwarder.PortForward | ||||
| // function for the given stream pair. | ||||
| func (h *httpStreamHandler) portForward(p *httpStreamPair) { | ||||
| 	defer p.dataStream.Close() | ||||
| 	defer p.errorStream.Close() | ||||
|  | ||||
| 	portString := p.dataStream.Headers().Get(api.PortHeader) | ||||
| 	port, _ := strconv.ParseInt(portString, 10, 32) | ||||
|  | ||||
| 	glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) | ||||
| 	err := h.forwarder.PortForward(h.pod, h.uid, int32(port), p.dataStream) | ||||
| 	glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) | ||||
| 		utilruntime.HandleError(msg) | ||||
| 		fmt.Fprint(p.errorStream, msg.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // httpStreamPair represents the error and data streams for a port | ||||
| // forwarding request. | ||||
| type httpStreamPair struct { | ||||
| 	lock        sync.RWMutex | ||||
| 	requestID   string | ||||
| 	dataStream  httpstream.Stream | ||||
| 	errorStream httpstream.Stream | ||||
| 	complete    chan struct{} | ||||
| } | ||||
|  | ||||
| // newPortForwardPair creates a new httpStreamPair. | ||||
| func newPortForwardPair(requestID string) *httpStreamPair { | ||||
| 	return &httpStreamPair{ | ||||
| 		requestID: requestID, | ||||
| 		complete:  make(chan struct{}), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // add adds the stream to the httpStreamPair. If the pair already | ||||
| // contains a stream for the new stream's type, an error is returned. add | ||||
| // returns true if both the data and error streams for this pair have been | ||||
| // received. | ||||
| func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) { | ||||
| 	p.lock.Lock() | ||||
| 	defer p.lock.Unlock() | ||||
|  | ||||
| 	switch stream.Headers().Get(api.StreamType) { | ||||
| 	case api.StreamTypeError: | ||||
| 		if p.errorStream != nil { | ||||
| 			return false, errors.New("error stream already assigned") | ||||
| 		} | ||||
| 		p.errorStream = stream | ||||
| 	case api.StreamTypeData: | ||||
| 		if p.dataStream != nil { | ||||
| 			return false, errors.New("data stream already assigned") | ||||
| 		} | ||||
| 		p.dataStream = stream | ||||
| 	} | ||||
|  | ||||
| 	complete := p.errorStream != nil && p.dataStream != nil | ||||
| 	if complete { | ||||
| 		close(p.complete) | ||||
| 	} | ||||
| 	return complete, nil | ||||
| } | ||||
|  | ||||
| // printError writes s to p.errorStream if p.errorStream has been set. | ||||
| func (p *httpStreamPair) printError(s string) { | ||||
| 	p.lock.RLock() | ||||
| 	defer p.lock.RUnlock() | ||||
| 	if p.errorStream != nil { | ||||
| 		fmt.Fprint(p.errorStream, s) | ||||
| 	} | ||||
| } | ||||
| @@ -25,7 +25,7 @@ import ( | ||||
| 	"k8s.io/kubernetes/pkg/api" | ||||
| ) | ||||
| 
 | ||||
| func TestPortForwardStreamReceived(t *testing.T) { | ||||
| func TestHTTPStreamReceived(t *testing.T) { | ||||
| 	tests := map[string]struct { | ||||
| 		port          string | ||||
| 		streamType    string | ||||
| @@ -62,7 +62,7 @@ func TestPortForwardStreamReceived(t *testing.T) { | ||||
| 	} | ||||
| 	for name, test := range tests { | ||||
| 		streams := make(chan httpstream.Stream, 1) | ||||
| 		f := portForwardStreamReceived(streams) | ||||
| 		f := httpStreamReceived(streams) | ||||
| 		stream := newFakeHttpStream() | ||||
| 		if len(test.port) > 0 { | ||||
| 			stream.headers.Set("port", test.port) | ||||
| @@ -92,48 +92,11 @@ func TestPortForwardStreamReceived(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type fakeHttpStream struct { | ||||
| 	headers http.Header | ||||
| 	id      uint32 | ||||
| } | ||||
| 
 | ||||
| func newFakeHttpStream() *fakeHttpStream { | ||||
| 	return &fakeHttpStream{ | ||||
| 		headers: make(http.Header), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| var _ httpstream.Stream = &fakeHttpStream{} | ||||
| 
 | ||||
| func (s *fakeHttpStream) Read(data []byte) (int, error) { | ||||
| 	return 0, nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Write(data []byte) (int, error) { | ||||
| 	return 0, nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Close() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Reset() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Headers() http.Header { | ||||
| 	return s.headers | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Identifier() uint32 { | ||||
| 	return s.id | ||||
| } | ||||
| 
 | ||||
| func TestGetStreamPair(t *testing.T) { | ||||
| 	timeout := make(chan time.Time) | ||||
| 
 | ||||
| 	h := &portForwardStreamHandler{ | ||||
| 		streamPairs: make(map[string]*portForwardStreamPair), | ||||
| 	h := &httpStreamHandler{ | ||||
| 		streamPairs: make(map[string]*httpStreamPair), | ||||
| 	} | ||||
| 
 | ||||
| 	// test adding a new entry | ||||
| @@ -223,7 +186,7 @@ func TestGetStreamPair(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestRequestID(t *testing.T) { | ||||
| 	h := &portForwardStreamHandler{} | ||||
| 	h := &httpStreamHandler{} | ||||
| 
 | ||||
| 	s := newFakeHttpStream() | ||||
| 	s.headers.Set(api.StreamType, api.StreamTypeError) | ||||
| @@ -244,3 +207,40 @@ func TestRequestID(t *testing.T) { | ||||
| 		t.Errorf("expected %q, got %q", e, a) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type fakeHttpStream struct { | ||||
| 	headers http.Header | ||||
| 	id      uint32 | ||||
| } | ||||
| 
 | ||||
| func newFakeHttpStream() *fakeHttpStream { | ||||
| 	return &fakeHttpStream{ | ||||
| 		headers: make(http.Header), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| var _ httpstream.Stream = &fakeHttpStream{} | ||||
| 
 | ||||
| func (s *fakeHttpStream) Read(data []byte) (int, error) { | ||||
| 	return 0, nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Write(data []byte) (int, error) { | ||||
| 	return 0, nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Close() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Reset() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Headers() http.Header { | ||||
| 	return s.headers | ||||
| } | ||||
| 
 | ||||
| func (s *fakeHttpStream) Identifier() uint32 { | ||||
| 	return s.id | ||||
| } | ||||
| @@ -17,28 +17,20 @@ limitations under the License. | ||||
| package portforward | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/golang/glog" | ||||
|  | ||||
| 	"k8s.io/apimachinery/pkg/types" | ||||
| 	"k8s.io/apimachinery/pkg/util/httpstream" | ||||
| 	"k8s.io/apimachinery/pkg/util/httpstream/spdy" | ||||
| 	utilruntime "k8s.io/apimachinery/pkg/util/runtime" | ||||
| 	"k8s.io/kubernetes/pkg/api" | ||||
| 	"k8s.io/apimachinery/pkg/util/runtime" | ||||
| 	"k8s.io/apiserver/pkg/util/wsstream" | ||||
| ) | ||||
|  | ||||
| // PortForwarder knows how to forward content from a data stream to/from a port | ||||
| // in a pod. | ||||
| type PortForwarder interface { | ||||
| 	// PortForwarder copies data between a data stream and a port in a pod. | ||||
| 	PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error | ||||
| 	PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error | ||||
| } | ||||
|  | ||||
| // ServePortForward handles a port forwarding request.  A single request is | ||||
| @@ -46,278 +38,16 @@ type PortForwarder interface { | ||||
| // been timed out due to idleness. This function handles multiple forwarded | ||||
| // connections; i.e., multiple `curl http://localhost:8888/` requests will be | ||||
| // handled by a single invocation of ServePortForward. | ||||
| func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { | ||||
| 	supportedPortForwardProtocols := []string{PortForwardProtocolV1Name} | ||||
| 	_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) | ||||
| 	// negotiated protocol isn't currently used server side, but could be in the future | ||||
| 	if err != nil { | ||||
| 		// Handshake writes the error to the client | ||||
| 		utilruntime.HandleError(err) | ||||
| 		return | ||||
| func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, portForwardOptions *V4Options, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) { | ||||
| 	var err error | ||||
| 	if wsstream.IsWebSocketRequest(req) { | ||||
| 		err = handleWebSocketStreams(req, w, portForwarder, podName, uid, portForwardOptions, supportedProtocols, idleTimeout, streamCreationTimeout) | ||||
| 	} else { | ||||
| 		err = handleHttpStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout) | ||||
| 	} | ||||
|  | ||||
| 	streamChan := make(chan httpstream.Stream, 1) | ||||
|  | ||||
| 	glog.V(5).Infof("Upgrading port forward response") | ||||
| 	upgrader := spdy.NewResponseUpgrader() | ||||
| 	conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) | ||||
| 	if conn == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
|  | ||||
| 	glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) | ||||
| 	conn.SetIdleTimeout(idleTimeout) | ||||
|  | ||||
| 	h := &portForwardStreamHandler{ | ||||
| 		conn:                  conn, | ||||
| 		streamChan:            streamChan, | ||||
| 		streamPairs:           make(map[string]*portForwardStreamPair), | ||||
| 		streamCreationTimeout: streamCreationTimeout, | ||||
| 		pod:       podName, | ||||
| 		uid:       uid, | ||||
| 		forwarder: portForwarder, | ||||
| 	} | ||||
| 	h.run() | ||||
| } | ||||
|  | ||||
| // portForwardStreamReceived is the httpstream.NewStreamHandler for port | ||||
| // forward streams. It checks each stream's port and stream type headers, | ||||
| // rejecting any streams that with missing or invalid values. Each valid | ||||
| // stream is sent to the streams channel. | ||||
| func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { | ||||
| 	return func(stream httpstream.Stream, replySent <-chan struct{}) error { | ||||
| 		// make sure it has a valid port header | ||||
| 		portString := stream.Headers().Get(api.PortHeader) | ||||
| 		if len(portString) == 0 { | ||||
| 			return fmt.Errorf("%q header is required", api.PortHeader) | ||||
| 		} | ||||
| 		port, err := strconv.ParseUint(portString, 10, 16) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("unable to parse %q as a port: %v", portString, err) | ||||
| 		} | ||||
| 		if port < 1 { | ||||
| 			return fmt.Errorf("port %q must be > 0", portString) | ||||
| 		} | ||||
|  | ||||
| 		// make sure it has a valid stream type header | ||||
| 		streamType := stream.Headers().Get(api.StreamType) | ||||
| 		if len(streamType) == 0 { | ||||
| 			return fmt.Errorf("%q header is required", api.StreamType) | ||||
| 		} | ||||
| 		if streamType != api.StreamTypeError && streamType != api.StreamTypeData { | ||||
| 			return fmt.Errorf("invalid stream type %q", streamType) | ||||
| 		} | ||||
|  | ||||
| 		streams <- stream | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // portForwardStreamHandler is capable of processing multiple port forward | ||||
| // requests over a single httpstream.Connection. | ||||
| type portForwardStreamHandler struct { | ||||
| 	conn                  httpstream.Connection | ||||
| 	streamChan            chan httpstream.Stream | ||||
| 	streamPairsLock       sync.RWMutex | ||||
| 	streamPairs           map[string]*portForwardStreamPair | ||||
| 	streamCreationTimeout time.Duration | ||||
| 	pod                   string | ||||
| 	uid                   types.UID | ||||
| 	forwarder             PortForwarder | ||||
| } | ||||
|  | ||||
| // getStreamPair returns a portForwardStreamPair for requestID. This creates a | ||||
| // new pair if one does not yet exist for the requestID. The returned bool is | ||||
| // true if the pair was created. | ||||
| func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) { | ||||
| 	h.streamPairsLock.Lock() | ||||
| 	defer h.streamPairsLock.Unlock() | ||||
|  | ||||
| 	if p, ok := h.streamPairs[requestID]; ok { | ||||
| 		glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) | ||||
| 		return p, false | ||||
| 	} | ||||
|  | ||||
| 	glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) | ||||
|  | ||||
| 	p := newPortForwardPair(requestID) | ||||
| 	h.streamPairs[requestID] = p | ||||
|  | ||||
| 	return p, true | ||||
| } | ||||
|  | ||||
| // monitorStreamPair waits for the pair to receive both its error and data | ||||
| // streams, or for the timeout to expire (whichever happens first), and then | ||||
| // removes the pair. | ||||
| func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) { | ||||
| 	select { | ||||
| 	case <-timeout: | ||||
| 		err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID) | ||||
| 		utilruntime.HandleError(err) | ||||
| 		p.printError(err.Error()) | ||||
| 	case <-p.complete: | ||||
| 		glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID) | ||||
| 	} | ||||
| 	h.removeStreamPair(p.requestID) | ||||
| } | ||||
|  | ||||
| // hasStreamPair returns a bool indicating if a stream pair for requestID | ||||
| // exists. | ||||
| func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool { | ||||
| 	h.streamPairsLock.RLock() | ||||
| 	defer h.streamPairsLock.RUnlock() | ||||
|  | ||||
| 	_, ok := h.streamPairs[requestID] | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| // removeStreamPair removes the stream pair identified by requestID from streamPairs. | ||||
| func (h *portForwardStreamHandler) removeStreamPair(requestID string) { | ||||
| 	h.streamPairsLock.Lock() | ||||
| 	defer h.streamPairsLock.Unlock() | ||||
|  | ||||
| 	delete(h.streamPairs, requestID) | ||||
| } | ||||
|  | ||||
| // requestID returns the request id for stream. | ||||
| func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string { | ||||
| 	requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) | ||||
| 	if len(requestID) == 0 { | ||||
| 		glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) | ||||
| 		// If we get here, it's because the connection came from an older client | ||||
| 		// that isn't generating the request id header | ||||
| 		// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) | ||||
| 		// | ||||
| 		// This is a best-effort attempt at supporting older clients. | ||||
| 		// | ||||
| 		// When there aren't concurrent new forwarded connections, each connection | ||||
| 		// will have a pair of streams (data, error), and the stream IDs will be | ||||
| 		// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert | ||||
| 		// the stream ID into a pseudo-request id by taking the stream type and | ||||
| 		// using id = stream.Identifier() when the stream type is error, | ||||
| 		// and id = stream.Identifier() - 2 when it's data. | ||||
| 		// | ||||
| 		// NOTE: this only works when there are not concurrent new streams from | ||||
| 		// multiple forwarded connections; it's a best-effort attempt at supporting | ||||
| 		// old clients that don't generate request ids.  If there are concurrent | ||||
| 		// new connections, it's possible that 1 connection gets streams whose IDs | ||||
| 		// are not consecutive (e.g. 5 and 9 instead of 5 and 7). | ||||
| 		streamType := stream.Headers().Get(api.StreamType) | ||||
| 		switch streamType { | ||||
| 		case api.StreamTypeError: | ||||
| 			requestID = strconv.Itoa(int(stream.Identifier())) | ||||
| 		case api.StreamTypeData: | ||||
| 			requestID = strconv.Itoa(int(stream.Identifier()) - 2) | ||||
| 		} | ||||
|  | ||||
| 		glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) | ||||
| 	} | ||||
| 	return requestID | ||||
| } | ||||
|  | ||||
| // run is the main loop for the portForwardStreamHandler. It processes new | ||||
| // streams, invoking portForward for each complete stream pair. The loop exits | ||||
| // when the httpstream.Connection is closed. | ||||
| func (h *portForwardStreamHandler) run() { | ||||
| 	glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) | ||||
| Loop: | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-h.conn.CloseChan(): | ||||
| 			glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) | ||||
| 			break Loop | ||||
| 		case stream := <-h.streamChan: | ||||
| 			requestID := h.requestID(stream) | ||||
| 			streamType := stream.Headers().Get(api.StreamType) | ||||
| 			glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) | ||||
|  | ||||
| 			p, created := h.getStreamPair(requestID) | ||||
| 			if created { | ||||
| 				go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) | ||||
| 			} | ||||
| 			if complete, err := p.add(stream); err != nil { | ||||
| 				msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) | ||||
| 				utilruntime.HandleError(errors.New(msg)) | ||||
| 				p.printError(msg) | ||||
| 			} else if complete { | ||||
| 				go h.portForward(p) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // portForward invokes the portForwardStreamHandler's forwarder.PortForward | ||||
| // function for the given stream pair. | ||||
| func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) { | ||||
| 	defer p.dataStream.Close() | ||||
| 	defer p.errorStream.Close() | ||||
|  | ||||
| 	portString := p.dataStream.Headers().Get(api.PortHeader) | ||||
| 	port, _ := strconv.ParseUint(portString, 10, 16) | ||||
|  | ||||
| 	glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) | ||||
| 	err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream) | ||||
| 	glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) | ||||
| 		utilruntime.HandleError(msg) | ||||
| 		fmt.Fprint(p.errorStream, msg.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // portForwardStreamPair represents the error and data streams for a port | ||||
| // forwarding request. | ||||
| type portForwardStreamPair struct { | ||||
| 	lock        sync.RWMutex | ||||
| 	requestID   string | ||||
| 	dataStream  httpstream.Stream | ||||
| 	errorStream httpstream.Stream | ||||
| 	complete    chan struct{} | ||||
| } | ||||
|  | ||||
| // newPortForwardPair creates a new portForwardStreamPair. | ||||
| func newPortForwardPair(requestID string) *portForwardStreamPair { | ||||
| 	return &portForwardStreamPair{ | ||||
| 		requestID: requestID, | ||||
| 		complete:  make(chan struct{}), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // add adds the stream to the portForwardStreamPair. If the pair already | ||||
| // contains a stream for the new stream's type, an error is returned. add | ||||
| // returns true if both the data and error streams for this pair have been | ||||
| // received. | ||||
| func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) { | ||||
| 	p.lock.Lock() | ||||
| 	defer p.lock.Unlock() | ||||
|  | ||||
| 	switch stream.Headers().Get(api.StreamType) { | ||||
| 	case api.StreamTypeError: | ||||
| 		if p.errorStream != nil { | ||||
| 			return false, errors.New("error stream already assigned") | ||||
| 		} | ||||
| 		p.errorStream = stream | ||||
| 	case api.StreamTypeData: | ||||
| 		if p.dataStream != nil { | ||||
| 			return false, errors.New("data stream already assigned") | ||||
| 		} | ||||
| 		p.dataStream = stream | ||||
| 	} | ||||
|  | ||||
| 	complete := p.errorStream != nil && p.dataStream != nil | ||||
| 	if complete { | ||||
| 		close(p.complete) | ||||
| 	} | ||||
| 	return complete, nil | ||||
| } | ||||
|  | ||||
| // printError writes s to p.errorStream if p.errorStream has been set. | ||||
| func (p *portForwardStreamPair) printError(s string) { | ||||
| 	p.lock.RLock() | ||||
| 	defer p.lock.RUnlock() | ||||
| 	if p.errorStream != nil { | ||||
| 		fmt.Fprint(p.errorStream, s) | ||||
| 		runtime.HandleError(err) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										191
									
								
								pkg/kubelet/server/portforward/websocket.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								pkg/kubelet/server/portforward/websocket.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| /* | ||||
| Copyright 2016 The Kubernetes Authors. | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| */ | ||||
|  | ||||
| package portforward | ||||
|  | ||||
| import ( | ||||
| 	"encoding/binary" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/golang/glog" | ||||
|  | ||||
| 	"k8s.io/apimachinery/pkg/types" | ||||
| 	"k8s.io/apimachinery/pkg/util/runtime" | ||||
| 	"k8s.io/apiserver/pkg/server/httplog" | ||||
| 	"k8s.io/apiserver/pkg/util/wsstream" | ||||
| 	"k8s.io/kubernetes/pkg/api" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	dataChannel = iota | ||||
| 	errorChannel | ||||
|  | ||||
| 	v4BinaryWebsocketProtocol = "v4." + wsstream.ChannelWebSocketProtocol | ||||
| 	v4Base64WebsocketProtocol = "v4." + wsstream.Base64ChannelWebSocketProtocol | ||||
| ) | ||||
|  | ||||
| // options contains details about which streams are required for | ||||
| // port forwarding. | ||||
| type V4Options struct { | ||||
| 	Ports []int32 | ||||
| } | ||||
|  | ||||
| // newOptions creates a new options from the Request. | ||||
| func NewV4Options(req *http.Request) (*V4Options, error) { | ||||
| 	if !wsstream.IsWebSocketRequest(req) { | ||||
| 		return &V4Options{}, nil | ||||
| 	} | ||||
|  | ||||
| 	portStrings := req.URL.Query()[api.PortHeader] | ||||
| 	if len(portStrings) == 0 { | ||||
| 		return nil, fmt.Errorf("query parameter %q is required", api.PortHeader) | ||||
| 	} | ||||
|  | ||||
| 	ports := make([]int32, 0, len(portStrings)) | ||||
| 	for _, portString := range portStrings { | ||||
| 		if len(portString) == 0 { | ||||
| 			return nil, fmt.Errorf("query parameter %q cannot be empty", api.PortHeader) | ||||
| 		} | ||||
| 		for _, p := range strings.Split(portString, ",") { | ||||
| 			port, err := strconv.ParseUint(p, 10, 16) | ||||
| 			if err != nil { | ||||
| 				return nil, fmt.Errorf("unable to parse %q as a port: %v", portString, err) | ||||
| 			} | ||||
| 			if port < 1 { | ||||
| 				return nil, fmt.Errorf("port %q must be > 0", portString) | ||||
| 			} | ||||
| 			ports = append(ports, int32(port)) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &V4Options{ | ||||
| 		Ports: ports, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // handleWebSocketStreams handles requests to forward ports to a pod via | ||||
| // a PortForwarder. A pair of streams are created per port (DATA n, | ||||
| // ERROR n+1). The associated port is written to each stream as a unsigned 16 | ||||
| // bit integer in little endian format. | ||||
| func handleWebSocketStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, opts *V4Options, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error { | ||||
| 	channels := make([]wsstream.ChannelType, 0, len(opts.Ports)*2) | ||||
| 	for i := 0; i < len(opts.Ports); i++ { | ||||
| 		channels = append(channels, wsstream.ReadWriteChannel, wsstream.WriteChannel) | ||||
| 	} | ||||
| 	conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{ | ||||
| 		"": { | ||||
| 			Binary:   true, | ||||
| 			Channels: channels, | ||||
| 		}, | ||||
| 		v4BinaryWebsocketProtocol: { | ||||
| 			Binary:   true, | ||||
| 			Channels: channels, | ||||
| 		}, | ||||
| 		v4Base64WebsocketProtocol: { | ||||
| 			Binary:   false, | ||||
| 			Channels: channels, | ||||
| 		}, | ||||
| 	}) | ||||
| 	conn.SetIdleTimeout(idleTimeout) | ||||
| 	_, streams, err := conn.Open(httplog.Unlogged(w), req) | ||||
| 	if err != nil { | ||||
| 		err = fmt.Errorf("Unable to upgrade websocket connection: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 	streamPairs := make([]*websocketStreamPair, len(opts.Ports)) | ||||
| 	for i := range streamPairs { | ||||
| 		streamPair := websocketStreamPair{ | ||||
| 			port:        opts.Ports[i], | ||||
| 			dataStream:  streams[i*2+dataChannel], | ||||
| 			errorStream: streams[i*2+errorChannel], | ||||
| 		} | ||||
| 		streamPairs[i] = &streamPair | ||||
|  | ||||
| 		portBytes := make([]byte, 2) | ||||
| 		// port is always positive so conversion is allowable | ||||
| 		binary.LittleEndian.PutUint16(portBytes, uint16(streamPair.port)) | ||||
| 		streamPair.dataStream.Write(portBytes) | ||||
| 		streamPair.errorStream.Write(portBytes) | ||||
| 	} | ||||
| 	h := &websocketStreamHandler{ | ||||
| 		conn:        conn, | ||||
| 		streamPairs: streamPairs, | ||||
| 		pod:         podName, | ||||
| 		uid:         uid, | ||||
| 		forwarder:   portForwarder, | ||||
| 	} | ||||
| 	h.run() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // websocketStreamPair represents the error and data streams for a port | ||||
| // forwarding request. | ||||
| type websocketStreamPair struct { | ||||
| 	port        int32 | ||||
| 	dataStream  io.ReadWriteCloser | ||||
| 	errorStream io.WriteCloser | ||||
| } | ||||
|  | ||||
| // websocketStreamHandler is capable of processing a single port forward | ||||
| // request over a websocket connection | ||||
| type websocketStreamHandler struct { | ||||
| 	conn        *wsstream.Conn | ||||
| 	ports       []int32 | ||||
| 	streamPairs []*websocketStreamPair | ||||
| 	pod         string | ||||
| 	uid         types.UID | ||||
| 	forwarder   PortForwarder | ||||
| } | ||||
|  | ||||
| // run invokes the websocketStreamHandler's forwarder.PortForward | ||||
| // function for the given stream pair. | ||||
| func (h *websocketStreamHandler) run() { | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	wg.Add(len(h.streamPairs)) | ||||
|  | ||||
| 	for _, pair := range h.streamPairs { | ||||
| 		p := pair | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			h.portForward(p) | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| func (h *websocketStreamHandler) portForward(p *websocketStreamPair) { | ||||
| 	defer p.dataStream.Close() | ||||
| 	defer p.errorStream.Close() | ||||
|  | ||||
| 	glog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port) | ||||
| 	err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream) | ||||
| 	glog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err) | ||||
| 		runtime.HandleError(msg) | ||||
| 		fmt.Fprint(p.errorStream, msg.Error()) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										101
									
								
								pkg/kubelet/server/portforward/websocket_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								pkg/kubelet/server/portforward/websocket_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | ||||
| /* | ||||
| Copyright 2016 The Kubernetes Authors. | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| */ | ||||
|  | ||||
| package portforward | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestV4Options(t *testing.T) { | ||||
| 	tests := map[string]struct { | ||||
| 		url           string | ||||
| 		websocket     bool | ||||
| 		expectedOpts  *V4Options | ||||
| 		expectedError string | ||||
| 	}{ | ||||
| 		"non-ws request": { | ||||
| 			url:          "http://example.com", | ||||
| 			expectedOpts: &V4Options{}, | ||||
| 		}, | ||||
| 		"missing port": { | ||||
| 			url:           "http://example.com", | ||||
| 			websocket:     true, | ||||
| 			expectedError: `query parameter "port" is required`, | ||||
| 		}, | ||||
| 		"unable to parse port": { | ||||
| 			url:           "http://example.com?port=abc", | ||||
| 			websocket:     true, | ||||
| 			expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`, | ||||
| 		}, | ||||
| 		"negative port": { | ||||
| 			url:           "http://example.com?port=-1", | ||||
| 			websocket:     true, | ||||
| 			expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`, | ||||
| 		}, | ||||
| 		"one port": { | ||||
| 			url:       "http://example.com?port=80", | ||||
| 			websocket: true, | ||||
| 			expectedOpts: &V4Options{ | ||||
| 				Ports: []int32{80}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"multiple ports": { | ||||
| 			url:       "http://example.com?port=80,90,100", | ||||
| 			websocket: true, | ||||
| 			expectedOpts: &V4Options{ | ||||
| 				Ports: []int32{80, 90, 100}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"multiple port": { | ||||
| 			url:       "http://example.com?port=80&port=90", | ||||
| 			websocket: true, | ||||
| 			expectedOpts: &V4Options{ | ||||
| 				Ports: []int32{80, 90}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	for name, test := range tests { | ||||
| 		req, err := http.NewRequest(http.MethodGet, test.url, nil) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("%s: invalid url %q err=%q", name, test.url, err) | ||||
| 			continue | ||||
| 		} | ||||
| 		if test.websocket { | ||||
| 			req.Header.Set("Connection", "Upgrade") | ||||
| 			req.Header.Set("Upgrade", "websocket") | ||||
| 		} | ||||
| 		opts, err := NewV4Options(req) | ||||
| 		if len(test.expectedError) > 0 { | ||||
| 			if err == nil { | ||||
| 				t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) | ||||
| 			} | ||||
| 			if e, a := test.expectedError, err.Error(); e != a { | ||||
| 				t.Errorf("%s: expected err=%q, got %q", name, e, a) | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			t.Errorf("%s: unexpected error %v", name, err) | ||||
| 			continue | ||||
| 		} | ||||
| 		if !reflect.DeepEqual(test.expectedOpts, opts) { | ||||
| 			t.Errorf("%s: expected options %#v, got %#v", name, test.expectedOpts, err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -172,7 +172,7 @@ type HostInterface interface { | ||||
| 	AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan term.Size) error | ||||
| 	GetKubeletContainerLogs(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error | ||||
| 	ServeLogs(w http.ResponseWriter, req *http.Request) | ||||
| 	PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error | ||||
| 	PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error | ||||
| 	StreamingConnectionIdleTimeout() time.Duration | ||||
| 	ResyncInterval() time.Duration | ||||
| 	GetHostname() string | ||||
| @@ -184,7 +184,7 @@ type HostInterface interface { | ||||
| 	ListVolumesForPod(podUID types.UID) (map[string]volume.Volume, bool) | ||||
| 	GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommand.Options) (*url.URL, error) | ||||
| 	GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommand.Options) (*url.URL, error) | ||||
| 	GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) | ||||
| 	GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) | ||||
| } | ||||
|  | ||||
| // NewServer initializes and configures a kubelet.Server object to handle HTTP requests. | ||||
| @@ -335,9 +335,15 @@ func (s *Server) InstallDebuggingHandlers(criHandler http.Handler) { | ||||
| 	ws = new(restful.WebService) | ||||
| 	ws. | ||||
| 		Path("/portForward") | ||||
| 	ws.Route(ws.GET("/{podNamespace}/{podID}"). | ||||
| 		To(s.getPortForward). | ||||
| 		Operation("getPortForward")) | ||||
| 	ws.Route(ws.POST("/{podNamespace}/{podID}"). | ||||
| 		To(s.getPortForward). | ||||
| 		Operation("getPortForward")) | ||||
| 	ws.Route(ws.GET("/{podNamespace}/{podID}/{uid}"). | ||||
| 		To(s.getPortForward). | ||||
| 		Operation("getPortForward")) | ||||
| 	ws.Route(ws.POST("/{podNamespace}/{podID}/{uid}"). | ||||
| 		To(s.getPortForward). | ||||
| 		Operation("getPortForward")) | ||||
| @@ -562,7 +568,7 @@ func (s *Server) getSpec(request *restful.Request, response *restful.Response) { | ||||
| 	response.WriteEntity(info) | ||||
| } | ||||
|  | ||||
| type requestParams struct { | ||||
| type execRequestParams struct { | ||||
| 	podNamespace  string | ||||
| 	podName       string | ||||
| 	podUID        types.UID | ||||
| @@ -570,8 +576,8 @@ type requestParams struct { | ||||
| 	cmd           []string | ||||
| } | ||||
|  | ||||
| func getRequestParams(req *restful.Request) requestParams { | ||||
| 	return requestParams{ | ||||
| func getExecRequestParams(req *restful.Request) execRequestParams { | ||||
| 	return execRequestParams{ | ||||
| 		podNamespace:  req.PathParameter("podNamespace"), | ||||
| 		podName:       req.PathParameter("podID"), | ||||
| 		podUID:        types.UID(req.PathParameter("uid")), | ||||
| @@ -580,9 +586,23 @@ func getRequestParams(req *restful.Request) requestParams { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type portForwardRequestParams struct { | ||||
| 	podNamespace string | ||||
| 	podName      string | ||||
| 	podUID       types.UID | ||||
| } | ||||
|  | ||||
| func getPortForwardRequestParams(req *restful.Request) portForwardRequestParams { | ||||
| 	return portForwardRequestParams{ | ||||
| 		podNamespace: req.PathParameter("podNamespace"), | ||||
| 		podName:      req.PathParameter("podID"), | ||||
| 		podUID:       types.UID(req.PathParameter("uid")), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // getAttach handles requests to attach to a container. | ||||
| func (s *Server) getAttach(request *restful.Request, response *restful.Response) { | ||||
| 	params := getRequestParams(request) | ||||
| 	params := getExecRequestParams(request) | ||||
| 	streamOpts, err := remotecommand.NewOptions(request.Request) | ||||
| 	if err != nil { | ||||
| 		utilruntime.HandleError(err) | ||||
| @@ -620,7 +640,7 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response) | ||||
|  | ||||
| // getExec handles requests to run a command inside a container. | ||||
| func (s *Server) getExec(request *restful.Request, response *restful.Response) { | ||||
| 	params := getRequestParams(request) | ||||
| 	params := getExecRequestParams(request) | ||||
| 	streamOpts, err := remotecommand.NewOptions(request.Request) | ||||
| 	if err != nil { | ||||
| 		utilruntime.HandleError(err) | ||||
| @@ -659,7 +679,7 @@ func (s *Server) getExec(request *restful.Request, response *restful.Response) { | ||||
|  | ||||
| // getRun handles requests to run a command inside a container. | ||||
| func (s *Server) getRun(request *restful.Request, response *restful.Response) { | ||||
| 	params := getRequestParams(request) | ||||
| 	params := getExecRequestParams(request) | ||||
| 	pod, ok := s.host.GetPodByName(params.podNamespace, params.podName) | ||||
| 	if !ok { | ||||
| 		response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist")) | ||||
| @@ -693,7 +713,14 @@ func writeJsonResponse(response *restful.Response, data []byte) { | ||||
| // getPortForward handles a new restful port forward request. It determines the | ||||
| // pod name and uid and then calls ServePortForward. | ||||
| func (s *Server) getPortForward(request *restful.Request, response *restful.Response) { | ||||
| 	params := getRequestParams(request) | ||||
| 	params := getPortForwardRequestParams(request) | ||||
|  | ||||
| 	portForwardOptions, err := portforward.NewV4Options(request.Request) | ||||
| 	if err != nil { | ||||
| 		utilruntime.HandleError(err) | ||||
| 		response.WriteError(http.StatusBadRequest, err) | ||||
| 		return | ||||
| 	} | ||||
| 	pod, ok := s.host.GetPodByName(params.podNamespace, params.podName) | ||||
| 	if !ok { | ||||
| 		response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist")) | ||||
| @@ -704,7 +731,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID) | ||||
| 	redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID, *portForwardOptions) | ||||
| 	if err != nil { | ||||
| 		streaming.WriteError(err, response.ResponseWriter) | ||||
| 		return | ||||
| @@ -719,8 +746,10 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp | ||||
| 		s.host, | ||||
| 		kubecontainer.GetPodFullName(pod), | ||||
| 		params.podUID, | ||||
| 		portForwardOptions, | ||||
| 		s.host.StreamingConnectionIdleTimeout(), | ||||
| 		remotecommand.DefaultStreamCreationTimeout) | ||||
| 		remotecommand.DefaultStreamCreationTimeout, | ||||
| 		portforward.SupportedProtocols) | ||||
| } | ||||
|  | ||||
| // ServeHTTP responds to HTTP requests on the Kubelet. | ||||
|   | ||||
| @@ -52,6 +52,7 @@ import ( | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/cm" | ||||
| 	kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" | ||||
| 	kubecontainertesting "k8s.io/kubernetes/pkg/kubelet/container/testing" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/portforward" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/remotecommand" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/stats" | ||||
| 	"k8s.io/kubernetes/pkg/util/term" | ||||
| @@ -73,7 +74,7 @@ type fakeKubelet struct { | ||||
| 	runFunc                            func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error) | ||||
| 	execFunc                           func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error | ||||
| 	attachFunc                         func(pod string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error | ||||
| 	portForwardFunc                    func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error | ||||
| 	portForwardFunc                    func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error | ||||
| 	containerLogsFunc                  func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error | ||||
| 	streamingConnectionIdleTimeoutFunc func() time.Duration | ||||
| 	hostnameFunc                       func() string | ||||
| @@ -139,7 +140,7 @@ func (fk *fakeKubelet) AttachContainer(name string, uid types.UID, container str | ||||
| 	return fk.attachFunc(name, uid, container, in, out, err, tty) | ||||
| } | ||||
|  | ||||
| func (fk *fakeKubelet) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { | ||||
| func (fk *fakeKubelet) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { | ||||
| 	return fk.portForwardFunc(name, uid, port, stream) | ||||
| } | ||||
|  | ||||
| @@ -151,7 +152,7 @@ func (fk *fakeKubelet) GetAttach(podFullName string, podUID types.UID, container | ||||
| 	return fk.redirectURL, nil | ||||
| } | ||||
|  | ||||
| func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID) (*url.URL, error) { | ||||
| func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) { | ||||
| 	return fk.redirectURL, nil | ||||
| } | ||||
|  | ||||
| @@ -1503,7 +1504,7 @@ func TestServePortForward(t *testing.T) { | ||||
|  | ||||
| 		portForwardFuncDone := make(chan struct{}) | ||||
|  | ||||
| 		fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { | ||||
| 		fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { | ||||
| 			defer close(portForwardFuncDone) | ||||
|  | ||||
| 			if e, a := expectedPodName, name; e != a { | ||||
| @@ -1514,11 +1515,11 @@ func TestServePortForward(t *testing.T) { | ||||
| 				t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a) | ||||
| 			} | ||||
|  | ||||
| 			p, err := strconv.ParseUint(test.port, 10, 16) | ||||
| 			p, err := strconv.ParseInt(test.port, 10, 32) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) | ||||
| 			} | ||||
| 			if e, a := uint16(p), port; e != a { | ||||
| 			if e, a := int32(p), port; e != a { | ||||
| 				t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a) | ||||
| 			} | ||||
|  | ||||
|   | ||||
							
								
								
									
										331
									
								
								pkg/kubelet/server/server_websocket_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										331
									
								
								pkg/kubelet/server/server_websocket_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,331 @@ | ||||
| /* | ||||
| Copyright 2016 The Kubernetes Authors. | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| */ | ||||
|  | ||||
| package server | ||||
|  | ||||
| import ( | ||||
| 	"encoding/binary" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"golang.org/x/net/websocket" | ||||
|  | ||||
| 	"k8s.io/apimachinery/pkg/types" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	dataChannel = iota | ||||
| 	errorChannel | ||||
| ) | ||||
|  | ||||
| func TestServeWSPortForward(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		port          string | ||||
| 		uid           bool | ||||
| 		clientData    string | ||||
| 		containerData string | ||||
| 		shouldError   bool | ||||
| 	}{ | ||||
| 		{port: "", shouldError: true}, | ||||
| 		{port: "abc", shouldError: true}, | ||||
| 		{port: "-1", shouldError: true}, | ||||
| 		{port: "65536", shouldError: true}, | ||||
| 		{port: "0", shouldError: true}, | ||||
| 		{port: "1", shouldError: false}, | ||||
| 		{port: "8000", shouldError: false}, | ||||
| 		{port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, | ||||
| 		{port: "65535", shouldError: false}, | ||||
| 		{port: "65535", uid: true, shouldError: false}, | ||||
| 	} | ||||
|  | ||||
| 	podNamespace := "other" | ||||
| 	podName := "foo" | ||||
| 	expectedPodName := getPodName(podName, podNamespace) | ||||
| 	expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647" | ||||
|  | ||||
| 	for i, test := range tests { | ||||
| 		fw := newServerTest() | ||||
| 		defer fw.testHTTPServer.Close() | ||||
|  | ||||
| 		fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { | ||||
| 			return 0 | ||||
| 		} | ||||
|  | ||||
| 		portForwardFuncDone := make(chan struct{}) | ||||
|  | ||||
| 		fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { | ||||
| 			defer close(portForwardFuncDone) | ||||
|  | ||||
| 			if e, a := expectedPodName, name; e != a { | ||||
| 				t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) | ||||
| 			} | ||||
|  | ||||
| 			if e, a := expectedUid, uid; test.uid && e != string(a) { | ||||
| 				t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a) | ||||
| 			} | ||||
|  | ||||
| 			p, err := strconv.ParseInt(test.port, 10, 32) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) | ||||
| 			} | ||||
| 			if e, a := int32(p), port; e != a { | ||||
| 				t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a) | ||||
| 			} | ||||
|  | ||||
| 			if test.clientData != "" { | ||||
| 				fromClient := make([]byte, 32) | ||||
| 				n, err := stream.Read(fromClient) | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("%d: error reading client data: %v", i, err) | ||||
| 				} | ||||
| 				if e, a := test.clientData, string(fromClient[0:n]); e != a { | ||||
| 					t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			if test.containerData != "" { | ||||
| 				_, err := stream.Write([]byte(test.containerData)) | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("%d: error writing container data: %v", i, err) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		var url string | ||||
| 		if test.uid { | ||||
| 			url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port) | ||||
| 		} else { | ||||
| 			url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) | ||||
| 		} | ||||
|  | ||||
| 		ws, err := websocket.Dial(url, "", "http://127.0.0.1/") | ||||
| 		if test.shouldError { | ||||
| 			if err == nil { | ||||
| 				t.Fatalf("%d: websocket dial expected err", i) | ||||
| 			} | ||||
| 			continue | ||||
| 		} else if err != nil { | ||||
| 			t.Fatalf("%d: websocket dial unexpected err: %v", i, err) | ||||
| 		} | ||||
|  | ||||
| 		defer ws.Close() | ||||
|  | ||||
| 		p, err := strconv.ParseUint(test.port, 10, 16) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) | ||||
| 		} | ||||
| 		p16 := uint16(p) | ||||
|  | ||||
| 		channel, data, err := wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: read failed: expected no error: got %v", i, err) | ||||
| 		} | ||||
| 		if channel != dataChannel { | ||||
| 			t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel) | ||||
| 		} | ||||
| 		if len(data) != binary.Size(p16) { | ||||
| 			t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) | ||||
| 		} | ||||
| 		if e, a := p16, binary.LittleEndian.Uint16(data); e != a { | ||||
| 			t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) | ||||
| 		} | ||||
|  | ||||
| 		channel, data, err = wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) | ||||
| 		} | ||||
| 		if channel != errorChannel { | ||||
| 			t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel) | ||||
| 		} | ||||
| 		if len(data) != binary.Size(p16) { | ||||
| 			t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) | ||||
| 		} | ||||
| 		if e, a := p16, binary.LittleEndian.Uint16(data); e != a { | ||||
| 			t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) | ||||
| 		} | ||||
|  | ||||
| 		if test.clientData != "" { | ||||
| 			println("writing the client data") | ||||
| 			err := wsWrite(ws, dataChannel, []byte(test.clientData)) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("%d: unexpected error writing client data: %v", i, err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if test.containerData != "" { | ||||
| 			channel, data, err = wsRead(ws) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("%d: unexpected error reading container data: %v", i, err) | ||||
| 			} | ||||
| 			if e, a := test.containerData, string(data); e != a { | ||||
| 				t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		<-portForwardFuncDone | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestServeWSMultiplePortForward(t *testing.T) { | ||||
| 	portsText := []string{"7000,8000", "9000"} | ||||
| 	ports := []uint16{7000, 8000, 9000} | ||||
| 	podNamespace := "other" | ||||
| 	podName := "foo" | ||||
| 	expectedPodName := getPodName(podName, podNamespace) | ||||
|  | ||||
| 	fw := newServerTest() | ||||
| 	defer fw.testHTTPServer.Close() | ||||
|  | ||||
| 	fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { | ||||
| 		return 0 | ||||
| 	} | ||||
|  | ||||
| 	portForwardWG := sync.WaitGroup{} | ||||
| 	portForwardWG.Add(len(ports)) | ||||
|  | ||||
| 	portsMutex := sync.Mutex{} | ||||
| 	portsForwarded := map[int32]struct{}{} | ||||
|  | ||||
| 	fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { | ||||
| 		defer portForwardWG.Done() | ||||
|  | ||||
| 		if e, a := expectedPodName, name; e != a { | ||||
| 			t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a) | ||||
| 		} | ||||
|  | ||||
| 		portsMutex.Lock() | ||||
| 		portsForwarded[port] = struct{}{} | ||||
| 		portsMutex.Unlock() | ||||
|  | ||||
| 		fromClient := make([]byte, 32) | ||||
| 		n, err := stream.Read(fromClient) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: error reading client data: %v", port, err) | ||||
| 		} | ||||
| 		if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a { | ||||
| 			t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a) | ||||
| 		} | ||||
|  | ||||
| 		_, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port))) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: error writing container data: %v", port, err) | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName) | ||||
| 	for _, port := range portsText { | ||||
| 		url = url + fmt.Sprintf("port=%s&", port) | ||||
| 	} | ||||
|  | ||||
| 	ws, err := websocket.Dial(url, "", "http://127.0.0.1/") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("websocket dial unexpected err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	defer ws.Close() | ||||
|  | ||||
| 	for i, port := range ports { | ||||
| 		channel, data, err := wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: read failed: expected no error: got %v", i, err) | ||||
| 		} | ||||
| 		if int(channel) != i*2+dataChannel { | ||||
| 			t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel) | ||||
| 		} | ||||
| 		if len(data) != binary.Size(port) { | ||||
| 			t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) | ||||
| 		} | ||||
| 		if e, a := port, binary.LittleEndian.Uint16(data); e != a { | ||||
| 			t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) | ||||
| 		} | ||||
|  | ||||
| 		channel, data, err = wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) | ||||
| 		} | ||||
| 		if int(channel) != i*2+errorChannel { | ||||
| 			t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel) | ||||
| 		} | ||||
| 		if len(data) != binary.Size(port) { | ||||
| 			t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) | ||||
| 		} | ||||
| 		if e, a := port, binary.LittleEndian.Uint16(data); e != a { | ||||
| 			t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for i, port := range ports { | ||||
| 		println("writing the client data", port) | ||||
| 		err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port))) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: unexpected error writing client data: %v", i, err) | ||||
| 		} | ||||
|  | ||||
| 		channel, data, err := wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%d: unexpected error reading container data: %v", i, err) | ||||
| 		} | ||||
|  | ||||
| 		if int(channel) != i*2+dataChannel { | ||||
| 			t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel) | ||||
| 		} | ||||
| 		if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a { | ||||
| 			t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	portForwardWG.Wait() | ||||
|  | ||||
| 	portsMutex.Lock() | ||||
| 	defer portsMutex.Unlock() | ||||
| 	if len(ports) != len(portsForwarded) { | ||||
| 		t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded) | ||||
| 	} | ||||
| } | ||||
| func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { | ||||
| 	frame := make([]byte, len(data)+1) | ||||
| 	frame[0] = channel | ||||
| 	copy(frame[1:], data) | ||||
| 	err := websocket.Message.Send(conn, frame) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func wsRead(conn *websocket.Conn) (byte, []byte, error) { | ||||
| 	for { | ||||
| 		var data []byte | ||||
| 		err := websocket.Message.Receive(conn, &data) | ||||
| 		if err != nil { | ||||
| 			return 0, nil, err | ||||
| 		} | ||||
|  | ||||
| 		if len(data) == 0 { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		channel := data[0] | ||||
| 		data = data[1:] | ||||
|  | ||||
| 		return channel, data, err | ||||
| 	} | ||||
| } | ||||
| @@ -80,7 +80,12 @@ type Config struct { | ||||
| 	// The streaming protocols the server supports (understands and permits).  See | ||||
| 	// k8s.io/kubernetes/pkg/kubelet/server/remotecommand/constants.go for available protocols. | ||||
| 	// Only used for SPDY streaming. | ||||
| 	SupportedProtocols []string | ||||
| 	SupportedRemoteCommandProtocols []string | ||||
|  | ||||
| 	// The streaming protocols the server supports (understands and permits).  See | ||||
| 	// k8s.io/kubernetes/pkg/kubelet/server/portforward/constants.go for available protocols. | ||||
| 	// Only used for SPDY streaming. | ||||
| 	SupportedPortForwardProtocols []string | ||||
|  | ||||
| 	// The config for serving over TLS. If nil, TLS will not be used. | ||||
| 	TLSConfig *tls.Config | ||||
| @@ -89,9 +94,10 @@ type Config struct { | ||||
| // DefaultConfig provides default values for server Config. The DefaultConfig is partial, so | ||||
| // some fields like Addr must still be provided. | ||||
| var DefaultConfig = Config{ | ||||
| 	StreamIdleTimeout:     4 * time.Hour, | ||||
| 	StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout, | ||||
| 	SupportedProtocols:    remotecommand.SupportedStreamingProtocols, | ||||
| 	StreamIdleTimeout:               4 * time.Hour, | ||||
| 	StreamCreationTimeout:           remotecommand.DefaultStreamCreationTimeout, | ||||
| 	SupportedRemoteCommandProtocols: remotecommand.SupportedStreamingProtocols, | ||||
| 	SupportedPortForwardProtocols:   portforward.SupportedProtocols, | ||||
| } | ||||
|  | ||||
| // TODO(timstclair): Add auth(n/z) interface & handling. | ||||
| @@ -248,7 +254,7 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | ||||
| 		streamOpts, | ||||
| 		s.config.StreamIdleTimeout, | ||||
| 		s.config.StreamCreationTimeout, | ||||
| 		s.config.SupportedProtocols) | ||||
| 		s.config.SupportedRemoteCommandProtocols) | ||||
| } | ||||
|  | ||||
| func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { | ||||
| @@ -280,7 +286,7 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { | ||||
| 		streamOpts, | ||||
| 		s.config.StreamIdleTimeout, | ||||
| 		s.config.StreamCreationTimeout, | ||||
| 		s.config.SupportedProtocols) | ||||
| 		s.config.SupportedRemoteCommandProtocols) | ||||
| } | ||||
|  | ||||
| func (s *server) servePortForward(req *restful.Request, resp *restful.Response) { | ||||
| @@ -296,14 +302,22 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	portForwardOptions, err := portforward.NewV4Options(req.Request) | ||||
| 	if err != nil { | ||||
| 		resp.WriteError(http.StatusBadRequest, err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	portforward.ServePortForward( | ||||
| 		resp.ResponseWriter, | ||||
| 		req.Request, | ||||
| 		s.runtime, | ||||
| 		pf.PodSandboxId, | ||||
| 		"", // unused: podUID | ||||
| 		portForwardOptions, | ||||
| 		s.config.StreamIdleTimeout, | ||||
| 		s.config.StreamCreationTimeout) | ||||
| 		s.config.StreamCreationTimeout, | ||||
| 		s.config.SupportedPortForwardProtocols) | ||||
| } | ||||
|  | ||||
| // criAdapter wraps the Runtime functions to conform to the remotecommand interfaces. | ||||
| @@ -324,6 +338,6 @@ func (a *criAdapter) AttachContainer(podName string, podUID types.UID, container | ||||
| 	return a.Attach(container, in, out, err, tty, resize) | ||||
| } | ||||
|  | ||||
| func (a *criAdapter) PortForward(podName string, podUID types.UID, port uint16, stream io.ReadWriteCloser) error { | ||||
| 	return a.Runtime.PortForward(podName, int32(port), stream) | ||||
| func (a *criAdapter) PortForward(podName string, podUID types.UID, port int32, stream io.ReadWriteCloser) error { | ||||
| 	return a.Runtime.PortForward(podName, port, stream) | ||||
| } | ||||
|   | ||||
| @@ -240,7 +240,7 @@ func TestServePortForward(t *testing.T) { | ||||
|  | ||||
| 	exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) | ||||
| 	require.NoError(t, err) | ||||
| 	streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name) | ||||
| 	streamConn, _, err := exec.Dial(kubeletportforward.ProtocolV1Name) | ||||
| 	require.NoError(t, err) | ||||
| 	defer streamConn.Close() | ||||
|  | ||||
|   | ||||
| @@ -43,12 +43,14 @@ go_test( | ||||
|     deps = [ | ||||
|         "//pkg/api:go_default_library", | ||||
|         "//pkg/api/testing:go_default_library", | ||||
|         "//pkg/kubelet/client:go_default_library", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/api/errors", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/api/resource", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/apis/meta/v1", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/fields", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/labels", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/runtime", | ||||
|         "//vendor:k8s.io/apimachinery/pkg/types", | ||||
|         "//vendor:k8s.io/apiserver/pkg/endpoints/request", | ||||
|     ], | ||||
| ) | ||||
|   | ||||
| @@ -165,9 +165,10 @@ func (r *PortForwardREST) New() runtime.Object { | ||||
| 	return &api.Pod{} | ||||
| } | ||||
|  | ||||
| // NewConnectOptions returns nil since portforward doesn't take additional parameters | ||||
| // NewConnectOptions returns the versioned object that represents the | ||||
| // portforward parameters | ||||
| func (r *PortForwardREST) NewConnectOptions() (runtime.Object, bool, string) { | ||||
| 	return nil, false, "" | ||||
| 	return &api.PodPortForwardOptions{}, false, "" | ||||
| } | ||||
|  | ||||
| // ConnectMethods returns the methods supported by portforward | ||||
| @@ -177,7 +178,11 @@ func (r *PortForwardREST) ConnectMethods() []string { | ||||
|  | ||||
| // Connect returns a handler for the pod portforward proxy | ||||
| func (r *PortForwardREST) Connect(ctx genericapirequest.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) { | ||||
| 	location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name) | ||||
| 	portForwardOpts, ok := opts.(*api.PodPortForwardOptions) | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("invalid options object: %#v", opts) | ||||
| 	} | ||||
| 	location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name, portForwardOpts) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
| @@ -383,6 +383,14 @@ func streamParams(params url.Values, opts runtime.Object) error { | ||||
| 		if opts.TTY { | ||||
| 			params.Add(api.ExecTTYParam, "1") | ||||
| 		} | ||||
| 	case *api.PodPortForwardOptions: | ||||
| 		if len(opts.Ports) > 0 { | ||||
| 			ports := make([]string, len(opts.Ports)) | ||||
| 			for i, p := range opts.Ports { | ||||
| 				ports[i] = strconv.FormatInt(int64(p), 10) | ||||
| 			} | ||||
| 			params.Add(api.PortHeader, strings.Join(ports, ",")) | ||||
| 		} | ||||
| 	default: | ||||
| 		return fmt.Errorf("Unknown object for streaming: %v", opts) | ||||
| 	} | ||||
| @@ -477,6 +485,7 @@ func PortForwardLocation( | ||||
| 	connInfo client.ConnectionInfoGetter, | ||||
| 	ctx genericapirequest.Context, | ||||
| 	name string, | ||||
| 	opts *api.PodPortForwardOptions, | ||||
| ) (*url.URL, http.RoundTripper, error) { | ||||
| 	pod, err := getPod(getter, ctx, name) | ||||
| 	if err != nil { | ||||
| @@ -492,10 +501,15 @@ func PortForwardLocation( | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	params := url.Values{} | ||||
| 	if err := streamParams(params, opts); err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	loc := &url.URL{ | ||||
| 		Scheme: nodeInfo.Scheme, | ||||
| 		Host:   net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port), | ||||
| 		Path:   fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name), | ||||
| 		Scheme:   nodeInfo.Scheme, | ||||
| 		Host:     net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port), | ||||
| 		Path:     fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name), | ||||
| 		RawQuery: params.Encode(), | ||||
| 	} | ||||
| 	return loc, nodeInfo.Transport, nil | ||||
| } | ||||
|   | ||||
| @@ -17,6 +17,7 @@ limitations under the License. | ||||
| package pod | ||||
|  | ||||
| import ( | ||||
| 	"net/url" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| @@ -26,9 +27,11 @@ import ( | ||||
| 	"k8s.io/apimachinery/pkg/fields" | ||||
| 	"k8s.io/apimachinery/pkg/labels" | ||||
| 	"k8s.io/apimachinery/pkg/runtime" | ||||
| 	"k8s.io/apimachinery/pkg/types" | ||||
| 	genericapirequest "k8s.io/apiserver/pkg/endpoints/request" | ||||
| 	"k8s.io/kubernetes/pkg/api" | ||||
| 	apitesting "k8s.io/kubernetes/pkg/api/testing" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/client" | ||||
| ) | ||||
|  | ||||
| func TestMatchPod(t *testing.T) { | ||||
| @@ -333,3 +336,69 @@ func TestSelectableFieldLabelConversions(t *testing.T) { | ||||
| 		nil, | ||||
| 	) | ||||
| } | ||||
|  | ||||
| type mockConnectionInfoGetter struct { | ||||
| 	info *client.ConnectionInfo | ||||
| } | ||||
|  | ||||
| func (g mockConnectionInfoGetter) GetConnectionInfo(nodeName types.NodeName) (*client.ConnectionInfo, error) { | ||||
| 	return g.info, nil | ||||
| } | ||||
|  | ||||
| func TestPortForwardLocation(t *testing.T) { | ||||
| 	ctx := genericapirequest.NewDefaultContext() | ||||
| 	tcs := []struct { | ||||
| 		in          *api.Pod | ||||
| 		info        *client.ConnectionInfo | ||||
| 		opts        *api.PodPortForwardOptions | ||||
| 		expectedErr error | ||||
| 		expectedURL *url.URL | ||||
| 	}{ | ||||
| 		{ | ||||
| 			in: &api.Pod{ | ||||
| 				Spec: api.PodSpec{}, | ||||
| 			}, | ||||
| 			opts:        &api.PodPortForwardOptions{}, | ||||
| 			expectedErr: errors.NewBadRequest("pod test does not have a host assigned"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			in: &api.Pod{ | ||||
| 				ObjectMeta: metav1.ObjectMeta{ | ||||
| 					Namespace: "ns", | ||||
| 					Name:      "pod1", | ||||
| 				}, | ||||
| 				Spec: api.PodSpec{ | ||||
| 					NodeName: "node1", | ||||
| 				}, | ||||
| 			}, | ||||
| 			info:        &client.ConnectionInfo{}, | ||||
| 			opts:        &api.PodPortForwardOptions{}, | ||||
| 			expectedURL: &url.URL{Host: ":", Path: "/portForward/ns/pod1"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			in: &api.Pod{ | ||||
| 				ObjectMeta: metav1.ObjectMeta{ | ||||
| 					Namespace: "ns", | ||||
| 					Name:      "pod1", | ||||
| 				}, | ||||
| 				Spec: api.PodSpec{ | ||||
| 					NodeName: "node1", | ||||
| 				}, | ||||
| 			}, | ||||
| 			info:        &client.ConnectionInfo{}, | ||||
| 			opts:        &api.PodPortForwardOptions{Ports: []int32{80}}, | ||||
| 			expectedURL: &url.URL{Host: ":", Path: "/portForward/ns/pod1", RawQuery: "port=80"}, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tc := range tcs { | ||||
| 		getter := &mockPodGetter{tc.in} | ||||
| 		connectionGetter := &mockConnectionInfoGetter{tc.info} | ||||
| 		loc, _, err := PortForwardLocation(getter, connectionGetter, ctx, "test", tc.opts) | ||||
| 		if !reflect.DeepEqual(err, tc.expectedErr) { | ||||
| 			t.Errorf("expected %v, got %v", tc.expectedErr, err) | ||||
| 		} | ||||
| 		if !reflect.DeepEqual(loc, tc.expectedURL) { | ||||
| 			t.Errorf("expected %v, got %v", tc.expectedURL, loc) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -169,6 +169,7 @@ go_library( | ||||
|         "//vendor:github.com/onsi/gomega", | ||||
|         "//vendor:github.com/stretchr/testify/assert", | ||||
|         "//vendor:golang.org/x/crypto/ssh", | ||||
|         "//vendor:golang.org/x/net/websocket", | ||||
|         "//vendor:google.golang.org/api/compute/v1", | ||||
|         "//vendor:google.golang.org/api/googleapi", | ||||
|         "//vendor:gopkg.in/inf.v0", | ||||
|   | ||||
| @@ -17,6 +17,8 @@ limitations under the License. | ||||
| package e2e | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/binary" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| @@ -28,6 +30,7 @@ import ( | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"golang.org/x/net/websocket" | ||||
| 	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" | ||||
| 	"k8s.io/apimachinery/pkg/util/wait" | ||||
| 	"k8s.io/kubernetes/pkg/api/v1" | ||||
| @@ -36,6 +39,7 @@ import ( | ||||
| 	testutils "k8s.io/kubernetes/test/utils" | ||||
|  | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -368,36 +372,144 @@ func doTestMustConnectSendDisconnect(bindAddress string, f *framework.Framework) | ||||
| 	verifyLogMessage(logOutput, "^Done$") | ||||
| } | ||||
|  | ||||
| func doTestOverWebSockets(bindAddress string, f *framework.Framework) { | ||||
| 	config, err := framework.LoadConfig() | ||||
| 	Expect(err).NotTo(HaveOccurred(), "unable to get base config") | ||||
|  | ||||
| 	By("creating the pod") | ||||
| 	pod := pfPod("def", "10", "10", "100", fmt.Sprintf("%s", bindAddress)) | ||||
| 	if _, err := f.ClientSet.Core().Pods(f.Namespace.Name).Create(pod); err != nil { | ||||
| 		framework.Failf("Couldn't create pod: %v", err) | ||||
| 	} | ||||
| 	if err := f.WaitForPodReady(pod.Name); err != nil { | ||||
| 		framework.Failf("Pod did not start running: %v", err) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		logs, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester") | ||||
| 		if err != nil { | ||||
| 			framework.Logf("Error getting pod log: %v", err) | ||||
| 		} else { | ||||
| 			framework.Logf("Pod log:\n%s", logs) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	req := f.ClientSet.Core().RESTClient().Get(). | ||||
| 		Namespace(f.Namespace.Name). | ||||
| 		Resource("pods"). | ||||
| 		Name(pod.Name). | ||||
| 		Suffix("portforward"). | ||||
| 		Param("ports", "80") | ||||
|  | ||||
| 	url := req.URL() | ||||
| 	ws, err := framework.OpenWebSocketForURL(url, config, []string{"v4.channel.k8s.io"}) | ||||
| 	if err != nil { | ||||
| 		framework.Failf("Failed to open websocket to %s: %v", url.String(), err) | ||||
| 	} | ||||
| 	defer ws.Close() | ||||
|  | ||||
| 	Eventually(func() error { | ||||
| 		channel, msg, err := wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err) | ||||
| 		} | ||||
| 		if channel != 0 { | ||||
| 			return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg) | ||||
| 		} | ||||
| 		if p := binary.LittleEndian.Uint16(msg); p != 80 { | ||||
| 			return fmt.Errorf("Received the wrong port: %d", p) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}, time.Minute, 10*time.Second).Should(BeNil()) | ||||
|  | ||||
| 	Eventually(func() error { | ||||
| 		channel, msg, err := wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err) | ||||
| 		} | ||||
| 		if channel != 1 { | ||||
| 			return fmt.Errorf("Got message from server that didn't start with channel 1 (error): %v", msg) | ||||
| 		} | ||||
| 		if p := binary.LittleEndian.Uint16(msg); p != 80 { | ||||
| 			return fmt.Errorf("Received the wrong port: %d", p) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}, time.Minute, 10*time.Second).Should(BeNil()) | ||||
|  | ||||
| 	By("sending the expected data to the local port") | ||||
| 	err = wsWrite(ws, 0, []byte("def")) | ||||
| 	if err != nil { | ||||
| 		framework.Failf("Failed to write to websocket %s: %v", url.String(), err) | ||||
| 	} | ||||
|  | ||||
| 	By("reading data from the local port") | ||||
| 	buf := bytes.Buffer{} | ||||
| 	expectedData := bytes.Repeat([]byte("x"), 100) | ||||
| 	Eventually(func() error { | ||||
| 		channel, msg, err := wsRead(ws) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err) | ||||
| 		} | ||||
| 		if channel != 0 { | ||||
| 			return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg) | ||||
| 		} | ||||
| 		buf.Write(msg) | ||||
| 		if bytes.Equal(expectedData, buf.Bytes()) { | ||||
| 			return fmt.Errorf("Expected %q from server, got %q", expectedData, buf.Bytes()) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}, time.Minute, 10*time.Second).Should(BeNil()) | ||||
|  | ||||
| 	By("verifying logs") | ||||
| 	logOutput, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester") | ||||
| 	if err != nil { | ||||
| 		framework.Failf("Error retrieving pod logs: %v", err) | ||||
| 	} | ||||
| 	verifyLogMessage(logOutput, "^Accepted client connection$") | ||||
| 	verifyLogMessage(logOutput, "^Received expected client data$") | ||||
| } | ||||
|  | ||||
| var _ = framework.KubeDescribe("Port forwarding", func() { | ||||
| 	f := framework.NewDefaultFramework("port-forwarding") | ||||
|  | ||||
| 	framework.KubeDescribe("With a server  listening on 0.0.0.0 that expects a client request", func() { | ||||
| 		It("should support a client that connects, sends no data, and disconnects", func() { | ||||
| 			doTestMustConnectSendNothing("0.0.0.0", f) | ||||
| 	framework.KubeDescribe("With a server listening on 0.0.0.0", func() { | ||||
| 		framework.KubeDescribe("that expects a client request", func() { | ||||
| 			It("should support a client that connects, sends no data, and disconnects", func() { | ||||
| 				doTestMustConnectSendNothing("0.0.0.0", f) | ||||
| 			}) | ||||
| 			It("should support a client that connects, sends data, and disconnects", func() { | ||||
| 				doTestMustConnectSendDisconnect("0.0.0.0", f) | ||||
| 			}) | ||||
| 		}) | ||||
| 		It("should support a client that connects, sends data, and disconnects", func() { | ||||
| 			doTestMustConnectSendDisconnect("0.0.0.0", f) | ||||
|  | ||||
| 		framework.KubeDescribe("that expects no client request", func() { | ||||
| 			It("should support a client that connects, sends data, and disconnects", func() { | ||||
| 				doTestConnectSendDisconnect("0.0.0.0", f) | ||||
| 			}) | ||||
| 		}) | ||||
|  | ||||
| 		It("should support forwarding over websockets", func() { | ||||
| 			doTestOverWebSockets("0.0.0.0", f) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	framework.KubeDescribe("With a server  listening on 0.0.0.0 that expects no client request", func() { | ||||
| 		It("should support a client that connects, sends data, and disconnects", func() { | ||||
| 			doTestConnectSendDisconnect("0.0.0.0", f) | ||||
| 	framework.KubeDescribe("With a server listening on localhost", func() { | ||||
| 		framework.KubeDescribe("that expects a client request", func() { | ||||
| 			It("should support a client that connects, sends no data, and disconnects [Conformance]", func() { | ||||
| 				doTestMustConnectSendNothing("localhost", f) | ||||
| 			}) | ||||
| 			It("should support a client that connects, sends data, and disconnects [Conformance]", func() { | ||||
| 				doTestMustConnectSendDisconnect("localhost", f) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	framework.KubeDescribe("With a server  listening on localhost that expects a client request", func() { | ||||
| 		It("should support a client that connects, sends no data, and disconnects [Conformance]", func() { | ||||
| 			doTestMustConnectSendNothing("localhost", f) | ||||
| 		framework.KubeDescribe("that expects no client request", func() { | ||||
| 			It("should support a client that connects, sends data, and disconnects [Conformance]", func() { | ||||
| 				doTestConnectSendDisconnect("localhost", f) | ||||
| 			}) | ||||
| 		}) | ||||
| 		It("should support a client that connects, sends data, and disconnects [Conformance]", func() { | ||||
| 			doTestMustConnectSendDisconnect("localhost", f) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	framework.KubeDescribe("With a server  listening on localhost that expects no client request", func() { | ||||
| 		It("should support a client that connects, sends data, and disconnects [Conformance]", func() { | ||||
| 			doTestConnectSendDisconnect("localhost", f) | ||||
| 		It("should support forwarding over websockets", func() { | ||||
| 			doTestOverWebSockets("localhost", f) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
| @@ -412,3 +524,30 @@ func verifyLogMessage(log, expected string) { | ||||
| 	} | ||||
| 	framework.Failf("Missing %q from log: %s", expected, log) | ||||
| } | ||||
|  | ||||
| func wsRead(conn *websocket.Conn) (byte, []byte, error) { | ||||
| 	for { | ||||
| 		var data []byte | ||||
| 		err := websocket.Message.Receive(conn, &data) | ||||
| 		if err != nil { | ||||
| 			return 0, nil, err | ||||
| 		} | ||||
|  | ||||
| 		if len(data) == 0 { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		channel := data[0] | ||||
| 		data = data[1:] | ||||
|  | ||||
| 		return channel, data, err | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { | ||||
| 	frame := make([]byte, len(data)+1) | ||||
| 	frame[0] = channel | ||||
| 	copy(frame[1:], data) | ||||
| 	err := websocket.Message.Send(conn, frame) | ||||
| 	return err | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Kubernetes Submit Queue
					Kubernetes Submit Queue