mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 17:52:32 +00:00
Merge branch 'master-oss' into sethvargo/cli-magic
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -64,6 +64,7 @@ tags
|
||||
# compiled output
|
||||
ui/dist
|
||||
ui/tmp
|
||||
ui/root
|
||||
|
||||
# dependencies
|
||||
ui/node_modules
|
||||
|
||||
245
CHANGELOG.md
245
CHANGELOG.md
@@ -1,27 +1,254 @@
|
||||
## 0.8.4 (Unreleased)
|
||||
## 0.9.2 (Unreleased)
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* physical/s3: Allow using paths with S3 for non-AWS deployments [GH-3730]
|
||||
* physical/s3: Add ability to disable SSL for non-AWS deployments [GH-3730]
|
||||
|
||||
## 0.9.1 (December 21st, 2017)
|
||||
|
||||
DEPRECATIONS/CHANGES:
|
||||
|
||||
* AppRole Case Sensitivity: In prior versions of Vault, `list` operations
|
||||
against AppRole roles would require preserving case in the role name, even
|
||||
though most other operations within AppRole are case-insensitive with
|
||||
respect to the role name. This has been fixed; existing roles will behave as
|
||||
they have in the past, but new roles will act case-insensitively in these
|
||||
cases.
|
||||
* Token Auth Backend Roles parameter types: For `allowed_policies` and
|
||||
`disallowed_policies` in role definitions in the token auth backend, input
|
||||
can now be a comma-separated string or an array of strings. Reading a role
|
||||
will now return arrays for these parameters.
|
||||
* Transit key exporting: You can now mark a key in the `transit` backend as
|
||||
`exportable` at any time, rather than just at creation time; however, once
|
||||
this value is set, it still cannot be unset.
|
||||
* PKI Secret Backend Roles parameter types: For `allowed_domains` and
|
||||
`key_usage` in role definitions in the PKI secret backend, input
|
||||
can now be a comma-separated string or an array of strings. Reading a role
|
||||
will now return arrays for these parameters.
|
||||
* SSH Dynamic Keys Method Defaults to 2048-bit Keys: When using the dynamic
|
||||
key method in the SSH backend, the default is now to use 2048-bit keys if no
|
||||
specific key bit size is specified.
|
||||
* Consul Secret Backend lease handling: The `consul` secret backend can now
|
||||
accept both strings and integer numbers of seconds for its lease value. The
|
||||
value returned on a role read will be an integer number of seconds instead
|
||||
of a human-friendly string.
|
||||
* Unprintable characters not allowed in API paths: Unprintable characters are
|
||||
no longer allowed in names in the API (paths and path parameters), with an
|
||||
extra restriction on whitespace characters. Allowed characters are those
|
||||
that are considered printable by Unicode plus spaces.
|
||||
|
||||
FEATURES:
|
||||
|
||||
* **Transit Backup/Restore**: The `transit` backend now supports a backup
|
||||
operation that can export a given key, including all key versions and
|
||||
configuration, as well as a restore operation allowing import into another
|
||||
Vault.
|
||||
* **gRPC Database Plugins**: Database plugins now use gRPC for transport,
|
||||
allowing them to be written in other languages.
|
||||
* **Nomad Secret Backend**: Nomad ACL tokens can now be generated and revoked
|
||||
using Vault.
|
||||
* **TLS Cert Auth Backend Improvements**: The `cert` auth backend can now
|
||||
match against custom certificate extensions via exact or glob matching, and
|
||||
additionally supports max_ttl and periodic token toggles.
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* auth/cert: Support custom certificate constraints [GH-3634]
|
||||
* auth/cert: Support setting `max_ttl` and `period` [GH-3642]
|
||||
* audit/file: Setting a file mode of `0000` will now disable Vault from
|
||||
automatically `chmod`ing the log file [GH-3649]
|
||||
* auth/github: The legacy MFA system can now be used with the GitHub auth
|
||||
backend [GH-3696]
|
||||
* auth/okta: The legacy MFA system can now be used with the Okta auth backend
|
||||
[GH-3653]
|
||||
* auth/token: `allowed_policies` and `disallowed_policies` can now be specified
|
||||
as a comma-separated string or an array of strings [GH-3641]
|
||||
* command/server: The log level can now be specified with `VAULT_LOG_LEVEL`
|
||||
[GH-3721]
|
||||
* core: Period values from auth backends will now be checked and applied to the
|
||||
TTL value directly by core on login and renewal requests [GH-3677]
|
||||
* database/mongodb: Add optional `write_concern` parameter, which can be set
|
||||
during database configuration. This establishes a session-wide [write
|
||||
concern](https://docs.mongodb.com/manual/reference/write-concern/) for the
|
||||
lifecycle of the mount [GH-3646]
|
||||
* http: Request path containing non-printable characters will return 400 - Bad
|
||||
Request [GH-3697]
|
||||
* mfa/okta: Filter a given email address as a login filter, allowing operation
|
||||
when login email and account email are different
|
||||
* plugins: Make Vault more resilient when unsealing when plugins are
|
||||
unavailable [GH-3686]
|
||||
* secret/pki: `allowed_domains` and `key_usage` can now be specified
|
||||
as a comma-separated string or an array of strings [GH-3642]
|
||||
* secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593]
|
||||
* secret/consul: The Consul secret backend now uses the value of `lease` set
|
||||
on the role, if set, when renewing a secret. [GH-3796]
|
||||
* storage/mysql: Don't attempt database creation if it exists, which can help
|
||||
under certain permissions constraints [GH-3716]
|
||||
|
||||
BUG FIXES:
|
||||
|
||||
* api/status (enterprise): Fix status reporting when using an auto seal
|
||||
* auth/approle: Fix case-sensitive/insensitive comparison issue [GH-3665]
|
||||
* auth/cert: Return `allowed_names` on role read [GH-3654]
|
||||
* auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496]
|
||||
[GH-3625] [GH-3656]
|
||||
* core: Fix seal status reporting when using an autoseal
|
||||
* core: Add creation path to wrap info for a control group token
|
||||
* core: Fix potential panic that could occur using plugins when a node
|
||||
transitioned from active to standby [GH-3638]
|
||||
* core: Fix memory ballooning when a connection would connect to the cluster
|
||||
port and then go away -- redux! [GH-3680]
|
||||
* core: Replace recursive token revocation logic with depth-first logic, which
|
||||
can avoid hitting stack depth limits in extreme cases [GH-2348]
|
||||
* core: When doing a read on configured audited-headers, properly handle case
|
||||
insensitivity [GH-3701]
|
||||
* core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable
|
||||
* database/mysql: Allow the creation statement to use commands that are not yet
|
||||
supported by the prepare statement protocol [GH-3619]
|
||||
* plugin/auth-gcp: Fix IAM roles when using `allow_gce_inference` [VPAG-19]
|
||||
|
||||
## 0.9.0.1 (November 21st, 2017) (Enterprise Only)
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* auth/gcp: Support seal wrapping of configuration parameters
|
||||
* auth/kubernetes: Support seal wrapping of configuration parameters
|
||||
|
||||
BUG FIXES:
|
||||
|
||||
* Fix an upgrade issue with some physical backends when migrating from legacy
|
||||
HSM stored key support to the new Seal Wrap mechanism
|
||||
|
||||
## 0.9.0 (November 14th, 2017)
|
||||
|
||||
DEPRECATIONS/CHANGES:
|
||||
|
||||
* HSM config parameter requirements: When using Vault with an HSM, a new
|
||||
paramter is required: `hmac_key_label`. This performs a similar function to
|
||||
`key_label` but for the HMAC key Vault will use. Vault will generate a
|
||||
suitable key if this value is specified and `generate_key` is set true.
|
||||
* API HTTP client behavior: When calling `NewClient` the API no longer
|
||||
modifies the provided client/transport. In particular this means it will no
|
||||
longer enable redirection limiting and HTTP/2 support on custom clients. It
|
||||
is suggested that if you want to make changes to an HTTP client that you use
|
||||
one created by `DefaultConfig` as a starting point.
|
||||
* AWS EC2 client nonce behavior: The client nonce generated by the backend
|
||||
that gets returned along with the authentication response will be audited in
|
||||
plaintext. If this is undesired, the clients can choose to supply a custom
|
||||
nonce to the login endpoint. The custom nonce set by the client will from
|
||||
now on, not be returned back with the authentication response, and hence not
|
||||
audit logged.
|
||||
* AWS Auth role options: The API will now error when trying to create or
|
||||
update a role with the mutually-exclusive options
|
||||
`disallow_reauthentication` and `allow_instance_migration`.
|
||||
* SSH CA role read changes: When reading back a role from the `ssh` backend,
|
||||
the TTL/max TTL values will now be an integer number of seconds rather than
|
||||
a string. This better matches the API elsewhere in Vault.
|
||||
* SSH role list changes: When listing roles from the `ssh` backend via the API,
|
||||
the response data will additionally return a `key_info` map that will contain
|
||||
a map of each key with a corresponding object containing the `key_type`.
|
||||
* More granularity in audit logs: Audit request and response entires are still
|
||||
in RFC3339 format but now have a granularity of nanoseconds.
|
||||
* High availability related values have been moved out of the `storage` and
|
||||
`ha_storage` stanzas, and into the top-level configuration. `redirect_addr`
|
||||
has been renamed to `api_addr`. The stanzas still support accepting
|
||||
HA-related values to maintain backward compatibility, but top-level values
|
||||
will take precedence.
|
||||
* A new `seal` stanza has been added to the configuration file, which is
|
||||
optional and enables configuration of the seal type to use for additional
|
||||
data protection, such as using HSM or Cloud KMS solutions to encrypt and
|
||||
decrypt data.
|
||||
|
||||
FEATURES:
|
||||
|
||||
* **RSA Support for Transit Backend**: Transit backend can now generate RSA
|
||||
keys which can be used for encryption and signing. [GH-3489]
|
||||
* **Identity System**: Now in open source and with significant enhancements,
|
||||
Identity is an integrated system for understanding users across tokens and
|
||||
enabling easier management of users directly and via groups.
|
||||
* **External Groups in Identity**: Vault can now automatically assign users
|
||||
and systems to groups in Identity based on their membership in external
|
||||
groups.
|
||||
* **Seal Wrap / FIPS 140-2 Compatibility (Enterprise)**: Vault can now take
|
||||
advantage of FIPS 140-2-certified HSMs to ensure that Critical Security
|
||||
Parameters are protected in a compliant fashion. Vault's implementation has
|
||||
received a statement of compliance from Leidos.
|
||||
* **Control Groups (Enterprise)**: Require multiple members of an Identity
|
||||
group to authorize a requested action before it is allowed to run.
|
||||
* **Cloud Auto-Unseal (Enterprise)**: Automatically unseal Vault using AWS KMS
|
||||
and GCP CKMS.
|
||||
* **Sentinel Integration (Enterprise)**: Take advantage of HashiCorp Sentinel
|
||||
to create extremely flexible access control policies -- even on
|
||||
unauthenticated endpoints.
|
||||
* **Barrier Rekey Support for Auto-Unseal (Enterprise)**: When using auto-unsealing
|
||||
functionality, the `rekey` operation is now supported; it uses recovery keys
|
||||
to authorize the master key rekey.
|
||||
* **Operation Token for Disaster Recovery Actions (Enterprise)**: When using
|
||||
Disaster Recovery replication, a token can be created that can be used to
|
||||
authorize actions such as promotion and updating primary information, rather
|
||||
than using recovery keys.
|
||||
* **Trigger Auto-Unseal with Recovery Keys (Enterprise)**: When using
|
||||
auto-unsealing, a request to unseal Vault can be triggered by a threshold of
|
||||
recovery keys, rather than requiring the Vault process to be restarted.
|
||||
* **UI Redesign (Enterprise)**: All new experience for the Vault Enterprise
|
||||
UI. The look and feel has been completely redesigned to give users a better
|
||||
experience and make managing secrets fast and easy.
|
||||
* **UI: SSH Secret Backend (Enterprise)**: Configure an SSH secret backend,
|
||||
create and browse roles. And use them to sign keys or generate one time
|
||||
passwords.
|
||||
* **UI: AWS Secret Backend (Enterprise)**: You can now configure the AWS
|
||||
backend via the Vault Enterprise UI. In addition you can create roles,
|
||||
browse the roles and Generate IAM Credentials from them in the UI.
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* api: Add ability to set custom headers on each call [GH-3394]
|
||||
* command/server: Add config option to disable requesting client certificates
|
||||
[GH-3373]
|
||||
* secret/cassandra: Work around Cassandra ignoring consistency levels for a
|
||||
user listing query [GH-3469]
|
||||
* secret/pki: Allow entering URLs for `pki` as both comma-separated strings and JSON
|
||||
arrays [GH-3409]
|
||||
* secret/transit: Sign and verify operations now support a `none` hash
|
||||
algorithm to allow signing/verifying pre-hashed data [GH-3448]
|
||||
* core: Disallow mounting underneath an existing path, not just over [GH-2919]
|
||||
* physical/file: Use `700` as permissions when creating directories. The files
|
||||
themselves were `600` and are all encrypted, but this doesn't hurt.
|
||||
* secret/aws: Add ability to use custom IAM/STS endpoints [GH-3416]
|
||||
* secret/cassandra: Work around Cassandra ignoring consistency levels for a
|
||||
user listing query [GH-3469]
|
||||
* secret/pki: Private keys can now be marshalled as PKCS#8 [GH-3518]
|
||||
* secret/pki: Allow entering URLs for `pki` as both comma-separated strings and JSON
|
||||
arrays [GH-3409]
|
||||
* secret/ssh: Role TTL/max TTL can now be specified as either a string or an
|
||||
integer [GH-3507]
|
||||
* secret/transit: Sign and verify operations now support a `none` hash
|
||||
algorithm to allow signing/verifying pre-hashed data [GH-3448]
|
||||
* secret/database: Add the ability to glob allowed roles in the Database Backend [GH-3387]
|
||||
* ui (enterprise): Support for RSA keys in the transit backend
|
||||
* ui (enterprise): Support for DR Operation Token generation, promoting, and
|
||||
updating primary on DR Secondary clusters
|
||||
|
||||
BUG FIXES:
|
||||
|
||||
* api: Fix panic when setting a custom HTTP client but with a nil transport
|
||||
[GH-3437]
|
||||
[GH-3435] [GH-3437]
|
||||
* api: Fix authing to the `cert` backend when the CA for the client cert is
|
||||
not known to the server's listener [GH-2946]
|
||||
* auth/approle: Create role ID index during read if a role is missing one [GH-3561]
|
||||
* auth/aws: Don't allow mutually exclusive options [GH-3291]
|
||||
* auth/radius: Fix logging in in some situations [GH-3461]
|
||||
* core: Fix memleak when a connection would connect to the cluster port and
|
||||
then go away [GH-3513]
|
||||
* core: Fix panic if a single-use token is used to step-down or seal [GH-3497]
|
||||
* core: Set rather than add headers to prevent some duplicated headers in
|
||||
responses when requests were forwarded to the active node [GH-3485]
|
||||
* physical/etcd3: Fix some listing issues due to how etcd3 does prefix
|
||||
matching [GH-3406]
|
||||
* physical/etcd3: Fix case where standbys can lose their etcd client lease
|
||||
[GH-3031]
|
||||
* physical/file: Fix listing when underscores are the first component of a
|
||||
path [GH-3476]
|
||||
* plugins: Allow response errors to be returned from backend plugins [GH-3412]
|
||||
* secret/transit: Fix panic if the length of the input ciphertext was less
|
||||
than the expected nonce length [GH-3521]
|
||||
* ui (enterprise): Reinstate support for generic secret backends - this was
|
||||
erroneously removed in a previous release
|
||||
|
||||
## 0.8.3 (September 19th, 2017)
|
||||
|
||||
@@ -117,7 +344,7 @@ IMPROVEMENTS:
|
||||
|
||||
* audit/file: Allow specifying `stdout` as the `file_path` to log to standard
|
||||
output [GH-3235]
|
||||
* auth/aws: Allow wildcards in `bound_iam_principal_id` [GH-3213]
|
||||
* auth/aws: Allow wildcards in `bound_iam_principal_arn` [GH-3213]
|
||||
* auth/okta: Compare groups case-insensitively since Okta is only
|
||||
case-preserving [GH-3240]
|
||||
* auth/okta: Standardize Okta configuration APIs across backends [GH-3245]
|
||||
|
||||
5
Makefile
5
Makefile
@@ -79,7 +79,7 @@ vet:
|
||||
prep: fmtcheck
|
||||
@sh -c "'$(CURDIR)/scripts/goversioncheck.sh' '$(GO_VERSION_MIN)'"
|
||||
go generate $(go list ./... | grep -v /vendor/)
|
||||
cp .hooks/* .git/hooks/
|
||||
@if [ -d .git/hooks ]; then cp .hooks/* .git/hooks/; fi
|
||||
|
||||
# bootstrap the build by downloading additional tools
|
||||
bootstrap:
|
||||
@@ -92,8 +92,11 @@ proto:
|
||||
protoc -I helper/forwarding -I vault -I ../../.. vault/*.proto --go_out=plugins=grpc:vault
|
||||
protoc -I helper/storagepacker helper/storagepacker/types.proto --go_out=plugins=grpc:helper/storagepacker
|
||||
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
|
||||
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
|
||||
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
|
||||
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
|
||||
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
|
||||
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go
|
||||
|
||||
fmtcheck:
|
||||
@sh -c "'$(CURDIR)/scripts/gofmtcheck.sh'"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
Vault [](https://travis-ci.org/hashicorp/vault) [](https://gitter.im/hashicorp-vault/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [](https://www.hashicorp.com/products/vault/?utm_source=github&utm_medium=banner&utm_campaign=github-vault-enterprise)
|
||||
=========
|
||||
# Vault [](https://travis-ci.org/hashicorp/vault) [](https://gitter.im/hashicorp-vault/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [](https://www.hashicorp.com/products/vault/?utm_source=github&utm_medium=banner&utm_campaign=github-vault-enterprise)
|
||||
|
||||
----
|
||||
|
||||
**Please note**: We take Vault's security and our users' trust very seriously. If you believe you have found a security issue in Vault, _please responsibly disclose_ by contacting us at [security@hashicorp.com](mailto:security@hashicorp.com).
|
||||
|
||||
=========
|
||||
----
|
||||
|
||||
- Website: https://www.vaultproject.io
|
||||
- IRC: `#vault-tool` on Freenode
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// testHTTPServer creates a test HTTP server that handles requests until
|
||||
@@ -19,9 +17,6 @@ func testHTTPServer(
|
||||
}
|
||||
|
||||
server := &http.Server{Handler: handler}
|
||||
if err := http2.ConfigureServer(server, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go server.Serve(ln)
|
||||
|
||||
config := DefaultConfig()
|
||||
|
||||
236
api/client.go
236
api/client.go
@@ -13,12 +13,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/go-rootcerts"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
"github.com/sethgrid/pester"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const EnvVaultAddress = "VAULT_ADDR"
|
||||
@@ -43,24 +43,31 @@ type WrappingLookupFunc func(operation, path string) string
|
||||
|
||||
// Config is used to configure the creation of the client.
|
||||
type Config struct {
|
||||
modifyLock sync.RWMutex
|
||||
|
||||
// Address is the address of the Vault server. This should be a complete
|
||||
// URL such as "http://vault.example.com". If you need a custom SSL
|
||||
// cert or want to enable insecure mode, you need to specify a custom
|
||||
// HttpClient.
|
||||
Address string
|
||||
|
||||
// HttpClient is the HTTP client to use, which will currently always have the
|
||||
// same values as http.DefaultClient. This is used to control redirect behavior.
|
||||
// HttpClient is the HTTP client to use. Vault sets sane defaults for the
|
||||
// http.Client and its associated http.Transport created in DefaultConfig.
|
||||
// If you must modify Vault's defaults, it is suggested that you start with
|
||||
// that client and modify as needed rather than start with an empty client
|
||||
// (or http.DefaultClient).
|
||||
HttpClient *http.Client
|
||||
|
||||
redirectSetup sync.Once
|
||||
|
||||
// MaxRetries controls the maximum number of times to retry when a 5xx error
|
||||
// occurs. Set to 0 or less to disable retrying. Defaults to 0.
|
||||
MaxRetries int
|
||||
|
||||
// Timeout is for setting custom timeout parameter in the HttpClient
|
||||
Timeout time.Duration
|
||||
|
||||
// If there is an error when creating the configuration, this will be the
|
||||
// error
|
||||
Error error
|
||||
}
|
||||
|
||||
// TLSConfig contains the parameters needed to configure TLS on the HTTP client
|
||||
@@ -93,60 +100,91 @@ type TLSConfig struct {
|
||||
//
|
||||
// The default Address is https://127.0.0.1:8200, but this can be overridden by
|
||||
// setting the `VAULT_ADDR` environment variable.
|
||||
//
|
||||
// If an error is encountered, this will return nil.
|
||||
func DefaultConfig() *Config {
|
||||
config := &Config{
|
||||
Address: "https://127.0.0.1:8200",
|
||||
HttpClient: cleanhttp.DefaultClient(),
|
||||
}
|
||||
config.HttpClient.Timeout = time.Second * 60
|
||||
|
||||
transport := config.HttpClient.Transport.(*http.Transport)
|
||||
transport.TLSHandshakeTimeout = 10 * time.Second
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
if err := http2.ConfigureTransport(transport); err != nil {
|
||||
config.Error = err
|
||||
return config
|
||||
}
|
||||
|
||||
if v := os.Getenv(EnvVaultAddress); v != "" {
|
||||
config.Address = v
|
||||
if err := config.ReadEnvironment(); err != nil {
|
||||
config.Error = err
|
||||
return config
|
||||
}
|
||||
|
||||
// Ensure redirects are not automatically followed
|
||||
// Note that this is sane for the API client as it has its own
|
||||
// redirect handling logic (and thus also for command/meta),
|
||||
// but in e.g. http_test actual redirect handling is necessary
|
||||
config.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
// Returning this value causes the Go net library to not close the
|
||||
// response body and to nil out the error. Otherwise pester tries
|
||||
// three times on every redirect because it sees an error from this
|
||||
// function (to prevent redirects) passing through to it.
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// ConfigureTLS takes a set of TLS configurations and applies those to the the HTTP client.
|
||||
// ConfigureTLS takes a set of TLS configurations and applies those to the the
|
||||
// HTTP client.
|
||||
func (c *Config) ConfigureTLS(t *TLSConfig) error {
|
||||
if c.HttpClient == nil {
|
||||
c.HttpClient = DefaultConfig().HttpClient
|
||||
}
|
||||
clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
|
||||
|
||||
var clientCert tls.Certificate
|
||||
foundClientCert := false
|
||||
if t.CACert != "" || t.CAPath != "" || t.ClientCert != "" || t.ClientKey != "" || t.Insecure {
|
||||
if t.ClientCert != "" && t.ClientKey != "" {
|
||||
var err error
|
||||
clientCert, err = tls.LoadX509KeyPair(t.ClientCert, t.ClientKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
foundClientCert = true
|
||||
} else if t.ClientCert != "" || t.ClientKey != "" {
|
||||
return fmt.Errorf("Both client cert and client key must be provided")
|
||||
|
||||
switch {
|
||||
case t.ClientCert != "" && t.ClientKey != "":
|
||||
var err error
|
||||
clientCert, err = tls.LoadX509KeyPair(t.ClientCert, t.ClientKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
foundClientCert = true
|
||||
case t.ClientCert != "" || t.ClientKey != "":
|
||||
return fmt.Errorf("Both client cert and client key must be provided")
|
||||
}
|
||||
|
||||
if t.CACert != "" || t.CAPath != "" {
|
||||
rootConfig := &rootcerts.Config{
|
||||
CAFile: t.CACert,
|
||||
CAPath: t.CAPath,
|
||||
}
|
||||
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
|
||||
rootConfig := &rootcerts.Config{
|
||||
CAFile: t.CACert,
|
||||
CAPath: t.CAPath,
|
||||
if t.Insecure {
|
||||
clientTLSConfig.InsecureSkipVerify = true
|
||||
}
|
||||
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientTLSConfig.InsecureSkipVerify = t.Insecure
|
||||
|
||||
if foundClientCert {
|
||||
clientTLSConfig.Certificates = []tls.Certificate{clientCert}
|
||||
// We use this function to ignore the server's preferential list of
|
||||
// CAs, otherwise any CA used for the cert auth backend must be in the
|
||||
// server's CA pool
|
||||
clientTLSConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &clientCert, nil
|
||||
}
|
||||
}
|
||||
|
||||
if t.TLSServerName != "" {
|
||||
clientTLSConfig.ServerName = t.TLSServerName
|
||||
}
|
||||
@@ -154,9 +192,8 @@ func (c *Config) ConfigureTLS(t *TLSConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadEnvironment reads configuration information from the
|
||||
// environment. If there is an error, no configuration value
|
||||
// is updated.
|
||||
// ReadEnvironment reads configuration information from the environment. If
|
||||
// there is an error, no configuration value is updated.
|
||||
func (c *Config) ReadEnvironment() error {
|
||||
var envAddress string
|
||||
var envCACert string
|
||||
@@ -218,6 +255,10 @@ func (c *Config) ReadEnvironment() error {
|
||||
TLSServerName: envTLSServerName,
|
||||
Insecure: envInsecure,
|
||||
}
|
||||
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
if err := c.ConfigureTLS(t); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -237,10 +278,9 @@ func (c *Config) ReadEnvironment() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client is the client to the Vault API. Create a client with NewClient. Note:
|
||||
// it is not safe to modify client configuration from multiple goroutines at
|
||||
// once. Set configuration first, then run requests.
|
||||
// Client is the client to the Vault API. Create a client with NewClient.
|
||||
type Client struct {
|
||||
modifyLock sync.RWMutex
|
||||
addr *url.URL
|
||||
config *Config
|
||||
token string
|
||||
@@ -250,24 +290,29 @@ type Client struct {
|
||||
policyOverride bool
|
||||
}
|
||||
|
||||
// SetMFACreds sets the MFA credentials supplied either via the environment
|
||||
// variable or via the command line.
|
||||
func (c *Client) SetMFACreds(creds []string) {
|
||||
c.mfaCreds = creds
|
||||
}
|
||||
|
||||
// NewClient returns a new client for the given configuration.
|
||||
//
|
||||
// If the configuration is nil, Vault will use configuration from
|
||||
// DefaultConfig(), which is the recommended starting configuration.
|
||||
//
|
||||
// If the environment variable `VAULT_TOKEN` is present, the token will be
|
||||
// automatically added to the client. Otherwise, you must manually call
|
||||
// `SetToken()`.
|
||||
func NewClient(c *Config) (*Client, error) {
|
||||
if c == nil {
|
||||
c = DefaultConfig()
|
||||
if err := c.ReadEnvironment(); err != nil {
|
||||
return nil, fmt.Errorf("error reading environment: %v", err)
|
||||
}
|
||||
def := DefaultConfig()
|
||||
if def == nil {
|
||||
return nil, fmt.Errorf("could not create/read default configuration")
|
||||
}
|
||||
if def.Error != nil {
|
||||
return nil, errwrap.Wrapf("error encountered setting up default configuration: {{err}}", def.Error)
|
||||
}
|
||||
|
||||
if c == nil {
|
||||
c = def
|
||||
}
|
||||
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
u, err := url.Parse(c.Address)
|
||||
if err != nil {
|
||||
@@ -275,41 +320,19 @@ func NewClient(c *Config) (*Client, error) {
|
||||
}
|
||||
|
||||
if c.HttpClient == nil {
|
||||
c.HttpClient = DefaultConfig().HttpClient
|
||||
c.HttpClient = def.HttpClient
|
||||
}
|
||||
if c.HttpClient.Transport == nil {
|
||||
c.HttpClient.Transport = cleanhttp.DefaultTransport()
|
||||
c.HttpClient.Transport = def.HttpClient.Transport
|
||||
}
|
||||
|
||||
if tp, ok := c.HttpClient.Transport.(*http.Transport); ok {
|
||||
if err := http2.ConfigureTransport(tp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
redirFunc := func() {
|
||||
// Ensure redirects are not automatically followed
|
||||
// Note that this is sane for the API client as it has its own
|
||||
// redirect handling logic (and thus also for command/meta),
|
||||
// but in e.g. http_test actual redirect handling is necessary
|
||||
c.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
// Returning this value causes the Go net library to not close the
|
||||
// response body and to nil out the error. Otherwise pester tries
|
||||
// three times on every redirect because it sees an error from this
|
||||
// function (to prevent redirects) passing through to it.
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
}
|
||||
|
||||
c.redirectSetup.Do(redirFunc)
|
||||
|
||||
client := &Client{
|
||||
addr: u,
|
||||
config: c,
|
||||
}
|
||||
|
||||
if token := os.Getenv(EnvVaultToken); token != "" {
|
||||
client.SetToken(token)
|
||||
client.token = token
|
||||
}
|
||||
|
||||
return client, nil
|
||||
@@ -319,6 +342,9 @@ func NewClient(c *Config) (*Client, error) {
|
||||
// "<Scheme>://<Host>:<Port>". Setting this on a client will override the
|
||||
// value of VAULT_ADDR environment variable.
|
||||
func (c *Client) SetAddress(addr string) error {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
var err error
|
||||
if c.addr, err = url.Parse(addr); err != nil {
|
||||
return fmt.Errorf("failed to set address: %v", err)
|
||||
@@ -329,56 +355,112 @@ func (c *Client) SetAddress(addr string) error {
|
||||
|
||||
// Address returns the Vault URL the client is configured to connect to
|
||||
func (c *Client) Address() string {
|
||||
c.modifyLock.RLock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
|
||||
return c.addr.String()
|
||||
}
|
||||
|
||||
// SetMaxRetries sets the number of retries that will be used in the case of certain errors
|
||||
func (c *Client) SetMaxRetries(retries int) {
|
||||
c.modifyLock.RLock()
|
||||
c.config.modifyLock.Lock()
|
||||
defer c.config.modifyLock.Unlock()
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
c.config.MaxRetries = retries
|
||||
}
|
||||
|
||||
// SetClientTimeout sets the client request timeout
|
||||
func (c *Client) SetClientTimeout(timeout time.Duration) {
|
||||
c.modifyLock.RLock()
|
||||
c.config.modifyLock.Lock()
|
||||
defer c.config.modifyLock.Unlock()
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
c.config.Timeout = timeout
|
||||
}
|
||||
|
||||
// SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs
|
||||
// for a given operation and path
|
||||
func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.wrappingLookupFunc = lookupFunc
|
||||
}
|
||||
|
||||
// SetMFACreds sets the MFA credentials supplied either via the environment
|
||||
// variable or via the command line.
|
||||
func (c *Client) SetMFACreds(creds []string) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.mfaCreds = creds
|
||||
}
|
||||
|
||||
// Token returns the access token being used by this client. It will
|
||||
// return the empty string if there is no token set.
|
||||
func (c *Client) Token() string {
|
||||
c.modifyLock.RLock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
|
||||
return c.token
|
||||
}
|
||||
|
||||
// SetToken sets the token directly. This won't perform any auth
|
||||
// verification, it simply sets the token properly for future requests.
|
||||
func (c *Client) SetToken(v string) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.token = v
|
||||
}
|
||||
|
||||
// ClearToken deletes the token if it is set or does nothing otherwise.
|
||||
func (c *Client) ClearToken() {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.token = ""
|
||||
}
|
||||
|
||||
// SetHeaders sets the headers to be used for future requests.
|
||||
func (c *Client) SetHeaders(headers http.Header) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.headers = headers
|
||||
}
|
||||
|
||||
// Clone creates a copy of this client.
|
||||
// Clone creates a new client with the same configuration. Note that the same
|
||||
// underlying http.Client is used; modifying the client from more than one
|
||||
// goroutine at once may not be safe, so modify the client as needed and then
|
||||
// clone.
|
||||
func (c *Client) Clone() (*Client, error) {
|
||||
return NewClient(c.config)
|
||||
c.modifyLock.RLock()
|
||||
c.config.modifyLock.RLock()
|
||||
config := c.config
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
newConfig := &Config{
|
||||
Address: config.Address,
|
||||
HttpClient: config.HttpClient,
|
||||
MaxRetries: config.MaxRetries,
|
||||
Timeout: config.Timeout,
|
||||
}
|
||||
config.modifyLock.RUnlock()
|
||||
|
||||
return NewClient(newConfig)
|
||||
}
|
||||
|
||||
// SetPolicyOverride sets whether requests should be sent with the policy
|
||||
// override flag to request overriding soft-mandatory Sentinel policies (both
|
||||
// RGPs and EGPs)
|
||||
func (c *Client) SetPolicyOverride(override bool) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
c.policyOverride = override
|
||||
}
|
||||
|
||||
@@ -386,6 +468,9 @@ func (c *Client) SetPolicyOverride(override bool) {
|
||||
// configured for this client. This is an advanced method and generally
|
||||
// doesn't need to be called externally.
|
||||
func (c *Client) NewRequest(method, requestPath string) *Request {
|
||||
c.modifyLock.RLock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
|
||||
// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV
|
||||
// record and take the highest match; this is not designed for high-availability, just discovery
|
||||
var host string = c.addr.Host
|
||||
@@ -442,6 +527,11 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
|
||||
// a Vault server not configured with this client. This is an advanced operation
|
||||
// that generally won't need to be called externally.
|
||||
func (c *Client) RawRequest(r *Request) (*Response, error) {
|
||||
c.modifyLock.RLock()
|
||||
c.config.modifyLock.RLock()
|
||||
defer c.config.modifyLock.RUnlock()
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
redirectCount := 0
|
||||
START:
|
||||
req, err := r.ToHTTP()
|
||||
|
||||
@@ -163,8 +163,8 @@ func TestClientEnvSettings(t *testing.T) {
|
||||
if len(tlsConfig.RootCAs.Subjects()) == 0 {
|
||||
t.Fatalf("bad: expected a cert pool with at least one subject")
|
||||
}
|
||||
if len(tlsConfig.Certificates) != 1 {
|
||||
t.Fatalf("bad: expected client tls config to have a client certificate")
|
||||
if tlsConfig.GetClientCertificate == nil {
|
||||
t.Fatalf("bad: expected client tls config to have a certificate getter")
|
||||
}
|
||||
if tlsConfig.InsecureSkipVerify != true {
|
||||
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
|
||||
@@ -213,3 +213,16 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
client1, err1 := NewClient(nil)
|
||||
if err1 != nil {
|
||||
t.Fatalf("NewClient failed: %v", err1)
|
||||
}
|
||||
client2, err2 := client1.Clone()
|
||||
if err2 != nil {
|
||||
t.Fatalf("Clone failed: %v", err2)
|
||||
}
|
||||
|
||||
_ = client2
|
||||
}
|
||||
|
||||
@@ -50,12 +50,13 @@ var (
|
||||
type Renewer struct {
|
||||
l sync.Mutex
|
||||
|
||||
client *Client
|
||||
secret *Secret
|
||||
grace time.Duration
|
||||
random *rand.Rand
|
||||
doneCh chan error
|
||||
renewCh chan *RenewOutput
|
||||
client *Client
|
||||
secret *Secret
|
||||
grace time.Duration
|
||||
random *rand.Rand
|
||||
increment int
|
||||
doneCh chan error
|
||||
renewCh chan *RenewOutput
|
||||
|
||||
stopped bool
|
||||
stopCh chan struct{}
|
||||
@@ -79,6 +80,11 @@ type RenewerInput struct {
|
||||
// RenewBuffer is the size of the buffered channel where renew messages are
|
||||
// dispatched.
|
||||
RenewBuffer int
|
||||
|
||||
// The new TTL, in seconds, that should be set on the lease. The TTL set
|
||||
// here may or may not be honored by the vault server, based on Vault
|
||||
// configuration or any associated max TTL values.
|
||||
Increment int
|
||||
}
|
||||
|
||||
// RenewOutput is the metadata returned to the client (if it's listening) to
|
||||
@@ -120,12 +126,13 @@ func (c *Client) NewRenewer(i *RenewerInput) (*Renewer, error) {
|
||||
}
|
||||
|
||||
return &Renewer{
|
||||
client: c,
|
||||
secret: secret,
|
||||
grace: grace,
|
||||
random: random,
|
||||
doneCh: make(chan error, 1),
|
||||
renewCh: make(chan *RenewOutput, renewBuffer),
|
||||
client: c,
|
||||
secret: secret,
|
||||
grace: grace,
|
||||
increment: i.Increment,
|
||||
random: random,
|
||||
doneCh: make(chan error, 1),
|
||||
renewCh: make(chan *RenewOutput, renewBuffer),
|
||||
|
||||
stopped: false,
|
||||
stopCh: make(chan struct{}),
|
||||
@@ -245,7 +252,7 @@ func (r *Renewer) renewLease() error {
|
||||
}
|
||||
|
||||
// Renew the lease.
|
||||
renewal, err := client.Sys().Renew(leaseID, 0)
|
||||
renewal, err := client.Sys().Renew(leaseID, r.increment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -224,6 +224,7 @@ func (s *Secret) TokenTTL() (time.Duration, error) {
|
||||
// available in WrappedAccessor.
|
||||
type SecretWrapInfo struct {
|
||||
Token string `json:"token"`
|
||||
Accessor string `json:"accessor"`
|
||||
TTL int `json:"ttl"`
|
||||
CreationTime time.Time `json:"creation_time"`
|
||||
CreationPath string `json:"creation_path"`
|
||||
|
||||
@@ -26,6 +26,7 @@ func TestParseSecret(t *testing.T) {
|
||||
],
|
||||
"wrap_info": {
|
||||
"token": "token",
|
||||
"accessor": "accessor",
|
||||
"ttl": 60,
|
||||
"creation_time": "2016-06-07T15:52:10-04:00",
|
||||
"wrapped_accessor": "abcd1234"
|
||||
@@ -51,6 +52,7 @@ func TestParseSecret(t *testing.T) {
|
||||
},
|
||||
WrapInfo: &api.SecretWrapInfo{
|
||||
Token: "token",
|
||||
Accessor: "accessor",
|
||||
TTL: 60,
|
||||
CreationTime: rawTime,
|
||||
WrappedAccessor: "abcd1234",
|
||||
|
||||
@@ -87,6 +87,7 @@ type EnableAuthOptions struct {
|
||||
Config AuthConfigInput `json:"config" structs:"config"`
|
||||
Local bool `json:"local" structs:"local"`
|
||||
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty"`
|
||||
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
|
||||
}
|
||||
|
||||
type AuthConfigInput struct {
|
||||
@@ -99,6 +100,7 @@ type AuthMount struct {
|
||||
Accessor string `json:"accessor" structs:"accessor" mapstructure:"accessor"`
|
||||
Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"`
|
||||
Local bool `json:"local" structs:"local" mapstructure:"local"`
|
||||
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
|
||||
}
|
||||
|
||||
type AuthConfigOutput struct {
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
package api
|
||||
|
||||
func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) {
|
||||
r := c.c.NewRequest("GET", "/v1/sys/generate-root/attempt")
|
||||
return c.generateRootStatusCommon("/v1/sys/generate-root/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateDROperationTokenStatus() (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootStatusCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootStatusCommon(path string) (*GenerateRootStatusResponse, error) {
|
||||
r := c.c.NewRequest("GET", path)
|
||||
resp, err := c.c.RawRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -14,12 +22,20 @@ func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) {
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateRootInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootInitCommon("/v1/sys/generate-root/attempt", otp, pgpKey)
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateDROperationTokenInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootInitCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt", otp, pgpKey)
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootInitCommon(path, otp, pgpKey string) (*GenerateRootStatusResponse, error) {
|
||||
body := map[string]interface{}{
|
||||
"otp": otp,
|
||||
"pgp_key": pgpKey,
|
||||
}
|
||||
|
||||
r := c.c.NewRequest("PUT", "/v1/sys/generate-root/attempt")
|
||||
r := c.c.NewRequest("PUT", path)
|
||||
if err := r.SetJSONBody(body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -36,7 +52,15 @@ func (c *Sys) GenerateRootInit(otp, pgpKey string) (*GenerateRootStatusResponse,
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateRootCancel() error {
|
||||
r := c.c.NewRequest("DELETE", "/v1/sys/generate-root/attempt")
|
||||
return c.generateRootCancelCommon("/v1/sys/generate-root/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateDROperationTokenCancel() error {
|
||||
return c.generateRootCancelCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt")
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootCancelCommon(path string) error {
|
||||
r := c.c.NewRequest("DELETE", path)
|
||||
resp, err := c.c.RawRequest(r)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
@@ -45,12 +69,20 @@ func (c *Sys) GenerateRootCancel() error {
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateRootUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootUpdateCommon("/v1/sys/generate-root/update", shard, nonce)
|
||||
}
|
||||
|
||||
func (c *Sys) GenerateDROperationTokenUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) {
|
||||
return c.generateRootUpdateCommon("/v1/sys/replication/dr/secondary/generate-operation-token/update", shard, nonce)
|
||||
}
|
||||
|
||||
func (c *Sys) generateRootUpdateCommon(path, shard, nonce string) (*GenerateRootStatusResponse, error) {
|
||||
body := map[string]interface{}{
|
||||
"key": shard,
|
||||
"nonce": nonce,
|
||||
}
|
||||
|
||||
r := c.c.NewRequest("PUT", "/v1/sys/generate-root/update")
|
||||
r := c.c.NewRequest("PUT", path)
|
||||
if err := r.SetJSONBody(body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -72,6 +104,7 @@ type GenerateRootStatusResponse struct {
|
||||
Progress int
|
||||
Required int
|
||||
Complete bool
|
||||
EncodedToken string `json:"encoded_token"`
|
||||
EncodedRootToken string `json:"encoded_root_token"`
|
||||
PGPFingerprint string `json:"pgp_fingerprint"`
|
||||
}
|
||||
|
||||
@@ -125,6 +125,7 @@ type MountInput struct {
|
||||
Config MountConfigInput `json:"config" structs:"config"`
|
||||
Local bool `json:"local" structs:"local"`
|
||||
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name"`
|
||||
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
|
||||
}
|
||||
|
||||
type MountConfigInput struct {
|
||||
@@ -132,7 +133,6 @@ type MountConfigInput struct {
|
||||
MaxLeaseTTL string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"`
|
||||
ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"`
|
||||
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"`
|
||||
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
|
||||
}
|
||||
|
||||
type MountOutput struct {
|
||||
@@ -141,6 +141,7 @@ type MountOutput struct {
|
||||
Accessor string `json:"accessor" structs:"accessor"`
|
||||
Config MountConfigOutput `json:"config" structs:"config"`
|
||||
Local bool `json:"local" structs:"local"`
|
||||
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
|
||||
}
|
||||
|
||||
type MountConfigOutput struct {
|
||||
@@ -148,5 +149,4 @@ type MountConfigOutput struct {
|
||||
MaxLeaseTTL int `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"`
|
||||
ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"`
|
||||
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"`
|
||||
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
|
||||
}
|
||||
|
||||
@@ -49,12 +49,14 @@ func sealStatusRequest(c *Sys, r *Request) (*SealStatusResponse, error) {
|
||||
}
|
||||
|
||||
type SealStatusResponse struct {
|
||||
Sealed bool `json:"sealed"`
|
||||
T int `json:"t"`
|
||||
N int `json:"n"`
|
||||
Progress int `json:"progress"`
|
||||
Nonce string `json:"nonce"`
|
||||
Version string `json:"version"`
|
||||
ClusterName string `json:"cluster_name,omitempty"`
|
||||
ClusterID string `json:"cluster_id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Sealed bool `json:"sealed"`
|
||||
T int `json:"t"`
|
||||
N int `json:"n"`
|
||||
Progress int `json:"progress"`
|
||||
Nonce string `json:"nonce"`
|
||||
Version string `json:"version"`
|
||||
ClusterName string `json:"cluster_name,omitempty"`
|
||||
ClusterID string `json:"cluster_id,omitempty"`
|
||||
RecoverySeal bool `json:"recovery_seal"`
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func (f *AuditFormatter) FormatRequest(
|
||||
}
|
||||
|
||||
if !config.OmitTime {
|
||||
reqEntry.Time = time.Now().UTC().Format(time.RFC3339)
|
||||
reqEntry.Time = time.Now().UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
return f.AuditFormatWriter.WriteRequest(w, reqEntry)
|
||||
@@ -242,12 +242,13 @@ func (f *AuditFormatter) FormatResponse(
|
||||
|
||||
// Cache and restore accessor in the response
|
||||
if resp != nil {
|
||||
var accessor, wrappedAccessor string
|
||||
var accessor, wrappedAccessor, wrappingAccessor string
|
||||
if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" {
|
||||
accessor = resp.Auth.Accessor
|
||||
}
|
||||
if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" {
|
||||
wrappedAccessor = resp.WrapInfo.WrappedAccessor
|
||||
wrappingAccessor = resp.WrapInfo.Accessor
|
||||
}
|
||||
if err := Hash(salt, resp); err != nil {
|
||||
return err
|
||||
@@ -258,6 +259,9 @@ func (f *AuditFormatter) FormatResponse(
|
||||
if wrappedAccessor != "" {
|
||||
resp.WrapInfo.WrappedAccessor = wrappedAccessor
|
||||
}
|
||||
if wrappingAccessor != "" {
|
||||
resp.WrapInfo.Accessor = wrappingAccessor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,6 +305,7 @@ func (f *AuditFormatter) FormatResponse(
|
||||
respWrapInfo = &AuditResponseWrapInfo{
|
||||
TTL: int(resp.WrapInfo.TTL / time.Second),
|
||||
Token: token,
|
||||
Accessor: resp.WrapInfo.Accessor,
|
||||
CreationTime: resp.WrapInfo.CreationTime.Format(time.RFC3339Nano),
|
||||
CreationPath: resp.WrapInfo.CreationPath,
|
||||
WrappedAccessor: resp.WrapInfo.WrappedAccessor,
|
||||
@@ -347,7 +352,7 @@ func (f *AuditFormatter) FormatResponse(
|
||||
}
|
||||
|
||||
if !config.OmitTime {
|
||||
respEntry.Time = time.Now().UTC().Format(time.RFC3339)
|
||||
respEntry.Time = time.Now().UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
return f.AuditFormatWriter.WriteResponse(w, respEntry)
|
||||
@@ -412,6 +417,7 @@ type AuditSecret struct {
|
||||
type AuditResponseWrapInfo struct {
|
||||
TTL int `json:"ttl"`
|
||||
Token string `json:"token"`
|
||||
Accessor string `json:"accessor"`
|
||||
CreationTime string `json:"creation_time"`
|
||||
CreationPath string `json:"creation_path"`
|
||||
WrappedAccessor string `json:"wrapped_accessor,omitempty"`
|
||||
|
||||
@@ -93,6 +93,7 @@ func Hash(salter *salt.Salt, raw interface{}) error {
|
||||
}
|
||||
|
||||
s.Token = fn(s.Token)
|
||||
s.Accessor = fn(s.Accessor)
|
||||
|
||||
if s.WrappedAccessor != "" {
|
||||
s.WrappedAccessor = fn(s.WrappedAccessor)
|
||||
|
||||
@@ -148,6 +148,7 @@ func TestHash(t *testing.T) {
|
||||
WrapInfo: &wrapping.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "bar",
|
||||
Accessor: "flimflam",
|
||||
CreationTime: now,
|
||||
WrappedAccessor: "bar",
|
||||
},
|
||||
@@ -160,6 +161,7 @@ func TestHash(t *testing.T) {
|
||||
WrapInfo: &wrapping.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
Accessor: "hmac-sha256:7c9c6fe666d0af73b3ebcfbfabe6885015558213208e6635ba104047b22f6390",
|
||||
CreationTime: now,
|
||||
WrappedAccessor: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
},
|
||||
@@ -206,6 +208,11 @@ func TestHash(t *testing.T) {
|
||||
if err := Hash(localSalt, tc.Input); err != nil {
|
||||
t.Fatalf("err: %s\n\n%s", err, input)
|
||||
}
|
||||
if _, ok := tc.Input.(*logical.Response); ok {
|
||||
if !reflect.DeepEqual(tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo) {
|
||||
t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output\n%#v", input, tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo)
|
||||
}
|
||||
}
|
||||
if !reflect.DeepEqual(tc.Input, tc.Output) {
|
||||
t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output\n%#v", input, tc.Input, tc.Output)
|
||||
}
|
||||
|
||||
@@ -75,7 +75,9 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mode = os.FileMode(m)
|
||||
if m != 0 {
|
||||
mode = os.FileMode(m)
|
||||
}
|
||||
}
|
||||
|
||||
b := &Backend{
|
||||
@@ -247,13 +249,15 @@ func (b *Backend) open() error {
|
||||
}
|
||||
|
||||
// Change the file mode in case the log file already existed. We special
|
||||
// case /dev/null since we can't chmod it
|
||||
// case /dev/null since we can't chmod it and bypass if the mode is zero
|
||||
switch b.path {
|
||||
case "/dev/null":
|
||||
default:
|
||||
err = os.Chmod(b.path, b.mode)
|
||||
if err != nil {
|
||||
return err
|
||||
if b.mode != 0 {
|
||||
err = os.Chmod(b.path, b.mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
|
||||
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL")
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
||||
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL")
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package approle
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
@@ -52,7 +51,7 @@ func (b *backend) pathLoginUpdateAliasLookahead(req *logical.Request, data *fram
|
||||
func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, roleName, metadata, _, err := b.validateCredentials(req, data)
|
||||
if err != nil || role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to validate SecretID: %s", err)), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to validate credentials: %v", err)), nil
|
||||
}
|
||||
|
||||
// Always include the role name, for later filtering
|
||||
@@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
|
||||
Policies: role.Policies,
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
Renewable: true,
|
||||
TTL: role.TokenTTL,
|
||||
},
|
||||
Alias: &logical.Alias{
|
||||
Name: role.RoleID,
|
||||
},
|
||||
}
|
||||
|
||||
// If 'Period' is set, use the value of 'Period' as the TTL.
|
||||
// Otherwise, set the normal TokenTTL.
|
||||
if role.Period > time.Duration(0) {
|
||||
auth.TTL = role.Period
|
||||
} else {
|
||||
auth.TTL = role.TokenTTL
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Auth: auth,
|
||||
}, nil
|
||||
@@ -94,8 +86,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
|
||||
return nil, fmt.Errorf("failed to fetch role_name during renewal")
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
// Ensure that the Role still exists.
|
||||
role, err := b.roleEntry(req.Storage, roleName)
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate role %s during renewal:%s", roleName, err)
|
||||
}
|
||||
@@ -103,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
|
||||
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
|
||||
}
|
||||
|
||||
// If 'Period' is set on the Role, the token should never expire.
|
||||
// Replenish the TTL with 'Period's value.
|
||||
if role.Period > time.Duration(0) {
|
||||
// If 'Period' was updated after the token was issued,
|
||||
// token will bear the updated 'Period' value as its TTL.
|
||||
req.Auth.TTL = role.Period
|
||||
return &logical.Response{Auth: req.Auth}, nil
|
||||
} else {
|
||||
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
||||
resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Auth.Period = role.Period
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathLoginHelpSys = "Issue a token based on the credentials supplied"
|
||||
|
||||
@@ -57,6 +57,10 @@ type roleStorageEntry struct {
|
||||
// value is not modified on the role. If the `Period` in the role is modified,
|
||||
// a token will pick up the new value during its next renewal.
|
||||
Period time.Duration `json:"period" mapstructure:"period" structs:"period"`
|
||||
|
||||
// LowerCaseRoleName enforces the lower casing of role names for all the
|
||||
// roles that get created since this field was introduced.
|
||||
LowerCaseRoleName bool `json:"lower_case_role_name" mapstructure:"lower_case_role_name" structs:"lower_case_role_name"`
|
||||
}
|
||||
|
||||
// roleIDStorageEntry represents the reverse mapping from RoleID to Role
|
||||
@@ -509,10 +513,20 @@ the role.`,
|
||||
|
||||
// pathRoleExistenceCheck returns whether the role with the given name exists or not.
|
||||
func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
role, err := b.roleEntry(req.Storage, data.Get("role_name").(string))
|
||||
roleName := data.Get("role_name").(string)
|
||||
if roleName == "" {
|
||||
return false, fmt.Errorf("missing role_name")
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return role != nil, nil
|
||||
}
|
||||
|
||||
@@ -537,13 +551,21 @@ func (b *backend) pathRoleSecretIDList(req *logical.Request, data *framework.Fie
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
// Get the role entry
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("role %s does not exist", roleName)), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("role %q does not exist", roleName)), nil
|
||||
}
|
||||
|
||||
if role.LowerCaseRoleName {
|
||||
roleName = strings.ToLower(roleName)
|
||||
}
|
||||
|
||||
// Guard the list operation with an outer lock
|
||||
@@ -552,7 +574,7 @@ func (b *backend) pathRoleSecretIDList(req *logical.Request, data *framework.Fie
|
||||
|
||||
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
|
||||
}
|
||||
|
||||
// Listing works one level at a time. Get the first level of data
|
||||
@@ -618,9 +640,8 @@ func validateRoleConstraints(role *roleStorageEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setRoleEntry grabs a write lock and stores the options on an role into the
|
||||
// storage. Also creates a reverse index from the role's RoleID to the role
|
||||
// itself.
|
||||
// setRoleEntry persists the role and creates an index from roleID to role
|
||||
// name.
|
||||
func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
|
||||
if roleName == "" {
|
||||
return fmt.Errorf("missing role name")
|
||||
@@ -641,7 +662,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
|
||||
return err
|
||||
}
|
||||
if entry == nil {
|
||||
return fmt.Errorf("failed to create storage entry for role %s", roleName)
|
||||
return fmt.Errorf("failed to create storage entry for role %q", roleName)
|
||||
}
|
||||
|
||||
// Check if the index from the role_id to role already exists
|
||||
@@ -680,7 +701,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
|
||||
})
|
||||
}
|
||||
|
||||
// roleEntry grabs the read lock and fetches the options of an role from the storage
|
||||
// roleEntry reads the role from storage
|
||||
func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEntry, error) {
|
||||
if roleName == "" {
|
||||
return nil, fmt.Errorf("missing role_name")
|
||||
@@ -688,11 +709,6 @@ func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEnt
|
||||
|
||||
var role roleStorageEntry
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if entry, err := s.Get("role/" + strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if entry == nil {
|
||||
@@ -712,6 +728,10 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// Check if the role already exists
|
||||
role, err := b.roleEntry(req.Storage, roleName)
|
||||
if err != nil {
|
||||
@@ -722,13 +742,14 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
|
||||
if role == nil && req.Operation == logical.CreateOperation {
|
||||
hmacKey, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create role_id: %s\n", err)
|
||||
return nil, fmt.Errorf("failed to create role_id: %v\n", err)
|
||||
}
|
||||
role = &roleStorageEntry{
|
||||
HMACKey: hmacKey,
|
||||
HMACKey: hmacKey,
|
||||
LowerCaseRoleName: true,
|
||||
}
|
||||
} else if role == nil {
|
||||
return nil, fmt.Errorf("role entry not found during update operation")
|
||||
return logical.ErrorResponse(fmt.Sprintf("invalid role name")), nil
|
||||
}
|
||||
|
||||
previousRoleID := role.RoleID
|
||||
@@ -737,12 +758,12 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
|
||||
} else if req.Operation == logical.CreateOperation {
|
||||
roleID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate role_id: %s\n", err)
|
||||
return nil, fmt.Errorf("failed to generate role_id: %v\n", err)
|
||||
}
|
||||
role.RoleID = roleID
|
||||
}
|
||||
if role.RoleID == "" {
|
||||
return logical.ErrorResponse("invalid role_id"), nil
|
||||
return logical.ErrorResponse("invalid role_id supplied, or failed to generate a role_id"), nil
|
||||
}
|
||||
|
||||
if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok {
|
||||
@@ -780,7 +801,7 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie
|
||||
role.Period = time.Second * time.Duration(data.Get("period").(int))
|
||||
}
|
||||
if role.Period > b.System().MaxLeaseTTL() {
|
||||
return logical.ErrorResponse(fmt.Sprintf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
|
||||
}
|
||||
|
||||
if secretIDNumUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok {
|
||||
@@ -843,32 +864,78 @@ func (b *backend) pathRoleRead(req *logical.Request, data *framework.FieldData)
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
lockRelease := lock.RUnlock
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
lockRelease()
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
return nil, nil
|
||||
} else {
|
||||
// Convert the 'time.Duration' values to second.
|
||||
role.SecretIDTTL /= time.Second
|
||||
role.TokenTTL /= time.Second
|
||||
role.TokenMaxTTL /= time.Second
|
||||
role.Period /= time.Second
|
||||
|
||||
// Create a map of data to be returned and remove sensitive information from it
|
||||
data := structs.New(role).Map()
|
||||
delete(data, "role_id")
|
||||
delete(data, "hmac_key")
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: data,
|
||||
}
|
||||
|
||||
if err := validateRoleConstraints(role); err != nil {
|
||||
resp.AddWarning("Role does not have any constraints set on it. Updates to this role will require a constraint to be set")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if role == nil {
|
||||
lockRelease()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
respData := map[string]interface{}{
|
||||
"bind_secret_id": role.BindSecretID,
|
||||
"bound_cidr_list": role.BoundCIDRList,
|
||||
"period": role.Period / time.Second,
|
||||
"policies": role.Policies,
|
||||
"secret_id_num_uses": role.SecretIDNumUses,
|
||||
"secret_id_ttl": role.SecretIDTTL / time.Second,
|
||||
"token_max_ttl": role.TokenMaxTTL / time.Second,
|
||||
"token_num_uses": role.TokenNumUses,
|
||||
"token_ttl": role.TokenTTL / time.Second,
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: respData,
|
||||
}
|
||||
|
||||
if err := validateRoleConstraints(role); err != nil {
|
||||
resp.AddWarning("Role does not have any constraints set on it. Updates to this role will require a constraint to be set")
|
||||
}
|
||||
|
||||
// For sanity, verify that the index still exists. If the index is missing,
|
||||
// add one and return a warning so it can be reported.
|
||||
roleIDIndex, err := b.roleIDEntry(req.Storage, role.RoleID)
|
||||
if err != nil {
|
||||
lockRelease()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if roleIDIndex == nil {
|
||||
// Switch to a write lock
|
||||
lock.RUnlock()
|
||||
lock.Lock()
|
||||
lockRelease = lock.Unlock
|
||||
|
||||
// Check again if the index is missing
|
||||
roleIDIndex, err = b.roleIDEntry(req.Storage, role.RoleID)
|
||||
if err != nil {
|
||||
lockRelease()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if roleIDIndex == nil {
|
||||
// Create a new index
|
||||
err = b.setRoleIDEntry(req.Storage, role.RoleID, &roleIDStorageEntry{
|
||||
Name: roleName,
|
||||
})
|
||||
if err != nil {
|
||||
lockRelease()
|
||||
return nil, fmt.Errorf("failed to create secondary index for role_id %q: %v", role.RoleID, err)
|
||||
}
|
||||
resp.AddWarning("Role identifier was missing an index back to role name. A new index has been added. Please report this observation.")
|
||||
}
|
||||
}
|
||||
|
||||
lockRelease()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// pathRoleDelete removes the role from the storage
|
||||
@@ -878,6 +945,10 @@ func (b *backend) pathRoleDelete(req *logical.Request, data *framework.FieldData
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -886,19 +957,14 @@ func (b *backend) pathRoleDelete(req *logical.Request, data *framework.FieldData
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Acquire the lock before deleting the secrets.
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// Just before the role is deleted, remove all the SecretIDs issued as part of the role.
|
||||
if err = b.flushRoleSecrets(req.Storage, roleName, role.HMACKey); err != nil {
|
||||
return nil, fmt.Errorf("failed to invalidate the secrets belonging to role '%s': %s", roleName, err)
|
||||
return nil, fmt.Errorf("failed to invalidate the secrets belonging to role %q: %v", roleName, err)
|
||||
}
|
||||
|
||||
// Delete the reverse mapping from RoleID to the role
|
||||
if err = b.roleIDEntryDelete(req.Storage, role.RoleID); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete the mapping from RoleID to role '%s': %s", roleName, err)
|
||||
return nil, fmt.Errorf("failed to delete the mapping from RoleID to role %q: %v", roleName, err)
|
||||
}
|
||||
|
||||
// After deleting the SecretIDs and the RoleID, delete the role itself
|
||||
@@ -921,25 +987,33 @@ func (b *backend) pathRoleSecretIDLookupUpdate(req *logical.Request, data *frame
|
||||
return logical.ErrorResponse("missing secret_id"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
// Fetch the role
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("role %s does not exist", roleName)
|
||||
return nil, fmt.Errorf("role %q does not exist", roleName)
|
||||
}
|
||||
|
||||
if role.LowerCaseRoleName {
|
||||
roleName = strings.ToLower(roleName)
|
||||
}
|
||||
|
||||
// Create the HMAC of the secret ID using the per-role HMAC key
|
||||
secretIDHMAC, err := createHMAC(role.HMACKey, secretID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of secret_id: %s", err)
|
||||
return nil, fmt.Errorf("failed to create HMAC of secret_id: %v", err)
|
||||
}
|
||||
|
||||
// Create the HMAC of the roleName using the per-role HMAC key
|
||||
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
|
||||
}
|
||||
|
||||
// Create the index at which the secret_id would've been stored
|
||||
@@ -996,22 +1070,26 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(req *logical.Request, data
|
||||
return logical.ErrorResponse("missing secret_id"), nil
|
||||
}
|
||||
|
||||
roleLock := b.roleLock(roleName)
|
||||
roleLock.RLock()
|
||||
defer roleLock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("role %s does not exist", roleName)
|
||||
return nil, fmt.Errorf("role %q does not exist", roleName)
|
||||
}
|
||||
|
||||
secretIDHMAC, err := createHMAC(role.HMACKey, secretID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of secret_id: %s", err)
|
||||
return nil, fmt.Errorf("failed to create HMAC of secret_id: %v", err)
|
||||
}
|
||||
|
||||
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
|
||||
}
|
||||
|
||||
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
|
||||
@@ -1036,7 +1114,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(req *logical.Request, data
|
||||
|
||||
// Delete the storage entry that corresponds to the SecretID
|
||||
if err := req.Storage.Delete(entryIndex); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete SecretID: %s", err)
|
||||
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
@@ -1059,12 +1137,16 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(req *logical.Request, dat
|
||||
// Get the role details to fetch the RoleID and accessor to get
|
||||
// the HMACed SecretID.
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("role %s does not exist", roleName)
|
||||
return nil, fmt.Errorf("role %q does not exist", roleName)
|
||||
}
|
||||
|
||||
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
|
||||
@@ -1072,12 +1154,12 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(req *logical.Request, dat
|
||||
return nil, err
|
||||
}
|
||||
if accessorEntry == nil {
|
||||
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor:%s\n", secretIDAccessor)
|
||||
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor: %q\n", secretIDAccessor)
|
||||
}
|
||||
|
||||
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
|
||||
}
|
||||
|
||||
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC)
|
||||
@@ -1105,7 +1187,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("role %s does not exist", roleName)
|
||||
return nil, fmt.Errorf("role %q does not exist", roleName)
|
||||
}
|
||||
|
||||
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
|
||||
@@ -1113,12 +1195,12 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque
|
||||
return nil, err
|
||||
}
|
||||
if accessorEntry == nil {
|
||||
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor:%s\n", secretIDAccessor)
|
||||
return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor: %q\n", secretIDAccessor)
|
||||
}
|
||||
|
||||
roleNameHMAC, err := createHMAC(role.HMACKey, roleName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err)
|
||||
return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err)
|
||||
}
|
||||
|
||||
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC)
|
||||
@@ -1134,7 +1216,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque
|
||||
|
||||
// Delete the storage entry that corresponds to the SecretID
|
||||
if err := req.Storage.Delete(entryIndex); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete SecretID: %s", err)
|
||||
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
@@ -1146,6 +1228,11 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// Re-read the role after grabbing the lock
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1154,11 +1241,6 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.BoundCIDRList = strings.TrimSpace(data.Get("bound_cidr_list").(string))
|
||||
if role.BoundCIDRList == "" {
|
||||
return logical.ErrorResponse("missing bound_cidr_list"), nil
|
||||
@@ -1167,7 +1249,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew
|
||||
if role.BoundCIDRList != "" {
|
||||
valid, err := cidrutil.ValidateCIDRListString(role.BoundCIDRList, ",")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate CIDR blocks: %q", err)
|
||||
return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err)
|
||||
}
|
||||
if !valid {
|
||||
return logical.ErrorResponse("failed to validate CIDR blocks"), nil
|
||||
@@ -1183,6 +1265,10 @@ func (b *backend) pathRoleBoundCIDRListRead(req *logical.Request, data *framewor
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1202,6 +1288,10 @@ func (b *backend) pathRoleBoundCIDRListDelete(req *logical.Request, data *framew
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1210,11 +1300,6 @@ func (b *backend) pathRoleBoundCIDRListDelete(req *logical.Request, data *framew
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// Deleting a field implies setting the value to it's default value.
|
||||
role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string)
|
||||
|
||||
@@ -1227,6 +1312,10 @@ func (b *backend) pathRoleBindSecretIDUpdate(req *logical.Request, data *framewo
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1235,11 +1324,6 @@ func (b *backend) pathRoleBindSecretIDUpdate(req *logical.Request, data *framewo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok {
|
||||
role.BindSecretID = bindSecretIDRaw.(bool)
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1254,6 +1338,10 @@ func (b *backend) pathRoleBindSecretIDRead(req *logical.Request, data *framework
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1273,6 +1361,10 @@ func (b *backend) pathRoleBindSecretIDDelete(req *logical.Request, data *framewo
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1281,11 +1373,6 @@ func (b *backend) pathRoleBindSecretIDDelete(req *logical.Request, data *framewo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// Deleting a field implies setting the value to it's default value.
|
||||
role.BindSecretID = data.GetDefaultOrZero("bind_secret_id").(bool)
|
||||
|
||||
@@ -1298,6 +1385,10 @@ func (b *backend) pathRolePoliciesUpdate(req *logical.Request, data *framework.F
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1311,11 +1402,6 @@ func (b *backend) pathRolePoliciesUpdate(req *logical.Request, data *framework.F
|
||||
return logical.ErrorResponse("missing policies"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.Policies = policyutil.ParsePolicies(policiesRaw)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1327,6 +1413,10 @@ func (b *backend) pathRolePoliciesRead(req *logical.Request, data *framework.Fie
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1346,6 +1436,10 @@ func (b *backend) pathRolePoliciesDelete(req *logical.Request, data *framework.F
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1354,11 +1448,6 @@ func (b *backend) pathRolePoliciesDelete(req *logical.Request, data *framework.F
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.Policies = []string{}
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1370,6 +1459,10 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(req *logical.Request, data *fram
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1378,11 +1471,6 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(req *logical.Request, data *fram
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if numUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok {
|
||||
role.SecretIDNumUses = numUsesRaw.(int)
|
||||
if role.SecretIDNumUses < 0 {
|
||||
@@ -1400,6 +1488,10 @@ func (b *backend) pathRoleRoleIDUpdate(req *logical.Request, data *framework.Fie
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1408,11 +1500,6 @@ func (b *backend) pathRoleRoleIDUpdate(req *logical.Request, data *framework.Fie
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
previousRoleID := role.RoleID
|
||||
role.RoleID = data.Get("role_id").(string)
|
||||
if role.RoleID == "" {
|
||||
@@ -1428,6 +1515,10 @@ func (b *backend) pathRoleRoleIDRead(req *logical.Request, data *framework.Field
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1447,6 +1538,10 @@ func (b *backend) pathRoleSecretIDNumUsesRead(req *logical.Request, data *framew
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1466,6 +1561,10 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(req *logical.Request, data *fram
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1474,11 +1573,6 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(req *logical.Request, data *fram
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.SecretIDNumUses = data.GetDefaultOrZero("secret_id_num_uses").(int)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1490,6 +1584,10 @@ func (b *backend) pathRoleSecretIDTTLUpdate(req *logical.Request, data *framewor
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1498,11 +1596,6 @@ func (b *backend) pathRoleSecretIDTTLUpdate(req *logical.Request, data *framewor
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if secretIDTTLRaw, ok := data.GetOk("secret_id_ttl"); ok {
|
||||
role.SecretIDTTL = time.Second * time.Duration(secretIDTTLRaw.(int))
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1517,6 +1610,10 @@ func (b *backend) pathRoleSecretIDTTLRead(req *logical.Request, data *framework.
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1537,6 +1634,10 @@ func (b *backend) pathRoleSecretIDTTLDelete(req *logical.Request, data *framewor
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1545,11 +1646,6 @@ func (b *backend) pathRoleSecretIDTTLDelete(req *logical.Request, data *framewor
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.SecretIDTTL = time.Second * time.Duration(data.GetDefaultOrZero("secret_id_ttl").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1561,6 +1657,10 @@ func (b *backend) pathRolePeriodUpdate(req *logical.Request, data *framework.Fie
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1569,15 +1669,10 @@ func (b *backend) pathRolePeriodUpdate(req *logical.Request, data *framework.Fie
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if periodRaw, ok := data.GetOk("period"); ok {
|
||||
role.Period = time.Second * time.Duration(periodRaw.(int))
|
||||
if role.Period > b.System().MaxLeaseTTL() {
|
||||
return logical.ErrorResponse(fmt.Sprintf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
|
||||
}
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
} else {
|
||||
@@ -1591,6 +1686,10 @@ func (b *backend) pathRolePeriodRead(req *logical.Request, data *framework.Field
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1611,6 +1710,10 @@ func (b *backend) pathRolePeriodDelete(req *logical.Request, data *framework.Fie
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1619,11 +1722,6 @@ func (b *backend) pathRolePeriodDelete(req *logical.Request, data *framework.Fie
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.Period = time.Second * time.Duration(data.GetDefaultOrZero("period").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1635,6 +1733,10 @@ func (b *backend) pathRoleTokenNumUsesUpdate(req *logical.Request, data *framewo
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1643,11 +1745,6 @@ func (b *backend) pathRoleTokenNumUsesUpdate(req *logical.Request, data *framewo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if tokenNumUsesRaw, ok := data.GetOk("token_num_uses"); ok {
|
||||
role.TokenNumUses = tokenNumUsesRaw.(int)
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1662,6 +1759,10 @@ func (b *backend) pathRoleTokenNumUsesRead(req *logical.Request, data *framework
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1681,6 +1782,10 @@ func (b *backend) pathRoleTokenNumUsesDelete(req *logical.Request, data *framewo
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1689,11 +1794,6 @@ func (b *backend) pathRoleTokenNumUsesDelete(req *logical.Request, data *framewo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.TokenNumUses = data.GetDefaultOrZero("token_num_uses").(int)
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1705,6 +1805,10 @@ func (b *backend) pathRoleTokenTTLUpdate(req *logical.Request, data *framework.F
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1713,11 +1817,6 @@ func (b *backend) pathRoleTokenTTLUpdate(req *logical.Request, data *framework.F
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if tokenTTLRaw, ok := data.GetOk("token_ttl"); ok {
|
||||
role.TokenTTL = time.Second * time.Duration(tokenTTLRaw.(int))
|
||||
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
|
||||
@@ -1735,6 +1834,10 @@ func (b *backend) pathRoleTokenTTLRead(req *logical.Request, data *framework.Fie
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1755,6 +1858,10 @@ func (b *backend) pathRoleTokenTTLDelete(req *logical.Request, data *framework.F
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1763,11 +1870,6 @@ func (b *backend) pathRoleTokenTTLDelete(req *logical.Request, data *framework.F
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.TokenTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_ttl").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1779,6 +1881,10 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(req *logical.Request, data *framewor
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1787,11 +1893,6 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(req *logical.Request, data *framewor
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if tokenMaxTTLRaw, ok := data.GetOk("token_max_ttl"); ok {
|
||||
role.TokenMaxTTL = time.Second * time.Duration(tokenMaxTTLRaw.(int))
|
||||
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
|
||||
@@ -1809,6 +1910,10 @@ func (b *backend) pathRoleTokenMaxTTLRead(req *logical.Request, data *framework.
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
|
||||
return nil, err
|
||||
} else if role == nil {
|
||||
@@ -1829,6 +1934,10 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor
|
||||
return logical.ErrorResponse("missing role_name"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1837,11 +1946,6 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
role.TokenMaxTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_max_ttl").(int))
|
||||
|
||||
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
|
||||
@@ -1850,7 +1954,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor
|
||||
func (b *backend) pathRoleSecretIDUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
secretID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate SecretID:%s", err)
|
||||
return nil, fmt.Errorf("failed to generate secret_id: %v", err)
|
||||
}
|
||||
return b.handleRoleSecretIDCommon(req, data, secretID)
|
||||
}
|
||||
@@ -1869,12 +1973,16 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
|
||||
return logical.ErrorResponse("missing secret_id"), nil
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleName)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("role %s does not exist", roleName)), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("role %q does not exist", roleName)), nil
|
||||
}
|
||||
|
||||
if !role.BindSecretID {
|
||||
@@ -1887,7 +1995,7 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
|
||||
if cidrList != "" {
|
||||
valid, err := cidrutil.ValidateCIDRListString(cidrList, ",")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate CIDR blocks: %q", err)
|
||||
return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err)
|
||||
}
|
||||
if !valid {
|
||||
return logical.ErrorResponse("failed to validate CIDR blocks"), nil
|
||||
@@ -1913,8 +2021,12 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to parse metadata: %v", err)), nil
|
||||
}
|
||||
|
||||
if role.LowerCaseRoleName {
|
||||
roleName = strings.ToLower(roleName)
|
||||
}
|
||||
|
||||
if secretIDStorage, err = b.registerSecretIDEntry(req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
|
||||
return nil, fmt.Errorf("failed to store SecretID: %s", err)
|
||||
return nil, fmt.Errorf("failed to store secret_id: %v", err)
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
|
||||
@@ -2,6 +2,7 @@ package approle
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -10,6 +11,248 @@ import (
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func TestApprole_RoleNameLowerCasing(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
var roleID, secretID string
|
||||
|
||||
b, storage := createBackendWithStorage(t)
|
||||
|
||||
// Save a role with out LowerCaseRoleName set
|
||||
role := &roleStorageEntry{
|
||||
RoleID: "testroleid",
|
||||
HMACKey: "testhmackey",
|
||||
Policies: []string{"default"},
|
||||
BindSecretID: true,
|
||||
}
|
||||
err = b.setRoleEntry(storage, "testRoleName", role, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretIDReq := &logical.Request{
|
||||
Path: "role/testRoleName/secret-id",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
}
|
||||
resp, err = b.HandleRequest(secretIDReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
secretID = resp.Data["secret_id"].(string)
|
||||
roleID = "testroleid"
|
||||
|
||||
// Regular login flow. This should succeed.
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "login",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"role_id": roleID,
|
||||
"secret_id": secretID,
|
||||
},
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
// Lower case the role name when generating the secret id
|
||||
secretIDReq.Path = "role/testrolename/secret-id"
|
||||
resp, err = b.HandleRequest(secretIDReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
secretID = resp.Data["secret_id"].(string)
|
||||
|
||||
// Login should fail
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "login",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"role_id": roleID,
|
||||
"secret_id": secretID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil || !resp.IsError() {
|
||||
t.Fatalf("expected an error")
|
||||
}
|
||||
|
||||
// Delete the role and create it again. This time don't directly persist
|
||||
// it, but route the request to the creation handler so that it sets the
|
||||
// LowerCaseRoleName to true.
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "role/testRoleName",
|
||||
Operation: logical.DeleteOperation,
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
roleReq := &logical.Request{
|
||||
Path: "role/testRoleName",
|
||||
Operation: logical.CreateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"bind_secret_id": true,
|
||||
},
|
||||
}
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
// Create secret id with lower cased role name
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "role/testrolename/secret-id",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
secretID = resp.Data["secret_id"].(string)
|
||||
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "role/testrolename/role-id",
|
||||
Operation: logical.ReadOperation,
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
roleID = resp.Data["role_id"].(string)
|
||||
|
||||
// Login should pass
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "login",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"role_id": roleID,
|
||||
"secret_id": secretID,
|
||||
},
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr:%v", resp, err)
|
||||
}
|
||||
|
||||
// Lookup of secret ID should work in case-insensitive manner
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "role/testrolename/secret-id/lookup",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"secret_id": secretID,
|
||||
},
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("failed to lookup secret IDs")
|
||||
}
|
||||
|
||||
// Listing of secret IDs should work in case-insensitive manner
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Path: "role/testrolename/secret-id",
|
||||
Operation: logical.ListOperation,
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr: %v", resp, err)
|
||||
}
|
||||
|
||||
if len(resp.Data["keys"].([]string)) != 1 {
|
||||
t.Fatalf("failed to list secret IDs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppRole_RoleReadSetIndex(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
b, storage := createBackendWithStorage(t)
|
||||
|
||||
roleReq := &logical.Request{
|
||||
Path: "role/testrole",
|
||||
Operation: logical.CreateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"bind_secret_id": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Create a role
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
|
||||
}
|
||||
|
||||
roleIDReq := &logical.Request{
|
||||
Path: "role/testrole/role-id",
|
||||
Operation: logical.ReadOperation,
|
||||
Storage: storage,
|
||||
}
|
||||
|
||||
// Get the role ID
|
||||
resp, err = b.HandleRequest(roleIDReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
|
||||
}
|
||||
roleID := resp.Data["role_id"].(string)
|
||||
|
||||
// Delete the role ID index
|
||||
err = b.roleIDEntryDelete(storage, roleID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the role again. This should add the index and return a warning
|
||||
roleReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
|
||||
}
|
||||
|
||||
// Check if the warning is being returned
|
||||
if !strings.Contains(resp.Warnings[0], "Role identifier was missing an index back to role name.") {
|
||||
t.Fatalf("bad: expected a warning in the response")
|
||||
}
|
||||
|
||||
roleIDIndex, err := b.roleIDEntry(storage, roleID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check if the index has been successfully created
|
||||
if roleIDIndex == nil || roleIDIndex.Name != "testrole" {
|
||||
t.Fatalf("bad: expected role to have an index")
|
||||
}
|
||||
|
||||
roleReq.Operation = logical.UpdateOperation
|
||||
roleReq.Data = map[string]interface{}{
|
||||
"bind_secret_id": true,
|
||||
"policies": "default",
|
||||
}
|
||||
|
||||
// Check if updating and reading of roles work and that there are no lock
|
||||
// contentions dangling due to previous operation
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
|
||||
}
|
||||
roleReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppRole_CIDRSubset(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
@@ -75,15 +75,19 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
|
||||
return nil, "", err
|
||||
}
|
||||
if roleIDIndex == nil {
|
||||
return nil, "", fmt.Errorf("failed to find secondary index for role_id %q\n", roleID)
|
||||
return nil, "", fmt.Errorf("invalid role_id %q\n", roleID)
|
||||
}
|
||||
|
||||
lock := b.roleLock(roleIDIndex.Name)
|
||||
lock.RLock()
|
||||
defer lock.RUnlock()
|
||||
|
||||
role, err := b.roleEntry(s, roleIDIndex.Name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, "", fmt.Errorf("role %q referred by the SecretID does not exist", roleIDIndex.Name)
|
||||
return nil, "", fmt.Errorf("role %q referred by the role_id %q does not exist anymore", roleIDIndex.Name, roleID)
|
||||
}
|
||||
|
||||
return role, roleIDIndex.Name, nil
|
||||
@@ -121,6 +125,10 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
|
||||
return nil, "", metadata, "", fmt.Errorf("missing secret_id")
|
||||
}
|
||||
|
||||
if role.LowerCaseRoleName {
|
||||
roleName = strings.ToLower(roleName)
|
||||
}
|
||||
|
||||
// Check if the SecretID supplied is valid. If use limit was specified
|
||||
// on the SecretID, it will be decremented in this call.
|
||||
var valid bool
|
||||
|
||||
@@ -99,6 +99,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
LocalStorage: []string{
|
||||
"whitelist/identity/",
|
||||
},
|
||||
SealWrapStorage: []string{
|
||||
"config/client",
|
||||
},
|
||||
},
|
||||
Paths: []*framework.Path{
|
||||
pathLogin(b),
|
||||
|
||||
@@ -1125,6 +1125,11 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
|
||||
t.Fatalf("instance ID not present in the response object")
|
||||
}
|
||||
|
||||
_, ok := resp.Auth.Metadata["nonce"]
|
||||
if ok {
|
||||
t.Fatalf("client nonce should not have been returned")
|
||||
}
|
||||
|
||||
loginInput["nonce"] = "changed-vault-client-nonce"
|
||||
// try to login again with changed nonce
|
||||
resp, err = b.HandleRequest(loginRequest)
|
||||
@@ -1159,7 +1164,9 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
|
||||
t.Fatalf("failed to delete whitelist identity")
|
||||
}
|
||||
|
||||
// Allow a fresh login.
|
||||
// Allow a fresh login without supplying the nonce
|
||||
delete(loginInput, "nonce")
|
||||
|
||||
resp, err = b.HandleRequest(loginRequest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1167,6 +1174,11 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
|
||||
if resp == nil || resp.Auth == nil || resp.IsError() {
|
||||
t.Fatalf("login attempt failed")
|
||||
}
|
||||
|
||||
_, ok = resp.Auth.Metadata["nonce"]
|
||||
if !ok {
|
||||
t.Fatalf("expected nonce to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_pathStsConfig(t *testing.T) {
|
||||
|
||||
@@ -34,6 +34,11 @@ func GenerateLoginData(accessKey, secretKey, sessionToken, headerValue string) (
|
||||
return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
|
||||
}
|
||||
|
||||
_, err = creds.Get()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve credentials from credential chain: %v", err)
|
||||
}
|
||||
|
||||
// Use the credentials we've found to construct an STS session
|
||||
stsSession, err := session.NewSessionWithOptions(session.Options{
|
||||
Config: aws.Config{Credentials: creds},
|
||||
|
||||
@@ -643,7 +643,7 @@ func (b *backend) pathLoginUpdateEc2(
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Don't let subsequent login attempts to bypass in initial
|
||||
// Don't let subsequent login attempts to bypass the initial
|
||||
// intent of disabling reauthentication, despite the properties
|
||||
// of role getting updated. For example: Role has the value set
|
||||
// to 'false', a role-tag login sets the value to 'true', then
|
||||
@@ -693,7 +693,6 @@ func (b *backend) pathLoginUpdateEc2(
|
||||
|
||||
if roleTagResp != nil {
|
||||
// Role tag is enabled on the role.
|
||||
//
|
||||
|
||||
// Overwrite the policies with the ones returned from processing the role tag
|
||||
// If there are no policies on the role tag, policies on the role are inherited.
|
||||
@@ -777,8 +776,9 @@ func (b *backend) pathLoginUpdateEc2(
|
||||
},
|
||||
}
|
||||
|
||||
// Return the nonce only if reauthentication is allowed
|
||||
if !disallowReauthentication {
|
||||
// Return the nonce only if reauthentication is allowed and if the nonce
|
||||
// was not supplied by the user.
|
||||
if !disallowReauthentication && !clientNonceSupplied {
|
||||
// Echo the client nonce back. If nonce param was not supplied
|
||||
// to the endpoint at all (setting it to empty string does not
|
||||
// qualify here), callers should extract out the nonce from
|
||||
@@ -786,23 +786,15 @@ func (b *backend) pathLoginUpdateEc2(
|
||||
resp.Auth.Metadata["nonce"] = clientNonce
|
||||
}
|
||||
|
||||
if roleEntry.Period > time.Duration(0) {
|
||||
resp.Auth.TTL = roleEntry.Period
|
||||
} else {
|
||||
// Cap the TTL value.
|
||||
shortestTTL := b.System().DefaultLeaseTTL()
|
||||
if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL {
|
||||
shortestTTL = roleEntry.TTL
|
||||
if roleEntry.MaxTTL > time.Duration(0) {
|
||||
// Cap TTL to shortestMaxTTL
|
||||
if resp.Auth.TTL > shortestMaxTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (shortestMaxTTL / time.Second)))
|
||||
resp.Auth.TTL = shortestMaxTTL
|
||||
}
|
||||
if shortestMaxTTL < shortestTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Effective ttl of %q exceeded the effective max_ttl of %q; ttl value is capped appropriately", (shortestTTL / time.Second).String(), (shortestMaxTTL / time.Second).String()))
|
||||
shortestTTL = shortestMaxTTL
|
||||
}
|
||||
resp.Auth.TTL = shortestTTL
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
||||
}
|
||||
|
||||
// handleRoleTagLogin is used to fetch the role tag of the instance and
|
||||
@@ -985,13 +977,12 @@ func (b *backend) pathLoginRenewIam(
|
||||
}
|
||||
}
|
||||
|
||||
// If 'Period' is set on the role, then the token should never expire.
|
||||
if roleEntry.Period > time.Duration(0) {
|
||||
req.Auth.TTL = roleEntry.Period
|
||||
return &logical.Response{Auth: req.Auth}, nil
|
||||
} else {
|
||||
return framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(req, data)
|
||||
resp, err := framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(req, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Auth.Period = roleEntry.Period
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLoginRenewEc2(
|
||||
@@ -1072,24 +1063,12 @@ func (b *backend) pathLoginRenewEc2(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If 'Period' is set on the role, then the token should never expire. Role
|
||||
// tag does not have a 'Period' field. So, regarless of whether the token
|
||||
// was issued using a role login or a role tag login, the period set on the
|
||||
// role should take effect.
|
||||
if roleEntry.Period > time.Duration(0) {
|
||||
req.Auth.TTL = roleEntry.Period
|
||||
return &logical.Response{Auth: req.Auth}, nil
|
||||
} else {
|
||||
// Cap the TTL value
|
||||
shortestTTL := b.System().DefaultLeaseTTL()
|
||||
if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL {
|
||||
shortestTTL = roleEntry.TTL
|
||||
}
|
||||
if shortestMaxTTL < shortestTTL {
|
||||
shortestTTL = shortestMaxTTL
|
||||
}
|
||||
return framework.LeaseExtend(shortestTTL, shortestMaxTTL, b.System())(req, data)
|
||||
resp, err := framework.LeaseExtend(roleEntry.TTL, shortestMaxTTL, b.System())(req, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Auth.Period = roleEntry.Period
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLoginUpdateIam(
|
||||
@@ -1238,7 +1217,7 @@ func (b *backend) pathLoginUpdateIam(
|
||||
policies := roleEntry.Policies
|
||||
|
||||
inferredEntityType := ""
|
||||
inferredEntityId := ""
|
||||
inferredEntityID := ""
|
||||
if roleEntry.InferredEntityType == ec2EntityType {
|
||||
instance, err := b.validateInstance(req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
|
||||
if err != nil {
|
||||
@@ -1264,7 +1243,7 @@ func (b *backend) pathLoginUpdateIam(
|
||||
}
|
||||
|
||||
inferredEntityType = ec2EntityType
|
||||
inferredEntityId = entity.SessionInfo
|
||||
inferredEntityID = entity.SessionInfo
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
@@ -1277,7 +1256,7 @@ func (b *backend) pathLoginUpdateIam(
|
||||
"client_user_id": callerUniqueId,
|
||||
"auth_type": iamAuthType,
|
||||
"inferred_entity_type": inferredEntityType,
|
||||
"inferred_entity_id": inferredEntityId,
|
||||
"inferred_entity_id": inferredEntityID,
|
||||
"inferred_aws_region": roleEntry.InferredAWSRegion,
|
||||
"account_id": entity.AccountNumber,
|
||||
},
|
||||
@@ -1295,25 +1274,18 @@ func (b *backend) pathLoginUpdateIam(
|
||||
},
|
||||
}
|
||||
|
||||
if roleEntry.Period > time.Duration(0) {
|
||||
resp.Auth.TTL = roleEntry.Period
|
||||
} else {
|
||||
shortestTTL := b.System().DefaultLeaseTTL()
|
||||
if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL {
|
||||
shortestTTL = roleEntry.TTL
|
||||
if roleEntry.MaxTTL > time.Duration(0) {
|
||||
// Cap maxTTL to the sysview's max TTL
|
||||
maxTTL := roleEntry.MaxTTL
|
||||
if maxTTL > b.System().MaxLeaseTTL() {
|
||||
maxTTL = b.System().MaxLeaseTTL()
|
||||
}
|
||||
|
||||
maxTTL := b.System().MaxLeaseTTL()
|
||||
if roleEntry.MaxTTL > time.Duration(0) && roleEntry.MaxTTL < maxTTL {
|
||||
maxTTL = roleEntry.MaxTTL
|
||||
// Cap TTL to MaxTTL
|
||||
if resp.Auth.TTL > maxTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (maxTTL / time.Second)))
|
||||
resp.Auth.TTL = maxTTL
|
||||
}
|
||||
|
||||
if shortestTTL > maxTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Effective TTL of %q exceeded the effective max_ttl of %q; TTL value is capped accordingly", (shortestTTL / time.Second).String(), (maxTTL / time.Second).String()))
|
||||
shortestTTL = maxTTL
|
||||
}
|
||||
|
||||
resp.Auth.TTL = shortestTTL
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
@@ -1333,11 +1305,11 @@ func hasValuesForEc2Auth(data *framework.FieldData) (bool, bool) {
|
||||
|
||||
func hasValuesForIamAuth(data *framework.FieldData) (bool, bool) {
|
||||
_, hasRequestMethod := data.GetOk("iam_http_request_method")
|
||||
_, hasRequestUrl := data.GetOk("iam_request_url")
|
||||
_, hasRequestURL := data.GetOk("iam_request_url")
|
||||
_, hasRequestBody := data.GetOk("iam_request_body")
|
||||
_, hasRequestHeaders := data.GetOk("iam_request_headers")
|
||||
return (hasRequestMethod && hasRequestUrl && hasRequestBody && hasRequestHeaders),
|
||||
(hasRequestMethod || hasRequestUrl || hasRequestBody || hasRequestHeaders)
|
||||
return (hasRequestMethod && hasRequestURL && hasRequestBody && hasRequestHeaders),
|
||||
(hasRequestMethod || hasRequestURL || hasRequestBody || hasRequestHeaders)
|
||||
}
|
||||
|
||||
func parseIamArn(iamArn string) (*iamEntity, error) {
|
||||
|
||||
@@ -663,6 +663,10 @@ func (b *backend) pathRoleCreateUpdate(
|
||||
roleEntry.AllowInstanceMigration = data.Get("allow_instance_migration").(bool)
|
||||
}
|
||||
|
||||
if roleEntry.AllowInstanceMigration && roleEntry.DisallowReauthentication {
|
||||
return logical.ErrorResponse("cannot specify both disallow_reauthentication=true and allow_instance_migration=true"), nil
|
||||
}
|
||||
|
||||
var resp logical.Response
|
||||
|
||||
ttlRaw, ok := data.GetOk("ttl")
|
||||
|
||||
@@ -124,6 +124,10 @@ func (b *backend) pathRoleTagUpdate(
|
||||
resp.AddWarning("Role does not allow instance migration. Login will not be allowed with this tag unless the role value is updated.")
|
||||
}
|
||||
|
||||
if disallowReauthentication && allowInstanceMigration {
|
||||
return logical.ErrorResponse("cannot set both disallow_reauthentication and allow_instance_migration"), nil
|
||||
}
|
||||
|
||||
// max_ttl for the role tag should be less than the max_ttl set on the role.
|
||||
maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second
|
||||
|
||||
|
||||
@@ -66,12 +66,25 @@ func TestBackend_pathRoleEc2(t *testing.T) {
|
||||
Data: data,
|
||||
Storage: storage,
|
||||
})
|
||||
if resp != nil && resp.IsError() {
|
||||
t.Fatalf("failed to create role: %s", resp.Data["error"])
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil || !resp.IsError() {
|
||||
t.Fatalf("expected failure to create role with both allow_instance_migration true and disallow_reauthentication true")
|
||||
}
|
||||
data["disallow_reauthentication"] = false
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "role/ami-abcd123",
|
||||
Data: data,
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp != nil && resp.IsError() {
|
||||
t.Fatalf("failure to update role: %v", resp.Data["error"])
|
||||
}
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "role/ami-abcd123",
|
||||
@@ -80,8 +93,12 @@ func TestBackend_pathRoleEc2(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !resp.Data["allow_instance_migration"].(bool) || !resp.Data["disallow_reauthentication"].(bool) {
|
||||
t.Fatal("bad: expected:true got:false\n")
|
||||
if !resp.Data["allow_instance_migration"].(bool) {
|
||||
t.Fatal("bad: expected allow_instance_migration:true got:false\n")
|
||||
}
|
||||
|
||||
if resp.Data["disallow_reauthentication"].(bool) {
|
||||
t.Fatal("bad: expected disallow_reauthentication: false got:true\n")
|
||||
}
|
||||
|
||||
// add another entry, to test listing of role entries
|
||||
@@ -529,7 +546,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) {
|
||||
"ttl": "10m",
|
||||
"max_ttl": "20m",
|
||||
"policies": "testpolicy1,testpolicy2",
|
||||
"disallow_reauthentication": true,
|
||||
"disallow_reauthentication": false,
|
||||
"hmac_key": "testhmackey",
|
||||
"period": "1m",
|
||||
}
|
||||
@@ -567,7 +584,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) {
|
||||
"ttl": time.Duration(600),
|
||||
"max_ttl": time.Duration(1200),
|
||||
"policies": []string{"testpolicy1", "testpolicy2"},
|
||||
"disallow_reauthentication": true,
|
||||
"disallow_reauthentication": false,
|
||||
"period": time.Duration(60),
|
||||
}
|
||||
|
||||
|
||||
@@ -587,7 +587,7 @@ func TestBackend_CRLs(t *testing.T) {
|
||||
func testFactory(t *testing.T) logical.Backend {
|
||||
b, err := Factory(&logical.BackendConfig{
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: 300 * time.Second,
|
||||
DefaultLeaseTTLVal: 1000 * time.Second,
|
||||
MaxLeaseTTLVal: 1800 * time.Second,
|
||||
},
|
||||
StorageView: &logical.InmemStorage{},
|
||||
@@ -619,9 +619,9 @@ func TestBackend_CertWrites(t *testing.T) {
|
||||
tc := logicaltest.TestCase{
|
||||
Backend: testFactory(t),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepCert(t, "aaa", ca1, "foo", "", false),
|
||||
testAccStepCert(t, "bbb", ca2, "foo", "", false),
|
||||
testAccStepCert(t, "ccc", ca3, "foo", "", true),
|
||||
testAccStepCert(t, "aaa", ca1, "foo", "", "", false),
|
||||
testAccStepCert(t, "bbb", ca2, "foo", "", "", false),
|
||||
testAccStepCert(t, "ccc", ca3, "foo", "", "", true),
|
||||
},
|
||||
}
|
||||
tc.Steps = append(tc.Steps, testAccStepListCerts(t, []string{"aaa", "bbb"})...)
|
||||
@@ -642,16 +642,18 @@ func TestBackend_basic_CA(t *testing.T) {
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: testFactory(t),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepCert(t, "web", ca, "foo", "", false),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCertLease(t, "web", ca, "foo"),
|
||||
testAccStepCertTTL(t, "web", ca, "foo"),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCertMaxTTL(t, "web", ca, "foo"),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCertNoLease(t, "web", ca, "foo"),
|
||||
testAccStepLoginDefaultLease(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "*.example.com", false),
|
||||
testAccStepCert(t, "web", ca, "foo", "*.example.com", "", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "*.invalid.com", false),
|
||||
testAccStepCert(t, "web", ca, "foo", "*.invalid.com", "", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
},
|
||||
})
|
||||
@@ -700,11 +702,68 @@ func TestBackend_basic_singleCert(t *testing.T) {
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: testFactory(t),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepCert(t, "web", ca, "foo", "", false),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "example.com", false),
|
||||
testAccStepCert(t, "web", ca, "foo", "example.com", "", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "invalid", false),
|
||||
testAccStepCert(t, "web", ca, "foo", "invalid", "", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "1.2.3.4:invalid", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Test a self-signed client with custom extensions (root CA) that is trusted
|
||||
func TestBackend_extensions_singleCert(t *testing.T) {
|
||||
connState, err := testConnState(
|
||||
"test-fixtures/root/rootcawextcert.pem",
|
||||
"test-fixtures/root/rootcawextkey.pem",
|
||||
"test-fixtures/root/rootcacert.pem",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error testing connection state: %v", err)
|
||||
}
|
||||
ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: testFactory(t),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:A UTF8String Extension", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:*,2.1.1.2:A UTF8*", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "1.2.3.45:*", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:The Wrong Value", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:*,2.1.1.2:The Wrong Value", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:,2.1.1.2:*", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:A UTF8String Extension", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:*,2.1.1.2:A UTF8*", false),
|
||||
testAccStepLogin(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "example.com", "1.2.3.45:*", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:The Wrong Value", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:*,2.1.1.2:The Wrong Value", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:A UTF8String Extension", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:*,2.1.1.2:A UTF8*", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "invalid", "1.2.3.45:*", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:The Wrong Value", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:*,2.1.1.2:The Wrong Value", false),
|
||||
testAccStepLoginInvalid(t, connState),
|
||||
},
|
||||
})
|
||||
@@ -724,9 +783,9 @@ func TestBackend_mixed_constraints(t *testing.T) {
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: testFactory(t),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepCert(t, "1unconstrained", ca, "foo", "", false),
|
||||
testAccStepCert(t, "2matching", ca, "foo", "*.example.com,whatever", false),
|
||||
testAccStepCert(t, "3invalid", ca, "foo", "invalid", false),
|
||||
testAccStepCert(t, "1unconstrained", ca, "foo", "", "", false),
|
||||
testAccStepCert(t, "2matching", ca, "foo", "*.example.com,whatever", "", false),
|
||||
testAccStepCert(t, "3invalid", ca, "foo", "invalid", "", false),
|
||||
testAccStepLogin(t, connState),
|
||||
// Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match
|
||||
testAccStepLoginWithName(t, connState, "2matching"),
|
||||
@@ -826,7 +885,7 @@ func testAccStepLoginDefaultLease(t *testing.T, connState tls.ConnectionState) l
|
||||
Unauthenticated: true,
|
||||
ConnState: &connState,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Auth.TTL != 300*time.Second {
|
||||
if resp.Auth.TTL != 1000*time.Second {
|
||||
t.Fatalf("bad lease length: %#v", resp.Auth)
|
||||
}
|
||||
|
||||
@@ -906,17 +965,18 @@ func testAccStepListCerts(
|
||||
}
|
||||
|
||||
func testAccStepCert(
|
||||
t *testing.T, name string, cert []byte, policies string, allowedNames string, expectError bool) logicaltest.TestStep {
|
||||
t *testing.T, name string, cert []byte, policies string, allowedNames string, requiredExtensions string, expectError bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "certs/" + name,
|
||||
ErrorOk: expectError,
|
||||
Data: map[string]interface{}{
|
||||
"certificate": string(cert),
|
||||
"policies": policies,
|
||||
"display_name": name,
|
||||
"allowed_names": allowedNames,
|
||||
"lease": 1000,
|
||||
"certificate": string(cert),
|
||||
"policies": policies,
|
||||
"display_name": name,
|
||||
"allowed_names": allowedNames,
|
||||
"required_extensions": requiredExtensions,
|
||||
"lease": 1000,
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil && expectError {
|
||||
@@ -955,6 +1015,21 @@ func testAccStepCertTTL(
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCertMaxTTL(
|
||||
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "certs/" + name,
|
||||
Data: map[string]interface{}{
|
||||
"certificate": string(cert),
|
||||
"policies": policies,
|
||||
"display_name": name,
|
||||
"ttl": "1000s",
|
||||
"max_ttl": "1200s",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCertNoLease(
|
||||
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
|
||||
@@ -45,6 +45,13 @@ Must be x509 PEM encoded.`,
|
||||
At least one must exist in either the Common Name or SANs. Supports globbing.`,
|
||||
},
|
||||
|
||||
"required_extensions": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `A comma-separated string or array of extensions
|
||||
formatted as "oid:value". Expects the extension value to be some type of ASN1 encoded string.
|
||||
All values much match. Supports globbing on "value".`,
|
||||
},
|
||||
|
||||
"display_name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `The display name to use for clients using this
|
||||
@@ -67,6 +74,19 @@ seconds. Defaults to system/backend default TTL.`,
|
||||
Description: `TTL for tokens issued by this backend.
|
||||
Defaults to system/backend default TTL time.`,
|
||||
},
|
||||
"max_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `Duration in either an integer number of seconds (3600) or
|
||||
an integer time unit (60m) after which the
|
||||
issued token can no longer be renewed.`,
|
||||
},
|
||||
"period": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `If set, indicates that the token generated using this role
|
||||
should never expire. The token should be renewed within the
|
||||
duration specified by this value. At each renewal, the token's
|
||||
TTL will be set to the value of this parameter.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
@@ -124,17 +144,14 @@ func (b *backend) pathCertRead(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
duration := cert.TTL
|
||||
if duration == 0 {
|
||||
duration = b.System().DefaultLeaseTTL()
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"certificate": cert.Certificate,
|
||||
"display_name": cert.DisplayName,
|
||||
"policies": cert.Policies,
|
||||
"ttl": duration / time.Second,
|
||||
"ttl": cert.TTL / time.Second,
|
||||
"max_ttl": cert.MaxTTL / time.Second,
|
||||
"period": cert.Period / time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -146,6 +163,48 @@ func (b *backend) pathCertWrite(
|
||||
displayName := d.Get("display_name").(string)
|
||||
policies := policyutil.ParsePolicies(d.Get("policies"))
|
||||
allowedNames := d.Get("allowed_names").([]string)
|
||||
requiredExtensions := d.Get("required_extensions").([]string)
|
||||
|
||||
var resp logical.Response
|
||||
|
||||
// Parse the ttl (or lease duration)
|
||||
systemDefaultTTL := b.System().DefaultLeaseTTL()
|
||||
ttl := time.Duration(d.Get("ttl").(int)) * time.Second
|
||||
if ttl == 0 {
|
||||
ttl = time.Duration(d.Get("lease").(int)) * time.Second
|
||||
}
|
||||
if ttl > systemDefaultTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Given ttl of %d seconds is greater than current mount/system default of %d seconds", ttl/time.Second, systemDefaultTTL/time.Second))
|
||||
}
|
||||
|
||||
if ttl < time.Duration(0) {
|
||||
return logical.ErrorResponse("ttl cannot be negative"), nil
|
||||
}
|
||||
|
||||
// Parse max_ttl
|
||||
systemMaxTTL := b.System().MaxLeaseTTL()
|
||||
maxTTL := time.Duration(d.Get("max_ttl").(int)) * time.Second
|
||||
if maxTTL > systemMaxTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Given max_ttl of %d seconds is greater than current mount/system default of %d seconds", maxTTL/time.Second, systemMaxTTL/time.Second))
|
||||
}
|
||||
|
||||
if maxTTL < time.Duration(0) {
|
||||
return logical.ErrorResponse("max_ttl cannot be negative"), nil
|
||||
}
|
||||
|
||||
if maxTTL != 0 && ttl > maxTTL {
|
||||
return logical.ErrorResponse("ttl should be shorter than max_ttl"), nil
|
||||
}
|
||||
|
||||
// Parse period
|
||||
period := time.Duration(d.Get("period").(int)) * time.Second
|
||||
if period > systemMaxTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Given period of %d seconds is greater than the backend's maximum TTL of %d seconds", period/time.Second, systemMaxTTL/time.Second))
|
||||
}
|
||||
|
||||
if period < time.Duration(0) {
|
||||
return logical.ErrorResponse("period cannot be negative"), nil
|
||||
}
|
||||
|
||||
// Default the display name to the certificate name if not given
|
||||
if displayName == "" {
|
||||
@@ -172,24 +231,15 @@ func (b *backend) pathCertWrite(
|
||||
}
|
||||
|
||||
certEntry := &CertEntry{
|
||||
Name: name,
|
||||
Certificate: certificate,
|
||||
DisplayName: displayName,
|
||||
Policies: policies,
|
||||
AllowedNames: allowedNames,
|
||||
}
|
||||
|
||||
// Parse the lease duration or default to backend/system default
|
||||
maxTTL := b.System().MaxLeaseTTL()
|
||||
ttl := time.Duration(d.Get("ttl").(int)) * time.Second
|
||||
if ttl == time.Duration(0) {
|
||||
ttl = time.Second * time.Duration(d.Get("lease").(int))
|
||||
}
|
||||
if ttl > maxTTL {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Given TTL of %d seconds greater than current mount/system default of %d seconds", ttl/time.Second, maxTTL/time.Second)), nil
|
||||
}
|
||||
if ttl > time.Duration(0) {
|
||||
certEntry.TTL = ttl
|
||||
Name: name,
|
||||
Certificate: certificate,
|
||||
DisplayName: displayName,
|
||||
Policies: policies,
|
||||
AllowedNames: allowedNames,
|
||||
RequiredExtensions: requiredExtensions,
|
||||
TTL: ttl,
|
||||
MaxTTL: maxTTL,
|
||||
Period: period,
|
||||
}
|
||||
|
||||
// Store it
|
||||
@@ -200,16 +250,24 @@ func (b *backend) pathCertWrite(
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
|
||||
if len(resp.Warnings) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
type CertEntry struct {
|
||||
Name string
|
||||
Certificate string
|
||||
DisplayName string
|
||||
Policies []string
|
||||
TTL time.Duration
|
||||
AllowedNames []string
|
||||
Name string
|
||||
Certificate string
|
||||
DisplayName string
|
||||
Policies []string
|
||||
TTL time.Duration
|
||||
MaxTTL time.Duration
|
||||
Period time.Duration
|
||||
AllowedNames []string
|
||||
RequiredExtensions []string
|
||||
}
|
||||
|
||||
const pathCertHelpSyn = `
|
||||
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
@@ -84,9 +86,9 @@ func (b *backend) pathLogin(
|
||||
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
|
||||
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
|
||||
|
||||
// Generate a response
|
||||
resp := &logical.Response{
|
||||
Auth: &logical.Auth{
|
||||
Period: matched.Entry.Period,
|
||||
InternalData: map[string]interface{}{
|
||||
"subject_key_id": skid,
|
||||
"authority_key_id": akid,
|
||||
@@ -108,6 +110,22 @@ func (b *backend) pathLogin(
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if matched.Entry.MaxTTL > time.Duration(0) {
|
||||
// Cap maxTTL to the sysview's max TTL
|
||||
maxTTL := matched.Entry.MaxTTL
|
||||
if maxTTL > b.System().MaxLeaseTTL() {
|
||||
maxTTL = b.System().MaxLeaseTTL()
|
||||
}
|
||||
|
||||
// Cap TTL to MaxTTL
|
||||
if resp.Auth.TTL > maxTTL {
|
||||
resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (maxTTL / time.Second)))
|
||||
resp.Auth.TTL = maxTTL
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a response
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -134,7 +152,7 @@ func (b *backend) pathLoginRenew(
|
||||
|
||||
clientCerts := req.Connection.ConnState.PeerCertificates
|
||||
if len(clientCerts) == 0 {
|
||||
return nil, fmt.Errorf("no client certificate found")
|
||||
return logical.ErrorResponse("no client certificate found"), nil
|
||||
}
|
||||
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
|
||||
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
|
||||
@@ -160,7 +178,12 @@ func (b *backend) pathLoginRenew(
|
||||
return nil, fmt.Errorf("policies have changed, not renewing")
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(cert.TTL, 0, b.System())(req, d)
|
||||
resp, err := framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(req, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Auth.Period = cert.Period
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
|
||||
@@ -237,28 +260,70 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
|
||||
}
|
||||
|
||||
func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool {
|
||||
return !b.checkForChainInCRLs(trustedChain) &&
|
||||
b.matchesNames(clientCert, config) &&
|
||||
b.matchesCertificateExtenions(clientCert, config)
|
||||
}
|
||||
|
||||
// matchesNames verifies that the certificate matches at least one configured
|
||||
// allowed name
|
||||
func (b *backend) matchesNames(clientCert *x509.Certificate, config *ParsedCert) bool {
|
||||
// Default behavior (no names) is to allow all names
|
||||
nameMatched := len(config.Entry.AllowedNames) == 0
|
||||
if len(config.Entry.AllowedNames) == 0 {
|
||||
return true
|
||||
}
|
||||
// At least one pattern must match at least one name if any patterns are specified
|
||||
for _, allowedName := range config.Entry.AllowedNames {
|
||||
if glob.Glob(allowedName, clientCert.Subject.CommonName) {
|
||||
nameMatched = true
|
||||
return true
|
||||
}
|
||||
|
||||
for _, name := range clientCert.DNSNames {
|
||||
if glob.Glob(allowedName, name) {
|
||||
nameMatched = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range clientCert.EmailAddresses {
|
||||
if glob.Glob(allowedName, name) {
|
||||
nameMatched = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return !b.checkForChainInCRLs(trustedChain) && nameMatched
|
||||
// matchesCertificateExtenions verifies that the certificate matches configured
|
||||
// required extensions
|
||||
func (b *backend) matchesCertificateExtenions(clientCert *x509.Certificate, config *ParsedCert) bool {
|
||||
// If no required extensions, nothing to check here
|
||||
if len(config.Entry.RequiredExtensions) == 0 {
|
||||
return true
|
||||
}
|
||||
// Fail fast if we have required extensions but no extensions on the cert
|
||||
if len(clientCert.Extensions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Build Client Extensions Map for Constraint Matching
|
||||
// x509 Writes Extensions in ASN1 with a bitstring tag, which results in the field
|
||||
// including its ASN.1 type tag bytes. For the sake of simplicity, assume string type
|
||||
// and drop the tag bytes. And get the number of bytes from the tag.
|
||||
clientExtMap := make(map[string]string, len(clientCert.Extensions))
|
||||
for _, ext := range clientCert.Extensions {
|
||||
var parsedValue string
|
||||
asn1.Unmarshal(ext.Value, &parsedValue)
|
||||
clientExtMap[ext.Id.String()] = parsedValue
|
||||
}
|
||||
// If any of the required extensions don't match the constraint fails
|
||||
for _, requiredExt := range config.Entry.RequiredExtensions {
|
||||
reqExt := strings.SplitN(requiredExt, ":", 2)
|
||||
clientExtValue, clientExtValueOk := clientExtMap[reqExt[0]]
|
||||
if !clientExtValueOk || !glob.Glob(reqExt[1], clientExtValue) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// loadTrustedCerts is used to load all the trusted certificates from the backend
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
92223EAFBBEE17A3
|
||||
21
builtin/credential/cert/test-fixtures/root/rootcawext.cnf
Normal file
21
builtin/credential/cert/test-fixtures/root/rootcawext.cnf
Normal file
@@ -0,0 +1,21 @@
|
||||
[ req ]
|
||||
default_bits = 2048
|
||||
encrypt_key = no
|
||||
prompt = no
|
||||
default_md = sha256
|
||||
req_extensions = req_v3
|
||||
distinguished_name = dn
|
||||
|
||||
[ dn ]
|
||||
CN = example.com
|
||||
|
||||
[ req_v3 ]
|
||||
subjectAltName = @alt_names
|
||||
2.1.1.1=ASN1:UTF8String:A UTF8String Extension
|
||||
2.1.1.2=ASN1:UTF8:A UTF8 Extension
|
||||
2.1.1.3=ASN1:IA5:An IA5 Extension
|
||||
2.1.1.4=ASN1:VISIBLE:A Visible Extension
|
||||
|
||||
[ alt_names ]
|
||||
DNS.1 = example.com
|
||||
IP.1 = 127.0.0.1
|
||||
19
builtin/credential/cert/test-fixtures/root/rootcawext.csr
Normal file
19
builtin/credential/cert/test-fixtures/root/rootcawext.csr
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE REQUEST-----
|
||||
MIIDAzCCAesCAQAwFjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3
|
||||
DQEBAQUAA4IBDwAwggEKAoIBAQDM2PrLyK/wVQIcnK362ZylDrIVMjFQzps/0AxM
|
||||
ke+8MNPMArBlSAhnZus6qb0nN0nJrDLkHQgYqnSvK9N7VUv/xFblEcOLBlciLhyN
|
||||
Wkm92+q/M/xOvUVmnYkN3XgTI5QNxF7ZWDFHmwCNV27RraQZou0hG7yvyoILLMQE
|
||||
3MnMCNM1nZ9JIuBMcRsZLGqQ1XNaQljboRVIUjimzkcfYyTruhLosTIbwForp78J
|
||||
MzHHqVjtLJXPqUnRMS7KhGMj1f2mIswQzCv6F2PWEzNBbP4Gb67znKikKDs0RgyL
|
||||
RyfizFNFJSC58XntK8jwHK1D8W3UepFf4K8xNFnhPoKWtWfJAgMBAAGggacwgaQG
|
||||
CSqGSIb3DQEJDjGBljCBkzAcBgNVHREEFTATggtleGFtcGxlLmNvbYcEfwAAATAf
|
||||
BgNRAQEEGAwWQSBVVEY4U3RyaW5nIEV4dGVuc2lvbjAZBgNRAQIEEgwQQSBVVEY4
|
||||
IEV4dGVuc2lvbjAZBgNRAQMEEhYQQW4gSUE1IEV4dGVuc2lvbjAcBgNRAQQEFRoT
|
||||
QSBWaXNpYmxlIEV4dGVuc2lvbjANBgkqhkiG9w0BAQsFAAOCAQEAtYjewBcqAXxk
|
||||
tDY0lpZid6ZvfngdDlDZX0vrs3zNppKNe5Sl+jsoDOexqTA7HQA/y1ru117sAEeB
|
||||
yiqMeZ7oPk8b3w+BZUpab7p2qPMhZypKl93y/jGXGscc3jRbUBnym9S91PSq6wUd
|
||||
f2aigSqFc9+ywFVdx5PnnZUfcrUQ2a+AweYEkGOzXX2Ga+Ige8grDMCzRgCoP5cW
|
||||
kM5ghwZp5wYIBGrKBU9iDcBlmnNhYaGWf+dD00JtVDPNn2bJnCsJHIO0nklZgnrS
|
||||
fli8VQ1nYPkONdkiRYLt6//6at1iNDoDgsVCChtlVkLpxFIKcDFUHlffZsc1kMFI
|
||||
HTX579k8hA==
|
||||
-----END CERTIFICATE REQUEST-----
|
||||
@@ -0,0 +1,20 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDRjCCAi6gAwIBAgIJAJIiPq+77hejMA0GCSqGSIb3DQEBCwUAMBYxFDASBgNV
|
||||
BAMTC2V4YW1wbGUuY29tMB4XDTE3MTEyOTE5MTgwM1oXDTI3MTEyNzE5MTgwM1ow
|
||||
FjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
|
||||
ggEKAoIBAQDM2PrLyK/wVQIcnK362ZylDrIVMjFQzps/0AxMke+8MNPMArBlSAhn
|
||||
Zus6qb0nN0nJrDLkHQgYqnSvK9N7VUv/xFblEcOLBlciLhyNWkm92+q/M/xOvUVm
|
||||
nYkN3XgTI5QNxF7ZWDFHmwCNV27RraQZou0hG7yvyoILLMQE3MnMCNM1nZ9JIuBM
|
||||
cRsZLGqQ1XNaQljboRVIUjimzkcfYyTruhLosTIbwForp78JMzHHqVjtLJXPqUnR
|
||||
MS7KhGMj1f2mIswQzCv6F2PWEzNBbP4Gb67znKikKDs0RgyLRyfizFNFJSC58Xnt
|
||||
K8jwHK1D8W3UepFf4K8xNFnhPoKWtWfJAgMBAAGjgZYwgZMwHAYDVR0RBBUwE4IL
|
||||
ZXhhbXBsZS5jb22HBH8AAAEwHwYDUQEBBBgMFkEgVVRGOFN0cmluZyBFeHRlbnNp
|
||||
b24wGQYDUQECBBIMEEEgVVRGOCBFeHRlbnNpb24wGQYDUQEDBBIWEEFuIElBNSBF
|
||||
eHRlbnNpb24wHAYDUQEEBBUaE0EgVmlzaWJsZSBFeHRlbnNpb24wDQYJKoZIhvcN
|
||||
AQELBQADggEBAGU/iA6saupEaGn/veVNCknFGDL7pst5D6eX/y9atXlBOdJe7ZJJ
|
||||
XQRkeHJldA0khVpzH7Ryfi+/25WDuNz+XTZqmb4ppeV8g9amtqBwxziQ9UUwYrza
|
||||
eDBqdXBaYp/iHUEHoceX4F44xuo80BIqwF0lD9TFNUFoILnF26ajhKX0xkGaiKTH
|
||||
6SbjBfHoQVMzOHokVRWregmgNycV+MAI9Ne9XkIZvdOYeNlcS9drZeJI3szkiaxB
|
||||
WWaWaAr5UU2Z0yUCZnAIDMRcIiUbSEjIDz504sSuCzTctMOxWZu0r/0UrXRzwZZi
|
||||
HAaKm3MUmBh733ChP4rTB58nr5DEr5rJ9P8=
|
||||
-----END CERTIFICATE-----
|
||||
28
builtin/credential/cert/test-fixtures/root/rootcawextkey.pem
Normal file
28
builtin/credential/cert/test-fixtures/root/rootcawextkey.pem
Normal file
@@ -0,0 +1,28 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDM2PrLyK/wVQIc
|
||||
nK362ZylDrIVMjFQzps/0AxMke+8MNPMArBlSAhnZus6qb0nN0nJrDLkHQgYqnSv
|
||||
K9N7VUv/xFblEcOLBlciLhyNWkm92+q/M/xOvUVmnYkN3XgTI5QNxF7ZWDFHmwCN
|
||||
V27RraQZou0hG7yvyoILLMQE3MnMCNM1nZ9JIuBMcRsZLGqQ1XNaQljboRVIUjim
|
||||
zkcfYyTruhLosTIbwForp78JMzHHqVjtLJXPqUnRMS7KhGMj1f2mIswQzCv6F2PW
|
||||
EzNBbP4Gb67znKikKDs0RgyLRyfizFNFJSC58XntK8jwHK1D8W3UepFf4K8xNFnh
|
||||
PoKWtWfJAgMBAAECggEAW7hLkzMok9N8PpNo0wjcuor58cOnkSbxHIFrAF3XmcvD
|
||||
CXWqxa6bFLFgYcPejdCTmVkg8EKPfXvVAxn8dxyaCss+nRJ3G6ibGxLKdgAXRItT
|
||||
cIk2T4svp+KhmzOur+MeR4vFbEuwxP8CIEclt3yoHVJ2Gnzw30UtNRO2MPcq48/C
|
||||
ZODGeBqUif1EGjDAvlqu5kl/pcDBJ3ctIZdVUMYYW4R9JtzKsmwhX7CRCBm8k5hG
|
||||
2uzn8AKwpuVtfWcnX59UUmHGJ8mjETuNLARRAwWBWhl8f7wckmi+PKERJGEM2QE5
|
||||
/Voy0p22zmQ3waS8LgiI7YHCAEFqjVWNziVGdR36gQKBgQDxkpfkEsfa5PieIaaF
|
||||
iQOO0rrjEJ9MBOQqmTDeclmDPNkM9qvCF/dqpJfOtliYFxd7JJ3OR2wKrBb5vGHt
|
||||
qIB51Rnm9aDTM4OUEhnhvbPlERD0W+yWYXWRvqyHz0GYwEFGQ83h95GC/qfTosqy
|
||||
LEzYLDafiPeNP+DG/HYRljAxUwKBgQDZFOWHEcZkSFPLNZiksHqs90OR2zIFxZcx
|
||||
SrbkjqXjRjehWEAwgpvQ/quSBxrE2E8xXgVm90G1JpWzxjUfKKQRM6solQeEpnwY
|
||||
kCy2Ozij/TtbLNRlU65UQ+nMto8KTSIyJbxxdOZxYdtJAJQp1FJO1a1WC11z4+zh
|
||||
lnLV1O5S8wKBgQCDf/QU4DBQtNGtas315Oa96XJ4RkUgoYz+r1NN09tsOERC7UgE
|
||||
KP2y3JQSn2pMqE1M6FrKvlBO4uzC10xLja0aJOmrssvwDBu1D8FtA9IYgJjFHAEG
|
||||
v1i7lJrgdu7TUtx1flVli1l3gF4lM3m5UaonBrJZV7rB9iLKzwUKf8IOJwKBgFt/
|
||||
QktPA6brEV56Za8sr1hOFA3bLNdf9B0Tl8j4ExWbWAFKeCu6MUDCxsAS/IZxgdeW
|
||||
AILovqpC7CBM78EFWTni5EaDohqYLYAQ7LeWeIYuSyFf4Nogjj74LQha/iliX4Jx
|
||||
g17y3dp2W34Gn2yOEG8oAxpcSfR54jMnPZnBWP5fAoGBAMNAd3oa/xq9A5v719ik
|
||||
naD7PdrjBdhnPk4egzMDv54y6pCFlvFbEiBduBWTmiVa7dSzhYtmEbri2WrgARlu
|
||||
vkfTnVH9E8Hnm4HTbNn+ebxrofq1AOAvdApSoslsOP1NT9J6zB89RzChJyzjbIQR
|
||||
Gevrutb4uO9qpB1jDVoMmGde
|
||||
-----END PRIVATE KEY-----
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/google/go-github/github"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/vault/helper/mfa"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -35,11 +36,11 @@ func Backend() *backend {
|
||||
}
|
||||
|
||||
allPaths := append(b.TeamMap.Paths(), b.UserMap.Paths()...)
|
||||
|
||||
b.Backend = &framework.Backend{
|
||||
Help: backendHelp,
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: mfa.MFARootPaths(),
|
||||
Unauthenticated: []string{
|
||||
"login",
|
||||
},
|
||||
@@ -47,9 +48,7 @@ func Backend() *backend {
|
||||
|
||||
Paths: append([]*framework.Path{
|
||||
pathConfig(&b),
|
||||
pathLogin(&b),
|
||||
}, allPaths...),
|
||||
|
||||
}, append(allPaths, mfa.MFAPaths(b.Backend, pathLogin(&b))...)...),
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
BackendType: logical.TypeCredential,
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (b *backend) pathLogin(
|
||||
return logical.ErrorResponse(fmt.Sprintf("error sanitizing TTLs: %s", err)), nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
resp := &logical.Response{
|
||||
Auth: &logical.Auth{
|
||||
InternalData: map[string]interface{}{
|
||||
"token": token,
|
||||
@@ -93,7 +93,18 @@ func (b *backend) pathLogin(
|
||||
Name: *verifyResp.User.Login,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
for _, teamName := range verifyResp.TeamNames {
|
||||
if teamName == "" {
|
||||
continue
|
||||
}
|
||||
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
|
||||
Name: teamName,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLoginRenew(
|
||||
@@ -125,7 +136,22 @@ func (b *backend) pathLoginRenew(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return framework.LeaseExtend(config.TTL, config.MaxTTL, b.System())(req, d)
|
||||
|
||||
resp, err := framework.LeaseExtend(config.TTL, config.MaxTTL, b.System())(req, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove old aliases
|
||||
resp.Auth.GroupAliases = nil
|
||||
|
||||
for _, teamName := range verifyResp.TeamNames {
|
||||
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
|
||||
Name: teamName,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) verifyCredentials(req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
|
||||
@@ -233,14 +259,16 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
|
||||
}
|
||||
|
||||
return &verifyCredentialsResp{
|
||||
User: user,
|
||||
Org: org,
|
||||
Policies: append(groupPoliciesList, userPoliciesList...),
|
||||
User: user,
|
||||
Org: org,
|
||||
Policies: append(groupPoliciesList, userPoliciesList...),
|
||||
TeamNames: teamNames,
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
type verifyCredentialsResp struct {
|
||||
User *github.User
|
||||
Org *github.Organization
|
||||
Policies []string
|
||||
User *github.User
|
||||
Org *github.Organization
|
||||
Policies []string
|
||||
TeamNames []string
|
||||
}
|
||||
|
||||
@@ -31,6 +31,10 @@ func Backend() *backend {
|
||||
Unauthenticated: []string{
|
||||
"login/*",
|
||||
},
|
||||
|
||||
SealWrapStorage: []string{
|
||||
"config",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: append([]*framework.Path{
|
||||
@@ -88,22 +92,22 @@ func EscapeLDAPValue(input string) string {
|
||||
return input
|
||||
}
|
||||
|
||||
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
|
||||
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
|
||||
|
||||
cfg, err := b.Config(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, logical.ErrorResponse("ldap backend not configured"), nil
|
||||
return nil, logical.ErrorResponse("ldap backend not configured"), nil, nil
|
||||
}
|
||||
|
||||
c, err := cfg.DialLDAP()
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
return nil, logical.ErrorResponse(err.Error()), nil, nil
|
||||
}
|
||||
if c == nil {
|
||||
return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil
|
||||
return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil, nil
|
||||
}
|
||||
|
||||
// Clean connection
|
||||
@@ -111,7 +115,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
|
||||
userBindDN, err := b.getUserBindDN(cfg, c, username)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
return nil, logical.ErrorResponse(err.Error()), nil, nil
|
||||
}
|
||||
|
||||
if b.Logger().IsDebug() {
|
||||
@@ -119,7 +123,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
}
|
||||
|
||||
if cfg.DenyNullBind && len(password) == 0 {
|
||||
return nil, logical.ErrorResponse("password cannot be of zero length when passwordless binds are being denied"), nil
|
||||
return nil, logical.ErrorResponse("password cannot be of zero length when passwordless binds are being denied"), nil, nil
|
||||
}
|
||||
|
||||
// Try to bind as the login user. This is where the actual authentication takes place.
|
||||
@@ -129,14 +133,14 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
err = c.UnauthenticatedBind(userBindDN)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil, nil
|
||||
}
|
||||
|
||||
// We re-bind to the BindDN if it's defined because we assume
|
||||
// the BindDN should be the one to search, not the user logging in.
|
||||
if cfg.BindDN != "" && cfg.BindPassword != "" {
|
||||
if err := c.Bind(cfg.BindDN, cfg.BindPassword); err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("Encountered an error while attempting to re-bind with the BindDN User: %s", err.Error())), nil
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("Encountered an error while attempting to re-bind with the BindDN User: %s", err.Error())), nil, nil
|
||||
}
|
||||
if b.Logger().IsDebug() {
|
||||
b.Logger().Debug("auth/ldap: Re-Bound to original BindDN")
|
||||
@@ -145,12 +149,12 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
|
||||
userDN, err := b.getUserDN(cfg, c, userBindDN)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
return nil, logical.ErrorResponse(err.Error()), nil, nil
|
||||
}
|
||||
|
||||
ldapGroups, err := b.getLdapGroups(cfg, c, userDN, username)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
return nil, logical.ErrorResponse(err.Error()), nil, nil
|
||||
}
|
||||
if b.Logger().IsDebug() {
|
||||
b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups)
|
||||
@@ -199,10 +203,10 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
}
|
||||
|
||||
ldapResponse.Data["error"] = errStr
|
||||
return nil, ldapResponse, nil
|
||||
return nil, ldapResponse, nil, nil
|
||||
}
|
||||
|
||||
return policies, ldapResponse, nil
|
||||
return policies, ldapResponse, allGroups, nil
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -55,7 +55,7 @@ func (b *backend) pathLogin(
|
||||
username := d.Get("username").(string)
|
||||
password := d.Get("password").(string)
|
||||
|
||||
policies, resp, err := b.Login(req, username, password)
|
||||
policies, resp, groupNames, err := b.Login(req, username, password)
|
||||
// Handle an internal error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -87,6 +87,15 @@ func (b *backend) pathLogin(
|
||||
Name: username,
|
||||
},
|
||||
}
|
||||
|
||||
for _, groupName := range groupNames {
|
||||
if groupName == "" {
|
||||
continue
|
||||
}
|
||||
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
|
||||
Name: groupName,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -96,7 +105,7 @@ func (b *backend) pathLoginRenew(
|
||||
username := req.Auth.Metadata["username"]
|
||||
password := req.Auth.InternalData["password"].(string)
|
||||
|
||||
loginPolicies, resp, err := b.Login(req, username, password)
|
||||
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
|
||||
if len(loginPolicies) == 0 {
|
||||
return resp, err
|
||||
}
|
||||
@@ -105,7 +114,21 @@ func (b *backend) pathLoginRenew(
|
||||
return nil, fmt.Errorf("policies have changed, not renewing")
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(0, 0, b.System())(req, d)
|
||||
resp, err = framework.LeaseExtend(0, 0, b.System())(req, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove old aliases
|
||||
resp.Auth.GroupAliases = nil
|
||||
|
||||
for _, groupName := range groupNames {
|
||||
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
|
||||
Name: groupName,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathLoginSyn = `
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/chrismalek/oktasdk-go/okta"
|
||||
"github.com/hashicorp/vault/helper/mfa"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
@@ -22,9 +23,14 @@ func Backend() *backend {
|
||||
Help: backendHelp,
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: mfa.MFARootPaths(),
|
||||
|
||||
Unauthenticated: []string{
|
||||
"login/*",
|
||||
},
|
||||
SealWrapStorage: []string{
|
||||
"config",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: append([]*framework.Path{
|
||||
@@ -33,8 +39,9 @@ func Backend() *backend {
|
||||
pathGroups(&b),
|
||||
pathUsersList(&b),
|
||||
pathGroupsList(&b),
|
||||
pathLogin(&b),
|
||||
}),
|
||||
},
|
||||
mfa.MFAPaths(b.Backend, pathLogin(&b))...,
|
||||
),
|
||||
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
BackendType: logical.TypeCredential,
|
||||
@@ -47,13 +54,13 @@ type backend struct {
|
||||
*framework.Backend
|
||||
}
|
||||
|
||||
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
|
||||
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
|
||||
cfg, err := b.Config(req.Storage)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, logical.ErrorResponse("Okta auth method not configured"), nil
|
||||
return nil, logical.ErrorResponse("Okta auth method not configured"), nil, nil
|
||||
}
|
||||
|
||||
client := cfg.OktaClient()
|
||||
@@ -71,16 +78,16 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
"password": password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
var result authResult
|
||||
rsp, err := client.Do(authReq, &result)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil, nil
|
||||
}
|
||||
if rsp == nil {
|
||||
return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil
|
||||
return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil, nil
|
||||
}
|
||||
|
||||
oktaResponse := &logical.Response{
|
||||
@@ -92,7 +99,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
if cfg.Token != "" {
|
||||
oktaGroups, err := b.getOktaGroups(client, &result.Embedded.User)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("okta failure retrieving groups: %v", err)), nil
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("okta failure retrieving groups: %v", err)), nil, nil
|
||||
}
|
||||
if len(oktaGroups) == 0 {
|
||||
errString := fmt.Sprintf(
|
||||
@@ -142,10 +149,10 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
||||
}
|
||||
|
||||
oktaResponse.Data["error"] = errStr
|
||||
return nil, oktaResponse, nil
|
||||
return nil, oktaResponse, nil, nil
|
||||
}
|
||||
|
||||
return policies, oktaResponse, nil
|
||||
return policies, oktaResponse, allGroups, nil
|
||||
}
|
||||
|
||||
func (b *backend) getOktaGroups(client *okta.Client, user *okta.User) ([]string, error) {
|
||||
|
||||
@@ -38,6 +38,15 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
|
||||
"password": password,
|
||||
}
|
||||
|
||||
mfa_method, ok := m["method"]
|
||||
if ok {
|
||||
data["method"] = mfa_method
|
||||
}
|
||||
mfa_passcode, ok := m["passcode"]
|
||||
if ok {
|
||||
data["passcode"] = mfa_passcode
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("auth/%s/login/%s", mount, username)
|
||||
secret, err := c.Logical().Write(path, data)
|
||||
if err != nil {
|
||||
|
||||
@@ -57,7 +57,7 @@ func (b *backend) pathLogin(
|
||||
username := d.Get("username").(string)
|
||||
password := d.Get("password").(string)
|
||||
|
||||
policies, resp, err := b.Login(req, username, password)
|
||||
policies, resp, groupNames, err := b.Login(req, username, password)
|
||||
// Handle an internal error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -96,6 +96,16 @@ func (b *backend) pathLogin(
|
||||
Name: username,
|
||||
},
|
||||
}
|
||||
|
||||
for _, groupName := range groupNames {
|
||||
if groupName == "" {
|
||||
continue
|
||||
}
|
||||
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
|
||||
Name: groupName,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -105,7 +115,7 @@ func (b *backend) pathLoginRenew(
|
||||
username := req.Auth.Metadata["username"]
|
||||
password := req.Auth.InternalData["password"].(string)
|
||||
|
||||
loginPolicies, resp, err := b.Login(req, username, password)
|
||||
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
|
||||
if len(loginPolicies) == 0 {
|
||||
return resp, err
|
||||
}
|
||||
@@ -119,7 +129,22 @@ func (b *backend) pathLoginRenew(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(cfg.TTL, cfg.MaxTTL, b.System())(req, d)
|
||||
resp, err = framework.LeaseExtend(cfg.TTL, cfg.MaxTTL, b.System())(req, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove old aliases
|
||||
resp.Auth.GroupAliases = nil
|
||||
|
||||
for _, groupName := range groupNames {
|
||||
resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{
|
||||
Name: groupName,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
||||
}
|
||||
|
||||
func (b *backend) getConfig(req *logical.Request) (*ConfigEntry, error) {
|
||||
|
||||
@@ -26,6 +26,10 @@ func Backend() *backend {
|
||||
"login",
|
||||
"login/*",
|
||||
},
|
||||
|
||||
SealWrapStorage: []string{
|
||||
"config",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: append([]*framework.Path{
|
||||
|
||||
@@ -39,7 +39,7 @@ func pathConfig(b *backend) *framework.Path {
|
||||
"read_timeout": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Default: 10,
|
||||
Description: "Number of seconds before response times out (default: 10). Note: kept for backwards compatibility, currently unused.",
|
||||
Description: "Number of seconds before response times out (default: 10)",
|
||||
},
|
||||
"nas_port": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
|
||||
@@ -154,7 +154,9 @@ func (b *backend) RadiusLogin(req *logical.Request, username string, password st
|
||||
Timeout: time.Duration(cfg.DialTimeout) * time.Second,
|
||||
},
|
||||
}
|
||||
received, err := client.Exchange(context.Background(), packet, hostport)
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), time.Duration(cfg.ReadTimeout)*time.Second)
|
||||
received, err := client.Exchange(ctx, packet, hostport)
|
||||
cancelFunc()
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
@@ -25,6 +25,9 @@ func Backend() *backend {
|
||||
LocalStorage: []string{
|
||||
framework.WALPrefix,
|
||||
},
|
||||
SealWrapStorage: []string{
|
||||
"config/root",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
|
||||
@@ -13,8 +13,9 @@ import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func getRootConfig(s logical.Storage) (*aws.Config, error) {
|
||||
func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
|
||||
credsConfig := &awsutil.CredentialsConfig{}
|
||||
var endpoint string
|
||||
|
||||
entry, err := s.Get("config/root")
|
||||
if err != nil {
|
||||
@@ -29,6 +30,12 @@ func getRootConfig(s logical.Storage) (*aws.Config, error) {
|
||||
credsConfig.AccessKey = config.AccessKey
|
||||
credsConfig.SecretKey = config.SecretKey
|
||||
credsConfig.Region = config.Region
|
||||
switch {
|
||||
case clientType == "iam" && config.IAMEndpoint != "":
|
||||
endpoint = *aws.String(config.IAMEndpoint)
|
||||
case clientType == "sts" && config.STSEndpoint != "":
|
||||
endpoint = *aws.String(config.STSEndpoint)
|
||||
}
|
||||
}
|
||||
|
||||
if credsConfig.Region == "" {
|
||||
@@ -51,16 +58,19 @@ func getRootConfig(s logical.Storage) (*aws.Config, error) {
|
||||
return &aws.Config{
|
||||
Credentials: creds,
|
||||
Region: aws.String(credsConfig.Region),
|
||||
Endpoint: &endpoint,
|
||||
HTTPClient: cleanhttp.DefaultClient(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func clientIAM(s logical.Storage) (*iam.IAM, error) {
|
||||
awsConfig, err := getRootConfig(s)
|
||||
awsConfig, err := getRootConfig(s, "iam")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := iam.New(session.New(awsConfig))
|
||||
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("could not obtain iam client")
|
||||
}
|
||||
@@ -68,11 +78,12 @@ func clientIAM(s logical.Storage) (*iam.IAM, error) {
|
||||
}
|
||||
|
||||
func clientSTS(s logical.Storage) (*sts.STS, error) {
|
||||
awsConfig, err := getRootConfig(s)
|
||||
awsConfig, err := getRootConfig(s, "sts")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := sts.New(session.New(awsConfig))
|
||||
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("could not obtain sts client")
|
||||
}
|
||||
|
||||
@@ -23,6 +23,14 @@ func pathConfigRoot() *framework.Path {
|
||||
Type: framework.TypeString,
|
||||
Description: "Region for API calls.",
|
||||
},
|
||||
"iam_endpoint": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Endpoint to custom IAM server URL",
|
||||
},
|
||||
"sts_endpoint": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Endpoint to custom STS server URL",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
@@ -37,11 +45,15 @@ func pathConfigRoot() *framework.Path {
|
||||
func pathConfigRootWrite(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
region := data.Get("region").(string)
|
||||
iamendpoint := data.Get("iam_endpoint").(string)
|
||||
stsendpoint := data.Get("sts_endpoint").(string)
|
||||
|
||||
entry, err := logical.StorageEntryJSON("config/root", rootConfig{
|
||||
AccessKey: data.Get("access_key").(string),
|
||||
SecretKey: data.Get("secret_key").(string),
|
||||
Region: region,
|
||||
AccessKey: data.Get("access_key").(string),
|
||||
SecretKey: data.Get("secret_key").(string),
|
||||
IAMEndpoint: iamendpoint,
|
||||
STSEndpoint: stsendpoint,
|
||||
Region: region,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -55,9 +67,11 @@ func pathConfigRootWrite(
|
||||
}
|
||||
|
||||
type rootConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Region string `json:"region"`
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
IAMEndpoint string `json:"iam_endpoint"`
|
||||
STSEndpoint string `json:"sts_endpoint"`
|
||||
Region string `json:"region"`
|
||||
}
|
||||
|
||||
const pathConfigRootHelpSyn = `
|
||||
|
||||
@@ -25,6 +25,12 @@ func Backend() *backend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathRoles(&b),
|
||||
|
||||
@@ -4,73 +4,76 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v2"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) {
|
||||
func prepareCassandraTestContainer(t *testing.T) (func(), string, int) {
|
||||
if os.Getenv("CASSANDRA_HOST") != "" {
|
||||
return "", os.Getenv("CASSANDRA_HOST")
|
||||
return func() {}, os.Getenv("CASSANDRA_HOST"), 0
|
||||
}
|
||||
|
||||
// Without this the checks for whether the container has started seem to
|
||||
// never actually pass. There's really no reason to expose the test
|
||||
// containers, so don't.
|
||||
dockertest.BindDockerToLocalhost = "yep"
|
||||
|
||||
testImagePull.Do(func() {
|
||||
dockertest.Pull("cassandra")
|
||||
})
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
cwd, _ := os.Getwd()
|
||||
cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd)
|
||||
|
||||
cid, connErr := dockertest.ConnectToCassandra("latest", 60, 1000*time.Millisecond, func(connURL string) bool {
|
||||
// This will cause a validation to run
|
||||
resp, err := b.HandleRequest(&logical.Request{
|
||||
Storage: s,
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: map[string]interface{}{
|
||||
"hosts": connURL,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": 3,
|
||||
},
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
// It's likely not up and running yet, so return false and try again
|
||||
return false
|
||||
}
|
||||
|
||||
retURL = connURL
|
||||
return true
|
||||
}, []string{"-v", cwd + "/test-fixtures/:/etc/cassandra/"}...)
|
||||
|
||||
if connErr != nil {
|
||||
if cid != "" {
|
||||
cid.KillRemove()
|
||||
}
|
||||
t.Fatalf("could not connect to database: %v", connErr)
|
||||
ro := &dockertest.RunOptions{
|
||||
Repository: "cassandra",
|
||||
Tag: "latest",
|
||||
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
|
||||
Mounts: []string{cassandraMountPath},
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
|
||||
err := cid.KillRemove()
|
||||
resource, err := pool.RunWithOptions(ro)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("Could not start local cassandra docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
port, _ := strconv.Atoi(resource.GetPort("9042/tcp"))
|
||||
address := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
clusterConfig := gocql.NewCluster(address)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: "cassandra",
|
||||
Password: "cassandra",
|
||||
}
|
||||
clusterConfig.ProtoVersion = 4
|
||||
clusterConfig.Port = port
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
defer session.Close()
|
||||
return nil
|
||||
}); err != nil {
|
||||
cleanup()
|
||||
t.Fatalf("Could not connect to cassandra docker container: %s", err)
|
||||
}
|
||||
return cleanup, address, port
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
@@ -84,10 +87,8 @@ func TestBackend_basic(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cid, hostname := prepareTestContainer(t, config.StorageView, b)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
cleanup, hostname, _ := prepareCassandraTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
@@ -110,10 +111,8 @@ func TestBackend_roleCrud(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cid, hostname := prepareTestContainer(t, config.StorageView, b)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
cleanup, hostname, _ := prepareCassandraTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
|
||||
@@ -421,7 +421,7 @@ seed_provider:
|
||||
parameters:
|
||||
# seeds is actually a comma-delimited list of addresses.
|
||||
# Ex: "<ip1>,<ip2>,<ip3>"
|
||||
- seeds: "172.17.0.3"
|
||||
- seeds: "127.0.0.1"
|
||||
|
||||
# For workloads with more data than can fit in memory, Cassandra's
|
||||
# bottleneck will be reads that need to fetch data from
|
||||
@@ -572,7 +572,7 @@ ssl_storage_port: 7001
|
||||
#
|
||||
# Setting listen_address to 0.0.0.0 is always wrong.
|
||||
#
|
||||
listen_address: 172.17.0.3
|
||||
listen_address: 172.17.0.2
|
||||
|
||||
# Set listen_address OR listen_interface, not both. Interfaces must correspond
|
||||
# to a single address, IP aliasing is not supported.
|
||||
@@ -586,7 +586,7 @@ listen_address: 172.17.0.3
|
||||
|
||||
# Address to broadcast to other Cassandra nodes
|
||||
# Leaving this blank will set it to the same value as listen_address
|
||||
broadcast_address: 172.17.0.3
|
||||
broadcast_address: 127.0.0.1
|
||||
|
||||
# When using multiple physical network interfaces, set this
|
||||
# to true to listen on broadcast_address in addition to
|
||||
@@ -668,7 +668,7 @@ rpc_port: 9160
|
||||
# be set to 0.0.0.0. If left blank, this will be set to the value of
|
||||
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
|
||||
# be set.
|
||||
broadcast_rpc_address: 172.17.0.3
|
||||
broadcast_rpc_address: 127.0.0.1
|
||||
|
||||
# enable or disable keepalive on rpc/native connections
|
||||
rpc_keepalive: true
|
||||
|
||||
@@ -16,6 +16,12 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/access",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigAccess(),
|
||||
pathListRoles(&b),
|
||||
|
||||
@@ -433,12 +433,8 @@ func testAccStepReadPolicy(t *testing.T, name string, policy string, lease time.
|
||||
return fmt.Errorf("mismatch: %s %s", out, policy)
|
||||
}
|
||||
|
||||
leaseRaw := resp.Data["lease"].(string)
|
||||
l, err := time.ParseDuration(leaseRaw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if l != lease {
|
||||
l := resp.Data["lease"].(int64)
|
||||
if lease != time.Second*time.Duration(l) {
|
||||
return fmt.Errorf("mismatch: %v %v", l, lease)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -44,7 +44,7 @@ Defaults to 'client'.`,
|
||||
},
|
||||
|
||||
"lease": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: "Lease time of the role.",
|
||||
},
|
||||
},
|
||||
@@ -91,7 +91,7 @@ func pathRolesRead(
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"lease": result.Lease.String(),
|
||||
"lease": int64(result.Lease.Seconds()),
|
||||
"token_type": result.TokenType,
|
||||
},
|
||||
}
|
||||
@@ -130,13 +130,9 @@ func pathRolesWrite(
|
||||
}
|
||||
|
||||
var lease time.Duration
|
||||
leaseParam := d.Get("lease").(string)
|
||||
if leaseParam != "" {
|
||||
lease, err = time.ParseDuration(leaseParam)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"error parsing given lease of %s: %s", leaseParam, err)), nil
|
||||
}
|
||||
leaseParamRaw, ok := d.GetOk("lease")
|
||||
if ok {
|
||||
lease = time.Second * time.Duration(leaseParamRaw.(int))
|
||||
}
|
||||
|
||||
entry, err := logical.StorageEntryJSON("policy/"+name, roleConfig{
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
|
||||
func pathToken(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Pattern: "creds/" + framework.GenericNameRegex("role"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
"role": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role",
|
||||
},
|
||||
@@ -27,14 +27,14 @@ func pathToken(b *backend) *framework.Path {
|
||||
|
||||
func (b *backend) pathTokenRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
role := d.Get("role").(string)
|
||||
|
||||
entry, err := req.Storage.Get("policy/" + name)
|
||||
entry, err := req.Storage.Get("policy/" + role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving role: %s", err)
|
||||
}
|
||||
if entry == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Role '%s' not found", name)), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("role %q not found", role)), nil
|
||||
}
|
||||
|
||||
var result roleConfig
|
||||
@@ -56,7 +56,7 @@ func (b *backend) pathTokenRead(
|
||||
}
|
||||
|
||||
// Generate a name for the token
|
||||
tokenName := fmt.Sprintf("Vault %s %s %d", name, req.DisplayName, time.Now().UnixNano())
|
||||
tokenName := fmt.Sprintf("Vault %s %s %d", role, req.DisplayName, time.Now().UnixNano())
|
||||
|
||||
// Create it
|
||||
token, _, err := c.ACL().Create(&api.ACLEntry{
|
||||
@@ -73,6 +73,7 @@ func (b *backend) pathTokenRead(
|
||||
"token": token,
|
||||
}, map[string]interface{}{
|
||||
"token": token,
|
||||
"role": role,
|
||||
})
|
||||
s.Secret.TTL = result.Lease
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package consul
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
@@ -26,8 +28,30 @@ func secretToken(b *backend) *framework.Secret {
|
||||
|
||||
func (b *backend) secretTokenRenew(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
roleRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok || roleRaw == nil {
|
||||
return framework.LeaseExtend(0, 0, b.System())(req, d)
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(0, 0, b.System())(req, d)
|
||||
role, ok := roleRaw.(string)
|
||||
if !ok {
|
||||
return framework.LeaseExtend(0, 0, b.System())(req, d)
|
||||
}
|
||||
|
||||
entry, err := req.Storage.Get("policy/" + role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving role: %s", err)
|
||||
}
|
||||
if entry == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("issuing role %q not found", role)), nil
|
||||
}
|
||||
|
||||
var result roleConfig
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(result.Lease, 0, b.System())(req, d)
|
||||
}
|
||||
|
||||
func secretTokenRevoke(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"strings"
|
||||
@@ -28,6 +29,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathListPluginConnection(&b),
|
||||
pathConfigurePluginConnection(&b),
|
||||
@@ -81,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
|
||||
// This function creates a new db object from the stored configuration and
|
||||
// caches it in the connections map. The caller of this function needs to hold
|
||||
// the backend's write lock
|
||||
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
@@ -97,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, true)
|
||||
err = db.Initialize(ctx, config.ConnectionDetails, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -124,6 +131,21 @@ func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*Datab
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
type upgradeStatements struct {
|
||||
// This json tag has a typo in it, the new version does not. This
|
||||
// necessitates this upgrade logic.
|
||||
CreationStatements string `json:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements"`
|
||||
}
|
||||
|
||||
type upgradeCheck struct {
|
||||
// This json tag has a typo in it, the new version does not. This
|
||||
// necessitates this upgrade logic.
|
||||
Statements upgradeStatements `json:"statments"`
|
||||
}
|
||||
|
||||
func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + roleName)
|
||||
if err != nil {
|
||||
@@ -133,11 +155,24 @@ func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry,
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var upgradeCh upgradeCheck
|
||||
if err := entry.DecodeJSON(&upgradeCh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
empty := upgradeCheck{}
|
||||
if upgradeCh != empty {
|
||||
result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements
|
||||
result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements
|
||||
result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements
|
||||
result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
@@ -164,7 +199,8 @@ func (b *databaseBackend) clearConnection(name string) {
|
||||
|
||||
func (b *databaseBackend) closeIfShutdown(name string, err error) {
|
||||
// Plugin has shutdown, close it so next call can reconnect.
|
||||
if err == rpc.ErrShutdown {
|
||||
switch err {
|
||||
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
|
||||
b.Lock()
|
||||
b.clearConnection(name)
|
||||
b.Unlock()
|
||||
|
||||
@@ -116,6 +116,55 @@ func TestBackend_PluginMain(t *testing.T) {
|
||||
postgresql.Run(apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
func TestBackend_RoleUpgrade(t *testing.T) {
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
backend := &databaseBackend{}
|
||||
|
||||
roleEnt := &roleEntry{
|
||||
Statements: dbplugin.Statements{
|
||||
CreationStatements: "test",
|
||||
},
|
||||
}
|
||||
|
||||
entry, err := logical.StorageEntryJSON("role/test", roleEnt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := storage.Put(entry); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
role, err := backend.Role(storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(role, roleEnt) {
|
||||
t.Fatal("bad role %#v", role)
|
||||
}
|
||||
|
||||
// Upgrade case
|
||||
badJSON := `{"statments":{"creation_statments":"test","revocation_statements":"","rollback_statements":"","renew_statements":""}}`
|
||||
entry = &logical.StorageEntry{
|
||||
Key: "role/test",
|
||||
Value: []byte(badJSON),
|
||||
}
|
||||
if err := storage.Put(entry); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
role, err = backend.Role(storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(role, roleEnt) {
|
||||
t.Fatal("bad role %#v", role)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestBackend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
@@ -488,9 +537,11 @@ func TestBackend_roleCrud(t *testing.T) {
|
||||
RevocationStatements: defaultRevocationSQL,
|
||||
}
|
||||
|
||||
var actual dbplugin.Statements
|
||||
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
|
||||
t.Fatal(err)
|
||||
actual := dbplugin.Statements{
|
||||
CreationStatements: resp.Data["creation_statements"].(string),
|
||||
RevocationStatements: resp.Data["revocation_statements"].(string),
|
||||
RollbackStatements: resp.Data["rollback_statements"].(string),
|
||||
RenewStatements: resp.Data["renew_statements"].(string),
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
@@ -609,6 +660,40 @@ func TestBackend_allowedRoles(t *testing.T) {
|
||||
t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err)
|
||||
}
|
||||
|
||||
// update connection with glob allowed roles connection
|
||||
data = map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"allowed_roles": "allow*",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Get creds, should work.
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/allowed",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err = b.HandleRequest(req)
|
||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
// update connection with * allowed roles connection
|
||||
data = map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
@@ -17,11 +15,11 @@ type DatabasePluginClient struct {
|
||||
client *plugin.Client
|
||||
sync.Mutex
|
||||
|
||||
*databasePluginRPCClient
|
||||
Database
|
||||
}
|
||||
|
||||
func (dc *DatabasePluginClient) Close() error {
|
||||
err := dc.databasePluginRPCClient.Close()
|
||||
err := dc.Database.Close()
|
||||
dc.client.Kill()
|
||||
|
||||
return err
|
||||
@@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
|
||||
|
||||
// We should have a database type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
databaseRPC := raw.(*databasePluginRPCClient)
|
||||
var db Database
|
||||
switch raw.(type) {
|
||||
case *gRPCClient:
|
||||
db = raw.(*gRPCClient)
|
||||
case *databasePluginRPCClient:
|
||||
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
|
||||
db = raw.(*databasePluginRPCClient)
|
||||
default:
|
||||
return nil, errors.New("unsupported client type")
|
||||
}
|
||||
|
||||
// Wrap RPC implimentation in DatabasePluginClient
|
||||
return &DatabasePluginClient{
|
||||
client: client,
|
||||
databasePluginRPCClient: databaseRPC,
|
||||
client: client,
|
||||
Database: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequest{
|
||||
Statements: statements,
|
||||
UsernameConfig: usernameConfig,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
|
||||
req := RevokeUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
req := InitializeRequest{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
556
builtin/logical/database/dbplugin/database.pb.go
Normal file
556
builtin/logical/database/dbplugin/database.pb.go
Normal file
@@ -0,0 +1,556 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: builtin/logical/database/dbplugin/database.proto
|
||||
|
||||
/*
|
||||
Package dbplugin is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
builtin/logical/database/dbplugin/database.proto
|
||||
|
||||
It has these top-level messages:
|
||||
InitializeRequest
|
||||
CreateUserRequest
|
||||
RenewUserRequest
|
||||
RevokeUserRequest
|
||||
Statements
|
||||
UsernameConfig
|
||||
CreateUserResponse
|
||||
TypeResponse
|
||||
Empty
|
||||
*/
|
||||
package dbplugin
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
|
||||
|
||||
import (
|
||||
context "golang.org/x/net/context"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type InitializeRequest struct {
|
||||
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
|
||||
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
|
||||
}
|
||||
|
||||
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
|
||||
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*InitializeRequest) ProtoMessage() {}
|
||||
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func (m *InitializeRequest) GetConfig() []byte {
|
||||
if m != nil {
|
||||
return m.Config
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InitializeRequest) GetVerifyConnection() bool {
|
||||
if m != nil {
|
||||
return m.VerifyConnection
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
|
||||
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
|
||||
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*CreateUserRequest) ProtoMessage() {}
|
||||
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
||||
|
||||
func (m *CreateUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
|
||||
if m != nil {
|
||||
return m.UsernameConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||
if m != nil {
|
||||
return m.Expiration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RenewUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
|
||||
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*RenewUserRequest) ProtoMessage() {}
|
||||
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||
|
||||
func (m *RenewUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||
if m != nil {
|
||||
return m.Expiration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RevokeUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
|
||||
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*RevokeUserRequest) ProtoMessage() {}
|
||||
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
||||
|
||||
func (m *RevokeUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *RevokeUserRequest) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Statements struct {
|
||||
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
|
||||
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
|
||||
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
|
||||
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Statements) Reset() { *m = Statements{} }
|
||||
func (m *Statements) String() string { return proto.CompactTextString(m) }
|
||||
func (*Statements) ProtoMessage() {}
|
||||
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
|
||||
|
||||
func (m *Statements) GetCreationStatements() string {
|
||||
if m != nil {
|
||||
return m.CreationStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRevocationStatements() string {
|
||||
if m != nil {
|
||||
return m.RevocationStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRollbackStatements() string {
|
||||
if m != nil {
|
||||
return m.RollbackStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRenewStatements() string {
|
||||
if m != nil {
|
||||
return m.RenewStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type UsernameConfig struct {
|
||||
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
|
||||
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
|
||||
}
|
||||
|
||||
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
|
||||
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
|
||||
func (*UsernameConfig) ProtoMessage() {}
|
||||
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
|
||||
|
||||
func (m *UsernameConfig) GetDisplayName() string {
|
||||
if m != nil {
|
||||
return m.DisplayName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *UsernameConfig) GetRoleName() string {
|
||||
if m != nil {
|
||||
return m.RoleName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type CreateUserResponse struct {
|
||||
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
|
||||
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
|
||||
}
|
||||
|
||||
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
|
||||
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*CreateUserResponse) ProtoMessage() {}
|
||||
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
|
||||
|
||||
func (m *CreateUserResponse) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *CreateUserResponse) GetPassword() string {
|
||||
if m != nil {
|
||||
return m.Password
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type TypeResponse struct {
|
||||
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
|
||||
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*TypeResponse) ProtoMessage() {}
|
||||
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
|
||||
|
||||
func (m *TypeResponse) GetType() string {
|
||||
if m != nil {
|
||||
return m.Type
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Empty struct {
|
||||
}
|
||||
|
||||
func (m *Empty) Reset() { *m = Empty{} }
|
||||
func (m *Empty) String() string { return proto.CompactTextString(m) }
|
||||
func (*Empty) ProtoMessage() {}
|
||||
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
|
||||
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
|
||||
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
|
||||
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
|
||||
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
|
||||
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
|
||||
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
|
||||
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
|
||||
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// Client API for Database service
|
||||
|
||||
type DatabaseClient interface {
|
||||
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
|
||||
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
|
||||
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
||||
}
|
||||
|
||||
type databaseClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
|
||||
return &databaseClient{cc}
|
||||
}
|
||||
|
||||
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
|
||||
out := new(TypeResponse)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
|
||||
out := new(CreateUserResponse)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Server API for Database service
|
||||
|
||||
type DatabaseServer interface {
|
||||
Type(context.Context, *Empty) (*TypeResponse, error)
|
||||
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
|
||||
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
|
||||
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
|
||||
Initialize(context.Context, *InitializeRequest) (*Empty, error)
|
||||
Close(context.Context, *Empty) (*Empty, error)
|
||||
}
|
||||
|
||||
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
|
||||
s.RegisterService(&_Database_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Type(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Type",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(CreateUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).CreateUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/CreateUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RenewUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).RenewUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/RenewUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RevokeUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).RevokeUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/RevokeUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(InitializeRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Initialize(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Initialize",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Close(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Close",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _Database_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "dbplugin.Database",
|
||||
HandlerType: (*DatabaseServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Type",
|
||||
Handler: _Database_Type_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "CreateUser",
|
||||
Handler: _Database_CreateUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RenewUser",
|
||||
Handler: _Database_RenewUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RevokeUser",
|
||||
Handler: _Database_RevokeUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Initialize",
|
||||
Handler: _Database_Initialize_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Close",
|
||||
Handler: _Database_Close_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "builtin/logical/database/dbplugin/database.proto",
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 548 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
|
||||
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
|
||||
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
|
||||
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
|
||||
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
|
||||
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
|
||||
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
|
||||
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
|
||||
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
|
||||
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
|
||||
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
|
||||
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
|
||||
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
|
||||
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
|
||||
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
|
||||
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
|
||||
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
|
||||
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
|
||||
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
|
||||
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
|
||||
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
|
||||
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
|
||||
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
|
||||
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
|
||||
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
|
||||
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
|
||||
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
|
||||
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
|
||||
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
|
||||
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
|
||||
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
|
||||
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
|
||||
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
|
||||
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
|
||||
0x94, 0x05, 0x00, 0x00,
|
||||
}
|
||||
58
builtin/logical/database/dbplugin/database.proto
Normal file
58
builtin/logical/database/dbplugin/database.proto
Normal file
@@ -0,0 +1,58 @@
|
||||
syntax = "proto3";
|
||||
package dbplugin;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message InitializeRequest {
|
||||
bytes config = 1;
|
||||
bool verify_connection = 2;
|
||||
}
|
||||
|
||||
message CreateUserRequest {
|
||||
Statements statements = 1;
|
||||
UsernameConfig username_config = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RenewUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RevokeUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
}
|
||||
|
||||
message Statements {
|
||||
string creation_statements = 1;
|
||||
string revocation_statements = 2;
|
||||
string rollback_statements = 3;
|
||||
string renew_statements = 4;
|
||||
}
|
||||
|
||||
message UsernameConfig {
|
||||
string DisplayName = 1;
|
||||
string RoleName = 2;
|
||||
}
|
||||
|
||||
message CreateUserResponse {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
message TypeResponse {
|
||||
string type = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
service Database {
|
||||
rpc Type(Empty) returns (TypeResponse);
|
||||
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
|
||||
rpc RenewUser(RenewUserRequest) returns (Empty);
|
||||
rpc RevokeUser(RevokeUserRequest) returns (Empty);
|
||||
rpc Initialize(InitializeRequest) returns (Empty);
|
||||
rpc Close(Empty) returns (Empty);
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
@@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
|
||||
next Database
|
||||
logger log.Logger
|
||||
|
||||
typeStr string
|
||||
typeStr string
|
||||
transport string
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.Close()
|
||||
}
|
||||
|
||||
@@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
|
||||
@@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
|
||||
|
||||
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
|
||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
|
||||
@@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
|
||||
|
||||
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
|
||||
@@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
|
||||
|
||||
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
||||
@@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
|
||||
|
||||
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
||||
|
||||
198
builtin/logical/database/dbplugin/grpc_transport.go
Normal file
198
builtin/logical/database/dbplugin/grpc_transport.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPluginShutdown = errors.New("plugin shutdown")
|
||||
)
|
||||
|
||||
// ---- gRPC Server domain ----
|
||||
|
||||
type gRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
|
||||
t, err := s.impl.Type()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TypeResponse{
|
||||
Type: t,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
|
||||
|
||||
return &CreateUserResponse{
|
||||
Username: u,
|
||||
Password: p,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
|
||||
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
|
||||
config := map[string]interface{}{}
|
||||
|
||||
err := json.Unmarshal(req.Config, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
|
||||
s.impl.Close()
|
||||
return &Empty{}, nil
|
||||
}
|
||||
|
||||
// ---- gRPC client domain ----
|
||||
|
||||
type gRPCClient struct {
|
||||
client DatabaseClient
|
||||
clientConn *grpc.ClientConn
|
||||
}
|
||||
|
||||
func (c gRPCClient) Type() (string, error) {
|
||||
// If the plugin has already shutdown, this will hang forever so we give it
|
||||
// a one second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return "", ErrPluginShutdown
|
||||
}
|
||||
resp, err := c.client.Type(ctx, &Empty{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.Type, err
|
||||
}
|
||||
|
||||
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return "", "", ErrPluginShutdown
|
||||
}
|
||||
|
||||
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
|
||||
Statements: &statements,
|
||||
UsernameConfig: &usernameConfig,
|
||||
Expiration: t,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
Expiration: t,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
|
||||
configRaw, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
_, err = c.client.Initialize(ctx, &InitializeRequest{
|
||||
Config: configRaw,
|
||||
VerifyConnection: verifyConnection,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Close() error {
|
||||
// If the plugin has already shutdown, this will hang forever so we give it
|
||||
// a one second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
_, err := c.client.Close(ctx, &Empty{})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
139
builtin/logical/database/dbplugin/netrpc_transport.go
Normal file
139
builtin/logical/database/dbplugin/netrpc_transport.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequestRPC{
|
||||
Statements: statements,
|
||||
UsernameConfig: usernameConfig,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
|
||||
req := RevokeUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
req := InitializeRequestRPC{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequestRPC struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
}
|
||||
|
||||
type CreateUserRequestRPC struct {
|
||||
Statements Statements
|
||||
UsernameConfig UsernameConfig
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
@@ -13,29 +16,14 @@ import (
|
||||
// Database is the interface that all database objects must implement.
|
||||
type Database interface {
|
||||
Type() (string, error)
|
||||
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(ctx context.Context, statements Statements, username string) error
|
||||
|
||||
Initialize(config map[string]interface{}, verifyConnection bool) error
|
||||
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Statements set in role creation and passed into the database type's functions.
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
||||
}
|
||||
|
||||
// UsernameConfig is used to configure prefixes for the username to be
|
||||
// generated.
|
||||
type UsernameConfig struct {
|
||||
DisplayName string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||
@@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var transport string
|
||||
var db Database
|
||||
if pluginRunner.Builtin {
|
||||
// Plugin is builtin so we can retrieve an instance of the interface
|
||||
@@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
||||
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
|
||||
}
|
||||
|
||||
transport = "builtin"
|
||||
|
||||
} else {
|
||||
// create a DatabasePluginClient instance
|
||||
db, err = newPluginClient(sys, pluginRunner, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Switch on the underlying database client type to get the transport
|
||||
// method.
|
||||
switch db.(*DatabasePluginClient).Database.(type) {
|
||||
case *gRPCClient:
|
||||
transport = "gRPC"
|
||||
case *databasePluginRPCClient:
|
||||
transport = "netRPC"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
typeStr, err := db.Type()
|
||||
@@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
||||
// Wrap with tracing middleware
|
||||
if logger.IsTrace() {
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
logger: logger,
|
||||
transport: transport,
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
|
||||
return &databasePluginRPCClient{client: c}, nil
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequest struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
|
||||
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
|
||||
return nil
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Statements Statements
|
||||
UsernameConfig UsernameConfig
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
||||
|
||||
// ---- RPC Response Args Domain ----
|
||||
|
||||
type CreateUserResponse struct {
|
||||
Username string
|
||||
Password string
|
||||
func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
|
||||
return &gRPCClient{
|
||||
client: NewDatabaseClient(c),
|
||||
clientConn: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package dbplugin_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
@@ -20,7 +22,7 @@ type mockPlugin struct {
|
||||
}
|
||||
|
||||
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
|
||||
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
err = errors.New("err")
|
||||
if usernameConf.DisplayName == "" || expiration.IsZero() {
|
||||
return "", "", err
|
||||
@@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
|
||||
|
||||
return usernameConf.DisplayName, "test", nil
|
||||
}
|
||||
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
err := errors.New("err")
|
||||
if username == "" || expiration.IsZero() {
|
||||
return err
|
||||
@@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
|
||||
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
|
||||
err := errors.New("err")
|
||||
if username == "" {
|
||||
return err
|
||||
@@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
|
||||
delete(m.users, username)
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
|
||||
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
|
||||
err := errors.New("err")
|
||||
if len(conf) != 1 {
|
||||
return err
|
||||
@@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
||||
cores := cluster.Cores
|
||||
|
||||
sys := vault.TestDynamicSystemView(cores[0].Core)
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main")
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
|
||||
|
||||
return cluster, sys
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_Main(t *testing.T) {
|
||||
func TestPlugin_GRPC_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
@@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
|
||||
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_NetRPC_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p := &mockPlugin{
|
||||
users: make(map[string][]string),
|
||||
}
|
||||
|
||||
args := []string{"--tls-skip-verify=true"}
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
|
||||
serveConf := dbplugin.ServeConfig(p, tlsProvider)
|
||||
serveConf.GRPCServer = nil
|
||||
|
||||
plugin.Serve(serveConf)
|
||||
}
|
||||
|
||||
func TestPlugin_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
@@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = dbRaw.Initialize(connectionDetails, true)
|
||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
||||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
@@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(dbplugin.Statements{}, us)
|
||||
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the code is still compatible with an old netRPC plugin
|
||||
func TestPlugin_NetRPC_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if us != "test" || pw != "test" {
|
||||
t.Fatal("expected username and password to be 'test'")
|
||||
}
|
||||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,10 @@ import (
|
||||
// Database implementation in a databasePluginRPCServer object and starts a
|
||||
// RPC server.
|
||||
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||
plugin.Serve(ServeConfig(db, tlsProvider))
|
||||
}
|
||||
|
||||
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
|
||||
dbPlugin := &DatabasePlugin{
|
||||
impl: db,
|
||||
}
|
||||
@@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||
"database": dbPlugin,
|
||||
}
|
||||
|
||||
plugin.Serve(&plugin.ServeConfig{
|
||||
return &plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
Plugins: pluginMap,
|
||||
TLSProvider: tlsProvider,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(args.Statements, args.Username)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
GRPCServer: plugin.DefaultGRPCServer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
@@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
|
||||
b.clearConnection(name)
|
||||
|
||||
// Execute plugin again, we don't need the object so throw away.
|
||||
_, err := b.createDBObj(req.Storage, name)
|
||||
_, err := b.createDBObj(context.TODO(), req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, verifyConnection)
|
||||
err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -49,7 +50,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
|
||||
// If role name isn't in the database's allowed roles, send back a
|
||||
// permission denied.
|
||||
if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContains(dbConfig.AllowedRoles, name) {
|
||||
if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContainsGlob(dbConfig.AllowedRoles, name) {
|
||||
return nil, logical.ErrPermissionDenied
|
||||
}
|
||||
|
||||
@@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
@@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
}
|
||||
|
||||
// Create the user
|
||||
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration)
|
||||
username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
||||
@@ -181,7 +181,7 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
|
||||
|
||||
type roleEntry struct {
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
Statements dbplugin.Statements `json:"statments" mapstructure:"statements" structs:"statments"`
|
||||
Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
@@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||
err := db.RenewUser(role.Statements, username, expireTime)
|
||||
err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
@@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
||||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = db.RevokeUser(role.Statements, username)
|
||||
err = db.RevokeUser(context.TODO(), role.Statements, username)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
||||
@@ -24,6 +24,12 @@ func Backend() *framework.Backend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
|
||||
@@ -24,6 +24,12 @@ func Backend() *backend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
|
||||
@@ -24,6 +24,12 @@ func Backend() *backend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
|
||||
68
builtin/logical/nomad/backend.go
Normal file
68
builtin/logical/nomad/backend.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/nomad/api"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/access",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigAccess(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathCredsCreate(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretToken(&b),
|
||||
},
|
||||
BackendType: logical.TypeLogical,
|
||||
}
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
}
|
||||
|
||||
func (b *backend) client(s logical.Storage) (*api.Client, error) {
|
||||
conf, err := b.readConfigAccess(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nomadConf := api.DefaultConfig()
|
||||
if conf != nil {
|
||||
if conf.Address != "" {
|
||||
nomadConf.Address = conf.Address
|
||||
}
|
||||
if conf.Token != "" {
|
||||
nomadConf.SecretID = conf.Token
|
||||
}
|
||||
}
|
||||
|
||||
client, err := api.NewClient(nomadConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
302
builtin/logical/nomad/backend_test.go
Normal file
302
builtin/logical/nomad/backend_test.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nomadapi "github.com/hashicorp/nomad/api"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
func prepareTestContainer(t *testing.T) (cleanup func(), retAddress string, nomadToken string) {
|
||||
nomadToken = os.Getenv("NOMAD_TOKEN")
|
||||
|
||||
retAddress = os.Getenv("NOMAD_ADDR")
|
||||
|
||||
if retAddress != "" {
|
||||
return func() {}, retAddress, nomadToken
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
dockerOptions := &dockertest.RunOptions{
|
||||
Repository: "djenriquez/nomad",
|
||||
Tag: "latest",
|
||||
Cmd: []string{"agent", "-dev"},
|
||||
Env: []string{`NOMAD_LOCAL_CONFIG=bind_addr = "0.0.0.0" acl { enabled = true }`},
|
||||
}
|
||||
resource, err := pool.RunWithOptions(dockerOptions)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local Nomad docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retAddress = fmt.Sprintf("http://localhost:%s/", resource.GetPort("4646/tcp"))
|
||||
// Give Nomad time to initialize
|
||||
|
||||
time.Sleep(5000 * time.Millisecond)
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
nomadapiConfig := nomadapi.DefaultConfig()
|
||||
nomadapiConfig.Address = retAddress
|
||||
nomad, err := nomadapi.NewClient(nomadapiConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
aclbootstrap, _, err := nomad.ACLTokens().Bootstrap(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
nomadToken = aclbootstrap.SecretID
|
||||
t.Logf("[WARN] Generated Master token: %s", nomadToken)
|
||||
policy := &nomadapi.ACLPolicy{
|
||||
Name: "test",
|
||||
Description: "test",
|
||||
Rules: `namespace "default" {
|
||||
policy = "read"
|
||||
}
|
||||
`,
|
||||
}
|
||||
anonPolicy := &nomadapi.ACLPolicy{
|
||||
Name: "anonymous",
|
||||
Description: "Deny all access for anonymous requests",
|
||||
Rules: `namespace "default" {
|
||||
policy = "deny"
|
||||
}
|
||||
agent {
|
||||
policy = "deny"
|
||||
}
|
||||
node {
|
||||
policy = "deny"
|
||||
}
|
||||
`,
|
||||
}
|
||||
nomadAuthConfig := nomadapi.DefaultConfig()
|
||||
nomadAuthConfig.Address = retAddress
|
||||
nomadAuthConfig.SecretID = nomadToken
|
||||
nomadAuth, err := nomadapi.NewClient(nomadAuthConfig)
|
||||
_, err = nomadAuth.ACLPolicies().Upsert(policy, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = nomadAuth.ACLPolicies().Upsert(anonPolicy, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return err
|
||||
}); err != nil {
|
||||
cleanup()
|
||||
t.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
return cleanup, retAddress, nomadToken
|
||||
}
|
||||
|
||||
func TestBackend_config_access(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL, connToken := prepareTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"address": connURL,
|
||||
"token": connToken,
|
||||
}
|
||||
|
||||
confReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/access",
|
||||
Storage: config.StorageView,
|
||||
Data: connData,
|
||||
}
|
||||
|
||||
resp, err := b.HandleRequest(confReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
|
||||
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
confReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(confReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
expected := map[string]interface{}{
|
||||
"address": connData["address"].(string),
|
||||
}
|
||||
if !reflect.DeepEqual(expected, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
|
||||
}
|
||||
if resp.Data["token"] != nil {
|
||||
t.Fatalf("token should not be set in the response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_renew_revoke(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL, connToken := prepareTestContainer(t)
|
||||
defer cleanup()
|
||||
connData := map[string]interface{}{
|
||||
"address": connURL,
|
||||
"token": connToken,
|
||||
}
|
||||
|
||||
req := &logical.Request{
|
||||
Storage: config.StorageView,
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/access",
|
||||
Data: connData,
|
||||
}
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Path = "role/test"
|
||||
req.Data = map[string]interface{}{
|
||||
"policies": []string{"policy"},
|
||||
"lease": "6h",
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Operation = logical.ReadOperation
|
||||
req.Path = "creds/test"
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("resp nil")
|
||||
}
|
||||
if resp.IsError() {
|
||||
t.Fatalf("resp is error: %v", resp.Error())
|
||||
}
|
||||
|
||||
generatedSecret := resp.Secret
|
||||
generatedSecret.IssueTime = time.Now()
|
||||
generatedSecret.TTL = 6 * time.Hour
|
||||
|
||||
var d struct {
|
||||
Token string `mapstructure:"secret_id"`
|
||||
Accessor string `mapstructure:"accessor_id"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("[WARN] Generated token: %s with accesor %s", d.Token, d.Accessor)
|
||||
|
||||
// Build a client and verify that the credentials work
|
||||
nomadapiConfig := nomadapi.DefaultConfig()
|
||||
nomadapiConfig.Address = connData["address"].(string)
|
||||
nomadapiConfig.SecretID = d.Token
|
||||
client, err := nomadapi.NewClient(nomadapiConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log("[WARN] Verifying that the generated token works...")
|
||||
_, err = client.Agent().Members, nil
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Operation = logical.RenewOperation
|
||||
req.Secret = generatedSecret
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("got nil response from renew")
|
||||
}
|
||||
|
||||
req.Operation = logical.RevokeOperation
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Build a management client and verify that the token does not exist anymore
|
||||
nomadmgmtConfig := nomadapi.DefaultConfig()
|
||||
nomadmgmtConfig.Address = connData["address"].(string)
|
||||
nomadmgmtConfig.SecretID = connData["token"].(string)
|
||||
mgmtclient, err := nomadapi.NewClient(nomadmgmtConfig)
|
||||
|
||||
q := &nomadapi.QueryOptions{
|
||||
Namespace: "default",
|
||||
}
|
||||
|
||||
t.Log("[WARN] Verifying that the generated token does not exist...")
|
||||
_, _, err = mgmtclient.ACLTokens().Info(d.Accessor, q)
|
||||
if err == nil {
|
||||
t.Fatal("err: expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_CredsCreateEnvVar(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL, connToken := prepareTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
req := logical.TestRequest(t, logical.UpdateOperation, "role/test")
|
||||
req.Data = map[string]interface{}{
|
||||
"policies": []string{"policy"},
|
||||
"lease": "6h",
|
||||
}
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
os.Setenv("NOMAD_TOKEN", connToken)
|
||||
defer os.Unsetenv("NOMAD_TOKEN")
|
||||
os.Setenv("NOMAD_ADDR", connURL)
|
||||
defer os.Unsetenv("NOMAD_ADDR")
|
||||
|
||||
req.Operation = logical.ReadOperation
|
||||
req.Path = "creds/test"
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("resp nil")
|
||||
}
|
||||
if resp.IsError() {
|
||||
t.Fatalf("resp is error: %v", resp.Error())
|
||||
}
|
||||
}
|
||||
121
builtin/logical/nomad/path_config_access.go
Normal file
121
builtin/logical/nomad/path_config_access.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const configAccessKey = "config/access"
|
||||
|
||||
func pathConfigAccess(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/access",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"address": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Nomad server address",
|
||||
},
|
||||
|
||||
"token": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Token for API calls",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathConfigAccessRead,
|
||||
logical.CreateOperation: b.pathConfigAccessWrite,
|
||||
logical.UpdateOperation: b.pathConfigAccessWrite,
|
||||
logical.DeleteOperation: b.pathConfigAccessDelete,
|
||||
},
|
||||
|
||||
ExistenceCheck: b.configExistenceCheck,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) configExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.readConfigAccess(req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return entry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) readConfigAccess(storage logical.Storage) (*accessConfig, error) {
|
||||
entry, err := storage.Get(configAccessKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conf := &accessConfig{}
|
||||
if err := entry.DecodeJSON(conf); err != nil {
|
||||
return nil, errwrap.Wrapf("error reading nomad access configuration: {{err}}", err)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigAccessRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
conf, err := b.readConfigAccess(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"address": conf.Address,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigAccessWrite(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
conf, err := b.readConfigAccess(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf == nil {
|
||||
conf = &accessConfig{}
|
||||
}
|
||||
|
||||
address, ok := data.GetOk("address")
|
||||
if ok {
|
||||
conf.Address = address.(string)
|
||||
}
|
||||
token, ok := data.GetOk("token")
|
||||
if ok {
|
||||
conf.Token = token.(string)
|
||||
}
|
||||
|
||||
entry, err := logical.StorageEntryJSON("config/access", conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigAccessDelete(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
if err := req.Storage.Delete(configAccessKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type accessConfig struct {
|
||||
Address string `json:"address"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
109
builtin/logical/nomad/path_config_lease.go
Normal file
109
builtin/logical/nomad/path_config_lease.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const leaseConfigKey = "config/lease"
|
||||
|
||||
func pathConfigLease(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: "Duration before which the issued token needs renewal",
|
||||
},
|
||||
"max_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `Duration after which the issued token should not be allowed to be renewed`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathLeaseRead,
|
||||
logical.UpdateOperation: b.pathLeaseUpdate,
|
||||
logical.DeleteOperation: b.pathLeaseDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the lease configuration parameters
|
||||
func (b *backend) pathLeaseUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
TTL: time.Second * time.Duration(d.Get("ttl").(int)),
|
||||
MaxTTL: time.Second * time.Duration(d.Get("max_ttl").(int)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLeaseDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
if err := req.Storage.Delete(leaseConfigKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Returns the lease configuration parameters
|
||||
func (b *backend) pathLeaseRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.LeaseConfig(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"ttl": int64(lease.TTL.Seconds()),
|
||||
"max_ttl": int64(lease.MaxTTL.Seconds()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get(leaseConfigKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Lease configuration information for the secrets issued by this backend
|
||||
type configLease struct {
|
||||
TTL time.Duration `json:"ttl" mapstructure:"ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl"`
|
||||
}
|
||||
|
||||
var pathConfigLeaseHelpSyn = "Configure the lease parameters for generated tokens"
|
||||
|
||||
var pathConfigLeaseHelpDesc = `
|
||||
Sets the ttl and max_ttl values for the secrets to be issued by this backend.
|
||||
Both ttl and max_ttl takes in an integer number of seconds as input as well as
|
||||
inputs like "1h".
|
||||
`
|
||||
80
builtin/logical/nomad/path_creds_create.go
Normal file
80
builtin/logical/nomad/path_creds_create.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/nomad/api"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathCredsCreate(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathTokenRead,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathTokenRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error retrieving role: {{err}}", err)
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("role %q not found", name)), nil
|
||||
}
|
||||
|
||||
// Determine if we have a lease configuration
|
||||
leaseConfig, err := b.LeaseConfig(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leaseConfig == nil {
|
||||
leaseConfig = &configLease{}
|
||||
}
|
||||
|
||||
// Get the nomad client
|
||||
c, err := b.client(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a name for the token
|
||||
tokenName := fmt.Sprintf("vault-%s-%s-%d", name, req.DisplayName, time.Now().UnixNano())
|
||||
|
||||
// Create it
|
||||
token, _, err := c.ACLTokens().Create(&api.ACLToken{
|
||||
Name: tokenName,
|
||||
Type: role.TokenType,
|
||||
Policies: role.Policies,
|
||||
Global: role.Global,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use the helper to create the secret
|
||||
resp := b.Secret(SecretTokenType).Response(map[string]interface{}{
|
||||
"secret_id": token.SecretID,
|
||||
"accessor_id": token.AccessorID,
|
||||
}, map[string]interface{}{
|
||||
"accessor_id": token.AccessorID,
|
||||
})
|
||||
resp.Secret.TTL = leaseConfig.TTL
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
189
builtin/logical/nomad/path_roles.go
Normal file
189
builtin/logical/nomad/path_roles.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathListRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "role/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "role/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role",
|
||||
},
|
||||
|
||||
"policies": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: "Comma-separated string or list of policies as previously created in Nomad. Required for 'client' token.",
|
||||
},
|
||||
|
||||
"global": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Description: "Boolean value describing if the token should be global or not. Defaults to false.",
|
||||
},
|
||||
|
||||
"type": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Default: "client",
|
||||
Description: `Which type of token to create: 'client'
|
||||
or 'management'. If a 'management' token,
|
||||
the "policies" parameter is not required.
|
||||
Defaults to 'client'.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRolesRead,
|
||||
logical.CreateOperation: b.pathRolesWrite,
|
||||
logical.UpdateOperation: b.pathRolesWrite,
|
||||
logical.DeleteOperation: b.pathRolesDelete,
|
||||
},
|
||||
|
||||
ExistenceCheck: b.rolesExistenceCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
|
||||
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
|
||||
func (b *backend) rolesExistenceCheck(req *logical.Request, d *framework.FieldData) (bool, error) {
|
||||
name := d.Get("name").(string)
|
||||
entry, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return entry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) Role(storage logical.Storage, name string) (*roleConfig, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("invalid role name")
|
||||
}
|
||||
|
||||
entry, err := storage.Get("role/" + name)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error retrieving role: {{err}}", err)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleConfig
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRolesRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"type": role.TokenType,
|
||||
"global": role.Global,
|
||||
"policies": role.Policies,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRolesWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
role = new(roleConfig)
|
||||
}
|
||||
|
||||
policies, ok := d.GetOk("policies")
|
||||
if ok {
|
||||
role.Policies = policies.([]string)
|
||||
}
|
||||
|
||||
role.TokenType = d.Get("type").(string)
|
||||
switch role.TokenType {
|
||||
case "client":
|
||||
if len(role.Policies) == 0 {
|
||||
return logical.ErrorResponse(
|
||||
"policies cannot be empty when using client tokens"), nil
|
||||
}
|
||||
case "management":
|
||||
if len(role.Policies) != 0 {
|
||||
return logical.ErrorResponse(
|
||||
"policies should be empty when using management tokens"), nil
|
||||
}
|
||||
default:
|
||||
return logical.ErrorResponse(
|
||||
`type must be "client" or "management"`), nil
|
||||
}
|
||||
|
||||
global, ok := d.GetOk("global")
|
||||
if ok {
|
||||
role.Global = global.(bool)
|
||||
}
|
||||
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRolesDelete(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
if err := req.Storage.Delete("role/" + name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type roleConfig struct {
|
||||
Policies []string `json:"policies"`
|
||||
TokenType string `json:"type"`
|
||||
Global bool `json:"global"`
|
||||
}
|
||||
68
builtin/logical/nomad/secret_token.go
Normal file
68
builtin/logical/nomad/secret_token.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const (
|
||||
SecretTokenType = "token"
|
||||
)
|
||||
|
||||
func secretToken(b *backend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretTokenType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"token": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Request token",
|
||||
},
|
||||
},
|
||||
|
||||
Renew: b.secretTokenRenew,
|
||||
Revoke: b.secretTokenRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) secretTokenRenew(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.LeaseConfig(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
lease = &configLease{}
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(lease.TTL, lease.MaxTTL, b.System())(req, d)
|
||||
}
|
||||
|
||||
func (b *backend) secretTokenRevoke(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
c, err := b.client(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("error getting Nomad client")
|
||||
}
|
||||
|
||||
accessorIDRaw, ok := req.Secret.InternalData["accessor_id"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("accessor_id is missing on the lease")
|
||||
}
|
||||
accessorID, ok := accessorIDRaw.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to convert accessor_id")
|
||||
}
|
||||
_, err = c.ACLTokens().Delete(accessorID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1463,7 +1463,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
||||
//t.Logf("test step %d\nrole vals: %#v\n", stepCount, roleVals)
|
||||
stepCount++
|
||||
//t.Logf("test step %d\nissue vals: %#v\n", stepCount, issueTestStep)
|
||||
roleTestStep.Data = structs.New(roleVals).Map()
|
||||
roleTestStep.Data = roleVals.ToResponseData()
|
||||
roleTestStep.Data["generate_lease"] = false
|
||||
ret = append(ret, roleTestStep)
|
||||
issueTestStep.Data = structs.New(issueVals).Map()
|
||||
@@ -1594,38 +1594,38 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
||||
roleVals.CodeSigningFlag = false
|
||||
roleVals.EmailProtectionFlag = false
|
||||
|
||||
var usage string
|
||||
var usage []string
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",DigitalSignature"
|
||||
usage = append(usage, "DigitalSignature")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",ContentCoMmitment"
|
||||
usage = append(usage, "ContentCoMmitment")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",KeyEncipherment"
|
||||
usage = append(usage, "KeyEncipherment")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",DataEncipherment"
|
||||
usage = append(usage, "DataEncipherment")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",KeyAgreemEnt"
|
||||
usage = append(usage, "KeyAgreemEnt")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",CertSign"
|
||||
usage = append(usage, "CertSign")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",CRLSign"
|
||||
usage = append(usage, "CRLSign")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",EncipherOnly"
|
||||
usage = append(usage, "EncipherOnly")
|
||||
}
|
||||
if mathRand.Int()%2 == 1 {
|
||||
usage = usage + ",DecipherOnly"
|
||||
usage = append(usage, "DecipherOnly")
|
||||
}
|
||||
|
||||
roleVals.KeyUsage = usage
|
||||
parsedKeyUsage := parseKeyUsages(roleVals.KeyUsage)
|
||||
if parsedKeyUsage == 0 && usage != "" {
|
||||
if parsedKeyUsage == 0 && len(usage) != 0 {
|
||||
panic("parsed key usages was zero")
|
||||
}
|
||||
parsedKeyUsageUnderTest = parsedKeyUsage
|
||||
@@ -1759,10 +1759,10 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
||||
commonNames.Localhost = true
|
||||
addCnTests()
|
||||
|
||||
roleVals.AllowedDomains = "foobar.com"
|
||||
roleVals.AllowedDomains = []string{"foobar.com"}
|
||||
addCnTests()
|
||||
|
||||
roleVals.AllowedDomains = "example.com"
|
||||
roleVals.AllowedDomains = []string{"example.com"}
|
||||
roleVals.AllowSubdomains = true
|
||||
commonNames.SubDomain = true
|
||||
commonNames.Wildcard = true
|
||||
@@ -1770,13 +1770,13 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
||||
commonNames.SubSubdomainWildcard = true
|
||||
addCnTests()
|
||||
|
||||
roleVals.AllowedDomains = "foobar.com,example.com"
|
||||
roleVals.AllowedDomains = []string{"foobar.com", "example.com"}
|
||||
commonNames.SecondDomain = true
|
||||
roleVals.AllowBareDomains = true
|
||||
commonNames.BareDomain = true
|
||||
addCnTests()
|
||||
|
||||
roleVals.AllowedDomains = "foobar.com,*example.com"
|
||||
roleVals.AllowedDomains = []string{"foobar.com", "*example.com"}
|
||||
roleVals.AllowGlobDomains = true
|
||||
commonNames.GlobDomain = true
|
||||
addCnTests()
|
||||
|
||||
@@ -17,14 +17,14 @@ func (b *backend) getGenerationParams(
|
||||
case "internal":
|
||||
default:
|
||||
errorResp = logical.ErrorResponse(
|
||||
`The "exported" path parameter must be "internal" or "exported"`)
|
||||
`the "exported" path parameter must be "internal" or "exported"`)
|
||||
return
|
||||
}
|
||||
|
||||
format = getFormat(data)
|
||||
if format == "" {
|
||||
errorResp = logical.ErrorResponse(
|
||||
`The "format" path parameter must be "pem", "der", or "pem_bundle"`)
|
||||
`the "format" path parameter must be "pem", "der", "der_pkcs", or "pem_bundle"`)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package pki
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -16,6 +18,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
@@ -372,9 +375,9 @@ func validateNames(req *logical.Request, names []string, role *roleEntry) string
|
||||
}
|
||||
}
|
||||
|
||||
if role.AllowedDomains != "" {
|
||||
if len(role.AllowedDomains) > 0 {
|
||||
valid := false
|
||||
for _, currDomain := range strings.Split(role.AllowedDomains, ",") {
|
||||
for _, currDomain := range role.AllowedDomains {
|
||||
// If there is, say, a trailing comma, ignore it
|
||||
if currDomain == "" {
|
||||
continue
|
||||
@@ -1183,3 +1186,66 @@ NameCheck:
|
||||
|
||||
return fmt.Errorf("name %q disallowed by CA's permitted DNS domains", badName)
|
||||
}
|
||||
|
||||
func convertRespToPKCS8(resp *logical.Response) error {
|
||||
privRaw, ok := resp.Data["private_key"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
priv, ok := privRaw.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("error converting response to pkcs8: could not parse original value as string")
|
||||
}
|
||||
|
||||
privKeyTypeRaw, ok := resp.Data["private_key_type"]
|
||||
if !ok {
|
||||
return fmt.Errorf("error converting response to pkcs8: %q not found in response", "private_key_type")
|
||||
}
|
||||
privKeyType, ok := privKeyTypeRaw.(certutil.PrivateKeyType)
|
||||
if !ok {
|
||||
return fmt.Errorf("error converting response to pkcs8: could not parse original type value as string")
|
||||
}
|
||||
|
||||
var keyData []byte
|
||||
var pemUsed bool
|
||||
var err error
|
||||
var signer crypto.Signer
|
||||
|
||||
block, _ := pem.Decode([]byte(priv))
|
||||
if block == nil {
|
||||
keyData, err = base64.StdEncoding.DecodeString(priv)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error converting response to pkcs8: error decoding original value: {{err}}", err)
|
||||
}
|
||||
} else {
|
||||
keyData = block.Bytes
|
||||
pemUsed = true
|
||||
}
|
||||
|
||||
switch privKeyType {
|
||||
case certutil.RSAPrivateKey:
|
||||
signer, err = x509.ParsePKCS1PrivateKey(keyData)
|
||||
case certutil.ECPrivateKey:
|
||||
signer, err = x509.ParseECPrivateKey(keyData)
|
||||
default:
|
||||
return fmt.Errorf("unknown private key type %q", privKeyType)
|
||||
}
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error converting response to pkcs8: error parsing previous key: {{err}}", err)
|
||||
}
|
||||
|
||||
keyData, err = certutil.MarshalPKCS8PrivateKey(signer)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error converting response to pkcs8: error marshaling pkcs8 key: {{err}}", err)
|
||||
}
|
||||
|
||||
if pemUsed {
|
||||
block.Type = "PRIVATE KEY"
|
||||
block.Bytes = keyData
|
||||
resp.Data["private_key"] = string(pem.EncodeToMemory(block))
|
||||
} else {
|
||||
resp.Data["private_key"] = base64.StdEncoding.EncodeToString(keyData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,6 +22,17 @@ key and issuing cert will be appended to the
|
||||
certificate pem. Defaults to "pem".`,
|
||||
}
|
||||
|
||||
fields["private_key_format"] = &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Default: "der",
|
||||
Description: `Format for the returned private key.
|
||||
Generally the default will be controlled by the "format"
|
||||
parameter as either base64-encoded DER or PEM-encoded DER.
|
||||
However, this can be set to "pkcs8" to have the returned
|
||||
private key contain base64-encoded pkcs8 or PEM-encoded
|
||||
pkcs8 instead. Defaults to "der".`,
|
||||
}
|
||||
|
||||
fields["ip_sans"] = &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `The requested IP SANs, if any, in a
|
||||
|
||||
@@ -106,6 +106,13 @@ func (b *backend) pathGenerateIntermediate(
|
||||
}
|
||||
}
|
||||
|
||||
if data.Get("private_key_format").(string) == "pkcs8" {
|
||||
err = convertRespToPKCS8(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cb := &certutil.CertBundle{}
|
||||
cb.PrivateKey = csrb.PrivateKey
|
||||
cb.PrivateKeyType = csrb.PrivateKeyType
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@@ -163,7 +164,7 @@ func (b *backend) pathIssueSignCert(
|
||||
format := getFormat(data)
|
||||
if format == "" {
|
||||
return logical.ErrorResponse(
|
||||
`The "format" path parameter must be "pem", "der", or "pem_bundle"`), nil
|
||||
`the "format" path parameter must be "pem", "der", or "pem_bundle"`), nil
|
||||
}
|
||||
|
||||
var caErr error
|
||||
@@ -171,10 +172,10 @@ func (b *backend) pathIssueSignCert(
|
||||
switch caErr.(type) {
|
||||
case errutil.UserError:
|
||||
return nil, errutil.UserError{Err: fmt.Sprintf(
|
||||
"Could not fetch the CA certificate (was one set?): %s", caErr)}
|
||||
"could not fetch the CA certificate (was one set?): %s", caErr)}
|
||||
case errutil.InternalError:
|
||||
return nil, errutil.InternalError{Err: fmt.Sprintf(
|
||||
"Error fetching CA certificate: %s", caErr)}
|
||||
"error fetching CA certificate: %s", caErr)}
|
||||
}
|
||||
|
||||
var parsedBundle *certutil.ParsedCertBundle
|
||||
@@ -195,12 +196,12 @@ func (b *backend) pathIssueSignCert(
|
||||
|
||||
signingCB, err := signingBundle.ToCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error converting raw signing bundle to cert bundle: %s", err)
|
||||
return nil, errwrap.Wrapf("error converting raw signing bundle to cert bundle: {{err}}", err)
|
||||
}
|
||||
|
||||
cb, err := parsedBundle.ToCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error converting raw cert bundle to cert bundle: %s", err)
|
||||
return nil, errwrap.Wrapf("error converting raw cert bundle to cert bundle: {{err}}", err)
|
||||
}
|
||||
|
||||
respData := map[string]interface{}{
|
||||
@@ -267,6 +268,13 @@ func (b *backend) pathIssueSignCert(
|
||||
resp.Secret.TTL = parsedBundle.Certificate.NotAfter.Sub(time.Now())
|
||||
}
|
||||
|
||||
if data.Get("private_key_format").(string) == "pkcs8" {
|
||||
err = convertRespToPKCS8(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !role.NoStore {
|
||||
err = req.Storage.Put(&logical.StorageEntry{
|
||||
Key: "certs/" + normalizeSerial(cb.SerialNumber),
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
@@ -57,13 +56,12 @@ name in a request`,
|
||||
},
|
||||
|
||||
"allowed_domains": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, clients can request certificates for
|
||||
subdomains directly beneath these domains, including
|
||||
the wildcard subdomains. See the documentation for more
|
||||
information. This parameter accepts a comma-separated list
|
||||
of domains.`,
|
||||
information. This parameter accepts a comma-separated
|
||||
string or list of domains.`,
|
||||
},
|
||||
|
||||
"allow_bare_domains": &framework.FieldSchema{
|
||||
@@ -158,14 +156,14 @@ the key_type.`,
|
||||
},
|
||||
|
||||
"key_usage": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Default: "DigitalSignature,KeyAgreement,KeyEncipherment",
|
||||
Description: `A comma-separated set of key usages (not extended
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Default: []string{"DigitalSignature", "KeyAgreement", "KeyEncipherment"},
|
||||
Description: `A comma-separated string or list of key usages (not extended
|
||||
key usages). Valid values can be found at
|
||||
https://golang.org/pkg/crypto/x509/#KeyUsage
|
||||
-- simply drop the "KeyUsage" part of the name.
|
||||
To remove all key usages from being set, set
|
||||
this value to an empty string.`,
|
||||
this value to an empty list.`,
|
||||
},
|
||||
|
||||
"use_csr_common_name": &framework.FieldSchema{
|
||||
@@ -217,8 +215,8 @@ leases adversely affect the startup time of Vault.`,
|
||||
Default: false,
|
||||
Description: `
|
||||
If set, certificates issued/signed against this role will not be stored in the
|
||||
in the storage backend. This can improve performance when issuing large numbers
|
||||
of certificates. However, certificates issued in this way cannot be enumerated
|
||||
storage backend. This can improve performance when issuing large numbers of
|
||||
certificates. However, certificates issued in this way cannot be enumerated
|
||||
or revoked, so this option is recommended only for certificates that are
|
||||
non-sensitive, or extremely short-lived. This option implies a value of "false"
|
||||
for "generate_lease".`,
|
||||
@@ -267,23 +265,21 @@ func (b *backend) getRole(s logical.Storage, n string) (*roleEntry, error) {
|
||||
result.AllowBareDomains = true
|
||||
modified = true
|
||||
}
|
||||
if result.AllowedDomainsOld != "" {
|
||||
result.AllowedDomains = strings.Split(result.AllowedDomainsOld, ",")
|
||||
result.AllowedDomainsOld = ""
|
||||
modified = true
|
||||
}
|
||||
if result.AllowedBaseDomain != "" {
|
||||
found := false
|
||||
allowedDomains := strings.Split(result.AllowedDomains, ",")
|
||||
if len(allowedDomains) != 0 {
|
||||
for _, v := range allowedDomains {
|
||||
if v == result.AllowedBaseDomain {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
for _, v := range result.AllowedDomains {
|
||||
if v == result.AllowedBaseDomain {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if result.AllowedDomains == "" {
|
||||
result.AllowedDomains = result.AllowedBaseDomain
|
||||
} else {
|
||||
result.AllowedDomains += "," + result.AllowedBaseDomain
|
||||
}
|
||||
result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain)
|
||||
}
|
||||
result.AllowedBaseDomain = ""
|
||||
modified = true
|
||||
@@ -299,13 +295,23 @@ func (b *backend) getRole(s logical.Storage, n string) (*roleEntry, error) {
|
||||
modified = true
|
||||
}
|
||||
|
||||
// Upgrade key usages
|
||||
if result.KeyUsageOld != "" {
|
||||
result.KeyUsage = strings.Split(result.KeyUsageOld, ",")
|
||||
result.KeyUsageOld = ""
|
||||
modified = true
|
||||
}
|
||||
|
||||
if modified {
|
||||
jsonEntry, err := logical.StorageEntryJSON("role/"+n, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.Put(jsonEntry); err != nil {
|
||||
return nil, err
|
||||
// Only perform upgrades on replication primary
|
||||
if !strings.Contains(err.Error(), logical.ErrReadOnly.Error()) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,18 +357,8 @@ func (b *backend) pathRoleRead(
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: structs.New(role).Map(),
|
||||
Data: role.ToResponseData(),
|
||||
}
|
||||
|
||||
if resp.Data == nil {
|
||||
return nil, fmt.Errorf("error converting role data to response")
|
||||
}
|
||||
|
||||
// These values are deprecated and the entries are migrated on read
|
||||
delete(resp.Data, "lease")
|
||||
delete(resp.Data, "lease_max")
|
||||
delete(resp.Data, "allowed_base_domain")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -385,7 +381,7 @@ func (b *backend) pathRoleCreate(
|
||||
MaxTTL: data.Get("max_ttl").(string),
|
||||
TTL: (time.Duration(data.Get("ttl").(int)) * time.Second).String(),
|
||||
AllowLocalhost: data.Get("allow_localhost").(bool),
|
||||
AllowedDomains: data.Get("allowed_domains").(string),
|
||||
AllowedDomains: data.Get("allowed_domains").([]string),
|
||||
AllowBareDomains: data.Get("allow_bare_domains").(bool),
|
||||
AllowSubdomains: data.Get("allow_subdomains").(bool),
|
||||
AllowGlobDomains: data.Get("allow_glob_domains").(bool),
|
||||
@@ -400,7 +396,7 @@ func (b *backend) pathRoleCreate(
|
||||
KeyBits: data.Get("key_bits").(int),
|
||||
UseCSRCommonName: data.Get("use_csr_common_name").(bool),
|
||||
UseCSRSANs: data.Get("use_csr_sans").(bool),
|
||||
KeyUsage: data.Get("key_usage").(string),
|
||||
KeyUsage: data.Get("key_usage").([]string),
|
||||
OU: data.Get("ou").(string),
|
||||
Organization: data.Get("organization").(string),
|
||||
GenerateLease: new(bool),
|
||||
@@ -473,10 +469,9 @@ func (b *backend) pathRoleCreate(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func parseKeyUsages(input string) int {
|
||||
func parseKeyUsages(input []string) int {
|
||||
var parsedKeyUsages x509.KeyUsage
|
||||
splitKeyUsage := strings.Split(input, ",")
|
||||
for _, k := range splitKeyUsage {
|
||||
for _, k := range input {
|
||||
switch strings.ToLower(strings.TrimSpace(k)) {
|
||||
case "digitalsignature":
|
||||
parsedKeyUsages |= x509.KeyUsageDigitalSignature
|
||||
@@ -503,40 +498,77 @@ func parseKeyUsages(input string) int {
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
LeaseMax string `json:"lease_max" structs:"lease_max" mapstructure:"lease_max"`
|
||||
Lease string `json:"lease" structs:"lease" mapstructure:"lease"`
|
||||
MaxTTL string `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"`
|
||||
TTL string `json:"ttl" structs:"ttl" mapstructure:"ttl"`
|
||||
AllowLocalhost bool `json:"allow_localhost" structs:"allow_localhost" mapstructure:"allow_localhost"`
|
||||
AllowedBaseDomain string `json:"allowed_base_domain" structs:"allowed_base_domain" mapstructure:"allowed_base_domain"`
|
||||
AllowedDomains string `json:"allowed_domains" structs:"allowed_domains" mapstructure:"allowed_domains"`
|
||||
AllowBaseDomain bool `json:"allow_base_domain" structs:"allow_base_domain" mapstructure:"allow_base_domain"`
|
||||
AllowBareDomains bool `json:"allow_bare_domains" structs:"allow_bare_domains" mapstructure:"allow_bare_domains"`
|
||||
AllowTokenDisplayName bool `json:"allow_token_displayname" structs:"allow_token_displayname" mapstructure:"allow_token_displayname"`
|
||||
AllowSubdomains bool `json:"allow_subdomains" structs:"allow_subdomains" mapstructure:"allow_subdomains"`
|
||||
AllowGlobDomains bool `json:"allow_glob_domains" structs:"allow_glob_domains" mapstructure:"allow_glob_domains"`
|
||||
AllowAnyName bool `json:"allow_any_name" structs:"allow_any_name" mapstructure:"allow_any_name"`
|
||||
EnforceHostnames bool `json:"enforce_hostnames" structs:"enforce_hostnames" mapstructure:"enforce_hostnames"`
|
||||
AllowIPSANs bool `json:"allow_ip_sans" structs:"allow_ip_sans" mapstructure:"allow_ip_sans"`
|
||||
ServerFlag bool `json:"server_flag" structs:"server_flag" mapstructure:"server_flag"`
|
||||
ClientFlag bool `json:"client_flag" structs:"client_flag" mapstructure:"client_flag"`
|
||||
CodeSigningFlag bool `json:"code_signing_flag" structs:"code_signing_flag" mapstructure:"code_signing_flag"`
|
||||
EmailProtectionFlag bool `json:"email_protection_flag" structs:"email_protection_flag" mapstructure:"email_protection_flag"`
|
||||
UseCSRCommonName bool `json:"use_csr_common_name" structs:"use_csr_common_name" mapstructure:"use_csr_common_name"`
|
||||
UseCSRSANs bool `json:"use_csr_sans" structs:"use_csr_sans" mapstructure:"use_csr_sans"`
|
||||
KeyType string `json:"key_type" structs:"key_type" mapstructure:"key_type"`
|
||||
KeyBits int `json:"key_bits" structs:"key_bits" mapstructure:"key_bits"`
|
||||
MaxPathLength *int `json:",omitempty" structs:"max_path_length,omitempty" mapstructure:"max_path_length"`
|
||||
KeyUsage string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"`
|
||||
OU string `json:"ou" structs:"ou" mapstructure:"ou"`
|
||||
Organization string `json:"organization" structs:"organization" mapstructure:"organization"`
|
||||
GenerateLease *bool `json:"generate_lease,omitempty" structs:"generate_lease,omitempty"`
|
||||
NoStore bool `json:"no_store" structs:"no_store" mapstructure:"no_store"`
|
||||
LeaseMax string `json:"lease_max"`
|
||||
Lease string `json:"lease"`
|
||||
MaxTTL string `json:"max_ttl" mapstructure:"max_ttl"`
|
||||
TTL string `json:"ttl" mapstructure:"ttl"`
|
||||
AllowLocalhost bool `json:"allow_localhost" mapstructure:"allow_localhost"`
|
||||
AllowedBaseDomain string `json:"allowed_base_domain" mapstructure:"allowed_base_domain"`
|
||||
AllowedDomainsOld string `json:"allowed_domains,omit_empty"`
|
||||
AllowedDomains []string `json:"allowed_domains_list" mapstructure:"allowed_domains"`
|
||||
AllowBaseDomain bool `json:"allow_base_domain"`
|
||||
AllowBareDomains bool `json:"allow_bare_domains" mapstructure:"allow_bare_domains"`
|
||||
AllowTokenDisplayName bool `json:"allow_token_displayname" mapstructure:"allow_token_displayname"`
|
||||
AllowSubdomains bool `json:"allow_subdomains" mapstructure:"allow_subdomains"`
|
||||
AllowGlobDomains bool `json:"allow_glob_domains" mapstructure:"allow_glob_domains"`
|
||||
AllowAnyName bool `json:"allow_any_name" mapstructure:"allow_any_name"`
|
||||
EnforceHostnames bool `json:"enforce_hostnames" mapstructure:"enforce_hostnames"`
|
||||
AllowIPSANs bool `json:"allow_ip_sans" mapstructure:"allow_ip_sans"`
|
||||
ServerFlag bool `json:"server_flag" mapstructure:"server_flag"`
|
||||
ClientFlag bool `json:"client_flag" mapstructure:"client_flag"`
|
||||
CodeSigningFlag bool `json:"code_signing_flag" mapstructure:"code_signing_flag"`
|
||||
EmailProtectionFlag bool `json:"email_protection_flag" mapstructure:"email_protection_flag"`
|
||||
UseCSRCommonName bool `json:"use_csr_common_name" mapstructure:"use_csr_common_name"`
|
||||
UseCSRSANs bool `json:"use_csr_sans" mapstructure:"use_csr_sans"`
|
||||
KeyType string `json:"key_type" mapstructure:"key_type"`
|
||||
KeyBits int `json:"key_bits" mapstructure:"key_bits"`
|
||||
MaxPathLength *int `json:",omitempty" mapstructure:"max_path_length"`
|
||||
KeyUsageOld string `json:"key_usage,omitempty"`
|
||||
KeyUsage []string `json:"key_usage_list" mapstructure:"key_usage"`
|
||||
OU string `json:"ou" mapstructure:"ou"`
|
||||
Organization string `json:"organization" mapstructure:"organization"`
|
||||
GenerateLease *bool `json:"generate_lease,omitempty"`
|
||||
NoStore bool `json:"no_store" mapstructure:"no_store"`
|
||||
|
||||
// Used internally for signing intermediates
|
||||
AllowExpirationPastCA bool
|
||||
}
|
||||
|
||||
func (r *roleEntry) ToResponseData() map[string]interface{} {
|
||||
responseData := map[string]interface{}{
|
||||
"ttl": r.TTL,
|
||||
"max_ttl": r.MaxTTL,
|
||||
"allow_localhost": r.AllowLocalhost,
|
||||
"allowed_domains": r.AllowedDomains,
|
||||
"allow_bare_domains": r.AllowBareDomains,
|
||||
"allow_token_displayname": r.AllowTokenDisplayName,
|
||||
"allow_subdomains": r.AllowSubdomains,
|
||||
"allow_glob_domains": r.AllowGlobDomains,
|
||||
"allow_any_name": r.AllowAnyName,
|
||||
"enforce_hostnames": r.EnforceHostnames,
|
||||
"allow_ip_sans": r.AllowIPSANs,
|
||||
"server_flag": r.ServerFlag,
|
||||
"client_flag": r.ClientFlag,
|
||||
"code_signing_flag": r.CodeSigningFlag,
|
||||
"email_protection_flag": r.EmailProtectionFlag,
|
||||
"use_csr_common_name": r.UseCSRCommonName,
|
||||
"use_csr_sans": r.UseCSRSANs,
|
||||
"key_type": r.KeyType,
|
||||
"key_bits": r.KeyBits,
|
||||
"key_usage": r.KeyUsage,
|
||||
"ou": r.OU,
|
||||
"organization": r.Organization,
|
||||
"no_store": r.NoStore,
|
||||
}
|
||||
if r.MaxPathLength != nil {
|
||||
responseData["max_path_length"] = r.MaxPathLength
|
||||
}
|
||||
if r.GenerateLease != nil {
|
||||
responseData["generate_lease"] = r.GenerateLease
|
||||
}
|
||||
return responseData
|
||||
}
|
||||
|
||||
const pathListRolesHelpSyn = `List the existing roles in this backend`
|
||||
|
||||
const pathListRolesHelpDesc = `Roles will be listed by the role name.`
|
||||
|
||||
@@ -120,6 +120,181 @@ func TestPki_RoleGenerateLease(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPki_RoleKeyUsage(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
b, storage := createBackendWithStorage(t)
|
||||
|
||||
roleData := map[string]interface{}{
|
||||
"allowed_domains": "myvault.com",
|
||||
"ttl": "5h",
|
||||
"key_usage": []string{"KeyEncipherment", "DigitalSignature"},
|
||||
}
|
||||
|
||||
roleReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/testrole",
|
||||
Storage: storage,
|
||||
Data: roleData,
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v resp: %#v", err, resp)
|
||||
}
|
||||
|
||||
roleReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v resp: %#v", err, resp)
|
||||
}
|
||||
|
||||
keyUsage := resp.Data["key_usage"].([]string)
|
||||
if len(keyUsage) != 2 {
|
||||
t.Fatalf("key_usage should have 2 values")
|
||||
}
|
||||
|
||||
// Check that old key usage value is nil
|
||||
var role roleEntry
|
||||
err = mapstructure.Decode(resp.Data, &role)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if role.KeyUsageOld != "" {
|
||||
t.Fatalf("old key usage storage value should be blank")
|
||||
}
|
||||
|
||||
// Make it explicit
|
||||
role.KeyUsageOld = "KeyEncipherment,DigitalSignature"
|
||||
role.KeyUsage = nil
|
||||
|
||||
entry, err := logical.StorageEntryJSON("role/testrole", role)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := storage.Put(entry); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reading should upgrade key_usage
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v resp: %#v", err, resp)
|
||||
}
|
||||
|
||||
keyUsage = resp.Data["key_usage"].([]string)
|
||||
if len(keyUsage) != 2 {
|
||||
t.Fatalf("key_usage should have 2 values")
|
||||
}
|
||||
|
||||
// Read back from storage to ensure upgrade
|
||||
entry, err = storage.Get("role/testrole")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if entry == nil {
|
||||
t.Fatalf("role should not be nil")
|
||||
}
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if result.KeyUsageOld != "" {
|
||||
t.Fatal("old key usage value should be blank")
|
||||
}
|
||||
if len(result.KeyUsage) != 2 {
|
||||
t.Fatal("key_usage should have 2 values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPki_RoleAllowedDomains(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
b, storage := createBackendWithStorage(t)
|
||||
|
||||
roleData := map[string]interface{}{
|
||||
"allowed_domains": []string{"foobar.com", "*example.com"},
|
||||
"ttl": "5h",
|
||||
}
|
||||
|
||||
roleReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/testrole",
|
||||
Storage: storage,
|
||||
Data: roleData,
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v resp: %#v", err, resp)
|
||||
}
|
||||
|
||||
roleReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v resp: %#v", err, resp)
|
||||
}
|
||||
|
||||
allowedDomains := resp.Data["allowed_domains"].([]string)
|
||||
if len(allowedDomains) != 2 {
|
||||
t.Fatalf("allowed_domains should have 2 values")
|
||||
}
|
||||
|
||||
// Check that old key usage value is nil
|
||||
var role roleEntry
|
||||
err = mapstructure.Decode(resp.Data, &role)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if role.AllowedDomainsOld != "" {
|
||||
t.Fatalf("old allowed_domains storage value should be blank")
|
||||
}
|
||||
|
||||
// Make it explicit
|
||||
role.AllowedDomainsOld = "foobar.com,*example.com"
|
||||
role.AllowedDomains = nil
|
||||
|
||||
entry, err := logical.StorageEntryJSON("role/testrole", role)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := storage.Put(entry); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reading should upgrade key_usage
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v resp: %#v", err, resp)
|
||||
}
|
||||
|
||||
allowedDomains = resp.Data["allowed_domains"].([]string)
|
||||
if len(allowedDomains) != 2 {
|
||||
t.Fatalf("allowed_domains should have 2 values")
|
||||
}
|
||||
|
||||
// Read back from storage to ensure upgrade
|
||||
entry, err = storage.Get("role/testrole")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if entry == nil {
|
||||
t.Fatalf("role should not be nil")
|
||||
}
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if result.AllowedDomainsOld != "" {
|
||||
t.Fatal("old allowed_domains value should be blank")
|
||||
}
|
||||
if len(result.AllowedDomains) != 2 {
|
||||
t.Fatal("allowed_domains should have 2 values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPki_RoleNoStore(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
@@ -149,7 +149,7 @@ func (b *backend) pathCAGenerateRoot(
|
||||
|
||||
cb, err := parsedBundle.ToCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting raw cert bundle to cert bundle: %s", err)
|
||||
return nil, errwrap.Wrapf("error converting raw cert bundle to cert bundle: {{err}}", err)
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
@@ -188,6 +188,13 @@ func (b *backend) pathCAGenerateRoot(
|
||||
}
|
||||
}
|
||||
|
||||
if data.Get("private_key_format").(string) == "pkcs8" {
|
||||
err = convertRespToPKCS8(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Store it as the CA bundle
|
||||
entry, err = logical.StorageEntryJSON("config/ca_bundle", cb)
|
||||
if err != nil {
|
||||
@@ -205,7 +212,7 @@ func (b *backend) pathCAGenerateRoot(
|
||||
Value: parsedBundle.CertificateBytes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to store certificate locally: %v", err)
|
||||
return nil, errwrap.Wrapf("unable to store certificate locally: {{err}}", err)
|
||||
}
|
||||
|
||||
// For ease of later use, also store just the certificate at a known
|
||||
|
||||
@@ -25,6 +25,12 @@ func Backend(conf *logical.BackendConfig) *backend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
|
||||
@@ -26,6 +26,12 @@ func Backend() *backend {
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
|
||||
@@ -271,6 +271,31 @@ func TestSSHBackend_Lookup(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHBackend_RoleList(t *testing.T) {
|
||||
testOTPRoleData := map[string]interface{}{
|
||||
"key_type": testOTPKeyType,
|
||||
"default_user": testUserName,
|
||||
"cidr_list": testCIDRList,
|
||||
}
|
||||
resp1 := map[string]interface{}{}
|
||||
resp2 := map[string]interface{}{
|
||||
"keys": []string{testOTPRoleName},
|
||||
"key_info": map[string]interface{}{
|
||||
testOTPRoleName: map[string]interface{}{
|
||||
"key_type": testOTPKeyType,
|
||||
},
|
||||
},
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Factory: testingFactory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testRoleList(t, resp1),
|
||||
testRoleWrite(t, testOTPRoleName, testOTPRoleData),
|
||||
testRoleList(t, resp2),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHBackend_DynamicKeyCreate(t *testing.T) {
|
||||
testDynamicRoleData := map[string]interface{}{
|
||||
"key_type": testDynamicKeyType,
|
||||
@@ -962,6 +987,25 @@ func testRoleWrite(t *testing.T, name string, data map[string]interface{}) logic
|
||||
}
|
||||
}
|
||||
|
||||
func testRoleList(t *testing.T, expected map[string]interface{}) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ListOperation,
|
||||
Path: "roles",
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
return fmt.Errorf("nil response")
|
||||
}
|
||||
if resp.Data == nil {
|
||||
return fmt.Errorf("nil data")
|
||||
}
|
||||
if !reflect.DeepEqual(resp.Data, expected) {
|
||||
return fmt.Errorf("Invalid response:\nactual:%#v\nexpected is %#v", resp.Data, expected)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testRoleRead(t *testing.T, roleName string, expected map[string]interface{}) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
|
||||
@@ -83,7 +83,7 @@ func pathRoles(b *backend) *framework.Path {
|
||||
Description: `
|
||||
[Required for Dynamic type] [Not applicable for OTP type] [Not applicable for CA type]
|
||||
Admin user at remote host. The shared key being registered should be
|
||||
for this user and should have root privileges. Everytime a dynamic
|
||||
for this user and should have root privileges. Everytime a dynamic
|
||||
credential is being generated for other users, Vault uses this admin
|
||||
username to login to remote host and install the generated credential
|
||||
for the other user.`,
|
||||
@@ -175,7 +175,7 @@ func pathRoles(b *backend) *framework.Path {
|
||||
`,
|
||||
},
|
||||
"ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `
|
||||
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
|
||||
The lease duration if no specific lease duration is
|
||||
@@ -184,7 +184,7 @@ func pathRoles(b *backend) *framework.Path {
|
||||
the value of max_ttl.`,
|
||||
},
|
||||
"max_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `
|
||||
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
|
||||
The maximum allowed lease duration
|
||||
@@ -386,15 +386,15 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
|
||||
return logical.ErrorResponse("missing admin username"), nil
|
||||
}
|
||||
|
||||
// This defaults to 1024 and it can also be 2048.
|
||||
// This defaults to 1024 and it can also be 2048 and 4096.
|
||||
keyBits := d.Get("key_bits").(int)
|
||||
if keyBits != 0 && keyBits != 1024 && keyBits != 2048 {
|
||||
if keyBits != 0 && keyBits != 1024 && keyBits != 2048 && keyBits != 4096 {
|
||||
return logical.ErrorResponse("invalid key_bits field"), nil
|
||||
}
|
||||
|
||||
// If user has not set this field, default it to 1024
|
||||
// If user has not set this field, default it to 2048
|
||||
if keyBits == 0 {
|
||||
keyBits = 1024
|
||||
keyBits = 2048
|
||||
}
|
||||
|
||||
// Store all the fields required by dynamic key type
|
||||
@@ -433,9 +433,9 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
|
||||
}
|
||||
|
||||
func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework.FieldData) (*sshRole, *logical.Response) {
|
||||
ttl := time.Duration(data.Get("ttl").(int)) * time.Second
|
||||
maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second
|
||||
role := &sshRole{
|
||||
MaxTTL: data.Get("max_ttl").(string),
|
||||
TTL: data.Get("ttl").(string),
|
||||
AllowedCriticalOptions: data.Get("allowed_critical_options").(string),
|
||||
AllowedExtensions: data.Get("allowed_extensions").(string),
|
||||
AllowUserCertificates: data.Get("allow_user_certificates").(bool),
|
||||
@@ -457,44 +457,12 @@ func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework
|
||||
defaultCriticalOptions := convertMapToStringValue(data.Get("default_critical_options").(map[string]interface{}))
|
||||
defaultExtensions := convertMapToStringValue(data.Get("default_extensions").(map[string]interface{}))
|
||||
|
||||
var maxTTL time.Duration
|
||||
maxSystemTTL := b.System().MaxLeaseTTL()
|
||||
if len(role.MaxTTL) == 0 {
|
||||
maxTTL = maxSystemTTL
|
||||
} else {
|
||||
var err error
|
||||
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid max ttl: %s", err))
|
||||
}
|
||||
}
|
||||
if maxTTL > maxSystemTTL {
|
||||
return nil, logical.ErrorResponse("Requested max TTL is higher than backend maximum")
|
||||
if ttl != 0 && maxTTL != 0 && ttl > maxTTL {
|
||||
return nil, logical.ErrorResponse(
|
||||
`"ttl" value must be less than "max_ttl" when both are specified`)
|
||||
}
|
||||
|
||||
ttl := b.System().DefaultLeaseTTL()
|
||||
if len(role.TTL) != 0 {
|
||||
var err error
|
||||
ttl, err = parseutil.ParseDurationSecond(role.TTL)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid ttl: %s", err))
|
||||
}
|
||||
}
|
||||
if ttl > maxTTL {
|
||||
// If they are using the system default, cap it to the role max;
|
||||
// if it was specified on the command line, make it an error
|
||||
if len(role.TTL) == 0 {
|
||||
ttl = maxTTL
|
||||
} else {
|
||||
return nil, logical.ErrorResponse(
|
||||
`"ttl" value must be less than "max_ttl" and/or backend default max lease TTL value`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Persist clamped TTLs
|
||||
// Persist TTLs
|
||||
role.TTL = ttl.String()
|
||||
role.MaxTTL = maxTTL.String()
|
||||
role.DefaultCriticalOptions = defaultCriticalOptions
|
||||
@@ -520,13 +488,115 @@ func (b *backend) getRole(s logical.Storage, n string) (*sshRole, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// parseRole converts a sshRole object into its map[string]interface representation,
|
||||
// with appropriate values for each KeyType. If the KeyType is invalid, it will retun
|
||||
// an error.
|
||||
func (b *backend) parseRole(role *sshRole) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
|
||||
switch role.KeyType {
|
||||
case KeyTypeOTP:
|
||||
result = map[string]interface{}{
|
||||
"default_user": role.DefaultUser,
|
||||
"cidr_list": role.CIDRList,
|
||||
"exclude_cidr_list": role.ExcludeCIDRList,
|
||||
"key_type": role.KeyType,
|
||||
"port": role.Port,
|
||||
"allowed_users": role.AllowedUsers,
|
||||
}
|
||||
case KeyTypeCA:
|
||||
ttl, err := parseutil.ParseDurationSecond(role.TTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxTTL, err := parseutil.ParseDurationSecond(role.MaxTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = map[string]interface{}{
|
||||
"allowed_users": role.AllowedUsers,
|
||||
"allowed_domains": role.AllowedDomains,
|
||||
"default_user": role.DefaultUser,
|
||||
"ttl": int64(ttl.Seconds()),
|
||||
"max_ttl": int64(maxTTL.Seconds()),
|
||||
"allowed_critical_options": role.AllowedCriticalOptions,
|
||||
"allowed_extensions": role.AllowedExtensions,
|
||||
"allow_user_certificates": role.AllowUserCertificates,
|
||||
"allow_host_certificates": role.AllowHostCertificates,
|
||||
"allow_bare_domains": role.AllowBareDomains,
|
||||
"allow_subdomains": role.AllowSubdomains,
|
||||
"allow_user_key_ids": role.AllowUserKeyIDs,
|
||||
"key_id_format": role.KeyIDFormat,
|
||||
"key_type": role.KeyType,
|
||||
"default_critical_options": role.DefaultCriticalOptions,
|
||||
"default_extensions": role.DefaultExtensions,
|
||||
}
|
||||
case KeyTypeDynamic:
|
||||
result = map[string]interface{}{
|
||||
"key": role.KeyName,
|
||||
"admin_user": role.AdminUser,
|
||||
"default_user": role.DefaultUser,
|
||||
"cidr_list": role.CIDRList,
|
||||
"exclude_cidr_list": role.ExcludeCIDRList,
|
||||
"port": role.Port,
|
||||
"key_type": role.KeyType,
|
||||
"key_bits": role.KeyBits,
|
||||
"allowed_users": role.AllowedUsers,
|
||||
"key_option_specs": role.KeyOptionSpecs,
|
||||
// Returning install script will make the output look messy.
|
||||
// But this is one way for clients to see the script that is
|
||||
// being used to install the key. If there is some problem,
|
||||
// the script can be modified and configured by clients.
|
||||
"install_script": role.InstallScript,
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key type: %v", role.KeyType)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("roles/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
keyInfo := map[string]interface{}{}
|
||||
for _, entry := range entries {
|
||||
role, err := b.getRole(req.Storage, entry)
|
||||
if err != nil {
|
||||
// On error, log warning and continue
|
||||
if b.Logger().IsWarn() {
|
||||
b.Logger().Warn("ssh: error getting role info", "role", entry, "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if role == nil {
|
||||
// On empty role, log warning and continue
|
||||
if b.Logger().IsWarn() {
|
||||
b.Logger().Warn("ssh: no role info found", "role", entry)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
roleInfo, err := b.parseRole(role)
|
||||
if err != nil {
|
||||
if b.Logger().IsWarn() {
|
||||
b.Logger().Warn("ssh: error parsing role info", "role", entry, "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if keyType, ok := roleInfo["key_type"]; ok {
|
||||
keyInfo[entry] = map[string]interface{}{
|
||||
"key_type": keyType,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return logical.ListResponseWithInfo(entries, keyInfo), nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
@@ -538,60 +608,14 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return information should be based on the key type of the role
|
||||
if role.KeyType == KeyTypeOTP {
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"default_user": role.DefaultUser,
|
||||
"cidr_list": role.CIDRList,
|
||||
"exclude_cidr_list": role.ExcludeCIDRList,
|
||||
"key_type": role.KeyType,
|
||||
"port": role.Port,
|
||||
"allowed_users": role.AllowedUsers,
|
||||
},
|
||||
}, nil
|
||||
} else if role.KeyType == KeyTypeCA {
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"allowed_users": role.AllowedUsers,
|
||||
"allowed_domains": role.AllowedDomains,
|
||||
"default_user": role.DefaultUser,
|
||||
"max_ttl": role.MaxTTL,
|
||||
"ttl": role.TTL,
|
||||
"allowed_critical_options": role.AllowedCriticalOptions,
|
||||
"allowed_extensions": role.AllowedExtensions,
|
||||
"allow_user_certificates": role.AllowUserCertificates,
|
||||
"allow_host_certificates": role.AllowHostCertificates,
|
||||
"allow_bare_domains": role.AllowBareDomains,
|
||||
"allow_subdomains": role.AllowSubdomains,
|
||||
"allow_user_key_ids": role.AllowUserKeyIDs,
|
||||
"key_id_format": role.KeyIDFormat,
|
||||
"key_type": role.KeyType,
|
||||
"default_critical_options": role.DefaultCriticalOptions,
|
||||
"default_extensions": role.DefaultExtensions,
|
||||
},
|
||||
}, nil
|
||||
} else {
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"key": role.KeyName,
|
||||
"admin_user": role.AdminUser,
|
||||
"default_user": role.DefaultUser,
|
||||
"cidr_list": role.CIDRList,
|
||||
"exclude_cidr_list": role.ExcludeCIDRList,
|
||||
"port": role.Port,
|
||||
"key_type": role.KeyType,
|
||||
"key_bits": role.KeyBits,
|
||||
"allowed_users": role.AllowedUsers,
|
||||
"key_option_specs": role.KeyOptionSpecs,
|
||||
// Returning install script will make the output look messy.
|
||||
// But this is one way for clients to see the script that is
|
||||
// being used to install the key. If there is some problem,
|
||||
// the script can be modified and configured by clients.
|
||||
"install_script": role.InstallScript,
|
||||
},
|
||||
}, nil
|
||||
roleInfo, err := b.parseRole(role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: roleInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
|
||||
@@ -43,7 +43,7 @@ func pathSign(b *backend) *framework.Path {
|
||||
Description: `The desired role with configuration for this request.`,
|
||||
},
|
||||
"ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `The requested Time To Live for the SSH certificate;
|
||||
sets the expiration date. If not specified
|
||||
the role default, backend default, or system
|
||||
@@ -345,40 +345,34 @@ func (b *backend) calculateExtensions(data *framework.FieldData, role *sshRole)
|
||||
}
|
||||
|
||||
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
|
||||
|
||||
var ttl, maxTTL time.Duration
|
||||
var ttlField string
|
||||
ttlFieldInt, ok := data.GetOk("ttl")
|
||||
if !ok {
|
||||
ttlField = role.TTL
|
||||
} else {
|
||||
ttlField = ttlFieldInt.(string)
|
||||
}
|
||||
var err error
|
||||
|
||||
if len(ttlField) == 0 {
|
||||
ttlRaw, specifiedTTL := data.GetOk("ttl")
|
||||
if specifiedTTL {
|
||||
ttl = time.Duration(ttlRaw.(int)) * time.Second
|
||||
} else {
|
||||
ttl, err = parseutil.ParseDurationSecond(role.TTL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = b.System().DefaultLeaseTTL()
|
||||
} else {
|
||||
var err error
|
||||
ttl, err = parseutil.ParseDurationSecond(ttlField)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid requested ttl: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(role.MaxTTL) == 0 {
|
||||
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if maxTTL == 0 {
|
||||
maxTTL = b.System().MaxLeaseTTL()
|
||||
} else {
|
||||
var err error
|
||||
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid requested max ttl: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if ttl > maxTTL {
|
||||
// Don't error if they were using system defaults, only error if
|
||||
// they specifically chose a bad TTL
|
||||
if len(ttlField) == 0 {
|
||||
if !specifiedTTL {
|
||||
ttl = maxTTL
|
||||
} else {
|
||||
return 0, fmt.Errorf("ttl is larger than maximum allowed (%d)", maxTTL/time.Second)
|
||||
|
||||
@@ -43,6 +43,8 @@ func Backend(conf *logical.BackendConfig) *backend {
|
||||
b.pathHMAC(),
|
||||
b.pathSign(),
|
||||
b.pathVerify(),
|
||||
b.pathBackup(),
|
||||
b.pathRestore(),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{},
|
||||
|
||||
@@ -38,6 +38,191 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
|
||||
return b, config.StorageView
|
||||
}
|
||||
|
||||
func TestTransit_RSA(t *testing.T) {
|
||||
testTransit_RSA(t, "rsa-2048")
|
||||
testTransit_RSA(t, "rsa-4096")
|
||||
}
|
||||
|
||||
func testTransit_RSA(t *testing.T, keyType string) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
b, storage := createBackendWithStorage(t)
|
||||
|
||||
keyReq := &logical.Request{
|
||||
Path: "keys/rsa",
|
||||
Operation: logical.UpdateOperation,
|
||||
Data: map[string]interface{}{
|
||||
"type": keyType,
|
||||
},
|
||||
Storage: storage,
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(keyReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
|
||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox"
|
||||
|
||||
encryptReq := &logical.Request{
|
||||
Path: "encrypt/rsa",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"plaintext": plaintext,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(encryptReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
|
||||
ciphertext1 := resp.Data["ciphertext"].(string)
|
||||
|
||||
decryptReq := &logical.Request{
|
||||
Path: "decrypt/rsa",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"ciphertext": ciphertext1,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(decryptReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
|
||||
decryptedPlaintext := resp.Data["plaintext"]
|
||||
|
||||
if plaintext != decryptedPlaintext {
|
||||
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
|
||||
}
|
||||
|
||||
// Rotate the key
|
||||
rotateReq := &logical.Request{
|
||||
Path: "keys/rsa/rotate",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
}
|
||||
resp, err = b.HandleRequest(rotateReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
|
||||
// Encrypt again
|
||||
resp, err = b.HandleRequest(encryptReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
ciphertext2 := resp.Data["ciphertext"].(string)
|
||||
|
||||
if ciphertext1 == ciphertext2 {
|
||||
t.Fatalf("expected different ciphertexts")
|
||||
}
|
||||
|
||||
// See if the older ciphertext can still be decrypted
|
||||
resp, err = b.HandleRequest(decryptReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
if resp.Data["plaintext"].(string) != plaintext {
|
||||
t.Fatal("failed to decrypt old ciphertext after rotating the key")
|
||||
}
|
||||
|
||||
// Decrypt the new ciphertext
|
||||
decryptReq.Data = map[string]interface{}{
|
||||
"ciphertext": ciphertext2,
|
||||
}
|
||||
resp, err = b.HandleRequest(decryptReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
if resp.Data["plaintext"].(string) != plaintext {
|
||||
t.Fatal("failed to decrypt ciphertext after rotating the key")
|
||||
}
|
||||
|
||||
signReq := &logical.Request{
|
||||
Path: "sign/rsa",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"input": plaintext,
|
||||
},
|
||||
}
|
||||
resp, err = b.HandleRequest(signReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
signature := resp.Data["signature"].(string)
|
||||
|
||||
verifyReq := &logical.Request{
|
||||
Path: "verify/rsa",
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{
|
||||
"input": plaintext,
|
||||
"signature": signature,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(verifyReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
if !resp.Data["valid"].(bool) {
|
||||
t.Fatalf("failed to verify the RSA signature")
|
||||
}
|
||||
|
||||
signReq.Data = map[string]interface{}{
|
||||
"input": plaintext,
|
||||
"algorithm": "invalid",
|
||||
}
|
||||
resp, err = b.HandleRequest(signReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil || !resp.IsError() {
|
||||
t.Fatal("expected an error response")
|
||||
}
|
||||
|
||||
signReq.Data = map[string]interface{}{
|
||||
"input": plaintext,
|
||||
"algorithm": "sha2-512",
|
||||
}
|
||||
resp, err = b.HandleRequest(signReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
signature = resp.Data["signature"].(string)
|
||||
|
||||
verifyReq.Data = map[string]interface{}{
|
||||
"input": plaintext,
|
||||
"signature": signature,
|
||||
}
|
||||
resp, err = b.HandleRequest(verifyReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
if resp.Data["valid"].(bool) {
|
||||
t.Fatalf("expected validation to fail")
|
||||
}
|
||||
|
||||
verifyReq.Data = map[string]interface{}{
|
||||
"input": plaintext,
|
||||
"signature": signature,
|
||||
"algorithm": "sha2-512",
|
||||
}
|
||||
resp, err = b.HandleRequest(verifyReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
|
||||
}
|
||||
if !resp.Data["valid"].(bool) {
|
||||
t.Fatalf("failed to verify the RSA signature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
@@ -634,7 +819,7 @@ func TestKeyUpgrade(t *testing.T) {
|
||||
if p.Key != nil ||
|
||||
p.Keys == nil ||
|
||||
len(p.Keys) != 1 ||
|
||||
!reflect.DeepEqual(p.Keys[1].Key, key) {
|
||||
!reflect.DeepEqual(p.Keys[strconv.Itoa(1)].Key, key) {
|
||||
t.Errorf("bad key migration, result is %#v", p.Keys)
|
||||
}
|
||||
}
|
||||
@@ -1091,3 +1276,38 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
||||
// Wait for them all to finish
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBadInput(t *testing.T) {
|
||||
var b *backend
|
||||
sysView := logical.TestSystemView()
|
||||
storage := &logical.InmemStorage{}
|
||||
|
||||
b = Backend(&logical.BackendConfig{
|
||||
StorageView: storage,
|
||||
System: sysView,
|
||||
})
|
||||
|
||||
req := &logical.Request{
|
||||
Storage: storage,
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "keys/test",
|
||||
}
|
||||
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response")
|
||||
}
|
||||
|
||||
req.Path = "decrypt/test"
|
||||
req.Data = map[string]interface{}{
|
||||
"ciphertext": "vault:v1:abcd",
|
||||
}
|
||||
|
||||
_, err = b.HandleRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user