mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	Vault SSH: Review Rework
This commit is contained in:
		
							
								
								
									
										22
									
								
								api/ssh.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								api/ssh.go
									
									
									
									
									
								
							| @@ -4,27 +4,21 @@ import "fmt" | |||||||
|  |  | ||||||
| // SSH is used to return a client to invoke operations on SSH backend. | // SSH is used to return a client to invoke operations on SSH backend. | ||||||
| type SSH struct { | type SSH struct { | ||||||
| 	c *Client | 	c    *Client | ||||||
|  | 	Path string | ||||||
| } | } | ||||||
|  |  | ||||||
| // SSH is used to return the client for logical-backend API calls. | // SSH is used to return the client for logical-backend API calls. | ||||||
| func (c *Client) SSH() *SSH { | func (c *Client) SSH(path string) *SSH { | ||||||
| 	return &SSH{c: c} | 	return &SSH{ | ||||||
| } | 		c:    c, | ||||||
|  | 		Path: path, | ||||||
| // Invokes the SSH backend API to revoke a key identified by its lease ID. |  | ||||||
| func (c *SSH) KeyRevoke(id string) error { |  | ||||||
| 	r := c.c.NewRequest("PUT", "/v1/sys/revoke/"+id) |  | ||||||
| 	resp, err := c.c.RawRequest(r) |  | ||||||
| 	if err == nil { |  | ||||||
| 		defer resp.Body.Close() |  | ||||||
| 	} | 	} | ||||||
| 	return err |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Invokes the SSH backend API to create a dynamic key or an OTP | // Invokes the SSH backend API to create a dynamic key or an OTP | ||||||
| func (c *SSH) KeyCreate(role string, data map[string]interface{}) (*Secret, error) { | func (c *SSH) Credential(role string, data map[string]interface{}) (*Secret, error) { | ||||||
| 	r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/ssh/creds/%s", role)) | 	r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/%s/creds/%s", c.Path, role)) | ||||||
| 	if err := r.SetJSONBody(data); err != nil { | 	if err := r.SetJSONBody(data); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -8,6 +8,11 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type backend struct { | ||||||
|  | 	*framework.Backend | ||||||
|  | 	salt *salt.Salt | ||||||
|  | } | ||||||
|  |  | ||||||
| func Factory(conf *logical.BackendConfig) (logical.Backend, error) { | func Factory(conf *logical.BackendConfig) (logical.Backend, error) { | ||||||
| 	b, err := Backend(conf) | 	b, err := Backend(conf) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -54,11 +59,6 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) { | |||||||
| 	return b.Backend, nil | 	return b.Backend, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type backend struct { |  | ||||||
| 	*framework.Backend |  | ||||||
| 	salt *salt.Salt |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const backendHelp = ` | const backendHelp = ` | ||||||
| The SSH backend generates keys to eatablish SSH connection | The SSH backend generates keys to eatablish SSH connection | ||||||
| with remote hosts. There are two options to create the keys: | with remote hosts. There are two options to create the keys: | ||||||
|   | |||||||
| @@ -8,6 +8,11 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type configLease struct { | ||||||
|  | 	Lease    time.Duration | ||||||
|  | 	LeaseMax time.Duration | ||||||
|  | } | ||||||
|  |  | ||||||
| func pathConfigLease(b *backend) *framework.Path { | func pathConfigLease(b *backend) *framework.Path { | ||||||
| 	return &framework.Path{ | 	return &framework.Path{ | ||||||
| 		Pattern: "config/lease", | 		Pattern: "config/lease", | ||||||
| @@ -89,11 +94,6 @@ func (b *backend) Lease(s logical.Storage) (*configLease, error) { | |||||||
| 	return &result, nil | 	return &result, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type configLease struct { |  | ||||||
| 	Lease    time.Duration |  | ||||||
| 	LeaseMax time.Duration |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const pathConfigLeaseHelpSyn = ` | const pathConfigLeaseHelpSyn = ` | ||||||
| Configure the default lease information for SSH dynamic keys. | Configure the default lease information for SSH dynamic keys. | ||||||
| ` | ` | ||||||
|   | |||||||
| @@ -3,12 +3,22 @@ package ssh | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"strconv" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/helper/uuid" | 	"github.com/hashicorp/vault/helper/uuid" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type sshOTP struct { | ||||||
|  | 	Username string `json:"username"` | ||||||
|  | 	IP       string `json:"ip"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type sshCIDR struct { | ||||||
|  | 	CIDR []string | ||||||
|  | } | ||||||
|  |  | ||||||
| func pathCredsCreate(b *backend) *framework.Path { | func pathCredsCreate(b *backend) *framework.Path { | ||||||
| 	return &framework.Path{ | 	return &framework.Path{ | ||||||
| 		Pattern: "creds/(?P<name>[-\\w]+)", | 		Pattern: "creds/(?P<name>[-\\w]+)", | ||||||
| @@ -85,25 +95,10 @@ func (b *backend) pathCredsCreateWrite( | |||||||
| 	var result *logical.Response | 	var result *logical.Response | ||||||
| 	if role.KeyType == KeyTypeOTP { | 	if role.KeyType == KeyTypeOTP { | ||||||
| 		// Generate salted OTP | 		// Generate salted OTP | ||||||
| 		otp := uuid.GenerateUUID() | 		otp, err := b.GenerateOTPCredential(req, username, ip) | ||||||
| 		otpSalted := b.salt.SaltID(otp) |  | ||||||
| 		entry, err := req.Storage.Get("otp/" + otpSalted) |  | ||||||
| 		// Make sure that new OTP is not replacing an existing one |  | ||||||
| 		for err == nil && entry != nil { |  | ||||||
| 			otp := uuid.GenerateUUID() |  | ||||||
| 			otpSalted := b.salt.SaltID(otp) |  | ||||||
| 			entry, err = req.Storage.Get("otp/" + otpSalted) |  | ||||||
| 		} |  | ||||||
| 		entry, err = logical.StorageEntryJSON("otp/"+otpSalted, sshOTP{ |  | ||||||
| 			Username: username, |  | ||||||
| 			IP:       ip, |  | ||||||
| 		}) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		if err := req.Storage.Put(entry); err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		result = b.Secret(SecretOTPType).Response(map[string]interface{}{ | 		result = b.Secret(SecretOTPType).Response(map[string]interface{}{ | ||||||
| 			"key_type": role.KeyType, | 			"key_type": role.KeyType, | ||||||
| 			"key":      otp, | 			"key":      otp, | ||||||
| @@ -111,39 +106,10 @@ func (b *backend) pathCredsCreateWrite( | |||||||
| 			"otp": otp, | 			"otp": otp, | ||||||
| 		}) | 		}) | ||||||
| 	} else if role.KeyType == KeyTypeDynamic { | 	} else if role.KeyType == KeyTypeDynamic { | ||||||
| 		// Fetch the host key to be used for dynamic key installation | 		dynamicPublicKey, dynamicPrivateKey, err := b.GenerateDynamicCredential(req, &role, username, ip) | ||||||
| 		keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", role.KeyName)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, fmt.Errorf("key '%s' not found error:%s", role.KeyName, err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if keyEntry == nil { |  | ||||||
| 			return nil, fmt.Errorf("key '%s' not found", role.KeyName, err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		var hostKey sshHostKey |  | ||||||
| 		if err := keyEntry.DecodeJSON(&hostKey); err != nil { |  | ||||||
| 			return nil, fmt.Errorf("error reading the host key: %s", err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Generate RSA key pair |  | ||||||
| 		dynamicPublicKey, dynamicPrivateKey, err := generateRSAKeys() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, fmt.Errorf("error generating key: %s", err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Transfer the public key to target machine |  | ||||||
| 		err = uploadPublicKeyScp(dynamicPublicKey, username, ip, role.Port, hostKey.Key) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Add the public key to authorized_keys file in target machine |  | ||||||
| 		err = installPublicKeyInTarget(role.AdminUser, username, ip, role.Port, hostKey.Key) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, fmt.Errorf("error adding public key to authorized_keys file in target") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		result = b.Secret(SecretDynamicKeyType).Response(map[string]interface{}{ | 		result = b.Secret(SecretDynamicKeyType).Response(map[string]interface{}{ | ||||||
| 			"key":      dynamicPrivateKey, | 			"key":      dynamicPrivateKey, | ||||||
| 			"key_type": role.KeyType, | 			"key_type": role.KeyType, | ||||||
| @@ -168,13 +134,74 @@ func (b *backend) pathCredsCreateWrite( | |||||||
| 	return result, nil | 	return result, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type sshOTP struct { | // Generates a RSA key pair and installs it in the remote target | ||||||
| 	Username string `json:"username"` | func (b *backend) GenerateDynamicCredential(req *logical.Request, role *sshRole, username, ip string) (string, string, error) { | ||||||
| 	IP       string `json:"ip"` | 	// Fetch the host key to be used for dynamic key installation | ||||||
|  | 	keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", role.KeyName)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", fmt.Errorf("key '%s' not found error:%s", role.KeyName, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if keyEntry == nil { | ||||||
|  | 		return "", "", fmt.Errorf("key '%s' not found", role.KeyName, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var hostKey sshHostKey | ||||||
|  | 	if err := keyEntry.DecodeJSON(&hostKey); err != nil { | ||||||
|  | 		return "", "", fmt.Errorf("error reading the host key: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Generate RSA key pair | ||||||
|  | 	keyBits, err := strconv.Atoi(role.KeyBits) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", fmt.Errorf("error reading key bit size: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dynamicPublicKey, dynamicPrivateKey, err := generateRSAKeys(keyBits) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", fmt.Errorf("error generating key: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Transfer the public key to target machine | ||||||
|  | 	publicKeyFileName := uuid.GenerateUUID() | ||||||
|  | 	err = uploadPublicKeyScp(dynamicPublicKey, publicKeyFileName, username, ip, role.Port, hostKey.Key) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Add the public key to authorized_keys file in target machine | ||||||
|  | 	err = installPublicKeyInTarget(role.AdminUser, publicKeyFileName, username, ip, role.Port, hostKey.Key) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", fmt.Errorf("error adding public key to authorized_keys file in target") | ||||||
|  | 	} | ||||||
|  | 	return dynamicPublicKey, dynamicPrivateKey, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type sshCIDR struct { | // Generates an OTP and creates an entry for the same in storage backend. | ||||||
| 	CIDR []string | func (b *backend) GenerateOTPCredential(req *logical.Request, username, ip string) (string, error) { | ||||||
|  | 	otp := uuid.GenerateUUID() | ||||||
|  | 	otpSalted := b.salt.SaltID(otp) | ||||||
|  | 	entry, err := req.Storage.Get("otp/" + otpSalted) | ||||||
|  | 	// Make sure that new OTP is not replacing an existing one | ||||||
|  | 	for err == nil && entry != nil { | ||||||
|  | 		otp = uuid.GenerateUUID() | ||||||
|  | 		otpSalted = b.salt.SaltID(otp) | ||||||
|  | 		entry, err = req.Storage.Get("otp/" + otpSalted) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	entry, err = logical.StorageEntryJSON("otp/"+otpSalted, sshOTP{ | ||||||
|  | 		Username: username, | ||||||
|  | 		IP:       ip, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	if err := req.Storage.Put(entry); err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return otp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| const pathCredsCreateHelpSyn = ` | const pathCredsCreateHelpSyn = ` | ||||||
|   | |||||||
| @@ -3,10 +3,16 @@ package ssh | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/crypto/ssh" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type sshHostKey struct { | ||||||
|  | 	Key string | ||||||
|  | } | ||||||
|  |  | ||||||
| func pathKeys(b *backend) *framework.Path { | func pathKeys(b *backend) *framework.Path { | ||||||
| 	return &framework.Path{ | 	return &framework.Path{ | ||||||
| 		Pattern: "keys/(?P<name>[-\\w]+)", | 		Pattern: "keys/(?P<name>[-\\w]+)", | ||||||
| @@ -62,6 +68,11 @@ func (b *backend) pathKeysWrite(req *logical.Request, d *framework.FieldData) (* | |||||||
| 	keyName := d.Get("name").(string) | 	keyName := d.Get("name").(string) | ||||||
| 	keyString := d.Get("key").(string) | 	keyString := d.Get("key").(string) | ||||||
|  |  | ||||||
|  | 	signer, err := ssh.ParsePrivateKey([]byte(keyString)) | ||||||
|  | 	if err != nil || signer == nil { | ||||||
|  | 		return logical.ErrorResponse("Invalid key"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if keyString == "" { | 	if keyString == "" { | ||||||
| 		return logical.ErrorResponse("Missing key"), nil | 		return logical.ErrorResponse("Missing key"), nil | ||||||
| 	} | 	} | ||||||
| @@ -80,10 +91,6 @@ func (b *backend) pathKeysWrite(req *logical.Request, d *framework.FieldData) (* | |||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type sshHostKey struct { |  | ||||||
| 	Key string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const pathKeysSyn = ` | const pathKeysSyn = ` | ||||||
| Register a shared key which can be used to install dynamic key | Register a shared key which can be used to install dynamic key | ||||||
| in remote machine. | in remote machine. | ||||||
|   | |||||||
| @@ -3,12 +3,17 @@ package ssh | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const KeyTypeOTP = "otp" | ||||||
|  | const KeyTypeDynamic = "dynamic" | ||||||
|  | const KeyBitsRSA = "2048" | ||||||
|  |  | ||||||
| func pathRoles(b *backend) *framework.Path { | func pathRoles(b *backend) *framework.Path { | ||||||
| 	return &framework.Path{ | 	return &framework.Path{ | ||||||
| 		Pattern: "roles/(?P<name>[-\\w]+)", | 		Pattern: "roles/(?P<name>[-\\w]+)", | ||||||
| @@ -41,6 +46,10 @@ func pathRoles(b *backend) *framework.Path { | |||||||
| 				Type:        framework.TypeString, | 				Type:        framework.TypeString, | ||||||
| 				Description: "one-time-password or dynamic-key", | 				Description: "one-time-password or dynamic-key", | ||||||
| 			}, | 			}, | ||||||
|  | 			"key_bits": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "number of bits in keys", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Callbacks: map[logical.Operation]framework.OperationFunc{ | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
| @@ -54,127 +63,104 @@ func pathRoles(b *backend) *framework.Path { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func createOTPRole(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { |  | ||||||
| 	roleName := d.Get("name").(string) |  | ||||||
| 	if roleName == "" { |  | ||||||
| 		return logical.ErrorResponse("Missing role name"), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	cidr := d.Get("cidr").(string) |  | ||||||
| 	if cidr == "" { |  | ||||||
| 		return logical.ErrorResponse("Missing cidr blocks"), nil |  | ||||||
| 	} |  | ||||||
| 	for _, item := range strings.Split(cidr, ",") { |  | ||||||
| 		_, _, err := net.ParseCIDR(item) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return logical.ErrorResponse(fmt.Sprintf("Invalid cidr entry '%s'", item)), nil |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	adminUser := d.Get("admin_user").(string) |  | ||||||
| 	defaultUser := d.Get("default_user").(string) |  | ||||||
| 	if defaultUser == "" && adminUser != "" { |  | ||||||
| 		defaultUser = adminUser |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	port := d.Get("port").(string) |  | ||||||
| 	if port == "" { |  | ||||||
| 		port = "22" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	entry, err := logical.StorageEntryJSON(fmt.Sprintf("policy/%s", roleName), sshRole{ |  | ||||||
| 		AdminUser:   adminUser, |  | ||||||
| 		DefaultUser: defaultUser, |  | ||||||
| 		CIDR:        cidr, |  | ||||||
| 		Port:        port, |  | ||||||
| 		KeyType:     KeyTypeOTP, |  | ||||||
| 	}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if err := req.Storage.Put(entry); err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func createDynamicKeyRole(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { |  | ||||||
| 	roleName := d.Get("name").(string) |  | ||||||
| 	if roleName == "" { |  | ||||||
| 		return logical.ErrorResponse("Missing role name"), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	keyName := d.Get("key").(string) |  | ||||||
| 	if keyName == "" { |  | ||||||
| 		return logical.ErrorResponse("Missing key name"), nil |  | ||||||
| 	} |  | ||||||
| 	keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", keyName)) |  | ||||||
| 	if err != nil || keyEntry == nil { |  | ||||||
| 		return logical.ErrorResponse(fmt.Sprintf("Invalid 'key': '%s'", keyName)), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	adminUser := d.Get("admin_user").(string) |  | ||||||
| 	if adminUser == "" { |  | ||||||
| 		return logical.ErrorResponse("Missing admin username"), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	cidr := d.Get("cidr").(string) |  | ||||||
| 	if cidr == "" { |  | ||||||
| 		return logical.ErrorResponse("Missing cidr blocks"), nil |  | ||||||
| 	} |  | ||||||
| 	for _, item := range strings.Split(cidr, ",") { |  | ||||||
| 		_, _, err := net.ParseCIDR(item) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return logical.ErrorResponse(fmt.Sprintf("Invalid cidr entry '%s'", item)), nil |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	defaultUser := d.Get("default_user").(string) |  | ||||||
| 	if defaultUser == "" { |  | ||||||
| 		defaultUser = adminUser |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	port := d.Get("port").(string) |  | ||||||
| 	if port == "" { |  | ||||||
| 		port = "22" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	entry, err := logical.StorageEntryJSON(fmt.Sprintf("policy/%s", roleName), sshRole{ |  | ||||||
| 		KeyName:     keyName, |  | ||||||
| 		AdminUser:   adminUser, |  | ||||||
| 		DefaultUser: defaultUser, |  | ||||||
| 		CIDR:        cidr, |  | ||||||
| 		Port:        port, |  | ||||||
| 		KeyType:     KeyTypeDynamic, |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if err := req.Storage.Put(entry); err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	roleName := d.Get("name").(string) | ||||||
|  | 	if roleName == "" { | ||||||
|  | 		return logical.ErrorResponse("Missing role name"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cidr := d.Get("cidr").(string) | ||||||
|  | 	if cidr == "" { | ||||||
|  | 		return logical.ErrorResponse("Missing cidr blocks"), nil | ||||||
|  | 	} | ||||||
|  | 	for _, item := range strings.Split(cidr, ",") { | ||||||
|  | 		_, _, err := net.ParseCIDR(item) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return logical.ErrorResponse(fmt.Sprintf("Invalid cidr entry '%s'", item)), nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	port := d.Get("port").(string) | ||||||
|  | 	if port == "" { | ||||||
|  | 		port = "22" | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	keyType := d.Get("key_type").(string) | 	keyType := d.Get("key_type").(string) | ||||||
| 	if keyType == "" { | 	if keyType == "" { | ||||||
| 		return logical.ErrorResponse("Missing key type"), nil | 		return logical.ErrorResponse("Missing key type"), nil | ||||||
| 	} | 	} | ||||||
| 	keyType = strings.ToLower(keyType) | 	keyType = strings.ToLower(keyType) | ||||||
|  |  | ||||||
|  | 	var entry *logical.StorageEntry | ||||||
|  | 	var err error | ||||||
| 	if keyType == KeyTypeOTP { | 	if keyType == KeyTypeOTP { | ||||||
| 		return createOTPRole(req, d) | 		adminUser := d.Get("admin_user").(string) | ||||||
|  | 		if adminUser != "" { | ||||||
|  | 			return logical.ErrorResponse("Admin user not required for OTP type"), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		defaultUser := d.Get("default_user").(string) | ||||||
|  | 		if defaultUser == "" { | ||||||
|  | 			return logical.ErrorResponse("Missing default user"), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		entry, err = logical.StorageEntryJSON(fmt.Sprintf("policy/%s", roleName), sshRole{ | ||||||
|  | 			DefaultUser: defaultUser, | ||||||
|  | 			CIDR:        cidr, | ||||||
|  | 			KeyType:     KeyTypeOTP, | ||||||
|  | 		}) | ||||||
| 	} else if keyType == KeyTypeDynamic { | 	} else if keyType == KeyTypeDynamic { | ||||||
| 		return createDynamicKeyRole(req, d) | 		keyName := d.Get("key").(string) | ||||||
|  | 		if keyName == "" { | ||||||
|  | 			return logical.ErrorResponse("Missing key name"), nil | ||||||
|  | 		} | ||||||
|  | 		keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", keyName)) | ||||||
|  | 		if err != nil || keyEntry == nil { | ||||||
|  | 			return logical.ErrorResponse(fmt.Sprintf("Invalid 'key': '%s'", keyName)), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		adminUser := d.Get("admin_user").(string) | ||||||
|  | 		if adminUser == "" { | ||||||
|  | 			return logical.ErrorResponse("Missing admin username"), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		defaultUser := d.Get("default_user").(string) | ||||||
|  | 		if defaultUser == "" { | ||||||
|  | 			defaultUser = adminUser | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		keyBits := d.Get("key_bits").(string) | ||||||
|  | 		if keyBits != "" { | ||||||
|  | 			_, err := strconv.Atoi(keyBits) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return logical.ErrorResponse("Key bits should be an integer"), nil | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if keyBits == "" { | ||||||
|  | 			keyBits = KeyBitsRSA | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		entry, err = logical.StorageEntryJSON(fmt.Sprintf("policy/%s", roleName), sshRole{ | ||||||
|  | 			KeyName:     keyName, | ||||||
|  | 			AdminUser:   adminUser, | ||||||
|  | 			DefaultUser: defaultUser, | ||||||
|  | 			CIDR:        cidr, | ||||||
|  | 			Port:        port, | ||||||
|  | 			KeyType:     KeyTypeDynamic, | ||||||
|  | 			KeyBits:     keyBits, | ||||||
|  | 		}) | ||||||
| 	} else { | 	} else { | ||||||
| 		return logical.ErrorResponse("Invalid key type"), nil | 		return logical.ErrorResponse("Invalid key type"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := req.Storage.Put(entry); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
| @@ -216,15 +202,13 @@ func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) ( | |||||||
| type sshRole struct { | type sshRole struct { | ||||||
| 	KeyType     string `json:"key_type"` | 	KeyType     string `json:"key_type"` | ||||||
| 	KeyName     string `json:"key"` | 	KeyName     string `json:"key"` | ||||||
|  | 	KeyBits     string `json:"key_bits"` | ||||||
| 	AdminUser   string `json:"admin_user"` | 	AdminUser   string `json:"admin_user"` | ||||||
| 	DefaultUser string `json:"default_user"` | 	DefaultUser string `json:"default_user"` | ||||||
| 	CIDR        string `json:"cidr"` | 	CIDR        string `json:"cidr"` | ||||||
| 	Port        string `json:"port"` | 	Port        string `json:"port"` | ||||||
| } | } | ||||||
|  |  | ||||||
| const KeyTypeOTP = "otp" |  | ||||||
| const KeyTypeDynamic = "dynamic" |  | ||||||
|  |  | ||||||
| const pathRoleHelpSyn = ` | const pathRoleHelpSyn = ` | ||||||
| Manage the 'roles' that can be created with this backend. | Manage the 'roles' that can be created with this backend. | ||||||
| ` | ` | ||||||
|   | |||||||
| @@ -41,11 +41,11 @@ func (b *backend) pathVerifyWrite(req *logical.Request, d *framework.FieldData) | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &logical.Response{ | 	return &logical.Response{ | ||||||
| 		Data: map[string]interface{}{ | 		Data: map[string]interface{}{ | ||||||
| 			"username": otpEntry.Username, | 			"username": otpEntry.Username, | ||||||
| 			"ip":       otpEntry.IP, | 			"ip":       otpEntry.IP, | ||||||
| 			"valid":    "yes", |  | ||||||
| 		}, | 		}, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/helper/uuid" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
| @@ -24,7 +25,7 @@ func secretDynamicKey(b *backend) *framework.Secret { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		DefaultDuration:    10 * time.Minute, | 		DefaultDuration:    10 * time.Minute, | ||||||
| 		DefaultGracePeriod: 5 * time.Minute, | 		DefaultGracePeriod: 2 * time.Minute, | ||||||
| 		Renew:              b.secretDynamicKeyRenew, | 		Renew:              b.secretDynamicKeyRenew, | ||||||
| 		Revoke:             b.secretDynamicKeyRevoke, | 		Revoke:             b.secretDynamicKeyRevoke, | ||||||
| 	} | 	} | ||||||
| @@ -105,13 +106,14 @@ func (b *backend) secretDynamicKeyRevoke(req *logical.Request, d *framework.Fiel | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Transfer the dynamic public key to target machine and use it to remove the entry from authorized_keys file | 	// Transfer the dynamic public key to target machine and use it to remove the entry from authorized_keys file | ||||||
| 	err = uploadPublicKeyScp(dynamicPublicKey, username, ip, port, hostKey.Key) | 	dynamicPublicKeyFileName := uuid.GenerateUUID() | ||||||
|  | 	err = uploadPublicKeyScp(dynamicPublicKey, dynamicPublicKeyFileName, username, ip, port, hostKey.Key) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("public key transfer failed: %s", err) | 		return nil, fmt.Errorf("public key transfer failed: %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Remove the public key from authorized_keys file in target machine | 	// Remove the public key from authorized_keys file in target machine | ||||||
| 	err = uninstallPublicKeyInTarget(adminUser, username, ip, port, hostKey.Key) | 	err = uninstallPublicKeyInTarget(adminUser, dynamicPublicKeyFileName, username, ip, port, hostKey.Key) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("error removing public key from authorized_keys file in target") | 		return nil, fmt.Errorf("error removing public key from authorized_keys file in target") | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -20,24 +20,11 @@ func secretOTP(b *backend) *framework.Secret { | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		DefaultDuration:    10 * time.Minute, | 		DefaultDuration:    10 * time.Minute, | ||||||
| 		DefaultGracePeriod: 5 * time.Minute, | 		DefaultGracePeriod: 2 * time.Minute, | ||||||
| 		Renew:              b.secretOTPRenew, |  | ||||||
| 		Revoke:             b.secretOTPRevoke, | 		Revoke:             b.secretOTPRevoke, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) secretOTPRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { |  | ||||||
| 	lease, err := b.Lease(req.Storage) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	if lease == nil { |  | ||||||
| 		lease = &configLease{Lease: 1 * time.Hour} |  | ||||||
| 	} |  | ||||||
| 	f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, false) |  | ||||||
| 	return f(req, d) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (b *backend) secretOTPRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | func (b *backend) secretOTPRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
| 	otpRaw, ok := req.Secret.InternalData["otp"] | 	otpRaw, ok := req.Secret.InternalData["otp"] | ||||||
| 	if !ok { | 	if !ok { | ||||||
|   | |||||||
| @@ -20,8 +20,7 @@ import ( | |||||||
| // session with the target. Uses the public key authentication method | // session with the target. Uses the public key authentication method | ||||||
| // and hence the parameter 'key' takes in the private key. The fileName | // and hence the parameter 'key' takes in the private key. The fileName | ||||||
| // parameter takes an absolute path. | // parameter takes an absolute path. | ||||||
| func uploadPublicKeyScp(publicKey, username, ip, port, key string) error { | func uploadPublicKeyScp(publicKey, publicKeyFileName, username, ip, port, key string) error { | ||||||
| 	dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) |  | ||||||
| 	session, err := createSSHPublicKeysSession(username, ip, port, key) | 	session, err := createSSHPublicKeysSession(username, ip, port, key) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @@ -32,12 +31,12 @@ func uploadPublicKeyScp(publicKey, username, ip, port, key string) error { | |||||||
| 	defer session.Close() | 	defer session.Close() | ||||||
| 	go func() { | 	go func() { | ||||||
| 		w, _ := session.StdinPipe() | 		w, _ := session.StdinPipe() | ||||||
| 		fmt.Fprintln(w, "C0644", len(publicKey), dynamicPublicKeyFileName) | 		fmt.Fprintln(w, "C0644", len(publicKey), publicKeyFileName) | ||||||
| 		io.Copy(w, strings.NewReader(publicKey)) | 		io.Copy(w, strings.NewReader(publicKey)) | ||||||
| 		fmt.Fprint(w, "\x00") | 		fmt.Fprint(w, "\x00") | ||||||
| 		w.Close() | 		w.Close() | ||||||
| 	}() | 	}() | ||||||
| 	err = session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName)) | 	err = session.Run(fmt.Sprintf("scp -vt %s", publicKeyFileName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("public key upload failed") | 		return fmt.Errorf("public key upload failed") | ||||||
| 	} | 	} | ||||||
| @@ -87,8 +86,8 @@ func createSSHPublicKeysSession(username, ipAddr, port, hostKey string) (*ssh.Se | |||||||
| // Creates a new RSA key pair with key length of 2048. | // Creates a new RSA key pair with key length of 2048. | ||||||
| // The private key will be of pem format and the public key will be | // The private key will be of pem format and the public key will be | ||||||
| // of OpenSSH format. | // of OpenSSH format. | ||||||
| func generateRSAKeys() (publicKeyRsa string, privateKeyRsa string, err error) { | func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) { | ||||||
| 	privateKey, err := rsa.GenerateKey(rand.Reader, 2048) | 	privateKey, err := rsa.GenerateKey(rand.Reader, keyBits) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", "", fmt.Errorf("error generating RSA key-pair: %s", err) | 		return "", "", fmt.Errorf("error generating RSA key-pair: %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -108,7 +107,7 @@ func generateRSAKeys() (publicKeyRsa string, privateKeyRsa string, err error) { | |||||||
|  |  | ||||||
| // Concatenates the public present in that target machine's home | // Concatenates the public present in that target machine's home | ||||||
| // folder to ~/.ssh/authorized_keys file | // folder to ~/.ssh/authorized_keys file | ||||||
| func installPublicKeyInTarget(adminUser, username, ip, port, hostKey string) error { | func installPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port, hostKey string) error { | ||||||
| 	session, err := createSSHPublicKeysSession(adminUser, ip, port, hostKey) | 	session, err := createSSHPublicKeysSession(adminUser, ip, port, hostKey) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("unable to create SSH Session using public keys: %s", err) | 		return fmt.Errorf("unable to create SSH Session using public keys: %s", err) | ||||||
| @@ -122,11 +121,10 @@ func installPublicKeyInTarget(adminUser, username, ip, port, hostKey string) err | |||||||
| 	tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username) | 	tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username) | ||||||
|  |  | ||||||
| 	// Commands to be run on target machine | 	// Commands to be run on target machine | ||||||
| 	dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) | 	grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", publicKeyFileName, authKeysFileName, tempKeysFileName) | ||||||
| 	grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName) |  | ||||||
| 	catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) | 	catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) | ||||||
| 	catCmdAppendNew := fmt.Sprintf("cat %s >> %s", dynamicPublicKeyFileName, authKeysFileName) | 	catCmdAppendNew := fmt.Sprintf("cat %s >> %s", publicKeyFileName, authKeysFileName) | ||||||
| 	removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) | 	removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, publicKeyFileName) | ||||||
|  |  | ||||||
| 	targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd) | 	targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd) | ||||||
| 	session.Run(targetCmd) | 	session.Run(targetCmd) | ||||||
| @@ -135,7 +133,7 @@ func installPublicKeyInTarget(adminUser, username, ip, port, hostKey string) err | |||||||
|  |  | ||||||
| // Removes the installed public key from the authorized_keys file | // Removes the installed public key from the authorized_keys file | ||||||
| // in target machine | // in target machine | ||||||
| func uninstallPublicKeyInTarget(adminUser, username, ip, port, hostKey string) error { | func uninstallPublicKeyInTarget(adminUser, publicKeyFileName, username, ip, port, hostKey string) error { | ||||||
| 	session, err := createSSHPublicKeysSession(adminUser, ip, port, hostKey) | 	session, err := createSSHPublicKeysSession(adminUser, ip, port, hostKey) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("unable to create SSH Session using public keys: %s", err) | 		return fmt.Errorf("unable to create SSH Session using public keys: %s", err) | ||||||
| @@ -149,10 +147,9 @@ func uninstallPublicKeyInTarget(adminUser, username, ip, port, hostKey string) e | |||||||
| 	tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username) | 	tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username) | ||||||
|  |  | ||||||
| 	// Commands to be run on target machine | 	// Commands to be run on target machine | ||||||
| 	dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) | 	grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", publicKeyFileName, authKeysFileName, tempKeysFileName) | ||||||
| 	grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName) |  | ||||||
| 	catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) | 	catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) | ||||||
| 	removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) | 	removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, publicKeyFileName) | ||||||
|  |  | ||||||
| 	remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd) | 	remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd) | ||||||
| 	session.Run(remoteCmd) | 	session.Run(remoteCmd) | ||||||
|   | |||||||
| @@ -6,9 +6,9 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/exec" | 	"os/exec" | ||||||
|  | 	"os/user" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/api" |  | ||||||
| 	"github.com/hashicorp/vault/builtin/logical/ssh" | 	"github.com/hashicorp/vault/builtin/logical/ssh" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -19,13 +19,16 @@ type SSHCommand struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *SSHCommand) Run(args []string) int { | func (c *SSHCommand) Run(args []string) int { | ||||||
| 	var role string | 	var role, port, path string | ||||||
| 	var port string | 	var noExec bool | ||||||
| 	var sshCmdArgs []string | 	var sshCmdArgs []string | ||||||
| 	var sshDynamicKeyFileName string | 	var sshDynamicKeyFileName string | ||||||
| 	flags := c.Meta.FlagSet("ssh", FlagSetDefault) | 	flags := c.Meta.FlagSet("ssh", FlagSetDefault) | ||||||
| 	flags.StringVar(&role, "role", "", "") | 	flags.StringVar(&role, "role", "", "") | ||||||
| 	flags.StringVar(&port, "port", "22", "") | 	flags.StringVar(&port, "port", "22", "") | ||||||
|  | 	flags.StringVar(&path, "path", "ssh", "") | ||||||
|  | 	flags.BoolVar(&noExec, "no-exec", false, "") | ||||||
|  |  | ||||||
| 	flags.Usage = func() { c.Ui.Error(c.Help()) } | 	flags.Usage = func() { c.Ui.Error(c.Help()) } | ||||||
| 	if err := flags.Parse(args); err != nil { | 	if err := flags.Parse(args); err != nil { | ||||||
| 		return 1 | 		return 1 | ||||||
| @@ -43,23 +46,33 @@ func (c *SSHCommand) Run(args []string) int { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	input := strings.Split(args[0], "@") | 	input := strings.Split(args[0], "@") | ||||||
| 	if len(input) != 2 { | 	var username string | ||||||
|  | 	var ipAddr string | ||||||
|  | 	if len(input) == 1 { | ||||||
|  | 		u, err := user.Current() | ||||||
|  | 		if err != nil { | ||||||
|  | 			c.Ui.Error(fmt.Sprintf("Error fetching username: '%s'", err)) | ||||||
|  | 		} | ||||||
|  | 		username = u.Username | ||||||
|  | 		ipAddr = input[0] | ||||||
|  | 	} else if len(input) == 2 { | ||||||
|  | 		username = input[0] | ||||||
|  | 		ipAddr = input[1] | ||||||
|  | 	} else { | ||||||
| 		c.Ui.Error(fmt.Sprintf("Invalid parameter: %s", args[0])) | 		c.Ui.Error(fmt.Sprintf("Invalid parameter: %s", args[0])) | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	username := input[0] | 	ip, err := net.ResolveIPAddr("ip", ipAddr) | ||||||
|  |  | ||||||
| 	ip, err := net.ResolveIPAddr("ip", input[1]) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %s", err)) | 		c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %s", err)) | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if role == "" { | 	if role == "" { | ||||||
| 		role, err = setDefaultRole(client, ip.String()) | 		role, err = c.defaultRole(path, ip.String()) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			c.Ui.Error(fmt.Sprintf("Error setting default role: %s", err.Error())) | 			c.Ui.Error(fmt.Sprintf("Error setting default role: '%s'", err)) | ||||||
| 			return 1 | 			return 1 | ||||||
| 		} | 		} | ||||||
| 		c.Ui.Output(fmt.Sprintf("Vault SSH: Role:'%s'\n", role)) | 		c.Ui.Output(fmt.Sprintf("Vault SSH: Role:'%s'\n", role)) | ||||||
| @@ -70,12 +83,17 @@ func (c *SSHCommand) Run(args []string) int { | |||||||
| 		"ip":       ip.String(), | 		"ip":       ip.String(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	keySecret, err := client.SSH().KeyCreate(role, data) | 	keySecret, err := client.SSH(path).Credential(role, data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.Ui.Error(fmt.Sprintf("Error getting key for SSH session:%s", err)) | 		c.Ui.Error(fmt.Sprintf("Error getting key for SSH session:%s", err)) | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if noExec { | ||||||
|  | 		c.Ui.Output(fmt.Sprintf("IP:%s\nUsername: %s\nKey:%s\n", ip.String(), username, keySecret.Data["key"])) | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if keySecret.Data["key_type"].(string) == ssh.KeyTypeDynamic { | 	if keySecret.Data["key_type"].(string) == ssh.KeyTypeDynamic { | ||||||
| 		sshDynamicKey := string(keySecret.Data["key"].(string)) | 		sshDynamicKey := string(keySecret.Data["key"].(string)) | ||||||
| 		if len(sshDynamicKey) == 0 { | 		if len(sshDynamicKey) == 0 { | ||||||
| @@ -87,9 +105,8 @@ func (c *SSHCommand) Run(args []string) int { | |||||||
| 		sshCmdArgs = append(sshCmdArgs, []string{"-i", sshDynamicKeyFileName}...) | 		sshCmdArgs = append(sshCmdArgs, []string{"-i", sshDynamicKeyFileName}...) | ||||||
|  |  | ||||||
| 	} else if keySecret.Data["key_type"].(string) == ssh.KeyTypeOTP { | 	} else if keySecret.Data["key_type"].(string) == ssh.KeyTypeOTP { | ||||||
| 		fmt.Printf("OTP for the session is %s\n", string(keySecret.Data["key"].(string))) | 		c.Ui.Output(fmt.Sprintf("OTP for the session is %s\n", string(keySecret.Data["key"].(string)))) | ||||||
| 	} else { | 	} else { | ||||||
| 		// Intentionally not mentioning the exact error |  | ||||||
| 		c.Ui.Error("Error creating key") | 		c.Ui.Error("Error creating key") | ||||||
| 	} | 	} | ||||||
| 	sshCmdArgs = append(sshCmdArgs, []string{"-p", port}...) | 	sshCmdArgs = append(sshCmdArgs, []string{"-p", port}...) | ||||||
| @@ -107,25 +124,30 @@ func (c *SSHCommand) Run(args []string) int { | |||||||
| 	if keySecret.Data["key_type"].(string) == ssh.KeyTypeDynamic { | 	if keySecret.Data["key_type"].(string) == ssh.KeyTypeDynamic { | ||||||
| 		err = os.Remove(sshDynamicKeyFileName) | 		err = os.Remove(sshDynamicKeyFileName) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			// Intentionally not mentioning the exact error | 			c.Ui.Error(fmt.Sprintf("Error deleting key file: %s", err)) | ||||||
| 			c.Ui.Error("Error cleaning up") |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = client.SSH().KeyRevoke(keySecret.LeaseID) | 	err = client.Sys().Revoke(keySecret.LeaseID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// Intentionally not mentioning the exact error | 		c.Ui.Error(fmt.Sprintf("Error revoking the key: %s", err)) | ||||||
| 		c.Ui.Error("Error cleaning up") |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return 0 | 	return 0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func setDefaultRole(client *api.Client, ip string) (string, error) { | // If user did not provide the role with which SSH connection has | ||||||
|  | // to be established and if there is only one role associated with | ||||||
|  | // the IP, it is used by default. | ||||||
|  | func (c *SSHCommand) defaultRole(path, ip string) (string, error) { | ||||||
| 	data := map[string]interface{}{ | 	data := map[string]interface{}{ | ||||||
| 		"ip": ip, | 		"ip": ip, | ||||||
| 	} | 	} | ||||||
| 	secret, err := client.Logical().Write("ssh/lookup", data) | 	client, err := c.Client() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	secret, err := client.Logical().Write(path+"/lookup", data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", fmt.Errorf("Error finding roles for IP '%s':%s", ip, err) | 		return "", fmt.Errorf("Error finding roles for IP '%s':%s", ip, err) | ||||||
|  |  | ||||||
| @@ -135,13 +157,18 @@ func setDefaultRole(client *api.Client, ip string) (string, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if secret.Data["roles"] == nil { | 	if secret.Data["roles"] == nil { | ||||||
| 		return "", fmt.Errorf("IP '%s' not registered under any role", ip) | 		return "", fmt.Errorf("No matching roles found for IP '%s'", ip) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if len(secret.Data["roles"].([]interface{})) == 1 { | 	if len(secret.Data["roles"].([]interface{})) == 1 { | ||||||
| 		return secret.Data["roles"].([]interface{})[0].(string), nil | 		return secret.Data["roles"].([]interface{})[0].(string), nil | ||||||
| 	} else { | 	} else { | ||||||
| 		return "", fmt.Errorf("Multiple roles for IP '%s'. Select one of '%s' using '-role' option", ip, secret.Data["roles"]) | 		var roleNames string | ||||||
|  | 		for _, item := range secret.Data["roles"].([]interface{}) { | ||||||
|  | 			roleNames += item.(string) + ", " | ||||||
|  | 		} | ||||||
|  | 		roleNames = strings.TrimRight(roleNames, ", ") | ||||||
|  | 		return "", fmt.Errorf("IP '%s' has multiple roles.\nSelect a role using '-role' option.\nPossible roles: [%s]\nNote that all roles may not be permitted, based on ACLs.", ip, roleNames) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -160,21 +187,31 @@ Usage: vault ssh [options] username@ip | |||||||
|   that SSH backend is mounted and at least one 'role' be registed |   that SSH backend is mounted and at least one 'role' be registed | ||||||
|   with vault at priori. |   with vault at priori. | ||||||
|  |  | ||||||
|  |   For setting up SSH backends with one-time-passwords, installation | ||||||
|  |   of agent in target machines is required.  | ||||||
|  |   See [https://github.com/hashicorp/vault-ssh-agent] | ||||||
|  |  | ||||||
| General Options: | General Options: | ||||||
|  |  | ||||||
|   ` + generalOptionsUsage() + ` |   ` + generalOptionsUsage() + ` | ||||||
|  |  | ||||||
| SSH Options: | SSH Options: | ||||||
|  |  | ||||||
|   -role                 Mention the role to be used to create dynamic key. |   -role                 Role to be used to create the key. | ||||||
|   			Each IP is associated with a role. To see the associated |   			Each IP is associated with a role. To see the associated | ||||||
| 			roles with IP, use "lookup" endpoint. If you are certain that | 			roles with IP, use "lookup" endpoint. If you are certain that | ||||||
| 			there is only one role associated with the IP, you can | 			there is only one role associated with the IP, you can | ||||||
| 			skip mentioning the role. It will be chosen by default. | 			skip mentioning the role. It will be chosen by default. | ||||||
| 			If there are no roless associated with the IP, register | 			If there are no roles associated with the IP, register | ||||||
| 			the CIDR block of that IP using the "roles/" endpoint. | 			the CIDR block of that IP using the "roles/" endpoint. | ||||||
|  |  | ||||||
|   -port                 Port number to use for SSH connection. This defaults to port 22. |   -port                 Port number to use for SSH connection. This defaults to port 22. | ||||||
|  |  | ||||||
|  |   -no-exec		Shows the credentials but does not establish connection. | ||||||
|  |  | ||||||
|  |   -path			Mount point of SSH backend. If the backend is mounted at | ||||||
|  |   			'ssh', which is the default as well, this parameter can | ||||||
|  | 			be skipped. | ||||||
| ` | ` | ||||||
| 	return strings.TrimSpace(helpText) | 	return strings.TrimSpace(helpText) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 vishalnayak
					vishalnayak