mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
Cleanly close SSH connections
This commit is contained in:
@@ -57,6 +57,16 @@ func SSHCommNew(address string, config *SSHCommConfig) (result *comm, err error)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *comm) Close() error {
|
||||
var err error
|
||||
if c.conn != nil {
|
||||
err = c.conn.Close()
|
||||
}
|
||||
c.conn = nil
|
||||
c.client = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
|
||||
// The target directory and file for talking the SCP protocol
|
||||
target_dir := filepath.Dir(path)
|
||||
@@ -74,7 +84,7 @@ func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
|
||||
return c.scpSession("scp -vt "+target_dir, scpFunc)
|
||||
}
|
||||
|
||||
func (c *comm) newSession() (session *ssh.Session, err error) {
|
||||
func (c *comm) NewSession() (session *ssh.Session, err error) {
|
||||
if c.client == nil {
|
||||
err = errors.New("client not available")
|
||||
} else {
|
||||
@@ -93,15 +103,13 @@ func (c *comm) newSession() (session *ssh.Session, err error) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *comm) reconnect() (err error) {
|
||||
func (c *comm) reconnect() error {
|
||||
// Close previous connection.
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.Close()
|
||||
}
|
||||
|
||||
// Set the conn and client to nil since we'll recreate it
|
||||
c.conn = nil
|
||||
c.client = nil
|
||||
|
||||
var err error
|
||||
c.conn, err = c.config.Connection()
|
||||
if err != nil {
|
||||
// Explicitly set this to the REAL nil. Connection() can return
|
||||
@@ -112,19 +120,21 @@ func (c *comm) reconnect() (err error) {
|
||||
// http://golang.org/doc/faq#nil_error
|
||||
c.conn = nil
|
||||
log.Printf("reconnection error: %s", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
|
||||
if err != nil {
|
||||
log.Printf("handshake error: %s", err)
|
||||
c.Close()
|
||||
return err
|
||||
}
|
||||
if sshConn != nil {
|
||||
c.client = ssh.NewClient(sshConn, sshChan, req)
|
||||
}
|
||||
c.connectToAgent()
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *comm) connectToAgent() {
|
||||
@@ -146,12 +156,12 @@ func (c *comm) connectToAgent() {
|
||||
log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
|
||||
return
|
||||
}
|
||||
defer agentConn.Close()
|
||||
|
||||
// create agent and add in auth
|
||||
forwardingAgent := agent.NewClient(agentConn)
|
||||
if forwardingAgent == nil {
|
||||
log.Printf("[ERROR] Could not create agent client")
|
||||
agentConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -162,7 +172,7 @@ func (c *comm) connectToAgent() {
|
||||
agent.ForwardToAgent(c.client, forwardingAgent)
|
||||
|
||||
// Setup a session to request agent forwarding
|
||||
session, err := c.newSession()
|
||||
session, err := c.NewSession()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -177,7 +187,7 @@ func (c *comm) connectToAgent() {
|
||||
}
|
||||
|
||||
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
|
||||
session, err := c.newSession()
|
||||
session, err := c.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user