Merge branch 'master-oss' into sethvargo/cli-magic

This commit is contained in:
Jeff Mitchell
2018-01-03 14:02:31 -05:00
1088 changed files with 95212 additions and 45132 deletions

1
.gitignore vendored
View File

@@ -64,6 +64,7 @@ tags
# compiled output
ui/dist
ui/tmp
ui/root
# dependencies
ui/node_modules

View File

@@ -1,27 +1,254 @@
## 0.8.4 (Unreleased)
## 0.9.2 (Unreleased)
IMPROVEMENTS:
* physical/s3: Allow using paths with S3 for non-AWS deployments [GH-3730]
* physical/s3: Add ability to disable SSL for non-AWS deployments [GH-3730]
## 0.9.1 (December 21st, 2017)
DEPRECATIONS/CHANGES:
* AppRole Case Sensitivity: In prior versions of Vault, `list` operations
against AppRole roles would require preserving case in the role name, even
though most other operations within AppRole are case-insensitive with
respect to the role name. This has been fixed; existing roles will behave as
they have in the past, but new roles will act case-insensitively in these
cases.
* Token Auth Backend Roles parameter types: For `allowed_policies` and
`disallowed_policies` in role definitions in the token auth backend, input
can now be a comma-separated string or an array of strings. Reading a role
will now return arrays for these parameters.
* Transit key exporting: You can now mark a key in the `transit` backend as
`exportable` at any time, rather than just at creation time; however, once
this value is set, it still cannot be unset.
* PKI Secret Backend Roles parameter types: For `allowed_domains` and
`key_usage` in role definitions in the PKI secret backend, input
can now be a comma-separated string or an array of strings. Reading a role
will now return arrays for these parameters.
* SSH Dynamic Keys Method Defaults to 2048-bit Keys: When using the dynamic
key method in the SSH backend, the default is now to use 2048-bit keys if no
specific key bit size is specified.
* Consul Secret Backend lease handling: The `consul` secret backend can now
accept both strings and integer numbers of seconds for its lease value. The
value returned on a role read will be an integer number of seconds instead
of a human-friendly string.
* Unprintable characters not allowed in API paths: Unprintable characters are
no longer allowed in names in the API (paths and path parameters), with an
extra restriction on whitespace characters. Allowed characters are those
that are considered printable by Unicode plus spaces.
FEATURES:
* **Transit Backup/Restore**: The `transit` backend now supports a backup
operation that can export a given key, including all key versions and
configuration, as well as a restore operation allowing import into another
Vault.
* **gRPC Database Plugins**: Database plugins now use gRPC for transport,
allowing them to be written in other languages.
* **Nomad Secret Backend**: Nomad ACL tokens can now be generated and revoked
using Vault.
* **TLS Cert Auth Backend Improvements**: The `cert` auth backend can now
match against custom certificate extensions via exact or glob matching, and
additionally supports max_ttl and periodic token toggles.
IMPROVEMENTS:
* auth/cert: Support custom certificate constraints [GH-3634]
* auth/cert: Support setting `max_ttl` and `period` [GH-3642]
* audit/file: Setting a file mode of `0000` will now disable Vault from
automatically `chmod`ing the log file [GH-3649]
* auth/github: The legacy MFA system can now be used with the GitHub auth
backend [GH-3696]
* auth/okta: The legacy MFA system can now be used with the Okta auth backend
[GH-3653]
* auth/token: `allowed_policies` and `disallowed_policies` can now be specified
as a comma-separated string or an array of strings [GH-3641]
* command/server: The log level can now be specified with `VAULT_LOG_LEVEL`
[GH-3721]
* core: Period values from auth backends will now be checked and applied to the
TTL value directly by core on login and renewal requests [GH-3677]
* database/mongodb: Add optional `write_concern` parameter, which can be set
during database configuration. This establishes a session-wide [write
concern](https://docs.mongodb.com/manual/reference/write-concern/) for the
lifecycle of the mount [GH-3646]
* http: Request path containing non-printable characters will return 400 - Bad
Request [GH-3697]
* mfa/okta: Filter a given email address as a login filter, allowing operation
when login email and account email are different
* plugins: Make Vault more resilient when unsealing when plugins are
unavailable [GH-3686]
* secret/pki: `allowed_domains` and `key_usage` can now be specified
as a comma-separated string or an array of strings [GH-3642]
* secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593]
* secret/consul: The Consul secret backend now uses the value of `lease` set
on the role, if set, when renewing a secret. [GH-3796]
* storage/mysql: Don't attempt database creation if it exists, which can help
under certain permissions constraints [GH-3716]
BUG FIXES:
* api/status (enterprise): Fix status reporting when using an auto seal
* auth/approle: Fix case-sensitive/insensitive comparison issue [GH-3665]
* auth/cert: Return `allowed_names` on role read [GH-3654]
* auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496]
[GH-3625] [GH-3656]
* core: Fix seal status reporting when using an autoseal
* core: Add creation path to wrap info for a control group token
* core: Fix potential panic that could occur using plugins when a node
transitioned from active to standby [GH-3638]
* core: Fix memory ballooning when a connection would connect to the cluster
port and then go away -- redux! [GH-3680]
* core: Replace recursive token revocation logic with depth-first logic, which
can avoid hitting stack depth limits in extreme cases [GH-2348]
* core: When doing a read on configured audited-headers, properly handle case
insensitivity [GH-3701]
* core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable
* database/mysql: Allow the creation statement to use commands that are not yet
supported by the prepare statement protocol [GH-3619]
* plugin/auth-gcp: Fix IAM roles when using `allow_gce_inference` [VPAG-19]
## 0.9.0.1 (November 21st, 2017) (Enterprise Only)
IMPROVEMENTS:
* auth/gcp: Support seal wrapping of configuration parameters
* auth/kubernetes: Support seal wrapping of configuration parameters
BUG FIXES:
* Fix an upgrade issue with some physical backends when migrating from legacy
HSM stored key support to the new Seal Wrap mechanism
## 0.9.0 (November 14th, 2017)
DEPRECATIONS/CHANGES:
* HSM config parameter requirements: When using Vault with an HSM, a new
paramter is required: `hmac_key_label`. This performs a similar function to
`key_label` but for the HMAC key Vault will use. Vault will generate a
suitable key if this value is specified and `generate_key` is set true.
* API HTTP client behavior: When calling `NewClient` the API no longer
modifies the provided client/transport. In particular this means it will no
longer enable redirection limiting and HTTP/2 support on custom clients. It
is suggested that if you want to make changes to an HTTP client that you use
one created by `DefaultConfig` as a starting point.
* AWS EC2 client nonce behavior: The client nonce generated by the backend
that gets returned along with the authentication response will be audited in
plaintext. If this is undesired, the clients can choose to supply a custom
nonce to the login endpoint. The custom nonce set by the client will from
now on, not be returned back with the authentication response, and hence not
audit logged.
* AWS Auth role options: The API will now error when trying to create or
update a role with the mutually-exclusive options
`disallow_reauthentication` and `allow_instance_migration`.
* SSH CA role read changes: When reading back a role from the `ssh` backend,
the TTL/max TTL values will now be an integer number of seconds rather than
a string. This better matches the API elsewhere in Vault.
* SSH role list changes: When listing roles from the `ssh` backend via the API,
the response data will additionally return a `key_info` map that will contain
a map of each key with a corresponding object containing the `key_type`.
* More granularity in audit logs: Audit request and response entires are still
in RFC3339 format but now have a granularity of nanoseconds.
* High availability related values have been moved out of the `storage` and
`ha_storage` stanzas, and into the top-level configuration. `redirect_addr`
has been renamed to `api_addr`. The stanzas still support accepting
HA-related values to maintain backward compatibility, but top-level values
will take precedence.
* A new `seal` stanza has been added to the configuration file, which is
optional and enables configuration of the seal type to use for additional
data protection, such as using HSM or Cloud KMS solutions to encrypt and
decrypt data.
FEATURES:
* **RSA Support for Transit Backend**: Transit backend can now generate RSA
keys which can be used for encryption and signing. [GH-3489]
* **Identity System**: Now in open source and with significant enhancements,
Identity is an integrated system for understanding users across tokens and
enabling easier management of users directly and via groups.
* **External Groups in Identity**: Vault can now automatically assign users
and systems to groups in Identity based on their membership in external
groups.
* **Seal Wrap / FIPS 140-2 Compatibility (Enterprise)**: Vault can now take
advantage of FIPS 140-2-certified HSMs to ensure that Critical Security
Parameters are protected in a compliant fashion. Vault's implementation has
received a statement of compliance from Leidos.
* **Control Groups (Enterprise)**: Require multiple members of an Identity
group to authorize a requested action before it is allowed to run.
* **Cloud Auto-Unseal (Enterprise)**: Automatically unseal Vault using AWS KMS
and GCP CKMS.
* **Sentinel Integration (Enterprise)**: Take advantage of HashiCorp Sentinel
to create extremely flexible access control policies -- even on
unauthenticated endpoints.
* **Barrier Rekey Support for Auto-Unseal (Enterprise)**: When using auto-unsealing
functionality, the `rekey` operation is now supported; it uses recovery keys
to authorize the master key rekey.
* **Operation Token for Disaster Recovery Actions (Enterprise)**: When using
Disaster Recovery replication, a token can be created that can be used to
authorize actions such as promotion and updating primary information, rather
than using recovery keys.
* **Trigger Auto-Unseal with Recovery Keys (Enterprise)**: When using
auto-unsealing, a request to unseal Vault can be triggered by a threshold of
recovery keys, rather than requiring the Vault process to be restarted.
* **UI Redesign (Enterprise)**: All new experience for the Vault Enterprise
UI. The look and feel has been completely redesigned to give users a better
experience and make managing secrets fast and easy.
* **UI: SSH Secret Backend (Enterprise)**: Configure an SSH secret backend,
create and browse roles. And use them to sign keys or generate one time
passwords.
* **UI: AWS Secret Backend (Enterprise)**: You can now configure the AWS
backend via the Vault Enterprise UI. In addition you can create roles,
browse the roles and Generate IAM Credentials from them in the UI.
IMPROVEMENTS:
* api: Add ability to set custom headers on each call [GH-3394]
* command/server: Add config option to disable requesting client certificates
[GH-3373]
* secret/cassandra: Work around Cassandra ignoring consistency levels for a
user listing query [GH-3469]
* secret/pki: Allow entering URLs for `pki` as both comma-separated strings and JSON
arrays [GH-3409]
* secret/transit: Sign and verify operations now support a `none` hash
algorithm to allow signing/verifying pre-hashed data [GH-3448]
* core: Disallow mounting underneath an existing path, not just over [GH-2919]
* physical/file: Use `700` as permissions when creating directories. The files
themselves were `600` and are all encrypted, but this doesn't hurt.
* secret/aws: Add ability to use custom IAM/STS endpoints [GH-3416]
* secret/cassandra: Work around Cassandra ignoring consistency levels for a
user listing query [GH-3469]
* secret/pki: Private keys can now be marshalled as PKCS#8 [GH-3518]
* secret/pki: Allow entering URLs for `pki` as both comma-separated strings and JSON
arrays [GH-3409]
* secret/ssh: Role TTL/max TTL can now be specified as either a string or an
integer [GH-3507]
* secret/transit: Sign and verify operations now support a `none` hash
algorithm to allow signing/verifying pre-hashed data [GH-3448]
* secret/database: Add the ability to glob allowed roles in the Database Backend [GH-3387]
* ui (enterprise): Support for RSA keys in the transit backend
* ui (enterprise): Support for DR Operation Token generation, promoting, and
updating primary on DR Secondary clusters
BUG FIXES:
* api: Fix panic when setting a custom HTTP client but with a nil transport
[GH-3437]
[GH-3435] [GH-3437]
* api: Fix authing to the `cert` backend when the CA for the client cert is
not known to the server's listener [GH-2946]
* auth/approle: Create role ID index during read if a role is missing one [GH-3561]
* auth/aws: Don't allow mutually exclusive options [GH-3291]
* auth/radius: Fix logging in in some situations [GH-3461]
* core: Fix memleak when a connection would connect to the cluster port and
then go away [GH-3513]
* core: Fix panic if a single-use token is used to step-down or seal [GH-3497]
* core: Set rather than add headers to prevent some duplicated headers in
responses when requests were forwarded to the active node [GH-3485]
* physical/etcd3: Fix some listing issues due to how etcd3 does prefix
matching [GH-3406]
* physical/etcd3: Fix case where standbys can lose their etcd client lease
[GH-3031]
* physical/file: Fix listing when underscores are the first component of a
path [GH-3476]
* plugins: Allow response errors to be returned from backend plugins [GH-3412]
* secret/transit: Fix panic if the length of the input ciphertext was less
than the expected nonce length [GH-3521]
* ui (enterprise): Reinstate support for generic secret backends - this was
erroneously removed in a previous release
## 0.8.3 (September 19th, 2017)
@@ -117,7 +344,7 @@ IMPROVEMENTS:
* audit/file: Allow specifying `stdout` as the `file_path` to log to standard
output [GH-3235]
* auth/aws: Allow wildcards in `bound_iam_principal_id` [GH-3213]
* auth/aws: Allow wildcards in `bound_iam_principal_arn` [GH-3213]
* auth/okta: Compare groups case-insensitively since Okta is only
case-preserving [GH-3240]
* auth/okta: Standardize Okta configuration APIs across backends [GH-3245]

View File

@@ -79,7 +79,7 @@ vet:
prep: fmtcheck
@sh -c "'$(CURDIR)/scripts/goversioncheck.sh' '$(GO_VERSION_MIN)'"
go generate $(go list ./... | grep -v /vendor/)
cp .hooks/* .git/hooks/
@if [ -d .git/hooks ]; then cp .hooks/* .git/hooks/; fi
# bootstrap the build by downloading additional tools
bootstrap:
@@ -92,8 +92,11 @@ proto:
protoc -I helper/forwarding -I vault -I ../../.. vault/*.proto --go_out=plugins=grpc:vault
protoc -I helper/storagepacker helper/storagepacker/types.proto --go_out=plugins=grpc:helper/storagepacker
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go
fmtcheck:
@sh -c "'$(CURDIR)/scripts/gofmtcheck.sh'"

View File

@@ -1,8 +1,10 @@
Vault [![Build Status](https://travis-ci.org/hashicorp/vault.svg)](https://travis-ci.org/hashicorp/vault) [![Join the chat at https://gitter.im/hashicorp-vault/Lobby](https://badges.gitter.im/hashicorp-vault/Lobby.svg)](https://gitter.im/hashicorp-vault/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![vault enterprise](https://img.shields.io/badge/vault-enterprise-yellow.svg?colorB=7c8797&colorA=000000)](https://www.hashicorp.com/products/vault/?utm_source=github&utm_medium=banner&utm_campaign=github-vault-enterprise)
=========
# Vault [![Build Status](https://travis-ci.org/hashicorp/vault.svg)](https://travis-ci.org/hashicorp/vault) [![Join the chat at https://gitter.im/hashicorp-vault/Lobby](https://badges.gitter.im/hashicorp-vault/Lobby.svg)](https://gitter.im/hashicorp-vault/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![vault enterprise](https://img.shields.io/badge/vault-enterprise-yellow.svg?colorB=7c8797&colorA=000000)](https://www.hashicorp.com/products/vault/?utm_source=github&utm_medium=banner&utm_campaign=github-vault-enterprise)
----
**Please note**: We take Vault's security and our users' trust very seriously. If you believe you have found a security issue in Vault, _please responsibly disclose_ by contacting us at [security@hashicorp.com](mailto:security@hashicorp.com).
=========
----
- Website: https://www.vaultproject.io
- IRC: `#vault-tool` on Freenode

View File

@@ -5,8 +5,6 @@ import (
"net"
"net/http"
"testing"
"golang.org/x/net/http2"
)
// testHTTPServer creates a test HTTP server that handles requests until
@@ -19,9 +17,6 @@ func testHTTPServer(
}
server := &http.Server{Handler: handler}
if err := http2.ConfigureServer(server, nil); err != nil {
t.Fatal(err)
}
go server.Serve(ln)
config := DefaultConfig()

View File

@@ -13,12 +13,12 @@ import (
"sync"
"time"
"golang.org/x/net/http2"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/sethgrid/pester"
"golang.org/x/net/http2"
)
const EnvVaultAddress = "VAULT_ADDR"
@@ -43,24 +43,31 @@ type WrappingLookupFunc func(operation, path string) string
// Config is used to configure the creation of the client.
type Config struct {
modifyLock sync.RWMutex
// Address is the address of the Vault server. This should be a complete
// URL such as "http://vault.example.com". If you need a custom SSL
// cert or want to enable insecure mode, you need to specify a custom
// HttpClient.
Address string
// HttpClient is the HTTP client to use, which will currently always have the
// same values as http.DefaultClient. This is used to control redirect behavior.
// HttpClient is the HTTP client to use. Vault sets sane defaults for the
// http.Client and its associated http.Transport created in DefaultConfig.
// If you must modify Vault's defaults, it is suggested that you start with
// that client and modify as needed rather than start with an empty client
// (or http.DefaultClient).
HttpClient *http.Client
redirectSetup sync.Once
// MaxRetries controls the maximum number of times to retry when a 5xx error
// occurs. Set to 0 or less to disable retrying. Defaults to 0.
MaxRetries int
// Timeout is for setting custom timeout parameter in the HttpClient
Timeout time.Duration
// If there is an error when creating the configuration, this will be the
// error
Error error
}
// TLSConfig contains the parameters needed to configure TLS on the HTTP client
@@ -93,60 +100,91 @@ type TLSConfig struct {
//
// The default Address is https://127.0.0.1:8200, but this can be overridden by
// setting the `VAULT_ADDR` environment variable.
//
// If an error is encountered, this will return nil.
func DefaultConfig() *Config {
config := &Config{
Address: "https://127.0.0.1:8200",
HttpClient: cleanhttp.DefaultClient(),
}
config.HttpClient.Timeout = time.Second * 60
transport := config.HttpClient.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
if err := http2.ConfigureTransport(transport); err != nil {
config.Error = err
return config
}
if v := os.Getenv(EnvVaultAddress); v != "" {
config.Address = v
if err := config.ReadEnvironment(); err != nil {
config.Error = err
return config
}
// Ensure redirects are not automatically followed
// Note that this is sane for the API client as it has its own
// redirect handling logic (and thus also for command/meta),
// but in e.g. http_test actual redirect handling is necessary
config.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// Returning this value causes the Go net library to not close the
// response body and to nil out the error. Otherwise pester tries
// three times on every redirect because it sees an error from this
// function (to prevent redirects) passing through to it.
return http.ErrUseLastResponse
}
return config
}
// ConfigureTLS takes a set of TLS configurations and applies those to the the HTTP client.
// ConfigureTLS takes a set of TLS configurations and applies those to the the
// HTTP client.
func (c *Config) ConfigureTLS(t *TLSConfig) error {
if c.HttpClient == nil {
c.HttpClient = DefaultConfig().HttpClient
}
clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
var clientCert tls.Certificate
foundClientCert := false
if t.CACert != "" || t.CAPath != "" || t.ClientCert != "" || t.ClientKey != "" || t.Insecure {
if t.ClientCert != "" && t.ClientKey != "" {
var err error
clientCert, err = tls.LoadX509KeyPair(t.ClientCert, t.ClientKey)
if err != nil {
return err
}
foundClientCert = true
} else if t.ClientCert != "" || t.ClientKey != "" {
return fmt.Errorf("Both client cert and client key must be provided")
switch {
case t.ClientCert != "" && t.ClientKey != "":
var err error
clientCert, err = tls.LoadX509KeyPair(t.ClientCert, t.ClientKey)
if err != nil {
return err
}
foundClientCert = true
case t.ClientCert != "" || t.ClientKey != "":
return fmt.Errorf("Both client cert and client key must be provided")
}
if t.CACert != "" || t.CAPath != "" {
rootConfig := &rootcerts.Config{
CAFile: t.CACert,
CAPath: t.CAPath,
}
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
return err
}
}
clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
rootConfig := &rootcerts.Config{
CAFile: t.CACert,
CAPath: t.CAPath,
if t.Insecure {
clientTLSConfig.InsecureSkipVerify = true
}
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
return err
}
clientTLSConfig.InsecureSkipVerify = t.Insecure
if foundClientCert {
clientTLSConfig.Certificates = []tls.Certificate{clientCert}
// We use this function to ignore the server's preferential list of
// CAs, otherwise any CA used for the cert auth backend must be in the
// server's CA pool
clientTLSConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientCert, nil
}
}
if t.TLSServerName != "" {
clientTLSConfig.ServerName = t.TLSServerName
}
@@ -154,9 +192,8 @@ func (c *Config) ConfigureTLS(t *TLSConfig) error {
return nil
}
// ReadEnvironment reads configuration information from the
// environment. If there is an error, no configuration value
// is updated.
// ReadEnvironment reads configuration information from the environment. If
// there is an error, no configuration value is updated.
func (c *Config) ReadEnvironment() error {
var envAddress string
var envCACert string
@@ -218,6 +255,10 @@ func (c *Config) ReadEnvironment() error {
TLSServerName: envTLSServerName,
Insecure: envInsecure,
}
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
if err := c.ConfigureTLS(t); err != nil {
return err
}
@@ -237,10 +278,9 @@ func (c *Config) ReadEnvironment() error {
return nil
}
// Client is the client to the Vault API. Create a client with NewClient. Note:
// it is not safe to modify client configuration from multiple goroutines at
// once. Set configuration first, then run requests.
// Client is the client to the Vault API. Create a client with NewClient.
type Client struct {
modifyLock sync.RWMutex
addr *url.URL
config *Config
token string
@@ -250,24 +290,29 @@ type Client struct {
policyOverride bool
}
// SetMFACreds sets the MFA credentials supplied either via the environment
// variable or via the command line.
func (c *Client) SetMFACreds(creds []string) {
c.mfaCreds = creds
}
// NewClient returns a new client for the given configuration.
//
// If the configuration is nil, Vault will use configuration from
// DefaultConfig(), which is the recommended starting configuration.
//
// If the environment variable `VAULT_TOKEN` is present, the token will be
// automatically added to the client. Otherwise, you must manually call
// `SetToken()`.
func NewClient(c *Config) (*Client, error) {
if c == nil {
c = DefaultConfig()
if err := c.ReadEnvironment(); err != nil {
return nil, fmt.Errorf("error reading environment: %v", err)
}
def := DefaultConfig()
if def == nil {
return nil, fmt.Errorf("could not create/read default configuration")
}
if def.Error != nil {
return nil, errwrap.Wrapf("error encountered setting up default configuration: {{err}}", def.Error)
}
if c == nil {
c = def
}
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
u, err := url.Parse(c.Address)
if err != nil {
@@ -275,41 +320,19 @@ func NewClient(c *Config) (*Client, error) {
}
if c.HttpClient == nil {
c.HttpClient = DefaultConfig().HttpClient
c.HttpClient = def.HttpClient
}
if c.HttpClient.Transport == nil {
c.HttpClient.Transport = cleanhttp.DefaultTransport()
c.HttpClient.Transport = def.HttpClient.Transport
}
if tp, ok := c.HttpClient.Transport.(*http.Transport); ok {
if err := http2.ConfigureTransport(tp); err != nil {
return nil, err
}
}
redirFunc := func() {
// Ensure redirects are not automatically followed
// Note that this is sane for the API client as it has its own
// redirect handling logic (and thus also for command/meta),
// but in e.g. http_test actual redirect handling is necessary
c.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// Returning this value causes the Go net library to not close the
// response body and to nil out the error. Otherwise pester tries
// three times on every redirect because it sees an error from this
// function (to prevent redirects) passing through to it.
return http.ErrUseLastResponse
}
}
c.redirectSetup.Do(redirFunc)
client := &Client{
addr: u,
config: c,
}
if token := os.Getenv(EnvVaultToken); token != "" {
client.SetToken(token)
client.token = token
}
return client, nil
@@ -319,6 +342,9 @@ func NewClient(c *Config) (*Client, error) {
// "<Scheme>://<Host>:<Port>". Setting this on a client will override the
// value of VAULT_ADDR environment variable.
func (c *Client) SetAddress(addr string) error {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
var err error
if c.addr, err = url.Parse(addr); err != nil {
return fmt.Errorf("failed to set address: %v", err)
@@ -329,56 +355,112 @@ func (c *Client) SetAddress(addr string) error {
// Address returns the Vault URL the client is configured to connect to
func (c *Client) Address() string {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
return c.addr.String()
}
// SetMaxRetries sets the number of retries that will be used in the case of certain errors
func (c *Client) SetMaxRetries(retries int) {
c.modifyLock.RLock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()
c.modifyLock.RUnlock()
c.config.MaxRetries = retries
}
// SetClientTimeout sets the client request timeout
func (c *Client) SetClientTimeout(timeout time.Duration) {
c.modifyLock.RLock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()
c.modifyLock.RUnlock()
c.config.Timeout = timeout
}
// SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs
// for a given operation and path
func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.wrappingLookupFunc = lookupFunc
}
// SetMFACreds sets the MFA credentials supplied either via the environment
// variable or via the command line.
func (c *Client) SetMFACreds(creds []string) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.mfaCreds = creds
}
// Token returns the access token being used by this client. It will
// return the empty string if there is no token set.
func (c *Client) Token() string {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
return c.token
}
// SetToken sets the token directly. This won't perform any auth
// verification, it simply sets the token properly for future requests.
func (c *Client) SetToken(v string) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.token = v
}
// ClearToken deletes the token if it is set or does nothing otherwise.
func (c *Client) ClearToken() {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.token = ""
}
// SetHeaders sets the headers to be used for future requests.
func (c *Client) SetHeaders(headers http.Header) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.headers = headers
}
// Clone creates a copy of this client.
// Clone creates a new client with the same configuration. Note that the same
// underlying http.Client is used; modifying the client from more than one
// goroutine at once may not be safe, so modify the client as needed and then
// clone.
func (c *Client) Clone() (*Client, error) {
return NewClient(c.config)
c.modifyLock.RLock()
c.config.modifyLock.RLock()
config := c.config
c.modifyLock.RUnlock()
newConfig := &Config{
Address: config.Address,
HttpClient: config.HttpClient,
MaxRetries: config.MaxRetries,
Timeout: config.Timeout,
}
config.modifyLock.RUnlock()
return NewClient(newConfig)
}
// SetPolicyOverride sets whether requests should be sent with the policy
// override flag to request overriding soft-mandatory Sentinel policies (both
// RGPs and EGPs)
func (c *Client) SetPolicyOverride(override bool) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.policyOverride = override
}
@@ -386,6 +468,9 @@ func (c *Client) SetPolicyOverride(override bool) {
// configured for this client. This is an advanced method and generally
// doesn't need to be called externally.
func (c *Client) NewRequest(method, requestPath string) *Request {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV
// record and take the highest match; this is not designed for high-availability, just discovery
var host string = c.addr.Host
@@ -442,6 +527,11 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequest(r *Request) (*Response, error) {
c.modifyLock.RLock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()
c.modifyLock.RUnlock()
redirectCount := 0
START:
req, err := r.ToHTTP()

View File

@@ -163,8 +163,8 @@ func TestClientEnvSettings(t *testing.T) {
if len(tlsConfig.RootCAs.Subjects()) == 0 {
t.Fatalf("bad: expected a cert pool with at least one subject")
}
if len(tlsConfig.Certificates) != 1 {
t.Fatalf("bad: expected client tls config to have a client certificate")
if tlsConfig.GetClientCertificate == nil {
t.Fatalf("bad: expected client tls config to have a certificate getter")
}
if tlsConfig.InsecureSkipVerify != true {
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
@@ -213,3 +213,16 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
t.Fatal(err)
}
}
func TestClone(t *testing.T) {
client1, err1 := NewClient(nil)
if err1 != nil {
t.Fatalf("NewClient failed: %v", err1)
}
client2, err2 := client1.Clone()
if err2 != nil {
t.Fatalf("Clone failed: %v", err2)
}
_ = client2
}

View File

@@ -50,12 +50,13 @@ var (
type Renewer struct {
l sync.Mutex
client *Client
secret *Secret
grace time.Duration
random *rand.Rand
doneCh chan error
renewCh chan *RenewOutput
client *Client
secret *Secret
grace time.Duration
random *rand.Rand
increment int
doneCh chan error
renewCh chan *RenewOutput
stopped bool
stopCh chan struct{}
@@ -79,6 +80,11 @@ type RenewerInput struct {
// RenewBuffer is the size of the buffered channel where renew messages are
// dispatched.
RenewBuffer int
// The new TTL, in seconds, that should be set on the lease. The TTL set
// here may or may not be honored by the vault server, based on Vault
// configuration or any associated max TTL values.
Increment int
}
// RenewOutput is the metadata returned to the client (if it's listening) to
@@ -120,12 +126,13 @@ func (c *Client) NewRenewer(i *RenewerInput) (*Renewer, error) {
}
return &Renewer{
client: c,
secret: secret,
grace: grace,
random: random,
doneCh: make(chan error, 1),
renewCh: make(chan *RenewOutput, renewBuffer),
client: c,
secret: secret,
grace: grace,
increment: i.Increment,
random: random,
doneCh: make(chan error, 1),
renewCh: make(chan *RenewOutput, renewBuffer),
stopped: false,
stopCh: make(chan struct{}),
@@ -245,7 +252,7 @@ func (r *Renewer) renewLease() error {
}
// Renew the lease.
renewal, err := client.Sys().Renew(leaseID, 0)
renewal, err := client.Sys().Renew(leaseID, r.increment)
if err != nil {
return err
}

View File

@@ -224,6 +224,7 @@ func (s *Secret) TokenTTL() (time.Duration, error) {
// available in WrappedAccessor.
type SecretWrapInfo struct {
Token string `json:"token"`
Accessor string `json:"accessor"`
TTL int `json:"ttl"`
CreationTime time.Time `json:"creation_time"`
CreationPath string `json:"creation_path"`

View File

@@ -26,6 +26,7 @@ func TestParseSecret(t *testing.T) {
],
"wrap_info": {
"token": "token",
"accessor": "accessor",
"ttl": 60,
"creation_time": "2016-06-07T15:52:10-04:00",
"wrapped_accessor": "abcd1234"
@@ -51,6 +52,7 @@ func TestParseSecret(t *testing.T) {
},
WrapInfo: &api.SecretWrapInfo{
Token: "token",
Accessor: "accessor",
TTL: 60,
CreationTime: rawTime,
WrappedAccessor: "abcd1234",

View File

@@ -87,6 +87,7 @@ type EnableAuthOptions struct {
Config AuthConfigInput `json:"config" structs:"config"`
Local bool `json:"local" structs:"local"`
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}
type AuthConfigInput struct {
@@ -99,6 +100,7 @@ type AuthMount struct {
Accessor string `json:"accessor" structs:"accessor" mapstructure:"accessor"`
Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"`
Local bool `json:"local" structs:"local" mapstructure:"local"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}
type AuthConfigOutput struct {

View File

@@ -1,7 +1,15 @@
package api
func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/generate-root/attempt")
return c.generateRootStatusCommon("/v1/sys/generate-root/attempt")
}
func (c *Sys) GenerateDROperationTokenStatus() (*GenerateRootStatusResponse, error) {
return c.generateRootStatusCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
}
func (c *Sys) generateRootStatusCommon(path string) (*GenerateRootStatusResponse, error) {
r := c.c.NewRequest("GET", path)
resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
@@ -14,12 +22,20 @@ func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) {
}
func (c *Sys) GenerateRootInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) {
return c.generateRootInitCommon("/v1/sys/generate-root/attempt", otp, pgpKey)
}
func (c *Sys) GenerateDROperationTokenInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) {
return c.generateRootInitCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt", otp, pgpKey)
}
func (c *Sys) generateRootInitCommon(path, otp, pgpKey string) (*GenerateRootStatusResponse, error) {
body := map[string]interface{}{
"otp": otp,
"pgp_key": pgpKey,
}
r := c.c.NewRequest("PUT", "/v1/sys/generate-root/attempt")
r := c.c.NewRequest("PUT", path)
if err := r.SetJSONBody(body); err != nil {
return nil, err
}
@@ -36,7 +52,15 @@ func (c *Sys) GenerateRootInit(otp, pgpKey string) (*GenerateRootStatusResponse,
}
func (c *Sys) GenerateRootCancel() error {
r := c.c.NewRequest("DELETE", "/v1/sys/generate-root/attempt")
return c.generateRootCancelCommon("/v1/sys/generate-root/attempt")
}
func (c *Sys) GenerateDROperationTokenCancel() error {
return c.generateRootCancelCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
}
func (c *Sys) generateRootCancelCommon(path string) error {
r := c.c.NewRequest("DELETE", path)
resp, err := c.c.RawRequest(r)
if err == nil {
defer resp.Body.Close()
@@ -45,12 +69,20 @@ func (c *Sys) GenerateRootCancel() error {
}
func (c *Sys) GenerateRootUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) {
return c.generateRootUpdateCommon("/v1/sys/generate-root/update", shard, nonce)
}
func (c *Sys) GenerateDROperationTokenUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) {
return c.generateRootUpdateCommon("/v1/sys/replication/dr/secondary/generate-operation-token/update", shard, nonce)
}
func (c *Sys) generateRootUpdateCommon(path, shard, nonce string) (*GenerateRootStatusResponse, error) {
body := map[string]interface{}{
"key": shard,
"nonce": nonce,
}
r := c.c.NewRequest("PUT", "/v1/sys/generate-root/update")
r := c.c.NewRequest("PUT", path)
if err := r.SetJSONBody(body); err != nil {
return nil, err
}
@@ -72,6 +104,7 @@ type GenerateRootStatusResponse struct {
Progress int
Required int
Complete bool
EncodedToken string `json:"encoded_token"`
EncodedRootToken string `json:"encoded_root_token"`
PGPFingerprint string `json:"pgp_fingerprint"`
}

View File

@@ -125,6 +125,7 @@ type MountInput struct {
Config MountConfigInput `json:"config" structs:"config"`
Local bool `json:"local" structs:"local"`
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}
type MountConfigInput struct {
@@ -132,7 +133,6 @@ type MountConfigInput struct {
MaxLeaseTTL string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"`
ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"`
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}
type MountOutput struct {
@@ -141,6 +141,7 @@ type MountOutput struct {
Accessor string `json:"accessor" structs:"accessor"`
Config MountConfigOutput `json:"config" structs:"config"`
Local bool `json:"local" structs:"local"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}
type MountConfigOutput struct {
@@ -148,5 +149,4 @@ type MountConfigOutput struct {
MaxLeaseTTL int `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"`
ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"`
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}

View File

@@ -49,12 +49,14 @@ func sealStatusRequest(c *Sys, r *Request) (*SealStatusResponse, error) {
}
type SealStatusResponse struct {
Sealed bool `json:"sealed"`
T int `json:"t"`
N int `json:"n"`
Progress int `json:"progress"`
Nonce string `json:"nonce"`
Version string `json:"version"`
ClusterName string `json:"cluster_name,omitempty"`
ClusterID string `json:"cluster_id,omitempty"`
Type string `json:"type"`
Sealed bool `json:"sealed"`
T int `json:"t"`
N int `json:"n"`
Progress int `json:"progress"`
Nonce string `json:"nonce"`
Version string `json:"version"`
ClusterName string `json:"cluster_name,omitempty"`
ClusterID string `json:"cluster_id,omitempty"`
RecoverySeal bool `json:"recovery_seal"`
}

View File

@@ -146,7 +146,7 @@ func (f *AuditFormatter) FormatRequest(
}
if !config.OmitTime {
reqEntry.Time = time.Now().UTC().Format(time.RFC3339)
reqEntry.Time = time.Now().UTC().Format(time.RFC3339Nano)
}
return f.AuditFormatWriter.WriteRequest(w, reqEntry)
@@ -242,12 +242,13 @@ func (f *AuditFormatter) FormatResponse(
// Cache and restore accessor in the response
if resp != nil {
var accessor, wrappedAccessor string
var accessor, wrappedAccessor, wrappingAccessor string
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
accessor = resp.Auth.Accessor
}
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
wrappedAccessor = resp.WrapInfo.WrappedAccessor
wrappingAccessor = resp.WrapInfo.Accessor
}
if err := Hash(salt, resp); err != nil {
return err
@@ -258,6 +259,9 @@ func (f *AuditFormatter) FormatResponse(
if wrappedAccessor != "" {
resp.WrapInfo.WrappedAccessor = wrappedAccessor
}
if wrappingAccessor != "" {
resp.WrapInfo.Accessor = wrappingAccessor
}
}
}
@@ -301,6 +305,7 @@ func (f *AuditFormatter) FormatResponse(
respWrapInfo = &AuditResponseWrapInfo{
TTL: int(resp.WrapInfo.TTL / time.Second),
Token: token,
Accessor: resp.WrapInfo.Accessor,
CreationTime: resp.WrapInfo.CreationTime.Format(time.RFC3339Nano),
CreationPath: resp.WrapInfo.CreationPath,
WrappedAccessor: resp.WrapInfo.WrappedAccessor,
@@ -347,7 +352,7 @@ func (f *AuditFormatter) FormatResponse(
}
if !config.OmitTime {
respEntry.Time = time.Now().UTC().Format(time.RFC3339)
respEntry.Time = time.Now().UTC().Format(time.RFC3339Nano)
}
return f.AuditFormatWriter.WriteResponse(w, respEntry)
@@ -412,6 +417,7 @@ type AuditSecret struct {
type AuditResponseWrapInfo struct {
TTL int `json:"ttl"`
Token string `json:"token"`
Accessor string `json:"accessor"`
CreationTime string `json:"creation_time"`
CreationPath string `json:"creation_path"`
WrappedAccessor string `json:"wrapped_accessor,omitempty"`

View File

@@ -93,6 +93,7 @@ func Hash(salter *salt.Salt, raw interface{}) error {
}
s.Token = fn(s.Token)
s.Accessor = fn(s.Accessor)
if s.WrappedAccessor != "" {
s.WrappedAccessor = fn(s.WrappedAccessor)

View File

@@ -148,6 +148,7 @@ func TestHash(t *testing.T) {
WrapInfo: &wrapping.ResponseWrapInfo{
TTL: 60,
Token: "bar",
Accessor: "flimflam",
CreationTime: now,
WrappedAccessor: "bar",
},
@@ -160,6 +161,7 @@ func TestHash(t *testing.T) {
WrapInfo: &wrapping.ResponseWrapInfo{
TTL: 60,
Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
Accessor: "hmac-sha256:7c9c6fe666d0af73b3ebcfbfabe6885015558213208e6635ba104047b22f6390",
CreationTime: now,
WrappedAccessor: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
},
@@ -206,6 +208,11 @@ func TestHash(t *testing.T) {
if err := Hash(localSalt, tc.Input); err != nil {
t.Fatalf("err: %s\n\n%s", err, input)
}
if _, ok := tc.Input.(*logical.Response); ok {
if !reflect.DeepEqual(tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo) {
t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output\n%#v", input, tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo)
}
}
if !reflect.DeepEqual(tc.Input, tc.Output) {
t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output\n%#v", input, tc.Input, tc.Output)
}

View File

@@ -75,7 +75,9 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
if err != nil {
return nil, err
}
mode = os.FileMode(m)
if m != 0 {
mode = os.FileMode(m)
}
}
b := &Backend{
@@ -247,13 +249,15 @@ func (b *Backend) open() error {
}
// Change the file mode in case the log file already existed. We special
// case /dev/null since we can't chmod it
// case /dev/null since we can't chmod it and bypass if the mode is zero
switch b.path {
case "/dev/null":
default:
err = os.Chmod(b.path, b.mode)
if err != nil {
return err
if b.mode != 0 {
err = os.Chmod(b.path, b.mode)
if err != nil {
return err
}
}
}

View File

@@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL")
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
}
return nil
}
@@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL")
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
}
return nil
}

View File

@@ -3,7 +3,6 @@ package approle
import (
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@@ -52,7 +51,7 @@ func (b *backend) pathLoginUpdateAliasLookahead(req *logical.Request, data *fram
func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, roleName, metadata, _, err := b.validateCredentials(req, data)
if err != nil || role == nil {
return logical.ErrorResponse(fmt.Sprintf("failed to validate SecretID: %s", err)), nil
return logical.ErrorResponse(fmt.Sprintf("failed to validate credentials: %v", err)), nil
}
// Always include the role name, for later filtering
@@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
Policies: role.Policies,
LeaseOptions: logical.LeaseOptions{
Renewable: true,
TTL: role.TokenTTL,
},
Alias: &logical.Alias{
Name: role.RoleID,
},
}
// If 'Period' is set, use the value of 'Period' as the TTL.
// Otherwise, set the normal TokenTTL.
if role.Period > time.Duration(0) {
auth.TTL = role.Period
} else {
auth.TTL = role.TokenTTL
}
return &logical.Response{
Auth: auth,
}, nil
@@ -94,8 +86,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
return nil, fmt.Errorf("failed to fetch role_name during renewal")
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
// Ensure that the Role still exists.
role, err := b.roleEntry(req.Storage, roleName)
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, fmt.Errorf("failed to validate role %s during renewal:%s", roleName, err)
}
@@ -103,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
}
// If 'Period' is set on the Role, the token should never expire.
// Replenish the TTL with 'Period's value.
if role.Period > time.Duration(0) {
// If 'Period' was updated after the token was issued,
// token will bear the updated 'Period' value as its TTL.
req.Auth.TTL = role.Period
return &logical.Response{Auth: req.Auth}, nil
} else {
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
if err != nil {
return nil, err
}
resp.Auth.Period = role.Period
return resp, nil
}
const pathLoginHelpSys = "Issue a token based on the credentials supplied"

View File

@@ -57,6 +57,10 @@ type roleStorageEntry struct {
// value is not modified on the role. If the `Period` in the role is modified,
// a token will pick up the new value during its next renewal.
Period time.Duration `json:"period" mapstructure:"period" structs:"period"`
// LowerCaseRoleName enforces the lower casing of role names for all the
// roles that get created since this field was introduced.
LowerCaseRoleName bool `json:"lower_case_role_name" mapstructure:"lower_case_role_name" structs:"lower_case_role_name"`
}
// roleIDStorageEntry represents the reverse mapping from RoleID to Role
@@ -509,10 +513,20 @@ the role.`,
// pathRoleExistenceCheck returns whether the role with the given name exists or not.
func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
role, err := b.roleEntry(req.Storage, data.Get("role_name").(string))
roleName := data.Get("role_name").(string)
if roleName == "" {
return false, fmt.Errorf("missing role_name")
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return false, err
}
return role != nil, nil
}
@@ -537,13 +551,21 @@ func (b *backend) pathRoleSecretIDList(req *logical.Request, data *framework.Fie
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
// Get the role entry
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("role %s does not exist", roleName)), nil
return logical.ErrorResponse(fmt.Sprintf("role %q does not exist", roleName)), nil
}
if role.LowerCaseRoleName {
roleName = strings.ToLower(roleName)
}
// Guard the list operation with an outer lock
@@ -552,7 +574,7 @@ func (b *backend) pathRoleSecretIDList(req *logical.Request, data *framework.Fie
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
}
// Listing works one level at a time. Get the first level of data
@@ -618,9 +640,8 @@ func validateRoleConstraints(role *roleStorageEntry) error {
return nil
}
// setRoleEntry grabs a write lock and stores the options on an role into the
// storage. Also creates a reverse index from the role's RoleID to the role
// itself.
// setRoleEntry persists the role and creates an index from roleID to role
// name.
func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
if roleName == "" {
return fmt.Errorf("missing role name")
@@ -641,7 +662,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
return err
}
if entry == nil {
return fmt.Errorf("failed to create storage entry for role %s", roleName)
return fmt.Errorf("failed to create storage entry for role %q", roleName)
}
// Check if the index from the role_id to role already exists
@@ -680,7 +701,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
})
}
// roleEntry grabs the read lock and fetches the options of an role from the storage
// roleEntry reads the role from storage
func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role_name")
@@ -688,11 +709,6 @@ func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEnt
var role roleStorageEntry
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if entry, err := s.Get("role/" + strings.ToLower(roleName)); err != nil {
return nil, err
} else if entry == nil {
@@ -712,6 +728,10 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
// Check if the role already exists
role, err := b.roleEntry(req.Storage, roleName)
if err != nil {
@@ -722,13 +742,14 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
if role == nil && req.Operation == logical.CreateOperation {
hmacKey, err := uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("failed to create role_id: %s\n", err)
return nil, fmt.Errorf("failed to create role_id: %v\n", err)
}
role = &roleStorageEntry{
HMACKey: hmacKey,
HMACKey: hmacKey,
LowerCaseRoleName: true,
}
} else if role == nil {
return nil, fmt.Errorf("role entry not found during update operation")
return logical.ErrorResponse(fmt.Sprintf("invalid role name")), nil
}
previousRoleID := role.RoleID
@@ -737,12 +758,12 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
} else if req.Operation == logical.CreateOperation {
roleID, err := uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate role_id: %s\n", err)
return nil, fmt.Errorf("failed to generate role_id: %v\n", err)
}
role.RoleID = roleID
}
if role.RoleID == "" {
return logical.ErrorResponse("invalid role_id"), nil
return logical.ErrorResponse("invalid role_id supplied, or failed to generate a role_id"), nil
}
if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok {
@@ -780,7 +801,7 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
role.Period = time.Second * time.Duration(data.Get("period").(int))
}
if role.Period > b.System().MaxLeaseTTL() {
return logical.ErrorResponse(fmt.Sprintf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
}
if secretIDNumUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok {
@@ -843,32 +864,78 @@ func (b *backend) pathRoleRead(req *logical.Request, data *framework.FieldData)
return logical.ErrorResponse("missing role_name"), nil
}
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
lock := b.roleLock(roleName)
lock.RLock()
lockRelease := lock.RUnlock
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
lockRelease()
return nil, err
} else if role == nil {
return nil, nil
} else {
// Convert the 'time.Duration' values to second.
role.SecretIDTTL /= time.Second
role.TokenTTL /= time.Second
role.TokenMaxTTL /= time.Second
role.Period /= time.Second
// Create a map of data to be returned and remove sensitive information from it
data := structs.New(role).Map()
delete(data, "role_id")
delete(data, "hmac_key")
resp := &logical.Response{
Data: data,
}
if err := validateRoleConstraints(role); err != nil {
resp.AddWarning("Role does not have any constraints set on it. Updates to this role will require a constraint to be set")
}
return resp, nil
}
if role == nil {
lockRelease()
return nil, nil
}
respData := map[string]interface{}{
"bind_secret_id": role.BindSecretID,
"bound_cidr_list": role.BoundCIDRList,
"period": role.Period / time.Second,
"policies": role.Policies,
"secret_id_num_uses": role.SecretIDNumUses,
"secret_id_ttl": role.SecretIDTTL / time.Second,
"token_max_ttl": role.TokenMaxTTL / time.Second,
"token_num_uses": role.TokenNumUses,
"token_ttl": role.TokenTTL / time.Second,
}
resp := &logical.Response{
Data: respData,
}
if err := validateRoleConstraints(role); err != nil {
resp.AddWarning("Role does not have any constraints set on it. Updates to this role will require a constraint to be set")
}
// For sanity, verify that the index still exists. If the index is missing,
// add one and return a warning so it can be reported.
roleIDIndex, err := b.roleIDEntry(req.Storage, role.RoleID)
if err != nil {
lockRelease()
return nil, err
}
if roleIDIndex == nil {
// Switch to a write lock
lock.RUnlock()
lock.Lock()
lockRelease = lock.Unlock
// Check again if the index is missing
roleIDIndex, err = b.roleIDEntry(req.Storage, role.RoleID)
if err != nil {
lockRelease()
return nil, err
}
if roleIDIndex == nil {
// Create a new index
err = b.setRoleIDEntry(req.Storage, role.RoleID, &roleIDStorageEntry{
Name: roleName,
})
if err != nil {
lockRelease()
return nil, fmt.Errorf("failed to create secondary index for role_id %q: %v", role.RoleID, err)
}
resp.AddWarning("Role identifier was missing an index back to role name. A new index has been added. Please report this observation.")
}
}
lockRelease()
return resp, nil
}
// pathRoleDelete removes the role from the storage
@@ -878,6 +945,10 @@ func (b *backend) pathRoleDelete(req *logical.Request, data *framework.FieldData
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -886,19 +957,14 @@ func (b *backend) pathRoleDelete(req *logical.Request, data *framework.FieldData
return nil, nil
}
// Acquire the lock before deleting the secrets.
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
// Just before the role is deleted, remove all the SecretIDs issued as part of the role.
if err = b.flushRoleSecrets(req.Storage, roleName, role.HMACKey); err != nil {
return nil, fmt.Errorf("failed to invalidate the secrets belonging to role '%s': %s", roleName, err)
return nil, fmt.Errorf("failed to invalidate the secrets belonging to role %q: %v", roleName, err)
}
// Delete the reverse mapping from RoleID to the role
if err = b.roleIDEntryDelete(req.Storage, role.RoleID); err != nil {
return nil, fmt.Errorf("failed to delete the mapping from RoleID to role '%s': %s", roleName, err)
return nil, fmt.Errorf("failed to delete the mapping from RoleID to role %q: %v", roleName, err)
}
// After deleting the SecretIDs and the RoleID, delete the role itself
@@ -921,25 +987,33 @@ func (b *backend) pathRoleSecretIDLookupUpdate(req *logical.Request, data *frame
return logical.ErrorResponse("missing secret_id"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
// Fetch the role
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
if role == nil {
return nil, fmt.Errorf("role %s does not exist", roleName)
return nil, fmt.Errorf("role %q does not exist", roleName)
}
if role.LowerCaseRoleName {
roleName = strings.ToLower(roleName)
}
// Create the HMAC of the secret ID using the per-role HMAC key
secretIDHMAC, err := createHMAC(role.HMACKey, secretID)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of secret_id: %s", err)
return nil, fmt.Errorf("failed to create HMAC of secret_id: %v", err)
}
// Create the HMAC of the roleName using the per-role HMAC key
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
}
// Create the index at which the secret_id would've been stored
@@ -996,22 +1070,26 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(req *logical.Request, data
return logical.ErrorResponse("missing secret_id"), nil
}
roleLock := b.roleLock(roleName)
roleLock.RLock()
defer roleLock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
if role == nil {
return nil, fmt.Errorf("role %s does not exist", roleName)
return nil, fmt.Errorf("role %q does not exist", roleName)
}
secretIDHMAC, err := createHMAC(role.HMACKey, secretID)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of secret_id: %s", err)
return nil, fmt.Errorf("failed to create HMAC of secret_id: %v", err)
}
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
}
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
@@ -1036,7 +1114,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(req *logical.Request, data
// Delete the storage entry that corresponds to the SecretID
if err := req.Storage.Delete(entryIndex); err != nil {
return nil, fmt.Errorf("failed to delete SecretID: %s", err)
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
}
return nil, nil
@@ -1059,12 +1137,16 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(req *logical.Request, dat
// Get the role details to fetch the RoleID and accessor to get
// the HMACed SecretID.
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
if role == nil {
return nil, fmt.Errorf("role %s does not exist", roleName)
return nil, fmt.Errorf("role %q does not exist", roleName)
}
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
@@ -1072,12 +1154,12 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(req *logical.Request, dat
return nil, err
}
if accessorEntry == nil {
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor:%s\n", secretIDAccessor)
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor: %q\n", secretIDAccessor)
}
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
}
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC)
@@ -1105,7 +1187,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque
return nil, err
}
if role == nil {
return nil, fmt.Errorf("role %s does not exist", roleName)
return nil, fmt.Errorf("role %q does not exist", roleName)
}
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
@@ -1113,12 +1195,12 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque
return nil, err
}
if accessorEntry == nil {
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor:%s\n", secretIDAccessor)
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor: %q\n", secretIDAccessor)
}
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
}
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC)
@@ -1134,7 +1216,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque
// Delete the storage entry that corresponds to the SecretID
if err := req.Storage.Delete(entryIndex); err != nil {
return nil, fmt.Errorf("failed to delete SecretID: %s", err)
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
}
return nil, nil
@@ -1146,6 +1228,11 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
// Re-read the role after grabbing the lock
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1154,11 +1241,6 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.BoundCIDRList = strings.TrimSpace(data.Get("bound_cidr_list").(string))
if role.BoundCIDRList == "" {
return logical.ErrorResponse("missing bound_cidr_list"), nil
@@ -1167,7 +1249,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew
if role.BoundCIDRList != "" {
valid, err := cidrutil.ValidateCIDRListString(role.BoundCIDRList, ",")
if err != nil {
return nil, fmt.Errorf("failed to validate CIDR blocks: %q", err)
return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err)
}
if !valid {
return logical.ErrorResponse("failed to validate CIDR blocks"), nil
@@ -1183,6 +1265,10 @@ func (b *backend) pathRoleBoundCIDRListRead(req *logical.Request, data *framewor
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1202,6 +1288,10 @@ func (b *backend) pathRoleBoundCIDRListDelete(req *logical.Request, data *framew
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1210,11 +1300,6 @@ func (b *backend) pathRoleBoundCIDRListDelete(req *logical.Request, data *framew
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
// Deleting a field implies setting the value to it's default value.
role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string)
@@ -1227,6 +1312,10 @@ func (b *backend) pathRoleBindSecretIDUpdate(req *logical.Request, data *framewo
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1235,11 +1324,6 @@ func (b *backend) pathRoleBindSecretIDUpdate(req *logical.Request, data *framewo
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok {
role.BindSecretID = bindSecretIDRaw.(bool)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1254,6 +1338,10 @@ func (b *backend) pathRoleBindSecretIDRead(req *logical.Request, data *framework
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1273,6 +1361,10 @@ func (b *backend) pathRoleBindSecretIDDelete(req *logical.Request, data *framewo
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1281,11 +1373,6 @@ func (b *backend) pathRoleBindSecretIDDelete(req *logical.Request, data *framewo
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
// Deleting a field implies setting the value to it's default value.
role.BindSecretID = data.GetDefaultOrZero("bind_secret_id").(bool)
@@ -1298,6 +1385,10 @@ func (b *backend) pathRolePoliciesUpdate(req *logical.Request, data *framework.F
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1311,11 +1402,6 @@ func (b *backend) pathRolePoliciesUpdate(req *logical.Request, data *framework.F
return logical.ErrorResponse("missing policies"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.Policies = policyutil.ParsePolicies(policiesRaw)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1327,6 +1413,10 @@ func (b *backend) pathRolePoliciesRead(req *logical.Request, data *framework.Fie
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1346,6 +1436,10 @@ func (b *backend) pathRolePoliciesDelete(req *logical.Request, data *framework.F
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1354,11 +1448,6 @@ func (b *backend) pathRolePoliciesDelete(req *logical.Request, data *framework.F
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.Policies = []string{}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1370,6 +1459,10 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(req *logical.Request, data *fram
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1378,11 +1471,6 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(req *logical.Request, data *fram
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if numUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok {
role.SecretIDNumUses = numUsesRaw.(int)
if role.SecretIDNumUses < 0 {
@@ -1400,6 +1488,10 @@ func (b *backend) pathRoleRoleIDUpdate(req *logical.Request, data *framework.Fie
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1408,11 +1500,6 @@ func (b *backend) pathRoleRoleIDUpdate(req *logical.Request, data *framework.Fie
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
previousRoleID := role.RoleID
role.RoleID = data.Get("role_id").(string)
if role.RoleID == "" {
@@ -1428,6 +1515,10 @@ func (b *backend) pathRoleRoleIDRead(req *logical.Request, data *framework.Field
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1447,6 +1538,10 @@ func (b *backend) pathRoleSecretIDNumUsesRead(req *logical.Request, data *framew
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1466,6 +1561,10 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(req *logical.Request, data *fram
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1474,11 +1573,6 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(req *logical.Request, data *fram
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.SecretIDNumUses = data.GetDefaultOrZero("secret_id_num_uses").(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1490,6 +1584,10 @@ func (b *backend) pathRoleSecretIDTTLUpdate(req *logical.Request, data *framewor
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1498,11 +1596,6 @@ func (b *backend) pathRoleSecretIDTTLUpdate(req *logical.Request, data *framewor
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if secretIDTTLRaw, ok := data.GetOk("secret_id_ttl"); ok {
role.SecretIDTTL = time.Second * time.Duration(secretIDTTLRaw.(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1517,6 +1610,10 @@ func (b *backend) pathRoleSecretIDTTLRead(req *logical.Request, data *framework.
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1537,6 +1634,10 @@ func (b *backend) pathRoleSecretIDTTLDelete(req *logical.Request, data *framewor
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1545,11 +1646,6 @@ func (b *backend) pathRoleSecretIDTTLDelete(req *logical.Request, data *framewor
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.SecretIDTTL = time.Second * time.Duration(data.GetDefaultOrZero("secret_id_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1561,6 +1657,10 @@ func (b *backend) pathRolePeriodUpdate(req *logical.Request, data *framework.Fie
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1569,15 +1669,10 @@ func (b *backend) pathRolePeriodUpdate(req *logical.Request, data *framework.Fie
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if periodRaw, ok := data.GetOk("period"); ok {
role.Period = time.Second * time.Duration(periodRaw.(int))
if role.Period > b.System().MaxLeaseTTL() {
return logical.ErrorResponse(fmt.Sprintf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
} else {
@@ -1591,6 +1686,10 @@ func (b *backend) pathRolePeriodRead(req *logical.Request, data *framework.Field
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1611,6 +1710,10 @@ func (b *backend) pathRolePeriodDelete(req *logical.Request, data *framework.Fie
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1619,11 +1722,6 @@ func (b *backend) pathRolePeriodDelete(req *logical.Request, data *framework.Fie
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.Period = time.Second * time.Duration(data.GetDefaultOrZero("period").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1635,6 +1733,10 @@ func (b *backend) pathRoleTokenNumUsesUpdate(req *logical.Request, data *framewo
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1643,11 +1745,6 @@ func (b *backend) pathRoleTokenNumUsesUpdate(req *logical.Request, data *framewo
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if tokenNumUsesRaw, ok := data.GetOk("token_num_uses"); ok {
role.TokenNumUses = tokenNumUsesRaw.(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1662,6 +1759,10 @@ func (b *backend) pathRoleTokenNumUsesRead(req *logical.Request, data *framework
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1681,6 +1782,10 @@ func (b *backend) pathRoleTokenNumUsesDelete(req *logical.Request, data *framewo
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1689,11 +1794,6 @@ func (b *backend) pathRoleTokenNumUsesDelete(req *logical.Request, data *framewo
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.TokenNumUses = data.GetDefaultOrZero("token_num_uses").(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1705,6 +1805,10 @@ func (b *backend) pathRoleTokenTTLUpdate(req *logical.Request, data *framework.F
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1713,11 +1817,6 @@ func (b *backend) pathRoleTokenTTLUpdate(req *logical.Request, data *framework.F
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if tokenTTLRaw, ok := data.GetOk("token_ttl"); ok {
role.TokenTTL = time.Second * time.Duration(tokenTTLRaw.(int))
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
@@ -1735,6 +1834,10 @@ func (b *backend) pathRoleTokenTTLRead(req *logical.Request, data *framework.Fie
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1755,6 +1858,10 @@ func (b *backend) pathRoleTokenTTLDelete(req *logical.Request, data *framework.F
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1763,11 +1870,6 @@ func (b *backend) pathRoleTokenTTLDelete(req *logical.Request, data *framework.F
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.TokenTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1779,6 +1881,10 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(req *logical.Request, data *framewor
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1787,11 +1893,6 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(req *logical.Request, data *framewor
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
if tokenMaxTTLRaw, ok := data.GetOk("token_max_ttl"); ok {
role.TokenMaxTTL = time.Second * time.Duration(tokenMaxTTLRaw.(int))
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
@@ -1809,6 +1910,10 @@ func (b *backend) pathRoleTokenMaxTTLRead(req *logical.Request, data *framework.
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
@@ -1829,6 +1934,10 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor
return logical.ErrorResponse("missing role_name"), nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
@@ -1837,11 +1946,6 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor
return nil, nil
}
lock := b.roleLock(roleName)
lock.Lock()
defer lock.Unlock()
role.TokenMaxTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_max_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
@@ -1850,7 +1954,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor
func (b *backend) pathRoleSecretIDUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
secretID, err := uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate SecretID:%s", err)
return nil, fmt.Errorf("failed to generate secret_id: %v", err)
}
return b.handleRoleSecretIDCommon(req, data, secretID)
}
@@ -1869,12 +1973,16 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
return logical.ErrorResponse("missing secret_id"), nil
}
lock := b.roleLock(roleName)
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("role %s does not exist", roleName)), nil
return logical.ErrorResponse(fmt.Sprintf("role %q does not exist", roleName)), nil
}
if !role.BindSecretID {
@@ -1887,7 +1995,7 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
if cidrList != "" {
valid, err := cidrutil.ValidateCIDRListString(cidrList, ",")
if err != nil {
return nil, fmt.Errorf("failed to validate CIDR blocks: %q", err)
return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err)
}
if !valid {
return logical.ErrorResponse("failed to validate CIDR blocks"), nil
@@ -1913,8 +2021,12 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
return logical.ErrorResponse(fmt.Sprintf("failed to parse metadata: %v", err)), nil
}
if role.LowerCaseRoleName {
roleName = strings.ToLower(roleName)
}
if secretIDStorage, err = b.registerSecretIDEntry(req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
return nil, fmt.Errorf("failed to store SecretID: %s", err)
return nil, fmt.Errorf("failed to store secret_id: %v", err)
}
return &logical.Response{

View File

@@ -2,6 +2,7 @@ package approle
import (
"reflect"
"strings"
"testing"
"time"
@@ -10,6 +11,248 @@ import (
"github.com/mitchellh/mapstructure"
)
func TestApprole_RoleNameLowerCasing(t *testing.T) {
var resp *logical.Response
var err error
var roleID, secretID string
b, storage := createBackendWithStorage(t)
// Save a role with out LowerCaseRoleName set
role := &roleStorageEntry{
RoleID: "testroleid",
HMACKey: "testhmackey",
Policies: []string{"default"},
BindSecretID: true,
}
err = b.setRoleEntry(storage, "testRoleName", role, "")
if err != nil {
t.Fatal(err)
}
secretIDReq := &logical.Request{
Path: "role/testRoleName/secret-id",
Operation: logical.UpdateOperation,
Storage: storage,
}
resp, err = b.HandleRequest(secretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
secretID = resp.Data["secret_id"].(string)
roleID = "testroleid"
// Regular login flow. This should succeed.
resp, err = b.HandleRequest(&logical.Request{
Path: "login",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"role_id": roleID,
"secret_id": secretID,
},
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
// Lower case the role name when generating the secret id
secretIDReq.Path = "role/testrolename/secret-id"
resp, err = b.HandleRequest(secretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
secretID = resp.Data["secret_id"].(string)
// Login should fail
resp, err = b.HandleRequest(&logical.Request{
Path: "login",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"role_id": roleID,
"secret_id": secretID,
},
})
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected an error")
}
// Delete the role and create it again. This time don't directly persist
// it, but route the request to the creation handler so that it sets the
// LowerCaseRoleName to true.
resp, err = b.HandleRequest(&logical.Request{
Path: "role/testRoleName",
Operation: logical.DeleteOperation,
Storage: storage,
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
roleReq := &logical.Request{
Path: "role/testRoleName",
Operation: logical.CreateOperation,
Storage: storage,
Data: map[string]interface{}{
"bind_secret_id": true,
},
}
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
// Create secret id with lower cased role name
resp, err = b.HandleRequest(&logical.Request{
Path: "role/testrolename/secret-id",
Operation: logical.UpdateOperation,
Storage: storage,
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
secretID = resp.Data["secret_id"].(string)
resp, err = b.HandleRequest(&logical.Request{
Path: "role/testrolename/role-id",
Operation: logical.ReadOperation,
Storage: storage,
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
roleID = resp.Data["role_id"].(string)
// Login should pass
resp, err = b.HandleRequest(&logical.Request{
Path: "login",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"role_id": roleID,
"secret_id": secretID,
},
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr:%v", resp, err)
}
// Lookup of secret ID should work in case-insensitive manner
resp, err = b.HandleRequest(&logical.Request{
Path: "role/testrolename/secret-id/lookup",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"secret_id": secretID,
},
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
if resp == nil {
t.Fatalf("failed to lookup secret IDs")
}
// Listing of secret IDs should work in case-insensitive manner
resp, err = b.HandleRequest(&logical.Request{
Path: "role/testrolename/secret-id",
Operation: logical.ListOperation,
Storage: storage,
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
}
if len(resp.Data["keys"].([]string)) != 1 {
t.Fatalf("failed to list secret IDs")
}
}
func TestAppRole_RoleReadSetIndex(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
roleReq := &logical.Request{
Path: "role/testrole",
Operation: logical.CreateOperation,
Storage: storage,
Data: map[string]interface{}{
"bind_secret_id": true,
},
}
// Create a role
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
}
roleIDReq := &logical.Request{
Path: "role/testrole/role-id",
Operation: logical.ReadOperation,
Storage: storage,
}
// Get the role ID
resp, err = b.HandleRequest(roleIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
}
roleID := resp.Data["role_id"].(string)
// Delete the role ID index
err = b.roleIDEntryDelete(storage, roleID)
if err != nil {
t.Fatal(err)
}
// Read the role again. This should add the index and return a warning
roleReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
}
// Check if the warning is being returned
if !strings.Contains(resp.Warnings[0], "Role identifier was missing an index back to role name.") {
t.Fatalf("bad: expected a warning in the response")
}
roleIDIndex, err := b.roleIDEntry(storage, roleID)
if err != nil {
t.Fatal(err)
}
// Check if the index has been successfully created
if roleIDIndex == nil || roleIDIndex.Name != "testrole" {
t.Fatalf("bad: expected role to have an index")
}
roleReq.Operation = logical.UpdateOperation
roleReq.Data = map[string]interface{}{
"bind_secret_id": true,
"policies": "default",
}
// Check if updating and reading of roles work and that there are no lock
// contentions dangling due to previous operation
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
}
roleReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
}
}
func TestAppRole_CIDRSubset(t *testing.T) {
var resp *logical.Response
var err error

View File

@@ -75,15 +75,19 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
return nil, "", err
}
if roleIDIndex == nil {
return nil, "", fmt.Errorf("failed to find secondary index for role_id %q\n", roleID)
return nil, "", fmt.Errorf("invalid role_id %q\n", roleID)
}
lock := b.roleLock(roleIDIndex.Name)
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(s, roleIDIndex.Name)
if err != nil {
return nil, "", err
}
if role == nil {
return nil, "", fmt.Errorf("role %q referred by the SecretID does not exist", roleIDIndex.Name)
return nil, "", fmt.Errorf("role %q referred by the role_id %q does not exist anymore", roleIDIndex.Name, roleID)
}
return role, roleIDIndex.Name, nil
@@ -121,6 +125,10 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
return nil, "", metadata, "", fmt.Errorf("missing secret_id")
}
if role.LowerCaseRoleName {
roleName = strings.ToLower(roleName)
}
// Check if the SecretID supplied is valid. If use limit was specified
// on the SecretID, it will be decremented in this call.
var valid bool

View File

@@ -99,6 +99,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
LocalStorage: []string{
"whitelist/identity/",
},
SealWrapStorage: []string{
"config/client",
},
},
Paths: []*framework.Path{
pathLogin(b),

View File

@@ -1125,6 +1125,11 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
t.Fatalf("instance ID not present in the response object")
}
_, ok := resp.Auth.Metadata["nonce"]
if ok {
t.Fatalf("client nonce should not have been returned")
}
loginInput["nonce"] = "changed-vault-client-nonce"
// try to login again with changed nonce
resp, err = b.HandleRequest(loginRequest)
@@ -1159,7 +1164,9 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
t.Fatalf("failed to delete whitelist identity")
}
// Allow a fresh login.
// Allow a fresh login without supplying the nonce
delete(loginInput, "nonce")
resp, err = b.HandleRequest(loginRequest)
if err != nil {
t.Fatal(err)
@@ -1167,6 +1174,11 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
if resp == nil || resp.Auth == nil || resp.IsError() {
t.Fatalf("login attempt failed")
}
_, ok = resp.Auth.Metadata["nonce"]
if !ok {
t.Fatalf("expected nonce to be returned")
}
}
func TestBackend_pathStsConfig(t *testing.T) {

View File

@@ -34,6 +34,11 @@ func GenerateLoginData(accessKey, secretKey, sessionToken, headerValue string) (
return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
}
_, err = creds.Get()
if err != nil {
return nil, fmt.Errorf("failed to retrieve credentials from credential chain: %v", err)
}
// Use the credentials we've found to construct an STS session
stsSession, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Credentials: creds},

View File

@@ -643,7 +643,7 @@ func (b *backend) pathLoginUpdateEc2(
return logical.ErrorResponse(err.Error()), nil
}
// Don't let subsequent login attempts to bypass in initial
// Don't let subsequent login attempts to bypass the initial
// intent of disabling reauthentication, despite the properties
// of role getting updated. For example: Role has the value set
// to 'false', a role-tag login sets the value to 'true', then
@@ -693,7 +693,6 @@ func (b *backend) pathLoginUpdateEc2(
if roleTagResp != nil {
// Role tag is enabled on the role.
//
// Overwrite the policies with the ones returned from processing the role tag
// If there are no policies on the role tag, policies on the role are inherited.
@@ -777,8 +776,9 @@ func (b *backend) pathLoginUpdateEc2(
},
}
// Return the nonce only if reauthentication is allowed
if !disallowReauthentication {
// Return the nonce only if reauthentication is allowed and if the nonce
// was not supplied by the user.
if !disallowReauthentication && !clientNonceSupplied {
// Echo the client nonce back. If nonce param was not supplied
// to the endpoint at all (setting it to empty string does not
// qualify here), callers should extract out the nonce from
@@ -786,23 +786,15 @@ func (b *backend) pathLoginUpdateEc2(
resp.Auth.Metadata["nonce"] = clientNonce
}
if roleEntry.Period > time.Duration(0) {
resp.Auth.TTL = roleEntry.Period
} else {
// Cap the TTL value.
shortestTTL := b.System().DefaultLeaseTTL()
if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL {
shortestTTL = roleEntry.TTL
if roleEntry.MaxTTL > time.Duration(0) {
// Cap TTL to shortestMaxTTL
if resp.Auth.TTL > shortestMaxTTL {
resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (shortestMaxTTL / time.Second)))
resp.Auth.TTL = shortestMaxTTL
}
if shortestMaxTTL < shortestTTL {
resp.AddWarning(fmt.Sprintf("Effective ttl of %q exceeded the effective max_ttl of %q; ttl value is capped appropriately", (shortestTTL / time.Second).String(), (shortestMaxTTL / time.Second).String()))
shortestTTL = shortestMaxTTL
}
resp.Auth.TTL = shortestTTL
}
return resp, nil
}
// handleRoleTagLogin is used to fetch the role tag of the instance and
@@ -985,13 +977,12 @@ func (b *backend) pathLoginRenewIam(
}
}
// If 'Period' is set on the role, then the token should never expire.
if roleEntry.Period > time.Duration(0) {
req.Auth.TTL = roleEntry.Period
return &logical.Response{Auth: req.Auth}, nil
} else {
return framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(req, data)
resp, err := framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(req, data)
if err != nil {
return nil, err
}
resp.Auth.Period = roleEntry.Period
return resp, nil
}
func (b *backend) pathLoginRenewEc2(
@@ -1072,24 +1063,12 @@ func (b *backend) pathLoginRenewEc2(
return nil, err
}
// If 'Period' is set on the role, then the token should never expire. Role
// tag does not have a 'Period' field. So, regarless of whether the token
// was issued using a role login or a role tag login, the period set on the
// role should take effect.
if roleEntry.Period > time.Duration(0) {
req.Auth.TTL = roleEntry.Period
return &logical.Response{Auth: req.Auth}, nil
} else {
// Cap the TTL value
shortestTTL := b.System().DefaultLeaseTTL()
if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL {
shortestTTL = roleEntry.TTL
}
if shortestMaxTTL < shortestTTL {
shortestTTL = shortestMaxTTL
}
return framework.LeaseExtend(shortestTTL, shortestMaxTTL, b.System())(req, data)
resp, err := framework.LeaseExtend(roleEntry.TTL, shortestMaxTTL, b.System())(req, data)
if err != nil {
return nil, err
}
resp.Auth.Period = roleEntry.Period
return resp, nil
}
func (b *backend) pathLoginUpdateIam(
@@ -1238,7 +1217,7 @@ func (b *backend) pathLoginUpdateIam(
policies := roleEntry.Policies
inferredEntityType := ""
inferredEntityId := ""
inferredEntityID := ""
if roleEntry.InferredEntityType == ec2EntityType {
instance, err := b.validateInstance(req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
if err != nil {
@@ -1264,7 +1243,7 @@ func (b *backend) pathLoginUpdateIam(
}
inferredEntityType = ec2EntityType
inferredEntityId = entity.SessionInfo
inferredEntityID = entity.SessionInfo
}
resp := &logical.Response{
@@ -1277,7 +1256,7 @@ func (b *backend) pathLoginUpdateIam(
"client_user_id": callerUniqueId,
"auth_type": iamAuthType,
"inferred_entity_type": inferredEntityType,
"inferred_entity_id": inferredEntityId,
"inferred_entity_id": inferredEntityID,
"inferred_aws_region": roleEntry.InferredAWSRegion,
"account_id": entity.AccountNumber,
},
@@ -1295,25 +1274,18 @@ func (b *backend) pathLoginUpdateIam(
},
}
if roleEntry.Period > time.Duration(0) {
resp.Auth.TTL = roleEntry.Period
} else {
shortestTTL := b.System().DefaultLeaseTTL()
if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL {
shortestTTL = roleEntry.TTL
if roleEntry.MaxTTL > time.Duration(0) {
// Cap maxTTL to the sysview's max TTL
maxTTL := roleEntry.MaxTTL
if maxTTL > b.System().MaxLeaseTTL() {
maxTTL = b.System().MaxLeaseTTL()
}
maxTTL := b.System().MaxLeaseTTL()
if roleEntry.MaxTTL > time.Duration(0) && roleEntry.MaxTTL < maxTTL {
maxTTL = roleEntry.MaxTTL
// Cap TTL to MaxTTL
if resp.Auth.TTL > maxTTL {
resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (maxTTL / time.Second)))
resp.Auth.TTL = maxTTL
}
if shortestTTL > maxTTL {
resp.AddWarning(fmt.Sprintf("Effective TTL of %q exceeded the effective max_ttl of %q; TTL value is capped accordingly", (shortestTTL / time.Second).String(), (maxTTL / time.Second).String()))
shortestTTL = maxTTL
}
resp.Auth.TTL = shortestTTL
}
return resp, nil
@@ -1333,11 +1305,11 @@ func hasValuesForEc2Auth(data *framework.FieldData) (bool, bool) {
func hasValuesForIamAuth(data *framework.FieldData) (bool, bool) {
_, hasRequestMethod := data.GetOk("iam_http_request_method")
_, hasRequestUrl := data.GetOk("iam_request_url")
_, hasRequestURL := data.GetOk("iam_request_url")
_, hasRequestBody := data.GetOk("iam_request_body")
_, hasRequestHeaders := data.GetOk("iam_request_headers")
return (hasRequestMethod && hasRequestUrl && hasRequestBody && hasRequestHeaders),
(hasRequestMethod || hasRequestUrl || hasRequestBody || hasRequestHeaders)
return (hasRequestMethod && hasRequestURL && hasRequestBody && hasRequestHeaders),
(hasRequestMethod || hasRequestURL || hasRequestBody || hasRequestHeaders)
}
func parseIamArn(iamArn string) (*iamEntity, error) {

View File

@@ -663,6 +663,10 @@ func (b *backend) pathRoleCreateUpdate(
roleEntry.AllowInstanceMigration = data.Get("allow_instance_migration").(bool)
}
if roleEntry.AllowInstanceMigration && roleEntry.DisallowReauthentication {
return logical.ErrorResponse("cannot specify both disallow_reauthentication=true and allow_instance_migration=true"), nil
}
var resp logical.Response
ttlRaw, ok := data.GetOk("ttl")

View File

@@ -124,6 +124,10 @@ func (b *backend) pathRoleTagUpdate(
resp.AddWarning("Role does not allow instance migration. Login will not be allowed with this tag unless the role value is updated.")
}
if disallowReauthentication && allowInstanceMigration {
return logical.ErrorResponse("cannot set both disallow_reauthentication and allow_instance_migration"), nil
}
// max_ttl for the role tag should be less than the max_ttl set on the role.
maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second

View File

@@ -66,12 +66,25 @@ func TestBackend_pathRoleEc2(t *testing.T) {
Data: data,
Storage: storage,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to create role: %s", resp.Data["error"])
}
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected failure to create role with both allow_instance_migration true and disallow_reauthentication true")
}
data["disallow_reauthentication"] = false
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.UpdateOperation,
Path: "role/ami-abcd123",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failure to update role: %v", resp.Data["error"])
}
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.ReadOperation,
Path: "role/ami-abcd123",
@@ -80,8 +93,12 @@ func TestBackend_pathRoleEc2(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !resp.Data["allow_instance_migration"].(bool) || !resp.Data["disallow_reauthentication"].(bool) {
t.Fatal("bad: expected:true got:false\n")
if !resp.Data["allow_instance_migration"].(bool) {
t.Fatal("bad: expected allow_instance_migration:true got:false\n")
}
if resp.Data["disallow_reauthentication"].(bool) {
t.Fatal("bad: expected disallow_reauthentication: false got:true\n")
}
// add another entry, to test listing of role entries
@@ -529,7 +546,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) {
"ttl": "10m",
"max_ttl": "20m",
"policies": "testpolicy1,testpolicy2",
"disallow_reauthentication": true,
"disallow_reauthentication": false,
"hmac_key": "testhmackey",
"period": "1m",
}
@@ -567,7 +584,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) {
"ttl": time.Duration(600),
"max_ttl": time.Duration(1200),
"policies": []string{"testpolicy1", "testpolicy2"},
"disallow_reauthentication": true,
"disallow_reauthentication": false,
"period": time.Duration(60),
}

View File

@@ -587,7 +587,7 @@ func TestBackend_CRLs(t *testing.T) {
func testFactory(t *testing.T) logical.Backend {
b, err := Factory(&logical.BackendConfig{
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 300 * time.Second,
DefaultLeaseTTLVal: 1000 * time.Second,
MaxLeaseTTLVal: 1800 * time.Second,
},
StorageView: &logical.InmemStorage{},
@@ -619,9 +619,9 @@ func TestBackend_CertWrites(t *testing.T) {
tc := logicaltest.TestCase{
Backend: testFactory(t),
Steps: []logicaltest.TestStep{
testAccStepCert(t, "aaa", ca1, "foo", "", false),
testAccStepCert(t, "bbb", ca2, "foo", "", false),
testAccStepCert(t, "ccc", ca3, "foo", "", true),
testAccStepCert(t, "aaa", ca1, "foo", "", "", false),
testAccStepCert(t, "bbb", ca2, "foo", "", "", false),
testAccStepCert(t, "ccc", ca3, "foo", "", "", true),
},
}
tc.Steps = append(tc.Steps, testAccStepListCerts(t, []string{"aaa", "bbb"})...)
@@ -642,16 +642,18 @@ func TestBackend_basic_CA(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Backend: testFactory(t),
Steps: []logicaltest.TestStep{
testAccStepCert(t, "web", ca, "foo", "", false),
testAccStepCert(t, "web", ca, "foo", "", "", false),
testAccStepLogin(t, connState),
testAccStepCertLease(t, "web", ca, "foo"),
testAccStepCertTTL(t, "web", ca, "foo"),
testAccStepLogin(t, connState),
testAccStepCertMaxTTL(t, "web", ca, "foo"),
testAccStepLogin(t, connState),
testAccStepCertNoLease(t, "web", ca, "foo"),
testAccStepLoginDefaultLease(t, connState),
testAccStepCert(t, "web", ca, "foo", "*.example.com", false),
testAccStepCert(t, "web", ca, "foo", "*.example.com", "", false),
testAccStepLogin(t, connState),
testAccStepCert(t, "web", ca, "foo", "*.invalid.com", false),
testAccStepCert(t, "web", ca, "foo", "*.invalid.com", "", false),
testAccStepLoginInvalid(t, connState),
},
})
@@ -700,11 +702,68 @@ func TestBackend_basic_singleCert(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Backend: testFactory(t),
Steps: []logicaltest.TestStep{
testAccStepCert(t, "web", ca, "foo", "", false),
testAccStepCert(t, "web", ca, "foo", "", "", false),
testAccStepLogin(t, connState),
testAccStepCert(t, "web", ca, "foo", "example.com", false),
testAccStepCert(t, "web", ca, "foo", "example.com", "", false),
testAccStepLogin(t, connState),
testAccStepCert(t, "web", ca, "foo", "invalid", false),
testAccStepCert(t, "web", ca, "foo", "invalid", "", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "", "1.2.3.4:invalid", false),
testAccStepLoginInvalid(t, connState),
},
})
}
// Test a self-signed client with custom extensions (root CA) that is trusted
func TestBackend_extensions_singleCert(t *testing.T) {
connState, err := testConnState(
"test-fixtures/root/rootcawextcert.pem",
"test-fixtures/root/rootcawextkey.pem",
"test-fixtures/root/rootcacert.pem",
)
if err != nil {
t.Fatalf("error testing connection state: %v", err)
}
ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem")
if err != nil {
t.Fatalf("err: %v", err)
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: testFactory(t),
Steps: []logicaltest.TestStep{
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:A UTF8String Extension", false),
testAccStepLogin(t, connState),
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:*,2.1.1.2:A UTF8*", false),
testAccStepLogin(t, connState),
testAccStepCert(t, "web", ca, "foo", "", "1.2.3.45:*", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:The Wrong Value", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:*,2.1.1.2:The Wrong Value", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:,2.1.1.2:*", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:A UTF8String Extension", false),
testAccStepLogin(t, connState),
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:*,2.1.1.2:A UTF8*", false),
testAccStepLogin(t, connState),
testAccStepCert(t, "web", ca, "foo", "example.com", "1.2.3.45:*", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:The Wrong Value", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:*,2.1.1.2:The Wrong Value", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:A UTF8String Extension", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:*,2.1.1.2:A UTF8*", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "invalid", "1.2.3.45:*", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:The Wrong Value", false),
testAccStepLoginInvalid(t, connState),
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:*,2.1.1.2:The Wrong Value", false),
testAccStepLoginInvalid(t, connState),
},
})
@@ -724,9 +783,9 @@ func TestBackend_mixed_constraints(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Backend: testFactory(t),
Steps: []logicaltest.TestStep{
testAccStepCert(t, "1unconstrained", ca, "foo", "", false),
testAccStepCert(t, "2matching", ca, "foo", "*.example.com,whatever", false),
testAccStepCert(t, "3invalid", ca, "foo", "invalid", false),
testAccStepCert(t, "1unconstrained", ca, "foo", "", "", false),
testAccStepCert(t, "2matching", ca, "foo", "*.example.com,whatever", "", false),
testAccStepCert(t, "3invalid", ca, "foo", "invalid", "", false),
testAccStepLogin(t, connState),
// Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match
testAccStepLoginWithName(t, connState, "2matching"),
@@ -826,7 +885,7 @@ func testAccStepLoginDefaultLease(t *testing.T, connState tls.ConnectionState) l
Unauthenticated: true,
ConnState: &connState,
Check: func(resp *logical.Response) error {
if resp.Auth.TTL != 300*time.Second {
if resp.Auth.TTL != 1000*time.Second {
t.Fatalf("bad lease length: %#v", resp.Auth)
}
@@ -906,17 +965,18 @@ func testAccStepListCerts(
}
func testAccStepCert(
t *testing.T, name string, cert []byte, policies string, allowedNames string, expectError bool) logicaltest.TestStep {
t *testing.T, name string, cert []byte, policies string, allowedNames string, requiredExtensions string, expectError bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "certs/" + name,
ErrorOk: expectError,
Data: map[string]interface{}{
"certificate": string(cert),
"policies": policies,
"display_name": name,
"allowed_names": allowedNames,
"lease": 1000,
"certificate": string(cert),
"policies": policies,
"display_name": name,
"allowed_names": allowedNames,
"required_extensions": requiredExtensions,
"lease": 1000,
},
Check: func(resp *logical.Response) error {
if resp == nil && expectError {
@@ -955,6 +1015,21 @@ func testAccStepCertTTL(
}
}
func testAccStepCertMaxTTL(
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "certs/" + name,
Data: map[string]interface{}{
"certificate": string(cert),
"policies": policies,
"display_name": name,
"ttl": "1000s",
"max_ttl": "1200s",
},
}
}
func testAccStepCertNoLease(
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
return logicaltest.TestStep{

View File

@@ -45,6 +45,13 @@ Must be x509 PEM encoded.`,
At least one must exist in either the Common Name or SANs. Supports globbing.`,
},
"required_extensions": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `A comma-separated string or array of extensions
formatted as "oid:value". Expects the extension value to be some type of ASN1 encoded string.
All values much match. Supports globbing on "value".`,
},
"display_name": &framework.FieldSchema{
Type: framework.TypeString,
Description: `The display name to use for clients using this
@@ -67,6 +74,19 @@ seconds. Defaults to system/backend default TTL.`,
Description: `TTL for tokens issued by this backend.
Defaults to system/backend default TTL time.`,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Duration in either an integer number of seconds (3600) or
an integer time unit (60m) after which the
issued token can no longer be renewed.`,
},
"period": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `If set, indicates that the token generated using this role
should never expire. The token should be renewed within the
duration specified by this value. At each renewal, the token's
TTL will be set to the value of this parameter.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -124,17 +144,14 @@ func (b *backend) pathCertRead(
return nil, nil
}
duration := cert.TTL
if duration == 0 {
duration = b.System().DefaultLeaseTTL()
}
return &logical.Response{
Data: map[string]interface{}{
"certificate": cert.Certificate,
"display_name": cert.DisplayName,
"policies": cert.Policies,
"ttl": duration / time.Second,
"ttl": cert.TTL / time.Second,
"max_ttl": cert.MaxTTL / time.Second,
"period": cert.Period / time.Second,
},
}, nil
}
@@ -146,6 +163,48 @@ func (b *backend) pathCertWrite(
displayName := d.Get("display_name").(string)
policies := policyutil.ParsePolicies(d.Get("policies"))
allowedNames := d.Get("allowed_names").([]string)
requiredExtensions := d.Get("required_extensions").([]string)
var resp logical.Response
// Parse the ttl (or lease duration)
systemDefaultTTL := b.System().DefaultLeaseTTL()
ttl := time.Duration(d.Get("ttl").(int)) * time.Second
if ttl == 0 {
ttl = time.Duration(d.Get("lease").(int)) * time.Second
}
if ttl > systemDefaultTTL {
resp.AddWarning(fmt.Sprintf("Given ttl of %d seconds is greater than current mount/system default of %d seconds", ttl/time.Second, systemDefaultTTL/time.Second))
}
if ttl < time.Duration(0) {
return logical.ErrorResponse("ttl cannot be negative"), nil
}
// Parse max_ttl
systemMaxTTL := b.System().MaxLeaseTTL()
maxTTL := time.Duration(d.Get("max_ttl").(int)) * time.Second
if maxTTL > systemMaxTTL {
resp.AddWarning(fmt.Sprintf("Given max_ttl of %d seconds is greater than current mount/system default of %d seconds", maxTTL/time.Second, systemMaxTTL/time.Second))
}
if maxTTL < time.Duration(0) {
return logical.ErrorResponse("max_ttl cannot be negative"), nil
}
if maxTTL != 0 && ttl > maxTTL {
return logical.ErrorResponse("ttl should be shorter than max_ttl"), nil
}
// Parse period
period := time.Duration(d.Get("period").(int)) * time.Second
if period > systemMaxTTL {
resp.AddWarning(fmt.Sprintf("Given period of %d seconds is greater than the backend's maximum TTL of %d seconds", period/time.Second, systemMaxTTL/time.Second))
}
if period < time.Duration(0) {
return logical.ErrorResponse("period cannot be negative"), nil
}
// Default the display name to the certificate name if not given
if displayName == "" {
@@ -172,24 +231,15 @@ func (b *backend) pathCertWrite(
}
certEntry := &CertEntry{
Name: name,
Certificate: certificate,
DisplayName: displayName,
Policies: policies,
AllowedNames: allowedNames,
}
// Parse the lease duration or default to backend/system default
maxTTL := b.System().MaxLeaseTTL()
ttl := time.Duration(d.Get("ttl").(int)) * time.Second
if ttl == time.Duration(0) {
ttl = time.Second * time.Duration(d.Get("lease").(int))
}
if ttl > maxTTL {
return logical.ErrorResponse(fmt.Sprintf("Given TTL of %d seconds greater than current mount/system default of %d seconds", ttl/time.Second, maxTTL/time.Second)), nil
}
if ttl > time.Duration(0) {
certEntry.TTL = ttl
Name: name,
Certificate: certificate,
DisplayName: displayName,
Policies: policies,
AllowedNames: allowedNames,
RequiredExtensions: requiredExtensions,
TTL: ttl,
MaxTTL: maxTTL,
Period: period,
}
// Store it
@@ -200,16 +250,24 @@ func (b *backend) pathCertWrite(
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
if len(resp.Warnings) == 0 {
return nil, nil
}
return &resp, nil
}
type CertEntry struct {
Name string
Certificate string
DisplayName string
Policies []string
TTL time.Duration
AllowedNames []string
Name string
Certificate string
DisplayName string
Policies []string
TTL time.Duration
MaxTTL time.Duration
Period time.Duration
AllowedNames []string
RequiredExtensions []string
}
const pathCertHelpSyn = `

View File

@@ -4,11 +4,13 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/policyutil"
@@ -84,9 +86,9 @@ func (b *backend) pathLogin(
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
// Generate a response
resp := &logical.Response{
Auth: &logical.Auth{
Period: matched.Entry.Period,
InternalData: map[string]interface{}{
"subject_key_id": skid,
"authority_key_id": akid,
@@ -108,6 +110,22 @@ func (b *backend) pathLogin(
},
},
}
if matched.Entry.MaxTTL > time.Duration(0) {
// Cap maxTTL to the sysview's max TTL
maxTTL := matched.Entry.MaxTTL
if maxTTL > b.System().MaxLeaseTTL() {
maxTTL = b.System().MaxLeaseTTL()
}
// Cap TTL to MaxTTL
if resp.Auth.TTL > maxTTL {
resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (maxTTL / time.Second)))
resp.Auth.TTL = maxTTL
}
}
// Generate a response
return resp, nil
}
@@ -134,7 +152,7 @@ func (b *backend) pathLoginRenew(
clientCerts := req.Connection.ConnState.PeerCertificates
if len(clientCerts) == 0 {
return nil, fmt.Errorf("no client certificate found")
return logical.ErrorResponse("no client certificate found"), nil
}
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
@@ -160,7 +178,12 @@ func (b *backend) pathLoginRenew(
return nil, fmt.Errorf("policies have changed, not renewing")
}
return framework.LeaseExtend(cert.TTL, 0, b.System())(req, d)
resp, err := framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(req, d)
if err != nil {
return nil, err
}
resp.Auth.Period = cert.Period
return resp, nil
}
func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
@@ -237,28 +260,70 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
}
func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool {
return !b.checkForChainInCRLs(trustedChain) &&
b.matchesNames(clientCert, config) &&
b.matchesCertificateExtenions(clientCert, config)
}
// matchesNames verifies that the certificate matches at least one configured
// allowed name
func (b *backend) matchesNames(clientCert *x509.Certificate, config *ParsedCert) bool {
// Default behavior (no names) is to allow all names
nameMatched := len(config.Entry.AllowedNames) == 0
if len(config.Entry.AllowedNames) == 0 {
return true
}
// At least one pattern must match at least one name if any patterns are specified
for _, allowedName := range config.Entry.AllowedNames {
if glob.Glob(allowedName, clientCert.Subject.CommonName) {
nameMatched = true
return true
}
for _, name := range clientCert.DNSNames {
if glob.Glob(allowedName, name) {
nameMatched = true
return true
}
}
for _, name := range clientCert.EmailAddresses {
if glob.Glob(allowedName, name) {
nameMatched = true
return true
}
}
}
return false
}
return !b.checkForChainInCRLs(trustedChain) && nameMatched
// matchesCertificateExtenions verifies that the certificate matches configured
// required extensions
func (b *backend) matchesCertificateExtenions(clientCert *x509.Certificate, config *ParsedCert) bool {
// If no required extensions, nothing to check here
if len(config.Entry.RequiredExtensions) == 0 {
return true
}
// Fail fast if we have required extensions but no extensions on the cert
if len(clientCert.Extensions) == 0 {
return false
}
// Build Client Extensions Map for Constraint Matching
// x509 Writes Extensions in ASN1 with a bitstring tag, which results in the field
// including its ASN.1 type tag bytes. For the sake of simplicity, assume string type
// and drop the tag bytes. And get the number of bytes from the tag.
clientExtMap := make(map[string]string, len(clientCert.Extensions))
for _, ext := range clientCert.Extensions {
var parsedValue string
asn1.Unmarshal(ext.Value, &parsedValue)
clientExtMap[ext.Id.String()] = parsedValue
}
// If any of the required extensions don't match the constraint fails
for _, requiredExt := range config.Entry.RequiredExtensions {
reqExt := strings.SplitN(requiredExt, ":", 2)
clientExtValue, clientExtValueOk := clientExtMap[reqExt[0]]
if !clientExtValueOk || !glob.Glob(reqExt[1], clientExtValue) {
return false
}
}
return true
}
// loadTrustedCerts is used to load all the trusted certificates from the backend

View File

@@ -0,0 +1 @@
92223EAFBBEE17A3

View File

@@ -0,0 +1,21 @@
[ req ]
default_bits = 2048
encrypt_key = no
prompt = no
default_md = sha256
req_extensions = req_v3
distinguished_name = dn
[ dn ]
CN = example.com
[ req_v3 ]
subjectAltName = @alt_names
2.1.1.1=ASN1:UTF8String:A UTF8String Extension
2.1.1.2=ASN1:UTF8:A UTF8 Extension
2.1.1.3=ASN1:IA5:An IA5 Extension
2.1.1.4=ASN1:VISIBLE:A Visible Extension
[ alt_names ]
DNS.1 = example.com
IP.1 = 127.0.0.1

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE REQUEST-----
MIIDAzCCAesCAQAwFjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQDM2PrLyK/wVQIcnK362ZylDrIVMjFQzps/0AxM
ke+8MNPMArBlSAhnZus6qb0nN0nJrDLkHQgYqnSvK9N7VUv/xFblEcOLBlciLhyN
Wkm92+q/M/xOvUVmnYkN3XgTI5QNxF7ZWDFHmwCNV27RraQZou0hG7yvyoILLMQE
3MnMCNM1nZ9JIuBMcRsZLGqQ1XNaQljboRVIUjimzkcfYyTruhLosTIbwForp78J
MzHHqVjtLJXPqUnRMS7KhGMj1f2mIswQzCv6F2PWEzNBbP4Gb67znKikKDs0RgyL
RyfizFNFJSC58XntK8jwHK1D8W3UepFf4K8xNFnhPoKWtWfJAgMBAAGggacwgaQG
CSqGSIb3DQEJDjGBljCBkzAcBgNVHREEFTATggtleGFtcGxlLmNvbYcEfwAAATAf
BgNRAQEEGAwWQSBVVEY4U3RyaW5nIEV4dGVuc2lvbjAZBgNRAQIEEgwQQSBVVEY4
IEV4dGVuc2lvbjAZBgNRAQMEEhYQQW4gSUE1IEV4dGVuc2lvbjAcBgNRAQQEFRoT
QSBWaXNpYmxlIEV4dGVuc2lvbjANBgkqhkiG9w0BAQsFAAOCAQEAtYjewBcqAXxk
tDY0lpZid6ZvfngdDlDZX0vrs3zNppKNe5Sl+jsoDOexqTA7HQA/y1ru117sAEeB
yiqMeZ7oPk8b3w+BZUpab7p2qPMhZypKl93y/jGXGscc3jRbUBnym9S91PSq6wUd
f2aigSqFc9+ywFVdx5PnnZUfcrUQ2a+AweYEkGOzXX2Ga+Ige8grDMCzRgCoP5cW
kM5ghwZp5wYIBGrKBU9iDcBlmnNhYaGWf+dD00JtVDPNn2bJnCsJHIO0nklZgnrS
fli8VQ1nYPkONdkiRYLt6//6at1iNDoDgsVCChtlVkLpxFIKcDFUHlffZsc1kMFI
HTX579k8hA==
-----END CERTIFICATE REQUEST-----

View File

@@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDRjCCAi6gAwIBAgIJAJIiPq+77hejMA0GCSqGSIb3DQEBCwUAMBYxFDASBgNV
BAMTC2V4YW1wbGUuY29tMB4XDTE3MTEyOTE5MTgwM1oXDTI3MTEyNzE5MTgwM1ow
FjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
ggEKAoIBAQDM2PrLyK/wVQIcnK362ZylDrIVMjFQzps/0AxMke+8MNPMArBlSAhn
Zus6qb0nN0nJrDLkHQgYqnSvK9N7VUv/xFblEcOLBlciLhyNWkm92+q/M/xOvUVm
nYkN3XgTI5QNxF7ZWDFHmwCNV27RraQZou0hG7yvyoILLMQE3MnMCNM1nZ9JIuBM
cRsZLGqQ1XNaQljboRVIUjimzkcfYyTruhLosTIbwForp78JMzHHqVjtLJXPqUnR
MS7KhGMj1f2mIswQzCv6F2PWEzNBbP4Gb67znKikKDs0RgyLRyfizFNFJSC58Xnt
K8jwHK1D8W3UepFf4K8xNFnhPoKWtWfJAgMBAAGjgZYwgZMwHAYDVR0RBBUwE4IL
ZXhhbXBsZS5jb22HBH8AAAEwHwYDUQEBBBgMFkEgVVRGOFN0cmluZyBFeHRlbnNp
b24wGQYDUQECBBIMEEEgVVRGOCBFeHRlbnNpb24wGQYDUQEDBBIWEEFuIElBNSBF
eHRlbnNpb24wHAYDUQEEBBUaE0EgVmlzaWJsZSBFeHRlbnNpb24wDQYJKoZIhvcN
AQELBQADggEBAGU/iA6saupEaGn/veVNCknFGDL7pst5D6eX/y9atXlBOdJe7ZJJ
XQRkeHJldA0khVpzH7Ryfi+/25WDuNz+XTZqmb4ppeV8g9amtqBwxziQ9UUwYrza
eDBqdXBaYp/iHUEHoceX4F44xuo80BIqwF0lD9TFNUFoILnF26ajhKX0xkGaiKTH
6SbjBfHoQVMzOHokVRWregmgNycV+MAI9Ne9XkIZvdOYeNlcS9drZeJI3szkiaxB
WWaWaAr5UU2Z0yUCZnAIDMRcIiUbSEjIDz504sSuCzTctMOxWZu0r/0UrXRzwZZi
HAaKm3MUmBh733ChP4rTB58nr5DEr5rJ9P8=
-----END CERTIFICATE-----

View File

@@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDM2PrLyK/wVQIc
nK362ZylDrIVMjFQzps/0AxMke+8MNPMArBlSAhnZus6qb0nN0nJrDLkHQgYqnSv
K9N7VUv/xFblEcOLBlciLhyNWkm92+q/M/xOvUVmnYkN3XgTI5QNxF7ZWDFHmwCN
V27RraQZou0hG7yvyoILLMQE3MnMCNM1nZ9JIuBMcRsZLGqQ1XNaQljboRVIUjim
zkcfYyTruhLosTIbwForp78JMzHHqVjtLJXPqUnRMS7KhGMj1f2mIswQzCv6F2PW
EzNBbP4Gb67znKikKDs0RgyLRyfizFNFJSC58XntK8jwHK1D8W3UepFf4K8xNFnh
PoKWtWfJAgMBAAECggEAW7hLkzMok9N8PpNo0wjcuor58cOnkSbxHIFrAF3XmcvD
CXWqxa6bFLFgYcPejdCTmVkg8EKPfXvVAxn8dxyaCss+nRJ3G6ibGxLKdgAXRItT
cIk2T4svp+KhmzOur+MeR4vFbEuwxP8CIEclt3yoHVJ2Gnzw30UtNRO2MPcq48/C
ZODGeBqUif1EGjDAvlqu5kl/pcDBJ3ctIZdVUMYYW4R9JtzKsmwhX7CRCBm8k5hG
2uzn8AKwpuVtfWcnX59UUmHGJ8mjETuNLARRAwWBWhl8f7wckmi+PKERJGEM2QE5
/Voy0p22zmQ3waS8LgiI7YHCAEFqjVWNziVGdR36gQKBgQDxkpfkEsfa5PieIaaF
iQOO0rrjEJ9MBOQqmTDeclmDPNkM9qvCF/dqpJfOtliYFxd7JJ3OR2wKrBb5vGHt
qIB51Rnm9aDTM4OUEhnhvbPlERD0W+yWYXWRvqyHz0GYwEFGQ83h95GC/qfTosqy
LEzYLDafiPeNP+DG/HYRljAxUwKBgQDZFOWHEcZkSFPLNZiksHqs90OR2zIFxZcx
SrbkjqXjRjehWEAwgpvQ/quSBxrE2E8xXgVm90G1JpWzxjUfKKQRM6solQeEpnwY
kCy2Ozij/TtbLNRlU65UQ+nMto8KTSIyJbxxdOZxYdtJAJQp1FJO1a1WC11z4+zh
lnLV1O5S8wKBgQCDf/QU4DBQtNGtas315Oa96XJ4RkUgoYz+r1NN09tsOERC7UgE
KP2y3JQSn2pMqE1M6FrKvlBO4uzC10xLja0aJOmrssvwDBu1D8FtA9IYgJjFHAEG
v1i7lJrgdu7TUtx1flVli1l3gF4lM3m5UaonBrJZV7rB9iLKzwUKf8IOJwKBgFt/
QktPA6brEV56Za8sr1hOFA3bLNdf9B0Tl8j4ExWbWAFKeCu6MUDCxsAS/IZxgdeW
AILovqpC7CBM78EFWTni5EaDohqYLYAQ7LeWeIYuSyFf4Nogjj74LQha/iliX4Jx
g17y3dp2W34Gn2yOEG8oAxpcSfR54jMnPZnBWP5fAoGBAMNAd3oa/xq9A5v719ik
naD7PdrjBdhnPk4egzMDv54y6pCFlvFbEiBduBWTmiVa7dSzhYtmEbri2WrgARlu
vkfTnVH9E8Hnm4HTbNn+ebxrofq1AOAvdApSoslsOP1NT9J6zB89RzChJyzjbIQR
Gevrutb4uO9qpB1jDVoMmGde
-----END PRIVATE KEY-----

View File

@@ -5,6 +5,7 @@ import (
"github.com/google/go-github/github"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"golang.org/x/oauth2"
@@ -35,11 +36,11 @@ func Backend() *backend {
}
allPaths := append(b.TeamMap.Paths(), b.UserMap.Paths()...)
b.Backend = &framework.Backend{
Help: backendHelp,
PathsSpecial: &logical.Paths{
Root: mfa.MFARootPaths(),
Unauthenticated: []string{
"login",
},
@@ -47,9 +48,7 @@ func Backend() *backend {
Paths: append([]*framework.Path{
pathConfig(&b),
pathLogin(&b),
}, allPaths...),
}, append(allPaths, mfa.MFAPaths(b.Backend, pathLogin(&b))...)...),
AuthRenew: b.pathLoginRenew,
BackendType: logical.TypeCredential,
}

View File

@@ -74,7 +74,7 @@ func (b *backend) pathLogin(
return logical.ErrorResponse(fmt.Sprintf("error sanitizing TTLs: %s", err)), nil
}
return &logical.Response{
resp := &logical.Response{
Auth: &logical.Auth{
InternalData: map[string]interface{}{
"token": token,
@@ -93,7 +93,18 @@ func (b *backend) pathLogin(
Name: *verifyResp.User.Login,
},
},
}, nil
}
for _, teamName := range verifyResp.TeamNames {
if teamName == "" {
continue
}
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
Name: teamName,
})
}
return resp, nil
}
func (b *backend) pathLoginRenew(
@@ -125,7 +136,22 @@ func (b *backend) pathLoginRenew(
if err != nil {
return nil, err
}
return framework.LeaseExtend(config.TTL, config.MaxTTL, b.System())(req, d)
resp, err := framework.LeaseExtend(config.TTL, config.MaxTTL, b.System())(req, d)
if err != nil {
return nil, err
}
// Remove old aliases
resp.Auth.GroupAliases = nil
for _, teamName := range verifyResp.TeamNames {
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
Name: teamName,
})
}
return resp, nil
}
func (b *backend) verifyCredentials(req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
@@ -233,14 +259,16 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
}
return &verifyCredentialsResp{
User: user,
Org: org,
Policies: append(groupPoliciesList, userPoliciesList...),
User: user,
Org: org,
Policies: append(groupPoliciesList, userPoliciesList...),
TeamNames: teamNames,
}, nil, nil
}
type verifyCredentialsResp struct {
User *github.User
Org *github.Organization
Policies []string
User *github.User
Org *github.Organization
Policies []string
TeamNames []string
}

View File

@@ -31,6 +31,10 @@ func Backend() *backend {
Unauthenticated: []string{
"login/*",
},
SealWrapStorage: []string{
"config",
},
},
Paths: append([]*framework.Path{
@@ -88,22 +92,22 @@ func EscapeLDAPValue(input string) string {
return input
}
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(req)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if cfg == nil {
return nil, logical.ErrorResponse("ldap backend not configured"), nil
return nil, logical.ErrorResponse("ldap backend not configured"), nil, nil
}
c, err := cfg.DialLDAP()
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
if c == nil {
return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil
return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil, nil
}
// Clean connection
@@ -111,7 +115,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
userBindDN, err := b.getUserBindDN(cfg, c, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
if b.Logger().IsDebug() {
@@ -119,7 +123,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
}
if cfg.DenyNullBind && len(password) == 0 {
return nil, logical.ErrorResponse("password cannot be of zero length when passwordless binds are being denied"), nil
return nil, logical.ErrorResponse("password cannot be of zero length when passwordless binds are being denied"), nil, nil
}
// Try to bind as the login user. This is where the actual authentication takes place.
@@ -129,14 +133,14 @@ func (b *backend) Login(req *logical.Request, username string, password string)
err = c.UnauthenticatedBind(userBindDN)
}
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil, nil
}
// We re-bind to the BindDN if it's defined because we assume
// the BindDN should be the one to search, not the user logging in.
if cfg.BindDN != "" && cfg.BindPassword != "" {
if err := c.Bind(cfg.BindDN, cfg.BindPassword); err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("Encountered an error while attempting to re-bind with the BindDN User: %s", err.Error())), nil
return nil, logical.ErrorResponse(fmt.Sprintf("Encountered an error while attempting to re-bind with the BindDN User: %s", err.Error())), nil, nil
}
if b.Logger().IsDebug() {
b.Logger().Debug("auth/ldap: Re-Bound to original BindDN")
@@ -145,12 +149,12 @@ func (b *backend) Login(req *logical.Request, username string, password string)
userDN, err := b.getUserDN(cfg, c, userBindDN)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
ldapGroups, err := b.getLdapGroups(cfg, c, userDN, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
if b.Logger().IsDebug() {
b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups)
@@ -199,10 +203,10 @@ func (b *backend) Login(req *logical.Request, username string, password string)
}
ldapResponse.Data["error"] = errStr
return nil, ldapResponse, nil
return nil, ldapResponse, nil, nil
}
return policies, ldapResponse, nil
return policies, ldapResponse, allGroups, nil
}
/*

View File

@@ -55,7 +55,7 @@ func (b *backend) pathLogin(
username := d.Get("username").(string)
password := d.Get("password").(string)
policies, resp, err := b.Login(req, username, password)
policies, resp, groupNames, err := b.Login(req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@@ -87,6 +87,15 @@ func (b *backend) pathLogin(
Name: username,
},
}
for _, groupName := range groupNames {
if groupName == "" {
continue
}
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
Name: groupName,
})
}
return resp, nil
}
@@ -96,7 +105,7 @@ func (b *backend) pathLoginRenew(
username := req.Auth.Metadata["username"]
password := req.Auth.InternalData["password"].(string)
loginPolicies, resp, err := b.Login(req, username, password)
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
if len(loginPolicies) == 0 {
return resp, err
}
@@ -105,7 +114,21 @@ func (b *backend) pathLoginRenew(
return nil, fmt.Errorf("policies have changed, not renewing")
}
return framework.LeaseExtend(0, 0, b.System())(req, d)
resp, err = framework.LeaseExtend(0, 0, b.System())(req, d)
if err != nil {
return nil, err
}
// Remove old aliases
resp.Auth.GroupAliases = nil
for _, groupName := range groupNames {
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
Name: groupName,
})
}
return resp, nil
}
const pathLoginSyn = `

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/chrismalek/oktasdk-go/okta"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@@ -22,9 +23,14 @@ func Backend() *backend {
Help: backendHelp,
PathsSpecial: &logical.Paths{
Root: mfa.MFARootPaths(),
Unauthenticated: []string{
"login/*",
},
SealWrapStorage: []string{
"config",
},
},
Paths: append([]*framework.Path{
@@ -33,8 +39,9 @@ func Backend() *backend {
pathGroups(&b),
pathUsersList(&b),
pathGroupsList(&b),
pathLogin(&b),
}),
},
mfa.MFAPaths(b.Backend, pathLogin(&b))...,
),
AuthRenew: b.pathLoginRenew,
BackendType: logical.TypeCredential,
@@ -47,13 +54,13 @@ type backend struct {
*framework.Backend
}
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(req.Storage)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if cfg == nil {
return nil, logical.ErrorResponse("Okta auth method not configured"), nil
return nil, logical.ErrorResponse("Okta auth method not configured"), nil, nil
}
client := cfg.OktaClient()
@@ -71,16 +78,16 @@ func (b *backend) Login(req *logical.Request, username string, password string)
"password": password,
})
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
var result authResult
rsp, err := client.Do(authReq, &result)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil
return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil, nil
}
if rsp == nil {
return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil
return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil, nil
}
oktaResponse := &logical.Response{
@@ -92,7 +99,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
if cfg.Token != "" {
oktaGroups, err := b.getOktaGroups(client, &result.Embedded.User)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("okta failure retrieving groups: %v", err)), nil
return nil, logical.ErrorResponse(fmt.Sprintf("okta failure retrieving groups: %v", err)), nil, nil
}
if len(oktaGroups) == 0 {
errString := fmt.Sprintf(
@@ -142,10 +149,10 @@ func (b *backend) Login(req *logical.Request, username string, password string)
}
oktaResponse.Data["error"] = errStr
return nil, oktaResponse, nil
return nil, oktaResponse, nil, nil
}
return policies, oktaResponse, nil
return policies, oktaResponse, allGroups, nil
}
func (b *backend) getOktaGroups(client *okta.Client, user *okta.User) ([]string, error) {

View File

@@ -38,6 +38,15 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
"password": password,
}
mfa_method, ok := m["method"]
if ok {
data["method"] = mfa_method
}
mfa_passcode, ok := m["passcode"]
if ok {
data["passcode"] = mfa_passcode
}
path := fmt.Sprintf("auth/%s/login/%s", mount, username)
secret, err := c.Logical().Write(path, data)
if err != nil {

View File

@@ -57,7 +57,7 @@ func (b *backend) pathLogin(
username := d.Get("username").(string)
password := d.Get("password").(string)
policies, resp, err := b.Login(req, username, password)
policies, resp, groupNames, err := b.Login(req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@@ -96,6 +96,16 @@ func (b *backend) pathLogin(
Name: username,
},
}
for _, groupName := range groupNames {
if groupName == "" {
continue
}
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
Name: groupName,
})
}
return resp, nil
}
@@ -105,7 +115,7 @@ func (b *backend) pathLoginRenew(
username := req.Auth.Metadata["username"]
password := req.Auth.InternalData["password"].(string)
loginPolicies, resp, err := b.Login(req, username, password)
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
if len(loginPolicies) == 0 {
return resp, err
}
@@ -119,7 +129,22 @@ func (b *backend) pathLoginRenew(
return nil, err
}
return framework.LeaseExtend(cfg.TTL, cfg.MaxTTL, b.System())(req, d)
resp, err = framework.LeaseExtend(cfg.TTL, cfg.MaxTTL, b.System())(req, d)
if err != nil {
return nil, err
}
// Remove old aliases
resp.Auth.GroupAliases = nil
for _, groupName := range groupNames {
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
Name: groupName,
})
}
return resp, nil
}
func (b *backend) getConfig(req *logical.Request) (*ConfigEntry, error) {

View File

@@ -26,6 +26,10 @@ func Backend() *backend {
"login",
"login/*",
},
SealWrapStorage: []string{
"config",
},
},
Paths: append([]*framework.Path{

View File

@@ -39,7 +39,7 @@ func pathConfig(b *backend) *framework.Path {
"read_timeout": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Default: 10,
Description: "Number of seconds before response times out (default: 10). Note: kept for backwards compatibility, currently unused.",
Description: "Number of seconds before response times out (default: 10)",
},
"nas_port": &framework.FieldSchema{
Type: framework.TypeInt,

View File

@@ -154,7 +154,9 @@ func (b *backend) RadiusLogin(req *logical.Request, username string, password st
Timeout: time.Duration(cfg.DialTimeout) * time.Second,
},
}
received, err := client.Exchange(context.Background(), packet, hostport)
ctx, cancelFunc := context.WithTimeout(context.Background(), time.Duration(cfg.ReadTimeout)*time.Second)
received, err := client.Exchange(ctx, packet, hostport)
cancelFunc()
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
}

View File

@@ -25,6 +25,9 @@ func Backend() *backend {
LocalStorage: []string{
framework.WALPrefix,
},
SealWrapStorage: []string{
"config/root",
},
},
Paths: []*framework.Path{

View File

@@ -13,8 +13,9 @@ import (
"github.com/hashicorp/vault/logical"
)
func getRootConfig(s logical.Storage) (*aws.Config, error) {
func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{}
var endpoint string
entry, err := s.Get("config/root")
if err != nil {
@@ -29,6 +30,12 @@ func getRootConfig(s logical.Storage) (*aws.Config, error) {
credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey
credsConfig.Region = config.Region
switch {
case clientType == "iam" && config.IAMEndpoint != "":
endpoint = *aws.String(config.IAMEndpoint)
case clientType == "sts" && config.STSEndpoint != "":
endpoint = *aws.String(config.STSEndpoint)
}
}
if credsConfig.Region == "" {
@@ -51,16 +58,19 @@ func getRootConfig(s logical.Storage) (*aws.Config, error) {
return &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
Endpoint: &endpoint,
HTTPClient: cleanhttp.DefaultClient(),
}, nil
}
func clientIAM(s logical.Storage) (*iam.IAM, error) {
awsConfig, err := getRootConfig(s)
awsConfig, err := getRootConfig(s, "iam")
if err != nil {
return nil, err
}
client := iam.New(session.New(awsConfig))
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
}
@@ -68,11 +78,12 @@ func clientIAM(s logical.Storage) (*iam.IAM, error) {
}
func clientSTS(s logical.Storage) (*sts.STS, error) {
awsConfig, err := getRootConfig(s)
awsConfig, err := getRootConfig(s, "sts")
if err != nil {
return nil, err
}
client := sts.New(session.New(awsConfig))
if client == nil {
return nil, fmt.Errorf("could not obtain sts client")
}

View File

@@ -23,6 +23,14 @@ func pathConfigRoot() *framework.Path {
Type: framework.TypeString,
Description: "Region for API calls.",
},
"iam_endpoint": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Endpoint to custom IAM server URL",
},
"sts_endpoint": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Endpoint to custom STS server URL",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -37,11 +45,15 @@ func pathConfigRoot() *framework.Path {
func pathConfigRootWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
region := data.Get("region").(string)
iamendpoint := data.Get("iam_endpoint").(string)
stsendpoint := data.Get("sts_endpoint").(string)
entry, err := logical.StorageEntryJSON("config/root", rootConfig{
AccessKey: data.Get("access_key").(string),
SecretKey: data.Get("secret_key").(string),
Region: region,
AccessKey: data.Get("access_key").(string),
SecretKey: data.Get("secret_key").(string),
IAMEndpoint: iamendpoint,
STSEndpoint: stsendpoint,
Region: region,
})
if err != nil {
return nil, err
@@ -55,9 +67,11 @@ func pathConfigRootWrite(
}
type rootConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
Region string `json:"region"`
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
Region string `json:"region"`
}
const pathConfigRootHelpSyn = `

View File

@@ -25,6 +25,12 @@ func Backend() *backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathRoles(&b),

View File

@@ -4,73 +4,76 @@ import (
"fmt"
"log"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
dockertest "gopkg.in/ory-am/dockertest.v2"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
var (
testImagePull sync.Once
)
func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) {
func prepareCassandraTestContainer(t *testing.T) (func(), string, int) {
if os.Getenv("CASSANDRA_HOST") != "" {
return "", os.Getenv("CASSANDRA_HOST")
return func() {}, os.Getenv("CASSANDRA_HOST"), 0
}
// Without this the checks for whether the container has started seem to
// never actually pass. There's really no reason to expose the test
// containers, so don't.
dockertest.BindDockerToLocalhost = "yep"
testImagePull.Do(func() {
dockertest.Pull("cassandra")
})
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
cwd, _ := os.Getwd()
cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd)
cid, connErr := dockertest.ConnectToCassandra("latest", 60, 1000*time.Millisecond, func(connURL string) bool {
// This will cause a validation to run
resp, err := b.HandleRequest(&logical.Request{
Storage: s,
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: map[string]interface{}{
"hosts": connURL,
"username": "cassandra",
"password": "cassandra",
"protocol_version": 3,
},
})
if err != nil || (resp != nil && resp.IsError()) {
// It's likely not up and running yet, so return false and try again
return false
}
retURL = connURL
return true
}, []string{"-v", cwd + "/test-fixtures/:/etc/cassandra/"}...)
if connErr != nil {
if cid != "" {
cid.KillRemove()
}
t.Fatalf("could not connect to database: %v", connErr)
ro := &dockertest.RunOptions{
Repository: "cassandra",
Tag: "latest",
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
Mounts: []string{cassandraMountPath},
}
return
}
func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
err := cid.KillRemove()
resource, err := pool.RunWithOptions(ro)
if err != nil {
t.Fatal(err)
t.Fatalf("Could not start local cassandra docker container: %s", err)
}
cleanup := func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local container: %s", err)
}
}
port, _ := strconv.Atoi(resource.GetPort("9042/tcp"))
address := fmt.Sprintf("127.0.0.1:%d", port)
// exponential backoff-retry
if err = pool.Retry(func() error {
clusterConfig := gocql.NewCluster(address)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: "cassandra",
Password: "cassandra",
}
clusterConfig.ProtoVersion = 4
clusterConfig.Port = port
session, err := clusterConfig.CreateSession()
if err != nil {
return fmt.Errorf("error creating session: %s", err)
}
defer session.Close()
return nil
}); err != nil {
cleanup()
t.Fatalf("Could not connect to cassandra docker container: %s", err)
}
return cleanup, address, port
}
func TestBackend_basic(t *testing.T) {
@@ -84,10 +87,8 @@ func TestBackend_basic(t *testing.T) {
t.Fatal(err)
}
cid, hostname := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
cleanup, hostname, _ := prepareCassandraTestContainer(t)
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
@@ -110,10 +111,8 @@ func TestBackend_roleCrud(t *testing.T) {
t.Fatal(err)
}
cid, hostname := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
cleanup, hostname, _ := prepareCassandraTestContainer(t)
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,

View File

@@ -421,7 +421,7 @@ seed_provider:
parameters:
# seeds is actually a comma-delimited list of addresses.
# Ex: "<ip1>,<ip2>,<ip3>"
- seeds: "172.17.0.3"
- seeds: "127.0.0.1"
# For workloads with more data than can fit in memory, Cassandra's
# bottleneck will be reads that need to fetch data from
@@ -572,7 +572,7 @@ ssl_storage_port: 7001
#
# Setting listen_address to 0.0.0.0 is always wrong.
#
listen_address: 172.17.0.3
listen_address: 172.17.0.2
# Set listen_address OR listen_interface, not both. Interfaces must correspond
# to a single address, IP aliasing is not supported.
@@ -586,7 +586,7 @@ listen_address: 172.17.0.3
# Address to broadcast to other Cassandra nodes
# Leaving this blank will set it to the same value as listen_address
broadcast_address: 172.17.0.3
broadcast_address: 127.0.0.1
# When using multiple physical network interfaces, set this
# to true to listen on broadcast_address in addition to
@@ -668,7 +668,7 @@ rpc_port: 9160
# be set to 0.0.0.0. If left blank, this will be set to the value of
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
# be set.
broadcast_rpc_address: 172.17.0.3
broadcast_rpc_address: 127.0.0.1
# enable or disable keepalive on rpc/native connections
rpc_keepalive: true

View File

@@ -16,6 +16,12 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/access",
},
},
Paths: []*framework.Path{
pathConfigAccess(),
pathListRoles(&b),

View File

@@ -433,12 +433,8 @@ func testAccStepReadPolicy(t *testing.T, name string, policy string, lease time.
return fmt.Errorf("mismatch: %s %s", out, policy)
}
leaseRaw := resp.Data["lease"].(string)
l, err := time.ParseDuration(leaseRaw)
if err != nil {
return err
}
if l != lease {
l := resp.Data["lease"].(int64)
if lease != time.Second*time.Duration(l) {
return fmt.Errorf("mismatch: %v %v", l, lease)
}
return nil

View File

@@ -44,7 +44,7 @@ Defaults to 'client'.`,
},
"lease": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: "Lease time of the role.",
},
},
@@ -91,7 +91,7 @@ func pathRolesRead(
// Generate the response
resp := &logical.Response{
Data: map[string]interface{}{
"lease": result.Lease.String(),
"lease": int64(result.Lease.Seconds()),
"token_type": result.TokenType,
},
}
@@ -130,13 +130,9 @@ func pathRolesWrite(
}
var lease time.Duration
leaseParam := d.Get("lease").(string)
if leaseParam != "" {
lease, err = time.ParseDuration(leaseParam)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"error parsing given lease of %s: %s", leaseParam, err)), nil
}
leaseParamRaw, ok := d.GetOk("lease")
if ok {
lease = time.Second * time.Duration(leaseParamRaw.(int))
}
entry, err := logical.StorageEntryJSON("policy/"+name, roleConfig{

View File

@@ -11,9 +11,9 @@ import (
func pathToken(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Pattern: "creds/" + framework.GenericNameRegex("role"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
"role": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role",
},
@@ -27,14 +27,14 @@ func pathToken(b *backend) *framework.Path {
func (b *backend) pathTokenRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
role := d.Get("role").(string)
entry, err := req.Storage.Get("policy/" + name)
entry, err := req.Storage.Get("policy/" + role)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
if entry == nil {
return logical.ErrorResponse(fmt.Sprintf("Role '%s' not found", name)), nil
return logical.ErrorResponse(fmt.Sprintf("role %q not found", role)), nil
}
var result roleConfig
@@ -56,7 +56,7 @@ func (b *backend) pathTokenRead(
}
// Generate a name for the token
tokenName := fmt.Sprintf("Vault %s %s %d", name, req.DisplayName, time.Now().UnixNano())
tokenName := fmt.Sprintf("Vault %s %s %d", role, req.DisplayName, time.Now().UnixNano())
// Create it
token, _, err := c.ACL().Create(&api.ACLEntry{
@@ -73,6 +73,7 @@ func (b *backend) pathTokenRead(
"token": token,
}, map[string]interface{}{
"token": token,
"role": role,
})
s.Secret.TTL = result.Lease

View File

@@ -1,6 +1,8 @@
package consul
import (
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@@ -26,8 +28,30 @@ func secretToken(b *backend) *framework.Secret {
func (b *backend) secretTokenRenew(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
roleRaw, ok := req.Secret.InternalData["role"]
if !ok || roleRaw == nil {
return framework.LeaseExtend(0, 0, b.System())(req, d)
}
return framework.LeaseExtend(0, 0, b.System())(req, d)
role, ok := roleRaw.(string)
if !ok {
return framework.LeaseExtend(0, 0, b.System())(req, d)
}
entry, err := req.Storage.Get("policy/" + role)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
if entry == nil {
return logical.ErrorResponse(fmt.Sprintf("issuing role %q not found", role)), nil
}
var result roleConfig
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return framework.LeaseExtend(result.Lease, 0, b.System())(req, d)
}
func secretTokenRevoke(

View File

@@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"net/rpc"
"strings"
@@ -28,6 +29,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/*",
},
},
Paths: []*framework.Path{
pathListPluginConnection(&b),
pathConfigurePluginConnection(&b),
@@ -81,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
// This function creates a new db object from the stored configuration and
// caches it in the connections map. The caller of this function needs to hold
// the backend's write lock
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
db, ok := b.connections[name]
if ok {
return db, nil
@@ -97,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
return nil, err
}
err = db.Initialize(config.ConnectionDetails, true)
err = db.Initialize(ctx, config.ConnectionDetails, true)
if err != nil {
return nil, err
}
@@ -124,6 +131,21 @@ func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*Datab
return &config, nil
}
type upgradeStatements struct {
// This json tag has a typo in it, the new version does not. This
// necessitates this upgrade logic.
CreationStatements string `json:"creation_statments"`
RevocationStatements string `json:"revocation_statements"`
RollbackStatements string `json:"rollback_statements"`
RenewStatements string `json:"renew_statements"`
}
type upgradeCheck struct {
// This json tag has a typo in it, the new version does not. This
// necessitates this upgrade logic.
Statements upgradeStatements `json:"statments"`
}
func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) {
entry, err := s.Get("role/" + roleName)
if err != nil {
@@ -133,11 +155,24 @@ func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry,
return nil, nil
}
var upgradeCh upgradeCheck
if err := entry.DecodeJSON(&upgradeCh); err != nil {
return nil, err
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
empty := upgradeCheck{}
if upgradeCh != empty {
result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements
result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements
result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements
result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements
}
return &result, nil
}
@@ -164,7 +199,8 @@ func (b *databaseBackend) clearConnection(name string) {
func (b *databaseBackend) closeIfShutdown(name string, err error) {
// Plugin has shutdown, close it so next call can reconnect.
if err == rpc.ErrShutdown {
switch err {
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
b.Lock()
b.clearConnection(name)
b.Unlock()

View File

@@ -116,6 +116,55 @@ func TestBackend_PluginMain(t *testing.T) {
postgresql.Run(apiClientMeta.GetTLSConfig())
}
func TestBackend_RoleUpgrade(t *testing.T) {
storage := &logical.InmemStorage{}
backend := &databaseBackend{}
roleEnt := &roleEntry{
Statements: dbplugin.Statements{
CreationStatements: "test",
},
}
entry, err := logical.StorageEntryJSON("role/test", roleEnt)
if err != nil {
t.Fatal(err)
}
if err := storage.Put(entry); err != nil {
t.Fatal(err)
}
role, err := backend.Role(storage, "test")
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(role, roleEnt) {
t.Fatal("bad role %#v", role)
}
// Upgrade case
badJSON := `{"statments":{"creation_statments":"test","revocation_statements":"","rollback_statements":"","renew_statements":""}}`
entry = &logical.StorageEntry{
Key: "role/test",
Value: []byte(badJSON),
}
if err := storage.Put(entry); err != nil {
t.Fatal(err)
}
role, err = backend.Role(storage, "test")
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(role, roleEnt) {
t.Fatal("bad role %#v", role)
}
}
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
@@ -488,9 +537,11 @@ func TestBackend_roleCrud(t *testing.T) {
RevocationStatements: defaultRevocationSQL,
}
var actual dbplugin.Statements
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
t.Fatal(err)
actual := dbplugin.Statements{
CreationStatements: resp.Data["creation_statements"].(string),
RevocationStatements: resp.Data["revocation_statements"].(string),
RollbackStatements: resp.Data["rollback_statements"].(string),
RenewStatements: resp.Data["renew_statements"].(string),
}
if !reflect.DeepEqual(expected, actual) {
@@ -609,6 +660,40 @@ func TestBackend_allowedRoles(t *testing.T) {
t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err)
}
// update connection with glob allowed roles connection
data = map[string]interface{}{
"connection_url": connURL,
"plugin_name": "postgresql-database-plugin",
"allowed_roles": "allow*",
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Get creds, should work.
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "creds/allowed",
Storage: config.StorageView,
Data: data,
}
credsResp, err = b.HandleRequest(req)
if err != nil || (credsResp != nil && credsResp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
}
if !testCredsExist(t, credsResp, connURL) {
t.Fatalf("Creds should exist")
}
// update connection with * allowed roles connection
data = map[string]interface{}{
"connection_url": connURL,

View File

@@ -1,10 +1,8 @@
package dbplugin
import (
"fmt"
"net/rpc"
"errors"
"sync"
"time"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
@@ -17,11 +15,11 @@ type DatabasePluginClient struct {
client *plugin.Client
sync.Mutex
*databasePluginRPCClient
Database
}
func (dc *DatabasePluginClient) Close() error {
err := dc.databasePluginRPCClient.Close()
err := dc.Database.Close()
dc.client.Kill()
return err
@@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
databaseRPC := raw.(*databasePluginRPCClient)
var db Database
switch raw.(type) {
case *gRPCClient:
db = raw.(*gRPCClient)
case *databasePluginRPCClient:
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
db = raw.(*databasePluginRPCClient)
default:
return nil, errors.New("unsupported client type")
}
// Wrap RPC implimentation in DatabasePluginClient
return &DatabasePluginClient{
client: client,
databasePluginRPCClient: databaseRPC,
client: client,
Database: db,
}, nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequest{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
req := RenewUserRequest{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
req := RevokeUserRequest{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequest{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}

View File

@@ -0,0 +1,556 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: builtin/logical/database/dbplugin/database.proto
/*
Package dbplugin is a generated protocol buffer package.
It is generated from these files:
builtin/logical/database/dbplugin/database.proto
It has these top-level messages:
InitializeRequest
CreateUserRequest
RenewUserRequest
RevokeUserRequest
Statements
UsernameConfig
CreateUserResponse
TypeResponse
Empty
*/
package dbplugin
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type InitializeRequest struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
}
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
func (*InitializeRequest) ProtoMessage() {}
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *InitializeRequest) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
func (m *InitializeRequest) GetVerifyConnection() bool {
if m != nil {
return m.VerifyConnection
}
return false
}
type CreateUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
func (*CreateUserRequest) ProtoMessage() {}
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *CreateUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
if m != nil {
return m.UsernameConfig
}
return nil
}
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RenewUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
func (*RenewUserRequest) ProtoMessage() {}
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *RenewUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RenewUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RevokeUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
}
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
func (*RevokeUserRequest) ProtoMessage() {}
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *RevokeUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RevokeUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
type Statements struct {
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
}
func (m *Statements) Reset() { *m = Statements{} }
func (m *Statements) String() string { return proto.CompactTextString(m) }
func (*Statements) ProtoMessage() {}
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *Statements) GetCreationStatements() string {
if m != nil {
return m.CreationStatements
}
return ""
}
func (m *Statements) GetRevocationStatements() string {
if m != nil {
return m.RevocationStatements
}
return ""
}
func (m *Statements) GetRollbackStatements() string {
if m != nil {
return m.RollbackStatements
}
return ""
}
func (m *Statements) GetRenewStatements() string {
if m != nil {
return m.RenewStatements
}
return ""
}
type UsernameConfig struct {
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
}
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
func (*UsernameConfig) ProtoMessage() {}
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *UsernameConfig) GetDisplayName() string {
if m != nil {
return m.DisplayName
}
return ""
}
func (m *UsernameConfig) GetRoleName() string {
if m != nil {
return m.RoleName
}
return ""
}
type CreateUserResponse struct {
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
}
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
func (*CreateUserResponse) ProtoMessage() {}
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *CreateUserResponse) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *CreateUserResponse) GetPassword() string {
if m != nil {
return m.Password
}
return ""
}
type TypeResponse struct {
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
}
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
func (*TypeResponse) ProtoMessage() {}
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *TypeResponse) GetType() string {
if m != nil {
return m.Type
}
return ""
}
type Empty struct {
}
func (m *Empty) Reset() { *m = Empty{} }
func (m *Empty) String() string { return proto.CompactTextString(m) }
func (*Empty) ProtoMessage() {}
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func init() {
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Database service
type DatabaseClient interface {
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
}
type databaseClient struct {
cc *grpc.ClientConn
}
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
return &databaseClient{cc}
}
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
out := new(TypeResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
out := new(CreateUserResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Database service
type DatabaseServer interface {
Type(context.Context, *Empty) (*TypeResponse, error)
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error)
Close(context.Context, *Empty) (*Empty, error)
}
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
s.RegisterService(&_Database_serviceDesc, srv)
}
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Type(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Type",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CreateUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).CreateUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/CreateUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RenewUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RenewUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RenewUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RevokeUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RevokeUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RevokeUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Initialize",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Close(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Close",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
var _Database_serviceDesc = grpc.ServiceDesc{
ServiceName: "dbplugin.Database",
HandlerType: (*DatabaseServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Type",
Handler: _Database_Type_Handler,
},
{
MethodName: "CreateUser",
Handler: _Database_CreateUser_Handler,
},
{
MethodName: "RenewUser",
Handler: _Database_RenewUser_Handler,
},
{
MethodName: "RevokeUser",
Handler: _Database_RevokeUser_Handler,
},
{
MethodName: "Initialize",
Handler: _Database_Initialize_Handler,
},
{
MethodName: "Close",
Handler: _Database_Close_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "builtin/logical/database/dbplugin/database.proto",
}
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 548 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
0x94, 0x05, 0x00, 0x00,
}

View File

@@ -0,0 +1,58 @@
syntax = "proto3";
package dbplugin;
import "google/protobuf/timestamp.proto";
message InitializeRequest {
bytes config = 1;
bool verify_connection = 2;
}
message CreateUserRequest {
Statements statements = 1;
UsernameConfig username_config = 2;
google.protobuf.Timestamp expiration = 3;
}
message RenewUserRequest {
Statements statements = 1;
string username = 2;
google.protobuf.Timestamp expiration = 3;
}
message RevokeUserRequest {
Statements statements = 1;
string username = 2;
}
message Statements {
string creation_statements = 1;
string revocation_statements = 2;
string rollback_statements = 3;
string renew_statements = 4;
}
message UsernameConfig {
string DisplayName = 1;
string RoleName = 2;
}
message CreateUserResponse {
string username = 1;
string password = 2;
}
message TypeResponse {
string type = 1;
}
message Empty {}
service Database {
rpc Type(Empty) returns (TypeResponse);
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty);
rpc Close(Empty) returns (Empty);
}

View File

@@ -1,6 +1,7 @@
package dbplugin
import (
"context"
"time"
metrics "github.com/armon/go-metrics"
@@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
next Database
logger log.Logger
typeStr string
typeStr string
transport string
}
func (mw *databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
return mw.next.CreateUser(statements, usernameConfig, expiration)
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
return mw.next.RenewUser(statements, username, expiration)
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
return mw.next.RevokeUser(statements, username)
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
return mw.next.Initialize(conf, verifyConnection)
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Initialize(ctx, conf, verifyConnection)
}
func (mw *databaseTracingMiddleware) Close() (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Close()
}
@@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
@@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
return mw.next.CreateUser(statements, usernameConfig, expiration)
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
@@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
return mw.next.RenewUser(statements, username, expiration)
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
@@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
return mw.next.RevokeUser(statements, username)
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
@@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(conf, verifyConnection)
return mw.next.Initialize(ctx, conf, verifyConnection)
}
func (mw *databaseMetricsMiddleware) Close() (err error) {

View File

@@ -0,0 +1,198 @@
package dbplugin
import (
"context"
"encoding/json"
"errors"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"github.com/golang/protobuf/ptypes"
)
var (
ErrPluginShutdown = errors.New("plugin shutdown")
)
// ---- gRPC Server domain ----
type gRPCServer struct {
impl Database
}
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
t, err := s.impl.Type()
if err != nil {
return nil, err
}
return &TypeResponse{
Type: t,
}, nil
}
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
return &CreateUserResponse{
Username: u,
Password: p,
}, err
}
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
return &Empty{}, err
}
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
return &Empty{}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config)
if err != nil {
return nil, err
}
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
return &Empty{}, err
}
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
s.impl.Close()
return &Empty{}, nil
}
// ---- gRPC client domain ----
type gRPCClient struct {
client DatabaseClient
clientConn *grpc.ClientConn
}
func (c gRPCClient) Type() (string, error) {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", ErrPluginShutdown
}
resp, err := c.client.Type(ctx, &Empty{})
if err != nil {
return "", err
}
return resp.Type, err
}
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return "", "", err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", "", ErrPluginShutdown
}
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
Statements: &statements,
UsernameConfig: &usernameConfig,
Expiration: t,
})
if err != nil {
return "", "", err
}
return resp.Username, resp.Password, err
}
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
Statements: &statements,
Username: username,
Expiration: t,
})
return err
}
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
Statements: &statements,
Username: username,
})
return err
}
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
configRaw, err := json.Marshal(config)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.Initialize(ctx, &InitializeRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
return err
}
func (c *gRPCClient) Close() error {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
_, err := c.client.Close(ctx, &Empty{})
return err
}
return nil
}

View File

@@ -0,0 +1,139 @@
package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
)
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequestRPC{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
req := RenewUserRequestRPC{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
req := RevokeUserRequestRPC{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}
// ---- RPC Request Args Domain ----
type InitializeRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type CreateUserRequestRPC struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequestRPC struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequestRPC struct {
Statements Statements
Username string
}

View File

@@ -1,10 +1,13 @@
package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
"google.golang.org/grpc"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
log "github.com/mgutz/logxi/v1"
@@ -13,29 +16,14 @@ import (
// Database is the interface that all database objects must implement.
type Database interface {
Type() (string, error)
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
RevokeUser(ctx context.Context, statements Statements, username string) error
Initialize(config map[string]interface{}, verifyConnection bool) error
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
Close() error
}
// Statements set in role creation and passed into the database type's functions.
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
}
// UsernameConfig is used to configure prefixes for the username to be
// generated.
type UsernameConfig struct {
DisplayName string
RoleName string
}
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
@@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, err
}
var transport string
var db Database
if pluginRunner.Builtin {
// Plugin is builtin so we can retrieve an instance of the interface
@@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
}
transport = "builtin"
} else {
// create a DatabasePluginClient instance
db, err = newPluginClient(sys, pluginRunner, logger)
if err != nil {
return nil, err
}
// Switch on the underlying database client type to get the transport
// method.
switch db.(*DatabasePluginClient).Database.(type) {
case *gRPCClient:
transport = "gRPC"
case *databasePluginRPCClient:
transport = "netRPC"
}
}
typeStr, err := db.Type()
@@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
// Wrap with tracing middleware
if logger.IsTrace() {
db = &databaseTracingMiddleware{
next: db,
typeStr: typeStr,
logger: logger,
transport: transport,
next: db,
typeStr: typeStr,
logger: logger,
}
}
@@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
return &databasePluginRPCClient{client: c}, nil
}
// ---- RPC Request Args Domain ----
type InitializeRequest struct {
Config map[string]interface{}
VerifyConnection bool
func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
return nil
}
type CreateUserRequest struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequest struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequest struct {
Statements Statements
Username string
}
// ---- RPC Response Args Domain ----
type CreateUserResponse struct {
Username string
Password string
func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
return &gRPCClient{
client: NewDatabaseClient(c),
clientConn: c,
}, nil
}

View File

@@ -1,11 +1,13 @@
package dbplugin_test
import (
"context"
"errors"
"os"
"testing"
"time"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/pluginutil"
vaulthttp "github.com/hashicorp/vault/http"
@@ -20,7 +22,7 @@ type mockPlugin struct {
}
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
err = errors.New("err")
if usernameConf.DisplayName == "" || expiration.IsZero() {
return "", "", err
@@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
return usernameConf.DisplayName, "test", nil
}
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
err := errors.New("err")
if username == "" || expiration.IsZero() {
return err
@@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
return nil
}
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
err := errors.New("err")
if username == "" {
return err
@@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
delete(m.users, username)
return nil
}
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
err := errors.New("err")
if len(conf) != 1 {
return err
@@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
cores := cluster.Cores
sys := vault.TestDynamicSystemView(cores[0].Core)
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
return cluster, sys
}
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_Main(t *testing.T) {
func TestPlugin_GRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
@@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
}
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_NetRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
p := &mockPlugin{
users: make(map[string][]string),
}
args := []string{"--tls-skip-verify=true"}
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
serveConf := dbplugin.ServeConfig(p, tlsProvider)
serveConf.GRPCServer = nil
plugin.Serve(serveConf)
}
func TestPlugin_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
@@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1,
}
err = dbRaw.Initialize(connectionDetails, true)
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
RoleName: "test",
}
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
@@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
RoleName: "test",
}
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
RoleName: "test",
}
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(dbplugin.Statements{}, us)
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
// Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
connectionDetails := map[string]interface{}{
"test": 1,
}
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
if us != "test" || pw != "test" {
t.Fatal("expected username and password to be 'test'")
}
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
}
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@@ -10,6 +10,10 @@ import (
// Database implementation in a databasePluginRPCServer object and starts a
// RPC server.
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
plugin.Serve(ServeConfig(db, tlsProvider))
}
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
dbPlugin := &DatabasePlugin{
impl: db,
}
@@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
"database": dbPlugin,
}
plugin.Serve(&plugin.ServeConfig{
return &plugin.ServeConfig{
HandshakeConfig: handshakeConfig,
Plugins: pluginMap,
TLSProvider: tlsProvider,
})
}
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
err := ds.impl.RevokeUser(args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
GRPCServer: plugin.DefaultGRPCServer,
}
}

View File

@@ -1,6 +1,7 @@
package database
import (
"context"
"errors"
"fmt"
@@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
b.clearConnection(name)
// Execute plugin again, we don't need the object so throw away.
_, err := b.createDBObj(req.Storage, name)
_, err := b.createDBObj(context.TODO(), req.Storage, name)
if err != nil {
return nil, err
}
@@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
}
err = db.Initialize(config.ConnectionDetails, verifyConnection)
err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
if err != nil {
db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil

View File

@@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"time"
@@ -49,7 +50,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
// If role name isn't in the database's allowed roles, send back a
// permission denied.
if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContains(dbConfig.AllowedRoles, name) {
if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContainsGlob(dbConfig.AllowedRoles, name) {
return nil, logical.ErrPermissionDenied
}
@@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
}
// Create the user
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration)
username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
// Unlock
unlockFunc()
if err != nil {

View File

@@ -181,7 +181,7 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
Statements dbplugin.Statements `json:"statments" mapstructure:"statements" structs:"statments"`
Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
}

View File

@@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"github.com/hashicorp/vault/logical"
@@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
// Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
err := db.RenewUser(role.Statements, username, expireTime)
err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
// Unlock
unlockFunc()
if err != nil {
@@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
}
err = db.RevokeUser(role.Statements, username)
err = db.RevokeUser(context.TODO(), role.Statements, username)
// Unlock
unlockFunc()
if err != nil {

View File

@@ -24,6 +24,12 @@ func Backend() *framework.Backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),

View File

@@ -24,6 +24,12 @@ func Backend() *backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),

View File

@@ -24,6 +24,12 @@ func Backend() *backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),

View File

@@ -0,0 +1,68 @@
package nomad
import (
"github.com/hashicorp/nomad/api"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
return nil, err
}
return b, nil
}
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/access",
},
},
Paths: []*framework.Path{
pathConfigAccess(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathCredsCreate(&b),
},
Secrets: []*framework.Secret{
secretToken(&b),
},
BackendType: logical.TypeLogical,
}
return &b
}
type backend struct {
*framework.Backend
}
func (b *backend) client(s logical.Storage) (*api.Client, error) {
conf, err := b.readConfigAccess(s)
if err != nil {
return nil, err
}
nomadConf := api.DefaultConfig()
if conf != nil {
if conf.Address != "" {
nomadConf.Address = conf.Address
}
if conf.Token != "" {
nomadConf.SecretID = conf.Token
}
}
client, err := api.NewClient(nomadConf)
if err != nil {
return nil, err
}
return client, nil
}

View File

@@ -0,0 +1,302 @@
package nomad
import (
"fmt"
"os"
"reflect"
"testing"
"time"
nomadapi "github.com/hashicorp/nomad/api"
"github.com/hashicorp/vault/logical"
"github.com/mitchellh/mapstructure"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
func prepareTestContainer(t *testing.T) (cleanup func(), retAddress string, nomadToken string) {
nomadToken = os.Getenv("NOMAD_TOKEN")
retAddress = os.Getenv("NOMAD_ADDR")
if retAddress != "" {
return func() {}, retAddress, nomadToken
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
dockerOptions := &dockertest.RunOptions{
Repository: "djenriquez/nomad",
Tag: "latest",
Cmd: []string{"agent", "-dev"},
Env: []string{`NOMAD_LOCAL_CONFIG=bind_addr = "0.0.0.0" acl { enabled = true }`},
}
resource, err := pool.RunWithOptions(dockerOptions)
if err != nil {
t.Fatalf("Could not start local Nomad docker container: %s", err)
}
cleanup = func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local container: %s", err)
}
}
retAddress = fmt.Sprintf("http://localhost:%s/", resource.GetPort("4646/tcp"))
// Give Nomad time to initialize
time.Sleep(5000 * time.Millisecond)
// exponential backoff-retry
if err = pool.Retry(func() error {
var err error
nomadapiConfig := nomadapi.DefaultConfig()
nomadapiConfig.Address = retAddress
nomad, err := nomadapi.NewClient(nomadapiConfig)
if err != nil {
return err
}
aclbootstrap, _, err := nomad.ACLTokens().Bootstrap(nil)
if err != nil {
t.Fatalf("err: %v", err)
}
nomadToken = aclbootstrap.SecretID
t.Logf("[WARN] Generated Master token: %s", nomadToken)
policy := &nomadapi.ACLPolicy{
Name: "test",
Description: "test",
Rules: `namespace "default" {
policy = "read"
}
`,
}
anonPolicy := &nomadapi.ACLPolicy{
Name: "anonymous",
Description: "Deny all access for anonymous requests",
Rules: `namespace "default" {
policy = "deny"
}
agent {
policy = "deny"
}
node {
policy = "deny"
}
`,
}
nomadAuthConfig := nomadapi.DefaultConfig()
nomadAuthConfig.Address = retAddress
nomadAuthConfig.SecretID = nomadToken
nomadAuth, err := nomadapi.NewClient(nomadAuthConfig)
_, err = nomadAuth.ACLPolicies().Upsert(policy, nil)
if err != nil {
t.Fatal(err)
}
_, err = nomadAuth.ACLPolicies().Upsert(anonPolicy, nil)
if err != nil {
t.Fatal(err)
}
return err
}); err != nil {
cleanup()
t.Fatalf("Could not connect to docker: %s", err)
}
return cleanup, retAddress, nomadToken
}
func TestBackend_config_access(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL, connToken := prepareTestContainer(t)
defer cleanup()
connData := map[string]interface{}{
"address": connURL,
"token": connToken,
}
confReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/access",
Storage: config.StorageView,
Data: connData,
}
resp, err := b.HandleRequest(confReq)
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
}
confReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(confReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
}
expected := map[string]interface{}{
"address": connData["address"].(string),
}
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
if resp.Data["token"] != nil {
t.Fatalf("token should not be set in the response")
}
}
func TestBackend_renew_revoke(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL, connToken := prepareTestContainer(t)
defer cleanup()
connData := map[string]interface{}{
"address": connURL,
"token": connToken,
}
req := &logical.Request{
Storage: config.StorageView,
Operation: logical.UpdateOperation,
Path: "config/access",
Data: connData,
}
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
req.Path = "role/test"
req.Data = map[string]interface{}{
"policies": []string{"policy"},
"lease": "6h",
}
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
req.Operation = logical.ReadOperation
req.Path = "creds/test"
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("resp nil")
}
if resp.IsError() {
t.Fatalf("resp is error: %v", resp.Error())
}
generatedSecret := resp.Secret
generatedSecret.IssueTime = time.Now()
generatedSecret.TTL = 6 * time.Hour
var d struct {
Token string `mapstructure:"secret_id"`
Accessor string `mapstructure:"accessor_id"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
t.Fatal(err)
}
t.Logf("[WARN] Generated token: %s with accesor %s", d.Token, d.Accessor)
// Build a client and verify that the credentials work
nomadapiConfig := nomadapi.DefaultConfig()
nomadapiConfig.Address = connData["address"].(string)
nomadapiConfig.SecretID = d.Token
client, err := nomadapi.NewClient(nomadapiConfig)
if err != nil {
t.Fatal(err)
}
t.Log("[WARN] Verifying that the generated token works...")
_, err = client.Agent().Members, nil
if err != nil {
t.Fatal(err)
}
req.Operation = logical.RenewOperation
req.Secret = generatedSecret
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response from renew")
}
req.Operation = logical.RevokeOperation
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
// Build a management client and verify that the token does not exist anymore
nomadmgmtConfig := nomadapi.DefaultConfig()
nomadmgmtConfig.Address = connData["address"].(string)
nomadmgmtConfig.SecretID = connData["token"].(string)
mgmtclient, err := nomadapi.NewClient(nomadmgmtConfig)
q := &nomadapi.QueryOptions{
Namespace: "default",
}
t.Log("[WARN] Verifying that the generated token does not exist...")
_, _, err = mgmtclient.ACLTokens().Info(d.Accessor, q)
if err == nil {
t.Fatal("err: expected error")
}
}
func TestBackend_CredsCreateEnvVar(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL, connToken := prepareTestContainer(t)
defer cleanup()
req := logical.TestRequest(t, logical.UpdateOperation, "role/test")
req.Data = map[string]interface{}{
"policies": []string{"policy"},
"lease": "6h",
}
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
os.Setenv("NOMAD_TOKEN", connToken)
defer os.Unsetenv("NOMAD_TOKEN")
os.Setenv("NOMAD_ADDR", connURL)
defer os.Unsetenv("NOMAD_ADDR")
req.Operation = logical.ReadOperation
req.Path = "creds/test"
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("resp nil")
}
if resp.IsError() {
t.Fatalf("resp is error: %v", resp.Error())
}
}

View File

@@ -0,0 +1,121 @@
package nomad
import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
const configAccessKey = "config/access"
func pathConfigAccess(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/access",
Fields: map[string]*framework.FieldSchema{
"address": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Nomad server address",
},
"token": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Token for API calls",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathConfigAccessRead,
logical.CreateOperation: b.pathConfigAccessWrite,
logical.UpdateOperation: b.pathConfigAccessWrite,
logical.DeleteOperation: b.pathConfigAccessDelete,
},
ExistenceCheck: b.configExistenceCheck,
}
}
func (b *backend) configExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.readConfigAccess(req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) readConfigAccess(storage logical.Storage) (*accessConfig, error) {
entry, err := storage.Get(configAccessKey)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
conf := &accessConfig{}
if err := entry.DecodeJSON(conf); err != nil {
return nil, errwrap.Wrapf("error reading nomad access configuration: {{err}}", err)
}
return conf, nil
}
func (b *backend) pathConfigAccessRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
conf, err := b.readConfigAccess(req.Storage)
if err != nil {
return nil, err
}
if conf == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"address": conf.Address,
},
}, nil
}
func (b *backend) pathConfigAccessWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
conf, err := b.readConfigAccess(req.Storage)
if err != nil {
return nil, err
}
if conf == nil {
conf = &accessConfig{}
}
address, ok := data.GetOk("address")
if ok {
conf.Address = address.(string)
}
token, ok := data.GetOk("token")
if ok {
conf.Token = token.(string)
}
entry, err := logical.StorageEntryJSON("config/access", conf)
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathConfigAccessDelete(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
if err := req.Storage.Delete(configAccessKey); err != nil {
return nil, err
}
return nil, nil
}
type accessConfig struct {
Address string `json:"address"`
Token string `json:"token"`
}

View File

@@ -0,0 +1,109 @@
package nomad
import (
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
const leaseConfigKey = "config/lease"
func pathConfigLease(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: "Duration before which the issued token needs renewal",
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Duration after which the issued token should not be allowed to be renewed`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLeaseRead,
logical.UpdateOperation: b.pathLeaseUpdate,
logical.DeleteOperation: b.pathLeaseDelete,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
// Sets the lease configuration parameters
func (b *backend) pathLeaseUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
TTL: time.Second * time.Duration(d.Get("ttl").(int)),
MaxTTL: time.Second * time.Duration(d.Get("max_ttl").(int)),
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathLeaseDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
if err := req.Storage.Delete(leaseConfigKey); err != nil {
return nil, err
}
return nil, nil
}
// Returns the lease configuration parameters
func (b *backend) pathLeaseRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.LeaseConfig(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"ttl": int64(lease.TTL.Seconds()),
"max_ttl": int64(lease.MaxTTL.Seconds()),
},
}, nil
}
// Lease returns the lease information
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
entry, err := s.Get(leaseConfigKey)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
// Lease configuration information for the secrets issued by this backend
type configLease struct {
TTL time.Duration `json:"ttl" mapstructure:"ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl"`
}
var pathConfigLeaseHelpSyn = "Configure the lease parameters for generated tokens"
var pathConfigLeaseHelpDesc = `
Sets the ttl and max_ttl values for the secrets to be issued by this backend.
Both ttl and max_ttl takes in an integer number of seconds as input as well as
inputs like "1h".
`

View File

@@ -0,0 +1,80 @@
package nomad
import (
"fmt"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/nomad/api"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathCredsCreate(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathTokenRead,
},
}
}
func (b *backend) pathTokenRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
role, err := b.Role(req.Storage, name)
if err != nil {
return nil, errwrap.Wrapf("error retrieving role: {{err}}", err)
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("role %q not found", name)), nil
}
// Determine if we have a lease configuration
leaseConfig, err := b.LeaseConfig(req.Storage)
if err != nil {
return nil, err
}
if leaseConfig == nil {
leaseConfig = &configLease{}
}
// Get the nomad client
c, err := b.client(req.Storage)
if err != nil {
return nil, err
}
// Generate a name for the token
tokenName := fmt.Sprintf("vault-%s-%s-%d", name, req.DisplayName, time.Now().UnixNano())
// Create it
token, _, err := c.ACLTokens().Create(&api.ACLToken{
Name: tokenName,
Type: role.TokenType,
Policies: role.Policies,
Global: role.Global,
}, nil)
if err != nil {
return nil, err
}
// Use the helper to create the secret
resp := b.Secret(SecretTokenType).Response(map[string]interface{}{
"secret_id": token.SecretID,
"accessor_id": token.AccessorID,
}, map[string]interface{}{
"accessor_id": token.AccessorID,
})
resp.Secret.TTL = leaseConfig.TTL
return resp, nil
}

View File

@@ -0,0 +1,189 @@
package nomad
import (
"errors"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "role/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
}
}
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "role/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role",
},
"policies": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "Comma-separated string or list of policies as previously created in Nomad. Required for 'client' token.",
},
"global": &framework.FieldSchema{
Type: framework.TypeBool,
Description: "Boolean value describing if the token should be global or not. Defaults to false.",
},
"type": &framework.FieldSchema{
Type: framework.TypeString,
Default: "client",
Description: `Which type of token to create: 'client'
or 'management'. If a 'management' token,
the "policies" parameter is not required.
Defaults to 'client'.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRolesRead,
logical.CreateOperation: b.pathRolesWrite,
logical.UpdateOperation: b.pathRolesWrite,
logical.DeleteOperation: b.pathRolesDelete,
},
ExistenceCheck: b.rolesExistenceCheck,
}
}
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) rolesExistenceCheck(req *logical.Request, d *framework.FieldData) (bool, error) {
name := d.Get("name").(string)
entry, err := b.Role(req.Storage, name)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) Role(storage logical.Storage, name string) (*roleConfig, error) {
if name == "" {
return nil, errors.New("invalid role name")
}
entry, err := storage.Get("role/" + name)
if err != nil {
return nil, errwrap.Wrapf("error retrieving role: {{err}}", err)
}
if entry == nil {
return nil, nil
}
var result roleConfig
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *backend) pathRoleList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("role/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathRolesRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
role, err := b.Role(req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
// Generate the response
resp := &logical.Response{
Data: map[string]interface{}{
"type": role.TokenType,
"global": role.Global,
"policies": role.Policies,
},
}
return resp, nil
}
func (b *backend) pathRolesWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
role, err := b.Role(req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
role = new(roleConfig)
}
policies, ok := d.GetOk("policies")
if ok {
role.Policies = policies.([]string)
}
role.TokenType = d.Get("type").(string)
switch role.TokenType {
case "client":
if len(role.Policies) == 0 {
return logical.ErrorResponse(
"policies cannot be empty when using client tokens"), nil
}
case "management":
if len(role.Policies) != 0 {
return logical.ErrorResponse(
"policies should be empty when using management tokens"), nil
}
default:
return logical.ErrorResponse(
`type must be "client" or "management"`), nil
}
global, ok := d.GetOk("global")
if ok {
role.Global = global.(bool)
}
entry, err := logical.StorageEntryJSON("role/"+name, role)
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathRolesDelete(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if err := req.Storage.Delete("role/" + name); err != nil {
return nil, err
}
return nil, nil
}
type roleConfig struct {
Policies []string `json:"policies"`
TokenType string `json:"type"`
Global bool `json:"global"`
}

View File

@@ -0,0 +1,68 @@
package nomad
import (
"errors"
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
const (
SecretTokenType = "token"
)
func secretToken(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretTokenType,
Fields: map[string]*framework.FieldSchema{
"token": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Request token",
},
},
Renew: b.secretTokenRenew,
Revoke: b.secretTokenRevoke,
}
}
func (b *backend) secretTokenRenew(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
lease, err := b.LeaseConfig(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
}
return framework.LeaseExtend(lease.TTL, lease.MaxTTL, b.System())(req, d)
}
func (b *backend) secretTokenRevoke(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
c, err := b.client(req.Storage)
if err != nil {
return nil, err
}
if c == nil {
return nil, fmt.Errorf("error getting Nomad client")
}
accessorIDRaw, ok := req.Secret.InternalData["accessor_id"]
if !ok {
return nil, fmt.Errorf("accessor_id is missing on the lease")
}
accessorID, ok := accessorIDRaw.(string)
if !ok {
return nil, errors.New("unable to convert accessor_id")
}
_, err = c.ACLTokens().Delete(accessorID, nil)
if err != nil {
return nil, err
}
return nil, nil
}

View File

@@ -1463,7 +1463,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
//t.Logf("test step %d\nrole vals: %#v\n", stepCount, roleVals)
stepCount++
//t.Logf("test step %d\nissue vals: %#v\n", stepCount, issueTestStep)
roleTestStep.Data = structs.New(roleVals).Map()
roleTestStep.Data = roleVals.ToResponseData()
roleTestStep.Data["generate_lease"] = false
ret = append(ret, roleTestStep)
issueTestStep.Data = structs.New(issueVals).Map()
@@ -1594,38 +1594,38 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals.CodeSigningFlag = false
roleVals.EmailProtectionFlag = false
var usage string
var usage []string
if mathRand.Int()%2 == 1 {
usage = usage + ",DigitalSignature"
usage = append(usage, "DigitalSignature")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",ContentCoMmitment"
usage = append(usage, "ContentCoMmitment")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",KeyEncipherment"
usage = append(usage, "KeyEncipherment")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",DataEncipherment"
usage = append(usage, "DataEncipherment")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",KeyAgreemEnt"
usage = append(usage, "KeyAgreemEnt")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",CertSign"
usage = append(usage, "CertSign")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",CRLSign"
usage = append(usage, "CRLSign")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",EncipherOnly"
usage = append(usage, "EncipherOnly")
}
if mathRand.Int()%2 == 1 {
usage = usage + ",DecipherOnly"
usage = append(usage, "DecipherOnly")
}
roleVals.KeyUsage = usage
parsedKeyUsage := parseKeyUsages(roleVals.KeyUsage)
if parsedKeyUsage == 0 && usage != "" {
if parsedKeyUsage == 0 && len(usage) != 0 {
panic("parsed key usages was zero")
}
parsedKeyUsageUnderTest = parsedKeyUsage
@@ -1759,10 +1759,10 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
commonNames.Localhost = true
addCnTests()
roleVals.AllowedDomains = "foobar.com"
roleVals.AllowedDomains = []string{"foobar.com"}
addCnTests()
roleVals.AllowedDomains = "example.com"
roleVals.AllowedDomains = []string{"example.com"}
roleVals.AllowSubdomains = true
commonNames.SubDomain = true
commonNames.Wildcard = true
@@ -1770,13 +1770,13 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
commonNames.SubSubdomainWildcard = true
addCnTests()
roleVals.AllowedDomains = "foobar.com,example.com"
roleVals.AllowedDomains = []string{"foobar.com", "example.com"}
commonNames.SecondDomain = true
roleVals.AllowBareDomains = true
commonNames.BareDomain = true
addCnTests()
roleVals.AllowedDomains = "foobar.com,*example.com"
roleVals.AllowedDomains = []string{"foobar.com", "*example.com"}
roleVals.AllowGlobDomains = true
commonNames.GlobDomain = true
addCnTests()

View File

@@ -17,14 +17,14 @@ func (b *backend) getGenerationParams(
case "internal":
default:
errorResp = logical.ErrorResponse(
`The "exported" path parameter must be "internal" or "exported"`)
`the "exported" path parameter must be "internal" or "exported"`)
return
}
format = getFormat(data)
if format == "" {
errorResp = logical.ErrorResponse(
`The "format" path parameter must be "pem", "der", or "pem_bundle"`)
`the "format" path parameter must be "pem", "der", "der_pkcs", or "pem_bundle"`)
return
}

View File

@@ -2,6 +2,7 @@ package pki
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
@@ -9,6 +10,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/pem"
"fmt"
"net"
@@ -16,6 +18,7 @@ import (
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/parseutil"
@@ -372,9 +375,9 @@ func validateNames(req *logical.Request, names []string, role *roleEntry) string
}
}
if role.AllowedDomains != "" {
if len(role.AllowedDomains) > 0 {
valid := false
for _, currDomain := range strings.Split(role.AllowedDomains, ",") {
for _, currDomain := range role.AllowedDomains {
// If there is, say, a trailing comma, ignore it
if currDomain == "" {
continue
@@ -1183,3 +1186,66 @@ NameCheck:
return fmt.Errorf("name %q disallowed by CA's permitted DNS domains", badName)
}
func convertRespToPKCS8(resp *logical.Response) error {
privRaw, ok := resp.Data["private_key"]
if !ok {
return nil
}
priv, ok := privRaw.(string)
if !ok {
return fmt.Errorf("error converting response to pkcs8: could not parse original value as string")
}
privKeyTypeRaw, ok := resp.Data["private_key_type"]
if !ok {
return fmt.Errorf("error converting response to pkcs8: %q not found in response", "private_key_type")
}
privKeyType, ok := privKeyTypeRaw.(certutil.PrivateKeyType)
if !ok {
return fmt.Errorf("error converting response to pkcs8: could not parse original type value as string")
}
var keyData []byte
var pemUsed bool
var err error
var signer crypto.Signer
block, _ := pem.Decode([]byte(priv))
if block == nil {
keyData, err = base64.StdEncoding.DecodeString(priv)
if err != nil {
return errwrap.Wrapf("error converting response to pkcs8: error decoding original value: {{err}}", err)
}
} else {
keyData = block.Bytes
pemUsed = true
}
switch privKeyType {
case certutil.RSAPrivateKey:
signer, err = x509.ParsePKCS1PrivateKey(keyData)
case certutil.ECPrivateKey:
signer, err = x509.ParseECPrivateKey(keyData)
default:
return fmt.Errorf("unknown private key type %q", privKeyType)
}
if err != nil {
return errwrap.Wrapf("error converting response to pkcs8: error parsing previous key: {{err}}", err)
}
keyData, err = certutil.MarshalPKCS8PrivateKey(signer)
if err != nil {
return errwrap.Wrapf("error converting response to pkcs8: error marshaling pkcs8 key: {{err}}", err)
}
if pemUsed {
block.Type = "PRIVATE KEY"
block.Bytes = keyData
resp.Data["private_key"] = string(pem.EncodeToMemory(block))
} else {
resp.Data["private_key"] = base64.StdEncoding.EncodeToString(keyData)
}
return nil
}

View File

@@ -22,6 +22,17 @@ key and issuing cert will be appended to the
certificate pem. Defaults to "pem".`,
}
fields["private_key_format"] = &framework.FieldSchema{
Type: framework.TypeString,
Default: "der",
Description: `Format for the returned private key.
Generally the default will be controlled by the "format"
parameter as either base64-encoded DER or PEM-encoded DER.
However, this can be set to "pkcs8" to have the returned
private key contain base64-encoded pkcs8 or PEM-encoded
pkcs8 instead. Defaults to "der".`,
}
fields["ip_sans"] = &framework.FieldSchema{
Type: framework.TypeString,
Description: `The requested IP SANs, if any, in a

View File

@@ -106,6 +106,13 @@ func (b *backend) pathGenerateIntermediate(
}
}
if data.Get("private_key_format").(string) == "pkcs8" {
err = convertRespToPKCS8(resp)
if err != nil {
return nil, err
}
}
cb := &certutil.CertBundle{}
cb.PrivateKey = csrb.PrivateKey
cb.PrivateKeyType = csrb.PrivateKeyType

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
@@ -163,7 +164,7 @@ func (b *backend) pathIssueSignCert(
format := getFormat(data)
if format == "" {
return logical.ErrorResponse(
`The "format" path parameter must be "pem", "der", or "pem_bundle"`), nil
`the "format" path parameter must be "pem", "der", or "pem_bundle"`), nil
}
var caErr error
@@ -171,10 +172,10 @@ func (b *backend) pathIssueSignCert(
switch caErr.(type) {
case errutil.UserError:
return nil, errutil.UserError{Err: fmt.Sprintf(
"Could not fetch the CA certificate (was one set?): %s", caErr)}
"could not fetch the CA certificate (was one set?): %s", caErr)}
case errutil.InternalError:
return nil, errutil.InternalError{Err: fmt.Sprintf(
"Error fetching CA certificate: %s", caErr)}
"error fetching CA certificate: %s", caErr)}
}
var parsedBundle *certutil.ParsedCertBundle
@@ -195,12 +196,12 @@ func (b *backend) pathIssueSignCert(
signingCB, err := signingBundle.ToCertBundle()
if err != nil {
return nil, fmt.Errorf("Error converting raw signing bundle to cert bundle: %s", err)
return nil, errwrap.Wrapf("error converting raw signing bundle to cert bundle: {{err}}", err)
}
cb, err := parsedBundle.ToCertBundle()
if err != nil {
return nil, fmt.Errorf("Error converting raw cert bundle to cert bundle: %s", err)
return nil, errwrap.Wrapf("error converting raw cert bundle to cert bundle: {{err}}", err)
}
respData := map[string]interface{}{
@@ -267,6 +268,13 @@ func (b *backend) pathIssueSignCert(
resp.Secret.TTL = parsedBundle.Certificate.NotAfter.Sub(time.Now())
}
if data.Get("private_key_format").(string) == "pkcs8" {
err = convertRespToPKCS8(resp)
if err != nil {
return nil, err
}
}
if !role.NoStore {
err = req.Storage.Put(&logical.StorageEntry{
Key: "certs/" + normalizeSerial(cb.SerialNumber),

View File

@@ -6,7 +6,6 @@ import (
"strings"
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@@ -57,13 +56,12 @@ name in a request`,
},
"allowed_domains": &framework.FieldSchema{
Type: framework.TypeString,
Default: "",
Type: framework.TypeCommaStringSlice,
Description: `If set, clients can request certificates for
subdomains directly beneath these domains, including
the wildcard subdomains. See the documentation for more
information. This parameter accepts a comma-separated list
of domains.`,
information. This parameter accepts a comma-separated
string or list of domains.`,
},
"allow_bare_domains": &framework.FieldSchema{
@@ -158,14 +156,14 @@ the key_type.`,
},
"key_usage": &framework.FieldSchema{
Type: framework.TypeString,
Default: "DigitalSignature,KeyAgreement,KeyEncipherment",
Description: `A comma-separated set of key usages (not extended
Type: framework.TypeCommaStringSlice,
Default: []string{"DigitalSignature", "KeyAgreement", "KeyEncipherment"},
Description: `A comma-separated string or list of key usages (not extended
key usages). Valid values can be found at
https://golang.org/pkg/crypto/x509/#KeyUsage
-- simply drop the "KeyUsage" part of the name.
To remove all key usages from being set, set
this value to an empty string.`,
this value to an empty list.`,
},
"use_csr_common_name": &framework.FieldSchema{
@@ -217,8 +215,8 @@ leases adversely affect the startup time of Vault.`,
Default: false,
Description: `
If set, certificates issued/signed against this role will not be stored in the
in the storage backend. This can improve performance when issuing large numbers
of certificates. However, certificates issued in this way cannot be enumerated
storage backend. This can improve performance when issuing large numbers of
certificates. However, certificates issued in this way cannot be enumerated
or revoked, so this option is recommended only for certificates that are
non-sensitive, or extremely short-lived. This option implies a value of "false"
for "generate_lease".`,
@@ -267,23 +265,21 @@ func (b *backend) getRole(s logical.Storage, n string) (*roleEntry, error) {
result.AllowBareDomains = true
modified = true
}
if result.AllowedDomainsOld != "" {
result.AllowedDomains = strings.Split(result.AllowedDomainsOld, ",")
result.AllowedDomainsOld = ""
modified = true
}
if result.AllowedBaseDomain != "" {
found := false
allowedDomains := strings.Split(result.AllowedDomains, ",")
if len(allowedDomains) != 0 {
for _, v := range allowedDomains {
if v == result.AllowedBaseDomain {
found = true
break
}
for _, v := range result.AllowedDomains {
if v == result.AllowedBaseDomain {
found = true
break
}
}
if !found {
if result.AllowedDomains == "" {
result.AllowedDomains = result.AllowedBaseDomain
} else {
result.AllowedDomains += "," + result.AllowedBaseDomain
}
result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain)
}
result.AllowedBaseDomain = ""
modified = true
@@ -299,13 +295,23 @@ func (b *backend) getRole(s logical.Storage, n string) (*roleEntry, error) {
modified = true
}
// Upgrade key usages
if result.KeyUsageOld != "" {
result.KeyUsage = strings.Split(result.KeyUsageOld, ",")
result.KeyUsageOld = ""
modified = true
}
if modified {
jsonEntry, err := logical.StorageEntryJSON("role/"+n, &result)
if err != nil {
return nil, err
}
if err := s.Put(jsonEntry); err != nil {
return nil, err
// Only perform upgrades on replication primary
if !strings.Contains(err.Error(), logical.ErrReadOnly.Error()) {
return nil, err
}
}
}
@@ -351,18 +357,8 @@ func (b *backend) pathRoleRead(
}
resp := &logical.Response{
Data: structs.New(role).Map(),
Data: role.ToResponseData(),
}
if resp.Data == nil {
return nil, fmt.Errorf("error converting role data to response")
}
// These values are deprecated and the entries are migrated on read
delete(resp.Data, "lease")
delete(resp.Data, "lease_max")
delete(resp.Data, "allowed_base_domain")
return resp, nil
}
@@ -385,7 +381,7 @@ func (b *backend) pathRoleCreate(
MaxTTL: data.Get("max_ttl").(string),
TTL: (time.Duration(data.Get("ttl").(int)) * time.Second).String(),
AllowLocalhost: data.Get("allow_localhost").(bool),
AllowedDomains: data.Get("allowed_domains").(string),
AllowedDomains: data.Get("allowed_domains").([]string),
AllowBareDomains: data.Get("allow_bare_domains").(bool),
AllowSubdomains: data.Get("allow_subdomains").(bool),
AllowGlobDomains: data.Get("allow_glob_domains").(bool),
@@ -400,7 +396,7 @@ func (b *backend) pathRoleCreate(
KeyBits: data.Get("key_bits").(int),
UseCSRCommonName: data.Get("use_csr_common_name").(bool),
UseCSRSANs: data.Get("use_csr_sans").(bool),
KeyUsage: data.Get("key_usage").(string),
KeyUsage: data.Get("key_usage").([]string),
OU: data.Get("ou").(string),
Organization: data.Get("organization").(string),
GenerateLease: new(bool),
@@ -473,10 +469,9 @@ func (b *backend) pathRoleCreate(
return nil, nil
}
func parseKeyUsages(input string) int {
func parseKeyUsages(input []string) int {
var parsedKeyUsages x509.KeyUsage
splitKeyUsage := strings.Split(input, ",")
for _, k := range splitKeyUsage {
for _, k := range input {
switch strings.ToLower(strings.TrimSpace(k)) {
case "digitalsignature":
parsedKeyUsages |= x509.KeyUsageDigitalSignature
@@ -503,40 +498,77 @@ func parseKeyUsages(input string) int {
}
type roleEntry struct {
LeaseMax string `json:"lease_max" structs:"lease_max" mapstructure:"lease_max"`
Lease string `json:"lease" structs:"lease" mapstructure:"lease"`
MaxTTL string `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"`
TTL string `json:"ttl" structs:"ttl" mapstructure:"ttl"`
AllowLocalhost bool `json:"allow_localhost" structs:"allow_localhost" mapstructure:"allow_localhost"`
AllowedBaseDomain string `json:"allowed_base_domain" structs:"allowed_base_domain" mapstructure:"allowed_base_domain"`
AllowedDomains string `json:"allowed_domains" structs:"allowed_domains" mapstructure:"allowed_domains"`
AllowBaseDomain bool `json:"allow_base_domain" structs:"allow_base_domain" mapstructure:"allow_base_domain"`
AllowBareDomains bool `json:"allow_bare_domains" structs:"allow_bare_domains" mapstructure:"allow_bare_domains"`
AllowTokenDisplayName bool `json:"allow_token_displayname" structs:"allow_token_displayname" mapstructure:"allow_token_displayname"`
AllowSubdomains bool `json:"allow_subdomains" structs:"allow_subdomains" mapstructure:"allow_subdomains"`
AllowGlobDomains bool `json:"allow_glob_domains" structs:"allow_glob_domains" mapstructure:"allow_glob_domains"`
AllowAnyName bool `json:"allow_any_name" structs:"allow_any_name" mapstructure:"allow_any_name"`
EnforceHostnames bool `json:"enforce_hostnames" structs:"enforce_hostnames" mapstructure:"enforce_hostnames"`
AllowIPSANs bool `json:"allow_ip_sans" structs:"allow_ip_sans" mapstructure:"allow_ip_sans"`
ServerFlag bool `json:"server_flag" structs:"server_flag" mapstructure:"server_flag"`
ClientFlag bool `json:"client_flag" structs:"client_flag" mapstructure:"client_flag"`
CodeSigningFlag bool `json:"code_signing_flag" structs:"code_signing_flag" mapstructure:"code_signing_flag"`
EmailProtectionFlag bool `json:"email_protection_flag" structs:"email_protection_flag" mapstructure:"email_protection_flag"`
UseCSRCommonName bool `json:"use_csr_common_name" structs:"use_csr_common_name" mapstructure:"use_csr_common_name"`
UseCSRSANs bool `json:"use_csr_sans" structs:"use_csr_sans" mapstructure:"use_csr_sans"`
KeyType string `json:"key_type" structs:"key_type" mapstructure:"key_type"`
KeyBits int `json:"key_bits" structs:"key_bits" mapstructure:"key_bits"`
MaxPathLength *int `json:",omitempty" structs:"max_path_length,omitempty" mapstructure:"max_path_length"`
KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"`
OU string `json:"ou" structs:"ou" mapstructure:"ou"`
Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
GenerateLease *bool `json:"generate_lease,omitempty" structs:"generate_lease,omitempty"`
NoStore bool `json:"no_store" structs:"no_store" mapstructure:"no_store"`
LeaseMax string `json:"lease_max"`
Lease string `json:"lease"`
MaxTTL string `json:"max_ttl" mapstructure:"max_ttl"`
TTL string `json:"ttl" mapstructure:"ttl"`
AllowLocalhost bool `json:"allow_localhost" mapstructure:"allow_localhost"`
AllowedBaseDomain string `json:"allowed_base_domain" mapstructure:"allowed_base_domain"`
AllowedDomainsOld string `json:"allowed_domains,omit_empty"`
AllowedDomains []string `json:"allowed_domains_list" mapstructure:"allowed_domains"`
AllowBaseDomain bool `json:"allow_base_domain"`
AllowBareDomains bool `json:"allow_bare_domains" mapstructure:"allow_bare_domains"`
AllowTokenDisplayName bool `json:"allow_token_displayname" mapstructure:"allow_token_displayname"`
AllowSubdomains bool `json:"allow_subdomains" mapstructure:"allow_subdomains"`
AllowGlobDomains bool `json:"allow_glob_domains" mapstructure:"allow_glob_domains"`
AllowAnyName bool `json:"allow_any_name" mapstructure:"allow_any_name"`
EnforceHostnames bool `json:"enforce_hostnames" mapstructure:"enforce_hostnames"`
AllowIPSANs bool `json:"allow_ip_sans" mapstructure:"allow_ip_sans"`
ServerFlag bool `json:"server_flag" mapstructure:"server_flag"`
ClientFlag bool `json:"client_flag" mapstructure:"client_flag"`
CodeSigningFlag bool `json:"code_signing_flag" mapstructure:"code_signing_flag"`
EmailProtectionFlag bool `json:"email_protection_flag" mapstructure:"email_protection_flag"`
UseCSRCommonName bool `json:"use_csr_common_name" mapstructure:"use_csr_common_name"`
UseCSRSANs bool `json:"use_csr_sans" mapstructure:"use_csr_sans"`
KeyType string `json:"key_type" mapstructure:"key_type"`
KeyBits int `json:"key_bits" mapstructure:"key_bits"`
MaxPathLength *int `json:",omitempty" mapstructure:"max_path_length"`
KeyUsageOld string `json:"key_usage,omitempty"`
KeyUsage []string `json:"key_usage_list" mapstructure:"key_usage"`
OU string `json:"ou" mapstructure:"ou"`
Organization string `json:"organization" mapstructure:"organization"`
GenerateLease *bool `json:"generate_lease,omitempty"`
NoStore bool `json:"no_store" mapstructure:"no_store"`
// Used internally for signing intermediates
AllowExpirationPastCA bool
}
func (r *roleEntry) ToResponseData() map[string]interface{} {
responseData := map[string]interface{}{
"ttl": r.TTL,
"max_ttl": r.MaxTTL,
"allow_localhost": r.AllowLocalhost,
"allowed_domains": r.AllowedDomains,
"allow_bare_domains": r.AllowBareDomains,
"allow_token_displayname": r.AllowTokenDisplayName,
"allow_subdomains": r.AllowSubdomains,
"allow_glob_domains": r.AllowGlobDomains,
"allow_any_name": r.AllowAnyName,
"enforce_hostnames": r.EnforceHostnames,
"allow_ip_sans": r.AllowIPSANs,
"server_flag": r.ServerFlag,
"client_flag": r.ClientFlag,
"code_signing_flag": r.CodeSigningFlag,
"email_protection_flag": r.EmailProtectionFlag,
"use_csr_common_name": r.UseCSRCommonName,
"use_csr_sans": r.UseCSRSANs,
"key_type": r.KeyType,
"key_bits": r.KeyBits,
"key_usage": r.KeyUsage,
"ou": r.OU,
"organization": r.Organization,
"no_store": r.NoStore,
}
if r.MaxPathLength != nil {
responseData["max_path_length"] = r.MaxPathLength
}
if r.GenerateLease != nil {
responseData["generate_lease"] = r.GenerateLease
}
return responseData
}
const pathListRolesHelpSyn = `List the existing roles in this backend`
const pathListRolesHelpDesc = `Roles will be listed by the role name.`

View File

@@ -120,6 +120,181 @@ func TestPki_RoleGenerateLease(t *testing.T) {
}
}
func TestPki_RoleKeyUsage(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
roleData := map[string]interface{}{
"allowed_domains": "myvault.com",
"ttl": "5h",
"key_usage": []string{"KeyEncipherment", "DigitalSignature"},
}
roleReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/testrole",
Storage: storage,
Data: roleData,
}
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v resp: %#v", err, resp)
}
roleReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v resp: %#v", err, resp)
}
keyUsage := resp.Data["key_usage"].([]string)
if len(keyUsage) != 2 {
t.Fatalf("key_usage should have 2 values")
}
// Check that old key usage value is nil
var role roleEntry
err = mapstructure.Decode(resp.Data, &role)
if err != nil {
t.Fatal(err)
}
if role.KeyUsageOld != "" {
t.Fatalf("old key usage storage value should be blank")
}
// Make it explicit
role.KeyUsageOld = "KeyEncipherment,DigitalSignature"
role.KeyUsage = nil
entry, err := logical.StorageEntryJSON("role/testrole", role)
if err != nil {
t.Fatal(err)
}
if err := storage.Put(entry); err != nil {
t.Fatal(err)
}
// Reading should upgrade key_usage
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v resp: %#v", err, resp)
}
keyUsage = resp.Data["key_usage"].([]string)
if len(keyUsage) != 2 {
t.Fatalf("key_usage should have 2 values")
}
// Read back from storage to ensure upgrade
entry, err = storage.Get("role/testrole")
if err != nil {
t.Fatalf("err: %v", err)
}
if entry == nil {
t.Fatalf("role should not be nil")
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err)
}
if result.KeyUsageOld != "" {
t.Fatal("old key usage value should be blank")
}
if len(result.KeyUsage) != 2 {
t.Fatal("key_usage should have 2 values")
}
}
func TestPki_RoleAllowedDomains(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
roleData := map[string]interface{}{
"allowed_domains": []string{"foobar.com", "*example.com"},
"ttl": "5h",
}
roleReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/testrole",
Storage: storage,
Data: roleData,
}
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v resp: %#v", err, resp)
}
roleReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v resp: %#v", err, resp)
}
allowedDomains := resp.Data["allowed_domains"].([]string)
if len(allowedDomains) != 2 {
t.Fatalf("allowed_domains should have 2 values")
}
// Check that old key usage value is nil
var role roleEntry
err = mapstructure.Decode(resp.Data, &role)
if err != nil {
t.Fatal(err)
}
if role.AllowedDomainsOld != "" {
t.Fatalf("old allowed_domains storage value should be blank")
}
// Make it explicit
role.AllowedDomainsOld = "foobar.com,*example.com"
role.AllowedDomains = nil
entry, err := logical.StorageEntryJSON("role/testrole", role)
if err != nil {
t.Fatal(err)
}
if err := storage.Put(entry); err != nil {
t.Fatal(err)
}
// Reading should upgrade key_usage
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v resp: %#v", err, resp)
}
allowedDomains = resp.Data["allowed_domains"].([]string)
if len(allowedDomains) != 2 {
t.Fatalf("allowed_domains should have 2 values")
}
// Read back from storage to ensure upgrade
entry, err = storage.Get("role/testrole")
if err != nil {
t.Fatalf("err: %v", err)
}
if entry == nil {
t.Fatalf("role should not be nil")
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
t.Fatalf("err: %v", err)
}
if result.AllowedDomainsOld != "" {
t.Fatal("old allowed_domains value should be blank")
}
if len(result.AllowedDomains) != 2 {
t.Fatal("allowed_domains should have 2 values")
}
}
func TestPki_RoleNoStore(t *testing.T) {
var resp *logical.Response
var err error

View File

@@ -149,7 +149,7 @@ func (b *backend) pathCAGenerateRoot(
cb, err := parsedBundle.ToCertBundle()
if err != nil {
return nil, fmt.Errorf("error converting raw cert bundle to cert bundle: %s", err)
return nil, errwrap.Wrapf("error converting raw cert bundle to cert bundle: {{err}}", err)
}
resp := &logical.Response{
@@ -188,6 +188,13 @@ func (b *backend) pathCAGenerateRoot(
}
}
if data.Get("private_key_format").(string) == "pkcs8" {
err = convertRespToPKCS8(resp)
if err != nil {
return nil, err
}
}
// Store it as the CA bundle
entry, err = logical.StorageEntryJSON("config/ca_bundle", cb)
if err != nil {
@@ -205,7 +212,7 @@ func (b *backend) pathCAGenerateRoot(
Value: parsedBundle.CertificateBytes,
})
if err != nil {
return nil, fmt.Errorf("Unable to store certificate locally: %v", err)
return nil, errwrap.Wrapf("unable to store certificate locally: {{err}}", err)
}
// For ease of later use, also store just the certificate at a known

View File

@@ -25,6 +25,12 @@ func Backend(conf *logical.BackendConfig) *backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),

View File

@@ -26,6 +26,12 @@ func Backend() *backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),

View File

@@ -271,6 +271,31 @@ func TestSSHBackend_Lookup(t *testing.T) {
})
}
func TestSSHBackend_RoleList(t *testing.T) {
testOTPRoleData := map[string]interface{}{
"key_type": testOTPKeyType,
"default_user": testUserName,
"cidr_list": testCIDRList,
}
resp1 := map[string]interface{}{}
resp2 := map[string]interface{}{
"keys": []string{testOTPRoleName},
"key_info": map[string]interface{}{
testOTPRoleName: map[string]interface{}{
"key_type": testOTPKeyType,
},
},
}
logicaltest.Test(t, logicaltest.TestCase{
Factory: testingFactory,
Steps: []logicaltest.TestStep{
testRoleList(t, resp1),
testRoleWrite(t, testOTPRoleName, testOTPRoleData),
testRoleList(t, resp2),
},
})
}
func TestSSHBackend_DynamicKeyCreate(t *testing.T) {
testDynamicRoleData := map[string]interface{}{
"key_type": testDynamicKeyType,
@@ -962,6 +987,25 @@ func testRoleWrite(t *testing.T, name string, data map[string]interface{}) logic
}
}
func testRoleList(t *testing.T, expected map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "roles",
Check: func(resp *logical.Response) error {
if resp == nil {
return fmt.Errorf("nil response")
}
if resp.Data == nil {
return fmt.Errorf("nil data")
}
if !reflect.DeepEqual(resp.Data, expected) {
return fmt.Errorf("Invalid response:\nactual:%#v\nexpected is %#v", resp.Data, expected)
}
return nil
},
}
}
func testRoleRead(t *testing.T, roleName string, expected map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,

View File

@@ -83,7 +83,7 @@ func pathRoles(b *backend) *framework.Path {
Description: `
[Required for Dynamic type] [Not applicable for OTP type] [Not applicable for CA type]
Admin user at remote host. The shared key being registered should be
for this user and should have root privileges. Everytime a dynamic
for this user and should have root privileges. Everytime a dynamic
credential is being generated for other users, Vault uses this admin
username to login to remote host and install the generated credential
for the other user.`,
@@ -175,7 +175,7 @@ func pathRoles(b *backend) *framework.Path {
`,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
The lease duration if no specific lease duration is
@@ -184,7 +184,7 @@ func pathRoles(b *backend) *framework.Path {
the value of max_ttl.`,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
The maximum allowed lease duration
@@ -386,15 +386,15 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
return logical.ErrorResponse("missing admin username"), nil
}
// This defaults to 1024 and it can also be 2048.
// This defaults to 1024 and it can also be 2048 and 4096.
keyBits := d.Get("key_bits").(int)
if keyBits != 0 && keyBits != 1024 && keyBits != 2048 {
if keyBits != 0 && keyBits != 1024 && keyBits != 2048 && keyBits != 4096 {
return logical.ErrorResponse("invalid key_bits field"), nil
}
// If user has not set this field, default it to 1024
// If user has not set this field, default it to 2048
if keyBits == 0 {
keyBits = 1024
keyBits = 2048
}
// Store all the fields required by dynamic key type
@@ -433,9 +433,9 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
}
func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework.FieldData) (*sshRole, *logical.Response) {
ttl := time.Duration(data.Get("ttl").(int)) * time.Second
maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second
role := &sshRole{
MaxTTL: data.Get("max_ttl").(string),
TTL: data.Get("ttl").(string),
AllowedCriticalOptions: data.Get("allowed_critical_options").(string),
AllowedExtensions: data.Get("allowed_extensions").(string),
AllowUserCertificates: data.Get("allow_user_certificates").(bool),
@@ -457,44 +457,12 @@ func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework
defaultCriticalOptions := convertMapToStringValue(data.Get("default_critical_options").(map[string]interface{}))
defaultExtensions := convertMapToStringValue(data.Get("default_extensions").(map[string]interface{}))
var maxTTL time.Duration
maxSystemTTL := b.System().MaxLeaseTTL()
if len(role.MaxTTL) == 0 {
maxTTL = maxSystemTTL
} else {
var err error
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf(
"Invalid max ttl: %s", err))
}
}
if maxTTL > maxSystemTTL {
return nil, logical.ErrorResponse("Requested max TTL is higher than backend maximum")
if ttl != 0 && maxTTL != 0 && ttl > maxTTL {
return nil, logical.ErrorResponse(
`"ttl" value must be less than "max_ttl" when both are specified`)
}
ttl := b.System().DefaultLeaseTTL()
if len(role.TTL) != 0 {
var err error
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf(
"Invalid ttl: %s", err))
}
}
if ttl > maxTTL {
// If they are using the system default, cap it to the role max;
// if it was specified on the command line, make it an error
if len(role.TTL) == 0 {
ttl = maxTTL
} else {
return nil, logical.ErrorResponse(
`"ttl" value must be less than "max_ttl" and/or backend default max lease TTL value`,
)
}
}
// Persist clamped TTLs
// Persist TTLs
role.TTL = ttl.String()
role.MaxTTL = maxTTL.String()
role.DefaultCriticalOptions = defaultCriticalOptions
@@ -520,13 +488,115 @@ func (b *backend) getRole(s logical.Storage, n string) (*sshRole, error) {
return &result, nil
}
// parseRole converts a sshRole object into its map[string]interface representation,
// with appropriate values for each KeyType. If the KeyType is invalid, it will retun
// an error.
func (b *backend) parseRole(role *sshRole) (map[string]interface{}, error) {
var result map[string]interface{}
switch role.KeyType {
case KeyTypeOTP:
result = map[string]interface{}{
"default_user": role.DefaultUser,
"cidr_list": role.CIDRList,
"exclude_cidr_list": role.ExcludeCIDRList,
"key_type": role.KeyType,
"port": role.Port,
"allowed_users": role.AllowedUsers,
}
case KeyTypeCA:
ttl, err := parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return nil, err
}
maxTTL, err := parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return nil, err
}
result = map[string]interface{}{
"allowed_users": role.AllowedUsers,
"allowed_domains": role.AllowedDomains,
"default_user": role.DefaultUser,
"ttl": int64(ttl.Seconds()),
"max_ttl": int64(maxTTL.Seconds()),
"allowed_critical_options": role.AllowedCriticalOptions,
"allowed_extensions": role.AllowedExtensions,
"allow_user_certificates": role.AllowUserCertificates,
"allow_host_certificates": role.AllowHostCertificates,
"allow_bare_domains": role.AllowBareDomains,
"allow_subdomains": role.AllowSubdomains,
"allow_user_key_ids": role.AllowUserKeyIDs,
"key_id_format": role.KeyIDFormat,
"key_type": role.KeyType,
"default_critical_options": role.DefaultCriticalOptions,
"default_extensions": role.DefaultExtensions,
}
case KeyTypeDynamic:
result = map[string]interface{}{
"key": role.KeyName,
"admin_user": role.AdminUser,
"default_user": role.DefaultUser,
"cidr_list": role.CIDRList,
"exclude_cidr_list": role.ExcludeCIDRList,
"port": role.Port,
"key_type": role.KeyType,
"key_bits": role.KeyBits,
"allowed_users": role.AllowedUsers,
"key_option_specs": role.KeyOptionSpecs,
// Returning install script will make the output look messy.
// But this is one way for clients to see the script that is
// being used to install the key. If there is some problem,
// the script can be modified and configured by clients.
"install_script": role.InstallScript,
}
default:
return nil, fmt.Errorf("invalid key type: %v", role.KeyType)
}
return result, nil
}
func (b *backend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("roles/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
keyInfo := map[string]interface{}{}
for _, entry := range entries {
role, err := b.getRole(req.Storage, entry)
if err != nil {
// On error, log warning and continue
if b.Logger().IsWarn() {
b.Logger().Warn("ssh: error getting role info", "role", entry, "error", err)
}
continue
}
if role == nil {
// On empty role, log warning and continue
if b.Logger().IsWarn() {
b.Logger().Warn("ssh: no role info found", "role", entry)
}
continue
}
roleInfo, err := b.parseRole(role)
if err != nil {
if b.Logger().IsWarn() {
b.Logger().Warn("ssh: error parsing role info", "role", entry, "error", err)
}
continue
}
if keyType, ok := roleInfo["key_type"]; ok {
keyInfo[entry] = map[string]interface{}{
"key_type": keyType,
}
}
}
return logical.ListResponseWithInfo(entries, keyInfo), nil
}
func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
@@ -538,60 +608,14 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
return nil, nil
}
// Return information should be based on the key type of the role
if role.KeyType == KeyTypeOTP {
return &logical.Response{
Data: map[string]interface{}{
"default_user": role.DefaultUser,
"cidr_list": role.CIDRList,
"exclude_cidr_list": role.ExcludeCIDRList,
"key_type": role.KeyType,
"port": role.Port,
"allowed_users": role.AllowedUsers,
},
}, nil
} else if role.KeyType == KeyTypeCA {
return &logical.Response{
Data: map[string]interface{}{
"allowed_users": role.AllowedUsers,
"allowed_domains": role.AllowedDomains,
"default_user": role.DefaultUser,
"max_ttl": role.MaxTTL,
"ttl": role.TTL,
"allowed_critical_options": role.AllowedCriticalOptions,
"allowed_extensions": role.AllowedExtensions,
"allow_user_certificates": role.AllowUserCertificates,
"allow_host_certificates": role.AllowHostCertificates,
"allow_bare_domains": role.AllowBareDomains,
"allow_subdomains": role.AllowSubdomains,
"allow_user_key_ids": role.AllowUserKeyIDs,
"key_id_format": role.KeyIDFormat,
"key_type": role.KeyType,
"default_critical_options": role.DefaultCriticalOptions,
"default_extensions": role.DefaultExtensions,
},
}, nil
} else {
return &logical.Response{
Data: map[string]interface{}{
"key": role.KeyName,
"admin_user": role.AdminUser,
"default_user": role.DefaultUser,
"cidr_list": role.CIDRList,
"exclude_cidr_list": role.ExcludeCIDRList,
"port": role.Port,
"key_type": role.KeyType,
"key_bits": role.KeyBits,
"allowed_users": role.AllowedUsers,
"key_option_specs": role.KeyOptionSpecs,
// Returning install script will make the output look messy.
// But this is one way for clients to see the script that is
// being used to install the key. If there is some problem,
// the script can be modified and configured by clients.
"install_script": role.InstallScript,
},
}, nil
roleInfo, err := b.parseRole(role)
if err != nil {
return nil, err
}
return &logical.Response{
Data: roleInfo,
}, nil
}
func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {

View File

@@ -43,7 +43,7 @@ func pathSign(b *backend) *framework.Path {
Description: `The desired role with configuration for this request.`,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `The requested Time To Live for the SSH certificate;
sets the expiration date. If not specified
the role default, backend default, or system
@@ -345,40 +345,34 @@ func (b *backend) calculateExtensions(data *framework.FieldData, role *sshRole)
}
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
var ttl, maxTTL time.Duration
var ttlField string
ttlFieldInt, ok := data.GetOk("ttl")
if !ok {
ttlField = role.TTL
} else {
ttlField = ttlFieldInt.(string)
}
var err error
if len(ttlField) == 0 {
ttlRaw, specifiedTTL := data.GetOk("ttl")
if specifiedTTL {
ttl = time.Duration(ttlRaw.(int)) * time.Second
} else {
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return 0, err
}
}
if ttl == 0 {
ttl = b.System().DefaultLeaseTTL()
} else {
var err error
ttl, err = parseutil.ParseDurationSecond(ttlField)
if err != nil {
return 0, fmt.Errorf("invalid requested ttl: %s", err)
}
}
if len(role.MaxTTL) == 0 {
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, err
}
if maxTTL == 0 {
maxTTL = b.System().MaxLeaseTTL()
} else {
var err error
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, fmt.Errorf("invalid requested max ttl: %s", err)
}
}
if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if len(ttlField) == 0 {
if !specifiedTTL {
ttl = maxTTL
} else {
return 0, fmt.Errorf("ttl is larger than maximum allowed (%d)", maxTTL/time.Second)

View File

@@ -43,6 +43,8 @@ func Backend(conf *logical.BackendConfig) *backend {
b.pathHMAC(),
b.pathSign(),
b.pathVerify(),
b.pathBackup(),
b.pathRestore(),
},
Secrets: []*framework.Secret{},

View File

@@ -38,6 +38,191 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
return b, config.StorageView
}
func TestTransit_RSA(t *testing.T) {
testTransit_RSA(t, "rsa-2048")
testTransit_RSA(t, "rsa-4096")
}
func testTransit_RSA(t *testing.T, keyType string) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
keyReq := &logical.Request{
Path: "keys/rsa",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"type": keyType,
},
Storage: storage,
}
resp, err = b.HandleRequest(keyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox"
encryptReq := &logical.Request{
Path: "encrypt/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"plaintext": plaintext,
},
}
resp, err = b.HandleRequest(encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
ciphertext1 := resp.Data["ciphertext"].(string)
decryptReq := &logical.Request{
Path: "decrypt/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"ciphertext": ciphertext1,
},
}
resp, err = b.HandleRequest(decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
decryptedPlaintext := resp.Data["plaintext"]
if plaintext != decryptedPlaintext {
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
}
// Rotate the key
rotateReq := &logical.Request{
Path: "keys/rsa/rotate",
Operation: logical.UpdateOperation,
Storage: storage,
}
resp, err = b.HandleRequest(rotateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
// Encrypt again
resp, err = b.HandleRequest(encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
ciphertext2 := resp.Data["ciphertext"].(string)
if ciphertext1 == ciphertext2 {
t.Fatalf("expected different ciphertexts")
}
// See if the older ciphertext can still be decrypted
resp, err = b.HandleRequest(decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if resp.Data["plaintext"].(string) != plaintext {
t.Fatal("failed to decrypt old ciphertext after rotating the key")
}
// Decrypt the new ciphertext
decryptReq.Data = map[string]interface{}{
"ciphertext": ciphertext2,
}
resp, err = b.HandleRequest(decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if resp.Data["plaintext"].(string) != plaintext {
t.Fatal("failed to decrypt ciphertext after rotating the key")
}
signReq := &logical.Request{
Path: "sign/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"input": plaintext,
},
}
resp, err = b.HandleRequest(signReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
signature := resp.Data["signature"].(string)
verifyReq := &logical.Request{
Path: "verify/rsa",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"input": plaintext,
"signature": signature,
},
}
resp, err = b.HandleRequest(verifyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if !resp.Data["valid"].(bool) {
t.Fatalf("failed to verify the RSA signature")
}
signReq.Data = map[string]interface{}{
"input": plaintext,
"algorithm": "invalid",
}
resp, err = b.HandleRequest(signReq)
if err != nil {
t.Fatal(err)
}
if resp == nil || !resp.IsError() {
t.Fatal("expected an error response")
}
signReq.Data = map[string]interface{}{
"input": plaintext,
"algorithm": "sha2-512",
}
resp, err = b.HandleRequest(signReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
signature = resp.Data["signature"].(string)
verifyReq.Data = map[string]interface{}{
"input": plaintext,
"signature": signature,
}
resp, err = b.HandleRequest(verifyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if resp.Data["valid"].(bool) {
t.Fatalf("expected validation to fail")
}
verifyReq.Data = map[string]interface{}{
"input": plaintext,
"signature": signature,
"algorithm": "sha2-512",
}
resp, err = b.HandleRequest(verifyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
if !resp.Data["valid"].(bool) {
t.Fatalf("failed to verify the RSA signature")
}
}
func TestBackend_basic(t *testing.T) {
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
@@ -634,7 +819,7 @@ func TestKeyUpgrade(t *testing.T) {
if p.Key != nil ||
p.Keys == nil ||
len(p.Keys) != 1 ||
!reflect.DeepEqual(p.Keys[1].Key, key) {
!reflect.DeepEqual(p.Keys[strconv.Itoa(1)].Key, key) {
t.Errorf("bad key migration, result is %#v", p.Keys)
}
}
@@ -1091,3 +1276,38 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
// Wait for them all to finish
wg.Wait()
}
func TestBadInput(t *testing.T) {
var b *backend
sysView := logical.TestSystemView()
storage := &logical.InmemStorage{}
b = Backend(&logical.BackendConfig{
StorageView: storage,
System: sysView,
})
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/test",
}
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
if resp != nil {
t.Fatal("expected nil response")
}
req.Path = "decrypt/test"
req.Data = map[string]interface{}{
"ciphertext": "vault:v1:abcd",
}
_, err = b.HandleRequest(req)
if err == nil {
t.Fatal("expected error")
}
}

Some files were not shown because too many files have changed in this diff Show More