mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
Adding transit logical backend
This commit is contained in:
35
builtin/logical/transit/backend.go
Normal file
35
builtin/logical/transit/backend.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(map[string]string) (logical.Backend, error) {
|
||||
return Backend(), nil
|
||||
}
|
||||
|
||||
func Backend() *framework.Backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"policy/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathPolicy(),
|
||||
pathEncrypt(),
|
||||
pathDecrypt(),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{},
|
||||
}
|
||||
|
||||
return b.Backend
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
}
|
||||
132
builtin/logical/transit/backend_test.go
Normal file
132
builtin/logical/transit/backend_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
testPlaintext = "the quick brown fox"
|
||||
)
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", false),
|
||||
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepWritePolicy(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "policy/" + name,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "policy/" + name,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "policy/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil && !expectNone {
|
||||
return fmt.Errorf("missing response")
|
||||
} else if expectNone {
|
||||
if resp != nil {
|
||||
return fmt.Errorf("response when expecting none")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var d struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Key []byte `mapstructure:"key"`
|
||||
CipherMode string `mapstructure:"cipher_mode"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.Name != name {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
if d.CipherMode != "aes-gcm" {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
if len(d.Key) != 32 {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepEncrypt(
|
||||
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "encrypt/" + name,
|
||||
Data: map[string]interface{}{
|
||||
"plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)),
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Ciphertext string `mapstructure:"ciphertext"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Ciphertext == "" {
|
||||
return fmt.Errorf("missing ciphertext")
|
||||
}
|
||||
decryptData["ciphertext"] = d.Ciphertext
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDecrypt(
|
||||
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "decrypt/" + name,
|
||||
Data: decryptData,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Plaintext string `mapstructure:"plaintext"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode the base64
|
||||
plainRaw, err := base64.StdEncoding.DecodeString(d.Plaintext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(plainRaw) != plaintext {
|
||||
return fmt.Errorf("plaintext mismatch: %s expect: %s", plainRaw, plaintext)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
100
builtin/logical/transit/path_decrypt.go
Normal file
100
builtin/logical/transit/path_decrypt.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathDecrypt() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: `decrypt/(?P<name>\w+)`,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the policy",
|
||||
},
|
||||
|
||||
"ciphertext": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Ciphertext value to decrypt",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.WriteOperation: pathDecryptWrite,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func pathDecryptWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
value := d.Get("ciphertext").(string)
|
||||
if len(value) == 0 {
|
||||
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Guard against a potentially invalid cipher-mode
|
||||
switch p.CipherMode {
|
||||
case "aes-gcm":
|
||||
default:
|
||||
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Verify the prefix
|
||||
if !strings.HasPrefix(value, "vault:v0:") {
|
||||
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Decode the base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(value, "vault:v0:"))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Setup the cipher
|
||||
aesCipher, err := aes.NewCipher(p.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup the GCM AEAD
|
||||
gcm, err := cipher.NewGCM(aesCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract the nonce and ciphertext
|
||||
nonce := decoded[:gcm.NonceSize()]
|
||||
ciphertext := decoded[gcm.NonceSize():]
|
||||
|
||||
// Verify and Decrypt
|
||||
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"plaintext": base64.StdEncoding.EncodeToString(plain),
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
104
builtin/logical/transit/path_encrypt.go
Normal file
104
builtin/logical/transit/path_encrypt.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathEncrypt() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: `encrypt/(?P<name>\w+)`,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the policy",
|
||||
},
|
||||
|
||||
"plaintext": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Plaintext value to encrypt",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.WriteOperation: pathEncryptWrite,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func pathEncryptWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
value := d.Get("plaintext").(string)
|
||||
if len(value) == 0 {
|
||||
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Decode the plaintext value
|
||||
plaintext, err := base64.StdEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("failed to decode plaintext as base64"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Guard against a potentially invalid cipher-mode
|
||||
switch p.CipherMode {
|
||||
case "aes-gcm":
|
||||
default:
|
||||
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Setup the cipher
|
||||
aesCipher, err := aes.NewCipher(p.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup the GCM AEAD
|
||||
gcm, err := cipher.NewGCM(aesCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Compute random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
_, err = rand.Read(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encrypt and tag with GCM
|
||||
out := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
|
||||
// Place the encrypted data after the nonce
|
||||
full := append(nonce, out...)
|
||||
|
||||
// Convert to base64
|
||||
encoded := base64.StdEncoding.EncodeToString(full)
|
||||
|
||||
// Prepend some information
|
||||
encoded = "vault:v0:" + encoded
|
||||
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"ciphertext": encoded,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
140
builtin/logical/transit/path_policy.go
Normal file
140
builtin/logical/transit/path_policy.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package transit
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
// Policy is the struct used to store metadata
|
||||
type Policy struct {
|
||||
Name string `json:"name"`
|
||||
Key []byte `json:"key"`
|
||||
CipherMode string `json:"cipher"`
|
||||
}
|
||||
|
||||
func (p *Policy) Serialize() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
func DeserializePolicy(buf []byte) (*Policy, error) {
|
||||
p := new(Policy)
|
||||
if err := json.Unmarshal(buf, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getPolicy(req *logical.Request, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := req.Storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
p, err := DeserializePolicy(raw.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func pathPolicy() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: `policy/(?P<name>\w+)`,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the policy",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.WriteOperation: pathPolicyWrite,
|
||||
logical.DeleteOperation: pathPolicyDelete,
|
||||
logical.ReadOperation: pathPolicyRead,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func pathPolicyWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Check if the policy already exists
|
||||
existing, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create the policy object
|
||||
p := &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
}
|
||||
|
||||
// Generate a 256bit key
|
||||
p.Key = make([]byte, 32)
|
||||
_, err = rand.Read(p.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encode the policy
|
||||
buf, err := p.Serialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the policy into storage
|
||||
err = req.Storage.Put(&logical.StorageEntry{
|
||||
Key: "policy/" + name,
|
||||
Value: buf,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func pathPolicyRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"name": p.Name,
|
||||
"key": p.Key,
|
||||
"cipher_mode": p.CipherMode,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func pathPolicyDelete(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
err := req.Storage.Delete("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/aws"
|
||||
"github.com/hashicorp/vault/builtin/logical/consul"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
|
||||
"github.com/hashicorp/vault/audit"
|
||||
tokenDisk "github.com/hashicorp/vault/builtin/token/disk"
|
||||
@@ -51,8 +52,9 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory {
|
||||
"github": credGitHub.Factory,
|
||||
},
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"aws": aws.Factory,
|
||||
"consul": consul.Factory,
|
||||
"aws": aws.Factory,
|
||||
"consul": consul.Factory,
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user