mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-10-31 18:28:13 +00:00 
			
		
		
		
	Don't include user data in CRI streaming redirect URLs
This commit is contained in:
		| @@ -606,7 +606,7 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response) | ||||
| 	podFullName := kubecontainer.GetPodFullName(pod) | ||||
| 	redirect, err := s.host.GetAttach(podFullName, params.podUID, params.containerName, *streamOpts) | ||||
| 	if err != nil { | ||||
| 		response.WriteError(streaming.HTTPStatus(err), err) | ||||
| 		streaming.WriteError(err, response.ResponseWriter) | ||||
| 		return | ||||
| 	} | ||||
| 	if redirect != nil { | ||||
| @@ -644,7 +644,7 @@ func (s *Server) getExec(request *restful.Request, response *restful.Response) { | ||||
| 	podFullName := kubecontainer.GetPodFullName(pod) | ||||
| 	redirect, err := s.host.GetExec(podFullName, params.podUID, params.containerName, params.cmd, *streamOpts) | ||||
| 	if err != nil { | ||||
| 		response.WriteError(streaming.HTTPStatus(err), err) | ||||
| 		streaming.WriteError(err, response.ResponseWriter) | ||||
| 		return | ||||
| 	} | ||||
| 	if redirect != nil { | ||||
| @@ -714,7 +714,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp | ||||
|  | ||||
| 	redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID) | ||||
| 	if err != nil { | ||||
| 		response.WriteError(streaming.HTTPStatus(err), err) | ||||
| 		streaming.WriteError(err, response.ResponseWriter) | ||||
| 		return | ||||
| 	} | ||||
| 	if redirect != nil { | ||||
|   | ||||
| @@ -12,14 +12,15 @@ go_library( | ||||
|     name = "go_default_library", | ||||
|     srcs = [ | ||||
|         "errors.go", | ||||
|         "request_cache.go", | ||||
|         "server.go", | ||||
|     ], | ||||
|     tags = ["automanaged"], | ||||
|     deps = [ | ||||
|         "//pkg/api:go_default_library", | ||||
|         "//pkg/kubelet/api/v1alpha1/runtime:go_default_library", | ||||
|         "//pkg/kubelet/server/portforward:go_default_library", | ||||
|         "//pkg/kubelet/server/remotecommand:go_default_library", | ||||
|         "//pkg/util/clock:go_default_library", | ||||
|         "//pkg/util/term:go_default_library", | ||||
|         "//vendor:github.com/emicklei/go-restful", | ||||
|         "//vendor:google.golang.org/grpc", | ||||
| @@ -30,7 +31,10 @@ go_library( | ||||
|  | ||||
| go_test( | ||||
|     name = "go_default_test", | ||||
|     srcs = ["server_test.go"], | ||||
|     srcs = [ | ||||
|         "request_cache_test.go", | ||||
|         "server_test.go", | ||||
|     ], | ||||
|     library = ":go_default_library", | ||||
|     tags = ["automanaged"], | ||||
|     deps = [ | ||||
| @@ -43,6 +47,7 @@ go_test( | ||||
|         "//vendor:github.com/stretchr/testify/assert", | ||||
|         "//vendor:github.com/stretchr/testify/require", | ||||
|         "//vendor:k8s.io/client-go/pkg/api", | ||||
|         "//vendor:k8s.io/client-go/pkg/util/clock", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
|   | ||||
| @@ -19,6 +19,7 @@ package streaming | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"google.golang.org/grpc" | ||||
| @@ -33,12 +34,27 @@ func ErrorTimeout(op string, timeout time.Duration) error { | ||||
| 	return grpc.Errorf(codes.DeadlineExceeded, fmt.Sprintf("%s timed out after %s", op, timeout.String())) | ||||
| } | ||||
|  | ||||
| // Translates a CRI streaming error into an HTTP status code. | ||||
| func HTTPStatus(err error) int { | ||||
| // The error returned when the maximum number of in-flight requests is exceeded. | ||||
| func ErrorTooManyInFlight() error { | ||||
| 	return grpc.Errorf(codes.ResourceExhausted, "maximum number of in-flight requests exceeded") | ||||
| } | ||||
|  | ||||
| // Translates a CRI streaming error into an appropriate HTTP response. | ||||
| func WriteError(err error, w http.ResponseWriter) error { | ||||
| 	var status int | ||||
| 	switch grpc.Code(err) { | ||||
| 	case codes.NotFound: | ||||
| 		return http.StatusNotFound | ||||
| 		status = http.StatusNotFound | ||||
| 	case codes.ResourceExhausted: | ||||
| 		// We only expect to hit this if there is a DoS, so we just wait the full TTL. | ||||
| 		// If this is ever hit in steady-state operations, consider increasing the MaxInFlight requests, | ||||
| 		// or plumbing through the time to next expiration. | ||||
| 		w.Header().Set("Retry-After", strconv.Itoa(int(CacheTTL.Seconds()))) | ||||
| 		status = http.StatusTooManyRequests | ||||
| 	default: | ||||
| 		return http.StatusInternalServerError | ||||
| 		status = http.StatusInternalServerError | ||||
| 	} | ||||
| 	w.WriteHeader(status) | ||||
| 	_, writeErr := w.Write([]byte(err.Error())) | ||||
| 	return writeErr | ||||
| } | ||||
|   | ||||
							
								
								
									
										146
									
								
								pkg/kubelet/server/streaming/request_cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								pkg/kubelet/server/streaming/request_cache.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,146 @@ | ||||
| /* | ||||
| Copyright 2016 The Kubernetes Authors. | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| */ | ||||
|  | ||||
| package streaming | ||||
|  | ||||
| import ( | ||||
| 	"container/list" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"k8s.io/kubernetes/pkg/util/clock" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// Timeout after which tokens become invalid. | ||||
| 	CacheTTL = 1 * time.Minute | ||||
| 	// The maximum number of in-flight requests to allow. | ||||
| 	MaxInFlight = 1000 | ||||
| 	// Length of the random base64 encoded token identifying the request. | ||||
| 	TokenLen = 8 | ||||
| ) | ||||
|  | ||||
| // requestCache caches streaming (exec/attach/port-forward) requests and generates a single-use | ||||
| // random token for their retrieval. The requestCache is used for building streaming URLs without | ||||
| // the need to encode every request parameter in the URL. | ||||
| type requestCache struct { | ||||
| 	// clock is used to obtain the current time | ||||
| 	clock clock.Clock | ||||
|  | ||||
| 	// tokens maps the generate token to the request for fast retrieval. | ||||
| 	tokens map[string]*list.Element | ||||
| 	// ll maintains an age-ordered request list for faster garbage collection of expired requests. | ||||
| 	ll *list.List | ||||
|  | ||||
| 	lock sync.Mutex | ||||
| } | ||||
|  | ||||
| // Type representing an *ExecRequest, *AttachRequest, or *PortForwardRequest. | ||||
| type request interface{} | ||||
|  | ||||
| type cacheEntry struct { | ||||
| 	token      string | ||||
| 	req        request | ||||
| 	expireTime time.Time | ||||
| } | ||||
|  | ||||
| func newRequestCache() *requestCache { | ||||
| 	return &requestCache{ | ||||
| 		clock:  clock.RealClock{}, | ||||
| 		ll:     list.New(), | ||||
| 		tokens: make(map[string]*list.Element), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Insert the given request into the cache and returns the token used for fetching it out. | ||||
| func (c *requestCache) Insert(req request) (token string, err error) { | ||||
| 	c.lock.Lock() | ||||
| 	defer c.lock.Unlock() | ||||
|  | ||||
| 	// Remove expired entries. | ||||
| 	c.gc() | ||||
| 	// If the cache is full, reject the request. | ||||
| 	if c.ll.Len() == MaxInFlight { | ||||
| 		return "", ErrorTooManyInFlight() | ||||
| 	} | ||||
| 	token, err = c.uniqueToken() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	ele := c.ll.PushFront(&cacheEntry{token, req, c.clock.Now().Add(CacheTTL)}) | ||||
|  | ||||
| 	c.tokens[token] = ele | ||||
| 	return token, nil | ||||
| } | ||||
|  | ||||
| // Consume the token (remove it from the cache) and return the cached request, if found. | ||||
| func (c *requestCache) Consume(token string) (req request, found bool) { | ||||
| 	c.lock.Lock() | ||||
| 	defer c.lock.Unlock() | ||||
| 	ele, ok := c.tokens[token] | ||||
| 	if !ok { | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	c.ll.Remove(ele) | ||||
| 	delete(c.tokens, token) | ||||
|  | ||||
| 	entry := ele.Value.(*cacheEntry) | ||||
| 	if c.clock.Now().After(entry.expireTime) { | ||||
| 		// Entry already expired. | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	return entry.req, true | ||||
| } | ||||
|  | ||||
| // uniqueToken generates a random URL-safe token and ensures uniqueness. | ||||
| func (c *requestCache) uniqueToken() (string, error) { | ||||
| 	const maxTries = 10 | ||||
| 	// Number of bytes to be TokenLen when base64 encoded. | ||||
| 	tokenSize := math.Ceil(float64(TokenLen) * 6 / 8) | ||||
| 	rawToken := make([]byte, int(tokenSize)) | ||||
| 	for i := 0; i < maxTries; i++ { | ||||
| 		if _, err := rand.Read(rawToken); err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		encoded := base64.RawURLEncoding.EncodeToString(rawToken) | ||||
| 		token := encoded[:TokenLen] | ||||
| 		// If it's unique, return it. Otherwise retry. | ||||
| 		if _, exists := c.tokens[encoded]; !exists { | ||||
| 			return token, nil | ||||
| 		} | ||||
| 	} | ||||
| 	return "", fmt.Errorf("failed to generate unique token") | ||||
| } | ||||
|  | ||||
| // Must be write-locked prior to calling. | ||||
| func (c *requestCache) gc() { | ||||
| 	now := c.clock.Now() | ||||
| 	for c.ll.Len() > 0 { | ||||
| 		oldest := c.ll.Back() | ||||
| 		entry := oldest.Value.(*cacheEntry) | ||||
| 		if !now.After(entry.expireTime) { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Oldest value is expired; remove it. | ||||
| 		c.ll.Remove(oldest) | ||||
| 		delete(c.tokens, entry.token) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										221
									
								
								pkg/kubelet/server/streaming/request_cache_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								pkg/kubelet/server/streaming/request_cache_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,221 @@ | ||||
| /* | ||||
| Copyright 2016 The Kubernetes Authors. | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| */ | ||||
|  | ||||
| package streaming | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
|  | ||||
| 	"k8s.io/client-go/pkg/util/clock" | ||||
| ) | ||||
|  | ||||
| func TestInsert(t *testing.T) { | ||||
| 	c, _ := newTestCache() | ||||
|  | ||||
| 	// Insert normal | ||||
| 	oldestTok, err := c.Insert(nextRequest()) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Len(t, oldestTok, TokenLen) | ||||
| 	assertCacheSize(t, c, 1) | ||||
|  | ||||
| 	// Insert until full | ||||
| 	for i := 0; i < MaxInFlight-2; i++ { | ||||
| 		tok, err := c.Insert(nextRequest()) | ||||
| 		require.NoError(t, err) | ||||
| 		assert.Len(t, tok, TokenLen) | ||||
| 	} | ||||
| 	assertCacheSize(t, c, MaxInFlight-1) | ||||
|  | ||||
| 	newestReq := nextRequest() | ||||
| 	newestTok, err := c.Insert(newestReq) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Len(t, newestTok, TokenLen) | ||||
| 	assertCacheSize(t, c, MaxInFlight) | ||||
| 	require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached") | ||||
|  | ||||
| 	// Consume newest token. | ||||
| 	req, ok := c.Consume(newestTok) | ||||
| 	assert.True(t, ok, "newest request should still be cached") | ||||
| 	assert.Equal(t, newestReq, req) | ||||
| 	require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached") | ||||
|  | ||||
| 	// Insert again (still full) | ||||
| 	tok, err := c.Insert(nextRequest()) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Len(t, tok, TokenLen) | ||||
| 	assertCacheSize(t, c, MaxInFlight) | ||||
|  | ||||
| 	// Insert again (should evict) | ||||
| 	_, err = c.Insert(nextRequest()) | ||||
| 	assert.Error(t, err, "should reject further requests") | ||||
| 	errResponse := httptest.NewRecorder() | ||||
| 	require.NoError(t, WriteError(err, errResponse)) | ||||
| 	assert.Equal(t, errResponse.Code, http.StatusTooManyRequests) | ||||
| 	assert.Equal(t, strconv.Itoa(int(CacheTTL.Seconds())), errResponse.HeaderMap.Get("Retry-After")) | ||||
|  | ||||
| 	assertCacheSize(t, c, MaxInFlight) | ||||
| 	_, ok = c.Consume(oldestTok) | ||||
| 	assert.True(t, ok, "oldest request should be valid") | ||||
| } | ||||
|  | ||||
| func TestConsume(t *testing.T) { | ||||
| 	c, clock := newTestCache() | ||||
|  | ||||
| 	{ // Insert & consume. | ||||
| 		req := nextRequest() | ||||
| 		tok, err := c.Insert(req) | ||||
| 		require.NoError(t, err) | ||||
| 		assertCacheSize(t, c, 1) | ||||
|  | ||||
| 		cachedReq, ok := c.Consume(tok) | ||||
| 		assert.True(t, ok) | ||||
| 		assert.Equal(t, req, cachedReq) | ||||
| 		assertCacheSize(t, c, 0) | ||||
| 	} | ||||
|  | ||||
| 	{ // Insert & consume out of order | ||||
| 		req1 := nextRequest() | ||||
| 		tok1, err := c.Insert(req1) | ||||
| 		require.NoError(t, err) | ||||
| 		assertCacheSize(t, c, 1) | ||||
|  | ||||
| 		req2 := nextRequest() | ||||
| 		tok2, err := c.Insert(req2) | ||||
| 		require.NoError(t, err) | ||||
| 		assertCacheSize(t, c, 2) | ||||
|  | ||||
| 		cachedReq2, ok := c.Consume(tok2) | ||||
| 		assert.True(t, ok) | ||||
| 		assert.Equal(t, req2, cachedReq2) | ||||
| 		assertCacheSize(t, c, 1) | ||||
|  | ||||
| 		cachedReq1, ok := c.Consume(tok1) | ||||
| 		assert.True(t, ok) | ||||
| 		assert.Equal(t, req1, cachedReq1) | ||||
| 		assertCacheSize(t, c, 0) | ||||
| 	} | ||||
|  | ||||
| 	{ // Consume a second time | ||||
| 		req := nextRequest() | ||||
| 		tok, err := c.Insert(req) | ||||
| 		require.NoError(t, err) | ||||
| 		assertCacheSize(t, c, 1) | ||||
|  | ||||
| 		cachedReq, ok := c.Consume(tok) | ||||
| 		assert.True(t, ok) | ||||
| 		assert.Equal(t, req, cachedReq) | ||||
| 		assertCacheSize(t, c, 0) | ||||
|  | ||||
| 		_, ok = c.Consume(tok) | ||||
| 		assert.False(t, ok) | ||||
| 		assertCacheSize(t, c, 0) | ||||
| 	} | ||||
|  | ||||
| 	{ // Consume without insert | ||||
| 		_, ok := c.Consume("fooBAR") | ||||
| 		assert.False(t, ok) | ||||
| 		assertCacheSize(t, c, 0) | ||||
| 	} | ||||
|  | ||||
| 	{ // Consume expired | ||||
| 		tok, err := c.Insert(nextRequest()) | ||||
| 		require.NoError(t, err) | ||||
| 		assertCacheSize(t, c, 1) | ||||
|  | ||||
| 		clock.Step(2 * CacheTTL) | ||||
|  | ||||
| 		_, ok := c.Consume(tok) | ||||
| 		assert.False(t, ok) | ||||
| 		assertCacheSize(t, c, 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestGC(t *testing.T) { | ||||
| 	c, clock := newTestCache() | ||||
|  | ||||
| 	// When empty | ||||
| 	c.gc() | ||||
| 	assertCacheSize(t, c, 0) | ||||
|  | ||||
| 	tok1, err := c.Insert(nextRequest()) | ||||
| 	require.NoError(t, err) | ||||
| 	assertCacheSize(t, c, 1) | ||||
| 	clock.Step(10 * time.Second) | ||||
| 	tok2, err := c.Insert(nextRequest()) | ||||
| 	require.NoError(t, err) | ||||
| 	assertCacheSize(t, c, 2) | ||||
|  | ||||
| 	// expired: tok1, tok2 | ||||
| 	// non-expired: tok3, tok4 | ||||
| 	clock.Step(2 * CacheTTL) | ||||
| 	tok3, err := c.Insert(nextRequest()) | ||||
| 	require.NoError(t, err) | ||||
| 	assertCacheSize(t, c, 1) | ||||
| 	clock.Step(10 * time.Second) | ||||
| 	tok4, err := c.Insert(nextRequest()) | ||||
| 	require.NoError(t, err) | ||||
| 	assertCacheSize(t, c, 2) | ||||
|  | ||||
| 	_, ok := c.Consume(tok1) | ||||
| 	assert.False(t, ok) | ||||
| 	_, ok = c.Consume(tok2) | ||||
| 	assert.False(t, ok) | ||||
| 	_, ok = c.Consume(tok3) | ||||
| 	assert.True(t, ok) | ||||
| 	_, ok = c.Consume(tok4) | ||||
| 	assert.True(t, ok) | ||||
|  | ||||
| 	// When full, nothing is expired. | ||||
| 	for i := 0; i < MaxInFlight; i++ { | ||||
| 		_, err := c.Insert(nextRequest()) | ||||
| 		require.NoError(t, err) | ||||
| 	} | ||||
| 	assertCacheSize(t, c, MaxInFlight) | ||||
|  | ||||
| 	// When everything is expired | ||||
| 	clock.Step(2 * CacheTTL) | ||||
| 	_, err = c.Insert(nextRequest()) | ||||
| 	require.NoError(t, err) | ||||
| 	assertCacheSize(t, c, 1) | ||||
| } | ||||
|  | ||||
| func newTestCache() (*requestCache, *clock.FakeClock) { | ||||
| 	c := newRequestCache() | ||||
| 	fakeClock := clock.NewFakeClock(time.Now()) | ||||
| 	c.clock = fakeClock | ||||
| 	return c, fakeClock | ||||
| } | ||||
|  | ||||
| func assertCacheSize(t *testing.T, cache *requestCache, expectedSize int) { | ||||
| 	tokenLen := len(cache.tokens) | ||||
| 	llLen := cache.ll.Len() | ||||
| 	assert.Equal(t, tokenLen, llLen, "inconsistent cache size! len(tokens)=%d; len(ll)=%d", tokenLen, llLen) | ||||
| 	assert.Equal(t, expectedSize, tokenLen, "unexpected cache size!") | ||||
| } | ||||
|  | ||||
| var requestUID = 0 | ||||
|  | ||||
| func nextRequest() interface{} { | ||||
| 	requestUID++ | ||||
| 	return requestUID | ||||
| } | ||||
| @@ -25,10 +25,12 @@ import ( | ||||
| 	"path" | ||||
| 	"time" | ||||
|  | ||||
| 	"google.golang.org/grpc" | ||||
| 	"google.golang.org/grpc/codes" | ||||
|  | ||||
| 	restful "github.com/emicklei/go-restful" | ||||
|  | ||||
| 	"k8s.io/apimachinery/pkg/types" | ||||
| 	"k8s.io/kubernetes/pkg/api" | ||||
| 	runtimeapi "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/portforward" | ||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/remotecommand" | ||||
| @@ -97,6 +99,7 @@ func NewServer(config Config, runtime Runtime) (Server, error) { | ||||
| 	s := &server{ | ||||
| 		config:  config, | ||||
| 		runtime: &criAdapter{runtime}, | ||||
| 		cache:   newRequestCache(), | ||||
| 	} | ||||
|  | ||||
| 	if s.config.BaseURL == nil { | ||||
| @@ -114,9 +117,9 @@ func NewServer(config Config, runtime Runtime) (Server, error) { | ||||
| 		path    string | ||||
| 		handler restful.RouteFunction | ||||
| 	}{ | ||||
| 		{"/exec/{containerID}", s.serveExec}, | ||||
| 		{"/attach/{containerID}", s.serveAttach}, | ||||
| 		{"/portforward/{podSandboxID}", s.servePortForward}, | ||||
| 		{"/exec/{token}", s.serveExec}, | ||||
| 		{"/attach/{token}", s.serveAttach}, | ||||
| 		{"/portforward/{token}", s.servePortForward}, | ||||
| 	} | ||||
| 	// If serving relative to a base path, set that here. | ||||
| 	pathPrefix := path.Dir(s.config.BaseURL.Path) | ||||
| @@ -139,37 +142,45 @@ type server struct { | ||||
| 	config  Config | ||||
| 	runtime *criAdapter | ||||
| 	handler http.Handler | ||||
| 	cache   *requestCache | ||||
| } | ||||
|  | ||||
| func (s *server) GetExec(req *runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) { | ||||
| 	url := s.buildURL("exec", req.GetContainerId(), streamOpts{ | ||||
| 		stdin:   req.GetStdin(), | ||||
| 		stdout:  true, | ||||
| 		stderr:  !req.GetTty(), // For TTY connections, both stderr is combined with stdout. | ||||
| 		tty:     req.GetTty(), | ||||
| 		command: req.GetCmd(), | ||||
| 	}) | ||||
| 	if req.GetContainerId() == "" { | ||||
| 		return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id") | ||||
| 	} | ||||
| 	token, err := s.cache.Insert(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &runtimeapi.ExecResponse{ | ||||
| 		Url: &url, | ||||
| 		Url: s.buildURL("exec", token), | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s *server) GetAttach(req *runtimeapi.AttachRequest) (*runtimeapi.AttachResponse, error) { | ||||
| 	url := s.buildURL("attach", req.GetContainerId(), streamOpts{ | ||||
| 		stdin:  req.GetStdin(), | ||||
| 		stdout: true, | ||||
| 		stderr: !req.GetTty(), // For TTY connections, both stderr is combined with stdout. | ||||
| 		tty:    req.GetTty(), | ||||
| 	}) | ||||
| 	if req.GetContainerId() == "" { | ||||
| 		return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id") | ||||
| 	} | ||||
| 	token, err := s.cache.Insert(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &runtimeapi.AttachResponse{ | ||||
| 		Url: &url, | ||||
| 		Url: s.buildURL("attach", token), | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s *server) GetPortForward(req *runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) { | ||||
| 	url := s.buildURL("portforward", req.GetPodSandboxId(), streamOpts{}) | ||||
| 	if req.GetPodSandboxId() == "" { | ||||
| 		return nil, grpc.Errorf(codes.InvalidArgument, "missing required pod_sandbox_id") | ||||
| 	} | ||||
| 	token, err := s.cache.Insert(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &runtimeapi.PortForwardResponse{ | ||||
| 		Url: &url, | ||||
| 		Url: s.buildURL("portforward", token), | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -200,63 +211,32 @@ func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 	s.handler.ServeHTTP(w, r) | ||||
| } | ||||
|  | ||||
| type streamOpts struct { | ||||
| 	stdin  bool | ||||
| 	stdout bool | ||||
| 	stderr bool | ||||
| 	tty    bool | ||||
|  | ||||
| 	command []string | ||||
| 	port    []int32 | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	urlParamStdin   = api.ExecStdinParam | ||||
| 	urlParamStdout  = api.ExecStdoutParam | ||||
| 	urlParamStderr  = api.ExecStderrParam | ||||
| 	urlParamTTY     = api.ExecTTYParam | ||||
| 	urlParamCommand = api.ExecCommandParamm | ||||
| ) | ||||
|  | ||||
| func (s *server) buildURL(method, id string, opts streamOpts) string { | ||||
| 	loc := &url.URL{ | ||||
| 		Path: path.Join(method, id), | ||||
| 	} | ||||
|  | ||||
| 	query := url.Values{} | ||||
| 	if opts.stdin { | ||||
| 		query.Add(urlParamStdin, "1") | ||||
| 	} | ||||
| 	if opts.stdout { | ||||
| 		query.Add(urlParamStdout, "1") | ||||
| 	} | ||||
| 	if opts.stderr { | ||||
| 		query.Add(urlParamStderr, "1") | ||||
| 	} | ||||
| 	if opts.tty { | ||||
| 		query.Add(urlParamTTY, "1") | ||||
| 	} | ||||
| 	for _, c := range opts.command { | ||||
| 		query.Add(urlParamCommand, c) | ||||
| 	} | ||||
| 	loc.RawQuery = query.Encode() | ||||
|  | ||||
| 	return s.config.BaseURL.ResolveReference(loc).String() | ||||
| func (s *server) buildURL(method, token string) *string { | ||||
| 	loc := s.config.BaseURL.ResolveReference(&url.URL{ | ||||
| 		Path: path.Join(method, token), | ||||
| 	}).String() | ||||
| 	return &loc | ||||
| } | ||||
|  | ||||
| func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | ||||
| 	containerID := req.PathParameter("containerID") | ||||
| 	if containerID == "" { | ||||
| 		resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) | ||||
| 	token := req.PathParameter("token") | ||||
| 	cachedRequest, ok := s.cache.Consume(token) | ||||
| 	if !ok { | ||||
| 		http.NotFound(resp.ResponseWriter, req.Request) | ||||
| 		return | ||||
| 	} | ||||
| 	exec, ok := cachedRequest.(*runtimeapi.ExecRequest) | ||||
| 	if !ok { | ||||
| 		http.NotFound(resp.ResponseWriter, req.Request) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	streamOpts, err := remotecommand.NewOptions(req.Request) | ||||
| 	if err != nil { | ||||
| 		resp.WriteError(http.StatusBadRequest, err) | ||||
| 		return | ||||
| 	streamOpts := &remotecommand.Options{ | ||||
| 		Stdin:  exec.GetStdin(), | ||||
| 		Stdout: true, | ||||
| 		Stderr: !exec.GetTty(), | ||||
| 		TTY:    exec.GetTty(), | ||||
| 	} | ||||
| 	cmd := req.Request.URL.Query()[api.ExecCommandParamm] | ||||
|  | ||||
| 	remotecommand.ServeExec( | ||||
| 		resp.ResponseWriter, | ||||
| @@ -264,8 +244,8 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | ||||
| 		s.runtime, | ||||
| 		"", // unused: podName | ||||
| 		"", // unusued: podUID | ||||
| 		containerID, | ||||
| 		cmd, | ||||
| 		exec.GetContainerId(), | ||||
| 		exec.GetCmd(), | ||||
| 		streamOpts, | ||||
| 		s.config.StreamIdleTimeout, | ||||
| 		s.config.StreamCreationTimeout, | ||||
| @@ -273,25 +253,31 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | ||||
| } | ||||
|  | ||||
| func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { | ||||
| 	containerID := req.PathParameter("containerID") | ||||
| 	if containerID == "" { | ||||
| 		resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) | ||||
| 	token := req.PathParameter("token") | ||||
| 	cachedRequest, ok := s.cache.Consume(token) | ||||
| 	if !ok { | ||||
| 		http.NotFound(resp.ResponseWriter, req.Request) | ||||
| 		return | ||||
| 	} | ||||
| 	attach, ok := cachedRequest.(*runtimeapi.AttachRequest) | ||||
| 	if !ok { | ||||
| 		http.NotFound(resp.ResponseWriter, req.Request) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	streamOpts, err := remotecommand.NewOptions(req.Request) | ||||
| 	if err != nil { | ||||
| 		resp.WriteError(http.StatusBadRequest, err) | ||||
| 		return | ||||
| 	streamOpts := &remotecommand.Options{ | ||||
| 		Stdin:  attach.GetStdin(), | ||||
| 		Stdout: true, | ||||
| 		Stderr: !attach.GetTty(), | ||||
| 		TTY:    attach.GetTty(), | ||||
| 	} | ||||
|  | ||||
| 	remotecommand.ServeAttach( | ||||
| 		resp.ResponseWriter, | ||||
| 		req.Request, | ||||
| 		s.runtime, | ||||
| 		"", // unused: podName | ||||
| 		"", // unusued: podUID | ||||
| 		containerID, | ||||
| 		attach.GetContainerId(), | ||||
| 		streamOpts, | ||||
| 		s.config.StreamIdleTimeout, | ||||
| 		s.config.StreamCreationTimeout, | ||||
| @@ -299,9 +285,15 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { | ||||
| } | ||||
|  | ||||
| func (s *server) servePortForward(req *restful.Request, resp *restful.Response) { | ||||
| 	podSandboxID := req.PathParameter("podSandboxID") | ||||
| 	if podSandboxID == "" { | ||||
| 		resp.WriteError(http.StatusBadRequest, errors.New("missing required podSandboxID path parameter")) | ||||
| 	token := req.PathParameter("token") | ||||
| 	cachedRequest, ok := s.cache.Consume(token) | ||||
| 	if !ok { | ||||
| 		http.NotFound(resp.ResponseWriter, req.Request) | ||||
| 		return | ||||
| 	} | ||||
| 	pf, ok := cachedRequest.(*runtimeapi.PortForwardRequest) | ||||
| 	if !ok { | ||||
| 		http.NotFound(resp.ResponseWriter, req.Request) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -309,7 +301,7 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response) | ||||
| 		resp.ResponseWriter, | ||||
| 		req.Request, | ||||
| 		s.runtime, | ||||
| 		podSandboxID, | ||||
| 		pf.GetPodSandboxId(), | ||||
| 		"", // unused: podUID | ||||
| 		s.config.StreamIdleTimeout, | ||||
| 		s.config.StreamCreationTimeout) | ||||
|   | ||||
| @@ -18,12 +18,12 @@ package streaming | ||||
|  | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
|  | ||||
| @@ -46,18 +46,18 @@ const ( | ||||
| ) | ||||
|  | ||||
| func TestGetExec(t *testing.T) { | ||||
| 	testcases := []struct { | ||||
| 	type testcase struct { | ||||
| 		cmd   []string | ||||
| 		tty   bool | ||||
| 		stdin bool | ||||
| 		expectedQuery string | ||||
| 	}{ | ||||
| 		{[]string{"echo", "foo"}, false, false, "?command=echo&command=foo&error=1&output=1"}, | ||||
| 		{[]string{"date"}, true, false, "?command=date&output=1&tty=1"}, | ||||
| 		{[]string{"date"}, false, true, "?command=date&error=1&input=1&output=1"}, | ||||
| 		{[]string{"date"}, true, true, "?command=date&input=1&output=1&tty=1"}, | ||||
| 	} | ||||
| 	server, err := NewServer(Config{ | ||||
| 	testcases := []testcase{ | ||||
| 		{[]string{"echo", "foo"}, false, false}, | ||||
| 		{[]string{"date"}, true, false}, | ||||
| 		{[]string{"date"}, false, true}, | ||||
| 		{[]string{"date"}, true, true}, | ||||
| 	} | ||||
| 	serv, err := NewServer(Config{ | ||||
| 		Addr: testAddr, | ||||
| 	}, nil) | ||||
| 	assert.NoError(t, err) | ||||
| @@ -79,6 +79,14 @@ func TestGetExec(t *testing.T) { | ||||
| 	}, nil) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	assertRequestToken := func(test testcase, cache *requestCache, token string) { | ||||
| 		req, ok := cache.Consume(token) | ||||
| 		require.True(t, ok, "token %s not found! testcase=%+v", token, test) | ||||
| 		assert.Equal(t, testContainerID, req.(*runtimeapi.ExecRequest).GetContainerId(), "testcase=%+v", test) | ||||
| 		assert.Equal(t, test.cmd, req.(*runtimeapi.ExecRequest).GetCmd(), "testcase=%+v", test) | ||||
| 		assert.Equal(t, test.tty, req.(*runtimeapi.ExecRequest).GetTty(), "testcase=%+v", test) | ||||
| 		assert.Equal(t, test.stdin, req.(*runtimeapi.ExecRequest).GetStdin(), "testcase=%+v", test) | ||||
| 	} | ||||
| 	containerID := testContainerID | ||||
| 	for _, test := range testcases { | ||||
| 		request := &runtimeapi.ExecRequest{ | ||||
| @@ -87,38 +95,47 @@ func TestGetExec(t *testing.T) { | ||||
| 			Tty:         &test.tty, | ||||
| 			Stdin:       &test.stdin, | ||||
| 		} | ||||
| 		// Non-TLS | ||||
| 		resp, err := server.GetExec(request) | ||||
| 		{ // Non-TLS | ||||
| 			resp, err := serv.GetExec(request) | ||||
| 			assert.NoError(t, err, "testcase=%+v", test) | ||||
| 		expectedURL := "http://" + testAddr + "/exec/" + testContainerID + test.expectedQuery | ||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | ||||
| 			expectedURL := "http://" + testAddr + "/exec/" | ||||
| 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||
| 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||
| 			assertRequestToken(test, serv.(*server).cache, token) | ||||
| 		} | ||||
|  | ||||
| 		// TLS | ||||
| 		resp, err = tlsServer.GetExec(request) | ||||
| 		{ // TLS | ||||
| 			resp, err := tlsServer.GetExec(request) | ||||
| 			assert.NoError(t, err, "testcase=%+v", test) | ||||
| 		expectedURL = "https://" + testAddr + "/exec/" + testContainerID + test.expectedQuery | ||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | ||||
| 			expectedURL := "https://" + testAddr + "/exec/" | ||||
| 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||
| 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||
| 			assertRequestToken(test, tlsServer.(*server).cache, token) | ||||
| 		} | ||||
|  | ||||
| 		// Path prefix | ||||
| 		resp, err = prefixServer.GetExec(request) | ||||
| 		{ // Path prefix | ||||
| 			resp, err := prefixServer.GetExec(request) | ||||
| 			assert.NoError(t, err, "testcase=%+v", test) | ||||
| 		expectedURL = "http://" + testAddr + "/" + pathPrefix + "/exec/" + testContainerID + test.expectedQuery | ||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | ||||
| 			expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/" | ||||
| 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||
| 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||
| 			assertRequestToken(test, prefixServer.(*server).cache, token) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestGetAttach(t *testing.T) { | ||||
| 	testcases := []struct { | ||||
| 	type testcase struct { | ||||
| 		tty   bool | ||||
| 		stdin bool | ||||
| 		expectedQuery string | ||||
| 	}{ | ||||
| 		{false, false, "?error=1&output=1"}, | ||||
| 		{true, false, "?output=1&tty=1"}, | ||||
| 		{false, true, "?error=1&input=1&output=1"}, | ||||
| 		{true, true, "?input=1&output=1&tty=1"}, | ||||
| 	} | ||||
| 	server, err := NewServer(Config{ | ||||
| 	testcases := []testcase{ | ||||
| 		{false, false}, | ||||
| 		{true, false}, | ||||
| 		{false, true}, | ||||
| 		{true, true}, | ||||
| 	} | ||||
| 	serv, err := NewServer(Config{ | ||||
| 		Addr: testAddr, | ||||
| 	}, nil) | ||||
| 	assert.NoError(t, err) | ||||
| @@ -129,6 +146,13 @@ func TestGetAttach(t *testing.T) { | ||||
| 	}, nil) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	assertRequestToken := func(test testcase, cache *requestCache, token string) { | ||||
| 		req, ok := cache.Consume(token) | ||||
| 		require.True(t, ok, "token %s not found! testcase=%+v", token, test) | ||||
| 		assert.Equal(t, testContainerID, req.(*runtimeapi.AttachRequest).GetContainerId(), "testcase=%+v", test) | ||||
| 		assert.Equal(t, test.tty, req.(*runtimeapi.AttachRequest).GetTty(), "testcase=%+v", test) | ||||
| 		assert.Equal(t, test.stdin, req.(*runtimeapi.AttachRequest).GetStdin(), "testcase=%+v", test) | ||||
| 	} | ||||
| 	containerID := testContainerID | ||||
| 	for _, test := range testcases { | ||||
| 		request := &runtimeapi.AttachRequest{ | ||||
| @@ -136,17 +160,23 @@ func TestGetAttach(t *testing.T) { | ||||
| 			Stdin:       &test.stdin, | ||||
| 			Tty:         &test.tty, | ||||
| 		} | ||||
| 		// Non-TLS | ||||
| 		resp, err := server.GetAttach(request) | ||||
| 		{ // Non-TLS | ||||
| 			resp, err := serv.GetAttach(request) | ||||
| 			assert.NoError(t, err, "testcase=%+v", test) | ||||
| 		expectedURL := "http://" + testAddr + "/attach/" + testContainerID + test.expectedQuery | ||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | ||||
| 			expectedURL := "http://" + testAddr + "/attach/" | ||||
| 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||
| 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||
| 			assertRequestToken(test, serv.(*server).cache, token) | ||||
| 		} | ||||
|  | ||||
| 		// TLS | ||||
| 		resp, err = tlsServer.GetAttach(request) | ||||
| 		{ // TLS | ||||
| 			resp, err := tlsServer.GetAttach(request) | ||||
| 			assert.NoError(t, err, "testcase=%+v", test) | ||||
| 		expectedURL = "https://" + testAddr + "/attach/" + testContainerID + test.expectedQuery | ||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | ||||
| 			expectedURL := "https://" + testAddr + "/attach/" | ||||
| 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||
| 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||
| 			assertRequestToken(test, tlsServer.(*server).cache, token) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -157,26 +187,36 @@ func TestGetPortForward(t *testing.T) { | ||||
| 		Port:         []int32{1, 2, 3, 4}, | ||||
| 	} | ||||
|  | ||||
| 	// Non-TLS | ||||
| 	server, err := NewServer(Config{ | ||||
| 	{ // Non-TLS | ||||
| 		serv, err := NewServer(Config{ | ||||
| 			Addr: testAddr, | ||||
| 		}, nil) | ||||
| 		assert.NoError(t, err) | ||||
| 	resp, err := server.GetPortForward(request) | ||||
| 		resp, err := serv.GetPortForward(request) | ||||
| 		assert.NoError(t, err) | ||||
| 	expectedURL := "http://" + testAddr + "/portforward/" + testPodSandboxID | ||||
| 	assert.Equal(t, expectedURL, resp.GetUrl()) | ||||
| 		expectedURL := "http://" + testAddr + "/portforward/" | ||||
| 		assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL)) | ||||
| 		token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||
| 		req, ok := serv.(*server).cache.Consume(token) | ||||
| 		require.True(t, ok, "token %s not found!", token) | ||||
| 		assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId()) | ||||
| 	} | ||||
|  | ||||
| 	// TLS | ||||
| 	{ // TLS | ||||
| 		tlsServer, err := NewServer(Config{ | ||||
| 			Addr:      testAddr, | ||||
| 			TLSConfig: &tls.Config{}, | ||||
| 		}, nil) | ||||
| 		assert.NoError(t, err) | ||||
| 	resp, err = tlsServer.GetPortForward(request) | ||||
| 		resp, err := tlsServer.GetPortForward(request) | ||||
| 		assert.NoError(t, err) | ||||
| 	expectedURL = "https://" + testAddr + "/portforward/" + testPodSandboxID | ||||
| 	assert.Equal(t, expectedURL, resp.GetUrl()) | ||||
| 		expectedURL := "https://" + testAddr + "/portforward/" | ||||
| 		assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL)) | ||||
| 		token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||
| 		req, ok := tlsServer.(*server).cache.Consume(token) | ||||
| 		require.True(t, ok, "token %s not found!", token) | ||||
| 		assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestServeExec(t *testing.T) { | ||||
| @@ -188,21 +228,18 @@ func TestServeAttach(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestServePortForward(t *testing.T) { | ||||
| 	rt := newFakeRuntime(t) | ||||
| 	s, err := NewServer(DefaultConfig, rt) | ||||
| 	require.NoError(t, err) | ||||
| 	testServer := httptest.NewServer(s) | ||||
| 	s, testServer := startTestServer(t) | ||||
| 	defer testServer.Close() | ||||
|  | ||||
| 	testURL, err := url.Parse(testServer.URL) | ||||
| 	podSandboxID := testPodSandboxID | ||||
| 	resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{ | ||||
| 		PodSandboxId: &podSandboxID, | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
| 	reqURL, err := url.Parse(resp.GetUrl()) | ||||
| 	require.NoError(t, err) | ||||
| 	loc := &url.URL{ | ||||
| 		Scheme: testURL.Scheme, | ||||
| 		Host:   testURL.Host, | ||||
| 	} | ||||
|  | ||||
| 	loc.Path = fmt.Sprintf("/%s/%s", "portforward", testPodSandboxID) | ||||
| 	exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) | ||||
| 	exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) | ||||
| 	require.NoError(t, err) | ||||
| 	streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name) | ||||
| 	require.NoError(t, err) | ||||
| @@ -227,22 +264,30 @@ func TestServePortForward(t *testing.T) { | ||||
| // Run the remote command test. | ||||
| // commandType is either "exec" or "attach". | ||||
| func runRemoteCommandTest(t *testing.T, commandType string) { | ||||
| 	rt := newFakeRuntime(t) | ||||
| 	s, err := NewServer(DefaultConfig, rt) | ||||
| 	require.NoError(t, err) | ||||
| 	testServer := httptest.NewServer(s) | ||||
| 	s, testServer := startTestServer(t) | ||||
| 	defer testServer.Close() | ||||
|  | ||||
| 	testURL, err := url.Parse(testServer.URL) | ||||
| 	var reqURL *url.URL | ||||
| 	stdin := true | ||||
| 	containerID := testContainerID | ||||
| 	switch commandType { | ||||
| 	case "exec": | ||||
| 		resp, err := s.GetExec(&runtimeapi.ExecRequest{ | ||||
| 			ContainerId: &containerID, | ||||
| 			Cmd:         []string{"echo"}, | ||||
| 			Stdin:       &stdin, | ||||
| 		}) | ||||
| 		require.NoError(t, err) | ||||
| 		reqURL, err = url.Parse(resp.GetUrl()) | ||||
| 		require.NoError(t, err) | ||||
| 	case "attach": | ||||
| 		resp, err := s.GetAttach(&runtimeapi.AttachRequest{ | ||||
| 			ContainerId: &containerID, | ||||
| 			Stdin:       &stdin, | ||||
| 		}) | ||||
| 		require.NoError(t, err) | ||||
| 		reqURL, err = url.Parse(resp.GetUrl()) | ||||
| 		require.NoError(t, err) | ||||
| 	query := url.Values{} | ||||
| 	query.Add(urlParamStdin, "1") | ||||
| 	query.Add(urlParamStdout, "1") | ||||
| 	query.Add(urlParamStderr, "1") | ||||
| 	loc := &url.URL{ | ||||
| 		Scheme:   testURL.Scheme, | ||||
| 		Host:     testURL.Host, | ||||
| 		RawQuery: query.Encode(), | ||||
| 	} | ||||
|  | ||||
| 	wg := sync.WaitGroup{} | ||||
| @@ -254,8 +299,7 @@ func runRemoteCommandTest(t *testing.T, commandType string) { | ||||
|  | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		loc.Path = fmt.Sprintf("/%s/%s", commandType, testContainerID) | ||||
| 		exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) | ||||
| 		exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) | ||||
| 		require.NoError(t, err) | ||||
|  | ||||
| 		opts := remotecommand.StreamOptions{ | ||||
| @@ -275,6 +319,36 @@ func runRemoteCommandTest(t *testing.T, commandType string) { | ||||
| 	}() | ||||
|  | ||||
| 	wg.Wait() | ||||
|  | ||||
| 	// Repeat request with the same URL should be a 404. | ||||
| 	resp, err := http.Get(reqURL.String()) | ||||
| 	require.NoError(t, err) | ||||
| 	assert.Equal(t, http.StatusNotFound, resp.StatusCode) | ||||
| } | ||||
|  | ||||
| func startTestServer(t *testing.T) (Server, *httptest.Server) { | ||||
| 	var s Server | ||||
| 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		s.ServeHTTP(w, r) | ||||
| 	})) | ||||
| 	cleanup := true | ||||
| 	defer func() { | ||||
| 		if cleanup { | ||||
| 			testServer.Close() | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	testURL, err := url.Parse(testServer.URL) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	rt := newFakeRuntime(t) | ||||
| 	config := DefaultConfig | ||||
| 	config.BaseURL = testURL | ||||
| 	s, err = NewServer(config, rt) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	cleanup = false // Caller must close the test server. | ||||
| 	return s, testServer | ||||
| } | ||||
|  | ||||
| const ( | ||||
|   | ||||
| @@ -391,6 +391,14 @@ var _ = framework.KubeDescribe("Kubectl client", func() { | ||||
| 				framework.Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a) | ||||
| 			} | ||||
|  | ||||
| 			By("executing a very long command in the container") | ||||
| 			veryLongData := make([]rune, 20000) | ||||
| 			for i := 0; i < len(veryLongData); i++ { | ||||
| 				veryLongData[i] = 'a' | ||||
| 			} | ||||
| 			execOutput = framework.RunKubectlOrDie("exec", fmt.Sprintf("--namespace=%v", ns), simplePodName, "echo", string(veryLongData)) | ||||
| 			Expect(string(veryLongData)).To(Equal(strings.TrimSpace(execOutput)), "Unexpected kubectl exec output") | ||||
|  | ||||
| 			By("executing a command in the container with noninteractive stdin") | ||||
| 			execOutput = framework.NewKubectlCommand("exec", fmt.Sprintf("--namespace=%v", ns), "-i", simplePodName, "cat"). | ||||
| 				WithStdinData("abcd1234"). | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Tim St. Clair
					Tim St. Clair