Cleanly close SSH connections

This commit is contained in:
Chi Vinh Le
2016-01-19 07:59:08 +01:00
parent 0e2a0cd5b5
commit 555834f83d
2 changed files with 38 additions and 63 deletions

View File

@@ -57,6 +57,16 @@ func SSHCommNew(address string, config *SSHCommConfig) (result *comm, err error)
return
}
func (c *comm) Close() error {
var err error
if c.conn != nil {
err = c.conn.Close()
}
c.conn = nil
c.client = nil
return err
}
func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
// The target directory and file for talking the SCP protocol
target_dir := filepath.Dir(path)
@@ -74,7 +84,7 @@ func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
return c.scpSession("scp -vt "+target_dir, scpFunc)
}
func (c *comm) newSession() (session *ssh.Session, err error) {
func (c *comm) NewSession() (session *ssh.Session, err error) {
if c.client == nil {
err = errors.New("client not available")
} else {
@@ -93,15 +103,13 @@ func (c *comm) newSession() (session *ssh.Session, err error) {
return session, nil
}
func (c *comm) reconnect() (err error) {
func (c *comm) reconnect() error {
// Close previous connection.
if c.conn != nil {
c.conn.Close()
c.Close()
}
// Set the conn and client to nil since we'll recreate it
c.conn = nil
c.client = nil
var err error
c.conn, err = c.config.Connection()
if err != nil {
// Explicitly set this to the REAL nil. Connection() can return
@@ -112,19 +120,21 @@ func (c *comm) reconnect() (err error) {
// http://golang.org/doc/faq#nil_error
c.conn = nil
log.Printf("reconnection error: %s", err)
return
return err
}
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
if err != nil {
log.Printf("handshake error: %s", err)
c.Close()
return err
}
if sshConn != nil {
c.client = ssh.NewClient(sshConn, sshChan, req)
}
c.connectToAgent()
return
return nil
}
func (c *comm) connectToAgent() {
@@ -146,12 +156,12 @@ func (c *comm) connectToAgent() {
log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
return
}
defer agentConn.Close()
// create agent and add in auth
forwardingAgent := agent.NewClient(agentConn)
if forwardingAgent == nil {
log.Printf("[ERROR] Could not create agent client")
agentConn.Close()
return
}
@@ -162,7 +172,7 @@ func (c *comm) connectToAgent() {
agent.ForwardToAgent(c.client, forwardingAgent)
// Setup a session to request agent forwarding
session, err := c.newSession()
session, err := c.NewSession()
if err != nil {
return
}
@@ -177,7 +187,7 @@ func (c *comm) connectToAgent() {
}
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
session, err := c.newSession()
session, err := c.NewSession()
if err != nil {
return err
}