mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			198 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			198 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package ssh
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/rand"
 | |
| 	"crypto/rsa"
 | |
| 	"crypto/x509"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/pem"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/hashicorp/vault/logical"
 | |
| 
 | |
| 	"golang.org/x/crypto/ssh"
 | |
| )
 | |
| 
 | |
| // Creates a SSH session object which can be used to run commands
 | |
| // in the target machine. The session will use public key authentication
 | |
| // method with port 22.
 | |
| func createSSHPublicKeysSession(username, ipAddr string, port int, hostKey string) (*ssh.Session, error) {
 | |
| 	if username == "" {
 | |
| 		return nil, fmt.Errorf("missing username")
 | |
| 	}
 | |
| 	if ipAddr == "" {
 | |
| 		return nil, fmt.Errorf("missing ip address")
 | |
| 	}
 | |
| 	if hostKey == "" {
 | |
| 		return nil, fmt.Errorf("missing host key")
 | |
| 	}
 | |
| 	signer, err := ssh.ParsePrivateKey([]byte(hostKey))
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("parsing Private Key failed: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	config := &ssh.ClientConfig{
 | |
| 		User: username,
 | |
| 		Auth: []ssh.AuthMethod{
 | |
| 			ssh.PublicKeys(signer),
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ipAddr, port), config)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if client == nil {
 | |
| 		return nil, fmt.Errorf("invalid client object: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	session, err := client.NewSession()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return session, nil
 | |
| }
 | |
| 
 | |
| // Creates a new RSA key pair with key length of 2048.
 | |
| // The private key will be of pem format and the public key will be
 | |
| // of OpenSSH format.
 | |
| func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) {
 | |
| 	privateKey, err := rsa.GenerateKey(rand.Reader, keyBits)
 | |
| 	if err != nil {
 | |
| 		return "", "", fmt.Errorf("error generating RSA key-pair: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	privateKeyRsa = string(pem.EncodeToMemory(&pem.Block{
 | |
| 		Type:  "RSA PRIVATE KEY",
 | |
| 		Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
 | |
| 	}))
 | |
| 
 | |
| 	sshPublicKey, err := ssh.NewPublicKey(privateKey.Public())
 | |
| 	if err != nil {
 | |
| 		return "", "", fmt.Errorf("error generating RSA key-pair: %s", err)
 | |
| 	}
 | |
| 	publicKeyRsa = "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // Installs or uninstalls the dynamic key in the remote host. The parameterized script
 | |
| // will install or uninstall the key. The remote host is assumed to be Linux,
 | |
| // and hence the path of the authorized_keys file is hard coded to resemble Linux.
 | |
| // Installing and uninstalling the keys means that the public key is appended or
 | |
| // removed from authorized_keys file.
 | |
| // The param 'install' if false, uninstalls the key.
 | |
| func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip string, port int, hostkey string, install bool) error {
 | |
| 	session, err := createSSHPublicKeysSession(adminUser, ip, port, hostkey)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("unable to create SSH Session using public keys: %s", err)
 | |
| 	}
 | |
| 	if session == nil {
 | |
| 		return fmt.Errorf("invalid session object")
 | |
| 	}
 | |
| 	defer session.Close()
 | |
| 
 | |
| 	authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username)
 | |
| 	scriptFileName := fmt.Sprintf("%s.sh", publicKeyFileName)
 | |
| 
 | |
| 	var installOption string
 | |
| 	if install {
 | |
| 		installOption = "install"
 | |
| 	} else {
 | |
| 		installOption = "uninstall"
 | |
| 	}
 | |
| 	// Give execute permissions to install script, run and delete it.
 | |
| 	chmodCmd := fmt.Sprintf("chmod +x %s", scriptFileName)
 | |
| 	scriptCmd := fmt.Sprintf("./%s %s %s %s", scriptFileName, installOption, publicKeyFileName, authKeysFileName)
 | |
| 	rmCmd := fmt.Sprintf("rm -f %s", scriptFileName)
 | |
| 	targetCmd := fmt.Sprintf("%s;%s;%s", chmodCmd, scriptCmd, rmCmd)
 | |
| 
 | |
| 	session.Run(targetCmd)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Takes an IP address and role name and checks if the IP is part
 | |
| // of CIDR blocks belonging to the role.
 | |
| func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error) {
 | |
| 	if roleName == "" {
 | |
| 		return false, fmt.Errorf("missing role name")
 | |
| 	}
 | |
| 
 | |
| 	if ip == "" {
 | |
| 		return false, fmt.Errorf("missing ip")
 | |
| 	}
 | |
| 
 | |
| 	roleEntry, err := s.Get(fmt.Sprintf("roles/%s", roleName))
 | |
| 	if err != nil {
 | |
| 		return false, fmt.Errorf("error retrieving role '%s'", err)
 | |
| 	}
 | |
| 	if roleEntry == nil {
 | |
| 		return false, fmt.Errorf("role '%s' not found", roleName)
 | |
| 	}
 | |
| 
 | |
| 	var role sshRole
 | |
| 	if err := roleEntry.DecodeJSON(&role); err != nil {
 | |
| 		return false, fmt.Errorf("error decoding role '%s'", roleName)
 | |
| 	}
 | |
| 
 | |
| 	if matched, err := cidrContainsIP(ip, role.CIDRList); err != nil {
 | |
| 		return false, err
 | |
| 	} else {
 | |
| 		return matched, nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Returns true if the IP supplied by the user is part of the comma
 | |
| // separated CIDR blocks
 | |
| func cidrContainsIP(ip, cidrList string) (bool, error) {
 | |
| 	for _, item := range strings.Split(cidrList, ",") {
 | |
| 		_, cidrIPNet, err := net.ParseCIDR(item)
 | |
| 		if err != nil {
 | |
| 			return false, fmt.Errorf("invalid CIDR entry '%s'", item)
 | |
| 		}
 | |
| 		if cidrIPNet.Contains(net.ParseIP(ip)) {
 | |
| 			return true, nil
 | |
| 		}
 | |
| 	}
 | |
| 	return false, nil
 | |
| }
 | |
| 
 | |
| func scpUpload(username, ip string, port int, hostkey, fileName, fileContent string) error {
 | |
| 	signer, err := ssh.ParsePrivateKey([]byte(hostkey))
 | |
| 	clientConfig := &ssh.ClientConfig{
 | |
| 		User: username,
 | |
| 		Auth: []ssh.AuthMethod{
 | |
| 			ssh.PublicKeys(signer),
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	connfunc := func() (net.Conn, error) {
 | |
| 		c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", ip, port), 15*time.Second)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		if tcpConn, ok := c.(*net.TCPConn); ok {
 | |
| 			tcpConn.SetKeepAlive(true)
 | |
| 			tcpConn.SetKeepAlivePeriod(5 * time.Second)
 | |
| 		}
 | |
| 
 | |
| 		return c, nil
 | |
| 	}
 | |
| 	config := &SSHCommConfig{
 | |
| 		SSHConfig:    clientConfig,
 | |
| 		Connection:   connfunc,
 | |
| 		Pty:          false,
 | |
| 		DisableAgent: true,
 | |
| 	}
 | |
| 	comm, err := SSHCommNew(fmt.Sprintf("%s:%d", ip, port), config)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error connecting to target: %s", err)
 | |
| 	}
 | |
| 	comm.Upload(fileName, bytes.NewBufferString(fileContent), nil)
 | |
| 	return nil
 | |
| }
 | 
