mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-10-31 18:28:13 +00:00 
			
		
		
		
	Don't include user data in CRI streaming redirect URLs
This commit is contained in:
		| @@ -606,7 +606,7 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response) | |||||||
| 	podFullName := kubecontainer.GetPodFullName(pod) | 	podFullName := kubecontainer.GetPodFullName(pod) | ||||||
| 	redirect, err := s.host.GetAttach(podFullName, params.podUID, params.containerName, *streamOpts) | 	redirect, err := s.host.GetAttach(podFullName, params.podUID, params.containerName, *streamOpts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		response.WriteError(streaming.HTTPStatus(err), err) | 		streaming.WriteError(err, response.ResponseWriter) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if redirect != nil { | 	if redirect != nil { | ||||||
| @@ -644,7 +644,7 @@ func (s *Server) getExec(request *restful.Request, response *restful.Response) { | |||||||
| 	podFullName := kubecontainer.GetPodFullName(pod) | 	podFullName := kubecontainer.GetPodFullName(pod) | ||||||
| 	redirect, err := s.host.GetExec(podFullName, params.podUID, params.containerName, params.cmd, *streamOpts) | 	redirect, err := s.host.GetExec(podFullName, params.podUID, params.containerName, params.cmd, *streamOpts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		response.WriteError(streaming.HTTPStatus(err), err) | 		streaming.WriteError(err, response.ResponseWriter) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if redirect != nil { | 	if redirect != nil { | ||||||
| @@ -714,7 +714,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp | |||||||
|  |  | ||||||
| 	redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID) | 	redirect, err := s.host.GetPortForward(pod.Name, pod.Namespace, pod.UID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		response.WriteError(streaming.HTTPStatus(err), err) | 		streaming.WriteError(err, response.ResponseWriter) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if redirect != nil { | 	if redirect != nil { | ||||||
|   | |||||||
| @@ -12,14 +12,15 @@ go_library( | |||||||
|     name = "go_default_library", |     name = "go_default_library", | ||||||
|     srcs = [ |     srcs = [ | ||||||
|         "errors.go", |         "errors.go", | ||||||
|  |         "request_cache.go", | ||||||
|         "server.go", |         "server.go", | ||||||
|     ], |     ], | ||||||
|     tags = ["automanaged"], |     tags = ["automanaged"], | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//pkg/api:go_default_library", |  | ||||||
|         "//pkg/kubelet/api/v1alpha1/runtime:go_default_library", |         "//pkg/kubelet/api/v1alpha1/runtime:go_default_library", | ||||||
|         "//pkg/kubelet/server/portforward:go_default_library", |         "//pkg/kubelet/server/portforward:go_default_library", | ||||||
|         "//pkg/kubelet/server/remotecommand:go_default_library", |         "//pkg/kubelet/server/remotecommand:go_default_library", | ||||||
|  |         "//pkg/util/clock:go_default_library", | ||||||
|         "//pkg/util/term:go_default_library", |         "//pkg/util/term:go_default_library", | ||||||
|         "//vendor:github.com/emicklei/go-restful", |         "//vendor:github.com/emicklei/go-restful", | ||||||
|         "//vendor:google.golang.org/grpc", |         "//vendor:google.golang.org/grpc", | ||||||
| @@ -30,7 +31,10 @@ go_library( | |||||||
|  |  | ||||||
| go_test( | go_test( | ||||||
|     name = "go_default_test", |     name = "go_default_test", | ||||||
|     srcs = ["server_test.go"], |     srcs = [ | ||||||
|  |         "request_cache_test.go", | ||||||
|  |         "server_test.go", | ||||||
|  |     ], | ||||||
|     library = ":go_default_library", |     library = ":go_default_library", | ||||||
|     tags = ["automanaged"], |     tags = ["automanaged"], | ||||||
|     deps = [ |     deps = [ | ||||||
| @@ -43,6 +47,7 @@ go_test( | |||||||
|         "//vendor:github.com/stretchr/testify/assert", |         "//vendor:github.com/stretchr/testify/assert", | ||||||
|         "//vendor:github.com/stretchr/testify/require", |         "//vendor:github.com/stretchr/testify/require", | ||||||
|         "//vendor:k8s.io/client-go/pkg/api", |         "//vendor:k8s.io/client-go/pkg/api", | ||||||
|  |         "//vendor:k8s.io/client-go/pkg/util/clock", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ package streaming | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| @@ -33,12 +34,27 @@ func ErrorTimeout(op string, timeout time.Duration) error { | |||||||
| 	return grpc.Errorf(codes.DeadlineExceeded, fmt.Sprintf("%s timed out after %s", op, timeout.String())) | 	return grpc.Errorf(codes.DeadlineExceeded, fmt.Sprintf("%s timed out after %s", op, timeout.String())) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Translates a CRI streaming error into an HTTP status code. | // The error returned when the maximum number of in-flight requests is exceeded. | ||||||
| func HTTPStatus(err error) int { | func ErrorTooManyInFlight() error { | ||||||
|  | 	return grpc.Errorf(codes.ResourceExhausted, "maximum number of in-flight requests exceeded") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Translates a CRI streaming error into an appropriate HTTP response. | ||||||
|  | func WriteError(err error, w http.ResponseWriter) error { | ||||||
|  | 	var status int | ||||||
| 	switch grpc.Code(err) { | 	switch grpc.Code(err) { | ||||||
| 	case codes.NotFound: | 	case codes.NotFound: | ||||||
| 		return http.StatusNotFound | 		status = http.StatusNotFound | ||||||
|  | 	case codes.ResourceExhausted: | ||||||
|  | 		// We only expect to hit this if there is a DoS, so we just wait the full TTL. | ||||||
|  | 		// If this is ever hit in steady-state operations, consider increasing the MaxInFlight requests, | ||||||
|  | 		// or plumbing through the time to next expiration. | ||||||
|  | 		w.Header().Set("Retry-After", strconv.Itoa(int(CacheTTL.Seconds()))) | ||||||
|  | 		status = http.StatusTooManyRequests | ||||||
| 	default: | 	default: | ||||||
| 		return http.StatusInternalServerError | 		status = http.StatusInternalServerError | ||||||
| 	} | 	} | ||||||
|  | 	w.WriteHeader(status) | ||||||
|  | 	_, writeErr := w.Write([]byte(err.Error())) | ||||||
|  | 	return writeErr | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										146
									
								
								pkg/kubelet/server/streaming/request_cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								pkg/kubelet/server/streaming/request_cache.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,146 @@ | |||||||
|  | /* | ||||||
|  | Copyright 2016 The Kubernetes Authors. | ||||||
|  |  | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | package streaming | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"container/list" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"k8s.io/kubernetes/pkg/util/clock" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	// Timeout after which tokens become invalid. | ||||||
|  | 	CacheTTL = 1 * time.Minute | ||||||
|  | 	// The maximum number of in-flight requests to allow. | ||||||
|  | 	MaxInFlight = 1000 | ||||||
|  | 	// Length of the random base64 encoded token identifying the request. | ||||||
|  | 	TokenLen = 8 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // requestCache caches streaming (exec/attach/port-forward) requests and generates a single-use | ||||||
|  | // random token for their retrieval. The requestCache is used for building streaming URLs without | ||||||
|  | // the need to encode every request parameter in the URL. | ||||||
|  | type requestCache struct { | ||||||
|  | 	// clock is used to obtain the current time | ||||||
|  | 	clock clock.Clock | ||||||
|  |  | ||||||
|  | 	// tokens maps the generate token to the request for fast retrieval. | ||||||
|  | 	tokens map[string]*list.Element | ||||||
|  | 	// ll maintains an age-ordered request list for faster garbage collection of expired requests. | ||||||
|  | 	ll *list.List | ||||||
|  |  | ||||||
|  | 	lock sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Type representing an *ExecRequest, *AttachRequest, or *PortForwardRequest. | ||||||
|  | type request interface{} | ||||||
|  |  | ||||||
|  | type cacheEntry struct { | ||||||
|  | 	token      string | ||||||
|  | 	req        request | ||||||
|  | 	expireTime time.Time | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newRequestCache() *requestCache { | ||||||
|  | 	return &requestCache{ | ||||||
|  | 		clock:  clock.RealClock{}, | ||||||
|  | 		ll:     list.New(), | ||||||
|  | 		tokens: make(map[string]*list.Element), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Insert the given request into the cache and returns the token used for fetching it out. | ||||||
|  | func (c *requestCache) Insert(req request) (token string, err error) { | ||||||
|  | 	c.lock.Lock() | ||||||
|  | 	defer c.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	// Remove expired entries. | ||||||
|  | 	c.gc() | ||||||
|  | 	// If the cache is full, reject the request. | ||||||
|  | 	if c.ll.Len() == MaxInFlight { | ||||||
|  | 		return "", ErrorTooManyInFlight() | ||||||
|  | 	} | ||||||
|  | 	token, err = c.uniqueToken() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	ele := c.ll.PushFront(&cacheEntry{token, req, c.clock.Now().Add(CacheTTL)}) | ||||||
|  |  | ||||||
|  | 	c.tokens[token] = ele | ||||||
|  | 	return token, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Consume the token (remove it from the cache) and return the cached request, if found. | ||||||
|  | func (c *requestCache) Consume(token string) (req request, found bool) { | ||||||
|  | 	c.lock.Lock() | ||||||
|  | 	defer c.lock.Unlock() | ||||||
|  | 	ele, ok := c.tokens[token] | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, false | ||||||
|  | 	} | ||||||
|  | 	c.ll.Remove(ele) | ||||||
|  | 	delete(c.tokens, token) | ||||||
|  |  | ||||||
|  | 	entry := ele.Value.(*cacheEntry) | ||||||
|  | 	if c.clock.Now().After(entry.expireTime) { | ||||||
|  | 		// Entry already expired. | ||||||
|  | 		return nil, false | ||||||
|  | 	} | ||||||
|  | 	return entry.req, true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // uniqueToken generates a random URL-safe token and ensures uniqueness. | ||||||
|  | func (c *requestCache) uniqueToken() (string, error) { | ||||||
|  | 	const maxTries = 10 | ||||||
|  | 	// Number of bytes to be TokenLen when base64 encoded. | ||||||
|  | 	tokenSize := math.Ceil(float64(TokenLen) * 6 / 8) | ||||||
|  | 	rawToken := make([]byte, int(tokenSize)) | ||||||
|  | 	for i := 0; i < maxTries; i++ { | ||||||
|  | 		if _, err := rand.Read(rawToken); err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  | 		encoded := base64.RawURLEncoding.EncodeToString(rawToken) | ||||||
|  | 		token := encoded[:TokenLen] | ||||||
|  | 		// If it's unique, return it. Otherwise retry. | ||||||
|  | 		if _, exists := c.tokens[encoded]; !exists { | ||||||
|  | 			return token, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return "", fmt.Errorf("failed to generate unique token") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Must be write-locked prior to calling. | ||||||
|  | func (c *requestCache) gc() { | ||||||
|  | 	now := c.clock.Now() | ||||||
|  | 	for c.ll.Len() > 0 { | ||||||
|  | 		oldest := c.ll.Back() | ||||||
|  | 		entry := oldest.Value.(*cacheEntry) | ||||||
|  | 		if !now.After(entry.expireTime) { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Oldest value is expired; remove it. | ||||||
|  | 		c.ll.Remove(oldest) | ||||||
|  | 		delete(c.tokens, entry.token) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										221
									
								
								pkg/kubelet/server/streaming/request_cache_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								pkg/kubelet/server/streaming/request_cache_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,221 @@ | |||||||
|  | /* | ||||||
|  | Copyright 2016 The Kubernetes Authors. | ||||||
|  |  | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | package streaming | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"strconv" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
|  |  | ||||||
|  | 	"k8s.io/client-go/pkg/util/clock" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestInsert(t *testing.T) { | ||||||
|  | 	c, _ := newTestCache() | ||||||
|  |  | ||||||
|  | 	// Insert normal | ||||||
|  | 	oldestTok, err := c.Insert(nextRequest()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Len(t, oldestTok, TokenLen) | ||||||
|  | 	assertCacheSize(t, c, 1) | ||||||
|  |  | ||||||
|  | 	// Insert until full | ||||||
|  | 	for i := 0; i < MaxInFlight-2; i++ { | ||||||
|  | 		tok, err := c.Insert(nextRequest()) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		assert.Len(t, tok, TokenLen) | ||||||
|  | 	} | ||||||
|  | 	assertCacheSize(t, c, MaxInFlight-1) | ||||||
|  |  | ||||||
|  | 	newestReq := nextRequest() | ||||||
|  | 	newestTok, err := c.Insert(newestReq) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Len(t, newestTok, TokenLen) | ||||||
|  | 	assertCacheSize(t, c, MaxInFlight) | ||||||
|  | 	require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached") | ||||||
|  |  | ||||||
|  | 	// Consume newest token. | ||||||
|  | 	req, ok := c.Consume(newestTok) | ||||||
|  | 	assert.True(t, ok, "newest request should still be cached") | ||||||
|  | 	assert.Equal(t, newestReq, req) | ||||||
|  | 	require.Contains(t, c.tokens, oldestTok, "oldest request should still be cached") | ||||||
|  |  | ||||||
|  | 	// Insert again (still full) | ||||||
|  | 	tok, err := c.Insert(nextRequest()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Len(t, tok, TokenLen) | ||||||
|  | 	assertCacheSize(t, c, MaxInFlight) | ||||||
|  |  | ||||||
|  | 	// Insert again (should evict) | ||||||
|  | 	_, err = c.Insert(nextRequest()) | ||||||
|  | 	assert.Error(t, err, "should reject further requests") | ||||||
|  | 	errResponse := httptest.NewRecorder() | ||||||
|  | 	require.NoError(t, WriteError(err, errResponse)) | ||||||
|  | 	assert.Equal(t, errResponse.Code, http.StatusTooManyRequests) | ||||||
|  | 	assert.Equal(t, strconv.Itoa(int(CacheTTL.Seconds())), errResponse.HeaderMap.Get("Retry-After")) | ||||||
|  |  | ||||||
|  | 	assertCacheSize(t, c, MaxInFlight) | ||||||
|  | 	_, ok = c.Consume(oldestTok) | ||||||
|  | 	assert.True(t, ok, "oldest request should be valid") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConsume(t *testing.T) { | ||||||
|  | 	c, clock := newTestCache() | ||||||
|  |  | ||||||
|  | 	{ // Insert & consume. | ||||||
|  | 		req := nextRequest() | ||||||
|  | 		tok, err := c.Insert(req) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		assertCacheSize(t, c, 1) | ||||||
|  |  | ||||||
|  | 		cachedReq, ok := c.Consume(tok) | ||||||
|  | 		assert.True(t, ok) | ||||||
|  | 		assert.Equal(t, req, cachedReq) | ||||||
|  | 		assertCacheSize(t, c, 0) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	{ // Insert & consume out of order | ||||||
|  | 		req1 := nextRequest() | ||||||
|  | 		tok1, err := c.Insert(req1) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		assertCacheSize(t, c, 1) | ||||||
|  |  | ||||||
|  | 		req2 := nextRequest() | ||||||
|  | 		tok2, err := c.Insert(req2) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		assertCacheSize(t, c, 2) | ||||||
|  |  | ||||||
|  | 		cachedReq2, ok := c.Consume(tok2) | ||||||
|  | 		assert.True(t, ok) | ||||||
|  | 		assert.Equal(t, req2, cachedReq2) | ||||||
|  | 		assertCacheSize(t, c, 1) | ||||||
|  |  | ||||||
|  | 		cachedReq1, ok := c.Consume(tok1) | ||||||
|  | 		assert.True(t, ok) | ||||||
|  | 		assert.Equal(t, req1, cachedReq1) | ||||||
|  | 		assertCacheSize(t, c, 0) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	{ // Consume a second time | ||||||
|  | 		req := nextRequest() | ||||||
|  | 		tok, err := c.Insert(req) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		assertCacheSize(t, c, 1) | ||||||
|  |  | ||||||
|  | 		cachedReq, ok := c.Consume(tok) | ||||||
|  | 		assert.True(t, ok) | ||||||
|  | 		assert.Equal(t, req, cachedReq) | ||||||
|  | 		assertCacheSize(t, c, 0) | ||||||
|  |  | ||||||
|  | 		_, ok = c.Consume(tok) | ||||||
|  | 		assert.False(t, ok) | ||||||
|  | 		assertCacheSize(t, c, 0) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	{ // Consume without insert | ||||||
|  | 		_, ok := c.Consume("fooBAR") | ||||||
|  | 		assert.False(t, ok) | ||||||
|  | 		assertCacheSize(t, c, 0) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	{ // Consume expired | ||||||
|  | 		tok, err := c.Insert(nextRequest()) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		assertCacheSize(t, c, 1) | ||||||
|  |  | ||||||
|  | 		clock.Step(2 * CacheTTL) | ||||||
|  |  | ||||||
|  | 		_, ok := c.Consume(tok) | ||||||
|  | 		assert.False(t, ok) | ||||||
|  | 		assertCacheSize(t, c, 0) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestGC(t *testing.T) { | ||||||
|  | 	c, clock := newTestCache() | ||||||
|  |  | ||||||
|  | 	// When empty | ||||||
|  | 	c.gc() | ||||||
|  | 	assertCacheSize(t, c, 0) | ||||||
|  |  | ||||||
|  | 	tok1, err := c.Insert(nextRequest()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assertCacheSize(t, c, 1) | ||||||
|  | 	clock.Step(10 * time.Second) | ||||||
|  | 	tok2, err := c.Insert(nextRequest()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assertCacheSize(t, c, 2) | ||||||
|  |  | ||||||
|  | 	// expired: tok1, tok2 | ||||||
|  | 	// non-expired: tok3, tok4 | ||||||
|  | 	clock.Step(2 * CacheTTL) | ||||||
|  | 	tok3, err := c.Insert(nextRequest()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assertCacheSize(t, c, 1) | ||||||
|  | 	clock.Step(10 * time.Second) | ||||||
|  | 	tok4, err := c.Insert(nextRequest()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assertCacheSize(t, c, 2) | ||||||
|  |  | ||||||
|  | 	_, ok := c.Consume(tok1) | ||||||
|  | 	assert.False(t, ok) | ||||||
|  | 	_, ok = c.Consume(tok2) | ||||||
|  | 	assert.False(t, ok) | ||||||
|  | 	_, ok = c.Consume(tok3) | ||||||
|  | 	assert.True(t, ok) | ||||||
|  | 	_, ok = c.Consume(tok4) | ||||||
|  | 	assert.True(t, ok) | ||||||
|  |  | ||||||
|  | 	// When full, nothing is expired. | ||||||
|  | 	for i := 0; i < MaxInFlight; i++ { | ||||||
|  | 		_, err := c.Insert(nextRequest()) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 	} | ||||||
|  | 	assertCacheSize(t, c, MaxInFlight) | ||||||
|  |  | ||||||
|  | 	// When everything is expired | ||||||
|  | 	clock.Step(2 * CacheTTL) | ||||||
|  | 	_, err = c.Insert(nextRequest()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assertCacheSize(t, c, 1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newTestCache() (*requestCache, *clock.FakeClock) { | ||||||
|  | 	c := newRequestCache() | ||||||
|  | 	fakeClock := clock.NewFakeClock(time.Now()) | ||||||
|  | 	c.clock = fakeClock | ||||||
|  | 	return c, fakeClock | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func assertCacheSize(t *testing.T, cache *requestCache, expectedSize int) { | ||||||
|  | 	tokenLen := len(cache.tokens) | ||||||
|  | 	llLen := cache.ll.Len() | ||||||
|  | 	assert.Equal(t, tokenLen, llLen, "inconsistent cache size! len(tokens)=%d; len(ll)=%d", tokenLen, llLen) | ||||||
|  | 	assert.Equal(t, expectedSize, tokenLen, "unexpected cache size!") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var requestUID = 0 | ||||||
|  |  | ||||||
|  | func nextRequest() interface{} { | ||||||
|  | 	requestUID++ | ||||||
|  | 	return requestUID | ||||||
|  | } | ||||||
| @@ -25,10 +25,12 @@ import ( | |||||||
| 	"path" | 	"path" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"google.golang.org/grpc" | ||||||
|  | 	"google.golang.org/grpc/codes" | ||||||
|  |  | ||||||
| 	restful "github.com/emicklei/go-restful" | 	restful "github.com/emicklei/go-restful" | ||||||
|  |  | ||||||
| 	"k8s.io/apimachinery/pkg/types" | 	"k8s.io/apimachinery/pkg/types" | ||||||
| 	"k8s.io/kubernetes/pkg/api" |  | ||||||
| 	runtimeapi "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" | 	runtimeapi "k8s.io/kubernetes/pkg/kubelet/api/v1alpha1/runtime" | ||||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/portforward" | 	"k8s.io/kubernetes/pkg/kubelet/server/portforward" | ||||||
| 	"k8s.io/kubernetes/pkg/kubelet/server/remotecommand" | 	"k8s.io/kubernetes/pkg/kubelet/server/remotecommand" | ||||||
| @@ -97,6 +99,7 @@ func NewServer(config Config, runtime Runtime) (Server, error) { | |||||||
| 	s := &server{ | 	s := &server{ | ||||||
| 		config:  config, | 		config:  config, | ||||||
| 		runtime: &criAdapter{runtime}, | 		runtime: &criAdapter{runtime}, | ||||||
|  | 		cache:   newRequestCache(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if s.config.BaseURL == nil { | 	if s.config.BaseURL == nil { | ||||||
| @@ -114,9 +117,9 @@ func NewServer(config Config, runtime Runtime) (Server, error) { | |||||||
| 		path    string | 		path    string | ||||||
| 		handler restful.RouteFunction | 		handler restful.RouteFunction | ||||||
| 	}{ | 	}{ | ||||||
| 		{"/exec/{containerID}", s.serveExec}, | 		{"/exec/{token}", s.serveExec}, | ||||||
| 		{"/attach/{containerID}", s.serveAttach}, | 		{"/attach/{token}", s.serveAttach}, | ||||||
| 		{"/portforward/{podSandboxID}", s.servePortForward}, | 		{"/portforward/{token}", s.servePortForward}, | ||||||
| 	} | 	} | ||||||
| 	// If serving relative to a base path, set that here. | 	// If serving relative to a base path, set that here. | ||||||
| 	pathPrefix := path.Dir(s.config.BaseURL.Path) | 	pathPrefix := path.Dir(s.config.BaseURL.Path) | ||||||
| @@ -139,37 +142,45 @@ type server struct { | |||||||
| 	config  Config | 	config  Config | ||||||
| 	runtime *criAdapter | 	runtime *criAdapter | ||||||
| 	handler http.Handler | 	handler http.Handler | ||||||
|  | 	cache   *requestCache | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *server) GetExec(req *runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) { | func (s *server) GetExec(req *runtimeapi.ExecRequest) (*runtimeapi.ExecResponse, error) { | ||||||
| 	url := s.buildURL("exec", req.GetContainerId(), streamOpts{ | 	if req.GetContainerId() == "" { | ||||||
| 		stdin:   req.GetStdin(), | 		return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id") | ||||||
| 		stdout:  true, | 	} | ||||||
| 		stderr:  !req.GetTty(), // For TTY connections, both stderr is combined with stdout. | 	token, err := s.cache.Insert(req) | ||||||
| 		tty:     req.GetTty(), | 	if err != nil { | ||||||
| 		command: req.GetCmd(), | 		return nil, err | ||||||
| 	}) | 	} | ||||||
| 	return &runtimeapi.ExecResponse{ | 	return &runtimeapi.ExecResponse{ | ||||||
| 		Url: &url, | 		Url: s.buildURL("exec", token), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *server) GetAttach(req *runtimeapi.AttachRequest) (*runtimeapi.AttachResponse, error) { | func (s *server) GetAttach(req *runtimeapi.AttachRequest) (*runtimeapi.AttachResponse, error) { | ||||||
| 	url := s.buildURL("attach", req.GetContainerId(), streamOpts{ | 	if req.GetContainerId() == "" { | ||||||
| 		stdin:  req.GetStdin(), | 		return nil, grpc.Errorf(codes.InvalidArgument, "missing required container_id") | ||||||
| 		stdout: true, | 	} | ||||||
| 		stderr: !req.GetTty(), // For TTY connections, both stderr is combined with stdout. | 	token, err := s.cache.Insert(req) | ||||||
| 		tty:    req.GetTty(), | 	if err != nil { | ||||||
| 	}) | 		return nil, err | ||||||
|  | 	} | ||||||
| 	return &runtimeapi.AttachResponse{ | 	return &runtimeapi.AttachResponse{ | ||||||
| 		Url: &url, | 		Url: s.buildURL("attach", token), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *server) GetPortForward(req *runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) { | func (s *server) GetPortForward(req *runtimeapi.PortForwardRequest) (*runtimeapi.PortForwardResponse, error) { | ||||||
| 	url := s.buildURL("portforward", req.GetPodSandboxId(), streamOpts{}) | 	if req.GetPodSandboxId() == "" { | ||||||
|  | 		return nil, grpc.Errorf(codes.InvalidArgument, "missing required pod_sandbox_id") | ||||||
|  | 	} | ||||||
|  | 	token, err := s.cache.Insert(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	return &runtimeapi.PortForwardResponse{ | 	return &runtimeapi.PortForwardResponse{ | ||||||
| 		Url: &url, | 		Url: s.buildURL("portforward", token), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -200,63 +211,32 @@ func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||||||
| 	s.handler.ServeHTTP(w, r) | 	s.handler.ServeHTTP(w, r) | ||||||
| } | } | ||||||
|  |  | ||||||
| type streamOpts struct { | func (s *server) buildURL(method, token string) *string { | ||||||
| 	stdin  bool | 	loc := s.config.BaseURL.ResolveReference(&url.URL{ | ||||||
| 	stdout bool | 		Path: path.Join(method, token), | ||||||
| 	stderr bool | 	}).String() | ||||||
| 	tty    bool | 	return &loc | ||||||
|  |  | ||||||
| 	command []string |  | ||||||
| 	port    []int32 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	urlParamStdin   = api.ExecStdinParam |  | ||||||
| 	urlParamStdout  = api.ExecStdoutParam |  | ||||||
| 	urlParamStderr  = api.ExecStderrParam |  | ||||||
| 	urlParamTTY     = api.ExecTTYParam |  | ||||||
| 	urlParamCommand = api.ExecCommandParamm |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func (s *server) buildURL(method, id string, opts streamOpts) string { |  | ||||||
| 	loc := &url.URL{ |  | ||||||
| 		Path: path.Join(method, id), |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	query := url.Values{} |  | ||||||
| 	if opts.stdin { |  | ||||||
| 		query.Add(urlParamStdin, "1") |  | ||||||
| 	} |  | ||||||
| 	if opts.stdout { |  | ||||||
| 		query.Add(urlParamStdout, "1") |  | ||||||
| 	} |  | ||||||
| 	if opts.stderr { |  | ||||||
| 		query.Add(urlParamStderr, "1") |  | ||||||
| 	} |  | ||||||
| 	if opts.tty { |  | ||||||
| 		query.Add(urlParamTTY, "1") |  | ||||||
| 	} |  | ||||||
| 	for _, c := range opts.command { |  | ||||||
| 		query.Add(urlParamCommand, c) |  | ||||||
| 	} |  | ||||||
| 	loc.RawQuery = query.Encode() |  | ||||||
|  |  | ||||||
| 	return s.config.BaseURL.ResolveReference(loc).String() |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | ||||||
| 	containerID := req.PathParameter("containerID") | 	token := req.PathParameter("token") | ||||||
| 	if containerID == "" { | 	cachedRequest, ok := s.cache.Consume(token) | ||||||
| 		resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) | 	if !ok { | ||||||
|  | 		http.NotFound(resp.ResponseWriter, req.Request) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	exec, ok := cachedRequest.(*runtimeapi.ExecRequest) | ||||||
|  | 	if !ok { | ||||||
|  | 		http.NotFound(resp.ResponseWriter, req.Request) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	streamOpts, err := remotecommand.NewOptions(req.Request) | 	streamOpts := &remotecommand.Options{ | ||||||
| 	if err != nil { | 		Stdin:  exec.GetStdin(), | ||||||
| 		resp.WriteError(http.StatusBadRequest, err) | 		Stdout: true, | ||||||
| 		return | 		Stderr: !exec.GetTty(), | ||||||
|  | 		TTY:    exec.GetTty(), | ||||||
| 	} | 	} | ||||||
| 	cmd := req.Request.URL.Query()[api.ExecCommandParamm] |  | ||||||
|  |  | ||||||
| 	remotecommand.ServeExec( | 	remotecommand.ServeExec( | ||||||
| 		resp.ResponseWriter, | 		resp.ResponseWriter, | ||||||
| @@ -264,8 +244,8 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | |||||||
| 		s.runtime, | 		s.runtime, | ||||||
| 		"", // unused: podName | 		"", // unused: podName | ||||||
| 		"", // unusued: podUID | 		"", // unusued: podUID | ||||||
| 		containerID, | 		exec.GetContainerId(), | ||||||
| 		cmd, | 		exec.GetCmd(), | ||||||
| 		streamOpts, | 		streamOpts, | ||||||
| 		s.config.StreamIdleTimeout, | 		s.config.StreamIdleTimeout, | ||||||
| 		s.config.StreamCreationTimeout, | 		s.config.StreamCreationTimeout, | ||||||
| @@ -273,25 +253,31 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { | func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { | ||||||
| 	containerID := req.PathParameter("containerID") | 	token := req.PathParameter("token") | ||||||
| 	if containerID == "" { | 	cachedRequest, ok := s.cache.Consume(token) | ||||||
| 		resp.WriteError(http.StatusBadRequest, errors.New("missing required containerID path parameter")) | 	if !ok { | ||||||
|  | 		http.NotFound(resp.ResponseWriter, req.Request) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	attach, ok := cachedRequest.(*runtimeapi.AttachRequest) | ||||||
|  | 	if !ok { | ||||||
|  | 		http.NotFound(resp.ResponseWriter, req.Request) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	streamOpts, err := remotecommand.NewOptions(req.Request) | 	streamOpts := &remotecommand.Options{ | ||||||
| 	if err != nil { | 		Stdin:  attach.GetStdin(), | ||||||
| 		resp.WriteError(http.StatusBadRequest, err) | 		Stdout: true, | ||||||
| 		return | 		Stderr: !attach.GetTty(), | ||||||
|  | 		TTY:    attach.GetTty(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	remotecommand.ServeAttach( | 	remotecommand.ServeAttach( | ||||||
| 		resp.ResponseWriter, | 		resp.ResponseWriter, | ||||||
| 		req.Request, | 		req.Request, | ||||||
| 		s.runtime, | 		s.runtime, | ||||||
| 		"", // unused: podName | 		"", // unused: podName | ||||||
| 		"", // unusued: podUID | 		"", // unusued: podUID | ||||||
| 		containerID, | 		attach.GetContainerId(), | ||||||
| 		streamOpts, | 		streamOpts, | ||||||
| 		s.config.StreamIdleTimeout, | 		s.config.StreamIdleTimeout, | ||||||
| 		s.config.StreamCreationTimeout, | 		s.config.StreamCreationTimeout, | ||||||
| @@ -299,9 +285,15 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (s *server) servePortForward(req *restful.Request, resp *restful.Response) { | func (s *server) servePortForward(req *restful.Request, resp *restful.Response) { | ||||||
| 	podSandboxID := req.PathParameter("podSandboxID") | 	token := req.PathParameter("token") | ||||||
| 	if podSandboxID == "" { | 	cachedRequest, ok := s.cache.Consume(token) | ||||||
| 		resp.WriteError(http.StatusBadRequest, errors.New("missing required podSandboxID path parameter")) | 	if !ok { | ||||||
|  | 		http.NotFound(resp.ResponseWriter, req.Request) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	pf, ok := cachedRequest.(*runtimeapi.PortForwardRequest) | ||||||
|  | 	if !ok { | ||||||
|  | 		http.NotFound(resp.ResponseWriter, req.Request) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -309,7 +301,7 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response) | |||||||
| 		resp.ResponseWriter, | 		resp.ResponseWriter, | ||||||
| 		req.Request, | 		req.Request, | ||||||
| 		s.runtime, | 		s.runtime, | ||||||
| 		podSandboxID, | 		pf.GetPodSandboxId(), | ||||||
| 		"", // unused: podUID | 		"", // unused: podUID | ||||||
| 		s.config.StreamIdleTimeout, | 		s.config.StreamIdleTimeout, | ||||||
| 		s.config.StreamCreationTimeout) | 		s.config.StreamCreationTimeout) | ||||||
|   | |||||||
| @@ -18,12 +18,12 @@ package streaming | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"fmt" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| @@ -46,18 +46,18 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestGetExec(t *testing.T) { | func TestGetExec(t *testing.T) { | ||||||
| 	testcases := []struct { | 	type testcase struct { | ||||||
| 		cmd           []string | 		cmd   []string | ||||||
| 		tty           bool | 		tty   bool | ||||||
| 		stdin         bool | 		stdin bool | ||||||
| 		expectedQuery string |  | ||||||
| 	}{ |  | ||||||
| 		{[]string{"echo", "foo"}, false, false, "?command=echo&command=foo&error=1&output=1"}, |  | ||||||
| 		{[]string{"date"}, true, false, "?command=date&output=1&tty=1"}, |  | ||||||
| 		{[]string{"date"}, false, true, "?command=date&error=1&input=1&output=1"}, |  | ||||||
| 		{[]string{"date"}, true, true, "?command=date&input=1&output=1&tty=1"}, |  | ||||||
| 	} | 	} | ||||||
| 	server, err := NewServer(Config{ | 	testcases := []testcase{ | ||||||
|  | 		{[]string{"echo", "foo"}, false, false}, | ||||||
|  | 		{[]string{"date"}, true, false}, | ||||||
|  | 		{[]string{"date"}, false, true}, | ||||||
|  | 		{[]string{"date"}, true, true}, | ||||||
|  | 	} | ||||||
|  | 	serv, err := NewServer(Config{ | ||||||
| 		Addr: testAddr, | 		Addr: testAddr, | ||||||
| 	}, nil) | 	}, nil) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| @@ -79,6 +79,14 @@ func TestGetExec(t *testing.T) { | |||||||
| 	}, nil) | 	}, nil) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	assertRequestToken := func(test testcase, cache *requestCache, token string) { | ||||||
|  | 		req, ok := cache.Consume(token) | ||||||
|  | 		require.True(t, ok, "token %s not found! testcase=%+v", token, test) | ||||||
|  | 		assert.Equal(t, testContainerID, req.(*runtimeapi.ExecRequest).GetContainerId(), "testcase=%+v", test) | ||||||
|  | 		assert.Equal(t, test.cmd, req.(*runtimeapi.ExecRequest).GetCmd(), "testcase=%+v", test) | ||||||
|  | 		assert.Equal(t, test.tty, req.(*runtimeapi.ExecRequest).GetTty(), "testcase=%+v", test) | ||||||
|  | 		assert.Equal(t, test.stdin, req.(*runtimeapi.ExecRequest).GetStdin(), "testcase=%+v", test) | ||||||
|  | 	} | ||||||
| 	containerID := testContainerID | 	containerID := testContainerID | ||||||
| 	for _, test := range testcases { | 	for _, test := range testcases { | ||||||
| 		request := &runtimeapi.ExecRequest{ | 		request := &runtimeapi.ExecRequest{ | ||||||
| @@ -87,38 +95,47 @@ func TestGetExec(t *testing.T) { | |||||||
| 			Tty:         &test.tty, | 			Tty:         &test.tty, | ||||||
| 			Stdin:       &test.stdin, | 			Stdin:       &test.stdin, | ||||||
| 		} | 		} | ||||||
| 		// Non-TLS | 		{ // Non-TLS | ||||||
| 		resp, err := server.GetExec(request) | 			resp, err := serv.GetExec(request) | ||||||
| 		assert.NoError(t, err, "testcase=%+v", test) | 			assert.NoError(t, err, "testcase=%+v", test) | ||||||
| 		expectedURL := "http://" + testAddr + "/exec/" + testContainerID + test.expectedQuery | 			expectedURL := "http://" + testAddr + "/exec/" | ||||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||||
|  | 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||||
|  | 			assertRequestToken(test, serv.(*server).cache, token) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// TLS | 		{ // TLS | ||||||
| 		resp, err = tlsServer.GetExec(request) | 			resp, err := tlsServer.GetExec(request) | ||||||
| 		assert.NoError(t, err, "testcase=%+v", test) | 			assert.NoError(t, err, "testcase=%+v", test) | ||||||
| 		expectedURL = "https://" + testAddr + "/exec/" + testContainerID + test.expectedQuery | 			expectedURL := "https://" + testAddr + "/exec/" | ||||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||||
|  | 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||||
|  | 			assertRequestToken(test, tlsServer.(*server).cache, token) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// Path prefix | 		{ // Path prefix | ||||||
| 		resp, err = prefixServer.GetExec(request) | 			resp, err := prefixServer.GetExec(request) | ||||||
| 		assert.NoError(t, err, "testcase=%+v", test) | 			assert.NoError(t, err, "testcase=%+v", test) | ||||||
| 		expectedURL = "http://" + testAddr + "/" + pathPrefix + "/exec/" + testContainerID + test.expectedQuery | 			expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/" | ||||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||||
|  | 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||||
|  | 			assertRequestToken(test, prefixServer.(*server).cache, token) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestGetAttach(t *testing.T) { | func TestGetAttach(t *testing.T) { | ||||||
| 	testcases := []struct { | 	type testcase struct { | ||||||
| 		tty           bool | 		tty   bool | ||||||
| 		stdin         bool | 		stdin bool | ||||||
| 		expectedQuery string |  | ||||||
| 	}{ |  | ||||||
| 		{false, false, "?error=1&output=1"}, |  | ||||||
| 		{true, false, "?output=1&tty=1"}, |  | ||||||
| 		{false, true, "?error=1&input=1&output=1"}, |  | ||||||
| 		{true, true, "?input=1&output=1&tty=1"}, |  | ||||||
| 	} | 	} | ||||||
| 	server, err := NewServer(Config{ | 	testcases := []testcase{ | ||||||
|  | 		{false, false}, | ||||||
|  | 		{true, false}, | ||||||
|  | 		{false, true}, | ||||||
|  | 		{true, true}, | ||||||
|  | 	} | ||||||
|  | 	serv, err := NewServer(Config{ | ||||||
| 		Addr: testAddr, | 		Addr: testAddr, | ||||||
| 	}, nil) | 	}, nil) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| @@ -129,6 +146,13 @@ func TestGetAttach(t *testing.T) { | |||||||
| 	}, nil) | 	}, nil) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	assertRequestToken := func(test testcase, cache *requestCache, token string) { | ||||||
|  | 		req, ok := cache.Consume(token) | ||||||
|  | 		require.True(t, ok, "token %s not found! testcase=%+v", token, test) | ||||||
|  | 		assert.Equal(t, testContainerID, req.(*runtimeapi.AttachRequest).GetContainerId(), "testcase=%+v", test) | ||||||
|  | 		assert.Equal(t, test.tty, req.(*runtimeapi.AttachRequest).GetTty(), "testcase=%+v", test) | ||||||
|  | 		assert.Equal(t, test.stdin, req.(*runtimeapi.AttachRequest).GetStdin(), "testcase=%+v", test) | ||||||
|  | 	} | ||||||
| 	containerID := testContainerID | 	containerID := testContainerID | ||||||
| 	for _, test := range testcases { | 	for _, test := range testcases { | ||||||
| 		request := &runtimeapi.AttachRequest{ | 		request := &runtimeapi.AttachRequest{ | ||||||
| @@ -136,17 +160,23 @@ func TestGetAttach(t *testing.T) { | |||||||
| 			Stdin:       &test.stdin, | 			Stdin:       &test.stdin, | ||||||
| 			Tty:         &test.tty, | 			Tty:         &test.tty, | ||||||
| 		} | 		} | ||||||
| 		// Non-TLS | 		{ // Non-TLS | ||||||
| 		resp, err := server.GetAttach(request) | 			resp, err := serv.GetAttach(request) | ||||||
| 		assert.NoError(t, err, "testcase=%+v", test) | 			assert.NoError(t, err, "testcase=%+v", test) | ||||||
| 		expectedURL := "http://" + testAddr + "/attach/" + testContainerID + test.expectedQuery | 			expectedURL := "http://" + testAddr + "/attach/" | ||||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||||
|  | 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||||
|  | 			assertRequestToken(test, serv.(*server).cache, token) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// TLS | 		{ // TLS | ||||||
| 		resp, err = tlsServer.GetAttach(request) | 			resp, err := tlsServer.GetAttach(request) | ||||||
| 		assert.NoError(t, err, "testcase=%+v", test) | 			assert.NoError(t, err, "testcase=%+v", test) | ||||||
| 		expectedURL = "https://" + testAddr + "/attach/" + testContainerID + test.expectedQuery | 			expectedURL := "https://" + testAddr + "/attach/" | ||||||
| 		assert.Equal(t, expectedURL, resp.GetUrl(), "testcase=%+v", test) | 			assert.Contains(t, resp.GetUrl(), expectedURL, "testcase=%+v", test) | ||||||
|  | 			token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||||
|  | 			assertRequestToken(test, tlsServer.(*server).cache, token) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -157,26 +187,36 @@ func TestGetPortForward(t *testing.T) { | |||||||
| 		Port:         []int32{1, 2, 3, 4}, | 		Port:         []int32{1, 2, 3, 4}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Non-TLS | 	{ // Non-TLS | ||||||
| 	server, err := NewServer(Config{ | 		serv, err := NewServer(Config{ | ||||||
| 		Addr: testAddr, | 			Addr: testAddr, | ||||||
| 	}, nil) | 		}, nil) | ||||||
| 	assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 	resp, err := server.GetPortForward(request) | 		resp, err := serv.GetPortForward(request) | ||||||
| 	assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 	expectedURL := "http://" + testAddr + "/portforward/" + testPodSandboxID | 		expectedURL := "http://" + testAddr + "/portforward/" | ||||||
| 	assert.Equal(t, expectedURL, resp.GetUrl()) | 		assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL)) | ||||||
|  | 		token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||||
|  | 		req, ok := serv.(*server).cache.Consume(token) | ||||||
|  | 		require.True(t, ok, "token %s not found!", token) | ||||||
|  | 		assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// TLS | 	{ // TLS | ||||||
| 	tlsServer, err := NewServer(Config{ | 		tlsServer, err := NewServer(Config{ | ||||||
| 		Addr:      testAddr, | 			Addr:      testAddr, | ||||||
| 		TLSConfig: &tls.Config{}, | 			TLSConfig: &tls.Config{}, | ||||||
| 	}, nil) | 		}, nil) | ||||||
| 	assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 	resp, err = tlsServer.GetPortForward(request) | 		resp, err := tlsServer.GetPortForward(request) | ||||||
| 	assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 	expectedURL = "https://" + testAddr + "/portforward/" + testPodSandboxID | 		expectedURL := "https://" + testAddr + "/portforward/" | ||||||
| 	assert.Equal(t, expectedURL, resp.GetUrl()) | 		assert.True(t, strings.HasPrefix(resp.GetUrl(), expectedURL)) | ||||||
|  | 		token := strings.TrimPrefix(resp.GetUrl(), expectedURL) | ||||||
|  | 		req, ok := tlsServer.(*server).cache.Consume(token) | ||||||
|  | 		require.True(t, ok, "token %s not found!", token) | ||||||
|  | 		assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).GetPodSandboxId()) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestServeExec(t *testing.T) { | func TestServeExec(t *testing.T) { | ||||||
| @@ -188,21 +228,18 @@ func TestServeAttach(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestServePortForward(t *testing.T) { | func TestServePortForward(t *testing.T) { | ||||||
| 	rt := newFakeRuntime(t) | 	s, testServer := startTestServer(t) | ||||||
| 	s, err := NewServer(DefaultConfig, rt) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	testServer := httptest.NewServer(s) |  | ||||||
| 	defer testServer.Close() | 	defer testServer.Close() | ||||||
|  |  | ||||||
| 	testURL, err := url.Parse(testServer.URL) | 	podSandboxID := testPodSandboxID | ||||||
|  | 	resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{ | ||||||
|  | 		PodSandboxId: &podSandboxID, | ||||||
|  | 	}) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	reqURL, err := url.Parse(resp.GetUrl()) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 	loc := &url.URL{ |  | ||||||
| 		Scheme: testURL.Scheme, |  | ||||||
| 		Host:   testURL.Host, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	loc.Path = fmt.Sprintf("/%s/%s", "portforward", testPodSandboxID) | 	exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) | ||||||
| 	exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) |  | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 	streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name) | 	streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| @@ -227,22 +264,30 @@ func TestServePortForward(t *testing.T) { | |||||||
| // Run the remote command test. | // Run the remote command test. | ||||||
| // commandType is either "exec" or "attach". | // commandType is either "exec" or "attach". | ||||||
| func runRemoteCommandTest(t *testing.T, commandType string) { | func runRemoteCommandTest(t *testing.T, commandType string) { | ||||||
| 	rt := newFakeRuntime(t) | 	s, testServer := startTestServer(t) | ||||||
| 	s, err := NewServer(DefaultConfig, rt) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	testServer := httptest.NewServer(s) |  | ||||||
| 	defer testServer.Close() | 	defer testServer.Close() | ||||||
|  |  | ||||||
| 	testURL, err := url.Parse(testServer.URL) | 	var reqURL *url.URL | ||||||
| 	require.NoError(t, err) | 	stdin := true | ||||||
| 	query := url.Values{} | 	containerID := testContainerID | ||||||
| 	query.Add(urlParamStdin, "1") | 	switch commandType { | ||||||
| 	query.Add(urlParamStdout, "1") | 	case "exec": | ||||||
| 	query.Add(urlParamStderr, "1") | 		resp, err := s.GetExec(&runtimeapi.ExecRequest{ | ||||||
| 	loc := &url.URL{ | 			ContainerId: &containerID, | ||||||
| 		Scheme:   testURL.Scheme, | 			Cmd:         []string{"echo"}, | ||||||
| 		Host:     testURL.Host, | 			Stdin:       &stdin, | ||||||
| 		RawQuery: query.Encode(), | 		}) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		reqURL, err = url.Parse(resp.GetUrl()) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 	case "attach": | ||||||
|  | 		resp, err := s.GetAttach(&runtimeapi.AttachRequest{ | ||||||
|  | 			ContainerId: &containerID, | ||||||
|  | 			Stdin:       &stdin, | ||||||
|  | 		}) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		reqURL, err = url.Parse(resp.GetUrl()) | ||||||
|  | 		require.NoError(t, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	wg := sync.WaitGroup{} | 	wg := sync.WaitGroup{} | ||||||
| @@ -254,8 +299,7 @@ func runRemoteCommandTest(t *testing.T, commandType string) { | |||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| 		defer wg.Done() | 		defer wg.Done() | ||||||
| 		loc.Path = fmt.Sprintf("/%s/%s", commandType, testContainerID) | 		exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) | ||||||
| 		exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", loc) |  | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
|  |  | ||||||
| 		opts := remotecommand.StreamOptions{ | 		opts := remotecommand.StreamOptions{ | ||||||
| @@ -275,6 +319,36 @@ func runRemoteCommandTest(t *testing.T, commandType string) { | |||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
|  |  | ||||||
|  | 	// Repeat request with the same URL should be a 404. | ||||||
|  | 	resp, err := http.Get(reqURL.String()) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	assert.Equal(t, http.StatusNotFound, resp.StatusCode) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func startTestServer(t *testing.T) (Server, *httptest.Server) { | ||||||
|  | 	var s Server | ||||||
|  | 	testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		s.ServeHTTP(w, r) | ||||||
|  | 	})) | ||||||
|  | 	cleanup := true | ||||||
|  | 	defer func() { | ||||||
|  | 		if cleanup { | ||||||
|  | 			testServer.Close() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	testURL, err := url.Parse(testServer.URL) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	rt := newFakeRuntime(t) | ||||||
|  | 	config := DefaultConfig | ||||||
|  | 	config.BaseURL = testURL | ||||||
|  | 	s, err = NewServer(config, rt) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	cleanup = false // Caller must close the test server. | ||||||
|  | 	return s, testServer | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|   | |||||||
| @@ -391,6 +391,14 @@ var _ = framework.KubeDescribe("Kubectl client", func() { | |||||||
| 				framework.Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a) | 				framework.Failf("Unexpected kubectl exec output. Wanted %q, got %q", e, a) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			By("executing a very long command in the container") | ||||||
|  | 			veryLongData := make([]rune, 20000) | ||||||
|  | 			for i := 0; i < len(veryLongData); i++ { | ||||||
|  | 				veryLongData[i] = 'a' | ||||||
|  | 			} | ||||||
|  | 			execOutput = framework.RunKubectlOrDie("exec", fmt.Sprintf("--namespace=%v", ns), simplePodName, "echo", string(veryLongData)) | ||||||
|  | 			Expect(string(veryLongData)).To(Equal(strings.TrimSpace(execOutput)), "Unexpected kubectl exec output") | ||||||
|  |  | ||||||
| 			By("executing a command in the container with noninteractive stdin") | 			By("executing a command in the container with noninteractive stdin") | ||||||
| 			execOutput = framework.NewKubectlCommand("exec", fmt.Sprintf("--namespace=%v", ns), "-i", simplePodName, "cat"). | 			execOutput = framework.NewKubectlCommand("exec", fmt.Sprintf("--namespace=%v", ns), "-i", simplePodName, "cat"). | ||||||
| 				WithStdinData("abcd1234"). | 				WithStdinData("abcd1234"). | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Tim St. Clair
					Tim St. Clair