mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
Allow vault ssh to accept ssh commands in any ssh compatible format (#4710)
* Allow vault ssh to accept ssh commands in any ssh compatible format Previously vault ssh required ssh commands to be in the format `username@hostname <flags> command`. While this works just fine for human users this breaks a lot of automation workflows and is not compatible with the options that the ssh client supports. Motivation We currently run ansible which uses vault ssh to connect to hosts. Ansible generates ssh commands with the format `ssh <flags> -o User=username hostname command`. While this is a valid ssh command it currently breaks with vault because vault expects the format to be `username@hostname`. To work around this we currently use a wrapper script to parse the correct username being set by ansible and translate this into a vault ssh compatible `username@hostname` format Changes * You can now specify arguments in any order that ssh client allows. All arguments are passed directly to the ssh command and the format isn't modified in any way. * The username and port are parsed from the specified ssh command. It will accept all of the options supported by the ssh command and also will properly prefer `-p` and `user@` if both options are specified. * The ssh port is only added from the vault credentials if it hasn't been specified on the command line
This commit is contained in:
committed by
Jeff Mitchell
parent
b5c0f5b1c5
commit
caf3b94335
176
command/ssh.go
176
command/ssh.go
@@ -243,17 +243,27 @@ func (c *SSHCommand) Run(args []string) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the username and IP.
|
// Extract the hostname, username and port from the ssh command
|
||||||
username, hostname, ip, err := c.userHostAndIP(args[0])
|
hostname, username, port, err := c.parseSSHCommand(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.UI.Error(fmt.Sprintf("Error parsing user and IP: %s", err))
|
c.UI.Error(fmt.Sprintf("Error parsing the ssh command: %q", err))
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// The rest of the args are ssh args
|
// Use the current user if no user was specified in the ssh command
|
||||||
sshArgs := []string{}
|
if username == "" {
|
||||||
if len(args) > 1 {
|
u, err := user.Current()
|
||||||
sshArgs = args[1:]
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error getting the current user: %q", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
username = u.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := c.resolveHostname(hostname)
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error resolving the ssh hostname: %q", err))
|
||||||
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the client in the command
|
// Set the client in the command
|
||||||
@@ -329,11 +339,11 @@ func (c *SSHCommand) Run(args []string) int {
|
|||||||
|
|
||||||
switch strings.ToLower(c.flagMode) {
|
switch strings.ToLower(c.flagMode) {
|
||||||
case ssh.KeyTypeCA:
|
case ssh.KeyTypeCA:
|
||||||
return c.handleTypeCA(username, hostname, ip, sshArgs)
|
return c.handleTypeCA(username, ip, port, args)
|
||||||
case ssh.KeyTypeOTP:
|
case ssh.KeyTypeOTP:
|
||||||
return c.handleTypeOTP(username, hostname, ip, sshArgs)
|
return c.handleTypeOTP(username, ip, port, args)
|
||||||
case ssh.KeyTypeDynamic:
|
case ssh.KeyTypeDynamic:
|
||||||
return c.handleTypeDynamic(username, ip, sshArgs)
|
return c.handleTypeDynamic(username, ip, port, args)
|
||||||
default:
|
default:
|
||||||
c.UI.Error(fmt.Sprintf("Unknown SSH mode: %s", c.flagMode))
|
c.UI.Error(fmt.Sprintf("Unknown SSH mode: %s", c.flagMode))
|
||||||
return 1
|
return 1
|
||||||
@@ -341,7 +351,7 @@ func (c *SSHCommand) Run(args []string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleTypeCA is used to handle SSH logins using the "CA" key type.
|
// handleTypeCA is used to handle SSH logins using the "CA" key type.
|
||||||
func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []string) int {
|
func (c *SSHCommand) handleTypeCA(username, ip, port string, sshArgs []string) int {
|
||||||
// Read the key from disk
|
// Read the key from disk
|
||||||
publicKey, err := ioutil.ReadFile(c.flagPublicKeyPath)
|
publicKey, err := ioutil.ReadFile(c.flagPublicKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -460,10 +470,6 @@ func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []strin
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
args = append(args,
|
|
||||||
username+"@"+hostname,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Add extra user defined ssh arguments
|
// Add extra user defined ssh arguments
|
||||||
args = append(args, sshArgs...)
|
args = append(args, sshArgs...)
|
||||||
|
|
||||||
@@ -493,7 +499,7 @@ func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleTypeOTP is used to handle SSH logins using the "otp" key type.
|
// handleTypeOTP is used to handle SSH logins using the "otp" key type.
|
||||||
func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs []string) int {
|
func (c *SSHCommand) handleTypeOTP(username, ip, port string, sshArgs []string) int {
|
||||||
secret, cred, err := c.generateCredential(username, ip)
|
secret, cred, err := c.generateCredential(username, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.UI.Error(fmt.Sprintf("failed to generate credential: %s", err))
|
c.UI.Error(fmt.Sprintf("failed to generate credential: %s", err))
|
||||||
@@ -543,10 +549,13 @@ func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a port wasn't specified in the ssh arguments lets use the port we got back from vault
|
||||||
|
if port == "" {
|
||||||
|
args = append(args, "-p", cred.Port)
|
||||||
|
}
|
||||||
|
|
||||||
args = append(args,
|
args = append(args,
|
||||||
"-o StrictHostKeyChecking="+c.flagStrictHostKeyChecking,
|
"-o StrictHostKeyChecking="+c.flagStrictHostKeyChecking,
|
||||||
"-p", cred.Port,
|
|
||||||
username+"@"+hostname,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Add the rest of the ssh args appended by the user
|
// Add the rest of the ssh args appended by the user
|
||||||
@@ -585,7 +594,7 @@ func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleTypeDynamic is used to handle SSH logins using the "dyanmic" key type.
|
// handleTypeDynamic is used to handle SSH logins using the "dyanmic" key type.
|
||||||
func (c *SSHCommand) handleTypeDynamic(username, ip string, sshArgs []string) int {
|
func (c *SSHCommand) handleTypeDynamic(username, ip, port string, sshArgs []string) int {
|
||||||
// Generate the credential
|
// Generate the credential
|
||||||
secret, cred, err := c.generateCredential(username, ip)
|
secret, cred, err := c.generateCredential(username, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -610,13 +619,20 @@ func (c *SSHCommand) handleTypeDynamic(username, ip string, sshArgs []string) in
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
args := append([]string{
|
args := make([]string, 0)
|
||||||
|
// If a port wasn't specified in the ssh arguments lets use the port we got back from vault
|
||||||
|
if port == "" {
|
||||||
|
args = append(args, "-p", cred.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args,
|
||||||
"-i", keyPath,
|
"-i", keyPath,
|
||||||
"-o UserKnownHostsFile=" + c.flagUserKnownHostsFile,
|
"-o UserKnownHostsFile="+c.flagUserKnownHostsFile,
|
||||||
"-o StrictHostKeyChecking=" + c.flagStrictHostKeyChecking,
|
"-o StrictHostKeyChecking="+c.flagStrictHostKeyChecking,
|
||||||
"-p", cred.Port,
|
)
|
||||||
username + "@" + ip,
|
|
||||||
}, sshArgs...)
|
// Add extra user defined ssh arguments
|
||||||
|
args = append(args, sshArgs...)
|
||||||
|
|
||||||
cmd := exec.Command("ssh", args...)
|
cmd := exec.Command("ssh", args...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
@@ -745,37 +761,95 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// userAndIP takes an argument in the format foo@1.2.3.4 and separates the IP
|
// Finds the hostname, username (optional) and port (optional) from any valid ssh command
|
||||||
// and user parts, returning any errors.
|
// Supports usrname@hostname but also specifying valid ssh flags like -o User=username,
|
||||||
func (c *SSHCommand) userHostAndIP(s string) (string, string, string, error) {
|
// -o Port=2222 and -p 2222 anywhere in the command
|
||||||
// split the parameter username@ip
|
func (c *SSHCommand) parseSSHCommand(args []string) (hostname string, username string, port string, err error) {
|
||||||
input := strings.Split(s, "@")
|
lastArg := ""
|
||||||
var username, address string
|
|
||||||
|
|
||||||
// If only IP is mentioned and username is skipped, assume username to
|
for _, i := range args {
|
||||||
// be the current username. Vault SSH role's default username could have
|
arg := lastArg
|
||||||
// been used, but in order to retain the consistency with SSH command,
|
lastArg = ""
|
||||||
// current username is employed.
|
|
||||||
switch len(input) {
|
// If -p has been specified then this is our ssh port
|
||||||
case 1:
|
if arg == "-p" {
|
||||||
u, err := user.Current()
|
port = i
|
||||||
if err != nil {
|
continue
|
||||||
return "", "", "", errors.Wrap(err, "failed to fetch current user")
|
|
||||||
}
|
}
|
||||||
username, address = u.Username, input[0]
|
|
||||||
case 2:
|
|
||||||
username, address = input[0], input[1]
|
|
||||||
default:
|
|
||||||
return "", "", "", fmt.Errorf("invalid arguments: %q", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// this is an ssh option, lets see if User or Port have been set and use it
|
||||||
|
if arg == "-o" {
|
||||||
|
split := strings.Split(i, "=")
|
||||||
|
key := split[0]
|
||||||
|
// Incase the value contains = signs we want to get all of them
|
||||||
|
value := strings.Join(split[1:], " ")
|
||||||
|
|
||||||
|
if key == "User" {
|
||||||
|
// Don't overwrite the user if it is already set by username@hostname
|
||||||
|
// This matches the behaviour for how regular ssh reponds when both are specified
|
||||||
|
if username == "" {
|
||||||
|
username = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if key == "Port" {
|
||||||
|
// Don't overwrite the port if it is already set by -p
|
||||||
|
// This matches the behaviour for how regular ssh reponds when both are specified
|
||||||
|
if port == "" {
|
||||||
|
port = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// This isn't an ssh argument that we care about. Lets keep on parsing the command
|
||||||
|
if arg != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is an ssh argument we want to look at the value
|
||||||
|
if strings.HasPrefix(i, "-") {
|
||||||
|
lastArg = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have gotten this far it means this is a bare argument
|
||||||
|
// The first bare argument is the hostname
|
||||||
|
// The second bare argument is the command to run on the remote host
|
||||||
|
|
||||||
|
// If the hostname hasn't been set yet than it means we have found the first bare argument
|
||||||
|
if hostname == "" {
|
||||||
|
if strings.Contains(i, "@") {
|
||||||
|
split := strings.Split(i, "@")
|
||||||
|
username = split[0]
|
||||||
|
hostname = split[1]
|
||||||
|
} else {
|
||||||
|
hostname = i
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
// The second bare argument is the command to run on the remote host.
|
||||||
|
// We need to break out and stop parsing arugments now
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
if hostname == "" {
|
||||||
|
return "", "", "", errors.Wrap(
|
||||||
|
err,
|
||||||
|
fmt.Sprintf("failed to find a hostname in ssh command %q", strings.Join(args, " ")),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return hostname, username, port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSHCommand) resolveHostname(hostname string) (ip string, err error) {
|
||||||
// Resolving domain names to IP address on the client side.
|
// Resolving domain names to IP address on the client side.
|
||||||
// Vault only deals with IP addresses.
|
// Vault only deals with IP addresses.
|
||||||
ipAddr, err := net.ResolveIPAddr("ip", address)
|
ipAddr, err := net.ResolveIPAddr("ip", hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", errors.Wrap(err, "failed to resolve IP address")
|
return "", errors.Wrap(err, "failed to resolve IP address")
|
||||||
}
|
}
|
||||||
ip := ipAddr.String()
|
ip = ipAddr.String()
|
||||||
|
return ip, nil
|
||||||
return username, address, ip, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,3 +21,136 @@ func TestSSHCommand_Run(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Skip("Need a way to setup target infrastructure")
|
t.Skip("Need a way to setup target infrastructure")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseSSHCommand(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, cmd := testSSHCommand(t)
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
hostname string
|
||||||
|
username string
|
||||||
|
port string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"Parse just a hostname",
|
||||||
|
[]string{
|
||||||
|
"hostname",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Parse the standard username@hostname",
|
||||||
|
[]string{
|
||||||
|
"username@hostname",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"username",
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Parse the username out of -o User=username",
|
||||||
|
[]string{
|
||||||
|
"-o", "User=username",
|
||||||
|
"hostname",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"username",
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"If the username is specified with -o User=username and realname@hostname prefer realname@",
|
||||||
|
[]string{
|
||||||
|
"-o", "User=username",
|
||||||
|
"realname@hostname",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"realname",
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Parse the port out of -o Port=2222",
|
||||||
|
[]string{
|
||||||
|
"-o", "Port=2222",
|
||||||
|
"hostname",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"",
|
||||||
|
"2222",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Parse the port out of -p 2222",
|
||||||
|
[]string{
|
||||||
|
"-p", "2222",
|
||||||
|
"hostname",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"",
|
||||||
|
"2222",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"If port is defined with -o Port=2222 and -p 2244 prefer -p",
|
||||||
|
[]string{
|
||||||
|
"-p", "2244",
|
||||||
|
"-o", "Port=2222",
|
||||||
|
"hostname",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"",
|
||||||
|
"2244",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Ssh args with a command",
|
||||||
|
[]string{
|
||||||
|
"hostname",
|
||||||
|
"command",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Flags after the ssh command are not pased because they are part of the command",
|
||||||
|
[]string{
|
||||||
|
"username@hostname",
|
||||||
|
"command",
|
||||||
|
"-p 22",
|
||||||
|
},
|
||||||
|
"hostname",
|
||||||
|
"username",
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
hostname, username, port, err := cmd.parseSSHCommand(test.args)
|
||||||
|
if err != test.err {
|
||||||
|
t.Errorf("got error: %q want %q", err, test.err)
|
||||||
|
}
|
||||||
|
if hostname != test.hostname {
|
||||||
|
t.Errorf("got hostname: %q want %q", hostname, test.hostname)
|
||||||
|
}
|
||||||
|
if username != test.username {
|
||||||
|
t.Errorf("got username: %q want %q", username, test.username)
|
||||||
|
}
|
||||||
|
if port != test.port {
|
||||||
|
t.Errorf("got port: %q want %q", port, test.port)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user