diff --git a/builtin/logical/ssh/backend_test.go b/builtin/logical/ssh/backend_test.go new file mode 100644 index 0000000000..2a2f105e84 --- /dev/null +++ b/builtin/logical/ssh/backend_test.go @@ -0,0 +1,227 @@ +package ssh + +import ( + "bytes" + "fmt" + "log" + "net" + "os/exec" + "os/user" + "strings" + "testing" + + "golang.org/x/crypto/ssh" + + "github.com/hashicorp/vault/logical" + logicaltest "github.com/hashicorp/vault/logical/testing" + "github.com/mitchellh/mapstructure" +) + +const ( + testCidr = "127.0.0.1/32" + testRoleName = "testRoleName" + testKey = "testKey" + testPublicKey = ` +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv +` + testPrivateKey = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn +Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k +yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM +PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY +rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC +bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH +hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu +c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa +0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq +d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/ +fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm +U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2 +V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg +j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu ++fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu +vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw +eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96 +u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c ++LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW +U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi +fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8 +Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw +Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3 +7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1 +eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw== +-----END RSA PRIVATE KEY----- +` +) + +var testIP string +var testPort string +var testUserName string +var testAdminUser string + +func init() { + addr, err := startTestServer() + if err != nil { + panic(fmt.Sprintf("Error starting mock server:%s", err)) + } + input := strings.Split(addr, ":") + testIP = input[0] + testPort = input[1] + + u, err := user.Current() + if err != nil { + panic(fmt.Sprintf("Error getting current username: '%s'", err)) + } + testUserName = u.Username + testAdminUser = u.Username +} + +func TestSSHBackend(t *testing.T) { + logicaltest.Test(t, logicaltest.TestCase{ + Backend: Backend(), + Steps: []logicaltest.TestStep{ + testNamedKeys(t), + testNewRole(t), + testRoleCreate(t), + }, + }) +} + +func startTestServer() (string, error) { + pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey)) + if err != nil { + return "", fmt.Errorf("Error parsing public key") + } + serverConfig := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 { + return &ssh.Permissions{}, nil + } else { + return nil, fmt.Errorf("Key does not match") + } + }, + } + signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + if err != nil { + panic("Error parsing private key") + } + serverConfig.AddHostKey(signer) + + soc, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", fmt.Errorf("Error listening to connection") + } + + go func() { + for { + conn, err := soc.Accept() + if err != nil { + panic(fmt.Sprintf("Error accepting incoming connection: %s", err)) + } + defer conn.Close() + sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig) + if err != nil { + panic(fmt.Sprintf("Handshaking error: %v", err)) + } + + go func() { + for chanReq := range chanReqs { + go func(chanReq ssh.NewChannel) { + if chanReq.ChannelType() != "session" { + chanReq.Reject(ssh.UnknownChannelType, "unknown channel type") + return + } + + ch, requests, err := chanReq.Accept() + if err != nil { + panic(fmt.Sprintf("Error accepting channel: %s", err)) + } + + go func(ch ssh.Channel, in <-chan *ssh.Request) { + for req := range in { + executeCommand(ch, req) + } + }(ch, requests) + }(chanReq) + } + sshConn.Close() + }() + } + }() + return soc.Addr().String(), nil +} + +func executeCommand(ch ssh.Channel, req *ssh.Request) { + command := string(req.Payload[4:]) + cmd := exec.Command("/bin/bash", []string{"-c", command}...) + req.Reply(true, nil) + + cmd.Stdout = ch + cmd.Stderr = ch + cmd.Stdin = ch + + err := cmd.Start() + if err != nil { + panic(fmt.Sprintf("Error starting the command: '%s'", err)) + } + + go func() { + _, err := cmd.Process.Wait() + if err != nil { + panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err)) + } + ch.Close() + }() +} + +func testRoleCreate(t *testing.T) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: fmt.Sprintf("creds/%s", testRoleName), + Data: map[string]interface{}{ + "username": testUserName, + "ip": testIP, + }, + Check: func(resp *logical.Response) error { + var d struct { + Key string `mapstructure:"key"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[WARN] Generated Key:%s\n", d.Key) + if d.Key == "" { + return fmt.Errorf("Generated key is an empty string") + } + _, err := ssh.ParsePrivateKey([]byte(d.Key)) + if err != nil { + return fmt.Errorf("Generated key is invalid") + } + return nil + }, + } +} + +func testNewRole(t *testing.T) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: fmt.Sprintf("roles/%s", testRoleName), + Data: map[string]interface{}{ + "key": testKey, + "admin_user": testAdminUser, + "cidr": testCidr, + "port": testPort, + }, + } +} + +func testNamedKeys(t *testing.T) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: fmt.Sprintf("keys/%s", testKey), + Data: map[string]interface{}{ + "key": testPrivateKey, + }, + } +} diff --git a/builtin/logical/ssh/path_role_create.go b/builtin/logical/ssh/path_role_create.go index 2d00bac6b0..add0a4380e 100644 --- a/builtin/logical/ssh/path_role_create.go +++ b/builtin/logical/ssh/path_role_create.go @@ -2,6 +2,7 @@ package ssh import ( "fmt" + "log" "net" "github.com/hashicorp/vault/logical" @@ -35,6 +36,7 @@ func pathRoleCreate(b *backend) *framework.Path { func (b *backend) pathRoleCreateWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + log.Printf("Vishal: pathRoleCreateWrite\n") roleName := d.Get("name").(string) username := d.Get("username").(string) ipRaw := d.Get("ip").(string) @@ -92,9 +94,11 @@ func (b *backend) pathRoleCreateWrite( // Transfer the public key to target machine err = uploadPublicKeyScp(dynamicPublicKey, username, ip, role.Port, hostKey.Key) + //return nil, nil //TODO remove this if err != nil { return nil, err } + log.Printf("Vishal: uploaded public key file to target\n") // Add the public key to authorized_keys file in target machine err = installPublicKeyInTarget(username, ip, role.Port, hostKey.Key) @@ -102,6 +106,7 @@ func (b *backend) pathRoleCreateWrite( return nil, fmt.Errorf("error adding public key to authorized_keys file in target") } + log.Printf("Vishal: installed public key file to target\n") result := b.Secret(SecretDynamicKeyType).Response(map[string]interface{}{ "key": dynamicPrivateKey, }, map[string]interface{}{ diff --git a/builtin/logical/ssh/util.go b/builtin/logical/ssh/util.go index aa06c0a4c2..992a0fad05 100644 --- a/builtin/logical/ssh/util.go +++ b/builtin/logical/ssh/util.go @@ -8,6 +8,7 @@ import ( "encoding/pem" "fmt" "io" + "log" "net" "strings" @@ -36,9 +37,9 @@ func uploadPublicKeyScp(publicKey, username, ip, port, key string) error { fmt.Fprint(w, "\x00") w.Close() }() - if err := session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName)); err != nil { - return err - } + log.Printf("Vishal: uploading now\n") + err = session.Run(fmt.Sprintf("scp -vt %s", dynamicPublicKeyFileName)) + log.Printf("Vishal: upload completed: err:%s\n", err) return nil } @@ -113,22 +114,22 @@ func installPublicKeyInTarget(username, ip, port, hostKey string) error { } defer session.Close() - authKeysFileName := fmt.Sprintf("/home/%s/.ssh/authorized_keys", username) - tempKeysFileName := fmt.Sprintf("/home/%s/temp_authorized_keys", username) + authKeysFileName := "~/.ssh/authorized_keys" + tempKeysFileName := "~/temp_authorized_keys" // Commands to be run on target machine dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) grepCmd := fmt.Sprintf("grep -vFf %s %s > %s", dynamicPublicKeyFileName, authKeysFileName, tempKeysFileName) catCmdRemoveDuplicate := fmt.Sprintf("cat %s > %s", tempKeysFileName, authKeysFileName) catCmdAppendNew := fmt.Sprintf("cat %s >> %s", dynamicPublicKeyFileName, authKeysFileName) - removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) + //removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) + log.Printf(grepCmd) + log.Printf(catCmdRemoveDuplicate) + log.Printf(catCmdAppendNew) - targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd) - - // Run the commands on target machine - if err := session.Run(targetCmd); err != nil { - return err - } + //targetCmd := fmt.Sprintf("%s;%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew, removeCmd) + targetCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, catCmdAppendNew) + session.Run(targetCmd) return nil } @@ -143,8 +144,8 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error { } defer session.Close() - authKeysFileName := "/home/" + username + "/.ssh/authorized_keys" - tempKeysFileName := "/home/" + username + "/temp_authorized_keys" + authKeysFileName := "~/.ssh/authorized_keys" + tempKeysFileName := "~/temp_authorized_keys" // Commands to be run on target machine dynamicPublicKeyFileName := fmt.Sprintf("vault_ssh_%s_%s.pub", username, ip) @@ -153,11 +154,7 @@ func uninstallPublicKeyInTarget(username, ip, port, hostKey string) error { removeCmd := fmt.Sprintf("rm -f %s %s", tempKeysFileName, dynamicPublicKeyFileName) remoteCmd := fmt.Sprintf("%s;%s;%s", grepCmd, catCmdRemoveDuplicate, removeCmd) - - // Run the commands in target machine - if err := session.Run(remoteCmd); err != nil { - return err - } + session.Run(remoteCmd) return nil } diff --git a/command/ssh.go b/command/ssh.go index dec099aa8e..f1bd6a28b2 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -51,7 +51,7 @@ func (c *SSHCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Error setting default role: %s", err.Error())) return 1 } - c.Ui.Output(fmt.Sprintf("Using role[%s]\n", role)) + c.Ui.Output(fmt.Sprintf("Vault SSH: Role:'%s'\n", role)) } data := map[string]interface{}{ @@ -72,10 +72,14 @@ func (c *SSHCommand) Run(args []string) int { sshDynamicKeyFileName := fmt.Sprintf("vault_temp_file_%s_%s", username, ip.String()) err = ioutil.WriteFile(sshDynamicKeyFileName, []byte(sshDynamicKey), 0600) - cmd := exec.Command("ssh", "-p", port, "-i", sshDynamicKeyFileName, args[0]) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - err = cmd.Run() + sshCmdArgs := []string{"-p", port, "-i", sshDynamicKeyFileName} + sshCmdArgs = append(sshCmdArgs, args...) + + sshCmd := exec.Command("ssh", sshCmdArgs...) + sshCmd.Stdin = os.Stdin + sshCmd.Stdout = os.Stdout + + err = sshCmd.Run() if err != nil { c.Ui.Error(fmt.Sprintf("Error while running ssh command:%s", err)) } @@ -138,13 +142,15 @@ General Options: SSH Options: - -role Mention the role to be used to create dynamic key. + -role Mention the role to be used to create dynamic key. Each IP is associated with a role. To see the associated roles with IP, use "lookup" endpoint. If you are certain that there is only one role associated with the IP, you can skip mentioning the role. It will be chosen by default. If there are no roless associated with the IP, register the CIDR block of that IP using the "roles/" endpoint. + + -port Port number to use for SSH connection. This defaults to port 22. ` return strings.TrimSpace(helpText) } diff --git a/command/ssh_test.go b/command/ssh_test.go new file mode 100644 index 0000000000..d59a377651 --- /dev/null +++ b/command/ssh_test.go @@ -0,0 +1,259 @@ +package command + +import ( + "bytes" + "fmt" + "log" + "net" + "os/exec" + "os/user" + "strings" + "testing" + + logicalssh "github.com/hashicorp/vault/builtin/logical/ssh" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/vault" + "github.com/mitchellh/cli" + "golang.org/x/crypto/ssh" +) + +const ( + testCidr = "127.0.0.1/32" + testRoleName = "testRoleName" + testKey = "testKey" + testPublicKey = ` +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCaKEIkyRuzYdWPABDoLSPJY3eMCEOXIE0kRI5jqCwJtbkLFydSPvF7swN3r3v/StSBUP+8jmCD8zbXOxmfZHF1XMYGLVJdqfZDT1VCy0HI7PkJbuTIFhdJo3RyOyOlSzj4JV4I3iN7BFbx8RBckEYegKykOps82hZwJYMdykq2iynVJEw+FEg2Y+Zte4DHcy75kR61HE3PM3BK7R5nIPNcuDXTXQZbmFq57LONi8EjAiVWIZitCGdQJg+8aDAceaHdb8xu3GiZUGWQVO8M3OUYbSqWgPIp7R9JI9XZBfby2twJsgJs4PKIH0kjYRW+0Q3iDZH51RTOX3F8yN8Zk7mv +` + testPrivateKey = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAmihCJMkbs2HVjwAQ6C0jyWN3jAhDlyBNJESOY6gsCbW5Cxcn +Uj7xe7MDd697/0rUgVD/vI5gg/M21zsZn2RxdVzGBi1SXan2Q09VQstByOz5CW7k +yBYXSaN0cjsjpUs4+CVeCN4jewRW8fEQXJBGHoCspDqbPNoWcCWDHcpKtosp1SRM +PhRINmPmbXuAx3Mu+ZEetRxNzzNwSu0eZyDzXLg1010GW5haueyzjYvBIwIlViGY +rQhnUCYPvGgwHHmh3W/MbtxomVBlkFTvDNzlGG0qloDyKe0fSSPV2QX28trcCbIC +bODyiB9JI2EVvtEN4g2R+dUUzl9xfMjfGZO5rwIDAQABAoIBAGHMUpIVx+4YjiyH +hTJWmNKFuOzsvTyeMHJmz9KneTC7yeYgTUDfT8IDQprmiIrghUp5AZU02kQ7wznu +c4XsahJjxflbPVrQnbv8E4IpgtWeiSuT366UXTfJa/GgVS/jNgQvaKXFj8rWaPZa +0d93ZBSr21rhF2UWko+ZLMJ0eMuvJ6yc+BsNjSXq5tGAeT+0vkMBcP+ltZWoEibq +d3YvxAzDmb4CwG4AqcSF1UMnuF6GEdRc/NLlq6YB72pPWaOi2oVEkIQPeMdSfTj/ +fFI61JB/MlnkQbAAPq/R/5pGhjiCqHds2uSinAAQuaE/cMdhfFBMYNfvadQIEZzm +U6F7O7ECgYEAzS7o+lm+W/1bAXmOiddwLAF4olXs3q0Am+sbZF6zMsq67ZT3txU2 +V3c3vBiXy4MOkOp5CcN9m1hai5CwMxEYoNE77+kwuxFV5pzGnHseHSbu2hWinLOg +j0+NQwKqy7U55amwz+Y41Wwn9obzU6AXQ38I9Kf+YWDiVIDVEBxVRbcCgYEAwFYu ++fEPAioSg3sn0S+z0TbEFp9p0meZWuqct3Lyn83lOpbfVNL6GSYBFwy92jxhQCMu +vGPzkK6ITRe4rapOjMLWosT6wzfgjubeHlhjt3Ccf4zm9OJQ7ghfqR5lKkxoKwZw +eB/iB/Li+ZCn2HpkrLQ6V4HAuJD2Fj+T7LFn68kCgYEAyPNNd4sXNU6vp4UehX96 +u46BUDPpNbin5Qxgmm9o/7CvXGnOJf/fZdA7xLstR0LGrEUHX/mW9eKVYyTEfG8c ++LuTAQcYE84JnD8lATJPLuvnd61CwkfmUxTtW5isH7AQ0Q3dPe/S76rqhLZsbxVW +U2OCKOKy7zoM0AgRI6MsHIcCgYAMd4mj+dQXN9LrYtg53vWw4fPj44FgegaetgZi +fbjsUtRA7/aZ8PL1HlmDvPexZaiIF7+3xmLLRgTfumHmH9vnk9mFw27dqImNubk8 +Dk6oXUxHmEKALQtB4pkQxT+ZdkpqP4iawLZN/ZhoxM+cYJKV/zio42gyjnLlDknw +Va9+wQKBgQDE7aUItIquTwNtcOsar7aMAYup7wHprEDSb7Y2PclUamKyLfjvJrX3 +7ZyXgH4PxDXeezwd+XdE2qdCwlW+3vMnveA9qFz+jyJ3hcxG+hcHMrTLM0A3NBH1 +eWhDYXIMZdnt2TojESQHBZhImgPL0nVfynj+I1uMbb84xGHVkACSHw== +-----END RSA PRIVATE KEY----- +` +) + +var testIP string +var testPort string +var testUserName string +var testAdminUser string + +func init() { + addr, err := startTestServer() + if err != nil { + panic(fmt.Sprintf("Error starting mock server:%s", err)) + } + input := strings.Split(addr, ":") + testIP = input[0] + testPort = input[1] + //testPort = "22" + + u, err := user.Current() + if err != nil { + panic(fmt.Sprintf("Error getting current username: '%s'", err)) + } + testUserName = u.Username + testAdminUser = u.Username + //testUserName = "vishal" //TODO: remove this + //testAdminUser = "vishal" //TODO: remove this +} + +func TestSSH(t *testing.T) { + err := vault.AddTestLogicalBackend("ssh", logicalssh.Factory) + if err != nil { + t.Fatalf("err: %s", err) + } + core, _, token := vault.TestCoreUnsealed(t) + ln, addr := http.TestServer(t, core) + defer ln.Close() + + ui := new(cli.MockUi) + mountCmd := &MountCommand{ + Meta: Meta{ + ClientToken: token, + Ui: ui, + }, + } + + args := []string{"-address", addr, "ssh"} + log.Printf("Vishal: mount args: %#v\n", args) + + if code := mountCmd.Run(args); code != 0 { + t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + } + + client, err := mountCmd.Client() + if err != nil { + t.Fatalf("err: %s", err) + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatalf("err: %s", err) + } + + mount, ok := mounts["ssh/"] + if !ok { + t.Fatal("should have ssh mount") + } + if mount.Type != "ssh" { + t.Fatal("should have ssh type") + } + writeCmd := &WriteCommand{ + Meta: Meta{ + ClientToken: token, + Ui: ui, + }, + } + args = []string{ + "-address", addr, + "ssh/keys/" + testKey, + "key=" + testPrivateKey, + } + log.Printf("Vishal: write args: %#v\n", args) + if code := writeCmd.Run(args); code != 0 { + t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + } + + args = []string{ + "-address", addr, + "ssh/roles/" + testRoleName, + "key=" + testKey, + "admin_user=" + testUserName, + "cidr=" + testCidr, + "port=" + testPort, + } + log.Printf("Vishal: write args: %#v\n", args) + if code := writeCmd.Run(args); code != 0 { + t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + } + log.Printf("Vishal: Reached here\n") + + sshCmd := &SSHCommand{ + Meta: Meta{ + ClientToken: token, + Ui: ui, + }, + } + args = []string{ + "-address", addr, + "-role=" + testRoleName, + testUserName + "@" + testIP, + "/usr/bin/whoami", + } + log.Printf("Vishal: ssh args: %#v\n", args) + if code := sshCmd.Run(args); code != 0 { + t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + } + log.Printf("addr:%s testRoleName:%s testUserName:%s testIP:%s testPort:%s\n", addr, testRoleName, testUserName, testIP, testPort) + // TODO: Compare the testUserName and response of whoami should match! else fail test. +} + +func executeCommand(ch ssh.Channel, req *ssh.Request) { + command := string(req.Payload[4:]) + cmd := exec.Command("/bin/bash", []string{"-c", command}...) + req.Reply(true, nil) + + cmd.Stdout = ch + cmd.Stderr = ch + cmd.Stdin = ch + + err := cmd.Start() + if err != nil { + panic(fmt.Sprintf("Error starting the command: '%s'", err)) + } + + go func() { + _, err := cmd.Process.Wait() + if err != nil { + panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err)) + } + ch.Close() + }() +} + +func startTestServer() (string, error) { + pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey)) + if err != nil { + return "", fmt.Errorf("Error parsing public key") + } + serverConfig := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 { + return &ssh.Permissions{}, nil + } else { + return nil, fmt.Errorf("Key does not match") + } + }, + } + signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + if err != nil { + panic("Error parsing private key") + } + serverConfig.AddHostKey(signer) + + soc, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", fmt.Errorf("Error listening to connection") + } + + go func() { + for { + conn, err := soc.Accept() + if err != nil { + panic(fmt.Sprintf("Error accepting incoming connection: %s", err)) + } + defer conn.Close() + sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig) + if err != nil { + panic(fmt.Sprintf("Handshaking error: %v", err)) + } + + go func() { + for chanReq := range chanReqs { + go func(chanReq ssh.NewChannel) { + if chanReq.ChannelType() != "session" { + chanReq.Reject(ssh.UnknownChannelType, "unknown channel type") + return + } + + ch, requests, err := chanReq.Accept() + if err != nil { + panic(fmt.Sprintf("Error accepting channel: %s", err)) + } + + go func(ch ssh.Channel, in <-chan *ssh.Request) { + for req := range in { + executeCommand(ch, req) + } + }(ch, requests) + }(chanReq) + } + sshConn.Close() + }() + } + }() + return soc.Addr().String(), nil +} diff --git a/vault/testing.go b/vault/testing.go index 4f7435af13..542af26a4c 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1,6 +1,7 @@ package vault import ( + "fmt" "testing" "github.com/hashicorp/vault/audit" @@ -26,12 +27,19 @@ func TestCore(t *testing.T) *Core { noopBackends["http"] = func(*logical.BackendConfig) (logical.Backend, error) { return new(rawHTTP), nil } + logicalBackends := make(map[string]logical.Factory) + for backendName, backendFactory := range noopBackends { + logicalBackends[backendName] = backendFactory + } + for backendName, backendFactory := range testLogicalBackends { + logicalBackends[backendName] = backendFactory + } physicalBackend := physical.NewInmem() c, err := NewCore(&CoreConfig{ Physical: physicalBackend, AuditBackends: noopAudits, - LogicalBackends: noopBackends, + LogicalBackends: logicalBackends, CredentialBackends: noopBackends, DisableMlock: true, }) @@ -83,6 +91,21 @@ func TestKeyCopy(key []byte) []byte { return result } +var testLogicalBackends = map[string]logical.Factory{} + +// This adds a logical backend for the test core. This needs to be +// invoked before the test core is created. +func AddTestLogicalBackend(name string, factory logical.Factory) error { + if name == "" { + return fmt.Errorf("Missing backend name") + } + if factory == nil { + return fmt.Errorf("Missing backend factory function") + } + testLogicalBackends[name] = factory + return nil +} + type noopAudit struct{} func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {