mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-11-04 04:08:16 +00:00 
			
		
		
		
	Merge pull request #28942 from kubernetes/revert-28805-ssh-dial-timeout
Revert "Add a customized ssh dialer that will timeout"
This commit is contained in:
		@@ -111,7 +111,7 @@ func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, err
 | 
			
		||||
 | 
			
		||||
func (s *SSHTunnel) Open() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	s.client, err = defaultTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
 | 
			
		||||
	s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
 | 
			
		||||
	tunnelOpenCounter.Inc()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		tunnelOpenFailCounter.Inc()
 | 
			
		||||
@@ -154,9 +154,21 @@ type sshDialer interface {
 | 
			
		||||
	Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// timeoutDialer implements a Dial() method that will timeout. The golang
 | 
			
		||||
// Real implementation of sshDialer
 | 
			
		||||
type realSSHDialer struct{}
 | 
			
		||||
 | 
			
		||||
var _ sshDialer = &realSSHDialer{}
 | 
			
		||||
 | 
			
		||||
func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
 | 
			
		||||
	return ssh.Dial(network, addr, config)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
 | 
			
		||||
// ssh library can hang indefinitely inside the Dial() call (see issue #23835).
 | 
			
		||||
// Wrapping all Dial() calls with a conservative timeout provides safety against
 | 
			
		||||
// getting stuck on that.
 | 
			
		||||
type timeoutDialer struct {
 | 
			
		||||
	dialer  sshDialer
 | 
			
		||||
	timeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -164,32 +176,30 @@ type timeoutDialer struct {
 | 
			
		||||
// seconds). This timeout is only intended to catch otherwise uncaught hangs.
 | 
			
		||||
const sshDialTimeout = 150 * time.Second
 | 
			
		||||
 | 
			
		||||
var defaultTimeoutDialer sshDialer = &timeoutDialer{sshDialTimeout}
 | 
			
		||||
var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}
 | 
			
		||||
 | 
			
		||||
func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
 | 
			
		||||
	conn, err := net.Dial(network, addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	conn.SetDeadline(time.Now().Add(d.timeout))
 | 
			
		||||
	// set to 0 so that conn will not time out after Dial.
 | 
			
		||||
	defer func() {
 | 
			
		||||
		conn.SetDeadline(time.Time{})
 | 
			
		||||
	var client *ssh.Client
 | 
			
		||||
	errCh := make(chan error, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer runtime.HandleCrash()
 | 
			
		||||
		var err error
 | 
			
		||||
		client, err = d.dialer.Dial(network, addr, config)
 | 
			
		||||
		errCh <- err
 | 
			
		||||
	}()
 | 
			
		||||
	// if conn times out, the NewClientConn will close it, so we will not end up
 | 
			
		||||
	// with hanging goroutines or open file descriptors.
 | 
			
		||||
	c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	select {
 | 
			
		||||
	case err := <-errCh:
 | 
			
		||||
		return client, err
 | 
			
		||||
	case <-time.After(d.timeout):
 | 
			
		||||
		return nil, fmt.Errorf("timed out dialing %s:%s", network, addr)
 | 
			
		||||
	}
 | 
			
		||||
	return ssh.NewClient(c, chans, reqs), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
 | 
			
		||||
// host as specific user, along with any SSH-level error.
 | 
			
		||||
// If user=="", it will default (like SSH) to os.Getenv("USER")
 | 
			
		||||
func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
 | 
			
		||||
	return runSSHCommand(defaultTimeoutDialer, cmd, user, host, signer, true)
 | 
			
		||||
	return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Internal implementation of runSSHCommand, for testing
 | 
			
		||||
 
 | 
			
		||||
@@ -329,49 +329,38 @@ func TestSSHUser(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type slowDialer struct {
 | 
			
		||||
	delay time.Duration
 | 
			
		||||
	err   error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
 | 
			
		||||
	time.Sleep(s.delay)
 | 
			
		||||
	if s.err != nil {
 | 
			
		||||
		return nil, s.err
 | 
			
		||||
	}
 | 
			
		||||
	return &ssh.Client{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTimeoutDialer(t *testing.T) {
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		delay             time.Duration
 | 
			
		||||
		timeout           time.Duration
 | 
			
		||||
		err               error
 | 
			
		||||
		expectedErrString string
 | 
			
		||||
	}{
 | 
			
		||||
		// should cause ssh.Dial to timeout.
 | 
			
		||||
		{0, "i/o timeout"},
 | 
			
		||||
		// should succeed
 | 
			
		||||
		{1 * time.Second, ""},
 | 
			
		||||
		// delay > timeout should cause ssh.Dial to timeout.
 | 
			
		||||
		{1 * time.Second, 0, nil, "timed out dialing"},
 | 
			
		||||
		// delay < timeout should return the result of the call to the dialer.
 | 
			
		||||
		{0, 1 * time.Second, nil, ""},
 | 
			
		||||
		{0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"},
 | 
			
		||||
	}
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		// setup
 | 
			
		||||
		private, _, err := GenerateKey(2048)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("unexpected error: %v", err)
 | 
			
		||||
			t.FailNow()
 | 
			
		||||
		}
 | 
			
		||||
		server, err := runTestSSHServer("foo", "bar")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("unexpected error: %v", err)
 | 
			
		||||
			t.FailNow()
 | 
			
		||||
		}
 | 
			
		||||
		privateData := EncodePrivateKey(private)
 | 
			
		||||
		tunnel, err := NewSSHTunnelFromBytes("foo", privateData, server.Host)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("unexpected error: %v", err)
 | 
			
		||||
			t.FailNow()
 | 
			
		||||
		}
 | 
			
		||||
		tunnel.SSHPort = server.Port
 | 
			
		||||
 | 
			
		||||
		// test the dialer
 | 
			
		||||
		dialer := &timeoutDialer{tc.timeout}
 | 
			
		||||
		client, err := dialer.Dial("tcp", net.JoinHostPort(tunnel.Host, tunnel.SSHPort), tunnel.Config)
 | 
			
		||||
		dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout}
 | 
			
		||||
		_, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{})
 | 
			
		||||
		if len(tc.expectedErrString) == 0 && err != nil ||
 | 
			
		||||
			!strings.Contains(fmt.Sprint(err), tc.expectedErrString) {
 | 
			
		||||
			t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err)
 | 
			
		||||
		}
 | 
			
		||||
		if len(tc.expectedErrString) == 0 {
 | 
			
		||||
			// verify the connection doesn't timeout after the handshake is done.
 | 
			
		||||
			time.Sleep(tc.timeout + 1*time.Second)
 | 
			
		||||
			if _, _, err := client.OpenChannel("direct-tcpip", nil); err != nil {
 | 
			
		||||
				t.Errorf("unexpected error %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user