From 756be6976d71d3b671f72cedc2460ebe464e36ba Mon Sep 17 00:00:00 2001 From: Vishal Nayak Date: Mon, 29 Jun 2015 22:00:08 -0400 Subject: [PATCH] Refactoring changes --- api/logical.go | 3 - api/ssh.go | 33 ++++++++++- builtin/logical/ssh/path_lookup.go | 6 +- builtin/logical/ssh/path_role_create.go | 65 ++++++++++----------- builtin/logical/ssh/secret_ssh_key.go | 9 ++- builtin/logical/ssh/util.go | 47 +++++++++++++--- command/ssh.go | 75 ++++++++++++++++--------- 7 files changed, 157 insertions(+), 81 deletions(-) diff --git a/api/logical.go b/api/logical.go index 20bd1ea59f..a633df4b96 100644 --- a/api/logical.go +++ b/api/logical.go @@ -1,7 +1,5 @@ package api -import "log" - // Logical is used to perform logical backend operations on Vault. type Logical struct { c *Client @@ -27,7 +25,6 @@ func (c *Logical) Read(path string) (*Secret, error) { } func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, error) { - log.Printf("Vishal: api.logical.Write(): invoking Put() on %#v\n", path) r := c.c.NewRequest("PUT", "/v1/"+path) if err := r.SetJSONBody(data); err != nil { return nil, err diff --git a/api/ssh.go b/api/ssh.go index cfe8c1bb67..ee0544430c 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -1,6 +1,9 @@ package api -import "fmt" +import ( + "encoding/json" + "fmt" +) type Ssh struct { c *Client @@ -10,8 +13,8 @@ func (c *Client) Ssh() *Ssh { return &Ssh{c: c} } -func (c *Ssh) KeyCreate(data map[string]interface{}) (*Secret, error) { - r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/ssh/creds/web")) +func (c *Ssh) KeyCreate(role string, data map[string]interface{}) (*Secret, error) { + r := c.c.NewRequest("PUT", fmt.Sprintf("/v1/ssh/creds/"+role)) if err := r.SetJSONBody(data); err != nil { return nil, err } @@ -24,3 +27,27 @@ func (c *Ssh) KeyCreate(data map[string]interface{}) (*Secret, error) { return ParseSecret(resp.Body) } + +func (c *Ssh) Lookup(data map[string]interface{}) (*SshRoles, error) { + r := c.c.NewRequest("PUT", "/v1/ssh/lookup") + if err := r.SetJSONBody(data); err != nil { + return nil, err + } + + resp, err := c.c.RawRequest(r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var roles SshRoles + dec := json.NewDecoder(resp.Body) + if err := dec.Decode(&roles); err != nil { + return nil, err + } + return &roles, nil +} + +type SshRoles struct { + Data map[string]interface{} `json:"data"` +} diff --git a/builtin/logical/ssh/path_lookup.go b/builtin/logical/ssh/path_lookup.go index c65f759675..a6f7f368de 100644 --- a/builtin/logical/ssh/path_lookup.go +++ b/builtin/logical/ssh/path_lookup.go @@ -1,10 +1,10 @@ package ssh import ( - "encoding/json" "fmt" "log" "net" + "strings" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -62,10 +62,8 @@ func containsIP(s logical.Storage, roleName string, ip string) (bool, error) { if err := roleEntry.DecodeJSON(&role); err != nil { return false, fmt.Errorf("error decoding role '%s'", roleName) } - var cidrEntry sshCIDR ipMatched := false - json.Unmarshal([]byte(role.CIDR), &cidrEntry) - for _, item := range cidrEntry.CIDR { + for _, item := range strings.Split(role.CIDR, ",") { log.Println(item) _, cidrIPNet, _ := net.ParseCIDR(item) ipMatched = cidrIPNet.Contains(net.ParseIP(ip)) diff --git a/builtin/logical/ssh/path_role_create.go b/builtin/logical/ssh/path_role_create.go index ebe580b4f3..c3ce3a3144 100644 --- a/builtin/logical/ssh/path_role_create.go +++ b/builtin/logical/ssh/path_role_create.go @@ -2,14 +2,10 @@ package ssh import ( "bytes" - "encoding/json" "fmt" - "io" "io/ioutil" "log" "net" - "os" - "path/filepath" "strings" "github.com/hashicorp/vault/logical" @@ -72,11 +68,8 @@ func (b *backend) pathRoleCreateWrite( } ip := ipAddr.String() - var cidrEntry sshCIDR ipMatched := false - log.Printf("Vishal: role.CIDR:%v\n", role.CIDR) - json.Unmarshal([]byte(role.CIDR), &cidrEntry) - for _, item := range cidrEntry.CIDR { + for _, item := range strings.Split(role.CIDR, ",") { log.Println(item) _, cidrIPNet, _ := net.ParseCIDR(item) ipMatched = cidrIPNet.Contains(ipAddr) @@ -113,39 +106,39 @@ func (b *backend) pathRoleCreateWrite( ioutil.WriteFile(otkPrivateKeyFileName, []byte(dynamicPrivateKey), 0600) ioutil.WriteFile(otkPublicKeyFileName, []byte(dynamicPublicKey), 0644) + uploadFileScp(otkPublicKeyFileName, username, ip, hostKey.Key) /* - scpCmd := "scp -i " + hostKeyFileName + " " + otkPublicKeyFileName + " " + username + "@" + ip + ":~;" - localCmdString := strings.Join([]string{ - scpCmd, - }, "") - //run the commands on vault server - err = exec_command(localCmdString) - if err != nil { - fmt.Errorf("Running command failed " + err.Error()) + otkPublicKeyFileNameBase := filepath.Base(otkPublicKeyFileName) + otkPublicKeyFile, _ := os.Open(otkPublicKeyFileName) + otkPublicKeyStat, err := otkPublicKeyFile.Stat() + if os.IsNotExist(err) { + return nil, fmt.Errorf("File does not exist") } + session := createSSHPublicKeysSession(username, ip, hostKey.Key) + if session == nil { + return nil, fmt.Errorf("Invalid session object") + } + go func() { + w, _ := session.StdinPipe() + fmt.Fprintln(w, "C0644", otkPublicKeyStat.Size(), otkPublicKeyFileNameBase) + io.Copy(w, otkPublicKeyFile) + fmt.Fprint(w, "\x00") + w.Close() + }() + if err := session.Run(fmt.Sprintf("scp -vt %s", otkPublicKeyFileNameBase)); err != nil { + panic("Failed to run: " + err.Error()) + } + session.Close() */ - otkPublicKeyFileNameBase := filepath.Base(otkPublicKeyFileName) - otkPublicKeyFile, _ := os.Open(otkPublicKeyFileName) - otkPublicKeyStat, _ := otkPublicKeyFile.Stat() - if otkPublicKeyStat.Size() <= 0 { - //return - } - session := createSSHPublicKeysSession(username, ip, hostKey.Key) - go func() { - w, _ := session.StdinPipe() - fmt.Fprintln(w, "C0644", otkPublicKeyStat.Size(), otkPublicKeyFileNameBase) - io.Copy(w, otkPublicKeyFile) - fmt.Fprint(w, "\x00") - w.Close() - }() - if err := session.Run(fmt.Sprintf("scp -vt %s", otkPublicKeyFileNameBase)); err != nil { - panic("Failed to run: " + err.Error()) - } - session.Close() - //connect to target machine - session = createSSHPublicKeysSession(username, ip, hostKey.Key) + session, err := createSSHPublicKeysSession(username, ip, hostKey.Key) + if err != nil { + return nil, fmt.Errorf("Unable to create SSH Session using public keys: %s", err) + } + if session == nil { + return nil, fmt.Errorf("Invalid session object") + } var buf bytes.Buffer session.Stdout = &buf diff --git a/builtin/logical/ssh/secret_ssh_key.go b/builtin/logical/ssh/secret_ssh_key.go index d162b3c812..00e1ddc172 100644 --- a/builtin/logical/ssh/secret_ssh_key.go +++ b/builtin/logical/ssh/secret_ssh_key.go @@ -126,7 +126,14 @@ func (b *backend) secretSshKeyRevoke(req *logical.Request, d *framework.FieldDat }, "") //connect to target machine - session := createSSHPublicKeysSession(username, ip, hostKey.Key) + session, err := createSSHPublicKeysSession(username, ip, hostKey.Key) + if err != nil { + return nil, fmt.Errorf("Unable to create SSH Session using public keys: %s", err) + } + if session == nil { + return nil, fmt.Errorf("Invalid session object") + } + var buf bytes.Buffer session.Stdout = &buf if err := session.Run(remoteCmdString); err != nil { diff --git a/builtin/logical/ssh/util.go b/builtin/logical/ssh/util.go index af1856d09d..affec2c9ef 100644 --- a/builtin/logical/ssh/util.go +++ b/builtin/logical/ssh/util.go @@ -7,9 +7,11 @@ import ( "encoding/base64" "encoding/pem" "fmt" + "io" "log" "os" "os/exec" + "path/filepath" "golang.org/x/crypto/ssh" ) @@ -22,10 +24,41 @@ func exec_command(cmdString string) error { return nil } -func createSSHPublicKeysSession(username string, ipAddr string, hostKey string) *ssh.Session { +func uploadFileScp(fileName, username, ip, key string) error { + nameBase := filepath.Base(fileName) + file, err := os.Open(fileName) + if err != nil { + return fmt.Errorf("Unable to open file") + } + stat, err := file.Stat() + if os.IsNotExist(err) { + return fmt.Errorf("File does not exist") + } + session, err := createSSHPublicKeysSession(username, ip, key) + if err != nil { + return fmt.Errorf("Unable to create SSH Session using public keys: %s", err) + } + if session == nil { + return fmt.Errorf("Invalid session object") + } + defer session.Close() + go func() { + w, _ := session.StdinPipe() + fmt.Fprintln(w, "C0644", stat.Size(), nameBase) + io.Copy(w, file) + fmt.Fprint(w, "\x00") + w.Close() + }() + if err := session.Run(fmt.Sprintf("scp -vt %s", nameBase)); err != nil { + return fmt.Errorf("Failed to run: %s", err) + } + return nil +} + +func createSSHPublicKeysSession(username string, ipAddr string, hostKey string) (*ssh.Session, error) { signer, err := ssh.ParsePrivateKey([]byte(hostKey)) if err != nil { - fmt.Errorf("Parsing Private Key failed: " + err.Error()) + return nil, fmt.Errorf("Parsing Private Key failed: %s", err) } config := &ssh.ClientConfig{ @@ -37,17 +70,17 @@ func createSSHPublicKeysSession(username string, ipAddr string, hostKey string) client, err := ssh.Dial("tcp", ipAddr+":22", config) if err != nil { - fmt.Errorf("Dial Failed: " + err.Error()) + return nil, fmt.Errorf("Dial Failed: %s", err) } if client == nil { - fmt.Errorf("SSH Dial to target failed: ", err.Error()) + return nil, fmt.Errorf("Invalid client object: %s", err) } session, err := client.NewSession() if err != nil { - fmt.Errorf("NewSession failed: " + err.Error()) + return nil, fmt.Errorf("Creating new client session failed: %s", err) } - return session + return session, nil } func removeFile(fileName string) { @@ -63,8 +96,6 @@ func removeFile(fileName string) { if err != nil { log.Printf(fmt.Sprintf("Failed: %s", err)) return - } else { - log.Printf("Successful\n") } } } diff --git a/command/ssh.go b/command/ssh.go index 4c89ce5b0f..b2cf7cc24a 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -16,59 +16,82 @@ type SshCommand struct { } func (c *SshCommand) Run(args []string) int { - log.SetFlags(log.LstdFlags | log.Lshortfile) - log.Printf("Vishal: SshCommand.Run: args:%#v len(args):%d\n", args, len(args)) - flags := c.Meta.FlagSet("ssh", FlagSetDefault) var role string + flags := c.Meta.FlagSet("ssh", FlagSetDefault) flags.StringVar(&role, "role", "", "") flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { return 1 } - log.Printf("Vishal: Role:%s\n", role) args = flags.Args() if len(args) < 1 { c.Ui.Error("ssh expects at least one argument") return 2 } + client, err := c.Client() if err != nil { c.Ui.Error(fmt.Sprintf("Error initializing client: %s", err)) return 2 } - log.Printf("Vishal: sshCommand.Run: args[0]: %#v\n", args[0]) input := strings.Split(args[0], "@") username := input[0] - ipAddr, err := net.ResolveIPAddr("ip4", input[1]) - log.Printf("Vishal: ssh.Ssh ipAddr_resolved: %#v\n", ipAddr.String()) - data := map[string]interface{}{ - "username": username, - "ip": ipAddr.String(), - } - - keySecret, err := client.Ssh().KeyCreate(data) + ip, err := net.ResolveIPAddr("ip4", input[1]) if err != nil { - c.Ui.Error(fmt.Sprintf("Error getting key for establishing SSH session", err)) + c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %s", err)) return 2 } - sshOneTimeKey := string(keySecret.Data["key"].(string)) - log.Printf("Vishal: command.ssh.Run returned! len(key):%d\n", len(sshOneTimeKey)) - ag := strings.Split(args[0], "@") - sshOtkFileName := "vault_ssh_otk_" + ag[0] + "_" + ag[1] + ".pem" - err = ioutil.WriteFile(sshOtkFileName, []byte(sshOneTimeKey), 0400) - //if sshOneTimeKey is empty, fail - //Establish a session directly from client to the target using the one time key received without making the vault server the middle guy:w + + if role == "" { + data := map[string]interface{}{ + "ip": ip.String(), + } + secret, err := client.Logical().Write("ssh/lookup", data) + if err != nil { + c.Ui.Error(fmt.Sprintf("Error finding roles for IP:%s Error:%s", ip.String(), err)) + return 1 + } + + if secret.Data["roles"] == nil { + c.Ui.Error(fmt.Sprintf("IP '%s' not registered under any role", ip.String())) + return 1 + } + + if len(secret.Data["roles"].([]interface{})) == 1 { + role = secret.Data["roles"].([]interface{})[0].(string) + c.Ui.Output(fmt.Sprintf("Using role[%s]\n", role)) + } else { + c.Ui.Error(fmt.Sprintf("Multiple roles for IP '%s'. Select one of '%s' using '-role' option", ip, secret.Data["roles"])) + return 1 + } + } + + data := map[string]interface{}{ + "username": username, + "ip": ip.String(), + } + keySecret, err := client.Ssh().KeyCreate(role, data) + if err != nil { + c.Ui.Error(fmt.Sprintf("Error getting key for SSH session:%s", err)) + return 2 + } + + sshDynamicKey := string(keySecret.Data["key"].(string)) + if len(sshDynamicKey) == 0 { + c.Ui.Error(fmt.Sprintf("Invalid key")) + return 2 + } + sshDynamicKeyFileName := "vault_ssh_key_" + username + "_" + ip.String() + ".pem" + err = ioutil.WriteFile(sshDynamicKeyFileName, []byte(sshDynamicKey), 0600) sshBinary, err := exec.LookPath("ssh") if err != nil { - log.Printf("ssh binary not found in PATH\n") + c.Ui.Error("ssh binary not found in PATH\n") + return 2 } sshEnv := os.Environ() - sshNew := "ssh -i " + sshOtkFileName + " " + args[0] - log.Printf("Vishal: sshNew:%#v\n", sshNew) - sshCmdArgs := []string{"ssh", "-i", sshOtkFileName, args[0]} - //defer os.Remove("vault_ssh_otk_" + args[0] + ".pem") + sshCmdArgs := []string{"ssh", "-i", sshDynamicKeyFileName, args[0]} if err := syscall.Exec(sshBinary, sshCmdArgs, sshEnv); err != nil { log.Printf("Execution failed: sshCommand: " + err.Error())