mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	Merge branch 'master-oss' into sethvargo/cli-magic
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -64,6 +64,7 @@ tags | |||||||
| # compiled output | # compiled output | ||||||
| ui/dist | ui/dist | ||||||
| ui/tmp | ui/tmp | ||||||
|  | ui/root | ||||||
|  |  | ||||||
| # dependencies | # dependencies | ||||||
| ui/node_modules | ui/node_modules | ||||||
|   | |||||||
							
								
								
									
										245
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										245
									
								
								CHANGELOG.md
									
									
									
									
									
								
							| @@ -1,27 +1,254 @@ | |||||||
| ## 0.8.4 (Unreleased) | ## 0.9.2 (Unreleased) | ||||||
|  |  | ||||||
|  | IMPROVEMENTS: | ||||||
|  |  | ||||||
|  |  * physical/s3: Allow using paths with S3 for non-AWS deployments [GH-3730] | ||||||
|  |  * physical/s3: Add ability to disable SSL for non-AWS deployments [GH-3730] | ||||||
|  |  | ||||||
|  | ## 0.9.1 (December 21st, 2017) | ||||||
|  |  | ||||||
|  | DEPRECATIONS/CHANGES: | ||||||
|  |  | ||||||
|  |  * AppRole Case Sensitivity: In prior versions of Vault, `list` operations | ||||||
|  |    against AppRole roles would require preserving case in the role name, even | ||||||
|  |    though most other operations within AppRole are case-insensitive with | ||||||
|  |    respect to the role name. This has been fixed; existing roles will behave as | ||||||
|  |    they have in the past, but new roles will act case-insensitively in these | ||||||
|  |    cases. | ||||||
|  |  * Token Auth Backend Roles parameter types: For `allowed_policies` and | ||||||
|  |    `disallowed_policies` in role definitions in the token auth backend, input | ||||||
|  |    can now be a comma-separated string or an array of strings. Reading a role | ||||||
|  |    will now return arrays for these parameters. | ||||||
|  |  * Transit key exporting: You can now mark a key in the `transit` backend as | ||||||
|  |    `exportable` at any time, rather than just at creation time; however, once | ||||||
|  |    this value is set, it still cannot be unset. | ||||||
|  |  * PKI Secret Backend Roles parameter types: For `allowed_domains` and | ||||||
|  |    `key_usage` in role definitions in the PKI secret backend, input | ||||||
|  |    can now be a comma-separated string or an array of strings. Reading a role | ||||||
|  |    will now return arrays for these parameters. | ||||||
|  |  * SSH Dynamic Keys Method Defaults to 2048-bit Keys: When using the dynamic | ||||||
|  |    key method in the SSH backend, the default is now to use 2048-bit keys if no | ||||||
|  |    specific key bit size is specified. | ||||||
|  |  * Consul Secret Backend lease handling: The `consul` secret backend can now | ||||||
|  |    accept both strings and integer numbers of seconds for its lease value. The | ||||||
|  |    value returned on a role read will be an integer number of seconds instead | ||||||
|  |    of a human-friendly string. | ||||||
|  |  * Unprintable characters not allowed in API paths: Unprintable characters are | ||||||
|  |    no longer allowed in names in the API (paths and path parameters), with an | ||||||
|  |    extra restriction on whitespace characters. Allowed characters are those | ||||||
|  |    that are considered printable by Unicode plus spaces. | ||||||
|  |  | ||||||
|  | FEATURES: | ||||||
|  |  | ||||||
|  |  * **Transit Backup/Restore**: The `transit` backend now supports a backup | ||||||
|  |    operation that can export a given key, including all key versions and | ||||||
|  |    configuration, as well as a restore operation allowing import into another | ||||||
|  |    Vault. | ||||||
|  |  * **gRPC Database Plugins**: Database plugins now use gRPC for transport, | ||||||
|  |    allowing them to be written in other languages. | ||||||
|  |  * **Nomad Secret Backend**: Nomad ACL tokens can now be generated and revoked | ||||||
|  |    using Vault. | ||||||
|  |  * **TLS Cert Auth Backend Improvements**: The `cert` auth backend can now | ||||||
|  |    match against custom certificate extensions via exact or glob matching, and | ||||||
|  |    additionally supports max_ttl and periodic token toggles. | ||||||
|  |  | ||||||
|  | IMPROVEMENTS: | ||||||
|  |  | ||||||
|  |  * auth/cert: Support custom certificate constraints [GH-3634] | ||||||
|  |  * auth/cert: Support setting `max_ttl` and `period` [GH-3642] | ||||||
|  |  * audit/file: Setting a file mode of `0000` will now disable Vault from | ||||||
|  |    automatically `chmod`ing the log file [GH-3649] | ||||||
|  |  * auth/github: The legacy MFA system can now be used with the GitHub auth | ||||||
|  |    backend [GH-3696] | ||||||
|  |  * auth/okta: The legacy MFA system can now be used with the Okta auth backend | ||||||
|  |    [GH-3653] | ||||||
|  |  * auth/token: `allowed_policies` and `disallowed_policies` can now be specified | ||||||
|  |    as a comma-separated string or an array of strings [GH-3641] | ||||||
|  |  * command/server: The log level can now be specified with `VAULT_LOG_LEVEL` | ||||||
|  |    [GH-3721] | ||||||
|  |  * core: Period values from auth backends will now be checked and applied to the | ||||||
|  |    TTL value directly by core on login and renewal requests [GH-3677] | ||||||
|  |  * database/mongodb: Add optional `write_concern` parameter, which can be set | ||||||
|  |    during database configuration. This establishes a session-wide [write | ||||||
|  |    concern](https://docs.mongodb.com/manual/reference/write-concern/) for the | ||||||
|  |    lifecycle of the mount [GH-3646] | ||||||
|  |  * http: Request path containing non-printable characters will return 400 - Bad | ||||||
|  |    Request [GH-3697] | ||||||
|  |  * mfa/okta: Filter a given email address as a login filter, allowing operation | ||||||
|  |    when login email and account email are different | ||||||
|  |  * plugins: Make Vault more resilient when unsealing when plugins are | ||||||
|  |    unavailable [GH-3686] | ||||||
|  |  * secret/pki: `allowed_domains` and `key_usage` can now be specified | ||||||
|  |    as a comma-separated string or an array of strings [GH-3642] | ||||||
|  |  * secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593] | ||||||
|  |  * secret/consul: The Consul secret backend now uses the value of `lease` set | ||||||
|  |    on the role, if set, when renewing a secret. [GH-3796] | ||||||
|  |  * storage/mysql: Don't attempt database creation if it exists, which can help | ||||||
|  |    under certain permissions constraints [GH-3716] | ||||||
|  |  | ||||||
|  | BUG FIXES: | ||||||
|  |  | ||||||
|  |  * api/status (enterprise): Fix status reporting when using an auto seal | ||||||
|  |  * auth/approle: Fix case-sensitive/insensitive comparison issue [GH-3665] | ||||||
|  |  * auth/cert: Return `allowed_names` on role read [GH-3654] | ||||||
|  |  * auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496] | ||||||
|  |    [GH-3625] [GH-3656] | ||||||
|  |  * core: Fix seal status reporting when using an autoseal | ||||||
|  |  * core: Add creation path to wrap info for a control group token | ||||||
|  |  * core: Fix potential panic that could occur using plugins when a node | ||||||
|  |    transitioned from active to standby [GH-3638] | ||||||
|  |  * core: Fix memory ballooning when a connection would connect to the cluster | ||||||
|  |    port and then go away -- redux! [GH-3680] | ||||||
|  |  * core: Replace recursive token revocation logic with depth-first logic, which | ||||||
|  |    can avoid hitting stack depth limits in extreme cases [GH-2348] | ||||||
|  |  * core: When doing a read on configured audited-headers, properly handle case | ||||||
|  |    insensitivity [GH-3701] | ||||||
|  |  * core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable | ||||||
|  |  * database/mysql: Allow the creation statement to use commands that are not yet | ||||||
|  |    supported by the prepare statement protocol [GH-3619] | ||||||
|  |  * plugin/auth-gcp: Fix IAM roles when using `allow_gce_inference` [VPAG-19] | ||||||
|  |  | ||||||
|  | ## 0.9.0.1 (November 21st, 2017) (Enterprise Only) | ||||||
|  |  | ||||||
|  | IMPROVEMENTS: | ||||||
|  |  | ||||||
|  |  * auth/gcp: Support seal wrapping of configuration parameters | ||||||
|  |  * auth/kubernetes: Support seal wrapping of configuration parameters | ||||||
|  |  | ||||||
|  | BUG FIXES: | ||||||
|  |  | ||||||
|  |  * Fix an upgrade issue with some physical backends when migrating from legacy | ||||||
|  |    HSM stored key support to the new Seal Wrap mechanism | ||||||
|  |  | ||||||
|  | ## 0.9.0 (November 14th, 2017) | ||||||
|  |  | ||||||
|  | DEPRECATIONS/CHANGES: | ||||||
|  |  | ||||||
|  |  * HSM config parameter requirements: When using Vault with an HSM, a new | ||||||
|  |    paramter is required: `hmac_key_label`.  This performs a similar function to | ||||||
|  |    `key_label` but for the HMAC key Vault will use. Vault will generate a | ||||||
|  |    suitable key if this value is specified and `generate_key` is set true. | ||||||
|  |  * API HTTP client behavior: When calling `NewClient` the API no longer | ||||||
|  |    modifies the provided client/transport. In particular this means it will no | ||||||
|  |    longer enable redirection limiting and HTTP/2 support on custom clients. It | ||||||
|  |    is suggested that if you want to make changes to an HTTP client that you use | ||||||
|  |    one created by `DefaultConfig` as a starting point. | ||||||
|  |  * AWS EC2 client nonce behavior: The client nonce generated by the backend | ||||||
|  |    that gets returned along with the authentication response will be audited in | ||||||
|  |    plaintext. If this is undesired, the clients can choose to supply a custom | ||||||
|  |    nonce to the login endpoint. The custom nonce set by the client will from | ||||||
|  |    now on, not be returned back with the authentication response, and hence not | ||||||
|  |    audit logged. | ||||||
|  |  * AWS Auth role options: The API will now error when trying to create or | ||||||
|  |    update a role with the mutually-exclusive options | ||||||
|  |    `disallow_reauthentication` and `allow_instance_migration`. | ||||||
|  |  * SSH CA role read changes: When reading back a role from the `ssh` backend, | ||||||
|  |    the TTL/max TTL values will now be an integer number of seconds rather than | ||||||
|  |    a string. This better matches the API elsewhere in Vault. | ||||||
|  |  * SSH role list changes: When listing roles from the `ssh` backend via the API, | ||||||
|  |    the response data will additionally return a `key_info` map that will contain | ||||||
|  |    a map of each key with a corresponding object containing the `key_type`. | ||||||
|  |  * More granularity in audit logs: Audit request and response entires are still | ||||||
|  |    in RFC3339 format but now have a granularity of nanoseconds. | ||||||
|  |  * High availability related values have been moved out of the `storage` and | ||||||
|  |    `ha_storage` stanzas, and into the top-level configuration. `redirect_addr` | ||||||
|  |    has been renamed to `api_addr`. The stanzas still support accepting | ||||||
|  |    HA-related values to maintain backward compatibility, but top-level values | ||||||
|  |    will take precedence. | ||||||
|  |  * A new `seal` stanza has been added to the configuration file, which is | ||||||
|  |    optional and enables configuration of the seal type to use for additional | ||||||
|  |    data protection, such as using HSM or Cloud KMS solutions to encrypt and | ||||||
|  |    decrypt data. | ||||||
|  |   | ||||||
|  | FEATURES: | ||||||
|  |  | ||||||
|  |  * **RSA Support for Transit Backend**: Transit backend can now generate RSA | ||||||
|  |    keys which can be used for encryption and signing. [GH-3489] | ||||||
|  |  * **Identity System**: Now in open source and with significant enhancements, | ||||||
|  |    Identity is an integrated system for understanding users across tokens and | ||||||
|  |    enabling easier management of users directly and via groups. | ||||||
|  |  * **External Groups in Identity**: Vault can now automatically assign users | ||||||
|  |    and systems to groups in Identity based on their membership in external | ||||||
|  |    groups. | ||||||
|  |  * **Seal Wrap / FIPS 140-2 Compatibility (Enterprise)**: Vault can now take | ||||||
|  |    advantage of FIPS 140-2-certified HSMs to ensure that Critical Security | ||||||
|  |    Parameters are protected in a compliant fashion. Vault's implementation has | ||||||
|  |    received a statement of compliance from Leidos. | ||||||
|  |  * **Control Groups (Enterprise)**: Require multiple members of an Identity | ||||||
|  |    group to authorize a requested action before it is allowed to run. | ||||||
|  |  * **Cloud Auto-Unseal (Enterprise)**: Automatically unseal Vault using AWS KMS | ||||||
|  |    and GCP CKMS. | ||||||
|  |  * **Sentinel Integration (Enterprise)**: Take advantage of HashiCorp Sentinel | ||||||
|  |    to create extremely flexible access control policies -- even on | ||||||
|  |    unauthenticated endpoints. | ||||||
|  |  * **Barrier Rekey Support for Auto-Unseal (Enterprise)**: When using auto-unsealing | ||||||
|  |    functionality, the `rekey` operation is now supported; it uses recovery keys | ||||||
|  |    to authorize the master key rekey. | ||||||
|  |  * **Operation Token for Disaster Recovery Actions (Enterprise)**: When using | ||||||
|  |    Disaster Recovery replication, a token can be created that can be used to | ||||||
|  |    authorize actions such as promotion and updating primary information, rather | ||||||
|  |    than using recovery keys. | ||||||
|  |  * **Trigger Auto-Unseal with Recovery Keys (Enterprise)**: When using | ||||||
|  |    auto-unsealing, a request to unseal Vault can be triggered by a threshold of | ||||||
|  |    recovery keys, rather than requiring the Vault process to be restarted. | ||||||
|  |  * **UI Redesign (Enterprise)**: All new experience for the Vault Enterprise | ||||||
|  |    UI. The look and feel has been completely redesigned to give users a better | ||||||
|  |    experience and make managing secrets fast and easy. | ||||||
|  |  * **UI: SSH Secret Backend (Enterprise)**: Configure an SSH secret backend, | ||||||
|  |    create and browse roles. And use them to sign keys or generate one time | ||||||
|  |    passwords. | ||||||
|  |  * **UI: AWS Secret Backend (Enterprise)**: You can now configure the AWS | ||||||
|  |    backend via the Vault Enterprise UI. In addition you can create roles, | ||||||
|  |    browse the roles and Generate IAM Credentials from them in the UI. | ||||||
|  |  | ||||||
| IMPROVEMENTS: | IMPROVEMENTS: | ||||||
|  |  | ||||||
|  * api: Add ability to set custom headers on each call [GH-3394] |  * api: Add ability to set custom headers on each call [GH-3394] | ||||||
|  * command/server: Add config option to disable requesting client certificates |  * command/server: Add config option to disable requesting client certificates | ||||||
|    [GH-3373] |    [GH-3373] | ||||||
|  * secret/cassandra: Work around Cassandra ignoring consistency levels for a |  * core: Disallow mounting underneath an existing path, not just over [GH-2919] | ||||||
|    user listing query [GH-3469] |  | ||||||
|  * secret/pki: Allow entering URLs for `pki` as both comma-separated strings and JSON |  | ||||||
|    arrays [GH-3409] |  | ||||||
|  * secret/transit: Sign and verify operations now support a `none` hash |  | ||||||
|    algorithm to allow signing/verifying pre-hashed data [GH-3448] |  | ||||||
|  * physical/file: Use `700` as permissions when creating directories. The files |  * physical/file: Use `700` as permissions when creating directories. The files | ||||||
|    themselves were `600` and are all encrypted, but this doesn't hurt. |    themselves were `600` and are all encrypted, but this doesn't hurt. | ||||||
|  |  * secret/aws: Add ability to use custom IAM/STS endpoints [GH-3416] | ||||||
|  |  * secret/cassandra: Work around Cassandra ignoring consistency levels for a | ||||||
|  |    user listing query [GH-3469] | ||||||
|  |  * secret/pki: Private keys can now be marshalled as PKCS#8 [GH-3518] | ||||||
|  |  * secret/pki: Allow entering URLs for `pki` as both comma-separated strings and JSON | ||||||
|  |    arrays [GH-3409] | ||||||
|  |  * secret/ssh: Role TTL/max TTL can now be specified as either a string or an | ||||||
|  |    integer [GH-3507] | ||||||
|  |  * secret/transit: Sign and verify operations now support a `none` hash | ||||||
|  |    algorithm to allow signing/verifying pre-hashed data [GH-3448] | ||||||
|  |  * secret/database: Add the ability to glob allowed roles in the Database Backend [GH-3387] | ||||||
|  |  * ui (enterprise): Support for RSA keys in the transit backend | ||||||
|  |  * ui (enterprise): Support for DR Operation Token generation, promoting, and | ||||||
|  |    updating primary on DR Secondary clusters | ||||||
|  |  | ||||||
| BUG FIXES: | BUG FIXES: | ||||||
|  |  | ||||||
|  * api: Fix panic when setting a custom HTTP client but with a nil transport |  * api: Fix panic when setting a custom HTTP client but with a nil transport | ||||||
|    [GH-3437] |    [GH-3435] [GH-3437] | ||||||
|  |  * api: Fix authing to the `cert` backend when the CA for the client cert is | ||||||
|  |    not known to the server's listener [GH-2946] | ||||||
|  |  * auth/approle: Create role ID index during read if a role is missing one [GH-3561] | ||||||
|  |  * auth/aws: Don't allow mutually exclusive options [GH-3291] | ||||||
|  * auth/radius: Fix logging in in some situations [GH-3461] |  * auth/radius: Fix logging in in some situations [GH-3461] | ||||||
|  |  * core: Fix memleak when a connection would connect to the cluster port and | ||||||
|  |    then go away [GH-3513] | ||||||
|  |  * core: Fix panic if a single-use token is used to step-down or seal [GH-3497] | ||||||
|  |  * core: Set rather than add headers to prevent some duplicated headers in | ||||||
|  |    responses when requests were forwarded to the active node [GH-3485] | ||||||
|  * physical/etcd3: Fix some listing issues due to how etcd3 does prefix |  * physical/etcd3: Fix some listing issues due to how etcd3 does prefix | ||||||
|    matching [GH-3406] |    matching [GH-3406] | ||||||
|  |  * physical/etcd3: Fix case where standbys can lose their etcd client lease | ||||||
|  |    [GH-3031] | ||||||
|  |  * physical/file: Fix listing when underscores are the first component of a | ||||||
|  |    path [GH-3476] | ||||||
|  * plugins: Allow response errors to be returned from backend plugins [GH-3412] |  * plugins: Allow response errors to be returned from backend plugins [GH-3412] | ||||||
|  |  * secret/transit: Fix panic if the length of the input ciphertext was less | ||||||
|  |    than the expected nonce length [GH-3521] | ||||||
|  |  * ui (enterprise): Reinstate support for generic secret backends - this was | ||||||
|  |    erroneously removed in a previous release | ||||||
|  |  | ||||||
| ## 0.8.3 (September 19th, 2017) | ## 0.8.3 (September 19th, 2017) | ||||||
|  |  | ||||||
| @@ -117,7 +344,7 @@ IMPROVEMENTS: | |||||||
|  |  | ||||||
|  * audit/file: Allow specifying `stdout` as the `file_path` to log to standard |  * audit/file: Allow specifying `stdout` as the `file_path` to log to standard | ||||||
|    output [GH-3235] |    output [GH-3235] | ||||||
|  * auth/aws: Allow wildcards in `bound_iam_principal_id` [GH-3213] |  * auth/aws: Allow wildcards in `bound_iam_principal_arn` [GH-3213] | ||||||
|  * auth/okta: Compare groups case-insensitively since Okta is only |  * auth/okta: Compare groups case-insensitively since Okta is only | ||||||
|    case-preserving [GH-3240] |    case-preserving [GH-3240] | ||||||
|  * auth/okta: Standardize Okta configuration APIs across backends [GH-3245] |  * auth/okta: Standardize Okta configuration APIs across backends [GH-3245] | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								Makefile
									
									
									
									
									
								
							| @@ -79,7 +79,7 @@ vet: | |||||||
| prep: fmtcheck | prep: fmtcheck | ||||||
| 	@sh -c "'$(CURDIR)/scripts/goversioncheck.sh' '$(GO_VERSION_MIN)'" | 	@sh -c "'$(CURDIR)/scripts/goversioncheck.sh' '$(GO_VERSION_MIN)'" | ||||||
| 	go generate $(go list ./... | grep -v /vendor/) | 	go generate $(go list ./... | grep -v /vendor/) | ||||||
| 	cp .hooks/* .git/hooks/ | 	@if [ -d .git/hooks ]; then cp .hooks/* .git/hooks/; fi | ||||||
|  |  | ||||||
| # bootstrap the build by downloading additional tools | # bootstrap the build by downloading additional tools | ||||||
| bootstrap: | bootstrap: | ||||||
| @@ -92,8 +92,11 @@ proto: | |||||||
| 	protoc -I helper/forwarding -I vault -I ../../.. vault/*.proto --go_out=plugins=grpc:vault | 	protoc -I helper/forwarding -I vault -I ../../.. vault/*.proto --go_out=plugins=grpc:vault | ||||||
| 	protoc -I helper/storagepacker helper/storagepacker/types.proto --go_out=plugins=grpc:helper/storagepacker | 	protoc -I helper/storagepacker helper/storagepacker/types.proto --go_out=plugins=grpc:helper/storagepacker | ||||||
| 	protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding | 	protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding | ||||||
|  | 	protoc -I physical physical/types.proto --go_out=plugins=grpc:physical | ||||||
| 	protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity | 	protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity | ||||||
|  | 	protoc  builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:. | ||||||
| 	sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go | 	sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go | ||||||
|  | 	sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go | ||||||
|  |  | ||||||
| fmtcheck: | fmtcheck: | ||||||
| 	@sh -c "'$(CURDIR)/scripts/gofmtcheck.sh'" | 	@sh -c "'$(CURDIR)/scripts/gofmtcheck.sh'" | ||||||
|   | |||||||
| @@ -1,8 +1,10 @@ | |||||||
| Vault [](https://travis-ci.org/hashicorp/vault) [](https://gitter.im/hashicorp-vault/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [](https://www.hashicorp.com/products/vault/?utm_source=github&utm_medium=banner&utm_campaign=github-vault-enterprise) | # Vault [](https://travis-ci.org/hashicorp/vault) [](https://gitter.im/hashicorp-vault/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [](https://www.hashicorp.com/products/vault/?utm_source=github&utm_medium=banner&utm_campaign=github-vault-enterprise) | ||||||
| ========= |  | ||||||
|  | ---- | ||||||
|  |  | ||||||
| **Please note**: We take Vault's security and our users' trust very seriously. If you believe you have found a security issue in Vault, _please responsibly disclose_ by contacting us at [security@hashicorp.com](mailto:security@hashicorp.com). | **Please note**: We take Vault's security and our users' trust very seriously. If you believe you have found a security issue in Vault, _please responsibly disclose_ by contacting us at [security@hashicorp.com](mailto:security@hashicorp.com). | ||||||
|  |  | ||||||
| ========= | ---- | ||||||
|  |  | ||||||
| -	Website: https://www.vaultproject.io | -	Website: https://www.vaultproject.io | ||||||
| -	IRC: `#vault-tool` on Freenode | -	IRC: `#vault-tool` on Freenode | ||||||
|   | |||||||
| @@ -5,8 +5,6 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"golang.org/x/net/http2" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // testHTTPServer creates a test HTTP server that handles requests until | // testHTTPServer creates a test HTTP server that handles requests until | ||||||
| @@ -19,9 +17,6 @@ func testHTTPServer( | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	server := &http.Server{Handler: handler} | 	server := &http.Server{Handler: handler} | ||||||
| 	if err := http2.ConfigureServer(server, nil); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	go server.Serve(ln) | 	go server.Serve(ln) | ||||||
|  |  | ||||||
| 	config := DefaultConfig() | 	config := DefaultConfig() | ||||||
|   | |||||||
							
								
								
									
										236
									
								
								api/client.go
									
									
									
									
									
								
							
							
						
						
									
										236
									
								
								api/client.go
									
									
									
									
									
								
							| @@ -13,12 +13,12 @@ import ( | |||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"golang.org/x/net/http2" | 	"github.com/hashicorp/errwrap" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-cleanhttp" | 	"github.com/hashicorp/go-cleanhttp" | ||||||
| 	"github.com/hashicorp/go-rootcerts" | 	"github.com/hashicorp/go-rootcerts" | ||||||
| 	"github.com/hashicorp/vault/helper/parseutil" | 	"github.com/hashicorp/vault/helper/parseutil" | ||||||
| 	"github.com/sethgrid/pester" | 	"github.com/sethgrid/pester" | ||||||
|  | 	"golang.org/x/net/http2" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const EnvVaultAddress = "VAULT_ADDR" | const EnvVaultAddress = "VAULT_ADDR" | ||||||
| @@ -43,24 +43,31 @@ type WrappingLookupFunc func(operation, path string) string | |||||||
|  |  | ||||||
| // Config is used to configure the creation of the client. | // Config is used to configure the creation of the client. | ||||||
| type Config struct { | type Config struct { | ||||||
|  | 	modifyLock sync.RWMutex | ||||||
|  |  | ||||||
| 	// Address is the address of the Vault server. This should be a complete | 	// Address is the address of the Vault server. This should be a complete | ||||||
| 	// URL such as "http://vault.example.com". If you need a custom SSL | 	// URL such as "http://vault.example.com". If you need a custom SSL | ||||||
| 	// cert or want to enable insecure mode, you need to specify a custom | 	// cert or want to enable insecure mode, you need to specify a custom | ||||||
| 	// HttpClient. | 	// HttpClient. | ||||||
| 	Address string | 	Address string | ||||||
|  |  | ||||||
| 	// HttpClient is the HTTP client to use, which will currently always have the | 	// HttpClient is the HTTP client to use. Vault sets sane defaults for the | ||||||
| 	// same values as http.DefaultClient. This is used to control redirect behavior. | 	// http.Client and its associated http.Transport created in DefaultConfig. | ||||||
|  | 	// If you must modify Vault's defaults, it is suggested that you start with | ||||||
|  | 	// that client and modify as needed rather than start with an empty client | ||||||
|  | 	// (or http.DefaultClient). | ||||||
| 	HttpClient *http.Client | 	HttpClient *http.Client | ||||||
|  |  | ||||||
| 	redirectSetup sync.Once |  | ||||||
|  |  | ||||||
| 	// MaxRetries controls the maximum number of times to retry when a 5xx error | 	// MaxRetries controls the maximum number of times to retry when a 5xx error | ||||||
| 	// occurs. Set to 0 or less to disable retrying. Defaults to 0. | 	// occurs. Set to 0 or less to disable retrying. Defaults to 0. | ||||||
| 	MaxRetries int | 	MaxRetries int | ||||||
|  |  | ||||||
| 	// Timeout is for setting custom timeout parameter in the HttpClient | 	// Timeout is for setting custom timeout parameter in the HttpClient | ||||||
| 	Timeout time.Duration | 	Timeout time.Duration | ||||||
|  |  | ||||||
|  | 	// If there is an error when creating the configuration, this will be the | ||||||
|  | 	// error | ||||||
|  | 	Error error | ||||||
| } | } | ||||||
|  |  | ||||||
| // TLSConfig contains the parameters needed to configure TLS on the HTTP client | // TLSConfig contains the parameters needed to configure TLS on the HTTP client | ||||||
| @@ -93,60 +100,91 @@ type TLSConfig struct { | |||||||
| // | // | ||||||
| // The default Address is https://127.0.0.1:8200, but this can be overridden by | // The default Address is https://127.0.0.1:8200, but this can be overridden by | ||||||
| // setting the `VAULT_ADDR` environment variable. | // setting the `VAULT_ADDR` environment variable. | ||||||
|  | // | ||||||
|  | // If an error is encountered, this will return nil. | ||||||
| func DefaultConfig() *Config { | func DefaultConfig() *Config { | ||||||
| 	config := &Config{ | 	config := &Config{ | ||||||
| 		Address:    "https://127.0.0.1:8200", | 		Address:    "https://127.0.0.1:8200", | ||||||
| 		HttpClient: cleanhttp.DefaultClient(), | 		HttpClient: cleanhttp.DefaultClient(), | ||||||
| 	} | 	} | ||||||
| 	config.HttpClient.Timeout = time.Second * 60 | 	config.HttpClient.Timeout = time.Second * 60 | ||||||
|  |  | ||||||
| 	transport := config.HttpClient.Transport.(*http.Transport) | 	transport := config.HttpClient.Transport.(*http.Transport) | ||||||
| 	transport.TLSHandshakeTimeout = 10 * time.Second | 	transport.TLSHandshakeTimeout = 10 * time.Second | ||||||
| 	transport.TLSClientConfig = &tls.Config{ | 	transport.TLSClientConfig = &tls.Config{ | ||||||
| 		MinVersion: tls.VersionTLS12, | 		MinVersion: tls.VersionTLS12, | ||||||
| 	} | 	} | ||||||
|  | 	if err := http2.ConfigureTransport(transport); err != nil { | ||||||
|  | 		config.Error = err | ||||||
|  | 		return config | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if v := os.Getenv(EnvVaultAddress); v != "" { | 	if err := config.ReadEnvironment(); err != nil { | ||||||
| 		config.Address = v | 		config.Error = err | ||||||
|  | 		return config | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Ensure redirects are not automatically followed | ||||||
|  | 	// Note that this is sane for the API client as it has its own | ||||||
|  | 	// redirect handling logic (and thus also for command/meta), | ||||||
|  | 	// but in e.g. http_test actual redirect handling is necessary | ||||||
|  | 	config.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { | ||||||
|  | 		// Returning this value causes the Go net library to not close the | ||||||
|  | 		// response body and to nil out the error. Otherwise pester tries | ||||||
|  | 		// three times on every redirect because it sees an error from this | ||||||
|  | 		// function (to prevent redirects) passing through to it. | ||||||
|  | 		return http.ErrUseLastResponse | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return config | 	return config | ||||||
| } | } | ||||||
|  |  | ||||||
| // ConfigureTLS takes a set of TLS configurations and applies those to the the HTTP client. | // ConfigureTLS takes a set of TLS configurations and applies those to the the | ||||||
|  | // HTTP client. | ||||||
| func (c *Config) ConfigureTLS(t *TLSConfig) error { | func (c *Config) ConfigureTLS(t *TLSConfig) error { | ||||||
| 	if c.HttpClient == nil { | 	if c.HttpClient == nil { | ||||||
| 		c.HttpClient = DefaultConfig().HttpClient | 		c.HttpClient = DefaultConfig().HttpClient | ||||||
| 	} | 	} | ||||||
|  | 	clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig | ||||||
|  |  | ||||||
| 	var clientCert tls.Certificate | 	var clientCert tls.Certificate | ||||||
| 	foundClientCert := false | 	foundClientCert := false | ||||||
| 	if t.CACert != "" || t.CAPath != "" || t.ClientCert != "" || t.ClientKey != "" || t.Insecure { |  | ||||||
| 		if t.ClientCert != "" && t.ClientKey != "" { | 	switch { | ||||||
| 			var err error | 	case t.ClientCert != "" && t.ClientKey != "": | ||||||
| 			clientCert, err = tls.LoadX509KeyPair(t.ClientCert, t.ClientKey) | 		var err error | ||||||
| 			if err != nil { | 		clientCert, err = tls.LoadX509KeyPair(t.ClientCert, t.ClientKey) | ||||||
| 				return err | 		if err != nil { | ||||||
| 			} | 			return err | ||||||
| 			foundClientCert = true | 		} | ||||||
| 		} else if t.ClientCert != "" || t.ClientKey != "" { | 		foundClientCert = true | ||||||
| 			return fmt.Errorf("Both client cert and client key must be provided") | 	case t.ClientCert != "" || t.ClientKey != "": | ||||||
|  | 		return fmt.Errorf("Both client cert and client key must be provided") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if t.CACert != "" || t.CAPath != "" { | ||||||
|  | 		rootConfig := &rootcerts.Config{ | ||||||
|  | 			CAFile: t.CACert, | ||||||
|  | 			CAPath: t.CAPath, | ||||||
|  | 		} | ||||||
|  | 		if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil { | ||||||
|  | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig | 	if t.Insecure { | ||||||
| 	rootConfig := &rootcerts.Config{ | 		clientTLSConfig.InsecureSkipVerify = true | ||||||
| 		CAFile: t.CACert, |  | ||||||
| 		CAPath: t.CAPath, |  | ||||||
| 	} | 	} | ||||||
| 	if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	clientTLSConfig.InsecureSkipVerify = t.Insecure |  | ||||||
|  |  | ||||||
| 	if foundClientCert { | 	if foundClientCert { | ||||||
| 		clientTLSConfig.Certificates = []tls.Certificate{clientCert} | 		// We use this function to ignore the server's preferential list of | ||||||
|  | 		// CAs, otherwise any CA used for the cert auth backend must be in the | ||||||
|  | 		// server's CA pool | ||||||
|  | 		clientTLSConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { | ||||||
|  | 			return &clientCert, nil | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if t.TLSServerName != "" { | 	if t.TLSServerName != "" { | ||||||
| 		clientTLSConfig.ServerName = t.TLSServerName | 		clientTLSConfig.ServerName = t.TLSServerName | ||||||
| 	} | 	} | ||||||
| @@ -154,9 +192,8 @@ func (c *Config) ConfigureTLS(t *TLSConfig) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // ReadEnvironment reads configuration information from the | // ReadEnvironment reads configuration information from the environment. If | ||||||
| // environment. If there is an error, no configuration value | // there is an error, no configuration value is updated. | ||||||
| // is updated. |  | ||||||
| func (c *Config) ReadEnvironment() error { | func (c *Config) ReadEnvironment() error { | ||||||
| 	var envAddress string | 	var envAddress string | ||||||
| 	var envCACert string | 	var envCACert string | ||||||
| @@ -218,6 +255,10 @@ func (c *Config) ReadEnvironment() error { | |||||||
| 		TLSServerName: envTLSServerName, | 		TLSServerName: envTLSServerName, | ||||||
| 		Insecure:      envInsecure, | 		Insecure:      envInsecure, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	if err := c.ConfigureTLS(t); err != nil { | 	if err := c.ConfigureTLS(t); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -237,10 +278,9 @@ func (c *Config) ReadEnvironment() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Client is the client to the Vault API. Create a client with NewClient. Note: | // Client is the client to the Vault API. Create a client with NewClient. | ||||||
| // it is not safe to modify client configuration from multiple goroutines at |  | ||||||
| // once. Set configuration first, then run requests. |  | ||||||
| type Client struct { | type Client struct { | ||||||
|  | 	modifyLock         sync.RWMutex | ||||||
| 	addr               *url.URL | 	addr               *url.URL | ||||||
| 	config             *Config | 	config             *Config | ||||||
| 	token              string | 	token              string | ||||||
| @@ -250,24 +290,29 @@ type Client struct { | |||||||
| 	policyOverride     bool | 	policyOverride     bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetMFACreds sets the MFA credentials supplied either via the environment |  | ||||||
| // variable or via the command line. |  | ||||||
| func (c *Client) SetMFACreds(creds []string) { |  | ||||||
| 	c.mfaCreds = creds |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewClient returns a new client for the given configuration. | // NewClient returns a new client for the given configuration. | ||||||
| // | // | ||||||
|  | // If the configuration is nil, Vault will use configuration from | ||||||
|  | // DefaultConfig(), which is the recommended starting configuration. | ||||||
|  | // | ||||||
| // If the environment variable `VAULT_TOKEN` is present, the token will be | // If the environment variable `VAULT_TOKEN` is present, the token will be | ||||||
| // automatically added to the client. Otherwise, you must manually call | // automatically added to the client. Otherwise, you must manually call | ||||||
| // `SetToken()`. | // `SetToken()`. | ||||||
| func NewClient(c *Config) (*Client, error) { | func NewClient(c *Config) (*Client, error) { | ||||||
| 	if c == nil { | 	def := DefaultConfig() | ||||||
| 		c = DefaultConfig() | 	if def == nil { | ||||||
| 		if err := c.ReadEnvironment(); err != nil { | 		return nil, fmt.Errorf("could not create/read default configuration") | ||||||
| 			return nil, fmt.Errorf("error reading environment: %v", err) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  | 	if def.Error != nil { | ||||||
|  | 		return nil, errwrap.Wrapf("error encountered setting up default configuration: {{err}}", def.Error) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if c == nil { | ||||||
|  | 		c = def | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	u, err := url.Parse(c.Address) | 	u, err := url.Parse(c.Address) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -275,41 +320,19 @@ func NewClient(c *Config) (*Client, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if c.HttpClient == nil { | 	if c.HttpClient == nil { | ||||||
| 		c.HttpClient = DefaultConfig().HttpClient | 		c.HttpClient = def.HttpClient | ||||||
| 	} | 	} | ||||||
| 	if c.HttpClient.Transport == nil { | 	if c.HttpClient.Transport == nil { | ||||||
| 		c.HttpClient.Transport = cleanhttp.DefaultTransport() | 		c.HttpClient.Transport = def.HttpClient.Transport | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if tp, ok := c.HttpClient.Transport.(*http.Transport); ok { |  | ||||||
| 		if err := http2.ConfigureTransport(tp); err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	redirFunc := func() { |  | ||||||
| 		// Ensure redirects are not automatically followed |  | ||||||
| 		// Note that this is sane for the API client as it has its own |  | ||||||
| 		// redirect handling logic (and thus also for command/meta), |  | ||||||
| 		// but in e.g. http_test actual redirect handling is necessary |  | ||||||
| 		c.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { |  | ||||||
| 			// Returning this value causes the Go net library to not close the |  | ||||||
| 			// response body and to nil out the error. Otherwise pester tries |  | ||||||
| 			// three times on every redirect because it sees an error from this |  | ||||||
| 			// function (to prevent redirects) passing through to it. |  | ||||||
| 			return http.ErrUseLastResponse |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	c.redirectSetup.Do(redirFunc) |  | ||||||
|  |  | ||||||
| 	client := &Client{ | 	client := &Client{ | ||||||
| 		addr:   u, | 		addr:   u, | ||||||
| 		config: c, | 		config: c, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if token := os.Getenv(EnvVaultToken); token != "" { | 	if token := os.Getenv(EnvVaultToken); token != "" { | ||||||
| 		client.SetToken(token) | 		client.token = token | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return client, nil | 	return client, nil | ||||||
| @@ -319,6 +342,9 @@ func NewClient(c *Config) (*Client, error) { | |||||||
| // "<Scheme>://<Host>:<Port>". Setting this on a client will override the | // "<Scheme>://<Host>:<Port>". Setting this on a client will override the | ||||||
| // value of VAULT_ADDR environment variable. | // value of VAULT_ADDR environment variable. | ||||||
| func (c *Client) SetAddress(addr string) error { | func (c *Client) SetAddress(addr string) error { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	var err error | 	var err error | ||||||
| 	if c.addr, err = url.Parse(addr); err != nil { | 	if c.addr, err = url.Parse(addr); err != nil { | ||||||
| 		return fmt.Errorf("failed to set address: %v", err) | 		return fmt.Errorf("failed to set address: %v", err) | ||||||
| @@ -329,56 +355,112 @@ func (c *Client) SetAddress(addr string) error { | |||||||
|  |  | ||||||
| // Address returns the Vault URL the client is configured to connect to | // Address returns the Vault URL the client is configured to connect to | ||||||
| func (c *Client) Address() string { | func (c *Client) Address() string { | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	defer c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
| 	return c.addr.String() | 	return c.addr.String() | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetMaxRetries sets the number of retries that will be used in the case of certain errors | // SetMaxRetries sets the number of retries that will be used in the case of certain errors | ||||||
| func (c *Client) SetMaxRetries(retries int) { | func (c *Client) SetMaxRetries(retries int) { | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	c.config.modifyLock.Lock() | ||||||
|  | 	defer c.config.modifyLock.Unlock() | ||||||
|  | 	c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
| 	c.config.MaxRetries = retries | 	c.config.MaxRetries = retries | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetClientTimeout sets the client request timeout | // SetClientTimeout sets the client request timeout | ||||||
| func (c *Client) SetClientTimeout(timeout time.Duration) { | func (c *Client) SetClientTimeout(timeout time.Duration) { | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	c.config.modifyLock.Lock() | ||||||
|  | 	defer c.config.modifyLock.Unlock() | ||||||
|  | 	c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
| 	c.config.Timeout = timeout | 	c.config.Timeout = timeout | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs | // SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs | ||||||
| // for a given operation and path | // for a given operation and path | ||||||
| func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) { | func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	c.wrappingLookupFunc = lookupFunc | 	c.wrappingLookupFunc = lookupFunc | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // SetMFACreds sets the MFA credentials supplied either via the environment | ||||||
|  | // variable or via the command line. | ||||||
|  | func (c *Client) SetMFACreds(creds []string) { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
|  | 	c.mfaCreds = creds | ||||||
|  | } | ||||||
|  |  | ||||||
| // Token returns the access token being used by this client. It will | // Token returns the access token being used by this client. It will | ||||||
| // return the empty string if there is no token set. | // return the empty string if there is no token set. | ||||||
| func (c *Client) Token() string { | func (c *Client) Token() string { | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	defer c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
| 	return c.token | 	return c.token | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetToken sets the token directly. This won't perform any auth | // SetToken sets the token directly. This won't perform any auth | ||||||
| // verification, it simply sets the token properly for future requests. | // verification, it simply sets the token properly for future requests. | ||||||
| func (c *Client) SetToken(v string) { | func (c *Client) SetToken(v string) { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	c.token = v | 	c.token = v | ||||||
| } | } | ||||||
|  |  | ||||||
| // ClearToken deletes the token if it is set or does nothing otherwise. | // ClearToken deletes the token if it is set or does nothing otherwise. | ||||||
| func (c *Client) ClearToken() { | func (c *Client) ClearToken() { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	c.token = "" | 	c.token = "" | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetHeaders sets the headers to be used for future requests. | // SetHeaders sets the headers to be used for future requests. | ||||||
| func (c *Client) SetHeaders(headers http.Header) { | func (c *Client) SetHeaders(headers http.Header) { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	c.headers = headers | 	c.headers = headers | ||||||
| } | } | ||||||
|  |  | ||||||
| // Clone creates a copy of this client. | // Clone creates a new client with the same configuration. Note that the same | ||||||
|  | // underlying http.Client is used; modifying the client from more than one | ||||||
|  | // goroutine at once may not be safe, so modify the client as needed and then | ||||||
|  | // clone. | ||||||
| func (c *Client) Clone() (*Client, error) { | func (c *Client) Clone() (*Client, error) { | ||||||
| 	return NewClient(c.config) | 	c.modifyLock.RLock() | ||||||
|  | 	c.config.modifyLock.RLock() | ||||||
|  | 	config := c.config | ||||||
|  | 	c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
|  | 	newConfig := &Config{ | ||||||
|  | 		Address:    config.Address, | ||||||
|  | 		HttpClient: config.HttpClient, | ||||||
|  | 		MaxRetries: config.MaxRetries, | ||||||
|  | 		Timeout:    config.Timeout, | ||||||
|  | 	} | ||||||
|  | 	config.modifyLock.RUnlock() | ||||||
|  |  | ||||||
|  | 	return NewClient(newConfig) | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetPolicyOverride sets whether requests should be sent with the policy | // SetPolicyOverride sets whether requests should be sent with the policy | ||||||
| // override flag to request overriding soft-mandatory Sentinel policies (both | // override flag to request overriding soft-mandatory Sentinel policies (both | ||||||
| // RGPs and EGPs) | // RGPs and EGPs) | ||||||
| func (c *Client) SetPolicyOverride(override bool) { | func (c *Client) SetPolicyOverride(override bool) { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  |  | ||||||
| 	c.policyOverride = override | 	c.policyOverride = override | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -386,6 +468,9 @@ func (c *Client) SetPolicyOverride(override bool) { | |||||||
| // configured for this client. This is an advanced method and generally | // configured for this client. This is an advanced method and generally | ||||||
| // doesn't need to be called externally. | // doesn't need to be called externally. | ||||||
| func (c *Client) NewRequest(method, requestPath string) *Request { | func (c *Client) NewRequest(method, requestPath string) *Request { | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	defer c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
| 	// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV | 	// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV | ||||||
| 	// record and take the highest match; this is not designed for high-availability, just discovery | 	// record and take the highest match; this is not designed for high-availability, just discovery | ||||||
| 	var host string = c.addr.Host | 	var host string = c.addr.Host | ||||||
| @@ -442,6 +527,11 @@ func (c *Client) NewRequest(method, requestPath string) *Request { | |||||||
| // a Vault server not configured with this client. This is an advanced operation | // a Vault server not configured with this client. This is an advanced operation | ||||||
| // that generally won't need to be called externally. | // that generally won't need to be called externally. | ||||||
| func (c *Client) RawRequest(r *Request) (*Response, error) { | func (c *Client) RawRequest(r *Request) (*Response, error) { | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	c.config.modifyLock.RLock() | ||||||
|  | 	defer c.config.modifyLock.RUnlock() | ||||||
|  | 	c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
| 	redirectCount := 0 | 	redirectCount := 0 | ||||||
| START: | START: | ||||||
| 	req, err := r.ToHTTP() | 	req, err := r.ToHTTP() | ||||||
|   | |||||||
| @@ -163,8 +163,8 @@ func TestClientEnvSettings(t *testing.T) { | |||||||
| 	if len(tlsConfig.RootCAs.Subjects()) == 0 { | 	if len(tlsConfig.RootCAs.Subjects()) == 0 { | ||||||
| 		t.Fatalf("bad: expected a cert pool with at least one subject") | 		t.Fatalf("bad: expected a cert pool with at least one subject") | ||||||
| 	} | 	} | ||||||
| 	if len(tlsConfig.Certificates) != 1 { | 	if tlsConfig.GetClientCertificate == nil { | ||||||
| 		t.Fatalf("bad: expected client tls config to have a client certificate") | 		t.Fatalf("bad: expected client tls config to have a certificate getter") | ||||||
| 	} | 	} | ||||||
| 	if tlsConfig.InsecureSkipVerify != true { | 	if tlsConfig.InsecureSkipVerify != true { | ||||||
| 		t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify) | 		t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify) | ||||||
| @@ -213,3 +213,16 @@ func TestClientNonTransportRoundTripper(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestClone(t *testing.T) { | ||||||
|  | 	client1, err1 := NewClient(nil) | ||||||
|  | 	if err1 != nil { | ||||||
|  | 		t.Fatalf("NewClient failed: %v", err1) | ||||||
|  | 	} | ||||||
|  | 	client2, err2 := client1.Clone() | ||||||
|  | 	if err2 != nil { | ||||||
|  | 		t.Fatalf("Clone failed: %v", err2) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_ = client2 | ||||||
|  | } | ||||||
|   | |||||||
| @@ -50,12 +50,13 @@ var ( | |||||||
| type Renewer struct { | type Renewer struct { | ||||||
| 	l sync.Mutex | 	l sync.Mutex | ||||||
|  |  | ||||||
| 	client  *Client | 	client    *Client | ||||||
| 	secret  *Secret | 	secret    *Secret | ||||||
| 	grace   time.Duration | 	grace     time.Duration | ||||||
| 	random  *rand.Rand | 	random    *rand.Rand | ||||||
| 	doneCh  chan error | 	increment int | ||||||
| 	renewCh chan *RenewOutput | 	doneCh    chan error | ||||||
|  | 	renewCh   chan *RenewOutput | ||||||
|  |  | ||||||
| 	stopped bool | 	stopped bool | ||||||
| 	stopCh  chan struct{} | 	stopCh  chan struct{} | ||||||
| @@ -79,6 +80,11 @@ type RenewerInput struct { | |||||||
| 	// RenewBuffer is the size of the buffered channel where renew messages are | 	// RenewBuffer is the size of the buffered channel where renew messages are | ||||||
| 	// dispatched. | 	// dispatched. | ||||||
| 	RenewBuffer int | 	RenewBuffer int | ||||||
|  |  | ||||||
|  | 	// The new TTL, in seconds, that should be set on the lease. The TTL set | ||||||
|  | 	// here may or may not be honored by the vault server, based on Vault | ||||||
|  | 	// configuration or any associated max TTL values. | ||||||
|  | 	Increment int | ||||||
| } | } | ||||||
|  |  | ||||||
| // RenewOutput is the metadata returned to the client (if it's listening) to | // RenewOutput is the metadata returned to the client (if it's listening) to | ||||||
| @@ -120,12 +126,13 @@ func (c *Client) NewRenewer(i *RenewerInput) (*Renewer, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &Renewer{ | 	return &Renewer{ | ||||||
| 		client:  c, | 		client:    c, | ||||||
| 		secret:  secret, | 		secret:    secret, | ||||||
| 		grace:   grace, | 		grace:     grace, | ||||||
| 		random:  random, | 		increment: i.Increment, | ||||||
| 		doneCh:  make(chan error, 1), | 		random:    random, | ||||||
| 		renewCh: make(chan *RenewOutput, renewBuffer), | 		doneCh:    make(chan error, 1), | ||||||
|  | 		renewCh:   make(chan *RenewOutput, renewBuffer), | ||||||
|  |  | ||||||
| 		stopped: false, | 		stopped: false, | ||||||
| 		stopCh:  make(chan struct{}), | 		stopCh:  make(chan struct{}), | ||||||
| @@ -245,7 +252,7 @@ func (r *Renewer) renewLease() error { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Renew the lease. | 		// Renew the lease. | ||||||
| 		renewal, err := client.Sys().Renew(leaseID, 0) | 		renewal, err := client.Sys().Renew(leaseID, r.increment) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -224,6 +224,7 @@ func (s *Secret) TokenTTL() (time.Duration, error) { | |||||||
| // available in WrappedAccessor. | // available in WrappedAccessor. | ||||||
| type SecretWrapInfo struct { | type SecretWrapInfo struct { | ||||||
| 	Token           string    `json:"token"` | 	Token           string    `json:"token"` | ||||||
|  | 	Accessor        string    `json:"accessor"` | ||||||
| 	TTL             int       `json:"ttl"` | 	TTL             int       `json:"ttl"` | ||||||
| 	CreationTime    time.Time `json:"creation_time"` | 	CreationTime    time.Time `json:"creation_time"` | ||||||
| 	CreationPath    string    `json:"creation_path"` | 	CreationPath    string    `json:"creation_path"` | ||||||
|   | |||||||
| @@ -26,6 +26,7 @@ func TestParseSecret(t *testing.T) { | |||||||
| 	], | 	], | ||||||
| 	"wrap_info": { | 	"wrap_info": { | ||||||
| 		"token": "token", | 		"token": "token", | ||||||
|  | 		"accessor": "accessor", | ||||||
| 		"ttl": 60, | 		"ttl": 60, | ||||||
| 		"creation_time": "2016-06-07T15:52:10-04:00", | 		"creation_time": "2016-06-07T15:52:10-04:00", | ||||||
| 		"wrapped_accessor": "abcd1234" | 		"wrapped_accessor": "abcd1234" | ||||||
| @@ -51,6 +52,7 @@ func TestParseSecret(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 		WrapInfo: &api.SecretWrapInfo{ | 		WrapInfo: &api.SecretWrapInfo{ | ||||||
| 			Token:           "token", | 			Token:           "token", | ||||||
|  | 			Accessor:        "accessor", | ||||||
| 			TTL:             60, | 			TTL:             60, | ||||||
| 			CreationTime:    rawTime, | 			CreationTime:    rawTime, | ||||||
| 			WrappedAccessor: "abcd1234", | 			WrappedAccessor: "abcd1234", | ||||||
|   | |||||||
| @@ -87,6 +87,7 @@ type EnableAuthOptions struct { | |||||||
| 	Config      AuthConfigInput `json:"config" structs:"config"` | 	Config      AuthConfigInput `json:"config" structs:"config"` | ||||||
| 	Local       bool            `json:"local" structs:"local"` | 	Local       bool            `json:"local" structs:"local"` | ||||||
| 	PluginName  string          `json:"plugin_name,omitempty" structs:"plugin_name,omitempty"` | 	PluginName  string          `json:"plugin_name,omitempty" structs:"plugin_name,omitempty"` | ||||||
|  | 	SealWrap    bool            `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type AuthConfigInput struct { | type AuthConfigInput struct { | ||||||
| @@ -99,6 +100,7 @@ type AuthMount struct { | |||||||
| 	Accessor    string           `json:"accessor" structs:"accessor" mapstructure:"accessor"` | 	Accessor    string           `json:"accessor" structs:"accessor" mapstructure:"accessor"` | ||||||
| 	Config      AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"` | 	Config      AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"` | ||||||
| 	Local       bool             `json:"local" structs:"local" mapstructure:"local"` | 	Local       bool             `json:"local" structs:"local" mapstructure:"local"` | ||||||
|  | 	SealWrap    bool             `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type AuthConfigOutput struct { | type AuthConfigOutput struct { | ||||||
|   | |||||||
| @@ -1,7 +1,15 @@ | |||||||
| package api | package api | ||||||
|  |  | ||||||
| func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) { | func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) { | ||||||
| 	r := c.c.NewRequest("GET", "/v1/sys/generate-root/attempt") | 	return c.generateRootStatusCommon("/v1/sys/generate-root/attempt") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) GenerateDROperationTokenStatus() (*GenerateRootStatusResponse, error) { | ||||||
|  | 	return c.generateRootStatusCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) generateRootStatusCommon(path string) (*GenerateRootStatusResponse, error) { | ||||||
|  | 	r := c.c.NewRequest("GET", path) | ||||||
| 	resp, err := c.c.RawRequest(r) | 	resp, err := c.c.RawRequest(r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -14,12 +22,20 @@ func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Sys) GenerateRootInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) { | func (c *Sys) GenerateRootInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) { | ||||||
|  | 	return c.generateRootInitCommon("/v1/sys/generate-root/attempt", otp, pgpKey) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) GenerateDROperationTokenInit(otp, pgpKey string) (*GenerateRootStatusResponse, error) { | ||||||
|  | 	return c.generateRootInitCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt", otp, pgpKey) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) generateRootInitCommon(path, otp, pgpKey string) (*GenerateRootStatusResponse, error) { | ||||||
| 	body := map[string]interface{}{ | 	body := map[string]interface{}{ | ||||||
| 		"otp":     otp, | 		"otp":     otp, | ||||||
| 		"pgp_key": pgpKey, | 		"pgp_key": pgpKey, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r := c.c.NewRequest("PUT", "/v1/sys/generate-root/attempt") | 	r := c.c.NewRequest("PUT", path) | ||||||
| 	if err := r.SetJSONBody(body); err != nil { | 	if err := r.SetJSONBody(body); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -36,7 +52,15 @@ func (c *Sys) GenerateRootInit(otp, pgpKey string) (*GenerateRootStatusResponse, | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Sys) GenerateRootCancel() error { | func (c *Sys) GenerateRootCancel() error { | ||||||
| 	r := c.c.NewRequest("DELETE", "/v1/sys/generate-root/attempt") | 	return c.generateRootCancelCommon("/v1/sys/generate-root/attempt") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) GenerateDROperationTokenCancel() error { | ||||||
|  | 	return c.generateRootCancelCommon("/v1/sys/replication/dr/secondary/generate-operation-token/attempt") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) generateRootCancelCommon(path string) error { | ||||||
|  | 	r := c.c.NewRequest("DELETE", path) | ||||||
| 	resp, err := c.c.RawRequest(r) | 	resp, err := c.c.RawRequest(r) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		defer resp.Body.Close() | 		defer resp.Body.Close() | ||||||
| @@ -45,12 +69,20 @@ func (c *Sys) GenerateRootCancel() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Sys) GenerateRootUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) { | func (c *Sys) GenerateRootUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) { | ||||||
|  | 	return c.generateRootUpdateCommon("/v1/sys/generate-root/update", shard, nonce) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) GenerateDROperationTokenUpdate(shard, nonce string) (*GenerateRootStatusResponse, error) { | ||||||
|  | 	return c.generateRootUpdateCommon("/v1/sys/replication/dr/secondary/generate-operation-token/update", shard, nonce) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Sys) generateRootUpdateCommon(path, shard, nonce string) (*GenerateRootStatusResponse, error) { | ||||||
| 	body := map[string]interface{}{ | 	body := map[string]interface{}{ | ||||||
| 		"key":   shard, | 		"key":   shard, | ||||||
| 		"nonce": nonce, | 		"nonce": nonce, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	r := c.c.NewRequest("PUT", "/v1/sys/generate-root/update") | 	r := c.c.NewRequest("PUT", path) | ||||||
| 	if err := r.SetJSONBody(body); err != nil { | 	if err := r.SetJSONBody(body); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -72,6 +104,7 @@ type GenerateRootStatusResponse struct { | |||||||
| 	Progress         int | 	Progress         int | ||||||
| 	Required         int | 	Required         int | ||||||
| 	Complete         bool | 	Complete         bool | ||||||
|  | 	EncodedToken     string `json:"encoded_token"` | ||||||
| 	EncodedRootToken string `json:"encoded_root_token"` | 	EncodedRootToken string `json:"encoded_root_token"` | ||||||
| 	PGPFingerprint   string `json:"pgp_fingerprint"` | 	PGPFingerprint   string `json:"pgp_fingerprint"` | ||||||
| } | } | ||||||
|   | |||||||
| @@ -125,6 +125,7 @@ type MountInput struct { | |||||||
| 	Config      MountConfigInput `json:"config" structs:"config"` | 	Config      MountConfigInput `json:"config" structs:"config"` | ||||||
| 	Local       bool             `json:"local" structs:"local"` | 	Local       bool             `json:"local" structs:"local"` | ||||||
| 	PluginName  string           `json:"plugin_name,omitempty" structs:"plugin_name"` | 	PluginName  string           `json:"plugin_name,omitempty" structs:"plugin_name"` | ||||||
|  | 	SealWrap    bool             `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type MountConfigInput struct { | type MountConfigInput struct { | ||||||
| @@ -132,7 +133,6 @@ type MountConfigInput struct { | |||||||
| 	MaxLeaseTTL     string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` | 	MaxLeaseTTL     string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` | ||||||
| 	ForceNoCache    bool   `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"` | 	ForceNoCache    bool   `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"` | ||||||
| 	PluginName      string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` | 	PluginName      string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` | ||||||
| 	SealWrap        bool   `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type MountOutput struct { | type MountOutput struct { | ||||||
| @@ -141,6 +141,7 @@ type MountOutput struct { | |||||||
| 	Accessor    string            `json:"accessor" structs:"accessor"` | 	Accessor    string            `json:"accessor" structs:"accessor"` | ||||||
| 	Config      MountConfigOutput `json:"config" structs:"config"` | 	Config      MountConfigOutput `json:"config" structs:"config"` | ||||||
| 	Local       bool              `json:"local" structs:"local"` | 	Local       bool              `json:"local" structs:"local"` | ||||||
|  | 	SealWrap    bool              `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type MountConfigOutput struct { | type MountConfigOutput struct { | ||||||
| @@ -148,5 +149,4 @@ type MountConfigOutput struct { | |||||||
| 	MaxLeaseTTL     int    `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` | 	MaxLeaseTTL     int    `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` | ||||||
| 	ForceNoCache    bool   `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"` | 	ForceNoCache    bool   `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"` | ||||||
| 	PluginName      string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` | 	PluginName      string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` | ||||||
| 	SealWrap        bool   `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"` |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -49,12 +49,14 @@ func sealStatusRequest(c *Sys, r *Request) (*SealStatusResponse, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| type SealStatusResponse struct { | type SealStatusResponse struct { | ||||||
| 	Sealed      bool   `json:"sealed"` | 	Type         string `json:"type"` | ||||||
| 	T           int    `json:"t"` | 	Sealed       bool   `json:"sealed"` | ||||||
| 	N           int    `json:"n"` | 	T            int    `json:"t"` | ||||||
| 	Progress    int    `json:"progress"` | 	N            int    `json:"n"` | ||||||
| 	Nonce       string `json:"nonce"` | 	Progress     int    `json:"progress"` | ||||||
| 	Version     string `json:"version"` | 	Nonce        string `json:"nonce"` | ||||||
| 	ClusterName string `json:"cluster_name,omitempty"` | 	Version      string `json:"version"` | ||||||
| 	ClusterID   string `json:"cluster_id,omitempty"` | 	ClusterName  string `json:"cluster_name,omitempty"` | ||||||
|  | 	ClusterID    string `json:"cluster_id,omitempty"` | ||||||
|  | 	RecoverySeal bool   `json:"recovery_seal"` | ||||||
| } | } | ||||||
|   | |||||||
| @@ -146,7 +146,7 @@ func (f *AuditFormatter) FormatRequest( | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !config.OmitTime { | 	if !config.OmitTime { | ||||||
| 		reqEntry.Time = time.Now().UTC().Format(time.RFC3339) | 		reqEntry.Time = time.Now().UTC().Format(time.RFC3339Nano) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return f.AuditFormatWriter.WriteRequest(w, reqEntry) | 	return f.AuditFormatWriter.WriteRequest(w, reqEntry) | ||||||
| @@ -242,12 +242,13 @@ func (f *AuditFormatter) FormatResponse( | |||||||
|  |  | ||||||
| 		// Cache and restore accessor in the response | 		// Cache and restore accessor in the response | ||||||
| 		if resp != nil { | 		if resp != nil { | ||||||
| 			var accessor, wrappedAccessor string | 			var accessor, wrappedAccessor, wrappingAccessor string | ||||||
| 			if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" { | 			if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" { | ||||||
| 				accessor = resp.Auth.Accessor | 				accessor = resp.Auth.Accessor | ||||||
| 			} | 			} | ||||||
| 			if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { | 			if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { | ||||||
| 				wrappedAccessor = resp.WrapInfo.WrappedAccessor | 				wrappedAccessor = resp.WrapInfo.WrappedAccessor | ||||||
|  | 				wrappingAccessor = resp.WrapInfo.Accessor | ||||||
| 			} | 			} | ||||||
| 			if err := Hash(salt, resp); err != nil { | 			if err := Hash(salt, resp); err != nil { | ||||||
| 				return err | 				return err | ||||||
| @@ -258,6 +259,9 @@ func (f *AuditFormatter) FormatResponse( | |||||||
| 			if wrappedAccessor != "" { | 			if wrappedAccessor != "" { | ||||||
| 				resp.WrapInfo.WrappedAccessor = wrappedAccessor | 				resp.WrapInfo.WrappedAccessor = wrappedAccessor | ||||||
| 			} | 			} | ||||||
|  | 			if wrappingAccessor != "" { | ||||||
|  | 				resp.WrapInfo.Accessor = wrappingAccessor | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -301,6 +305,7 @@ func (f *AuditFormatter) FormatResponse( | |||||||
| 		respWrapInfo = &AuditResponseWrapInfo{ | 		respWrapInfo = &AuditResponseWrapInfo{ | ||||||
| 			TTL:             int(resp.WrapInfo.TTL / time.Second), | 			TTL:             int(resp.WrapInfo.TTL / time.Second), | ||||||
| 			Token:           token, | 			Token:           token, | ||||||
|  | 			Accessor:        resp.WrapInfo.Accessor, | ||||||
| 			CreationTime:    resp.WrapInfo.CreationTime.Format(time.RFC3339Nano), | 			CreationTime:    resp.WrapInfo.CreationTime.Format(time.RFC3339Nano), | ||||||
| 			CreationPath:    resp.WrapInfo.CreationPath, | 			CreationPath:    resp.WrapInfo.CreationPath, | ||||||
| 			WrappedAccessor: resp.WrapInfo.WrappedAccessor, | 			WrappedAccessor: resp.WrapInfo.WrappedAccessor, | ||||||
| @@ -347,7 +352,7 @@ func (f *AuditFormatter) FormatResponse( | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !config.OmitTime { | 	if !config.OmitTime { | ||||||
| 		respEntry.Time = time.Now().UTC().Format(time.RFC3339) | 		respEntry.Time = time.Now().UTC().Format(time.RFC3339Nano) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return f.AuditFormatWriter.WriteResponse(w, respEntry) | 	return f.AuditFormatWriter.WriteResponse(w, respEntry) | ||||||
| @@ -412,6 +417,7 @@ type AuditSecret struct { | |||||||
| type AuditResponseWrapInfo struct { | type AuditResponseWrapInfo struct { | ||||||
| 	TTL             int    `json:"ttl"` | 	TTL             int    `json:"ttl"` | ||||||
| 	Token           string `json:"token"` | 	Token           string `json:"token"` | ||||||
|  | 	Accessor        string `json:"accessor"` | ||||||
| 	CreationTime    string `json:"creation_time"` | 	CreationTime    string `json:"creation_time"` | ||||||
| 	CreationPath    string `json:"creation_path"` | 	CreationPath    string `json:"creation_path"` | ||||||
| 	WrappedAccessor string `json:"wrapped_accessor,omitempty"` | 	WrappedAccessor string `json:"wrapped_accessor,omitempty"` | ||||||
|   | |||||||
| @@ -93,6 +93,7 @@ func Hash(salter *salt.Salt, raw interface{}) error { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		s.Token = fn(s.Token) | 		s.Token = fn(s.Token) | ||||||
|  | 		s.Accessor = fn(s.Accessor) | ||||||
|  |  | ||||||
| 		if s.WrappedAccessor != "" { | 		if s.WrappedAccessor != "" { | ||||||
| 			s.WrappedAccessor = fn(s.WrappedAccessor) | 			s.WrappedAccessor = fn(s.WrappedAccessor) | ||||||
|   | |||||||
| @@ -148,6 +148,7 @@ func TestHash(t *testing.T) { | |||||||
| 				WrapInfo: &wrapping.ResponseWrapInfo{ | 				WrapInfo: &wrapping.ResponseWrapInfo{ | ||||||
| 					TTL:             60, | 					TTL:             60, | ||||||
| 					Token:           "bar", | 					Token:           "bar", | ||||||
|  | 					Accessor:        "flimflam", | ||||||
| 					CreationTime:    now, | 					CreationTime:    now, | ||||||
| 					WrappedAccessor: "bar", | 					WrappedAccessor: "bar", | ||||||
| 				}, | 				}, | ||||||
| @@ -160,6 +161,7 @@ func TestHash(t *testing.T) { | |||||||
| 				WrapInfo: &wrapping.ResponseWrapInfo{ | 				WrapInfo: &wrapping.ResponseWrapInfo{ | ||||||
| 					TTL:             60, | 					TTL:             60, | ||||||
| 					Token:           "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", | 					Token:           "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", | ||||||
|  | 					Accessor:        "hmac-sha256:7c9c6fe666d0af73b3ebcfbfabe6885015558213208e6635ba104047b22f6390", | ||||||
| 					CreationTime:    now, | 					CreationTime:    now, | ||||||
| 					WrappedAccessor: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", | 					WrappedAccessor: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", | ||||||
| 				}, | 				}, | ||||||
| @@ -206,6 +208,11 @@ func TestHash(t *testing.T) { | |||||||
| 		if err := Hash(localSalt, tc.Input); err != nil { | 		if err := Hash(localSalt, tc.Input); err != nil { | ||||||
| 			t.Fatalf("err: %s\n\n%s", err, input) | 			t.Fatalf("err: %s\n\n%s", err, input) | ||||||
| 		} | 		} | ||||||
|  | 		if _, ok := tc.Input.(*logical.Response); ok { | ||||||
|  | 			if !reflect.DeepEqual(tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo) { | ||||||
|  | 				t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output\n%#v", input, tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 		if !reflect.DeepEqual(tc.Input, tc.Output) { | 		if !reflect.DeepEqual(tc.Input, tc.Output) { | ||||||
| 			t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output\n%#v", input, tc.Input, tc.Output) | 			t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output\n%#v", input, tc.Input, tc.Output) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -75,7 +75,9 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		mode = os.FileMode(m) | 		if m != 0 { | ||||||
|  | 			mode = os.FileMode(m) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	b := &Backend{ | 	b := &Backend{ | ||||||
| @@ -247,13 +249,15 @@ func (b *Backend) open() error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Change the file mode in case the log file already existed. We special | 	// Change the file mode in case the log file already existed. We special | ||||||
| 	// case /dev/null since we can't chmod it | 	// case /dev/null since we can't chmod it and bypass if the mode is zero | ||||||
| 	switch b.path { | 	switch b.path { | ||||||
| 	case "/dev/null": | 	case "/dev/null": | ||||||
| 	default: | 	default: | ||||||
| 		err = os.Chmod(b.path, b.mode) | 		if b.mode != 0 { | ||||||
| 		if err != nil { | 			err = os.Chmod(b.path, b.mode) | ||||||
| 			return err | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep { | |||||||
| func testAccLogin(t *testing.T, display string) logicaltest.TestStep { | func testAccLogin(t *testing.T, display string) logicaltest.TestStep { | ||||||
| 	checkTTL := func(resp *logical.Response) error { | 	checkTTL := func(resp *logical.Response) error { | ||||||
| 		if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { | 		if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { | ||||||
| 			return fmt.Errorf("invalid TTL") | 			return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL) | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep { | |||||||
| func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep { | func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep { | ||||||
| 	checkTTL := func(resp *logical.Response) error { | 	checkTTL := func(resp *logical.Response) error { | ||||||
| 		if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { | 		if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { | ||||||
| 			return fmt.Errorf("invalid TTL") | 			return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL) | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,7 +3,6 @@ package approle | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| @@ -52,7 +51,7 @@ func (b *backend) pathLoginUpdateAliasLookahead(req *logical.Request, data *fram | |||||||
| func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	role, roleName, metadata, _, err := b.validateCredentials(req, data) | 	role, roleName, metadata, _, err := b.validateCredentials(req, data) | ||||||
| 	if err != nil || role == nil { | 	if err != nil || role == nil { | ||||||
| 		return logical.ErrorResponse(fmt.Sprintf("failed to validate SecretID: %s", err)), nil | 		return logical.ErrorResponse(fmt.Sprintf("failed to validate credentials: %v", err)), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Always include the role name, for later filtering | 	// Always include the role name, for later filtering | ||||||
| @@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat | |||||||
| 		Policies: role.Policies, | 		Policies: role.Policies, | ||||||
| 		LeaseOptions: logical.LeaseOptions{ | 		LeaseOptions: logical.LeaseOptions{ | ||||||
| 			Renewable: true, | 			Renewable: true, | ||||||
|  | 			TTL:       role.TokenTTL, | ||||||
| 		}, | 		}, | ||||||
| 		Alias: &logical.Alias{ | 		Alias: &logical.Alias{ | ||||||
| 			Name: role.RoleID, | 			Name: role.RoleID, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// If 'Period' is set, use the value of 'Period' as the TTL. |  | ||||||
| 	// Otherwise, set the normal TokenTTL. |  | ||||||
| 	if role.Period > time.Duration(0) { |  | ||||||
| 		auth.TTL = role.Period |  | ||||||
| 	} else { |  | ||||||
| 		auth.TTL = role.TokenTTL |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &logical.Response{ | 	return &logical.Response{ | ||||||
| 		Auth: auth, | 		Auth: auth, | ||||||
| 	}, nil | 	}, nil | ||||||
| @@ -94,8 +86,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData | |||||||
| 		return nil, fmt.Errorf("failed to fetch role_name during renewal") | 		return nil, fmt.Errorf("failed to fetch role_name during renewal") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	// Ensure that the Role still exists. | 	// Ensure that the Role still exists. | ||||||
| 	role, err := b.roleEntry(req.Storage, roleName) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to validate role %s during renewal:%s", roleName, err) | 		return nil, fmt.Errorf("failed to validate role %s during renewal:%s", roleName, err) | ||||||
| 	} | 	} | ||||||
| @@ -103,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData | |||||||
| 		return nil, fmt.Errorf("role %s does not exist during renewal", roleName) | 		return nil, fmt.Errorf("role %s does not exist during renewal", roleName) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// If 'Period' is set on the Role, the token should never expire. | 	resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data) | ||||||
| 	// Replenish the TTL with 'Period's value. | 	if err != nil { | ||||||
| 	if role.Period > time.Duration(0) { | 		return nil, err | ||||||
| 		// If 'Period' was updated after the token was issued, |  | ||||||
| 		// token will bear the updated 'Period' value as its TTL. |  | ||||||
| 		req.Auth.TTL = role.Period |  | ||||||
| 		return &logical.Response{Auth: req.Auth}, nil |  | ||||||
| 	} else { |  | ||||||
| 		return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data) |  | ||||||
| 	} | 	} | ||||||
|  | 	resp.Auth.Period = role.Period | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| const pathLoginHelpSys = "Issue a token based on the credentials supplied" | const pathLoginHelpSys = "Issue a token based on the credentials supplied" | ||||||
|   | |||||||
| @@ -57,6 +57,10 @@ type roleStorageEntry struct { | |||||||
| 	// value is not modified on the role. If the `Period` in the role is modified, | 	// value is not modified on the role. If the `Period` in the role is modified, | ||||||
| 	// a token will pick up the new value during its next renewal. | 	// a token will pick up the new value during its next renewal. | ||||||
| 	Period time.Duration `json:"period" mapstructure:"period" structs:"period"` | 	Period time.Duration `json:"period" mapstructure:"period" structs:"period"` | ||||||
|  |  | ||||||
|  | 	// LowerCaseRoleName enforces the lower casing of role names for all the | ||||||
|  | 	// roles that get created since this field was introduced. | ||||||
|  | 	LowerCaseRoleName bool `json:"lower_case_role_name" mapstructure:"lower_case_role_name" structs:"lower_case_role_name"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // roleIDStorageEntry represents the reverse mapping from RoleID to Role | // roleIDStorageEntry represents the reverse mapping from RoleID to Role | ||||||
| @@ -509,10 +513,20 @@ the role.`, | |||||||
|  |  | ||||||
| // pathRoleExistenceCheck returns whether the role with the given name exists or not. | // pathRoleExistenceCheck returns whether the role with the given name exists or not. | ||||||
| func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) { | func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) { | ||||||
| 	role, err := b.roleEntry(req.Storage, data.Get("role_name").(string)) | 	roleName := data.Get("role_name").(string) | ||||||
|  | 	if roleName == "" { | ||||||
|  | 		return false, fmt.Errorf("missing role_name") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
|  | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return role != nil, nil | 	return role != nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -537,13 +551,21 @@ func (b *backend) pathRoleSecretIDList(req *logical.Request, data *framework.Fie | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	// Get the role entry | 	// Get the role entry | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if role == nil { | 	if role == nil { | ||||||
| 		return logical.ErrorResponse(fmt.Sprintf("role %s does not exist", roleName)), nil | 		return logical.ErrorResponse(fmt.Sprintf("role %q does not exist", roleName)), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if role.LowerCaseRoleName { | ||||||
|  | 		roleName = strings.ToLower(roleName) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Guard the list operation with an outer lock | 	// Guard the list operation with an outer lock | ||||||
| @@ -552,7 +574,7 @@ func (b *backend) pathRoleSecretIDList(req *logical.Request, data *framework.Fie | |||||||
|  |  | ||||||
| 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err) | 		return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Listing works one level at a time. Get the first level of data | 	// Listing works one level at a time. Get the first level of data | ||||||
| @@ -618,9 +640,8 @@ func validateRoleConstraints(role *roleStorageEntry) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // setRoleEntry grabs a write lock and stores the options on an role into the | // setRoleEntry persists the role and creates an index from roleID to role | ||||||
| // storage. Also creates a reverse index from the role's RoleID to the role | // name. | ||||||
| // itself. |  | ||||||
| func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error { | func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error { | ||||||
| 	if roleName == "" { | 	if roleName == "" { | ||||||
| 		return fmt.Errorf("missing role name") | 		return fmt.Errorf("missing role name") | ||||||
| @@ -641,7 +662,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	if entry == nil { | 	if entry == nil { | ||||||
| 		return fmt.Errorf("failed to create storage entry for role %s", roleName) | 		return fmt.Errorf("failed to create storage entry for role %q", roleName) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if the index from the role_id to role already exists | 	// Check if the index from the role_id to role already exists | ||||||
| @@ -680,7 +701,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| // roleEntry grabs the read lock and fetches the options of an role from the storage | // roleEntry reads the role from storage | ||||||
| func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEntry, error) { | func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEntry, error) { | ||||||
| 	if roleName == "" { | 	if roleName == "" { | ||||||
| 		return nil, fmt.Errorf("missing role_name") | 		return nil, fmt.Errorf("missing role_name") | ||||||
| @@ -688,11 +709,6 @@ func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEnt | |||||||
|  |  | ||||||
| 	var role roleStorageEntry | 	var role roleStorageEntry | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.RLock() |  | ||||||
| 	defer lock.RUnlock() |  | ||||||
|  |  | ||||||
| 	if entry, err := s.Get("role/" + strings.ToLower(roleName)); err != nil { | 	if entry, err := s.Get("role/" + strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if entry == nil { | 	} else if entry == nil { | ||||||
| @@ -712,6 +728,10 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	// Check if the role already exists | 	// Check if the role already exists | ||||||
| 	role, err := b.roleEntry(req.Storage, roleName) | 	role, err := b.roleEntry(req.Storage, roleName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -722,13 +742,14 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie | |||||||
| 	if role == nil && req.Operation == logical.CreateOperation { | 	if role == nil && req.Operation == logical.CreateOperation { | ||||||
| 		hmacKey, err := uuid.GenerateUUID() | 		hmacKey, err := uuid.GenerateUUID() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("failed to create role_id: %s\n", err) | 			return nil, fmt.Errorf("failed to create role_id: %v\n", err) | ||||||
| 		} | 		} | ||||||
| 		role = &roleStorageEntry{ | 		role = &roleStorageEntry{ | ||||||
| 			HMACKey: hmacKey, | 			HMACKey:           hmacKey, | ||||||
|  | 			LowerCaseRoleName: true, | ||||||
| 		} | 		} | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| 		return nil, fmt.Errorf("role entry not found during update operation") | 		return logical.ErrorResponse(fmt.Sprintf("invalid role name")), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	previousRoleID := role.RoleID | 	previousRoleID := role.RoleID | ||||||
| @@ -737,12 +758,12 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie | |||||||
| 	} else if req.Operation == logical.CreateOperation { | 	} else if req.Operation == logical.CreateOperation { | ||||||
| 		roleID, err := uuid.GenerateUUID() | 		roleID, err := uuid.GenerateUUID() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("failed to generate role_id: %s\n", err) | 			return nil, fmt.Errorf("failed to generate role_id: %v\n", err) | ||||||
| 		} | 		} | ||||||
| 		role.RoleID = roleID | 		role.RoleID = roleID | ||||||
| 	} | 	} | ||||||
| 	if role.RoleID == "" { | 	if role.RoleID == "" { | ||||||
| 		return logical.ErrorResponse("invalid role_id"), nil | 		return logical.ErrorResponse("invalid role_id supplied, or failed to generate a role_id"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok { | 	if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok { | ||||||
| @@ -780,7 +801,7 @@ func (b *backend) pathRoleCreateUpdate(req *logical.Request, data *framework.Fie | |||||||
| 		role.Period = time.Second * time.Duration(data.Get("period").(int)) | 		role.Period = time.Second * time.Duration(data.Get("period").(int)) | ||||||
| 	} | 	} | ||||||
| 	if role.Period > b.System().MaxLeaseTTL() { | 	if role.Period > b.System().MaxLeaseTTL() { | ||||||
| 		return logical.ErrorResponse(fmt.Sprintf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.Period.String(), b.System().MaxLeaseTTL().String())), nil | 		return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if secretIDNumUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok { | 	if secretIDNumUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok { | ||||||
| @@ -843,32 +864,78 @@ func (b *backend) pathRoleRead(req *logical.Request, data *framework.FieldData) | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	lockRelease := lock.RUnlock | ||||||
|  |  | ||||||
|  | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		lockRelease() | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { |  | ||||||
| 		return nil, nil |  | ||||||
| 	} else { |  | ||||||
| 		// Convert the 'time.Duration' values to second. |  | ||||||
| 		role.SecretIDTTL /= time.Second |  | ||||||
| 		role.TokenTTL /= time.Second |  | ||||||
| 		role.TokenMaxTTL /= time.Second |  | ||||||
| 		role.Period /= time.Second |  | ||||||
|  |  | ||||||
| 		// Create a map of data to be returned and remove sensitive information from it |  | ||||||
| 		data := structs.New(role).Map() |  | ||||||
| 		delete(data, "role_id") |  | ||||||
| 		delete(data, "hmac_key") |  | ||||||
|  |  | ||||||
| 		resp := &logical.Response{ |  | ||||||
| 			Data: data, |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if err := validateRoleConstraints(role); err != nil { |  | ||||||
| 			resp.AddWarning("Role does not have any constraints set on it. Updates to this role will require a constraint to be set") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		return resp, nil |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if role == nil { | ||||||
|  | 		lockRelease() | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	respData := map[string]interface{}{ | ||||||
|  | 		"bind_secret_id":     role.BindSecretID, | ||||||
|  | 		"bound_cidr_list":    role.BoundCIDRList, | ||||||
|  | 		"period":             role.Period / time.Second, | ||||||
|  | 		"policies":           role.Policies, | ||||||
|  | 		"secret_id_num_uses": role.SecretIDNumUses, | ||||||
|  | 		"secret_id_ttl":      role.SecretIDTTL / time.Second, | ||||||
|  | 		"token_max_ttl":      role.TokenMaxTTL / time.Second, | ||||||
|  | 		"token_num_uses":     role.TokenNumUses, | ||||||
|  | 		"token_ttl":          role.TokenTTL / time.Second, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp := &logical.Response{ | ||||||
|  | 		Data: respData, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := validateRoleConstraints(role); err != nil { | ||||||
|  | 		resp.AddWarning("Role does not have any constraints set on it. Updates to this role will require a constraint to be set") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// For sanity, verify that the index still exists. If the index is missing, | ||||||
|  | 	// add one and return a warning so it can be reported. | ||||||
|  | 	roleIDIndex, err := b.roleIDEntry(req.Storage, role.RoleID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		lockRelease() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if roleIDIndex == nil { | ||||||
|  | 		// Switch to a write lock | ||||||
|  | 		lock.RUnlock() | ||||||
|  | 		lock.Lock() | ||||||
|  | 		lockRelease = lock.Unlock | ||||||
|  |  | ||||||
|  | 		// Check again if the index is missing | ||||||
|  | 		roleIDIndex, err = b.roleIDEntry(req.Storage, role.RoleID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			lockRelease() | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if roleIDIndex == nil { | ||||||
|  | 			// Create a new index | ||||||
|  | 			err = b.setRoleIDEntry(req.Storage, role.RoleID, &roleIDStorageEntry{ | ||||||
|  | 				Name: roleName, | ||||||
|  | 			}) | ||||||
|  | 			if err != nil { | ||||||
|  | 				lockRelease() | ||||||
|  | 				return nil, fmt.Errorf("failed to create secondary index for role_id %q: %v", role.RoleID, err) | ||||||
|  | 			} | ||||||
|  | 			resp.AddWarning("Role identifier was missing an index back to role name. A new index has been added. Please report this observation.") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	lockRelease() | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // pathRoleDelete removes the role from the storage | // pathRoleDelete removes the role from the storage | ||||||
| @@ -878,6 +945,10 @@ func (b *backend) pathRoleDelete(req *logical.Request, data *framework.FieldData | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -886,19 +957,14 @@ func (b *backend) pathRoleDelete(req *logical.Request, data *framework.FieldData | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Acquire the lock before deleting the secrets. |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	// Just before the role is deleted, remove all the SecretIDs issued as part of the role. | 	// Just before the role is deleted, remove all the SecretIDs issued as part of the role. | ||||||
| 	if err = b.flushRoleSecrets(req.Storage, roleName, role.HMACKey); err != nil { | 	if err = b.flushRoleSecrets(req.Storage, roleName, role.HMACKey); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to invalidate the secrets belonging to role '%s': %s", roleName, err) | 		return nil, fmt.Errorf("failed to invalidate the secrets belonging to role %q: %v", roleName, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Delete the reverse mapping from RoleID to the role | 	// Delete the reverse mapping from RoleID to the role | ||||||
| 	if err = b.roleIDEntryDelete(req.Storage, role.RoleID); err != nil { | 	if err = b.roleIDEntryDelete(req.Storage, role.RoleID); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to delete the mapping from RoleID to role '%s': %s", roleName, err) | 		return nil, fmt.Errorf("failed to delete the mapping from RoleID to role %q: %v", roleName, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// After deleting the SecretIDs and the RoleID, delete the role itself | 	// After deleting the SecretIDs and the RoleID, delete the role itself | ||||||
| @@ -921,25 +987,33 @@ func (b *backend) pathRoleSecretIDLookupUpdate(req *logical.Request, data *frame | |||||||
| 		return logical.ErrorResponse("missing secret_id"), nil | 		return logical.ErrorResponse("missing secret_id"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	// Fetch the role | 	// Fetch the role | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if role == nil { | 	if role == nil { | ||||||
| 		return nil, fmt.Errorf("role %s does not exist", roleName) | 		return nil, fmt.Errorf("role %q does not exist", roleName) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if role.LowerCaseRoleName { | ||||||
|  | 		roleName = strings.ToLower(roleName) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Create the HMAC of the secret ID using the per-role HMAC key | 	// Create the HMAC of the secret ID using the per-role HMAC key | ||||||
| 	secretIDHMAC, err := createHMAC(role.HMACKey, secretID) | 	secretIDHMAC, err := createHMAC(role.HMACKey, secretID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to create HMAC of secret_id: %s", err) | 		return nil, fmt.Errorf("failed to create HMAC of secret_id: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Create the HMAC of the roleName using the per-role HMAC key | 	// Create the HMAC of the roleName using the per-role HMAC key | ||||||
| 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err) | 		return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Create the index at which the secret_id would've been stored | 	// Create the index at which the secret_id would've been stored | ||||||
| @@ -996,22 +1070,26 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(req *logical.Request, data | |||||||
| 		return logical.ErrorResponse("missing secret_id"), nil | 		return logical.ErrorResponse("missing secret_id"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	roleLock := b.roleLock(roleName) | ||||||
|  | 	roleLock.RLock() | ||||||
|  | 	defer roleLock.RUnlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if role == nil { | 	if role == nil { | ||||||
| 		return nil, fmt.Errorf("role %s does not exist", roleName) | 		return nil, fmt.Errorf("role %q does not exist", roleName) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	secretIDHMAC, err := createHMAC(role.HMACKey, secretID) | 	secretIDHMAC, err := createHMAC(role.HMACKey, secretID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to create HMAC of secret_id: %s", err) | 		return nil, fmt.Errorf("failed to create HMAC of secret_id: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err) | 		return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC) | 	entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC) | ||||||
| @@ -1036,7 +1114,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(req *logical.Request, data | |||||||
|  |  | ||||||
| 	// Delete the storage entry that corresponds to the SecretID | 	// Delete the storage entry that corresponds to the SecretID | ||||||
| 	if err := req.Storage.Delete(entryIndex); err != nil { | 	if err := req.Storage.Delete(entryIndex); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to delete SecretID: %s", err) | 		return nil, fmt.Errorf("failed to delete secret_id: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
| @@ -1059,12 +1137,16 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(req *logical.Request, dat | |||||||
| 	// Get the role details to fetch the RoleID and accessor to get | 	// Get the role details to fetch the RoleID and accessor to get | ||||||
| 	// the HMACed SecretID. | 	// the HMACed SecretID. | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if role == nil { | 	if role == nil { | ||||||
| 		return nil, fmt.Errorf("role %s does not exist", roleName) | 		return nil, fmt.Errorf("role %q does not exist", roleName) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor) | 	accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor) | ||||||
| @@ -1072,12 +1154,12 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(req *logical.Request, dat | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if accessorEntry == nil { | 	if accessorEntry == nil { | ||||||
| 		return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor:%s\n", secretIDAccessor) | 		return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor: %q\n", secretIDAccessor) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err) | 		return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC) | 	entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC) | ||||||
| @@ -1105,7 +1187,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if role == nil { | 	if role == nil { | ||||||
| 		return nil, fmt.Errorf("role %s does not exist", roleName) | 		return nil, fmt.Errorf("role %q does not exist", roleName) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor) | 	accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor) | ||||||
| @@ -1113,12 +1195,12 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if accessorEntry == nil { | 	if accessorEntry == nil { | ||||||
| 		return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor:%s\n", secretIDAccessor) | 		return nil, fmt.Errorf("failed to find accessor entry for secret_id_accessor: %q\n", secretIDAccessor) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | 	roleNameHMAC, err := createHMAC(role.HMACKey, roleName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to create HMAC of role_name: %s", err) | 		return nil, fmt.Errorf("failed to create HMAC of role_name: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC) | 	entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC) | ||||||
| @@ -1134,7 +1216,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(req *logical.Reque | |||||||
|  |  | ||||||
| 	// Delete the storage entry that corresponds to the SecretID | 	// Delete the storage entry that corresponds to the SecretID | ||||||
| 	if err := req.Storage.Delete(entryIndex); err != nil { | 	if err := req.Storage.Delete(entryIndex); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to delete SecretID: %s", err) | 		return nil, fmt.Errorf("failed to delete secret_id: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
| @@ -1146,6 +1228,11 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
|  | 	// Re-read the role after grabbing the lock | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1154,11 +1241,6 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.BoundCIDRList = strings.TrimSpace(data.Get("bound_cidr_list").(string)) | 	role.BoundCIDRList = strings.TrimSpace(data.Get("bound_cidr_list").(string)) | ||||||
| 	if role.BoundCIDRList == "" { | 	if role.BoundCIDRList == "" { | ||||||
| 		return logical.ErrorResponse("missing bound_cidr_list"), nil | 		return logical.ErrorResponse("missing bound_cidr_list"), nil | ||||||
| @@ -1167,7 +1249,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(req *logical.Request, data *framew | |||||||
| 	if role.BoundCIDRList != "" { | 	if role.BoundCIDRList != "" { | ||||||
| 		valid, err := cidrutil.ValidateCIDRListString(role.BoundCIDRList, ",") | 		valid, err := cidrutil.ValidateCIDRListString(role.BoundCIDRList, ",") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("failed to validate CIDR blocks: %q", err) | 			return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err) | ||||||
| 		} | 		} | ||||||
| 		if !valid { | 		if !valid { | ||||||
| 			return logical.ErrorResponse("failed to validate CIDR blocks"), nil | 			return logical.ErrorResponse("failed to validate CIDR blocks"), nil | ||||||
| @@ -1183,6 +1265,10 @@ func (b *backend) pathRoleBoundCIDRListRead(req *logical.Request, data *framewor | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1202,6 +1288,10 @@ func (b *backend) pathRoleBoundCIDRListDelete(req *logical.Request, data *framew | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1210,11 +1300,6 @@ func (b *backend) pathRoleBoundCIDRListDelete(req *logical.Request, data *framew | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	// Deleting a field implies setting the value to it's default value. | 	// Deleting a field implies setting the value to it's default value. | ||||||
| 	role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string) | 	role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string) | ||||||
|  |  | ||||||
| @@ -1227,6 +1312,10 @@ func (b *backend) pathRoleBindSecretIDUpdate(req *logical.Request, data *framewo | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1235,11 +1324,6 @@ func (b *backend) pathRoleBindSecretIDUpdate(req *logical.Request, data *framewo | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok { | 	if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok { | ||||||
| 		role.BindSecretID = bindSecretIDRaw.(bool) | 		role.BindSecretID = bindSecretIDRaw.(bool) | ||||||
| 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1254,6 +1338,10 @@ func (b *backend) pathRoleBindSecretIDRead(req *logical.Request, data *framework | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1273,6 +1361,10 @@ func (b *backend) pathRoleBindSecretIDDelete(req *logical.Request, data *framewo | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1281,11 +1373,6 @@ func (b *backend) pathRoleBindSecretIDDelete(req *logical.Request, data *framewo | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	// Deleting a field implies setting the value to it's default value. | 	// Deleting a field implies setting the value to it's default value. | ||||||
| 	role.BindSecretID = data.GetDefaultOrZero("bind_secret_id").(bool) | 	role.BindSecretID = data.GetDefaultOrZero("bind_secret_id").(bool) | ||||||
|  |  | ||||||
| @@ -1298,6 +1385,10 @@ func (b *backend) pathRolePoliciesUpdate(req *logical.Request, data *framework.F | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1311,11 +1402,6 @@ func (b *backend) pathRolePoliciesUpdate(req *logical.Request, data *framework.F | |||||||
| 		return logical.ErrorResponse("missing policies"), nil | 		return logical.ErrorResponse("missing policies"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.Policies = policyutil.ParsePolicies(policiesRaw) | 	role.Policies = policyutil.ParsePolicies(policiesRaw) | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1327,6 +1413,10 @@ func (b *backend) pathRolePoliciesRead(req *logical.Request, data *framework.Fie | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1346,6 +1436,10 @@ func (b *backend) pathRolePoliciesDelete(req *logical.Request, data *framework.F | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1354,11 +1448,6 @@ func (b *backend) pathRolePoliciesDelete(req *logical.Request, data *framework.F | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.Policies = []string{} | 	role.Policies = []string{} | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1370,6 +1459,10 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(req *logical.Request, data *fram | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1378,11 +1471,6 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(req *logical.Request, data *fram | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if numUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok { | 	if numUsesRaw, ok := data.GetOk("secret_id_num_uses"); ok { | ||||||
| 		role.SecretIDNumUses = numUsesRaw.(int) | 		role.SecretIDNumUses = numUsesRaw.(int) | ||||||
| 		if role.SecretIDNumUses < 0 { | 		if role.SecretIDNumUses < 0 { | ||||||
| @@ -1400,6 +1488,10 @@ func (b *backend) pathRoleRoleIDUpdate(req *logical.Request, data *framework.Fie | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1408,11 +1500,6 @@ func (b *backend) pathRoleRoleIDUpdate(req *logical.Request, data *framework.Fie | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	previousRoleID := role.RoleID | 	previousRoleID := role.RoleID | ||||||
| 	role.RoleID = data.Get("role_id").(string) | 	role.RoleID = data.Get("role_id").(string) | ||||||
| 	if role.RoleID == "" { | 	if role.RoleID == "" { | ||||||
| @@ -1428,6 +1515,10 @@ func (b *backend) pathRoleRoleIDRead(req *logical.Request, data *framework.Field | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1447,6 +1538,10 @@ func (b *backend) pathRoleSecretIDNumUsesRead(req *logical.Request, data *framew | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1466,6 +1561,10 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(req *logical.Request, data *fram | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1474,11 +1573,6 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(req *logical.Request, data *fram | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.SecretIDNumUses = data.GetDefaultOrZero("secret_id_num_uses").(int) | 	role.SecretIDNumUses = data.GetDefaultOrZero("secret_id_num_uses").(int) | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1490,6 +1584,10 @@ func (b *backend) pathRoleSecretIDTTLUpdate(req *logical.Request, data *framewor | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1498,11 +1596,6 @@ func (b *backend) pathRoleSecretIDTTLUpdate(req *logical.Request, data *framewor | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if secretIDTTLRaw, ok := data.GetOk("secret_id_ttl"); ok { | 	if secretIDTTLRaw, ok := data.GetOk("secret_id_ttl"); ok { | ||||||
| 		role.SecretIDTTL = time.Second * time.Duration(secretIDTTLRaw.(int)) | 		role.SecretIDTTL = time.Second * time.Duration(secretIDTTLRaw.(int)) | ||||||
| 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1517,6 +1610,10 @@ func (b *backend) pathRoleSecretIDTTLRead(req *logical.Request, data *framework. | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1537,6 +1634,10 @@ func (b *backend) pathRoleSecretIDTTLDelete(req *logical.Request, data *framewor | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1545,11 +1646,6 @@ func (b *backend) pathRoleSecretIDTTLDelete(req *logical.Request, data *framewor | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.SecretIDTTL = time.Second * time.Duration(data.GetDefaultOrZero("secret_id_ttl").(int)) | 	role.SecretIDTTL = time.Second * time.Duration(data.GetDefaultOrZero("secret_id_ttl").(int)) | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1561,6 +1657,10 @@ func (b *backend) pathRolePeriodUpdate(req *logical.Request, data *framework.Fie | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1569,15 +1669,10 @@ func (b *backend) pathRolePeriodUpdate(req *logical.Request, data *framework.Fie | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if periodRaw, ok := data.GetOk("period"); ok { | 	if periodRaw, ok := data.GetOk("period"); ok { | ||||||
| 		role.Period = time.Second * time.Duration(periodRaw.(int)) | 		role.Period = time.Second * time.Duration(periodRaw.(int)) | ||||||
| 		if role.Period > b.System().MaxLeaseTTL() { | 		if role.Period > b.System().MaxLeaseTTL() { | ||||||
| 			return logical.ErrorResponse(fmt.Sprintf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.Period.String(), b.System().MaxLeaseTTL().String())), nil | 			return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil | ||||||
| 		} | 		} | ||||||
| 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| 	} else { | 	} else { | ||||||
| @@ -1591,6 +1686,10 @@ func (b *backend) pathRolePeriodRead(req *logical.Request, data *framework.Field | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1611,6 +1710,10 @@ func (b *backend) pathRolePeriodDelete(req *logical.Request, data *framework.Fie | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1619,11 +1722,6 @@ func (b *backend) pathRolePeriodDelete(req *logical.Request, data *framework.Fie | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.Period = time.Second * time.Duration(data.GetDefaultOrZero("period").(int)) | 	role.Period = time.Second * time.Duration(data.GetDefaultOrZero("period").(int)) | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1635,6 +1733,10 @@ func (b *backend) pathRoleTokenNumUsesUpdate(req *logical.Request, data *framewo | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1643,11 +1745,6 @@ func (b *backend) pathRoleTokenNumUsesUpdate(req *logical.Request, data *framewo | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if tokenNumUsesRaw, ok := data.GetOk("token_num_uses"); ok { | 	if tokenNumUsesRaw, ok := data.GetOk("token_num_uses"); ok { | ||||||
| 		role.TokenNumUses = tokenNumUsesRaw.(int) | 		role.TokenNumUses = tokenNumUsesRaw.(int) | ||||||
| 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 		return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1662,6 +1759,10 @@ func (b *backend) pathRoleTokenNumUsesRead(req *logical.Request, data *framework | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1681,6 +1782,10 @@ func (b *backend) pathRoleTokenNumUsesDelete(req *logical.Request, data *framewo | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1689,11 +1794,6 @@ func (b *backend) pathRoleTokenNumUsesDelete(req *logical.Request, data *framewo | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.TokenNumUses = data.GetDefaultOrZero("token_num_uses").(int) | 	role.TokenNumUses = data.GetDefaultOrZero("token_num_uses").(int) | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1705,6 +1805,10 @@ func (b *backend) pathRoleTokenTTLUpdate(req *logical.Request, data *framework.F | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1713,11 +1817,6 @@ func (b *backend) pathRoleTokenTTLUpdate(req *logical.Request, data *framework.F | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if tokenTTLRaw, ok := data.GetOk("token_ttl"); ok { | 	if tokenTTLRaw, ok := data.GetOk("token_ttl"); ok { | ||||||
| 		role.TokenTTL = time.Second * time.Duration(tokenTTLRaw.(int)) | 		role.TokenTTL = time.Second * time.Duration(tokenTTLRaw.(int)) | ||||||
| 		if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL { | 		if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL { | ||||||
| @@ -1735,6 +1834,10 @@ func (b *backend) pathRoleTokenTTLRead(req *logical.Request, data *framework.Fie | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1755,6 +1858,10 @@ func (b *backend) pathRoleTokenTTLDelete(req *logical.Request, data *framework.F | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1763,11 +1870,6 @@ func (b *backend) pathRoleTokenTTLDelete(req *logical.Request, data *framework.F | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.TokenTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_ttl").(int)) | 	role.TokenTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_ttl").(int)) | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1779,6 +1881,10 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(req *logical.Request, data *framewor | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1787,11 +1893,6 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(req *logical.Request, data *framewor | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	if tokenMaxTTLRaw, ok := data.GetOk("token_max_ttl"); ok { | 	if tokenMaxTTLRaw, ok := data.GetOk("token_max_ttl"); ok { | ||||||
| 		role.TokenMaxTTL = time.Second * time.Duration(tokenMaxTTLRaw.(int)) | 		role.TokenMaxTTL = time.Second * time.Duration(tokenMaxTTLRaw.(int)) | ||||||
| 		if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL { | 		if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL { | ||||||
| @@ -1809,6 +1910,10 @@ func (b *backend) pathRoleTokenMaxTTLRead(req *logical.Request, data *framework. | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | 	if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} else if role == nil { | 	} else if role == nil { | ||||||
| @@ -1829,6 +1934,10 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor | |||||||
| 		return logical.ErrorResponse("missing role_name"), nil | 		return logical.ErrorResponse("missing role_name"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -1837,11 +1946,6 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lock := b.roleLock(roleName) |  | ||||||
|  |  | ||||||
| 	lock.Lock() |  | ||||||
| 	defer lock.Unlock() |  | ||||||
|  |  | ||||||
| 	role.TokenMaxTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_max_ttl").(int)) | 	role.TokenMaxTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_max_ttl").(int)) | ||||||
|  |  | ||||||
| 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | 	return nil, b.setRoleEntry(req.Storage, roleName, role, "") | ||||||
| @@ -1850,7 +1954,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(req *logical.Request, data *framewor | |||||||
| func (b *backend) pathRoleSecretIDUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathRoleSecretIDUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	secretID, err := uuid.GenerateUUID() | 	secretID, err := uuid.GenerateUUID() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to generate SecretID:%s", err) | 		return nil, fmt.Errorf("failed to generate secret_id: %v", err) | ||||||
| 	} | 	} | ||||||
| 	return b.handleRoleSecretIDCommon(req, data, secretID) | 	return b.handleRoleSecretIDCommon(req, data, secretID) | ||||||
| } | } | ||||||
| @@ -1869,12 +1973,16 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework | |||||||
| 		return logical.ErrorResponse("missing secret_id"), nil | 		return logical.ErrorResponse("missing secret_id"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleName) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | 	role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if role == nil { | 	if role == nil { | ||||||
| 		return logical.ErrorResponse(fmt.Sprintf("role %s does not exist", roleName)), nil | 		return logical.ErrorResponse(fmt.Sprintf("role %q does not exist", roleName)), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !role.BindSecretID { | 	if !role.BindSecretID { | ||||||
| @@ -1887,7 +1995,7 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework | |||||||
| 	if cidrList != "" { | 	if cidrList != "" { | ||||||
| 		valid, err := cidrutil.ValidateCIDRListString(cidrList, ",") | 		valid, err := cidrutil.ValidateCIDRListString(cidrList, ",") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("failed to validate CIDR blocks: %q", err) | 			return nil, fmt.Errorf("failed to validate CIDR blocks: %v", err) | ||||||
| 		} | 		} | ||||||
| 		if !valid { | 		if !valid { | ||||||
| 			return logical.ErrorResponse("failed to validate CIDR blocks"), nil | 			return logical.ErrorResponse("failed to validate CIDR blocks"), nil | ||||||
| @@ -1913,8 +2021,12 @@ func (b *backend) handleRoleSecretIDCommon(req *logical.Request, data *framework | |||||||
| 		return logical.ErrorResponse(fmt.Sprintf("failed to parse metadata: %v", err)), nil | 		return logical.ErrorResponse(fmt.Sprintf("failed to parse metadata: %v", err)), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if role.LowerCaseRoleName { | ||||||
|  | 		roleName = strings.ToLower(roleName) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if secretIDStorage, err = b.registerSecretIDEntry(req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil { | 	if secretIDStorage, err = b.registerSecretIDEntry(req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to store SecretID: %s", err) | 		return nil, fmt.Errorf("failed to store secret_id: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &logical.Response{ | 	return &logical.Response{ | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package approle | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -10,6 +11,248 @@ import ( | |||||||
| 	"github.com/mitchellh/mapstructure" | 	"github.com/mitchellh/mapstructure" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func TestApprole_RoleNameLowerCasing(t *testing.T) { | ||||||
|  | 	var resp *logical.Response | ||||||
|  | 	var err error | ||||||
|  | 	var roleID, secretID string | ||||||
|  |  | ||||||
|  | 	b, storage := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	// Save a role with out LowerCaseRoleName set | ||||||
|  | 	role := &roleStorageEntry{ | ||||||
|  | 		RoleID:       "testroleid", | ||||||
|  | 		HMACKey:      "testhmackey", | ||||||
|  | 		Policies:     []string{"default"}, | ||||||
|  | 		BindSecretID: true, | ||||||
|  | 	} | ||||||
|  | 	err = b.setRoleEntry(storage, "testRoleName", role, "") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	secretIDReq := &logical.Request{ | ||||||
|  | 		Path:      "role/testRoleName/secret-id", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(secretIDReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  | 	secretID = resp.Data["secret_id"].(string) | ||||||
|  | 	roleID = "testroleid" | ||||||
|  |  | ||||||
|  | 	// Regular login flow. This should succeed. | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "login", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"role_id":   roleID, | ||||||
|  | 			"secret_id": secretID, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Lower case the role name when generating the secret id | ||||||
|  | 	secretIDReq.Path = "role/testrolename/secret-id" | ||||||
|  | 	resp, err = b.HandleRequest(secretIDReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  | 	secretID = resp.Data["secret_id"].(string) | ||||||
|  |  | ||||||
|  | 	// Login should fail | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "login", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"role_id":   roleID, | ||||||
|  | 			"secret_id": secretID, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if resp == nil || !resp.IsError() { | ||||||
|  | 		t.Fatalf("expected an error") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Delete the role and create it again. This time don't directly persist | ||||||
|  | 	// it, but route the request to the creation handler so that it sets the | ||||||
|  | 	// LowerCaseRoleName to true. | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "role/testRoleName", | ||||||
|  | 		Operation: logical.DeleteOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleReq := &logical.Request{ | ||||||
|  | 		Path:      "role/testRoleName", | ||||||
|  | 		Operation: logical.CreateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"bind_secret_id": true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create secret id with lower cased role name | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "role/testrolename/secret-id", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  | 	secretID = resp.Data["secret_id"].(string) | ||||||
|  |  | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "role/testrolename/role-id", | ||||||
|  | 		Operation: logical.ReadOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  | 	roleID = resp.Data["role_id"].(string) | ||||||
|  |  | ||||||
|  | 	// Login should pass | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "login", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"role_id":   roleID, | ||||||
|  | 			"secret_id": secretID, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr:%v", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Lookup of secret ID should work in case-insensitive manner | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "role/testrolename/secret-id/lookup", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"secret_id": secretID, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  | 	if resp == nil { | ||||||
|  | 		t.Fatalf("failed to lookup secret IDs") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Listing of secret IDs should work in case-insensitive manner | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Path:      "role/testrolename/secret-id", | ||||||
|  | 		Operation: logical.ListOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(resp.Data["keys"].([]string)) != 1 { | ||||||
|  | 		t.Fatalf("failed to list secret IDs") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestAppRole_RoleReadSetIndex(t *testing.T) { | ||||||
|  | 	var resp *logical.Response | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	b, storage := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	roleReq := &logical.Request{ | ||||||
|  | 		Path:      "role/testrole", | ||||||
|  | 		Operation: logical.CreateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"bind_secret_id": true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create a role | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleIDReq := &logical.Request{ | ||||||
|  | 		Path:      "role/testrole/role-id", | ||||||
|  | 		Operation: logical.ReadOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get the role ID | ||||||
|  | 	resp, err = b.HandleRequest(roleIDReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err) | ||||||
|  | 	} | ||||||
|  | 	roleID := resp.Data["role_id"].(string) | ||||||
|  |  | ||||||
|  | 	// Delete the role ID index | ||||||
|  | 	err = b.roleIDEntryDelete(storage, roleID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Read the role again. This should add the index and return a warning | ||||||
|  | 	roleReq.Operation = logical.ReadOperation | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if the warning is being returned | ||||||
|  | 	if !strings.Contains(resp.Warnings[0], "Role identifier was missing an index back to role name.") { | ||||||
|  | 		t.Fatalf("bad: expected a warning in the response") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleIDIndex, err := b.roleIDEntry(storage, roleID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if the index has been successfully created | ||||||
|  | 	if roleIDIndex == nil || roleIDIndex.Name != "testrole" { | ||||||
|  | 		t.Fatalf("bad: expected role to have an index") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleReq.Operation = logical.UpdateOperation | ||||||
|  | 	roleReq.Data = map[string]interface{}{ | ||||||
|  | 		"bind_secret_id": true, | ||||||
|  | 		"policies":       "default", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if updating and reading of roles work and that there are no lock | ||||||
|  | 	// contentions dangling due to previous operation | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err) | ||||||
|  | 	} | ||||||
|  | 	roleReq.Operation = logical.ReadOperation | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: resp: %#v\n err: %v\n", resp, err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestAppRole_CIDRSubset(t *testing.T) { | func TestAppRole_CIDRSubset(t *testing.T) { | ||||||
| 	var resp *logical.Response | 	var resp *logical.Response | ||||||
| 	var err error | 	var err error | ||||||
|   | |||||||
| @@ -75,15 +75,19 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage | |||||||
| 		return nil, "", err | 		return nil, "", err | ||||||
| 	} | 	} | ||||||
| 	if roleIDIndex == nil { | 	if roleIDIndex == nil { | ||||||
| 		return nil, "", fmt.Errorf("failed to find secondary index for role_id %q\n", roleID) | 		return nil, "", fmt.Errorf("invalid role_id %q\n", roleID) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	lock := b.roleLock(roleIDIndex.Name) | ||||||
|  | 	lock.RLock() | ||||||
|  | 	defer lock.RUnlock() | ||||||
|  |  | ||||||
| 	role, err := b.roleEntry(s, roleIDIndex.Name) | 	role, err := b.roleEntry(s, roleIDIndex.Name) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, "", err | 		return nil, "", err | ||||||
| 	} | 	} | ||||||
| 	if role == nil { | 	if role == nil { | ||||||
| 		return nil, "", fmt.Errorf("role %q referred by the SecretID does not exist", roleIDIndex.Name) | 		return nil, "", fmt.Errorf("role %q referred by the role_id %q does not exist anymore", roleIDIndex.Name, roleID) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return role, roleIDIndex.Name, nil | 	return role, roleIDIndex.Name, nil | ||||||
| @@ -121,6 +125,10 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel | |||||||
| 			return nil, "", metadata, "", fmt.Errorf("missing secret_id") | 			return nil, "", metadata, "", fmt.Errorf("missing secret_id") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		if role.LowerCaseRoleName { | ||||||
|  | 			roleName = strings.ToLower(roleName) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// Check if the SecretID supplied is valid. If use limit was specified | 		// Check if the SecretID supplied is valid. If use limit was specified | ||||||
| 		// on the SecretID, it will be decremented in this call. | 		// on the SecretID, it will be decremented in this call. | ||||||
| 		var valid bool | 		var valid bool | ||||||
|   | |||||||
| @@ -99,6 +99,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { | |||||||
| 			LocalStorage: []string{ | 			LocalStorage: []string{ | ||||||
| 				"whitelist/identity/", | 				"whitelist/identity/", | ||||||
| 			}, | 			}, | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/client", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathLogin(b), | 			pathLogin(b), | ||||||
|   | |||||||
| @@ -1125,6 +1125,11 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. | |||||||
| 		t.Fatalf("instance ID not present in the response object") | 		t.Fatalf("instance ID not present in the response object") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	_, ok := resp.Auth.Metadata["nonce"] | ||||||
|  | 	if ok { | ||||||
|  | 		t.Fatalf("client nonce should not have been returned") | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	loginInput["nonce"] = "changed-vault-client-nonce" | 	loginInput["nonce"] = "changed-vault-client-nonce" | ||||||
| 	// try to login again with changed nonce | 	// try to login again with changed nonce | ||||||
| 	resp, err = b.HandleRequest(loginRequest) | 	resp, err = b.HandleRequest(loginRequest) | ||||||
| @@ -1159,7 +1164,9 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. | |||||||
| 		t.Fatalf("failed to delete whitelist identity") | 		t.Fatalf("failed to delete whitelist identity") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Allow a fresh login. | 	// Allow a fresh login without supplying the nonce | ||||||
|  | 	delete(loginInput, "nonce") | ||||||
|  |  | ||||||
| 	resp, err = b.HandleRequest(loginRequest) | 	resp, err = b.HandleRequest(loginRequest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| @@ -1167,6 +1174,11 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing. | |||||||
| 	if resp == nil || resp.Auth == nil || resp.IsError() { | 	if resp == nil || resp.Auth == nil || resp.IsError() { | ||||||
| 		t.Fatalf("login attempt failed") | 		t.Fatalf("login attempt failed") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	_, ok = resp.Auth.Metadata["nonce"] | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("expected nonce to be returned") | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestBackend_pathStsConfig(t *testing.T) { | func TestBackend_pathStsConfig(t *testing.T) { | ||||||
|   | |||||||
| @@ -34,6 +34,11 @@ func GenerateLoginData(accessKey, secretKey, sessionToken, headerValue string) ( | |||||||
| 		return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata") | 		return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	_, err = creds.Get() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to retrieve credentials from credential chain: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Use the credentials we've found to construct an STS session | 	// Use the credentials we've found to construct an STS session | ||||||
| 	stsSession, err := session.NewSessionWithOptions(session.Options{ | 	stsSession, err := session.NewSessionWithOptions(session.Options{ | ||||||
| 		Config: aws.Config{Credentials: creds}, | 		Config: aws.Config{Credentials: creds}, | ||||||
|   | |||||||
| @@ -643,7 +643,7 @@ func (b *backend) pathLoginUpdateEc2( | |||||||
| 			return logical.ErrorResponse(err.Error()), nil | 			return logical.ErrorResponse(err.Error()), nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Don't let subsequent login attempts to bypass in initial | 		// Don't let subsequent login attempts to bypass the initial | ||||||
| 		// intent of disabling reauthentication, despite the properties | 		// intent of disabling reauthentication, despite the properties | ||||||
| 		// of role getting updated. For example: Role has the value set | 		// of role getting updated. For example: Role has the value set | ||||||
| 		// to 'false', a role-tag login sets the value to 'true', then | 		// to 'false', a role-tag login sets the value to 'true', then | ||||||
| @@ -693,7 +693,6 @@ func (b *backend) pathLoginUpdateEc2( | |||||||
|  |  | ||||||
| 	if roleTagResp != nil { | 	if roleTagResp != nil { | ||||||
| 		// Role tag is enabled on the role. | 		// Role tag is enabled on the role. | ||||||
| 		// |  | ||||||
|  |  | ||||||
| 		// Overwrite the policies with the ones returned from processing the role tag | 		// Overwrite the policies with the ones returned from processing the role tag | ||||||
| 		// If there are no policies on the role tag, policies on the role are inherited. | 		// If there are no policies on the role tag, policies on the role are inherited. | ||||||
| @@ -777,8 +776,9 @@ func (b *backend) pathLoginUpdateEc2( | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Return the nonce only if reauthentication is allowed | 	// Return the nonce only if reauthentication is allowed and if the nonce | ||||||
| 	if !disallowReauthentication { | 	// was not supplied by the user. | ||||||
|  | 	if !disallowReauthentication && !clientNonceSupplied { | ||||||
| 		// Echo the client nonce back. If nonce param was not supplied | 		// Echo the client nonce back. If nonce param was not supplied | ||||||
| 		// to the endpoint at all (setting it to empty string does not | 		// to the endpoint at all (setting it to empty string does not | ||||||
| 		// qualify here), callers should extract out the nonce from | 		// qualify here), callers should extract out the nonce from | ||||||
| @@ -786,23 +786,15 @@ func (b *backend) pathLoginUpdateEc2( | |||||||
| 		resp.Auth.Metadata["nonce"] = clientNonce | 		resp.Auth.Metadata["nonce"] = clientNonce | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if roleEntry.Period > time.Duration(0) { | 	if roleEntry.MaxTTL > time.Duration(0) { | ||||||
| 		resp.Auth.TTL = roleEntry.Period | 		// Cap TTL to shortestMaxTTL | ||||||
| 	} else { | 		if resp.Auth.TTL > shortestMaxTTL { | ||||||
| 		// Cap the TTL value. | 			resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (shortestMaxTTL / time.Second))) | ||||||
| 		shortestTTL := b.System().DefaultLeaseTTL() | 			resp.Auth.TTL = shortestMaxTTL | ||||||
| 		if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL { |  | ||||||
| 			shortestTTL = roleEntry.TTL |  | ||||||
| 		} | 		} | ||||||
| 		if shortestMaxTTL < shortestTTL { |  | ||||||
| 			resp.AddWarning(fmt.Sprintf("Effective ttl of %q exceeded the effective max_ttl of %q; ttl value is capped appropriately", (shortestTTL / time.Second).String(), (shortestMaxTTL / time.Second).String())) |  | ||||||
| 			shortestTTL = shortestMaxTTL |  | ||||||
| 		} |  | ||||||
| 		resp.Auth.TTL = shortestTTL |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return resp, nil | 	return resp, nil | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // handleRoleTagLogin is used to fetch the role tag of the instance and | // handleRoleTagLogin is used to fetch the role tag of the instance and | ||||||
| @@ -985,13 +977,12 @@ func (b *backend) pathLoginRenewIam( | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// If 'Period' is set on the role, then the token should never expire. | 	resp, err := framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(req, data) | ||||||
| 	if roleEntry.Period > time.Duration(0) { | 	if err != nil { | ||||||
| 		req.Auth.TTL = roleEntry.Period | 		return nil, err | ||||||
| 		return &logical.Response{Auth: req.Auth}, nil |  | ||||||
| 	} else { |  | ||||||
| 		return framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(req, data) |  | ||||||
| 	} | 	} | ||||||
|  | 	resp.Auth.Period = roleEntry.Period | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathLoginRenewEc2( | func (b *backend) pathLoginRenewEc2( | ||||||
| @@ -1072,24 +1063,12 @@ func (b *backend) pathLoginRenewEc2( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// If 'Period' is set on the role, then the token should never expire. Role | 	resp, err := framework.LeaseExtend(roleEntry.TTL, shortestMaxTTL, b.System())(req, data) | ||||||
| 	// tag does not have a 'Period' field. So, regarless of whether the token | 	if err != nil { | ||||||
| 	// was issued using a role login or a role tag login, the period set on the | 		return nil, err | ||||||
| 	// role should take effect. |  | ||||||
| 	if roleEntry.Period > time.Duration(0) { |  | ||||||
| 		req.Auth.TTL = roleEntry.Period |  | ||||||
| 		return &logical.Response{Auth: req.Auth}, nil |  | ||||||
| 	} else { |  | ||||||
| 		// Cap the TTL value |  | ||||||
| 		shortestTTL := b.System().DefaultLeaseTTL() |  | ||||||
| 		if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL { |  | ||||||
| 			shortestTTL = roleEntry.TTL |  | ||||||
| 		} |  | ||||||
| 		if shortestMaxTTL < shortestTTL { |  | ||||||
| 			shortestTTL = shortestMaxTTL |  | ||||||
| 		} |  | ||||||
| 		return framework.LeaseExtend(shortestTTL, shortestMaxTTL, b.System())(req, data) |  | ||||||
| 	} | 	} | ||||||
|  | 	resp.Auth.Period = roleEntry.Period | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathLoginUpdateIam( | func (b *backend) pathLoginUpdateIam( | ||||||
| @@ -1238,7 +1217,7 @@ func (b *backend) pathLoginUpdateIam( | |||||||
| 	policies := roleEntry.Policies | 	policies := roleEntry.Policies | ||||||
|  |  | ||||||
| 	inferredEntityType := "" | 	inferredEntityType := "" | ||||||
| 	inferredEntityId := "" | 	inferredEntityID := "" | ||||||
| 	if roleEntry.InferredEntityType == ec2EntityType { | 	if roleEntry.InferredEntityType == ec2EntityType { | ||||||
| 		instance, err := b.validateInstance(req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account) | 		instance, err := b.validateInstance(req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -1264,7 +1243,7 @@ func (b *backend) pathLoginUpdateIam( | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		inferredEntityType = ec2EntityType | 		inferredEntityType = ec2EntityType | ||||||
| 		inferredEntityId = entity.SessionInfo | 		inferredEntityID = entity.SessionInfo | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	resp := &logical.Response{ | 	resp := &logical.Response{ | ||||||
| @@ -1277,7 +1256,7 @@ func (b *backend) pathLoginUpdateIam( | |||||||
| 				"client_user_id":       callerUniqueId, | 				"client_user_id":       callerUniqueId, | ||||||
| 				"auth_type":            iamAuthType, | 				"auth_type":            iamAuthType, | ||||||
| 				"inferred_entity_type": inferredEntityType, | 				"inferred_entity_type": inferredEntityType, | ||||||
| 				"inferred_entity_id":   inferredEntityId, | 				"inferred_entity_id":   inferredEntityID, | ||||||
| 				"inferred_aws_region":  roleEntry.InferredAWSRegion, | 				"inferred_aws_region":  roleEntry.InferredAWSRegion, | ||||||
| 				"account_id":           entity.AccountNumber, | 				"account_id":           entity.AccountNumber, | ||||||
| 			}, | 			}, | ||||||
| @@ -1295,25 +1274,18 @@ func (b *backend) pathLoginUpdateIam( | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if roleEntry.Period > time.Duration(0) { | 	if roleEntry.MaxTTL > time.Duration(0) { | ||||||
| 		resp.Auth.TTL = roleEntry.Period | 		// Cap maxTTL to the sysview's max TTL | ||||||
| 	} else { | 		maxTTL := roleEntry.MaxTTL | ||||||
| 		shortestTTL := b.System().DefaultLeaseTTL() | 		if maxTTL > b.System().MaxLeaseTTL() { | ||||||
| 		if roleEntry.TTL > time.Duration(0) && roleEntry.TTL < shortestTTL { | 			maxTTL = b.System().MaxLeaseTTL() | ||||||
| 			shortestTTL = roleEntry.TTL |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		maxTTL := b.System().MaxLeaseTTL() | 		// Cap TTL to MaxTTL | ||||||
| 		if roleEntry.MaxTTL > time.Duration(0) && roleEntry.MaxTTL < maxTTL { | 		if resp.Auth.TTL > maxTTL { | ||||||
| 			maxTTL = roleEntry.MaxTTL | 			resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (maxTTL / time.Second))) | ||||||
|  | 			resp.Auth.TTL = maxTTL | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if shortestTTL > maxTTL { |  | ||||||
| 			resp.AddWarning(fmt.Sprintf("Effective TTL of %q exceeded the effective max_ttl of %q; TTL value is capped accordingly", (shortestTTL / time.Second).String(), (maxTTL / time.Second).String())) |  | ||||||
| 			shortestTTL = maxTTL |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		resp.Auth.TTL = shortestTTL |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return resp, nil | 	return resp, nil | ||||||
| @@ -1333,11 +1305,11 @@ func hasValuesForEc2Auth(data *framework.FieldData) (bool, bool) { | |||||||
|  |  | ||||||
| func hasValuesForIamAuth(data *framework.FieldData) (bool, bool) { | func hasValuesForIamAuth(data *framework.FieldData) (bool, bool) { | ||||||
| 	_, hasRequestMethod := data.GetOk("iam_http_request_method") | 	_, hasRequestMethod := data.GetOk("iam_http_request_method") | ||||||
| 	_, hasRequestUrl := data.GetOk("iam_request_url") | 	_, hasRequestURL := data.GetOk("iam_request_url") | ||||||
| 	_, hasRequestBody := data.GetOk("iam_request_body") | 	_, hasRequestBody := data.GetOk("iam_request_body") | ||||||
| 	_, hasRequestHeaders := data.GetOk("iam_request_headers") | 	_, hasRequestHeaders := data.GetOk("iam_request_headers") | ||||||
| 	return (hasRequestMethod && hasRequestUrl && hasRequestBody && hasRequestHeaders), | 	return (hasRequestMethod && hasRequestURL && hasRequestBody && hasRequestHeaders), | ||||||
| 		(hasRequestMethod || hasRequestUrl || hasRequestBody || hasRequestHeaders) | 		(hasRequestMethod || hasRequestURL || hasRequestBody || hasRequestHeaders) | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseIamArn(iamArn string) (*iamEntity, error) { | func parseIamArn(iamArn string) (*iamEntity, error) { | ||||||
|   | |||||||
| @@ -663,6 +663,10 @@ func (b *backend) pathRoleCreateUpdate( | |||||||
| 		roleEntry.AllowInstanceMigration = data.Get("allow_instance_migration").(bool) | 		roleEntry.AllowInstanceMigration = data.Get("allow_instance_migration").(bool) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if roleEntry.AllowInstanceMigration && roleEntry.DisallowReauthentication { | ||||||
|  | 		return logical.ErrorResponse("cannot specify both disallow_reauthentication=true and allow_instance_migration=true"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	var resp logical.Response | 	var resp logical.Response | ||||||
|  |  | ||||||
| 	ttlRaw, ok := data.GetOk("ttl") | 	ttlRaw, ok := data.GetOk("ttl") | ||||||
|   | |||||||
| @@ -124,6 +124,10 @@ func (b *backend) pathRoleTagUpdate( | |||||||
| 		resp.AddWarning("Role does not allow instance migration. Login will not be allowed with this tag unless the role value is updated.") | 		resp.AddWarning("Role does not allow instance migration. Login will not be allowed with this tag unless the role value is updated.") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if disallowReauthentication && allowInstanceMigration { | ||||||
|  | 		return logical.ErrorResponse("cannot set both disallow_reauthentication and allow_instance_migration"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// max_ttl for the role tag should be less than the max_ttl set on the role. | 	// max_ttl for the role tag should be less than the max_ttl set on the role. | ||||||
| 	maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second | 	maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second | ||||||
|  |  | ||||||
|   | |||||||
| @@ -66,12 +66,25 @@ func TestBackend_pathRoleEc2(t *testing.T) { | |||||||
| 		Data:      data, | 		Data:      data, | ||||||
| 		Storage:   storage, | 		Storage:   storage, | ||||||
| 	}) | 	}) | ||||||
| 	if resp != nil && resp.IsError() { |  | ||||||
| 		t.Fatalf("failed to create role: %s", resp.Data["error"]) |  | ||||||
| 	} |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  | 	if resp == nil || !resp.IsError() { | ||||||
|  | 		t.Fatalf("expected failure to create role with both allow_instance_migration true and disallow_reauthentication true") | ||||||
|  | 	} | ||||||
|  | 	data["disallow_reauthentication"] = false | ||||||
|  | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "role/ami-abcd123", | ||||||
|  | 		Data:      data, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if resp != nil && resp.IsError() { | ||||||
|  | 		t.Fatalf("failure to update role: %v", resp.Data["error"]) | ||||||
|  | 	} | ||||||
| 	resp, err = b.HandleRequest(&logical.Request{ | 	resp, err = b.HandleRequest(&logical.Request{ | ||||||
| 		Operation: logical.ReadOperation, | 		Operation: logical.ReadOperation, | ||||||
| 		Path:      "role/ami-abcd123", | 		Path:      "role/ami-abcd123", | ||||||
| @@ -80,8 +93,12 @@ func TestBackend_pathRoleEc2(t *testing.T) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	if !resp.Data["allow_instance_migration"].(bool) || !resp.Data["disallow_reauthentication"].(bool) { | 	if !resp.Data["allow_instance_migration"].(bool) { | ||||||
| 		t.Fatal("bad: expected:true got:false\n") | 		t.Fatal("bad: expected allow_instance_migration:true got:false\n") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if resp.Data["disallow_reauthentication"].(bool) { | ||||||
|  | 		t.Fatal("bad: expected disallow_reauthentication: false got:true\n") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// add another entry, to test listing of role entries | 	// add another entry, to test listing of role entries | ||||||
| @@ -529,7 +546,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) { | |||||||
| 		"ttl":                       "10m", | 		"ttl":                       "10m", | ||||||
| 		"max_ttl":                   "20m", | 		"max_ttl":                   "20m", | ||||||
| 		"policies":                  "testpolicy1,testpolicy2", | 		"policies":                  "testpolicy1,testpolicy2", | ||||||
| 		"disallow_reauthentication": true, | 		"disallow_reauthentication": false, | ||||||
| 		"hmac_key":                  "testhmackey", | 		"hmac_key":                  "testhmackey", | ||||||
| 		"period":                    "1m", | 		"period":                    "1m", | ||||||
| 	} | 	} | ||||||
| @@ -567,7 +584,7 @@ func TestAwsEc2_RoleCrud(t *testing.T) { | |||||||
| 		"ttl":                       time.Duration(600), | 		"ttl":                       time.Duration(600), | ||||||
| 		"max_ttl":                   time.Duration(1200), | 		"max_ttl":                   time.Duration(1200), | ||||||
| 		"policies":                  []string{"testpolicy1", "testpolicy2"}, | 		"policies":                  []string{"testpolicy1", "testpolicy2"}, | ||||||
| 		"disallow_reauthentication": true, | 		"disallow_reauthentication": false, | ||||||
| 		"period":                    time.Duration(60), | 		"period":                    time.Duration(60), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -587,7 +587,7 @@ func TestBackend_CRLs(t *testing.T) { | |||||||
| func testFactory(t *testing.T) logical.Backend { | func testFactory(t *testing.T) logical.Backend { | ||||||
| 	b, err := Factory(&logical.BackendConfig{ | 	b, err := Factory(&logical.BackendConfig{ | ||||||
| 		System: &logical.StaticSystemView{ | 		System: &logical.StaticSystemView{ | ||||||
| 			DefaultLeaseTTLVal: 300 * time.Second, | 			DefaultLeaseTTLVal: 1000 * time.Second, | ||||||
| 			MaxLeaseTTLVal:     1800 * time.Second, | 			MaxLeaseTTLVal:     1800 * time.Second, | ||||||
| 		}, | 		}, | ||||||
| 		StorageView: &logical.InmemStorage{}, | 		StorageView: &logical.InmemStorage{}, | ||||||
| @@ -619,9 +619,9 @@ func TestBackend_CertWrites(t *testing.T) { | |||||||
| 	tc := logicaltest.TestCase{ | 	tc := logicaltest.TestCase{ | ||||||
| 		Backend: testFactory(t), | 		Backend: testFactory(t), | ||||||
| 		Steps: []logicaltest.TestStep{ | 		Steps: []logicaltest.TestStep{ | ||||||
| 			testAccStepCert(t, "aaa", ca1, "foo", "", false), | 			testAccStepCert(t, "aaa", ca1, "foo", "", "", false), | ||||||
| 			testAccStepCert(t, "bbb", ca2, "foo", "", false), | 			testAccStepCert(t, "bbb", ca2, "foo", "", "", false), | ||||||
| 			testAccStepCert(t, "ccc", ca3, "foo", "", true), | 			testAccStepCert(t, "ccc", ca3, "foo", "", "", true), | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	tc.Steps = append(tc.Steps, testAccStepListCerts(t, []string{"aaa", "bbb"})...) | 	tc.Steps = append(tc.Steps, testAccStepListCerts(t, []string{"aaa", "bbb"})...) | ||||||
| @@ -642,16 +642,18 @@ func TestBackend_basic_CA(t *testing.T) { | |||||||
| 	logicaltest.Test(t, logicaltest.TestCase{ | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
| 		Backend: testFactory(t), | 		Backend: testFactory(t), | ||||||
| 		Steps: []logicaltest.TestStep{ | 		Steps: []logicaltest.TestStep{ | ||||||
| 			testAccStepCert(t, "web", ca, "foo", "", false), | 			testAccStepCert(t, "web", ca, "foo", "", "", false), | ||||||
| 			testAccStepLogin(t, connState), | 			testAccStepLogin(t, connState), | ||||||
| 			testAccStepCertLease(t, "web", ca, "foo"), | 			testAccStepCertLease(t, "web", ca, "foo"), | ||||||
| 			testAccStepCertTTL(t, "web", ca, "foo"), | 			testAccStepCertTTL(t, "web", ca, "foo"), | ||||||
| 			testAccStepLogin(t, connState), | 			testAccStepLogin(t, connState), | ||||||
|  | 			testAccStepCertMaxTTL(t, "web", ca, "foo"), | ||||||
|  | 			testAccStepLogin(t, connState), | ||||||
| 			testAccStepCertNoLease(t, "web", ca, "foo"), | 			testAccStepCertNoLease(t, "web", ca, "foo"), | ||||||
| 			testAccStepLoginDefaultLease(t, connState), | 			testAccStepLoginDefaultLease(t, connState), | ||||||
| 			testAccStepCert(t, "web", ca, "foo", "*.example.com", false), | 			testAccStepCert(t, "web", ca, "foo", "*.example.com", "", false), | ||||||
| 			testAccStepLogin(t, connState), | 			testAccStepLogin(t, connState), | ||||||
| 			testAccStepCert(t, "web", ca, "foo", "*.invalid.com", false), | 			testAccStepCert(t, "web", ca, "foo", "*.invalid.com", "", false), | ||||||
| 			testAccStepLoginInvalid(t, connState), | 			testAccStepLoginInvalid(t, connState), | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| @@ -700,11 +702,68 @@ func TestBackend_basic_singleCert(t *testing.T) { | |||||||
| 	logicaltest.Test(t, logicaltest.TestCase{ | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
| 		Backend: testFactory(t), | 		Backend: testFactory(t), | ||||||
| 		Steps: []logicaltest.TestStep{ | 		Steps: []logicaltest.TestStep{ | ||||||
| 			testAccStepCert(t, "web", ca, "foo", "", false), | 			testAccStepCert(t, "web", ca, "foo", "", "", false), | ||||||
| 			testAccStepLogin(t, connState), | 			testAccStepLogin(t, connState), | ||||||
| 			testAccStepCert(t, "web", ca, "foo", "example.com", false), | 			testAccStepCert(t, "web", ca, "foo", "example.com", "", false), | ||||||
| 			testAccStepLogin(t, connState), | 			testAccStepLogin(t, connState), | ||||||
| 			testAccStepCert(t, "web", ca, "foo", "invalid", false), | 			testAccStepCert(t, "web", ca, "foo", "invalid", "", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "1.2.3.4:invalid", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test a self-signed client with custom extensions (root CA) that is trusted | ||||||
|  | func TestBackend_extensions_singleCert(t *testing.T) { | ||||||
|  | 	connState, err := testConnState( | ||||||
|  | 		"test-fixtures/root/rootcawextcert.pem", | ||||||
|  | 		"test-fixtures/root/rootcawextkey.pem", | ||||||
|  | 		"test-fixtures/root/rootcacert.pem", | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("error testing connection state: %v", err) | ||||||
|  | 	} | ||||||
|  | 	ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %v", err) | ||||||
|  | 	} | ||||||
|  | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
|  | 		Backend: testFactory(t), | ||||||
|  | 		Steps: []logicaltest.TestStep{ | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:A UTF8String Extension", false), | ||||||
|  | 			testAccStepLogin(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:*,2.1.1.2:A UTF8*", false), | ||||||
|  | 			testAccStepLogin(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "1.2.3.45:*", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:The Wrong Value", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:*,2.1.1.2:The Wrong Value", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "", "2.1.1.1:,2.1.1.2:*", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:A UTF8String Extension", false), | ||||||
|  | 			testAccStepLogin(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:*,2.1.1.2:A UTF8*", false), | ||||||
|  | 			testAccStepLogin(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "example.com", "1.2.3.45:*", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:The Wrong Value", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "example.com", "2.1.1.1:*,2.1.1.2:The Wrong Value", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:A UTF8String Extension", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:*,2.1.1.2:A UTF8*", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "invalid", "1.2.3.45:*", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:The Wrong Value", false), | ||||||
|  | 			testAccStepLoginInvalid(t, connState), | ||||||
|  | 			testAccStepCert(t, "web", ca, "foo", "invalid", "2.1.1.1:*,2.1.1.2:The Wrong Value", false), | ||||||
| 			testAccStepLoginInvalid(t, connState), | 			testAccStepLoginInvalid(t, connState), | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| @@ -724,9 +783,9 @@ func TestBackend_mixed_constraints(t *testing.T) { | |||||||
| 	logicaltest.Test(t, logicaltest.TestCase{ | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
| 		Backend: testFactory(t), | 		Backend: testFactory(t), | ||||||
| 		Steps: []logicaltest.TestStep{ | 		Steps: []logicaltest.TestStep{ | ||||||
| 			testAccStepCert(t, "1unconstrained", ca, "foo", "", false), | 			testAccStepCert(t, "1unconstrained", ca, "foo", "", "", false), | ||||||
| 			testAccStepCert(t, "2matching", ca, "foo", "*.example.com,whatever", false), | 			testAccStepCert(t, "2matching", ca, "foo", "*.example.com,whatever", "", false), | ||||||
| 			testAccStepCert(t, "3invalid", ca, "foo", "invalid", false), | 			testAccStepCert(t, "3invalid", ca, "foo", "invalid", "", false), | ||||||
| 			testAccStepLogin(t, connState), | 			testAccStepLogin(t, connState), | ||||||
| 			// Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match | 			// Assumes CertEntries are processed in alphabetical order (due to store.List), so we only match 2matching if 1unconstrained doesn't match | ||||||
| 			testAccStepLoginWithName(t, connState, "2matching"), | 			testAccStepLoginWithName(t, connState, "2matching"), | ||||||
| @@ -826,7 +885,7 @@ func testAccStepLoginDefaultLease(t *testing.T, connState tls.ConnectionState) l | |||||||
| 		Unauthenticated: true, | 		Unauthenticated: true, | ||||||
| 		ConnState:       &connState, | 		ConnState:       &connState, | ||||||
| 		Check: func(resp *logical.Response) error { | 		Check: func(resp *logical.Response) error { | ||||||
| 			if resp.Auth.TTL != 300*time.Second { | 			if resp.Auth.TTL != 1000*time.Second { | ||||||
| 				t.Fatalf("bad lease length: %#v", resp.Auth) | 				t.Fatalf("bad lease length: %#v", resp.Auth) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -906,17 +965,18 @@ func testAccStepListCerts( | |||||||
| } | } | ||||||
|  |  | ||||||
| func testAccStepCert( | func testAccStepCert( | ||||||
| 	t *testing.T, name string, cert []byte, policies string, allowedNames string, expectError bool) logicaltest.TestStep { | 	t *testing.T, name string, cert []byte, policies string, allowedNames string, requiredExtensions string, expectError bool) logicaltest.TestStep { | ||||||
| 	return logicaltest.TestStep{ | 	return logicaltest.TestStep{ | ||||||
| 		Operation: logical.UpdateOperation, | 		Operation: logical.UpdateOperation, | ||||||
| 		Path:      "certs/" + name, | 		Path:      "certs/" + name, | ||||||
| 		ErrorOk:   expectError, | 		ErrorOk:   expectError, | ||||||
| 		Data: map[string]interface{}{ | 		Data: map[string]interface{}{ | ||||||
| 			"certificate":   string(cert), | 			"certificate":         string(cert), | ||||||
| 			"policies":      policies, | 			"policies":            policies, | ||||||
| 			"display_name":  name, | 			"display_name":        name, | ||||||
| 			"allowed_names": allowedNames, | 			"allowed_names":       allowedNames, | ||||||
| 			"lease":         1000, | 			"required_extensions": requiredExtensions, | ||||||
|  | 			"lease":               1000, | ||||||
| 		}, | 		}, | ||||||
| 		Check: func(resp *logical.Response) error { | 		Check: func(resp *logical.Response) error { | ||||||
| 			if resp == nil && expectError { | 			if resp == nil && expectError { | ||||||
| @@ -955,6 +1015,21 @@ func testAccStepCertTTL( | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func testAccStepCertMaxTTL( | ||||||
|  | 	t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep { | ||||||
|  | 	return logicaltest.TestStep{ | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "certs/" + name, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"certificate":  string(cert), | ||||||
|  | 			"policies":     policies, | ||||||
|  | 			"display_name": name, | ||||||
|  | 			"ttl":          "1000s", | ||||||
|  | 			"max_ttl":      "1200s", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func testAccStepCertNoLease( | func testAccStepCertNoLease( | ||||||
| 	t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep { | 	t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep { | ||||||
| 	return logicaltest.TestStep{ | 	return logicaltest.TestStep{ | ||||||
|   | |||||||
| @@ -45,6 +45,13 @@ Must be x509 PEM encoded.`, | |||||||
| At least one must exist in either the Common Name or SANs. Supports globbing.`, | At least one must exist in either the Common Name or SANs. Supports globbing.`, | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
|  | 			"required_extensions": &framework.FieldSchema{ | ||||||
|  | 				Type: framework.TypeCommaStringSlice, | ||||||
|  | 				Description: `A comma-separated string or array of extensions | ||||||
|  | formatted as "oid:value". Expects the extension value to be some type of ASN1 encoded string. | ||||||
|  | All values much match. Supports globbing on "value".`, | ||||||
|  | 			}, | ||||||
|  |  | ||||||
| 			"display_name": &framework.FieldSchema{ | 			"display_name": &framework.FieldSchema{ | ||||||
| 				Type: framework.TypeString, | 				Type: framework.TypeString, | ||||||
| 				Description: `The display name to use for clients using this | 				Description: `The display name to use for clients using this | ||||||
| @@ -67,6 +74,19 @@ seconds. Defaults to system/backend default TTL.`, | |||||||
| 				Description: `TTL for tokens issued by this backend. | 				Description: `TTL for tokens issued by this backend. | ||||||
| Defaults to system/backend default TTL time.`, | Defaults to system/backend default TTL time.`, | ||||||
| 			}, | 			}, | ||||||
|  | 			"max_ttl": &framework.FieldSchema{ | ||||||
|  | 				Type: framework.TypeDurationSecond, | ||||||
|  | 				Description: `Duration in either an integer number of seconds (3600) or | ||||||
|  | an integer time unit (60m) after which the  | ||||||
|  | issued token can no longer be renewed.`, | ||||||
|  | 			}, | ||||||
|  | 			"period": &framework.FieldSchema{ | ||||||
|  | 				Type: framework.TypeDurationSecond, | ||||||
|  | 				Description: `If set, indicates that the token generated using this role | ||||||
|  | should never expire. The token should be renewed within the | ||||||
|  | duration specified by this value. At each renewal, the token's | ||||||
|  | TTL will be set to the value of this parameter.`, | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Callbacks: map[logical.Operation]framework.OperationFunc{ | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
| @@ -124,17 +144,14 @@ func (b *backend) pathCertRead( | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	duration := cert.TTL |  | ||||||
| 	if duration == 0 { |  | ||||||
| 		duration = b.System().DefaultLeaseTTL() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &logical.Response{ | 	return &logical.Response{ | ||||||
| 		Data: map[string]interface{}{ | 		Data: map[string]interface{}{ | ||||||
| 			"certificate":  cert.Certificate, | 			"certificate":  cert.Certificate, | ||||||
| 			"display_name": cert.DisplayName, | 			"display_name": cert.DisplayName, | ||||||
| 			"policies":     cert.Policies, | 			"policies":     cert.Policies, | ||||||
| 			"ttl":          duration / time.Second, | 			"ttl":          cert.TTL / time.Second, | ||||||
|  | 			"max_ttl":      cert.MaxTTL / time.Second, | ||||||
|  | 			"period":       cert.Period / time.Second, | ||||||
| 		}, | 		}, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| @@ -146,6 +163,48 @@ func (b *backend) pathCertWrite( | |||||||
| 	displayName := d.Get("display_name").(string) | 	displayName := d.Get("display_name").(string) | ||||||
| 	policies := policyutil.ParsePolicies(d.Get("policies")) | 	policies := policyutil.ParsePolicies(d.Get("policies")) | ||||||
| 	allowedNames := d.Get("allowed_names").([]string) | 	allowedNames := d.Get("allowed_names").([]string) | ||||||
|  | 	requiredExtensions := d.Get("required_extensions").([]string) | ||||||
|  |  | ||||||
|  | 	var resp logical.Response | ||||||
|  |  | ||||||
|  | 	// Parse the ttl (or lease duration) | ||||||
|  | 	systemDefaultTTL := b.System().DefaultLeaseTTL() | ||||||
|  | 	ttl := time.Duration(d.Get("ttl").(int)) * time.Second | ||||||
|  | 	if ttl == 0 { | ||||||
|  | 		ttl = time.Duration(d.Get("lease").(int)) * time.Second | ||||||
|  | 	} | ||||||
|  | 	if ttl > systemDefaultTTL { | ||||||
|  | 		resp.AddWarning(fmt.Sprintf("Given ttl of %d seconds is greater than current mount/system default of %d seconds", ttl/time.Second, systemDefaultTTL/time.Second)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if ttl < time.Duration(0) { | ||||||
|  | 		return logical.ErrorResponse("ttl cannot be negative"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Parse max_ttl | ||||||
|  | 	systemMaxTTL := b.System().MaxLeaseTTL() | ||||||
|  | 	maxTTL := time.Duration(d.Get("max_ttl").(int)) * time.Second | ||||||
|  | 	if maxTTL > systemMaxTTL { | ||||||
|  | 		resp.AddWarning(fmt.Sprintf("Given max_ttl of %d seconds is greater than current mount/system default of %d seconds", maxTTL/time.Second, systemMaxTTL/time.Second)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if maxTTL < time.Duration(0) { | ||||||
|  | 		return logical.ErrorResponse("max_ttl cannot be negative"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if maxTTL != 0 && ttl > maxTTL { | ||||||
|  | 		return logical.ErrorResponse("ttl should be shorter than max_ttl"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Parse period | ||||||
|  | 	period := time.Duration(d.Get("period").(int)) * time.Second | ||||||
|  | 	if period > systemMaxTTL { | ||||||
|  | 		resp.AddWarning(fmt.Sprintf("Given period of %d seconds is greater than the backend's maximum TTL of %d seconds", period/time.Second, systemMaxTTL/time.Second)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if period < time.Duration(0) { | ||||||
|  | 		return logical.ErrorResponse("period cannot be negative"), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Default the display name to the certificate name if not given | 	// Default the display name to the certificate name if not given | ||||||
| 	if displayName == "" { | 	if displayName == "" { | ||||||
| @@ -172,24 +231,15 @@ func (b *backend) pathCertWrite( | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	certEntry := &CertEntry{ | 	certEntry := &CertEntry{ | ||||||
| 		Name:         name, | 		Name:               name, | ||||||
| 		Certificate:  certificate, | 		Certificate:        certificate, | ||||||
| 		DisplayName:  displayName, | 		DisplayName:        displayName, | ||||||
| 		Policies:     policies, | 		Policies:           policies, | ||||||
| 		AllowedNames: allowedNames, | 		AllowedNames:       allowedNames, | ||||||
| 	} | 		RequiredExtensions: requiredExtensions, | ||||||
|  | 		TTL:                ttl, | ||||||
| 	// Parse the lease duration or default to backend/system default | 		MaxTTL:             maxTTL, | ||||||
| 	maxTTL := b.System().MaxLeaseTTL() | 		Period:             period, | ||||||
| 	ttl := time.Duration(d.Get("ttl").(int)) * time.Second |  | ||||||
| 	if ttl == time.Duration(0) { |  | ||||||
| 		ttl = time.Second * time.Duration(d.Get("lease").(int)) |  | ||||||
| 	} |  | ||||||
| 	if ttl > maxTTL { |  | ||||||
| 		return logical.ErrorResponse(fmt.Sprintf("Given TTL of %d seconds greater than current mount/system default of %d seconds", ttl/time.Second, maxTTL/time.Second)), nil |  | ||||||
| 	} |  | ||||||
| 	if ttl > time.Duration(0) { |  | ||||||
| 		certEntry.TTL = ttl |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Store it | 	// Store it | ||||||
| @@ -200,16 +250,24 @@ func (b *backend) pathCertWrite( | |||||||
| 	if err := req.Storage.Put(entry); err != nil { | 	if err := req.Storage.Put(entry); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return nil, nil |  | ||||||
|  | 	if len(resp.Warnings) == 0 { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type CertEntry struct { | type CertEntry struct { | ||||||
| 	Name         string | 	Name               string | ||||||
| 	Certificate  string | 	Certificate        string | ||||||
| 	DisplayName  string | 	DisplayName        string | ||||||
| 	Policies     []string | 	Policies           []string | ||||||
| 	TTL          time.Duration | 	TTL                time.Duration | ||||||
| 	AllowedNames []string | 	MaxTTL             time.Duration | ||||||
|  | 	Period             time.Duration | ||||||
|  | 	AllowedNames       []string | ||||||
|  | 	RequiredExtensions []string | ||||||
| } | } | ||||||
|  |  | ||||||
| const pathCertHelpSyn = ` | const pathCertHelpSyn = ` | ||||||
|   | |||||||
| @@ -4,11 +4,13 @@ import ( | |||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
|  | 	"encoding/asn1" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"encoding/pem" | 	"encoding/pem" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/helper/certutil" | 	"github.com/hashicorp/vault/helper/certutil" | ||||||
| 	"github.com/hashicorp/vault/helper/policyutil" | 	"github.com/hashicorp/vault/helper/policyutil" | ||||||
| @@ -84,9 +86,9 @@ func (b *backend) pathLogin( | |||||||
| 	skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId) | 	skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId) | ||||||
| 	akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId) | 	akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId) | ||||||
|  |  | ||||||
| 	// Generate a response |  | ||||||
| 	resp := &logical.Response{ | 	resp := &logical.Response{ | ||||||
| 		Auth: &logical.Auth{ | 		Auth: &logical.Auth{ | ||||||
|  | 			Period: matched.Entry.Period, | ||||||
| 			InternalData: map[string]interface{}{ | 			InternalData: map[string]interface{}{ | ||||||
| 				"subject_key_id":   skid, | 				"subject_key_id":   skid, | ||||||
| 				"authority_key_id": akid, | 				"authority_key_id": akid, | ||||||
| @@ -108,6 +110,22 @@ func (b *backend) pathLogin( | |||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if matched.Entry.MaxTTL > time.Duration(0) { | ||||||
|  | 		// Cap maxTTL to the sysview's max TTL | ||||||
|  | 		maxTTL := matched.Entry.MaxTTL | ||||||
|  | 		if maxTTL > b.System().MaxLeaseTTL() { | ||||||
|  | 			maxTTL = b.System().MaxLeaseTTL() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Cap TTL to MaxTTL | ||||||
|  | 		if resp.Auth.TTL > maxTTL { | ||||||
|  | 			resp.AddWarning(fmt.Sprintf("Effective TTL of '%s' exceeded the effective max_ttl of '%s'; TTL value is capped accordingly", (resp.Auth.TTL / time.Second), (maxTTL / time.Second))) | ||||||
|  | 			resp.Auth.TTL = maxTTL | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Generate a response | ||||||
| 	return resp, nil | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -134,7 +152,7 @@ func (b *backend) pathLoginRenew( | |||||||
|  |  | ||||||
| 		clientCerts := req.Connection.ConnState.PeerCertificates | 		clientCerts := req.Connection.ConnState.PeerCertificates | ||||||
| 		if len(clientCerts) == 0 { | 		if len(clientCerts) == 0 { | ||||||
| 			return nil, fmt.Errorf("no client certificate found") | 			return logical.ErrorResponse("no client certificate found"), nil | ||||||
| 		} | 		} | ||||||
| 		skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId) | 		skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId) | ||||||
| 		akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId) | 		akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId) | ||||||
| @@ -160,7 +178,12 @@ func (b *backend) pathLoginRenew( | |||||||
| 		return nil, fmt.Errorf("policies have changed, not renewing") | 		return nil, fmt.Errorf("policies have changed, not renewing") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return framework.LeaseExtend(cert.TTL, 0, b.System())(req, d) | 	resp, err := framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(req, d) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	resp.Auth.Period = cert.Period | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) { | func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) { | ||||||
| @@ -237,28 +260,70 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData | |||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool { | func (b *backend) matchesConstraints(clientCert *x509.Certificate, trustedChain []*x509.Certificate, config *ParsedCert) bool { | ||||||
|  | 	return !b.checkForChainInCRLs(trustedChain) && | ||||||
|  | 		b.matchesNames(clientCert, config) && | ||||||
|  | 		b.matchesCertificateExtenions(clientCert, config) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // matchesNames verifies that the certificate matches at least one configured | ||||||
|  | // allowed name | ||||||
|  | func (b *backend) matchesNames(clientCert *x509.Certificate, config *ParsedCert) bool { | ||||||
| 	// Default behavior (no names) is to allow all names | 	// Default behavior (no names) is to allow all names | ||||||
| 	nameMatched := len(config.Entry.AllowedNames) == 0 | 	if len(config.Entry.AllowedNames) == 0 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
| 	// At least one pattern must match at least one name if any patterns are specified | 	// At least one pattern must match at least one name if any patterns are specified | ||||||
| 	for _, allowedName := range config.Entry.AllowedNames { | 	for _, allowedName := range config.Entry.AllowedNames { | ||||||
| 		if glob.Glob(allowedName, clientCert.Subject.CommonName) { | 		if glob.Glob(allowedName, clientCert.Subject.CommonName) { | ||||||
| 			nameMatched = true | 			return true | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		for _, name := range clientCert.DNSNames { | 		for _, name := range clientCert.DNSNames { | ||||||
| 			if glob.Glob(allowedName, name) { | 			if glob.Glob(allowedName, name) { | ||||||
| 				nameMatched = true | 				return true | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		for _, name := range clientCert.EmailAddresses { | 		for _, name := range clientCert.EmailAddresses { | ||||||
| 			if glob.Glob(allowedName, name) { | 			if glob.Glob(allowedName, name) { | ||||||
| 				nameMatched = true | 				return true | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
| 	return !b.checkForChainInCRLs(trustedChain) && nameMatched | // matchesCertificateExtenions verifies that the certificate matches configured | ||||||
|  | // required extensions | ||||||
|  | func (b *backend) matchesCertificateExtenions(clientCert *x509.Certificate, config *ParsedCert) bool { | ||||||
|  | 	// If no required extensions, nothing to check here | ||||||
|  | 	if len(config.Entry.RequiredExtensions) == 0 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	// Fail fast if we have required extensions but no extensions on the cert | ||||||
|  | 	if len(clientCert.Extensions) == 0 { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Build Client Extensions Map for Constraint Matching | ||||||
|  | 	// x509 Writes Extensions in ASN1 with a bitstring tag, which results in the field | ||||||
|  | 	// including its ASN.1 type tag bytes. For the sake of simplicity, assume string type | ||||||
|  | 	// and drop the tag bytes. And get the number of bytes from the tag. | ||||||
|  | 	clientExtMap := make(map[string]string, len(clientCert.Extensions)) | ||||||
|  | 	for _, ext := range clientCert.Extensions { | ||||||
|  | 		var parsedValue string | ||||||
|  | 		asn1.Unmarshal(ext.Value, &parsedValue) | ||||||
|  | 		clientExtMap[ext.Id.String()] = parsedValue | ||||||
|  | 	} | ||||||
|  | 	// If any of the required extensions don't match the constraint fails | ||||||
|  | 	for _, requiredExt := range config.Entry.RequiredExtensions { | ||||||
|  | 		reqExt := strings.SplitN(requiredExt, ":", 2) | ||||||
|  | 		clientExtValue, clientExtValueOk := clientExtMap[reqExt[0]] | ||||||
|  | 		if !clientExtValueOk || !glob.Glob(reqExt[1], clientExtValue) { | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
| } | } | ||||||
|  |  | ||||||
| // loadTrustedCerts is used to load all the trusted certificates from the backend | // loadTrustedCerts is used to load all the trusted certificates from the backend | ||||||
|   | |||||||
| @@ -0,0 +1 @@ | |||||||
|  | 92223EAFBBEE17A3 | ||||||
							
								
								
									
										21
									
								
								builtin/credential/cert/test-fixtures/root/rootcawext.cnf
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								builtin/credential/cert/test-fixtures/root/rootcawext.cnf
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | [ req ] | ||||||
|  | default_bits = 2048 | ||||||
|  | encrypt_key = no | ||||||
|  | prompt = no | ||||||
|  | default_md = sha256 | ||||||
|  | req_extensions = req_v3 | ||||||
|  | distinguished_name = dn | ||||||
|  |  | ||||||
|  | [ dn ] | ||||||
|  | CN = example.com | ||||||
|  |  | ||||||
|  | [ req_v3 ] | ||||||
|  | subjectAltName = @alt_names | ||||||
|  | 2.1.1.1=ASN1:UTF8String:A UTF8String Extension | ||||||
|  | 2.1.1.2=ASN1:UTF8:A UTF8 Extension | ||||||
|  | 2.1.1.3=ASN1:IA5:An IA5 Extension | ||||||
|  | 2.1.1.4=ASN1:VISIBLE:A Visible Extension | ||||||
|  |  | ||||||
|  | [ alt_names ] | ||||||
|  | DNS.1 = example.com | ||||||
|  | IP.1 = 127.0.0.1 | ||||||
							
								
								
									
										19
									
								
								builtin/credential/cert/test-fixtures/root/rootcawext.csr
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								builtin/credential/cert/test-fixtures/root/rootcawext.csr
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | -----BEGIN CERTIFICATE REQUEST----- | ||||||
|  | MIIDAzCCAesCAQAwFjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3 | ||||||
|  | DQEBAQUAA4IBDwAwggEKAoIBAQDM2PrLyK/wVQIcnK362ZylDrIVMjFQzps/0AxM | ||||||
|  | ke+8MNPMArBlSAhnZus6qb0nN0nJrDLkHQgYqnSvK9N7VUv/xFblEcOLBlciLhyN | ||||||
|  | Wkm92+q/M/xOvUVmnYkN3XgTI5QNxF7ZWDFHmwCNV27RraQZou0hG7yvyoILLMQE | ||||||
|  | 3MnMCNM1nZ9JIuBMcRsZLGqQ1XNaQljboRVIUjimzkcfYyTruhLosTIbwForp78J | ||||||
|  | MzHHqVjtLJXPqUnRMS7KhGMj1f2mIswQzCv6F2PWEzNBbP4Gb67znKikKDs0RgyL | ||||||
|  | RyfizFNFJSC58XntK8jwHK1D8W3UepFf4K8xNFnhPoKWtWfJAgMBAAGggacwgaQG | ||||||
|  | CSqGSIb3DQEJDjGBljCBkzAcBgNVHREEFTATggtleGFtcGxlLmNvbYcEfwAAATAf | ||||||
|  | BgNRAQEEGAwWQSBVVEY4U3RyaW5nIEV4dGVuc2lvbjAZBgNRAQIEEgwQQSBVVEY4 | ||||||
|  | IEV4dGVuc2lvbjAZBgNRAQMEEhYQQW4gSUE1IEV4dGVuc2lvbjAcBgNRAQQEFRoT | ||||||
|  | QSBWaXNpYmxlIEV4dGVuc2lvbjANBgkqhkiG9w0BAQsFAAOCAQEAtYjewBcqAXxk | ||||||
|  | tDY0lpZid6ZvfngdDlDZX0vrs3zNppKNe5Sl+jsoDOexqTA7HQA/y1ru117sAEeB | ||||||
|  | yiqMeZ7oPk8b3w+BZUpab7p2qPMhZypKl93y/jGXGscc3jRbUBnym9S91PSq6wUd | ||||||
|  | f2aigSqFc9+ywFVdx5PnnZUfcrUQ2a+AweYEkGOzXX2Ga+Ige8grDMCzRgCoP5cW | ||||||
|  | kM5ghwZp5wYIBGrKBU9iDcBlmnNhYaGWf+dD00JtVDPNn2bJnCsJHIO0nklZgnrS | ||||||
|  | fli8VQ1nYPkONdkiRYLt6//6at1iNDoDgsVCChtlVkLpxFIKcDFUHlffZsc1kMFI | ||||||
|  | HTX579k8hA== | ||||||
|  | -----END CERTIFICATE REQUEST----- | ||||||
| @@ -0,0 +1,20 @@ | |||||||
|  | -----BEGIN CERTIFICATE----- | ||||||
|  | MIIDRjCCAi6gAwIBAgIJAJIiPq+77hejMA0GCSqGSIb3DQEBCwUAMBYxFDASBgNV | ||||||
|  | BAMTC2V4YW1wbGUuY29tMB4XDTE3MTEyOTE5MTgwM1oXDTI3MTEyNzE5MTgwM1ow | ||||||
|  | FjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw | ||||||
|  | ggEKAoIBAQDM2PrLyK/wVQIcnK362ZylDrIVMjFQzps/0AxMke+8MNPMArBlSAhn | ||||||
|  | Zus6qb0nN0nJrDLkHQgYqnSvK9N7VUv/xFblEcOLBlciLhyNWkm92+q/M/xOvUVm | ||||||
|  | nYkN3XgTI5QNxF7ZWDFHmwCNV27RraQZou0hG7yvyoILLMQE3MnMCNM1nZ9JIuBM | ||||||
|  | cRsZLGqQ1XNaQljboRVIUjimzkcfYyTruhLosTIbwForp78JMzHHqVjtLJXPqUnR | ||||||
|  | MS7KhGMj1f2mIswQzCv6F2PWEzNBbP4Gb67znKikKDs0RgyLRyfizFNFJSC58Xnt | ||||||
|  | K8jwHK1D8W3UepFf4K8xNFnhPoKWtWfJAgMBAAGjgZYwgZMwHAYDVR0RBBUwE4IL | ||||||
|  | ZXhhbXBsZS5jb22HBH8AAAEwHwYDUQEBBBgMFkEgVVRGOFN0cmluZyBFeHRlbnNp | ||||||
|  | b24wGQYDUQECBBIMEEEgVVRGOCBFeHRlbnNpb24wGQYDUQEDBBIWEEFuIElBNSBF | ||||||
|  | eHRlbnNpb24wHAYDUQEEBBUaE0EgVmlzaWJsZSBFeHRlbnNpb24wDQYJKoZIhvcN | ||||||
|  | AQELBQADggEBAGU/iA6saupEaGn/veVNCknFGDL7pst5D6eX/y9atXlBOdJe7ZJJ | ||||||
|  | XQRkeHJldA0khVpzH7Ryfi+/25WDuNz+XTZqmb4ppeV8g9amtqBwxziQ9UUwYrza | ||||||
|  | eDBqdXBaYp/iHUEHoceX4F44xuo80BIqwF0lD9TFNUFoILnF26ajhKX0xkGaiKTH | ||||||
|  | 6SbjBfHoQVMzOHokVRWregmgNycV+MAI9Ne9XkIZvdOYeNlcS9drZeJI3szkiaxB | ||||||
|  | WWaWaAr5UU2Z0yUCZnAIDMRcIiUbSEjIDz504sSuCzTctMOxWZu0r/0UrXRzwZZi | ||||||
|  | HAaKm3MUmBh733ChP4rTB58nr5DEr5rJ9P8= | ||||||
|  | -----END CERTIFICATE----- | ||||||
							
								
								
									
										28
									
								
								builtin/credential/cert/test-fixtures/root/rootcawextkey.pem
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								builtin/credential/cert/test-fixtures/root/rootcawextkey.pem
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | -----BEGIN PRIVATE KEY----- | ||||||
|  | MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDM2PrLyK/wVQIc | ||||||
|  | nK362ZylDrIVMjFQzps/0AxMke+8MNPMArBlSAhnZus6qb0nN0nJrDLkHQgYqnSv | ||||||
|  | K9N7VUv/xFblEcOLBlciLhyNWkm92+q/M/xOvUVmnYkN3XgTI5QNxF7ZWDFHmwCN | ||||||
|  | V27RraQZou0hG7yvyoILLMQE3MnMCNM1nZ9JIuBMcRsZLGqQ1XNaQljboRVIUjim | ||||||
|  | zkcfYyTruhLosTIbwForp78JMzHHqVjtLJXPqUnRMS7KhGMj1f2mIswQzCv6F2PW | ||||||
|  | EzNBbP4Gb67znKikKDs0RgyLRyfizFNFJSC58XntK8jwHK1D8W3UepFf4K8xNFnh | ||||||
|  | PoKWtWfJAgMBAAECggEAW7hLkzMok9N8PpNo0wjcuor58cOnkSbxHIFrAF3XmcvD | ||||||
|  | CXWqxa6bFLFgYcPejdCTmVkg8EKPfXvVAxn8dxyaCss+nRJ3G6ibGxLKdgAXRItT | ||||||
|  | cIk2T4svp+KhmzOur+MeR4vFbEuwxP8CIEclt3yoHVJ2Gnzw30UtNRO2MPcq48/C | ||||||
|  | ZODGeBqUif1EGjDAvlqu5kl/pcDBJ3ctIZdVUMYYW4R9JtzKsmwhX7CRCBm8k5hG | ||||||
|  | 2uzn8AKwpuVtfWcnX59UUmHGJ8mjETuNLARRAwWBWhl8f7wckmi+PKERJGEM2QE5 | ||||||
|  | /Voy0p22zmQ3waS8LgiI7YHCAEFqjVWNziVGdR36gQKBgQDxkpfkEsfa5PieIaaF | ||||||
|  | iQOO0rrjEJ9MBOQqmTDeclmDPNkM9qvCF/dqpJfOtliYFxd7JJ3OR2wKrBb5vGHt | ||||||
|  | qIB51Rnm9aDTM4OUEhnhvbPlERD0W+yWYXWRvqyHz0GYwEFGQ83h95GC/qfTosqy | ||||||
|  | LEzYLDafiPeNP+DG/HYRljAxUwKBgQDZFOWHEcZkSFPLNZiksHqs90OR2zIFxZcx | ||||||
|  | SrbkjqXjRjehWEAwgpvQ/quSBxrE2E8xXgVm90G1JpWzxjUfKKQRM6solQeEpnwY | ||||||
|  | kCy2Ozij/TtbLNRlU65UQ+nMto8KTSIyJbxxdOZxYdtJAJQp1FJO1a1WC11z4+zh | ||||||
|  | lnLV1O5S8wKBgQCDf/QU4DBQtNGtas315Oa96XJ4RkUgoYz+r1NN09tsOERC7UgE | ||||||
|  | KP2y3JQSn2pMqE1M6FrKvlBO4uzC10xLja0aJOmrssvwDBu1D8FtA9IYgJjFHAEG | ||||||
|  | v1i7lJrgdu7TUtx1flVli1l3gF4lM3m5UaonBrJZV7rB9iLKzwUKf8IOJwKBgFt/ | ||||||
|  | QktPA6brEV56Za8sr1hOFA3bLNdf9B0Tl8j4ExWbWAFKeCu6MUDCxsAS/IZxgdeW | ||||||
|  | AILovqpC7CBM78EFWTni5EaDohqYLYAQ7LeWeIYuSyFf4Nogjj74LQha/iliX4Jx | ||||||
|  | g17y3dp2W34Gn2yOEG8oAxpcSfR54jMnPZnBWP5fAoGBAMNAd3oa/xq9A5v719ik | ||||||
|  | naD7PdrjBdhnPk4egzMDv54y6pCFlvFbEiBduBWTmiVa7dSzhYtmEbri2WrgARlu | ||||||
|  | vkfTnVH9E8Hnm4HTbNn+ebxrofq1AOAvdApSoslsOP1NT9J6zB89RzChJyzjbIQR | ||||||
|  | Gevrutb4uO9qpB1jDVoMmGde | ||||||
|  | -----END PRIVATE KEY----- | ||||||
| @@ -5,6 +5,7 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/google/go-github/github" | 	"github.com/google/go-github/github" | ||||||
| 	"github.com/hashicorp/go-cleanhttp" | 	"github.com/hashicorp/go-cleanhttp" | ||||||
|  | 	"github.com/hashicorp/vault/helper/mfa" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| @@ -35,11 +36,11 @@ func Backend() *backend { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	allPaths := append(b.TeamMap.Paths(), b.UserMap.Paths()...) | 	allPaths := append(b.TeamMap.Paths(), b.UserMap.Paths()...) | ||||||
|  |  | ||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: backendHelp, | 		Help: backendHelp, | ||||||
|  |  | ||||||
| 		PathsSpecial: &logical.Paths{ | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			Root: mfa.MFARootPaths(), | ||||||
| 			Unauthenticated: []string{ | 			Unauthenticated: []string{ | ||||||
| 				"login", | 				"login", | ||||||
| 			}, | 			}, | ||||||
| @@ -47,9 +48,7 @@ func Backend() *backend { | |||||||
|  |  | ||||||
| 		Paths: append([]*framework.Path{ | 		Paths: append([]*framework.Path{ | ||||||
| 			pathConfig(&b), | 			pathConfig(&b), | ||||||
| 			pathLogin(&b), | 		}, append(allPaths, mfa.MFAPaths(b.Backend, pathLogin(&b))...)...), | ||||||
| 		}, allPaths...), |  | ||||||
|  |  | ||||||
| 		AuthRenew:   b.pathLoginRenew, | 		AuthRenew:   b.pathLoginRenew, | ||||||
| 		BackendType: logical.TypeCredential, | 		BackendType: logical.TypeCredential, | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -74,7 +74,7 @@ func (b *backend) pathLogin( | |||||||
| 		return logical.ErrorResponse(fmt.Sprintf("error sanitizing TTLs: %s", err)), nil | 		return logical.ErrorResponse(fmt.Sprintf("error sanitizing TTLs: %s", err)), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &logical.Response{ | 	resp := &logical.Response{ | ||||||
| 		Auth: &logical.Auth{ | 		Auth: &logical.Auth{ | ||||||
| 			InternalData: map[string]interface{}{ | 			InternalData: map[string]interface{}{ | ||||||
| 				"token": token, | 				"token": token, | ||||||
| @@ -93,7 +93,18 @@ func (b *backend) pathLogin( | |||||||
| 				Name: *verifyResp.User.Login, | 				Name: *verifyResp.User.Login, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	}, nil | 	} | ||||||
|  |  | ||||||
|  | 	for _, teamName := range verifyResp.TeamNames { | ||||||
|  | 		if teamName == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{ | ||||||
|  | 			Name: teamName, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathLoginRenew( | func (b *backend) pathLoginRenew( | ||||||
| @@ -125,7 +136,22 @@ func (b *backend) pathLoginRenew( | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return framework.LeaseExtend(config.TTL, config.MaxTTL, b.System())(req, d) |  | ||||||
|  | 	resp, err := framework.LeaseExtend(config.TTL, config.MaxTTL, b.System())(req, d) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Remove old aliases | ||||||
|  | 	resp.Auth.GroupAliases = nil | ||||||
|  |  | ||||||
|  | 	for _, teamName := range verifyResp.TeamNames { | ||||||
|  | 		resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{ | ||||||
|  | 			Name: teamName, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) verifyCredentials(req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) { | func (b *backend) verifyCredentials(req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) { | ||||||
| @@ -233,14 +259,16 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &verifyCredentialsResp{ | 	return &verifyCredentialsResp{ | ||||||
| 		User:     user, | 		User:      user, | ||||||
| 		Org:      org, | 		Org:       org, | ||||||
| 		Policies: append(groupPoliciesList, userPoliciesList...), | 		Policies:  append(groupPoliciesList, userPoliciesList...), | ||||||
|  | 		TeamNames: teamNames, | ||||||
| 	}, nil, nil | 	}, nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type verifyCredentialsResp struct { | type verifyCredentialsResp struct { | ||||||
| 	User     *github.User | 	User      *github.User | ||||||
| 	Org      *github.Organization | 	Org       *github.Organization | ||||||
| 	Policies []string | 	Policies  []string | ||||||
|  | 	TeamNames []string | ||||||
| } | } | ||||||
|   | |||||||
| @@ -31,6 +31,10 @@ func Backend() *backend { | |||||||
| 			Unauthenticated: []string{ | 			Unauthenticated: []string{ | ||||||
| 				"login/*", | 				"login/*", | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Paths: append([]*framework.Path{ | 		Paths: append([]*framework.Path{ | ||||||
| @@ -88,22 +92,22 @@ func EscapeLDAPValue(input string) string { | |||||||
| 	return input | 	return input | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, error) { | func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) { | ||||||
|  |  | ||||||
| 	cfg, err := b.Config(req) | 	cfg, err := b.Config(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, nil, nil, err | ||||||
| 	} | 	} | ||||||
| 	if cfg == nil { | 	if cfg == nil { | ||||||
| 		return nil, logical.ErrorResponse("ldap backend not configured"), nil | 		return nil, logical.ErrorResponse("ldap backend not configured"), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c, err := cfg.DialLDAP() | 	c, err := cfg.DialLDAP() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, logical.ErrorResponse(err.Error()), nil | 		return nil, logical.ErrorResponse(err.Error()), nil, nil | ||||||
| 	} | 	} | ||||||
| 	if c == nil { | 	if c == nil { | ||||||
| 		return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil | 		return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Clean connection | 	// Clean connection | ||||||
| @@ -111,7 +115,7 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
|  |  | ||||||
| 	userBindDN, err := b.getUserBindDN(cfg, c, username) | 	userBindDN, err := b.getUserBindDN(cfg, c, username) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, logical.ErrorResponse(err.Error()), nil | 		return nil, logical.ErrorResponse(err.Error()), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if b.Logger().IsDebug() { | 	if b.Logger().IsDebug() { | ||||||
| @@ -119,7 +123,7 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if cfg.DenyNullBind && len(password) == 0 { | 	if cfg.DenyNullBind && len(password) == 0 { | ||||||
| 		return nil, logical.ErrorResponse("password cannot be of zero length when passwordless binds are being denied"), nil | 		return nil, logical.ErrorResponse("password cannot be of zero length when passwordless binds are being denied"), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Try to bind as the login user. This is where the actual authentication takes place. | 	// Try to bind as the login user. This is where the actual authentication takes place. | ||||||
| @@ -129,14 +133,14 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
| 		err = c.UnauthenticatedBind(userBindDN) | 		err = c.UnauthenticatedBind(userBindDN) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil | 		return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// We re-bind to the BindDN if it's defined because we assume | 	// We re-bind to the BindDN if it's defined because we assume | ||||||
| 	// the BindDN should be the one to search, not the user logging in. | 	// the BindDN should be the one to search, not the user logging in. | ||||||
| 	if cfg.BindDN != "" && cfg.BindPassword != "" { | 	if cfg.BindDN != "" && cfg.BindPassword != "" { | ||||||
| 		if err := c.Bind(cfg.BindDN, cfg.BindPassword); err != nil { | 		if err := c.Bind(cfg.BindDN, cfg.BindPassword); err != nil { | ||||||
| 			return nil, logical.ErrorResponse(fmt.Sprintf("Encountered an error while attempting to re-bind with the BindDN User: %s", err.Error())), nil | 			return nil, logical.ErrorResponse(fmt.Sprintf("Encountered an error while attempting to re-bind with the BindDN User: %s", err.Error())), nil, nil | ||||||
| 		} | 		} | ||||||
| 		if b.Logger().IsDebug() { | 		if b.Logger().IsDebug() { | ||||||
| 			b.Logger().Debug("auth/ldap: Re-Bound to original BindDN") | 			b.Logger().Debug("auth/ldap: Re-Bound to original BindDN") | ||||||
| @@ -145,12 +149,12 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
|  |  | ||||||
| 	userDN, err := b.getUserDN(cfg, c, userBindDN) | 	userDN, err := b.getUserDN(cfg, c, userBindDN) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, logical.ErrorResponse(err.Error()), nil | 		return nil, logical.ErrorResponse(err.Error()), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ldapGroups, err := b.getLdapGroups(cfg, c, userDN, username) | 	ldapGroups, err := b.getLdapGroups(cfg, c, userDN, username) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, logical.ErrorResponse(err.Error()), nil | 		return nil, logical.ErrorResponse(err.Error()), nil, nil | ||||||
| 	} | 	} | ||||||
| 	if b.Logger().IsDebug() { | 	if b.Logger().IsDebug() { | ||||||
| 		b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups) | 		b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups) | ||||||
| @@ -199,10 +203,10 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		ldapResponse.Data["error"] = errStr | 		ldapResponse.Data["error"] = errStr | ||||||
| 		return nil, ldapResponse, nil | 		return nil, ldapResponse, nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return policies, ldapResponse, nil | 	return policies, ldapResponse, allGroups, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| /* | /* | ||||||
|   | |||||||
| @@ -55,7 +55,7 @@ func (b *backend) pathLogin( | |||||||
| 	username := d.Get("username").(string) | 	username := d.Get("username").(string) | ||||||
| 	password := d.Get("password").(string) | 	password := d.Get("password").(string) | ||||||
|  |  | ||||||
| 	policies, resp, err := b.Login(req, username, password) | 	policies, resp, groupNames, err := b.Login(req, username, password) | ||||||
| 	// Handle an internal error | 	// Handle an internal error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -87,6 +87,15 @@ func (b *backend) pathLogin( | |||||||
| 			Name: username, | 			Name: username, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	for _, groupName := range groupNames { | ||||||
|  | 		if groupName == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{ | ||||||
|  | 			Name: groupName, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
| 	return resp, nil | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -96,7 +105,7 @@ func (b *backend) pathLoginRenew( | |||||||
| 	username := req.Auth.Metadata["username"] | 	username := req.Auth.Metadata["username"] | ||||||
| 	password := req.Auth.InternalData["password"].(string) | 	password := req.Auth.InternalData["password"].(string) | ||||||
|  |  | ||||||
| 	loginPolicies, resp, err := b.Login(req, username, password) | 	loginPolicies, resp, groupNames, err := b.Login(req, username, password) | ||||||
| 	if len(loginPolicies) == 0 { | 	if len(loginPolicies) == 0 { | ||||||
| 		return resp, err | 		return resp, err | ||||||
| 	} | 	} | ||||||
| @@ -105,7 +114,21 @@ func (b *backend) pathLoginRenew( | |||||||
| 		return nil, fmt.Errorf("policies have changed, not renewing") | 		return nil, fmt.Errorf("policies have changed, not renewing") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return framework.LeaseExtend(0, 0, b.System())(req, d) | 	resp, err = framework.LeaseExtend(0, 0, b.System())(req, d) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Remove old aliases | ||||||
|  | 	resp.Auth.GroupAliases = nil | ||||||
|  |  | ||||||
|  | 	for _, groupName := range groupNames { | ||||||
|  | 		resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{ | ||||||
|  | 			Name: groupName, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| const pathLoginSyn = ` | const pathLoginSyn = ` | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
|  |  | ||||||
| 	"github.com/chrismalek/oktasdk-go/okta" | 	"github.com/chrismalek/oktasdk-go/okta" | ||||||
|  | 	"github.com/hashicorp/vault/helper/mfa" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
| @@ -22,9 +23,14 @@ func Backend() *backend { | |||||||
| 		Help: backendHelp, | 		Help: backendHelp, | ||||||
|  |  | ||||||
| 		PathsSpecial: &logical.Paths{ | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			Root: mfa.MFARootPaths(), | ||||||
|  |  | ||||||
| 			Unauthenticated: []string{ | 			Unauthenticated: []string{ | ||||||
| 				"login/*", | 				"login/*", | ||||||
| 			}, | 			}, | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Paths: append([]*framework.Path{ | 		Paths: append([]*framework.Path{ | ||||||
| @@ -33,8 +39,9 @@ func Backend() *backend { | |||||||
| 			pathGroups(&b), | 			pathGroups(&b), | ||||||
| 			pathUsersList(&b), | 			pathUsersList(&b), | ||||||
| 			pathGroupsList(&b), | 			pathGroupsList(&b), | ||||||
| 			pathLogin(&b), | 		}, | ||||||
| 		}), | 			mfa.MFAPaths(b.Backend, pathLogin(&b))..., | ||||||
|  | 		), | ||||||
|  |  | ||||||
| 		AuthRenew:   b.pathLoginRenew, | 		AuthRenew:   b.pathLoginRenew, | ||||||
| 		BackendType: logical.TypeCredential, | 		BackendType: logical.TypeCredential, | ||||||
| @@ -47,13 +54,13 @@ type backend struct { | |||||||
| 	*framework.Backend | 	*framework.Backend | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, error) { | func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) { | ||||||
| 	cfg, err := b.Config(req.Storage) | 	cfg, err := b.Config(req.Storage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, nil, nil, err | ||||||
| 	} | 	} | ||||||
| 	if cfg == nil { | 	if cfg == nil { | ||||||
| 		return nil, logical.ErrorResponse("Okta auth method not configured"), nil | 		return nil, logical.ErrorResponse("Okta auth method not configured"), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	client := cfg.OktaClient() | 	client := cfg.OktaClient() | ||||||
| @@ -71,16 +78,16 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
| 		"password": password, | 		"password": password, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, nil, nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var result authResult | 	var result authResult | ||||||
| 	rsp, err := client.Do(authReq, &result) | 	rsp, err := client.Do(authReq, &result) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil | 		return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil, nil | ||||||
| 	} | 	} | ||||||
| 	if rsp == nil { | 	if rsp == nil { | ||||||
| 		return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil | 		return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	oktaResponse := &logical.Response{ | 	oktaResponse := &logical.Response{ | ||||||
| @@ -92,7 +99,7 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
| 	if cfg.Token != "" { | 	if cfg.Token != "" { | ||||||
| 		oktaGroups, err := b.getOktaGroups(client, &result.Embedded.User) | 		oktaGroups, err := b.getOktaGroups(client, &result.Embedded.User) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, logical.ErrorResponse(fmt.Sprintf("okta failure retrieving groups: %v", err)), nil | 			return nil, logical.ErrorResponse(fmt.Sprintf("okta failure retrieving groups: %v", err)), nil, nil | ||||||
| 		} | 		} | ||||||
| 		if len(oktaGroups) == 0 { | 		if len(oktaGroups) == 0 { | ||||||
| 			errString := fmt.Sprintf( | 			errString := fmt.Sprintf( | ||||||
| @@ -142,10 +149,10 @@ func (b *backend) Login(req *logical.Request, username string, password string) | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		oktaResponse.Data["error"] = errStr | 		oktaResponse.Data["error"] = errStr | ||||||
| 		return nil, oktaResponse, nil | 		return nil, oktaResponse, nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return policies, oktaResponse, nil | 	return policies, oktaResponse, allGroups, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) getOktaGroups(client *okta.Client, user *okta.User) ([]string, error) { | func (b *backend) getOktaGroups(client *okta.Client, user *okta.User) ([]string, error) { | ||||||
|   | |||||||
| @@ -38,6 +38,15 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro | |||||||
| 		"password": password, | 		"password": password, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	mfa_method, ok := m["method"] | ||||||
|  | 	if ok { | ||||||
|  | 		data["method"] = mfa_method | ||||||
|  | 	} | ||||||
|  | 	mfa_passcode, ok := m["passcode"] | ||||||
|  | 	if ok { | ||||||
|  | 		data["passcode"] = mfa_passcode | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	path := fmt.Sprintf("auth/%s/login/%s", mount, username) | 	path := fmt.Sprintf("auth/%s/login/%s", mount, username) | ||||||
| 	secret, err := c.Logical().Write(path, data) | 	secret, err := c.Logical().Write(path, data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -57,7 +57,7 @@ func (b *backend) pathLogin( | |||||||
| 	username := d.Get("username").(string) | 	username := d.Get("username").(string) | ||||||
| 	password := d.Get("password").(string) | 	password := d.Get("password").(string) | ||||||
|  |  | ||||||
| 	policies, resp, err := b.Login(req, username, password) | 	policies, resp, groupNames, err := b.Login(req, username, password) | ||||||
| 	// Handle an internal error | 	// Handle an internal error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -96,6 +96,16 @@ func (b *backend) pathLogin( | |||||||
| 			Name: username, | 			Name: username, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	for _, groupName := range groupNames { | ||||||
|  | 		if groupName == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{ | ||||||
|  | 			Name: groupName, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return resp, nil | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -105,7 +115,7 @@ func (b *backend) pathLoginRenew( | |||||||
| 	username := req.Auth.Metadata["username"] | 	username := req.Auth.Metadata["username"] | ||||||
| 	password := req.Auth.InternalData["password"].(string) | 	password := req.Auth.InternalData["password"].(string) | ||||||
|  |  | ||||||
| 	loginPolicies, resp, err := b.Login(req, username, password) | 	loginPolicies, resp, groupNames, err := b.Login(req, username, password) | ||||||
| 	if len(loginPolicies) == 0 { | 	if len(loginPolicies) == 0 { | ||||||
| 		return resp, err | 		return resp, err | ||||||
| 	} | 	} | ||||||
| @@ -119,7 +129,22 @@ func (b *backend) pathLoginRenew( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return framework.LeaseExtend(cfg.TTL, cfg.MaxTTL, b.System())(req, d) | 	resp, err = framework.LeaseExtend(cfg.TTL, cfg.MaxTTL, b.System())(req, d) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Remove old aliases | ||||||
|  | 	resp.Auth.GroupAliases = nil | ||||||
|  |  | ||||||
|  | 	for _, groupName := range groupNames { | ||||||
|  | 		resp.Auth.GroupAliases = append(resp.Auth.GroupAliases, &logical.Alias{ | ||||||
|  | 			Name: groupName, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) getConfig(req *logical.Request) (*ConfigEntry, error) { | func (b *backend) getConfig(req *logical.Request) (*ConfigEntry, error) { | ||||||
|   | |||||||
| @@ -26,6 +26,10 @@ func Backend() *backend { | |||||||
| 				"login", | 				"login", | ||||||
| 				"login/*", | 				"login/*", | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Paths: append([]*framework.Path{ | 		Paths: append([]*framework.Path{ | ||||||
|   | |||||||
| @@ -39,7 +39,7 @@ func pathConfig(b *backend) *framework.Path { | |||||||
| 			"read_timeout": &framework.FieldSchema{ | 			"read_timeout": &framework.FieldSchema{ | ||||||
| 				Type:        framework.TypeDurationSecond, | 				Type:        framework.TypeDurationSecond, | ||||||
| 				Default:     10, | 				Default:     10, | ||||||
| 				Description: "Number of seconds before response times out (default: 10). Note: kept for backwards compatibility, currently unused.", | 				Description: "Number of seconds before response times out (default: 10)", | ||||||
| 			}, | 			}, | ||||||
| 			"nas_port": &framework.FieldSchema{ | 			"nas_port": &framework.FieldSchema{ | ||||||
| 				Type:        framework.TypeInt, | 				Type:        framework.TypeInt, | ||||||
|   | |||||||
| @@ -154,7 +154,9 @@ func (b *backend) RadiusLogin(req *logical.Request, username string, password st | |||||||
| 			Timeout: time.Duration(cfg.DialTimeout) * time.Second, | 			Timeout: time.Duration(cfg.DialTimeout) * time.Second, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	received, err := client.Exchange(context.Background(), packet, hostport) | 	ctx, cancelFunc := context.WithTimeout(context.Background(), time.Duration(cfg.ReadTimeout)*time.Second) | ||||||
|  | 	received, err := client.Exchange(ctx, packet, hostport) | ||||||
|  | 	cancelFunc() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, logical.ErrorResponse(err.Error()), nil | 		return nil, logical.ErrorResponse(err.Error()), nil | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -25,6 +25,9 @@ func Backend() *backend { | |||||||
| 			LocalStorage: []string{ | 			LocalStorage: []string{ | ||||||
| 				framework.WALPrefix, | 				framework.WALPrefix, | ||||||
| 			}, | 			}, | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/root", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
|   | |||||||
| @@ -13,8 +13,9 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func getRootConfig(s logical.Storage) (*aws.Config, error) { | func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) { | ||||||
| 	credsConfig := &awsutil.CredentialsConfig{} | 	credsConfig := &awsutil.CredentialsConfig{} | ||||||
|  | 	var endpoint string | ||||||
|  |  | ||||||
| 	entry, err := s.Get("config/root") | 	entry, err := s.Get("config/root") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -29,6 +30,12 @@ func getRootConfig(s logical.Storage) (*aws.Config, error) { | |||||||
| 		credsConfig.AccessKey = config.AccessKey | 		credsConfig.AccessKey = config.AccessKey | ||||||
| 		credsConfig.SecretKey = config.SecretKey | 		credsConfig.SecretKey = config.SecretKey | ||||||
| 		credsConfig.Region = config.Region | 		credsConfig.Region = config.Region | ||||||
|  | 		switch { | ||||||
|  | 		case clientType == "iam" && config.IAMEndpoint != "": | ||||||
|  | 			endpoint = *aws.String(config.IAMEndpoint) | ||||||
|  | 		case clientType == "sts" && config.STSEndpoint != "": | ||||||
|  | 			endpoint = *aws.String(config.STSEndpoint) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if credsConfig.Region == "" { | 	if credsConfig.Region == "" { | ||||||
| @@ -51,16 +58,19 @@ func getRootConfig(s logical.Storage) (*aws.Config, error) { | |||||||
| 	return &aws.Config{ | 	return &aws.Config{ | ||||||
| 		Credentials: creds, | 		Credentials: creds, | ||||||
| 		Region:      aws.String(credsConfig.Region), | 		Region:      aws.String(credsConfig.Region), | ||||||
|  | 		Endpoint:    &endpoint, | ||||||
| 		HTTPClient:  cleanhttp.DefaultClient(), | 		HTTPClient:  cleanhttp.DefaultClient(), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func clientIAM(s logical.Storage) (*iam.IAM, error) { | func clientIAM(s logical.Storage) (*iam.IAM, error) { | ||||||
| 	awsConfig, err := getRootConfig(s) | 	awsConfig, err := getRootConfig(s, "iam") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	client := iam.New(session.New(awsConfig)) | 	client := iam.New(session.New(awsConfig)) | ||||||
|  |  | ||||||
| 	if client == nil { | 	if client == nil { | ||||||
| 		return nil, fmt.Errorf("could not obtain iam client") | 		return nil, fmt.Errorf("could not obtain iam client") | ||||||
| 	} | 	} | ||||||
| @@ -68,11 +78,12 @@ func clientIAM(s logical.Storage) (*iam.IAM, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func clientSTS(s logical.Storage) (*sts.STS, error) { | func clientSTS(s logical.Storage) (*sts.STS, error) { | ||||||
| 	awsConfig, err := getRootConfig(s) | 	awsConfig, err := getRootConfig(s, "sts") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	client := sts.New(session.New(awsConfig)) | 	client := sts.New(session.New(awsConfig)) | ||||||
|  |  | ||||||
| 	if client == nil { | 	if client == nil { | ||||||
| 		return nil, fmt.Errorf("could not obtain sts client") | 		return nil, fmt.Errorf("could not obtain sts client") | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -23,6 +23,14 @@ func pathConfigRoot() *framework.Path { | |||||||
| 				Type:        framework.TypeString, | 				Type:        framework.TypeString, | ||||||
| 				Description: "Region for API calls.", | 				Description: "Region for API calls.", | ||||||
| 			}, | 			}, | ||||||
|  | 			"iam_endpoint": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "Endpoint to custom IAM server URL", | ||||||
|  | 			}, | ||||||
|  | 			"sts_endpoint": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "Endpoint to custom STS server URL", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Callbacks: map[logical.Operation]framework.OperationFunc{ | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
| @@ -37,11 +45,15 @@ func pathConfigRoot() *framework.Path { | |||||||
| func pathConfigRootWrite( | func pathConfigRootWrite( | ||||||
| 	req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | 	req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	region := data.Get("region").(string) | 	region := data.Get("region").(string) | ||||||
|  | 	iamendpoint := data.Get("iam_endpoint").(string) | ||||||
|  | 	stsendpoint := data.Get("sts_endpoint").(string) | ||||||
|  |  | ||||||
| 	entry, err := logical.StorageEntryJSON("config/root", rootConfig{ | 	entry, err := logical.StorageEntryJSON("config/root", rootConfig{ | ||||||
| 		AccessKey: data.Get("access_key").(string), | 		AccessKey:   data.Get("access_key").(string), | ||||||
| 		SecretKey: data.Get("secret_key").(string), | 		SecretKey:   data.Get("secret_key").(string), | ||||||
| 		Region:    region, | 		IAMEndpoint: iamendpoint, | ||||||
|  | 		STSEndpoint: stsendpoint, | ||||||
|  | 		Region:      region, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -55,9 +67,11 @@ func pathConfigRootWrite( | |||||||
| } | } | ||||||
|  |  | ||||||
| type rootConfig struct { | type rootConfig struct { | ||||||
| 	AccessKey string `json:"access_key"` | 	AccessKey   string `json:"access_key"` | ||||||
| 	SecretKey string `json:"secret_key"` | 	SecretKey   string `json:"secret_key"` | ||||||
| 	Region    string `json:"region"` | 	IAMEndpoint string `json:"iam_endpoint"` | ||||||
|  | 	STSEndpoint string `json:"sts_endpoint"` | ||||||
|  | 	Region      string `json:"region"` | ||||||
| } | } | ||||||
|  |  | ||||||
| const pathConfigRootHelpSyn = ` | const pathConfigRootHelpSyn = ` | ||||||
|   | |||||||
| @@ -25,6 +25,12 @@ func Backend() *backend { | |||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: strings.TrimSpace(backendHelp), | 		Help: strings.TrimSpace(backendHelp), | ||||||
|  |  | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/connection", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathConfigConnection(&b), | 			pathConfigConnection(&b), | ||||||
| 			pathRoles(&b), | 			pathRoles(&b), | ||||||
|   | |||||||
| @@ -4,73 +4,76 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"strconv" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
|  | 	"github.com/gocql/gocql" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	logicaltest "github.com/hashicorp/vault/logical/testing" | 	logicaltest "github.com/hashicorp/vault/logical/testing" | ||||||
| 	"github.com/mitchellh/mapstructure" | 	"github.com/mitchellh/mapstructure" | ||||||
| 	dockertest "gopkg.in/ory-am/dockertest.v2" | 	dockertest "gopkg.in/ory-am/dockertest.v3" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	testImagePull sync.Once | 	testImagePull sync.Once | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) { | func prepareCassandraTestContainer(t *testing.T) (func(), string, int) { | ||||||
| 	if os.Getenv("CASSANDRA_HOST") != "" { | 	if os.Getenv("CASSANDRA_HOST") != "" { | ||||||
| 		return "", os.Getenv("CASSANDRA_HOST") | 		return func() {}, os.Getenv("CASSANDRA_HOST"), 0 | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Without this the checks for whether the container has started seem to | 	pool, err := dockertest.NewPool("") | ||||||
| 	// never actually pass. There's really no reason to expose the test | 	if err != nil { | ||||||
| 	// containers, so don't. | 		t.Fatalf("Failed to connect to docker: %s", err) | ||||||
| 	dockertest.BindDockerToLocalhost = "yep" | 	} | ||||||
|  |  | ||||||
| 	testImagePull.Do(func() { |  | ||||||
| 		dockertest.Pull("cassandra") |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	cwd, _ := os.Getwd() | 	cwd, _ := os.Getwd() | ||||||
|  | 	cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd) | ||||||
|  |  | ||||||
| 	cid, connErr := dockertest.ConnectToCassandra("latest", 60, 1000*time.Millisecond, func(connURL string) bool { | 	ro := &dockertest.RunOptions{ | ||||||
| 		// This will cause a validation to run | 		Repository: "cassandra", | ||||||
| 		resp, err := b.HandleRequest(&logical.Request{ | 		Tag:        "latest", | ||||||
| 			Storage:   s, | 		Env:        []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"}, | ||||||
| 			Operation: logical.UpdateOperation, | 		Mounts:     []string{cassandraMountPath}, | ||||||
| 			Path:      "config/connection", |  | ||||||
| 			Data: map[string]interface{}{ |  | ||||||
| 				"hosts":            connURL, |  | ||||||
| 				"username":         "cassandra", |  | ||||||
| 				"password":         "cassandra", |  | ||||||
| 				"protocol_version": 3, |  | ||||||
| 			}, |  | ||||||
| 		}) |  | ||||||
| 		if err != nil || (resp != nil && resp.IsError()) { |  | ||||||
| 			// It's likely not up and running yet, so return false and try again |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		retURL = connURL |  | ||||||
| 		return true |  | ||||||
| 	}, []string{"-v", cwd + "/test-fixtures/:/etc/cassandra/"}...) |  | ||||||
|  |  | ||||||
| 	if connErr != nil { |  | ||||||
| 		if cid != "" { |  | ||||||
| 			cid.KillRemove() |  | ||||||
| 		} |  | ||||||
| 		t.Fatalf("could not connect to database: %v", connErr) |  | ||||||
| 	} | 	} | ||||||
|  | 	resource, err := pool.RunWithOptions(ro) | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { |  | ||||||
| 	err := cid.KillRemove() |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatalf("Could not start local cassandra docker container: %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	cleanup := func() { | ||||||
|  | 		err := pool.Purge(resource) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to cleanup local container: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	port, _ := strconv.Atoi(resource.GetPort("9042/tcp")) | ||||||
|  | 	address := fmt.Sprintf("127.0.0.1:%d", port) | ||||||
|  |  | ||||||
|  | 	// exponential backoff-retry | ||||||
|  | 	if err = pool.Retry(func() error { | ||||||
|  | 		clusterConfig := gocql.NewCluster(address) | ||||||
|  | 		clusterConfig.Authenticator = gocql.PasswordAuthenticator{ | ||||||
|  | 			Username: "cassandra", | ||||||
|  | 			Password: "cassandra", | ||||||
|  | 		} | ||||||
|  | 		clusterConfig.ProtoVersion = 4 | ||||||
|  | 		clusterConfig.Port = port | ||||||
|  |  | ||||||
|  | 		session, err := clusterConfig.CreateSession() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("error creating session: %s", err) | ||||||
|  | 		} | ||||||
|  | 		defer session.Close() | ||||||
|  | 		return nil | ||||||
|  | 	}); err != nil { | ||||||
|  | 		cleanup() | ||||||
|  | 		t.Fatalf("Could not connect to cassandra docker container: %s", err) | ||||||
|  | 	} | ||||||
|  | 	return cleanup, address, port | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestBackend_basic(t *testing.T) { | func TestBackend_basic(t *testing.T) { | ||||||
| @@ -84,10 +87,8 @@ func TestBackend_basic(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cid, hostname := prepareTestContainer(t, config.StorageView, b) | 	cleanup, hostname, _ := prepareCassandraTestContainer(t) | ||||||
| 	if cid != "" { | 	defer cleanup() | ||||||
| 		defer cleanupTestContainer(t, cid) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	logicaltest.Test(t, logicaltest.TestCase{ | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
| 		Backend: b, | 		Backend: b, | ||||||
| @@ -110,10 +111,8 @@ func TestBackend_roleCrud(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cid, hostname := prepareTestContainer(t, config.StorageView, b) | 	cleanup, hostname, _ := prepareCassandraTestContainer(t) | ||||||
| 	if cid != "" { | 	defer cleanup() | ||||||
| 		defer cleanupTestContainer(t, cid) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	logicaltest.Test(t, logicaltest.TestCase{ | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
| 		Backend: b, | 		Backend: b, | ||||||
|   | |||||||
| @@ -421,7 +421,7 @@ seed_provider: | |||||||
|       parameters: |       parameters: | ||||||
|           # seeds is actually a comma-delimited list of addresses. |           # seeds is actually a comma-delimited list of addresses. | ||||||
|           # Ex: "<ip1>,<ip2>,<ip3>" |           # Ex: "<ip1>,<ip2>,<ip3>" | ||||||
|           - seeds: "172.17.0.3" |           - seeds: "127.0.0.1" | ||||||
|  |  | ||||||
| # For workloads with more data than can fit in memory, Cassandra's | # For workloads with more data than can fit in memory, Cassandra's | ||||||
| # bottleneck will be reads that need to fetch data from | # bottleneck will be reads that need to fetch data from | ||||||
| @@ -572,7 +572,7 @@ ssl_storage_port: 7001 | |||||||
| # | # | ||||||
| # Setting listen_address to 0.0.0.0 is always wrong. | # Setting listen_address to 0.0.0.0 is always wrong. | ||||||
| # | # | ||||||
| listen_address: 172.17.0.3 | listen_address: 172.17.0.2 | ||||||
|  |  | ||||||
| # Set listen_address OR listen_interface, not both. Interfaces must correspond | # Set listen_address OR listen_interface, not both. Interfaces must correspond | ||||||
| # to a single address, IP aliasing is not supported. | # to a single address, IP aliasing is not supported. | ||||||
| @@ -586,7 +586,7 @@ listen_address: 172.17.0.3 | |||||||
|  |  | ||||||
| # Address to broadcast to other Cassandra nodes | # Address to broadcast to other Cassandra nodes | ||||||
| # Leaving this blank will set it to the same value as listen_address | # Leaving this blank will set it to the same value as listen_address | ||||||
| broadcast_address: 172.17.0.3 | broadcast_address: 127.0.0.1 | ||||||
|  |  | ||||||
| # When using multiple physical network interfaces, set this | # When using multiple physical network interfaces, set this | ||||||
| # to true to listen on broadcast_address in addition to | # to true to listen on broadcast_address in addition to | ||||||
| @@ -668,7 +668,7 @@ rpc_port: 9160 | |||||||
| # be set to 0.0.0.0. If left blank, this will be set to the value of | # be set to 0.0.0.0. If left blank, this will be set to the value of | ||||||
| # rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must | # rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must | ||||||
| # be set. | # be set. | ||||||
| broadcast_rpc_address: 172.17.0.3 | broadcast_rpc_address: 127.0.0.1 | ||||||
|  |  | ||||||
| # enable or disable keepalive on rpc/native connections | # enable or disable keepalive on rpc/native connections | ||||||
| rpc_keepalive: true | rpc_keepalive: true | ||||||
|   | |||||||
| @@ -16,6 +16,12 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) { | |||||||
| func Backend() *backend { | func Backend() *backend { | ||||||
| 	var b backend | 	var b backend | ||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/access", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathConfigAccess(), | 			pathConfigAccess(), | ||||||
| 			pathListRoles(&b), | 			pathListRoles(&b), | ||||||
|   | |||||||
| @@ -433,12 +433,8 @@ func testAccStepReadPolicy(t *testing.T, name string, policy string, lease time. | |||||||
| 				return fmt.Errorf("mismatch: %s %s", out, policy) | 				return fmt.Errorf("mismatch: %s %s", out, policy) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			leaseRaw := resp.Data["lease"].(string) | 			l := resp.Data["lease"].(int64) | ||||||
| 			l, err := time.ParseDuration(leaseRaw) | 			if lease != time.Second*time.Duration(l) { | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if l != lease { |  | ||||||
| 				return fmt.Errorf("mismatch: %v %v", l, lease) | 				return fmt.Errorf("mismatch: %v %v", l, lease) | ||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
|   | |||||||
| @@ -44,7 +44,7 @@ Defaults to 'client'.`, | |||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
| 			"lease": &framework.FieldSchema{ | 			"lease": &framework.FieldSchema{ | ||||||
| 				Type:        framework.TypeString, | 				Type:        framework.TypeDurationSecond, | ||||||
| 				Description: "Lease time of the role.", | 				Description: "Lease time of the role.", | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| @@ -91,7 +91,7 @@ func pathRolesRead( | |||||||
| 	// Generate the response | 	// Generate the response | ||||||
| 	resp := &logical.Response{ | 	resp := &logical.Response{ | ||||||
| 		Data: map[string]interface{}{ | 		Data: map[string]interface{}{ | ||||||
| 			"lease":      result.Lease.String(), | 			"lease":      int64(result.Lease.Seconds()), | ||||||
| 			"token_type": result.TokenType, | 			"token_type": result.TokenType, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| @@ -130,13 +130,9 @@ func pathRolesWrite( | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var lease time.Duration | 	var lease time.Duration | ||||||
| 	leaseParam := d.Get("lease").(string) | 	leaseParamRaw, ok := d.GetOk("lease") | ||||||
| 	if leaseParam != "" { | 	if ok { | ||||||
| 		lease, err = time.ParseDuration(leaseParam) | 		lease = time.Second * time.Duration(leaseParamRaw.(int)) | ||||||
| 		if err != nil { |  | ||||||
| 			return logical.ErrorResponse(fmt.Sprintf( |  | ||||||
| 				"error parsing given lease of %s: %s", leaseParam, err)), nil |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	entry, err := logical.StorageEntryJSON("policy/"+name, roleConfig{ | 	entry, err := logical.StorageEntryJSON("policy/"+name, roleConfig{ | ||||||
|   | |||||||
| @@ -11,9 +11,9 @@ import ( | |||||||
|  |  | ||||||
| func pathToken(b *backend) *framework.Path { | func pathToken(b *backend) *framework.Path { | ||||||
| 	return &framework.Path{ | 	return &framework.Path{ | ||||||
| 		Pattern: "creds/" + framework.GenericNameRegex("name"), | 		Pattern: "creds/" + framework.GenericNameRegex("role"), | ||||||
| 		Fields: map[string]*framework.FieldSchema{ | 		Fields: map[string]*framework.FieldSchema{ | ||||||
| 			"name": &framework.FieldSchema{ | 			"role": &framework.FieldSchema{ | ||||||
| 				Type:        framework.TypeString, | 				Type:        framework.TypeString, | ||||||
| 				Description: "Name of the role", | 				Description: "Name of the role", | ||||||
| 			}, | 			}, | ||||||
| @@ -27,14 +27,14 @@ func pathToken(b *backend) *framework.Path { | |||||||
|  |  | ||||||
| func (b *backend) pathTokenRead( | func (b *backend) pathTokenRead( | ||||||
| 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
| 	name := d.Get("name").(string) | 	role := d.Get("role").(string) | ||||||
|  |  | ||||||
| 	entry, err := req.Storage.Get("policy/" + name) | 	entry, err := req.Storage.Get("policy/" + role) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("error retrieving role: %s", err) | 		return nil, fmt.Errorf("error retrieving role: %s", err) | ||||||
| 	} | 	} | ||||||
| 	if entry == nil { | 	if entry == nil { | ||||||
| 		return logical.ErrorResponse(fmt.Sprintf("Role '%s' not found", name)), nil | 		return logical.ErrorResponse(fmt.Sprintf("role %q not found", role)), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var result roleConfig | 	var result roleConfig | ||||||
| @@ -56,7 +56,7 @@ func (b *backend) pathTokenRead( | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Generate a name for the token | 	// Generate a name for the token | ||||||
| 	tokenName := fmt.Sprintf("Vault %s %s %d", name, req.DisplayName, time.Now().UnixNano()) | 	tokenName := fmt.Sprintf("Vault %s %s %d", role, req.DisplayName, time.Now().UnixNano()) | ||||||
|  |  | ||||||
| 	// Create it | 	// Create it | ||||||
| 	token, _, err := c.ACL().Create(&api.ACLEntry{ | 	token, _, err := c.ACL().Create(&api.ACLEntry{ | ||||||
| @@ -73,6 +73,7 @@ func (b *backend) pathTokenRead( | |||||||
| 		"token": token, | 		"token": token, | ||||||
| 	}, map[string]interface{}{ | 	}, map[string]interface{}{ | ||||||
| 		"token": token, | 		"token": token, | ||||||
|  | 		"role":  role, | ||||||
| 	}) | 	}) | ||||||
| 	s.Secret.TTL = result.Lease | 	s.Secret.TTL = result.Lease | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,8 @@ | |||||||
| package consul | package consul | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| ) | ) | ||||||
| @@ -26,8 +28,30 @@ func secretToken(b *backend) *framework.Secret { | |||||||
|  |  | ||||||
| func (b *backend) secretTokenRenew( | func (b *backend) secretTokenRenew( | ||||||
| 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	roleRaw, ok := req.Secret.InternalData["role"] | ||||||
|  | 	if !ok || roleRaw == nil { | ||||||
|  | 		return framework.LeaseExtend(0, 0, b.System())(req, d) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return framework.LeaseExtend(0, 0, b.System())(req, d) | 	role, ok := roleRaw.(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		return framework.LeaseExtend(0, 0, b.System())(req, d) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	entry, err := req.Storage.Get("policy/" + role) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error retrieving role: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		return logical.ErrorResponse(fmt.Sprintf("issuing role %q not found", role)), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var result roleConfig | ||||||
|  | 	if err := entry.DecodeJSON(&result); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return framework.LeaseExtend(result.Lease, 0, b.System())(req, d) | ||||||
| } | } | ||||||
|  |  | ||||||
| func secretTokenRevoke( | func secretTokenRevoke( | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package database | package database | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/rpc" | 	"net/rpc" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -28,6 +29,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { | |||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: strings.TrimSpace(backendHelp), | 		Help: strings.TrimSpace(backendHelp), | ||||||
|  |  | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/*", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathListPluginConnection(&b), | 			pathListPluginConnection(&b), | ||||||
| 			pathConfigurePluginConnection(&b), | 			pathConfigurePluginConnection(&b), | ||||||
| @@ -81,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { | |||||||
| // This function creates a new db object from the stored configuration and | // This function creates a new db object from the stored configuration and | ||||||
| // caches it in the connections map. The caller of this function needs to hold | // caches it in the connections map. The caller of this function needs to hold | ||||||
| // the backend's write lock | // the backend's write lock | ||||||
| func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { | func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) { | ||||||
| 	db, ok := b.connections[name] | 	db, ok := b.connections[name] | ||||||
| 	if ok { | 	if ok { | ||||||
| 		return db, nil | 		return db, nil | ||||||
| @@ -97,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin. | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.Initialize(config.ConnectionDetails, true) | 	err = db.Initialize(ctx, config.ConnectionDetails, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -124,6 +131,21 @@ func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*Datab | |||||||
| 	return &config, nil | 	return &config, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type upgradeStatements struct { | ||||||
|  | 	// This json tag has a typo in it, the new version does not. This | ||||||
|  | 	// necessitates this upgrade logic. | ||||||
|  | 	CreationStatements   string `json:"creation_statments"` | ||||||
|  | 	RevocationStatements string `json:"revocation_statements"` | ||||||
|  | 	RollbackStatements   string `json:"rollback_statements"` | ||||||
|  | 	RenewStatements      string `json:"renew_statements"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type upgradeCheck struct { | ||||||
|  | 	// This json tag has a typo in it, the new version does not. This | ||||||
|  | 	// necessitates this upgrade logic. | ||||||
|  | 	Statements upgradeStatements `json:"statments"` | ||||||
|  | } | ||||||
|  |  | ||||||
| func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) { | func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) { | ||||||
| 	entry, err := s.Get("role/" + roleName) | 	entry, err := s.Get("role/" + roleName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -133,11 +155,24 @@ func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	var upgradeCh upgradeCheck | ||||||
|  | 	if err := entry.DecodeJSON(&upgradeCh); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	var result roleEntry | 	var result roleEntry | ||||||
| 	if err := entry.DecodeJSON(&result); err != nil { | 	if err := entry.DecodeJSON(&result); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	empty := upgradeCheck{} | ||||||
|  | 	if upgradeCh != empty { | ||||||
|  | 		result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements | ||||||
|  | 		result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements | ||||||
|  | 		result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements | ||||||
|  | 		result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return &result, nil | 	return &result, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -164,7 +199,8 @@ func (b *databaseBackend) clearConnection(name string) { | |||||||
|  |  | ||||||
| func (b *databaseBackend) closeIfShutdown(name string, err error) { | func (b *databaseBackend) closeIfShutdown(name string, err error) { | ||||||
| 	// Plugin has shutdown, close it so next call can reconnect. | 	// Plugin has shutdown, close it so next call can reconnect. | ||||||
| 	if err == rpc.ErrShutdown { | 	switch err { | ||||||
|  | 	case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: | ||||||
| 		b.Lock() | 		b.Lock() | ||||||
| 		b.clearConnection(name) | 		b.clearConnection(name) | ||||||
| 		b.Unlock() | 		b.Unlock() | ||||||
|   | |||||||
| @@ -116,6 +116,55 @@ func TestBackend_PluginMain(t *testing.T) { | |||||||
| 	postgresql.Run(apiClientMeta.GetTLSConfig()) | 	postgresql.Run(apiClientMeta.GetTLSConfig()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestBackend_RoleUpgrade(t *testing.T) { | ||||||
|  |  | ||||||
|  | 	storage := &logical.InmemStorage{} | ||||||
|  | 	backend := &databaseBackend{} | ||||||
|  |  | ||||||
|  | 	roleEnt := &roleEntry{ | ||||||
|  | 		Statements: dbplugin.Statements{ | ||||||
|  | 			CreationStatements: "test", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	entry, err := logical.StorageEntryJSON("role/test", roleEnt) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if err := storage.Put(entry); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	role, err := backend.Role(storage, "test") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(role, roleEnt) { | ||||||
|  | 		t.Fatal("bad role %#v", role) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Upgrade case | ||||||
|  | 	badJSON := `{"statments":{"creation_statments":"test","revocation_statements":"","rollback_statements":"","renew_statements":""}}` | ||||||
|  | 	entry = &logical.StorageEntry{ | ||||||
|  | 		Key:   "role/test", | ||||||
|  | 		Value: []byte(badJSON), | ||||||
|  | 	} | ||||||
|  | 	if err := storage.Put(entry); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	role, err = backend.Role(storage, "test") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(role, roleEnt) { | ||||||
|  | 		t.Fatal("bad role %#v", role) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestBackend_config_connection(t *testing.T) { | func TestBackend_config_connection(t *testing.T) { | ||||||
| 	var resp *logical.Response | 	var resp *logical.Response | ||||||
| 	var err error | 	var err error | ||||||
| @@ -488,9 +537,11 @@ func TestBackend_roleCrud(t *testing.T) { | |||||||
| 		RevocationStatements: defaultRevocationSQL, | 		RevocationStatements: defaultRevocationSQL, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var actual dbplugin.Statements | 	actual := dbplugin.Statements{ | ||||||
| 	if err := mapstructure.Decode(resp.Data, &actual); err != nil { | 		CreationStatements:   resp.Data["creation_statements"].(string), | ||||||
| 		t.Fatal(err) | 		RevocationStatements: resp.Data["revocation_statements"].(string), | ||||||
|  | 		RollbackStatements:   resp.Data["rollback_statements"].(string), | ||||||
|  | 		RenewStatements:      resp.Data["renew_statements"].(string), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !reflect.DeepEqual(expected, actual) { | 	if !reflect.DeepEqual(expected, actual) { | ||||||
| @@ -609,6 +660,40 @@ func TestBackend_allowedRoles(t *testing.T) { | |||||||
| 		t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) | 		t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// update connection with glob allowed roles connection | ||||||
|  | 	data = map[string]interface{}{ | ||||||
|  | 		"connection_url": connURL, | ||||||
|  | 		"plugin_name":    "postgresql-database-plugin", | ||||||
|  | 		"allowed_roles":  "allow*", | ||||||
|  | 	} | ||||||
|  | 	req = &logical.Request{ | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "config/plugin-test", | ||||||
|  | 		Storage:   config.StorageView, | ||||||
|  | 		Data:      data, | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(req) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("err:%s resp:%#v\n", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get creds, should work. | ||||||
|  | 	data = map[string]interface{}{} | ||||||
|  | 	req = &logical.Request{ | ||||||
|  | 		Operation: logical.ReadOperation, | ||||||
|  | 		Path:      "creds/allowed", | ||||||
|  | 		Storage:   config.StorageView, | ||||||
|  | 		Data:      data, | ||||||
|  | 	} | ||||||
|  | 	credsResp, err = b.HandleRequest(req) | ||||||
|  | 	if err != nil || (credsResp != nil && credsResp.IsError()) { | ||||||
|  | 		t.Fatalf("err:%s resp:%#v\n", err, credsResp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !testCredsExist(t, credsResp, connURL) { | ||||||
|  | 		t.Fatalf("Creds should exist") | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// update connection with * allowed roles connection | 	// update connection with * allowed roles connection | ||||||
| 	data = map[string]interface{}{ | 	data = map[string]interface{}{ | ||||||
| 		"connection_url": connURL, | 		"connection_url": connURL, | ||||||
|   | |||||||
| @@ -1,10 +1,8 @@ | |||||||
| package dbplugin | package dbplugin | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"errors" | ||||||
| 	"net/rpc" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-plugin" | 	"github.com/hashicorp/go-plugin" | ||||||
| 	"github.com/hashicorp/vault/helper/pluginutil" | 	"github.com/hashicorp/vault/helper/pluginutil" | ||||||
| @@ -17,11 +15,11 @@ type DatabasePluginClient struct { | |||||||
| 	client *plugin.Client | 	client *plugin.Client | ||||||
| 	sync.Mutex | 	sync.Mutex | ||||||
|  |  | ||||||
| 	*databasePluginRPCClient | 	Database | ||||||
| } | } | ||||||
|  |  | ||||||
| func (dc *DatabasePluginClient) Close() error { | func (dc *DatabasePluginClient) Close() error { | ||||||
| 	err := dc.databasePluginRPCClient.Close() | 	err := dc.Database.Close() | ||||||
| 	dc.client.Kill() | 	dc.client.Kill() | ||||||
|  |  | ||||||
| 	return err | 	return err | ||||||
| @@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR | |||||||
|  |  | ||||||
| 	// We should have a database type now. This feels like a normal interface | 	// We should have a database type now. This feels like a normal interface | ||||||
| 	// implementation but is in fact over an RPC connection. | 	// implementation but is in fact over an RPC connection. | ||||||
| 	databaseRPC := raw.(*databasePluginRPCClient) | 	var db Database | ||||||
|  | 	switch raw.(type) { | ||||||
|  | 	case *gRPCClient: | ||||||
|  | 		db = raw.(*gRPCClient) | ||||||
|  | 	case *databasePluginRPCClient: | ||||||
|  | 		logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name) | ||||||
|  | 		db = raw.(*databasePluginRPCClient) | ||||||
|  | 	default: | ||||||
|  | 		return nil, errors.New("unsupported client type") | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Wrap RPC implimentation in DatabasePluginClient | 	// Wrap RPC implimentation in DatabasePluginClient | ||||||
| 	return &DatabasePluginClient{ | 	return &DatabasePluginClient{ | ||||||
| 		client:                  client, | 		client:   client, | ||||||
| 		databasePluginRPCClient: databaseRPC, | 		Database: db, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // ---- RPC client domain ---- |  | ||||||
|  |  | ||||||
| // databasePluginRPCClient implements Database and is used on the client to |  | ||||||
| // make RPC calls to a plugin. |  | ||||||
| type databasePluginRPCClient struct { |  | ||||||
| 	client *rpc.Client |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (dr *databasePluginRPCClient) Type() (string, error) { |  | ||||||
| 	var dbType string |  | ||||||
| 	err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) |  | ||||||
|  |  | ||||||
| 	return fmt.Sprintf("plugin-%s", dbType), err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { |  | ||||||
| 	req := CreateUserRequest{ |  | ||||||
| 		Statements:     statements, |  | ||||||
| 		UsernameConfig: usernameConfig, |  | ||||||
| 		Expiration:     expiration, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var resp CreateUserResponse |  | ||||||
| 	err = dr.client.Call("Plugin.CreateUser", req, &resp) |  | ||||||
|  |  | ||||||
| 	return resp.Username, resp.Password, err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error { |  | ||||||
| 	req := RenewUserRequest{ |  | ||||||
| 		Statements: statements, |  | ||||||
| 		Username:   username, |  | ||||||
| 		Expiration: expiration, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { |  | ||||||
| 	req := RevokeUserRequest{ |  | ||||||
| 		Statements: statements, |  | ||||||
| 		Username:   username, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error { |  | ||||||
| 	req := InitializeRequest{ |  | ||||||
| 		Config:           conf, |  | ||||||
| 		VerifyConnection: verifyConnection, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (dr *databasePluginRPCClient) Close() error { |  | ||||||
| 	err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|   | |||||||
							
								
								
									
										556
									
								
								builtin/logical/database/dbplugin/database.pb.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										556
									
								
								builtin/logical/database/dbplugin/database.pb.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,556 @@ | |||||||
|  | // Code generated by protoc-gen-go. DO NOT EDIT. | ||||||
|  | // source: builtin/logical/database/dbplugin/database.proto | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | Package dbplugin is a generated protocol buffer package. | ||||||
|  |  | ||||||
|  | It is generated from these files: | ||||||
|  | 	builtin/logical/database/dbplugin/database.proto | ||||||
|  |  | ||||||
|  | It has these top-level messages: | ||||||
|  | 	InitializeRequest | ||||||
|  | 	CreateUserRequest | ||||||
|  | 	RenewUserRequest | ||||||
|  | 	RevokeUserRequest | ||||||
|  | 	Statements | ||||||
|  | 	UsernameConfig | ||||||
|  | 	CreateUserResponse | ||||||
|  | 	TypeResponse | ||||||
|  | 	Empty | ||||||
|  | */ | ||||||
|  | package dbplugin | ||||||
|  |  | ||||||
|  | import proto "github.com/golang/protobuf/proto" | ||||||
|  | import fmt "fmt" | ||||||
|  | import math "math" | ||||||
|  | import google_protobuf "github.com/golang/protobuf/ptypes/timestamp" | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	context "golang.org/x/net/context" | ||||||
|  | 	grpc "google.golang.org/grpc" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Reference imports to suppress errors if they are not otherwise used. | ||||||
|  | var _ = proto.Marshal | ||||||
|  | var _ = fmt.Errorf | ||||||
|  | var _ = math.Inf | ||||||
|  |  | ||||||
|  | // This is a compile-time assertion to ensure that this generated file | ||||||
|  | // is compatible with the proto package it is being compiled against. | ||||||
|  | // A compilation error at this line likely means your copy of the | ||||||
|  | // proto package needs to be updated. | ||||||
|  | const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package | ||||||
|  |  | ||||||
|  | type InitializeRequest struct { | ||||||
|  | 	Config           []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` | ||||||
|  | 	VerifyConnection bool   `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *InitializeRequest) Reset()                    { *m = InitializeRequest{} } | ||||||
|  | func (m *InitializeRequest) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*InitializeRequest) ProtoMessage()               {} | ||||||
|  | func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } | ||||||
|  |  | ||||||
|  | func (m *InitializeRequest) GetConfig() []byte { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Config | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *InitializeRequest) GetVerifyConnection() bool { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.VerifyConnection | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type CreateUserRequest struct { | ||||||
|  | 	Statements     *Statements                `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` | ||||||
|  | 	UsernameConfig *UsernameConfig            `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"` | ||||||
|  | 	Expiration     *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *CreateUserRequest) Reset()                    { *m = CreateUserRequest{} } | ||||||
|  | func (m *CreateUserRequest) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*CreateUserRequest) ProtoMessage()               {} | ||||||
|  | func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } | ||||||
|  |  | ||||||
|  | func (m *CreateUserRequest) GetStatements() *Statements { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Statements | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.UsernameConfig | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Expiration | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type RenewUserRequest struct { | ||||||
|  | 	Statements *Statements                `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` | ||||||
|  | 	Username   string                     `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` | ||||||
|  | 	Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *RenewUserRequest) Reset()                    { *m = RenewUserRequest{} } | ||||||
|  | func (m *RenewUserRequest) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*RenewUserRequest) ProtoMessage()               {} | ||||||
|  | func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } | ||||||
|  |  | ||||||
|  | func (m *RenewUserRequest) GetStatements() *Statements { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Statements | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *RenewUserRequest) GetUsername() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Username | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Expiration | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type RevokeUserRequest struct { | ||||||
|  | 	Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` | ||||||
|  | 	Username   string      `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *RevokeUserRequest) Reset()                    { *m = RevokeUserRequest{} } | ||||||
|  | func (m *RevokeUserRequest) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*RevokeUserRequest) ProtoMessage()               {} | ||||||
|  | func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } | ||||||
|  |  | ||||||
|  | func (m *RevokeUserRequest) GetStatements() *Statements { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Statements | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *RevokeUserRequest) GetUsername() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Username | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Statements struct { | ||||||
|  | 	CreationStatements   string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"` | ||||||
|  | 	RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"` | ||||||
|  | 	RollbackStatements   string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"` | ||||||
|  | 	RenewStatements      string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Statements) Reset()                    { *m = Statements{} } | ||||||
|  | func (m *Statements) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*Statements) ProtoMessage()               {} | ||||||
|  | func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } | ||||||
|  |  | ||||||
|  | func (m *Statements) GetCreationStatements() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.CreationStatements | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Statements) GetRevocationStatements() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.RevocationStatements | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Statements) GetRollbackStatements() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.RollbackStatements | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Statements) GetRenewStatements() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.RenewStatements | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type UsernameConfig struct { | ||||||
|  | 	DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"` | ||||||
|  | 	RoleName    string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *UsernameConfig) Reset()                    { *m = UsernameConfig{} } | ||||||
|  | func (m *UsernameConfig) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*UsernameConfig) ProtoMessage()               {} | ||||||
|  | func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } | ||||||
|  |  | ||||||
|  | func (m *UsernameConfig) GetDisplayName() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.DisplayName | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *UsernameConfig) GetRoleName() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.RoleName | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type CreateUserResponse struct { | ||||||
|  | 	Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"` | ||||||
|  | 	Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *CreateUserResponse) Reset()                    { *m = CreateUserResponse{} } | ||||||
|  | func (m *CreateUserResponse) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*CreateUserResponse) ProtoMessage()               {} | ||||||
|  | func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } | ||||||
|  |  | ||||||
|  | func (m *CreateUserResponse) GetUsername() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Username | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *CreateUserResponse) GetPassword() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Password | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TypeResponse struct { | ||||||
|  | 	Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *TypeResponse) Reset()                    { *m = TypeResponse{} } | ||||||
|  | func (m *TypeResponse) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*TypeResponse) ProtoMessage()               {} | ||||||
|  | func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } | ||||||
|  |  | ||||||
|  | func (m *TypeResponse) GetType() string { | ||||||
|  | 	if m != nil { | ||||||
|  | 		return m.Type | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Empty struct { | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Empty) Reset()                    { *m = Empty{} } | ||||||
|  | func (m *Empty) String() string            { return proto.CompactTextString(m) } | ||||||
|  | func (*Empty) ProtoMessage()               {} | ||||||
|  | func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} } | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest") | ||||||
|  | 	proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest") | ||||||
|  | 	proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest") | ||||||
|  | 	proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest") | ||||||
|  | 	proto.RegisterType((*Statements)(nil), "dbplugin.Statements") | ||||||
|  | 	proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig") | ||||||
|  | 	proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse") | ||||||
|  | 	proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse") | ||||||
|  | 	proto.RegisterType((*Empty)(nil), "dbplugin.Empty") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Reference imports to suppress errors if they are not otherwise used. | ||||||
|  | var _ context.Context | ||||||
|  | var _ grpc.ClientConn | ||||||
|  |  | ||||||
|  | // This is a compile-time assertion to ensure that this generated file | ||||||
|  | // is compatible with the grpc package it is being compiled against. | ||||||
|  | const _ = grpc.SupportPackageIsVersion4 | ||||||
|  |  | ||||||
|  | // Client API for Database service | ||||||
|  |  | ||||||
|  | type DatabaseClient interface { | ||||||
|  | 	Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) | ||||||
|  | 	CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) | ||||||
|  | 	RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) | ||||||
|  | 	RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) | ||||||
|  | 	Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) | ||||||
|  | 	Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type databaseClient struct { | ||||||
|  | 	cc *grpc.ClientConn | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient { | ||||||
|  | 	return &databaseClient{cc} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) { | ||||||
|  | 	out := new(TypeResponse) | ||||||
|  | 	err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return out, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) { | ||||||
|  | 	out := new(CreateUserResponse) | ||||||
|  | 	err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return out, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) { | ||||||
|  | 	out := new(Empty) | ||||||
|  | 	err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return out, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) { | ||||||
|  | 	out := new(Empty) | ||||||
|  | 	err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return out, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) { | ||||||
|  | 	out := new(Empty) | ||||||
|  | 	err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return out, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) { | ||||||
|  | 	out := new(Empty) | ||||||
|  | 	err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return out, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Server API for Database service | ||||||
|  |  | ||||||
|  | type DatabaseServer interface { | ||||||
|  | 	Type(context.Context, *Empty) (*TypeResponse, error) | ||||||
|  | 	CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) | ||||||
|  | 	RenewUser(context.Context, *RenewUserRequest) (*Empty, error) | ||||||
|  | 	RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error) | ||||||
|  | 	Initialize(context.Context, *InitializeRequest) (*Empty, error) | ||||||
|  | 	Close(context.Context, *Empty) (*Empty, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) { | ||||||
|  | 	s.RegisterService(&_Database_serviceDesc, srv) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||||
|  | 	in := new(Empty) | ||||||
|  | 	if err := dec(in); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if interceptor == nil { | ||||||
|  | 		return srv.(DatabaseServer).Type(ctx, in) | ||||||
|  | 	} | ||||||
|  | 	info := &grpc.UnaryServerInfo{ | ||||||
|  | 		Server:     srv, | ||||||
|  | 		FullMethod: "/dbplugin.Database/Type", | ||||||
|  | 	} | ||||||
|  | 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||||
|  | 		return srv.(DatabaseServer).Type(ctx, req.(*Empty)) | ||||||
|  | 	} | ||||||
|  | 	return interceptor(ctx, in, info, handler) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||||
|  | 	in := new(CreateUserRequest) | ||||||
|  | 	if err := dec(in); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if interceptor == nil { | ||||||
|  | 		return srv.(DatabaseServer).CreateUser(ctx, in) | ||||||
|  | 	} | ||||||
|  | 	info := &grpc.UnaryServerInfo{ | ||||||
|  | 		Server:     srv, | ||||||
|  | 		FullMethod: "/dbplugin.Database/CreateUser", | ||||||
|  | 	} | ||||||
|  | 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||||
|  | 		return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest)) | ||||||
|  | 	} | ||||||
|  | 	return interceptor(ctx, in, info, handler) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||||
|  | 	in := new(RenewUserRequest) | ||||||
|  | 	if err := dec(in); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if interceptor == nil { | ||||||
|  | 		return srv.(DatabaseServer).RenewUser(ctx, in) | ||||||
|  | 	} | ||||||
|  | 	info := &grpc.UnaryServerInfo{ | ||||||
|  | 		Server:     srv, | ||||||
|  | 		FullMethod: "/dbplugin.Database/RenewUser", | ||||||
|  | 	} | ||||||
|  | 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||||
|  | 		return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest)) | ||||||
|  | 	} | ||||||
|  | 	return interceptor(ctx, in, info, handler) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||||
|  | 	in := new(RevokeUserRequest) | ||||||
|  | 	if err := dec(in); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if interceptor == nil { | ||||||
|  | 		return srv.(DatabaseServer).RevokeUser(ctx, in) | ||||||
|  | 	} | ||||||
|  | 	info := &grpc.UnaryServerInfo{ | ||||||
|  | 		Server:     srv, | ||||||
|  | 		FullMethod: "/dbplugin.Database/RevokeUser", | ||||||
|  | 	} | ||||||
|  | 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||||
|  | 		return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest)) | ||||||
|  | 	} | ||||||
|  | 	return interceptor(ctx, in, info, handler) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||||
|  | 	in := new(InitializeRequest) | ||||||
|  | 	if err := dec(in); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if interceptor == nil { | ||||||
|  | 		return srv.(DatabaseServer).Initialize(ctx, in) | ||||||
|  | 	} | ||||||
|  | 	info := &grpc.UnaryServerInfo{ | ||||||
|  | 		Server:     srv, | ||||||
|  | 		FullMethod: "/dbplugin.Database/Initialize", | ||||||
|  | 	} | ||||||
|  | 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||||
|  | 		return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest)) | ||||||
|  | 	} | ||||||
|  | 	return interceptor(ctx, in, info, handler) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||||
|  | 	in := new(Empty) | ||||||
|  | 	if err := dec(in); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if interceptor == nil { | ||||||
|  | 		return srv.(DatabaseServer).Close(ctx, in) | ||||||
|  | 	} | ||||||
|  | 	info := &grpc.UnaryServerInfo{ | ||||||
|  | 		Server:     srv, | ||||||
|  | 		FullMethod: "/dbplugin.Database/Close", | ||||||
|  | 	} | ||||||
|  | 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||||
|  | 		return srv.(DatabaseServer).Close(ctx, req.(*Empty)) | ||||||
|  | 	} | ||||||
|  | 	return interceptor(ctx, in, info, handler) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var _Database_serviceDesc = grpc.ServiceDesc{ | ||||||
|  | 	ServiceName: "dbplugin.Database", | ||||||
|  | 	HandlerType: (*DatabaseServer)(nil), | ||||||
|  | 	Methods: []grpc.MethodDesc{ | ||||||
|  | 		{ | ||||||
|  | 			MethodName: "Type", | ||||||
|  | 			Handler:    _Database_Type_Handler, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			MethodName: "CreateUser", | ||||||
|  | 			Handler:    _Database_CreateUser_Handler, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			MethodName: "RenewUser", | ||||||
|  | 			Handler:    _Database_RenewUser_Handler, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			MethodName: "RevokeUser", | ||||||
|  | 			Handler:    _Database_RevokeUser_Handler, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			MethodName: "Initialize", | ||||||
|  | 			Handler:    _Database_Initialize_Handler, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			MethodName: "Close", | ||||||
|  | 			Handler:    _Database_Close_Handler, | ||||||
|  | 		}, | ||||||
|  | 	}, | ||||||
|  | 	Streams:  []grpc.StreamDesc{}, | ||||||
|  | 	Metadata: "builtin/logical/database/dbplugin/database.proto", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) } | ||||||
|  |  | ||||||
|  | var fileDescriptor0 = []byte{ | ||||||
|  | 	// 548 bytes of a gzipped FileDescriptorProto | ||||||
|  | 	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e, | ||||||
|  | 	0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f, | ||||||
|  | 	0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e, | ||||||
|  | 	0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2, | ||||||
|  | 	0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6, | ||||||
|  | 	0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c, | ||||||
|  | 	0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e, | ||||||
|  | 	0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3, | ||||||
|  | 	0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1, | ||||||
|  | 	0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f, | ||||||
|  | 	0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49, | ||||||
|  | 	0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35, | ||||||
|  | 	0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20, | ||||||
|  | 	0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6, | ||||||
|  | 	0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34, | ||||||
|  | 	0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0, | ||||||
|  | 	0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce, | ||||||
|  | 	0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68, | ||||||
|  | 	0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0, | ||||||
|  | 	0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8, | ||||||
|  | 	0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1, | ||||||
|  | 	0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0, | ||||||
|  | 	0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a, | ||||||
|  | 	0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab, | ||||||
|  | 	0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c, | ||||||
|  | 	0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4, | ||||||
|  | 	0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36, | ||||||
|  | 	0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a, | ||||||
|  | 	0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf, | ||||||
|  | 	0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67, | ||||||
|  | 	0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b, | ||||||
|  | 	0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d, | ||||||
|  | 	0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6, | ||||||
|  | 	0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56, | ||||||
|  | 	0x94, 0x05, 0x00, 0x00, | ||||||
|  | } | ||||||
							
								
								
									
										58
									
								
								builtin/logical/database/dbplugin/database.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								builtin/logical/database/dbplugin/database.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  | package dbplugin; | ||||||
|  |  | ||||||
|  | import "google/protobuf/timestamp.proto"; | ||||||
|  |  | ||||||
|  | message InitializeRequest { | ||||||
|  | 	bytes config = 1; | ||||||
|  | 	bool verify_connection = 2; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message CreateUserRequest { | ||||||
|  | 	Statements     statements = 1; | ||||||
|  | 	UsernameConfig username_config = 2; | ||||||
|  | 	google.protobuf.Timestamp expiration = 3; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message RenewUserRequest { | ||||||
|  | 	Statements statements = 1; | ||||||
|  | 	string username = 2; | ||||||
|  | 	google.protobuf.Timestamp expiration = 3; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message RevokeUserRequest { | ||||||
|  | 	Statements statements = 1; | ||||||
|  | 	string username = 2; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message Statements { | ||||||
|  | 	string creation_statements = 1; | ||||||
|  | 	string revocation_statements = 2; | ||||||
|  | 	string rollback_statements  = 3; | ||||||
|  | 	string renew_statements = 4; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message UsernameConfig { | ||||||
|  | 	string DisplayName = 1; | ||||||
|  | 	string RoleName = 2; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message CreateUserResponse { | ||||||
|  | 	string username = 1; | ||||||
|  | 	string password = 2; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message TypeResponse { | ||||||
|  |     string type = 1; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message Empty {} | ||||||
|  |  | ||||||
|  | service Database { | ||||||
|  |     rpc Type(Empty) returns (TypeResponse); | ||||||
|  |     rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); | ||||||
|  |     rpc RenewUser(RenewUserRequest) returns (Empty); | ||||||
|  |     rpc RevokeUser(RevokeUserRequest) returns (Empty); | ||||||
|  |     rpc Initialize(InitializeRequest) returns (Empty); | ||||||
|  |     rpc Close(Empty) returns (Empty); | ||||||
|  | } | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| package dbplugin | package dbplugin | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	metrics "github.com/armon/go-metrics" | 	metrics "github.com/armon/go-metrics" | ||||||
| @@ -15,55 +16,56 @@ type databaseTracingMiddleware struct { | |||||||
| 	next   Database | 	next   Database | ||||||
| 	logger log.Logger | 	logger log.Logger | ||||||
|  |  | ||||||
| 	typeStr string | 	typeStr   string | ||||||
|  | 	transport string | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseTracingMiddleware) Type() (string, error) { | func (mw *databaseTracingMiddleware) Type() (string, error) { | ||||||
| 	return mw.next.Type() | 	return mw.next.Type() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||||
| 	defer func(then time.Time) { | 	defer func(then time.Time) { | ||||||
| 		mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | 		mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||||
| 	}(time.Now()) | 	}(time.Now()) | ||||||
|  |  | ||||||
| 	mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) | 	mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||||
| 	return mw.next.CreateUser(statements, usernameConfig, expiration) | 	return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { | func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { | ||||||
| 	defer func(then time.Time) { | 	defer func(then time.Time) { | ||||||
| 		mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | 		mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||||
| 	}(time.Now()) | 	}(time.Now()) | ||||||
|  |  | ||||||
| 	mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) | 	mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport) | ||||||
| 	return mw.next.RenewUser(statements, username, expiration) | 	return mw.next.RenewUser(ctx, statements, username, expiration) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { | func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { | ||||||
| 	defer func(then time.Time) { | 	defer func(then time.Time) { | ||||||
| 		mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | 		mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||||
| 	}(time.Now()) | 	}(time.Now()) | ||||||
|  |  | ||||||
| 	mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) | 	mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||||
| 	return mw.next.RevokeUser(statements, username) | 	return mw.next.RevokeUser(ctx, statements, username) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { | func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { | ||||||
| 	defer func(then time.Time) { | 	defer func(then time.Time) { | ||||||
| 		mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) | 		mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then)) | ||||||
| 	}(time.Now()) | 	}(time.Now()) | ||||||
|  |  | ||||||
| 	mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) | 	mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||||
| 	return mw.next.Initialize(conf, verifyConnection) | 	return mw.next.Initialize(ctx, conf, verifyConnection) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseTracingMiddleware) Close() (err error) { | func (mw *databaseTracingMiddleware) Close() (err error) { | ||||||
| 	defer func(then time.Time) { | 	defer func(then time.Time) { | ||||||
| 		mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | 		mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||||
| 	}(time.Now()) | 	}(time.Now()) | ||||||
|  |  | ||||||
| 	mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) | 	mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||||
| 	return mw.next.Close() | 	return mw.next.Close() | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) { | |||||||
| 	return mw.next.Type() | 	return mw.next.Type() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||||
| 	defer func(now time.Time) { | 	defer func(now time.Time) { | ||||||
| 		metrics.MeasureSince([]string{"database", "CreateUser"}, now) | 		metrics.MeasureSince([]string{"database", "CreateUser"}, now) | ||||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) | 		metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) | ||||||
| @@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC | |||||||
|  |  | ||||||
| 	metrics.IncrCounter([]string{"database", "CreateUser"}, 1) | 	metrics.IncrCounter([]string{"database", "CreateUser"}, 1) | ||||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) | 	metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) | ||||||
| 	return mw.next.CreateUser(statements, usernameConfig, expiration) | 	return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { | func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { | ||||||
| 	defer func(now time.Time) { | 	defer func(now time.Time) { | ||||||
| 		metrics.MeasureSince([]string{"database", "RenewUser"}, now) | 		metrics.MeasureSince([]string{"database", "RenewUser"}, now) | ||||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) | 		metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) | ||||||
| @@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s | |||||||
|  |  | ||||||
| 	metrics.IncrCounter([]string{"database", "RenewUser"}, 1) | 	metrics.IncrCounter([]string{"database", "RenewUser"}, 1) | ||||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) | 	metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) | ||||||
| 	return mw.next.RenewUser(statements, username, expiration) | 	return mw.next.RenewUser(ctx, statements, username, expiration) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) { | func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { | ||||||
| 	defer func(now time.Time) { | 	defer func(now time.Time) { | ||||||
| 		metrics.MeasureSince([]string{"database", "RevokeUser"}, now) | 		metrics.MeasureSince([]string{"database", "RevokeUser"}, now) | ||||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) | 		metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) | ||||||
| @@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username | |||||||
|  |  | ||||||
| 	metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) | 	metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) | ||||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) | 	metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) | ||||||
| 	return mw.next.RevokeUser(statements, username) | 	return mw.next.RevokeUser(ctx, statements, username) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { | func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { | ||||||
| 	defer func(now time.Time) { | 	defer func(now time.Time) { | ||||||
| 		metrics.MeasureSince([]string{"database", "Initialize"}, now) | 		metrics.MeasureSince([]string{"database", "Initialize"}, now) | ||||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) | 		metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) | ||||||
| @@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver | |||||||
|  |  | ||||||
| 	metrics.IncrCounter([]string{"database", "Initialize"}, 1) | 	metrics.IncrCounter([]string{"database", "Initialize"}, 1) | ||||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) | 	metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) | ||||||
| 	return mw.next.Initialize(conf, verifyConnection) | 	return mw.next.Initialize(ctx, conf, verifyConnection) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (mw *databaseMetricsMiddleware) Close() (err error) { | func (mw *databaseMetricsMiddleware) Close() (err error) { | ||||||
|   | |||||||
							
								
								
									
										198
									
								
								builtin/logical/database/dbplugin/grpc_transport.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								builtin/logical/database/dbplugin/grpc_transport.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,198 @@ | |||||||
|  | package dbplugin | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"google.golang.org/grpc" | ||||||
|  | 	"google.golang.org/grpc/connectivity" | ||||||
|  |  | ||||||
|  | 	"github.com/golang/protobuf/ptypes" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	ErrPluginShutdown = errors.New("plugin shutdown") | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---- gRPC Server domain ---- | ||||||
|  |  | ||||||
|  | type gRPCServer struct { | ||||||
|  | 	impl Database | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) { | ||||||
|  | 	t, err := s.impl.Type() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &TypeResponse{ | ||||||
|  | 		Type: t, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) { | ||||||
|  | 	e, err := ptypes.Timestamp(req.Expiration) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e) | ||||||
|  |  | ||||||
|  | 	return &CreateUserResponse{ | ||||||
|  | 		Username: u, | ||||||
|  | 		Password: p, | ||||||
|  | 	}, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) { | ||||||
|  | 	e, err := ptypes.Timestamp(req.Expiration) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e) | ||||||
|  | 	return &Empty{}, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) { | ||||||
|  | 	err := s.impl.RevokeUser(ctx, *req.Statements, req.Username) | ||||||
|  | 	return &Empty{}, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) { | ||||||
|  | 	config := map[string]interface{}{} | ||||||
|  |  | ||||||
|  | 	err := json.Unmarshal(req.Config, &config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = s.impl.Initialize(ctx, config, req.VerifyConnection) | ||||||
|  | 	return &Empty{}, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) { | ||||||
|  | 	s.impl.Close() | ||||||
|  | 	return &Empty{}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---- gRPC client domain ---- | ||||||
|  |  | ||||||
|  | type gRPCClient struct { | ||||||
|  | 	client     DatabaseClient | ||||||
|  | 	clientConn *grpc.ClientConn | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c gRPCClient) Type() (string, error) { | ||||||
|  | 	// If the plugin has already shutdown, this will hang forever so we give it | ||||||
|  | 	// a one second timeout. | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	switch c.clientConn.GetState() { | ||||||
|  | 	case connectivity.Ready, connectivity.Idle: | ||||||
|  | 	default: | ||||||
|  | 		return "", ErrPluginShutdown | ||||||
|  | 	} | ||||||
|  | 	resp, err := c.client.Type(ctx, &Empty{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp.Type, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||||
|  | 	t, err := ptypes.TimestampProto(expiration) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch c.clientConn.GetState() { | ||||||
|  | 	case connectivity.Ready, connectivity.Idle: | ||||||
|  | 	default: | ||||||
|  | 		return "", "", ErrPluginShutdown | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err := c.client.CreateUser(ctx, &CreateUserRequest{ | ||||||
|  | 		Statements:     &statements, | ||||||
|  | 		UsernameConfig: &usernameConfig, | ||||||
|  | 		Expiration:     t, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp.Username, resp.Password, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error { | ||||||
|  | 	t, err := ptypes.TimestampProto(expiration) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch c.clientConn.GetState() { | ||||||
|  | 	case connectivity.Ready, connectivity.Idle: | ||||||
|  | 	default: | ||||||
|  | 		return ErrPluginShutdown | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = c.client.RenewUser(ctx, &RenewUserRequest{ | ||||||
|  | 		Statements: &statements, | ||||||
|  | 		Username:   username, | ||||||
|  | 		Expiration: t, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error { | ||||||
|  | 	switch c.clientConn.GetState() { | ||||||
|  | 	case connectivity.Ready, connectivity.Idle: | ||||||
|  | 	default: | ||||||
|  | 		return ErrPluginShutdown | ||||||
|  | 	} | ||||||
|  | 	_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{ | ||||||
|  | 		Statements: &statements, | ||||||
|  | 		Username:   username, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error { | ||||||
|  | 	configRaw, err := json.Marshal(config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch c.clientConn.GetState() { | ||||||
|  | 	case connectivity.Ready, connectivity.Idle: | ||||||
|  | 	default: | ||||||
|  | 		return ErrPluginShutdown | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = c.client.Initialize(ctx, &InitializeRequest{ | ||||||
|  | 		Config:           configRaw, | ||||||
|  | 		VerifyConnection: verifyConnection, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *gRPCClient) Close() error { | ||||||
|  | 	// If the plugin has already shutdown, this will hang forever so we give it | ||||||
|  | 	// a one second timeout. | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	switch c.clientConn.GetState() { | ||||||
|  | 	case connectivity.Ready, connectivity.Idle: | ||||||
|  | 		_, err := c.client.Close(ctx, &Empty{}) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										139
									
								
								builtin/logical/database/dbplugin/netrpc_transport.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								builtin/logical/database/dbplugin/netrpc_transport.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | |||||||
|  | package dbplugin | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/rpc" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---- RPC server domain ---- | ||||||
|  |  | ||||||
|  | // databasePluginRPCServer implements an RPC version of Database and is run | ||||||
|  | // inside a plugin. It wraps an underlying implementation of Database. | ||||||
|  | type databasePluginRPCServer struct { | ||||||
|  | 	impl Database | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { | ||||||
|  | 	var err error | ||||||
|  | 	*resp, err = ds.impl.Type() | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error { | ||||||
|  | 	var err error | ||||||
|  | 	resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error { | ||||||
|  | 	err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error { | ||||||
|  | 	err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error { | ||||||
|  | 	err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { | ||||||
|  | 	ds.impl.Close() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---- RPC client domain ---- | ||||||
|  | // databasePluginRPCClient implements Database and is used on the client to | ||||||
|  | // make RPC calls to a plugin. | ||||||
|  | type databasePluginRPCClient struct { | ||||||
|  | 	client *rpc.Client | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dr *databasePluginRPCClient) Type() (string, error) { | ||||||
|  | 	var dbType string | ||||||
|  | 	err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) | ||||||
|  |  | ||||||
|  | 	return fmt.Sprintf("plugin-%s", dbType), err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||||
|  | 	req := CreateUserRequestRPC{ | ||||||
|  | 		Statements:     statements, | ||||||
|  | 		UsernameConfig: usernameConfig, | ||||||
|  | 		Expiration:     expiration, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var resp CreateUserResponse | ||||||
|  | 	err = dr.client.Call("Plugin.CreateUser", req, &resp) | ||||||
|  |  | ||||||
|  | 	return resp.Username, resp.Password, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error { | ||||||
|  | 	req := RenewUserRequestRPC{ | ||||||
|  | 		Statements: statements, | ||||||
|  | 		Username:   username, | ||||||
|  | 		Expiration: expiration, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error { | ||||||
|  | 	req := RevokeUserRequestRPC{ | ||||||
|  | 		Statements: statements, | ||||||
|  | 		Username:   username, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error { | ||||||
|  | 	req := InitializeRequestRPC{ | ||||||
|  | 		Config:           conf, | ||||||
|  | 		VerifyConnection: verifyConnection, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dr *databasePluginRPCClient) Close() error { | ||||||
|  | 	err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---- RPC Request Args Domain ---- | ||||||
|  |  | ||||||
|  | type InitializeRequestRPC struct { | ||||||
|  | 	Config           map[string]interface{} | ||||||
|  | 	VerifyConnection bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type CreateUserRequestRPC struct { | ||||||
|  | 	Statements     Statements | ||||||
|  | 	UsernameConfig UsernameConfig | ||||||
|  | 	Expiration     time.Time | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type RenewUserRequestRPC struct { | ||||||
|  | 	Statements Statements | ||||||
|  | 	Username   string | ||||||
|  | 	Expiration time.Time | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type RevokeUserRequestRPC struct { | ||||||
|  | 	Statements Statements | ||||||
|  | 	Username   string | ||||||
|  | } | ||||||
| @@ -1,10 +1,13 @@ | |||||||
| package dbplugin | package dbplugin | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/rpc" | 	"net/rpc" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"google.golang.org/grpc" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-plugin" | 	"github.com/hashicorp/go-plugin" | ||||||
| 	"github.com/hashicorp/vault/helper/pluginutil" | 	"github.com/hashicorp/vault/helper/pluginutil" | ||||||
| 	log "github.com/mgutz/logxi/v1" | 	log "github.com/mgutz/logxi/v1" | ||||||
| @@ -13,29 +16,14 @@ import ( | |||||||
| // Database is the interface that all database objects must implement. | // Database is the interface that all database objects must implement. | ||||||
| type Database interface { | type Database interface { | ||||||
| 	Type() (string, error) | 	Type() (string, error) | ||||||
| 	CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) | 	CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) | ||||||
| 	RenewUser(statements Statements, username string, expiration time.Time) error | 	RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error | ||||||
| 	RevokeUser(statements Statements, username string) error | 	RevokeUser(ctx context.Context, statements Statements, username string) error | ||||||
|  |  | ||||||
| 	Initialize(config map[string]interface{}, verifyConnection bool) error | 	Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error | ||||||
| 	Close() error | 	Close() error | ||||||
| } | } | ||||||
|  |  | ||||||
| // Statements set in role creation and passed into the database type's functions. |  | ||||||
| type Statements struct { |  | ||||||
| 	CreationStatements   string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` |  | ||||||
| 	RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` |  | ||||||
| 	RollbackStatements   string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` |  | ||||||
| 	RenewStatements      string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // UsernameConfig is used to configure prefixes for the username to be |  | ||||||
| // generated. |  | ||||||
| type UsernameConfig struct { |  | ||||||
| 	DisplayName string |  | ||||||
| 	RoleName    string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // PluginFactory is used to build plugin database types. It wraps the database | // PluginFactory is used to build plugin database types. It wraps the database | ||||||
| // object in a logging and metrics middleware. | // object in a logging and metrics middleware. | ||||||
| func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { | func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { | ||||||
| @@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	var transport string | ||||||
| 	var db Database | 	var db Database | ||||||
| 	if pluginRunner.Builtin { | 	if pluginRunner.Builtin { | ||||||
| 		// Plugin is builtin so we can retrieve an instance of the interface | 		// Plugin is builtin so we can retrieve an instance of the interface | ||||||
| @@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. | |||||||
| 			return nil, fmt.Errorf("unsuported database type: %s", pluginName) | 			return nil, fmt.Errorf("unsuported database type: %s", pluginName) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		transport = "builtin" | ||||||
|  |  | ||||||
| 	} else { | 	} else { | ||||||
| 		// create a DatabasePluginClient instance | 		// create a DatabasePluginClient instance | ||||||
| 		db, err = newPluginClient(sys, pluginRunner, logger) | 		db, err = newPluginClient(sys, pluginRunner, logger) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Switch on the underlying database client type to get the transport | ||||||
|  | 		// method. | ||||||
|  | 		switch db.(*DatabasePluginClient).Database.(type) { | ||||||
|  | 		case *gRPCClient: | ||||||
|  | 			transport = "gRPC" | ||||||
|  | 		case *databasePluginRPCClient: | ||||||
|  | 			transport = "netRPC" | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	typeStr, err := db.Type() | 	typeStr, err := db.Type() | ||||||
| @@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. | |||||||
| 	// Wrap with tracing middleware | 	// Wrap with tracing middleware | ||||||
| 	if logger.IsTrace() { | 	if logger.IsTrace() { | ||||||
| 		db = &databaseTracingMiddleware{ | 		db = &databaseTracingMiddleware{ | ||||||
| 			next:    db, | 			transport: transport, | ||||||
| 			typeStr: typeStr, | 			next:      db, | ||||||
| 			logger:  logger, | 			typeStr:   typeStr, | ||||||
|  | 			logger:    logger, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e | |||||||
| 	return &databasePluginRPCClient{client: c}, nil | 	return &databasePluginRPCClient{client: c}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // ---- RPC Request Args Domain ---- | func (d DatabasePlugin) GRPCServer(s *grpc.Server) error { | ||||||
|  | 	RegisterDatabaseServer(s, &gRPCServer{impl: d.impl}) | ||||||
| type InitializeRequest struct { | 	return nil | ||||||
| 	Config           map[string]interface{} |  | ||||||
| 	VerifyConnection bool |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type CreateUserRequest struct { | func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) { | ||||||
| 	Statements     Statements | 	return &gRPCClient{ | ||||||
| 	UsernameConfig UsernameConfig | 		client:     NewDatabaseClient(c), | ||||||
| 	Expiration     time.Time | 		clientConn: c, | ||||||
| } | 	}, nil | ||||||
|  |  | ||||||
| type RenewUserRequest struct { |  | ||||||
| 	Statements Statements |  | ||||||
| 	Username   string |  | ||||||
| 	Expiration time.Time |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type RevokeUserRequest struct { |  | ||||||
| 	Statements Statements |  | ||||||
| 	Username   string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ---- RPC Response Args Domain ---- |  | ||||||
|  |  | ||||||
| type CreateUserResponse struct { |  | ||||||
| 	Username string |  | ||||||
| 	Password string |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,11 +1,13 @@ | |||||||
| package dbplugin_test | package dbplugin_test | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"os" | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	plugin "github.com/hashicorp/go-plugin" | ||||||
| 	"github.com/hashicorp/vault/builtin/logical/database/dbplugin" | 	"github.com/hashicorp/vault/builtin/logical/database/dbplugin" | ||||||
| 	"github.com/hashicorp/vault/helper/pluginutil" | 	"github.com/hashicorp/vault/helper/pluginutil" | ||||||
| 	vaulthttp "github.com/hashicorp/vault/http" | 	vaulthttp "github.com/hashicorp/vault/http" | ||||||
| @@ -20,7 +22,7 @@ type mockPlugin struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (m *mockPlugin) Type() (string, error) { return "mock", nil } | func (m *mockPlugin) Type() (string, error) { return "mock", nil } | ||||||
| func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||||
| 	err = errors.New("err") | 	err = errors.New("err") | ||||||
| 	if usernameConf.DisplayName == "" || expiration.IsZero() { | 	if usernameConf.DisplayName == "" || expiration.IsZero() { | ||||||
| 		return "", "", err | 		return "", "", err | ||||||
| @@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp | |||||||
|  |  | ||||||
| 	return usernameConf.DisplayName, "test", nil | 	return usernameConf.DisplayName, "test", nil | ||||||
| } | } | ||||||
| func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||||
| 	err := errors.New("err") | 	err := errors.New("err") | ||||||
| 	if username == "" || expiration.IsZero() { | 	if username == "" || expiration.IsZero() { | ||||||
| 		return err | 		return err | ||||||
| @@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, | |||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error { | func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error { | ||||||
| 	err := errors.New("err") | 	err := errors.New("err") | ||||||
| 	if username == "" { | 	if username == "" { | ||||||
| 		return err | 		return err | ||||||
| @@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) | |||||||
| 	delete(m.users, username) | 	delete(m.users, username) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { | func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error { | ||||||
| 	err := errors.New("err") | 	err := errors.New("err") | ||||||
| 	if len(conf) != 1 { | 	if len(conf) != 1 { | ||||||
| 		return err | 		return err | ||||||
| @@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { | |||||||
| 	cores := cluster.Cores | 	cores := cluster.Cores | ||||||
|  |  | ||||||
| 	sys := vault.TestDynamicSystemView(cores[0].Core) | 	sys := vault.TestDynamicSystemView(cores[0].Core) | ||||||
| 	vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main") | 	vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main") | ||||||
|  | 	vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main") | ||||||
|  |  | ||||||
| 	return cluster, sys | 	return cluster, sys | ||||||
| } | } | ||||||
|  |  | ||||||
| // This is not an actual test case, it's a helper function that will be executed | // This is not an actual test case, it's a helper function that will be executed | ||||||
| // by the go-plugin client via an exec call. | // by the go-plugin client via an exec call. | ||||||
| func TestPlugin_Main(t *testing.T) { | func TestPlugin_GRPC_Main(t *testing.T) { | ||||||
| 	if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { | 	if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) { | |||||||
| 	plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) | 	plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // This is not an actual test case, it's a helper function that will be executed | ||||||
|  | // by the go-plugin client via an exec call. | ||||||
|  | func TestPlugin_NetRPC_Main(t *testing.T) { | ||||||
|  | 	if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	p := &mockPlugin{ | ||||||
|  | 		users: make(map[string][]string), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	args := []string{"--tls-skip-verify=true"} | ||||||
|  |  | ||||||
|  | 	apiClientMeta := &pluginutil.APIClientMeta{} | ||||||
|  | 	flags := apiClientMeta.FlagSet() | ||||||
|  | 	flags.Parse(args) | ||||||
|  |  | ||||||
|  | 	tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig()) | ||||||
|  | 	serveConf := dbplugin.ServeConfig(p, tlsProvider) | ||||||
|  | 	serveConf.GRPCServer = nil | ||||||
|  |  | ||||||
|  | 	plugin.Serve(serveConf) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestPlugin_Initialize(t *testing.T) { | func TestPlugin_Initialize(t *testing.T) { | ||||||
| 	cluster, sys := getCluster(t) | 	cluster, sys := getCluster(t) | ||||||
| 	defer cluster.Cleanup() | 	defer cluster.Cleanup() | ||||||
| @@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) { | |||||||
| 		"test": 1, | 		"test": 1, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = dbRaw.Initialize(connectionDetails, true) | 	err = dbRaw.Initialize(context.Background(), connectionDetails, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) { | |||||||
| 		"test": 1, | 		"test": 1, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.Initialize(connectionDetails, true) | 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) { | |||||||
| 		RoleName:    "test", | 		RoleName:    "test", | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | 	us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) { | |||||||
|  |  | ||||||
| 	// try and save the same user again to verify it saved the first time, this | 	// try and save the same user again to verify it saved the first time, this | ||||||
| 	// should return an error | 	// should return an error | ||||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		t.Fatal("expected an error, user wasn't created correctly") | 		t.Fatal("expected an error, user wasn't created correctly") | ||||||
| 	} | 	} | ||||||
| @@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) { | |||||||
| 	connectionDetails := map[string]interface{}{ | 	connectionDetails := map[string]interface{}{ | ||||||
| 		"test": 1, | 		"test": 1, | ||||||
| 	} | 	} | ||||||
| 	err = db.Initialize(connectionDetails, true) | 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) { | |||||||
| 		RoleName:    "test", | 		RoleName:    "test", | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute)) | 	err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) { | |||||||
| 	connectionDetails := map[string]interface{}{ | 	connectionDetails := map[string]interface{}{ | ||||||
| 		"test": 1, | 		"test": 1, | ||||||
| 	} | 	} | ||||||
| 	err = db.Initialize(connectionDetails, true) | 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) { | |||||||
| 		RoleName:    "test", | 		RoleName:    "test", | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Test default revoke statememts | 	// Test default revoke statememts | ||||||
| 	err = db.RevokeUser(dbplugin.Statements{}, us) | 	err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Try adding the same username back so we can verify it was removed | 	// Try adding the same username back so we can verify it was removed | ||||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test the code is still compatible with an old netRPC plugin | ||||||
|  | func TestPlugin_NetRPC_Initialize(t *testing.T) { | ||||||
|  | 	cluster, sys := getCluster(t) | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	connectionDetails := map[string]interface{}{ | ||||||
|  | 		"test": 1, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = dbRaw.Initialize(context.Background(), connectionDetails, true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = dbRaw.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPlugin_NetRPC_CreateUser(t *testing.T) { | ||||||
|  | 	cluster, sys := getCluster(t) | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  | 	defer db.Close() | ||||||
|  |  | ||||||
|  | 	connectionDetails := map[string]interface{}{ | ||||||
|  | 		"test": 1, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usernameConf := dbplugin.UsernameConfig{ | ||||||
|  | 		DisplayName: "test", | ||||||
|  | 		RoleName:    "test", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if us != "test" || pw != "test" { | ||||||
|  | 		t.Fatal("expected username and password to be 'test'") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// try and save the same user again to verify it saved the first time, this | ||||||
|  | 	// should return an error | ||||||
|  | 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatal("expected an error, user wasn't created correctly") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPlugin_NetRPC_RenewUser(t *testing.T) { | ||||||
|  | 	cluster, sys := getCluster(t) | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  | 	defer db.Close() | ||||||
|  |  | ||||||
|  | 	connectionDetails := map[string]interface{}{ | ||||||
|  | 		"test": 1, | ||||||
|  | 	} | ||||||
|  | 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usernameConf := dbplugin.UsernameConfig{ | ||||||
|  | 		DisplayName: "test", | ||||||
|  | 		RoleName:    "test", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPlugin_NetRPC_RevokeUser(t *testing.T) { | ||||||
|  | 	cluster, sys := getCluster(t) | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  | 	defer db.Close() | ||||||
|  |  | ||||||
|  | 	connectionDetails := map[string]interface{}{ | ||||||
|  | 		"test": 1, | ||||||
|  | 	} | ||||||
|  | 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usernameConf := dbplugin.UsernameConfig{ | ||||||
|  | 		DisplayName: "test", | ||||||
|  | 		RoleName:    "test", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Test default revoke statememts | ||||||
|  | 	err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Try adding the same username back so we can verify it was removed | ||||||
|  | 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %s", err) | 		t.Fatalf("err: %s", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -10,6 +10,10 @@ import ( | |||||||
| // Database implementation in a databasePluginRPCServer object and starts a | // Database implementation in a databasePluginRPCServer object and starts a | ||||||
| // RPC server. | // RPC server. | ||||||
| func Serve(db Database, tlsProvider func() (*tls.Config, error)) { | func Serve(db Database, tlsProvider func() (*tls.Config, error)) { | ||||||
|  | 	plugin.Serve(ServeConfig(db, tlsProvider)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig { | ||||||
| 	dbPlugin := &DatabasePlugin{ | 	dbPlugin := &DatabasePlugin{ | ||||||
| 		impl: db, | 		impl: db, | ||||||
| 	} | 	} | ||||||
| @@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) { | |||||||
| 		"database": dbPlugin, | 		"database": dbPlugin, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	plugin.Serve(&plugin.ServeConfig{ | 	return &plugin.ServeConfig{ | ||||||
| 		HandshakeConfig: handshakeConfig, | 		HandshakeConfig: handshakeConfig, | ||||||
| 		Plugins:         pluginMap, | 		Plugins:         pluginMap, | ||||||
| 		TLSProvider:     tlsProvider, | 		TLSProvider:     tlsProvider, | ||||||
| 	}) | 		GRPCServer:      plugin.DefaultGRPCServer, | ||||||
| } | 	} | ||||||
|  |  | ||||||
| // ---- RPC server domain ---- |  | ||||||
|  |  | ||||||
| // databasePluginRPCServer implements an RPC version of Database and is run |  | ||||||
| // inside a plugin. It wraps an underlying implementation of Database. |  | ||||||
| type databasePluginRPCServer struct { |  | ||||||
| 	impl Database |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { |  | ||||||
| 	var err error |  | ||||||
| 	*resp, err = ds.impl.Type() |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { |  | ||||||
| 	var err error |  | ||||||
| 	resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { |  | ||||||
| 	err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { |  | ||||||
| 	err := ds.impl.RevokeUser(args.Statements, args.Username) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error { |  | ||||||
| 	err := ds.impl.Initialize(args.Config, args.VerifyConnection) |  | ||||||
|  |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { |  | ||||||
| 	ds.impl.Close() |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package database | package database | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  |  | ||||||
| @@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { | |||||||
| 		b.clearConnection(name) | 		b.clearConnection(name) | ||||||
|  |  | ||||||
| 		// Execute plugin again, we don't need the object so throw away. | 		// Execute plugin again, we don't need the object so throw away. | ||||||
| 		_, err := b.createDBObj(req.Storage, name) | 		_, err := b.createDBObj(context.TODO(), req.Storage, name) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| @@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { | |||||||
| 			return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil | 			return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		err = db.Initialize(config.ConnectionDetails, verifyConnection) | 		err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			db.Close() | 			db.Close() | ||||||
| 			return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil | 			return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package database | package database | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -49,7 +50,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { | |||||||
|  |  | ||||||
| 		// If role name isn't in the database's allowed roles, send back a | 		// If role name isn't in the database's allowed roles, send back a | ||||||
| 		// permission denied. | 		// permission denied. | ||||||
| 		if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContains(dbConfig.AllowedRoles, name) { | 		if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContainsGlob(dbConfig.AllowedRoles, name) { | ||||||
| 			return nil, logical.ErrPermissionDenied | 			return nil, logical.ErrPermissionDenied | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { | |||||||
| 			unlockFunc = b.Unlock | 			unlockFunc = b.Unlock | ||||||
|  |  | ||||||
| 			// Create a new DB object | 			// Create a new DB object | ||||||
| 			db, err = b.createDBObj(req.Storage, role.DBName) | 			db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				unlockFunc() | 				unlockFunc() | ||||||
| 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | ||||||
| @@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Create the user | 		// Create the user | ||||||
| 		username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration) | 		username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration) | ||||||
| 		// Unlock | 		// Unlock | ||||||
| 		unlockFunc() | 		unlockFunc() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|   | |||||||
| @@ -181,7 +181,7 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { | |||||||
|  |  | ||||||
| type roleEntry struct { | type roleEntry struct { | ||||||
| 	DBName     string              `json:"db_name" mapstructure:"db_name" structs:"db_name"` | 	DBName     string              `json:"db_name" mapstructure:"db_name" structs:"db_name"` | ||||||
| 	Statements dbplugin.Statements `json:"statments" mapstructure:"statements" structs:"statments"` | 	Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"` | ||||||
| 	DefaultTTL time.Duration       `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` | 	DefaultTTL time.Duration       `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` | ||||||
| 	MaxTTL     time.Duration       `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` | 	MaxTTL     time.Duration       `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package database | package database | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| @@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { | |||||||
| 			unlockFunc = b.Unlock | 			unlockFunc = b.Unlock | ||||||
|  |  | ||||||
| 			// Create a new DB object | 			// Create a new DB object | ||||||
| 			db, err = b.createDBObj(req.Storage, role.DBName) | 			db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				unlockFunc() | 				unlockFunc() | ||||||
| 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | ||||||
| @@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { | |||||||
|  |  | ||||||
| 		// Make sure we increase the VALID UNTIL endpoint for this user. | 		// Make sure we increase the VALID UNTIL endpoint for this user. | ||||||
| 		if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { | 		if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { | ||||||
| 			err := db.RenewUser(role.Statements, username, expireTime) | 			err := db.RenewUser(context.TODO(), role.Statements, username, expireTime) | ||||||
| 			// Unlock | 			// Unlock | ||||||
| 			unlockFunc() | 			unlockFunc() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { | |||||||
| 			unlockFunc = b.Unlock | 			unlockFunc = b.Unlock | ||||||
|  |  | ||||||
| 			// Create a new DB object | 			// Create a new DB object | ||||||
| 			db, err = b.createDBObj(req.Storage, role.DBName) | 			db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				unlockFunc() | 				unlockFunc() | ||||||
| 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		err = db.RevokeUser(role.Statements, username) | 		err = db.RevokeUser(context.TODO(), role.Statements, username) | ||||||
| 		// Unlock | 		// Unlock | ||||||
| 		unlockFunc() | 		unlockFunc() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|   | |||||||
| @@ -24,6 +24,12 @@ func Backend() *framework.Backend { | |||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: strings.TrimSpace(backendHelp), | 		Help: strings.TrimSpace(backendHelp), | ||||||
|  |  | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/connection", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathConfigConnection(&b), | 			pathConfigConnection(&b), | ||||||
| 			pathConfigLease(&b), | 			pathConfigLease(&b), | ||||||
|   | |||||||
| @@ -24,6 +24,12 @@ func Backend() *backend { | |||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: strings.TrimSpace(backendHelp), | 		Help: strings.TrimSpace(backendHelp), | ||||||
|  |  | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/connection", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathConfigConnection(&b), | 			pathConfigConnection(&b), | ||||||
| 			pathConfigLease(&b), | 			pathConfigLease(&b), | ||||||
|   | |||||||
| @@ -24,6 +24,12 @@ func Backend() *backend { | |||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: strings.TrimSpace(backendHelp), | 		Help: strings.TrimSpace(backendHelp), | ||||||
|  |  | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/connection", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathConfigConnection(&b), | 			pathConfigConnection(&b), | ||||||
| 			pathConfigLease(&b), | 			pathConfigLease(&b), | ||||||
|   | |||||||
							
								
								
									
										68
									
								
								builtin/logical/nomad/backend.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								builtin/logical/nomad/backend.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | |||||||
|  | package nomad | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/hashicorp/nomad/api" | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
|  | 	"github.com/hashicorp/vault/logical/framework" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Factory(conf *logical.BackendConfig) (logical.Backend, error) { | ||||||
|  | 	b := Backend() | ||||||
|  | 	if err := b.Setup(conf); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return b, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Backend() *backend { | ||||||
|  | 	var b backend | ||||||
|  | 	b.Backend = &framework.Backend{ | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/access", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		Paths: []*framework.Path{ | ||||||
|  | 			pathConfigAccess(&b), | ||||||
|  | 			pathConfigLease(&b), | ||||||
|  | 			pathListRoles(&b), | ||||||
|  | 			pathRoles(&b), | ||||||
|  | 			pathCredsCreate(&b), | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		Secrets: []*framework.Secret{ | ||||||
|  | 			secretToken(&b), | ||||||
|  | 		}, | ||||||
|  | 		BackendType: logical.TypeLogical, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &b | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type backend struct { | ||||||
|  | 	*framework.Backend | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) client(s logical.Storage) (*api.Client, error) { | ||||||
|  | 	conf, err := b.readConfigAccess(s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	nomadConf := api.DefaultConfig() | ||||||
|  | 	if conf != nil { | ||||||
|  | 		if conf.Address != "" { | ||||||
|  | 			nomadConf.Address = conf.Address | ||||||
|  | 		} | ||||||
|  | 		if conf.Token != "" { | ||||||
|  | 			nomadConf.SecretID = conf.Token | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	client, err := api.NewClient(nomadConf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return client, nil | ||||||
|  | } | ||||||
							
								
								
									
										302
									
								
								builtin/logical/nomad/backend_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										302
									
								
								builtin/logical/nomad/backend_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,302 @@ | |||||||
|  | package nomad | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	nomadapi "github.com/hashicorp/nomad/api" | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
|  | 	"github.com/mitchellh/mapstructure" | ||||||
|  | 	dockertest "gopkg.in/ory-am/dockertest.v3" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func prepareTestContainer(t *testing.T) (cleanup func(), retAddress string, nomadToken string) { | ||||||
|  | 	nomadToken = os.Getenv("NOMAD_TOKEN") | ||||||
|  |  | ||||||
|  | 	retAddress = os.Getenv("NOMAD_ADDR") | ||||||
|  |  | ||||||
|  | 	if retAddress != "" { | ||||||
|  | 		return func() {}, retAddress, nomadToken | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	pool, err := dockertest.NewPool("") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to connect to docker: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dockerOptions := &dockertest.RunOptions{ | ||||||
|  | 		Repository: "djenriquez/nomad", | ||||||
|  | 		Tag:        "latest", | ||||||
|  | 		Cmd:        []string{"agent", "-dev"}, | ||||||
|  | 		Env:        []string{`NOMAD_LOCAL_CONFIG=bind_addr = "0.0.0.0" acl { enabled = true }`}, | ||||||
|  | 	} | ||||||
|  | 	resource, err := pool.RunWithOptions(dockerOptions) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Could not start local Nomad docker container: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cleanup = func() { | ||||||
|  | 		err := pool.Purge(resource) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to cleanup local container: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	retAddress = fmt.Sprintf("http://localhost:%s/", resource.GetPort("4646/tcp")) | ||||||
|  | 	// Give Nomad time to initialize | ||||||
|  |  | ||||||
|  | 	time.Sleep(5000 * time.Millisecond) | ||||||
|  | 	// exponential backoff-retry | ||||||
|  | 	if err = pool.Retry(func() error { | ||||||
|  | 		var err error | ||||||
|  | 		nomadapiConfig := nomadapi.DefaultConfig() | ||||||
|  | 		nomadapiConfig.Address = retAddress | ||||||
|  | 		nomad, err := nomadapi.NewClient(nomadapiConfig) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		aclbootstrap, _, err := nomad.ACLTokens().Bootstrap(nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("err: %v", err) | ||||||
|  | 		} | ||||||
|  | 		nomadToken = aclbootstrap.SecretID | ||||||
|  | 		t.Logf("[WARN] Generated Master token: %s", nomadToken) | ||||||
|  | 		policy := &nomadapi.ACLPolicy{ | ||||||
|  | 			Name:        "test", | ||||||
|  | 			Description: "test", | ||||||
|  | 			Rules: `namespace "default" { | ||||||
|  |         policy = "read" | ||||||
|  |       } | ||||||
|  |       `, | ||||||
|  | 		} | ||||||
|  | 		anonPolicy := &nomadapi.ACLPolicy{ | ||||||
|  | 			Name:        "anonymous", | ||||||
|  | 			Description: "Deny all access for anonymous requests", | ||||||
|  | 			Rules: `namespace "default" { | ||||||
|  |             policy = "deny" | ||||||
|  |         } | ||||||
|  |         agent { | ||||||
|  |             policy = "deny" | ||||||
|  |         } | ||||||
|  |         node { | ||||||
|  |             policy = "deny" | ||||||
|  |         } | ||||||
|  |         `, | ||||||
|  | 		} | ||||||
|  | 		nomadAuthConfig := nomadapi.DefaultConfig() | ||||||
|  | 		nomadAuthConfig.Address = retAddress | ||||||
|  | 		nomadAuthConfig.SecretID = nomadToken | ||||||
|  | 		nomadAuth, err := nomadapi.NewClient(nomadAuthConfig) | ||||||
|  | 		_, err = nomadAuth.ACLPolicies().Upsert(policy, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		_, err = nomadAuth.ACLPolicies().Upsert(anonPolicy, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
|  | 	}); err != nil { | ||||||
|  | 		cleanup() | ||||||
|  | 		t.Fatalf("Could not connect to docker: %s", err) | ||||||
|  | 	} | ||||||
|  | 	return cleanup, retAddress, nomadToken | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestBackend_config_access(t *testing.T) { | ||||||
|  | 	config := logical.TestBackendConfig() | ||||||
|  | 	config.StorageView = &logical.InmemStorage{} | ||||||
|  | 	b, err := Factory(config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cleanup, connURL, connToken := prepareTestContainer(t) | ||||||
|  | 	defer cleanup() | ||||||
|  |  | ||||||
|  | 	connData := map[string]interface{}{ | ||||||
|  | 		"address": connURL, | ||||||
|  | 		"token":   connToken, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	confReq := &logical.Request{ | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "config/access", | ||||||
|  | 		Storage:   config.StorageView, | ||||||
|  | 		Data:      connData, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err := b.HandleRequest(confReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) || resp != nil { | ||||||
|  | 		t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	confReq.Operation = logical.ReadOperation | ||||||
|  | 	resp, err = b.HandleRequest(confReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	expected := map[string]interface{}{ | ||||||
|  | 		"address": connData["address"].(string), | ||||||
|  | 	} | ||||||
|  | 	if !reflect.DeepEqual(expected, resp.Data) { | ||||||
|  | 		t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) | ||||||
|  | 	} | ||||||
|  | 	if resp.Data["token"] != nil { | ||||||
|  | 		t.Fatalf("token should not be set in the response") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestBackend_renew_revoke(t *testing.T) { | ||||||
|  | 	config := logical.TestBackendConfig() | ||||||
|  | 	config.StorageView = &logical.InmemStorage{} | ||||||
|  | 	b, err := Factory(config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cleanup, connURL, connToken := prepareTestContainer(t) | ||||||
|  | 	defer cleanup() | ||||||
|  | 	connData := map[string]interface{}{ | ||||||
|  | 		"address": connURL, | ||||||
|  | 		"token":   connToken, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req := &logical.Request{ | ||||||
|  | 		Storage:   config.StorageView, | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "config/access", | ||||||
|  | 		Data:      connData, | ||||||
|  | 	} | ||||||
|  | 	resp, err := b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.Path = "role/test" | ||||||
|  | 	req.Data = map[string]interface{}{ | ||||||
|  | 		"policies": []string{"policy"}, | ||||||
|  | 		"lease":    "6h", | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.Operation = logical.ReadOperation | ||||||
|  | 	req.Path = "creds/test" | ||||||
|  | 	resp, err = b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if resp == nil { | ||||||
|  | 		t.Fatal("resp nil") | ||||||
|  | 	} | ||||||
|  | 	if resp.IsError() { | ||||||
|  | 		t.Fatalf("resp is error: %v", resp.Error()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	generatedSecret := resp.Secret | ||||||
|  | 	generatedSecret.IssueTime = time.Now() | ||||||
|  | 	generatedSecret.TTL = 6 * time.Hour | ||||||
|  |  | ||||||
|  | 	var d struct { | ||||||
|  | 		Token    string `mapstructure:"secret_id"` | ||||||
|  | 		Accessor string `mapstructure:"accessor_id"` | ||||||
|  | 	} | ||||||
|  | 	if err := mapstructure.Decode(resp.Data, &d); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	t.Logf("[WARN] Generated token: %s with accesor %s", d.Token, d.Accessor) | ||||||
|  |  | ||||||
|  | 	// Build a client and verify that the credentials work | ||||||
|  | 	nomadapiConfig := nomadapi.DefaultConfig() | ||||||
|  | 	nomadapiConfig.Address = connData["address"].(string) | ||||||
|  | 	nomadapiConfig.SecretID = d.Token | ||||||
|  | 	client, err := nomadapi.NewClient(nomadapiConfig) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("[WARN] Verifying that the generated token works...") | ||||||
|  | 	_, err = client.Agent().Members, nil | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.Operation = logical.RenewOperation | ||||||
|  | 	req.Secret = generatedSecret | ||||||
|  | 	resp, err = b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if resp == nil { | ||||||
|  | 		t.Fatal("got nil response from renew") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.Operation = logical.RevokeOperation | ||||||
|  | 	resp, err = b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Build a management client and verify that the token does not exist anymore | ||||||
|  | 	nomadmgmtConfig := nomadapi.DefaultConfig() | ||||||
|  | 	nomadmgmtConfig.Address = connData["address"].(string) | ||||||
|  | 	nomadmgmtConfig.SecretID = connData["token"].(string) | ||||||
|  | 	mgmtclient, err := nomadapi.NewClient(nomadmgmtConfig) | ||||||
|  |  | ||||||
|  | 	q := &nomadapi.QueryOptions{ | ||||||
|  | 		Namespace: "default", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("[WARN] Verifying that the generated token does not exist...") | ||||||
|  | 	_, _, err = mgmtclient.ACLTokens().Info(d.Accessor, q) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatal("err: expected error") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestBackend_CredsCreateEnvVar(t *testing.T) { | ||||||
|  | 	config := logical.TestBackendConfig() | ||||||
|  | 	config.StorageView = &logical.InmemStorage{} | ||||||
|  | 	b, err := Factory(config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cleanup, connURL, connToken := prepareTestContainer(t) | ||||||
|  | 	defer cleanup() | ||||||
|  |  | ||||||
|  | 	req := logical.TestRequest(t, logical.UpdateOperation, "role/test") | ||||||
|  | 	req.Data = map[string]interface{}{ | ||||||
|  | 		"policies": []string{"policy"}, | ||||||
|  | 		"lease":    "6h", | ||||||
|  | 	} | ||||||
|  | 	resp, err := b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	os.Setenv("NOMAD_TOKEN", connToken) | ||||||
|  | 	defer os.Unsetenv("NOMAD_TOKEN") | ||||||
|  | 	os.Setenv("NOMAD_ADDR", connURL) | ||||||
|  | 	defer os.Unsetenv("NOMAD_ADDR") | ||||||
|  |  | ||||||
|  | 	req.Operation = logical.ReadOperation | ||||||
|  | 	req.Path = "creds/test" | ||||||
|  | 	resp, err = b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if resp == nil { | ||||||
|  | 		t.Fatal("resp nil") | ||||||
|  | 	} | ||||||
|  | 	if resp.IsError() { | ||||||
|  | 		t.Fatalf("resp is error: %v", resp.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										121
									
								
								builtin/logical/nomad/path_config_access.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								builtin/logical/nomad/path_config_access.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | |||||||
|  | package nomad | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
|  | 	"github.com/hashicorp/vault/logical/framework" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const configAccessKey = "config/access" | ||||||
|  |  | ||||||
|  | func pathConfigAccess(b *backend) *framework.Path { | ||||||
|  | 	return &framework.Path{ | ||||||
|  | 		Pattern: "config/access", | ||||||
|  | 		Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 			"address": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "Nomad server address", | ||||||
|  | 			}, | ||||||
|  |  | ||||||
|  | 			"token": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "Token for API calls", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
|  | 			logical.ReadOperation:   b.pathConfigAccessRead, | ||||||
|  | 			logical.CreateOperation: b.pathConfigAccessWrite, | ||||||
|  | 			logical.UpdateOperation: b.pathConfigAccessWrite, | ||||||
|  | 			logical.DeleteOperation: b.pathConfigAccessDelete, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		ExistenceCheck: b.configExistenceCheck, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) configExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) { | ||||||
|  | 	entry, err := b.readConfigAccess(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return entry != nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) readConfigAccess(storage logical.Storage) (*accessConfig, error) { | ||||||
|  | 	entry, err := storage.Get(configAccessKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	conf := &accessConfig{} | ||||||
|  | 	if err := entry.DecodeJSON(conf); err != nil { | ||||||
|  | 		return nil, errwrap.Wrapf("error reading nomad access configuration: {{err}}", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return conf, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathConfigAccessRead( | ||||||
|  | 	req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	conf, err := b.readConfigAccess(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if conf == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &logical.Response{ | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"address": conf.Address, | ||||||
|  | 		}, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathConfigAccessWrite( | ||||||
|  | 	req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	conf, err := b.readConfigAccess(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if conf == nil { | ||||||
|  | 		conf = &accessConfig{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	address, ok := data.GetOk("address") | ||||||
|  | 	if ok { | ||||||
|  | 		conf.Address = address.(string) | ||||||
|  | 	} | ||||||
|  | 	token, ok := data.GetOk("token") | ||||||
|  | 	if ok { | ||||||
|  | 		conf.Token = token.(string) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	entry, err := logical.StorageEntryJSON("config/access", conf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if err := req.Storage.Put(entry); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathConfigAccessDelete( | ||||||
|  | 	req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	if err := req.Storage.Delete(configAccessKey); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type accessConfig struct { | ||||||
|  | 	Address string `json:"address"` | ||||||
|  | 	Token   string `json:"token"` | ||||||
|  | } | ||||||
							
								
								
									
										109
									
								
								builtin/logical/nomad/path_config_lease.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								builtin/logical/nomad/path_config_lease.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | |||||||
|  | package nomad | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
|  | 	"github.com/hashicorp/vault/logical/framework" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const leaseConfigKey = "config/lease" | ||||||
|  |  | ||||||
|  | func pathConfigLease(b *backend) *framework.Path { | ||||||
|  | 	return &framework.Path{ | ||||||
|  | 		Pattern: "config/lease", | ||||||
|  | 		Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 			"ttl": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeDurationSecond, | ||||||
|  | 				Description: "Duration before which the issued token needs renewal", | ||||||
|  | 			}, | ||||||
|  | 			"max_ttl": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeDurationSecond, | ||||||
|  | 				Description: `Duration after which the issued token should not be allowed to be renewed`, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
|  | 			logical.ReadOperation:   b.pathLeaseRead, | ||||||
|  | 			logical.UpdateOperation: b.pathLeaseUpdate, | ||||||
|  | 			logical.DeleteOperation: b.pathLeaseDelete, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		HelpSynopsis:    pathConfigLeaseHelpSyn, | ||||||
|  | 		HelpDescription: pathConfigLeaseHelpDesc, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Sets the lease configuration parameters | ||||||
|  | func (b *backend) pathLeaseUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	entry, err := logical.StorageEntryJSON("config/lease", &configLease{ | ||||||
|  | 		TTL:    time.Second * time.Duration(d.Get("ttl").(int)), | ||||||
|  | 		MaxTTL: time.Second * time.Duration(d.Get("max_ttl").(int)), | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if err := req.Storage.Put(entry); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathLeaseDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	if err := req.Storage.Delete(leaseConfigKey); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Returns the lease configuration parameters | ||||||
|  | func (b *backend) pathLeaseRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	lease, err := b.LeaseConfig(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if lease == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &logical.Response{ | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"ttl":     int64(lease.TTL.Seconds()), | ||||||
|  | 			"max_ttl": int64(lease.MaxTTL.Seconds()), | ||||||
|  | 		}, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Lease returns the lease information | ||||||
|  | func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) { | ||||||
|  | 	entry, err := s.Get(leaseConfigKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var result configLease | ||||||
|  | 	if err := entry.DecodeJSON(&result); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &result, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Lease configuration information for the secrets issued by this backend | ||||||
|  | type configLease struct { | ||||||
|  | 	TTL    time.Duration `json:"ttl" mapstructure:"ttl"` | ||||||
|  | 	MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var pathConfigLeaseHelpSyn = "Configure the lease parameters for generated tokens" | ||||||
|  |  | ||||||
|  | var pathConfigLeaseHelpDesc = ` | ||||||
|  | Sets the ttl and max_ttl values for the secrets to be issued by this backend. | ||||||
|  | Both ttl and max_ttl takes in an integer number of seconds as input as well as | ||||||
|  | inputs like "1h". | ||||||
|  | ` | ||||||
							
								
								
									
										80
									
								
								builtin/logical/nomad/path_creds_create.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								builtin/logical/nomad/path_creds_create.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | |||||||
|  | package nomad | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
|  | 	"github.com/hashicorp/nomad/api" | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
|  | 	"github.com/hashicorp/vault/logical/framework" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func pathCredsCreate(b *backend) *framework.Path { | ||||||
|  | 	return &framework.Path{ | ||||||
|  | 		Pattern: "creds/" + framework.GenericNameRegex("name"), | ||||||
|  | 		Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 			"name": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "Name of the role", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
|  | 			logical.ReadOperation: b.pathTokenRead, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathTokenRead( | ||||||
|  | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	name := d.Get("name").(string) | ||||||
|  |  | ||||||
|  | 	role, err := b.Role(req.Storage, name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errwrap.Wrapf("error retrieving role: {{err}}", err) | ||||||
|  | 	} | ||||||
|  | 	if role == nil { | ||||||
|  | 		return logical.ErrorResponse(fmt.Sprintf("role %q not found", name)), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Determine if we have a lease configuration | ||||||
|  | 	leaseConfig, err := b.LeaseConfig(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if leaseConfig == nil { | ||||||
|  | 		leaseConfig = &configLease{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get the nomad client | ||||||
|  | 	c, err := b.client(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Generate a name for the token | ||||||
|  | 	tokenName := fmt.Sprintf("vault-%s-%s-%d", name, req.DisplayName, time.Now().UnixNano()) | ||||||
|  |  | ||||||
|  | 	// Create it | ||||||
|  | 	token, _, err := c.ACLTokens().Create(&api.ACLToken{ | ||||||
|  | 		Name:     tokenName, | ||||||
|  | 		Type:     role.TokenType, | ||||||
|  | 		Policies: role.Policies, | ||||||
|  | 		Global:   role.Global, | ||||||
|  | 	}, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Use the helper to create the secret | ||||||
|  | 	resp := b.Secret(SecretTokenType).Response(map[string]interface{}{ | ||||||
|  | 		"secret_id":   token.SecretID, | ||||||
|  | 		"accessor_id": token.AccessorID, | ||||||
|  | 	}, map[string]interface{}{ | ||||||
|  | 		"accessor_id": token.AccessorID, | ||||||
|  | 	}) | ||||||
|  | 	resp.Secret.TTL = leaseConfig.TTL | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
|  | } | ||||||
							
								
								
									
										189
									
								
								builtin/logical/nomad/path_roles.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								builtin/logical/nomad/path_roles.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,189 @@ | |||||||
|  | package nomad | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
|  | 	"github.com/hashicorp/vault/logical/framework" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func pathListRoles(b *backend) *framework.Path { | ||||||
|  | 	return &framework.Path{ | ||||||
|  | 		Pattern: "role/?$", | ||||||
|  |  | ||||||
|  | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
|  | 			logical.ListOperation: b.pathRoleList, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func pathRoles(b *backend) *framework.Path { | ||||||
|  | 	return &framework.Path{ | ||||||
|  | 		Pattern: "role/" + framework.GenericNameRegex("name"), | ||||||
|  | 		Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 			"name": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "Name of the role", | ||||||
|  | 			}, | ||||||
|  |  | ||||||
|  | 			"policies": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeCommaStringSlice, | ||||||
|  | 				Description: "Comma-separated string or list of policies as previously created in Nomad. Required for 'client' token.", | ||||||
|  | 			}, | ||||||
|  |  | ||||||
|  | 			"global": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeBool, | ||||||
|  | 				Description: "Boolean value describing if the token should be global or not. Defaults to false.", | ||||||
|  | 			}, | ||||||
|  |  | ||||||
|  | 			"type": &framework.FieldSchema{ | ||||||
|  | 				Type:    framework.TypeString, | ||||||
|  | 				Default: "client", | ||||||
|  | 				Description: `Which type of token to create: 'client' | ||||||
|  | or 'management'. If a 'management' token, | ||||||
|  | the "policies" parameter is not required. | ||||||
|  | Defaults to 'client'.`, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		Callbacks: map[logical.Operation]framework.OperationFunc{ | ||||||
|  | 			logical.ReadOperation:   b.pathRolesRead, | ||||||
|  | 			logical.CreateOperation: b.pathRolesWrite, | ||||||
|  | 			logical.UpdateOperation: b.pathRolesWrite, | ||||||
|  | 			logical.DeleteOperation: b.pathRolesDelete, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		ExistenceCheck: b.rolesExistenceCheck, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Establishes dichotomy of request operation between CreateOperation and UpdateOperation. | ||||||
|  | // Returning 'true' forces an UpdateOperation, CreateOperation otherwise. | ||||||
|  | func (b *backend) rolesExistenceCheck(req *logical.Request, d *framework.FieldData) (bool, error) { | ||||||
|  | 	name := d.Get("name").(string) | ||||||
|  | 	entry, err := b.Role(req.Storage, name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	return entry != nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) Role(storage logical.Storage, name string) (*roleConfig, error) { | ||||||
|  | 	if name == "" { | ||||||
|  | 		return nil, errors.New("invalid role name") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	entry, err := storage.Get("role/" + name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errwrap.Wrapf("error retrieving role: {{err}}", err) | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var result roleConfig | ||||||
|  | 	if err := entry.DecodeJSON(&result); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &result, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathRoleList( | ||||||
|  | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	entries, err := req.Storage.List("role/") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return logical.ListResponse(entries), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathRolesRead( | ||||||
|  | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	name := d.Get("name").(string) | ||||||
|  |  | ||||||
|  | 	role, err := b.Role(req.Storage, name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if role == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Generate the response | ||||||
|  | 	resp := &logical.Response{ | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"type":     role.TokenType, | ||||||
|  | 			"global":   role.Global, | ||||||
|  | 			"policies": role.Policies, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	return resp, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathRolesWrite( | ||||||
|  | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	name := d.Get("name").(string) | ||||||
|  |  | ||||||
|  | 	role, err := b.Role(req.Storage, name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if role == nil { | ||||||
|  | 		role = new(roleConfig) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	policies, ok := d.GetOk("policies") | ||||||
|  | 	if ok { | ||||||
|  | 		role.Policies = policies.([]string) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	role.TokenType = d.Get("type").(string) | ||||||
|  | 	switch role.TokenType { | ||||||
|  | 	case "client": | ||||||
|  | 		if len(role.Policies) == 0 { | ||||||
|  | 			return logical.ErrorResponse( | ||||||
|  | 				"policies cannot be empty when using client tokens"), nil | ||||||
|  | 		} | ||||||
|  | 	case "management": | ||||||
|  | 		if len(role.Policies) != 0 { | ||||||
|  | 			return logical.ErrorResponse( | ||||||
|  | 				"policies should be empty when using management tokens"), nil | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		return logical.ErrorResponse( | ||||||
|  | 			`type must be "client" or "management"`), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	global, ok := d.GetOk("global") | ||||||
|  | 	if ok { | ||||||
|  | 		role.Global = global.(bool) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	entry, err := logical.StorageEntryJSON("role/"+name, role) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := req.Storage.Put(entry); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) pathRolesDelete( | ||||||
|  | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	name := d.Get("name").(string) | ||||||
|  | 	if err := req.Storage.Delete("role/" + name); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type roleConfig struct { | ||||||
|  | 	Policies  []string `json:"policies"` | ||||||
|  | 	TokenType string   `json:"type"` | ||||||
|  | 	Global    bool     `json:"global"` | ||||||
|  | } | ||||||
							
								
								
									
										68
									
								
								builtin/logical/nomad/secret_token.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								builtin/logical/nomad/secret_token.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | |||||||
|  | package nomad | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
|  | 	"github.com/hashicorp/vault/logical/framework" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	SecretTokenType = "token" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func secretToken(b *backend) *framework.Secret { | ||||||
|  | 	return &framework.Secret{ | ||||||
|  | 		Type: SecretTokenType, | ||||||
|  | 		Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 			"token": &framework.FieldSchema{ | ||||||
|  | 				Type:        framework.TypeString, | ||||||
|  | 				Description: "Request token", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
|  | 		Renew:  b.secretTokenRenew, | ||||||
|  | 		Revoke: b.secretTokenRevoke, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) secretTokenRenew( | ||||||
|  | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	lease, err := b.LeaseConfig(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if lease == nil { | ||||||
|  | 		lease = &configLease{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return framework.LeaseExtend(lease.TTL, lease.MaxTTL, b.System())(req, d) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *backend) secretTokenRevoke( | ||||||
|  | 	req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 	c, err := b.client(req.Storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if c == nil { | ||||||
|  | 		return nil, fmt.Errorf("error getting Nomad client") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	accessorIDRaw, ok := req.Secret.InternalData["accessor_id"] | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, fmt.Errorf("accessor_id is missing on the lease") | ||||||
|  | 	} | ||||||
|  | 	accessorID, ok := accessorIDRaw.(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errors.New("unable to convert accessor_id") | ||||||
|  | 	} | ||||||
|  | 	_, err = c.ACLTokens().Delete(accessorID, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
| @@ -1463,7 +1463,7 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { | |||||||
| 		//t.Logf("test step %d\nrole vals: %#v\n", stepCount, roleVals) | 		//t.Logf("test step %d\nrole vals: %#v\n", stepCount, roleVals) | ||||||
| 		stepCount++ | 		stepCount++ | ||||||
| 		//t.Logf("test step %d\nissue vals: %#v\n", stepCount, issueTestStep) | 		//t.Logf("test step %d\nissue vals: %#v\n", stepCount, issueTestStep) | ||||||
| 		roleTestStep.Data = structs.New(roleVals).Map() | 		roleTestStep.Data = roleVals.ToResponseData() | ||||||
| 		roleTestStep.Data["generate_lease"] = false | 		roleTestStep.Data["generate_lease"] = false | ||||||
| 		ret = append(ret, roleTestStep) | 		ret = append(ret, roleTestStep) | ||||||
| 		issueTestStep.Data = structs.New(issueVals).Map() | 		issueTestStep.Data = structs.New(issueVals).Map() | ||||||
| @@ -1594,38 +1594,38 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { | |||||||
| 			roleVals.CodeSigningFlag = false | 			roleVals.CodeSigningFlag = false | ||||||
| 			roleVals.EmailProtectionFlag = false | 			roleVals.EmailProtectionFlag = false | ||||||
|  |  | ||||||
| 			var usage string | 			var usage []string | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",DigitalSignature" | 				usage = append(usage, "DigitalSignature") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",ContentCoMmitment" | 				usage = append(usage, "ContentCoMmitment") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",KeyEncipherment" | 				usage = append(usage, "KeyEncipherment") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",DataEncipherment" | 				usage = append(usage, "DataEncipherment") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",KeyAgreemEnt" | 				usage = append(usage, "KeyAgreemEnt") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",CertSign" | 				usage = append(usage, "CertSign") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",CRLSign" | 				usage = append(usage, "CRLSign") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",EncipherOnly" | 				usage = append(usage, "EncipherOnly") | ||||||
| 			} | 			} | ||||||
| 			if mathRand.Int()%2 == 1 { | 			if mathRand.Int()%2 == 1 { | ||||||
| 				usage = usage + ",DecipherOnly" | 				usage = append(usage, "DecipherOnly") | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			roleVals.KeyUsage = usage | 			roleVals.KeyUsage = usage | ||||||
| 			parsedKeyUsage := parseKeyUsages(roleVals.KeyUsage) | 			parsedKeyUsage := parseKeyUsages(roleVals.KeyUsage) | ||||||
| 			if parsedKeyUsage == 0 && usage != "" { | 			if parsedKeyUsage == 0 && len(usage) != 0 { | ||||||
| 				panic("parsed key usages was zero") | 				panic("parsed key usages was zero") | ||||||
| 			} | 			} | ||||||
| 			parsedKeyUsageUnderTest = parsedKeyUsage | 			parsedKeyUsageUnderTest = parsedKeyUsage | ||||||
| @@ -1759,10 +1759,10 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { | |||||||
| 		commonNames.Localhost = true | 		commonNames.Localhost = true | ||||||
| 		addCnTests() | 		addCnTests() | ||||||
|  |  | ||||||
| 		roleVals.AllowedDomains = "foobar.com" | 		roleVals.AllowedDomains = []string{"foobar.com"} | ||||||
| 		addCnTests() | 		addCnTests() | ||||||
|  |  | ||||||
| 		roleVals.AllowedDomains = "example.com" | 		roleVals.AllowedDomains = []string{"example.com"} | ||||||
| 		roleVals.AllowSubdomains = true | 		roleVals.AllowSubdomains = true | ||||||
| 		commonNames.SubDomain = true | 		commonNames.SubDomain = true | ||||||
| 		commonNames.Wildcard = true | 		commonNames.Wildcard = true | ||||||
| @@ -1770,13 +1770,13 @@ func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { | |||||||
| 		commonNames.SubSubdomainWildcard = true | 		commonNames.SubSubdomainWildcard = true | ||||||
| 		addCnTests() | 		addCnTests() | ||||||
|  |  | ||||||
| 		roleVals.AllowedDomains = "foobar.com,example.com" | 		roleVals.AllowedDomains = []string{"foobar.com", "example.com"} | ||||||
| 		commonNames.SecondDomain = true | 		commonNames.SecondDomain = true | ||||||
| 		roleVals.AllowBareDomains = true | 		roleVals.AllowBareDomains = true | ||||||
| 		commonNames.BareDomain = true | 		commonNames.BareDomain = true | ||||||
| 		addCnTests() | 		addCnTests() | ||||||
|  |  | ||||||
| 		roleVals.AllowedDomains = "foobar.com,*example.com" | 		roleVals.AllowedDomains = []string{"foobar.com", "*example.com"} | ||||||
| 		roleVals.AllowGlobDomains = true | 		roleVals.AllowGlobDomains = true | ||||||
| 		commonNames.GlobDomain = true | 		commonNames.GlobDomain = true | ||||||
| 		addCnTests() | 		addCnTests() | ||||||
|   | |||||||
| @@ -17,14 +17,14 @@ func (b *backend) getGenerationParams( | |||||||
| 	case "internal": | 	case "internal": | ||||||
| 	default: | 	default: | ||||||
| 		errorResp = logical.ErrorResponse( | 		errorResp = logical.ErrorResponse( | ||||||
| 			`The "exported" path parameter must be "internal" or "exported"`) | 			`the "exported" path parameter must be "internal" or "exported"`) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	format = getFormat(data) | 	format = getFormat(data) | ||||||
| 	if format == "" { | 	if format == "" { | ||||||
| 		errorResp = logical.ErrorResponse( | 		errorResp = logical.ErrorResponse( | ||||||
| 			`The "format" path parameter must be "pem", "der", or "pem_bundle"`) | 			`the "format" path parameter must be "pem", "der", "der_pkcs", or "pem_bundle"`) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package pki | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"crypto" | ||||||
| 	"crypto/ecdsa" | 	"crypto/ecdsa" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"crypto/rsa" | 	"crypto/rsa" | ||||||
| @@ -9,6 +10,7 @@ import ( | |||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"crypto/x509/pkix" | 	"crypto/x509/pkix" | ||||||
| 	"encoding/asn1" | 	"encoding/asn1" | ||||||
|  | 	"encoding/base64" | ||||||
| 	"encoding/pem" | 	"encoding/pem" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| @@ -16,6 +18,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
| 	"github.com/hashicorp/vault/helper/certutil" | 	"github.com/hashicorp/vault/helper/certutil" | ||||||
| 	"github.com/hashicorp/vault/helper/errutil" | 	"github.com/hashicorp/vault/helper/errutil" | ||||||
| 	"github.com/hashicorp/vault/helper/parseutil" | 	"github.com/hashicorp/vault/helper/parseutil" | ||||||
| @@ -372,9 +375,9 @@ func validateNames(req *logical.Request, names []string, role *roleEntry) string | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if role.AllowedDomains != "" { | 		if len(role.AllowedDomains) > 0 { | ||||||
| 			valid := false | 			valid := false | ||||||
| 			for _, currDomain := range strings.Split(role.AllowedDomains, ",") { | 			for _, currDomain := range role.AllowedDomains { | ||||||
| 				// If there is, say, a trailing comma, ignore it | 				// If there is, say, a trailing comma, ignore it | ||||||
| 				if currDomain == "" { | 				if currDomain == "" { | ||||||
| 					continue | 					continue | ||||||
| @@ -1183,3 +1186,66 @@ NameCheck: | |||||||
|  |  | ||||||
| 	return fmt.Errorf("name %q disallowed by CA's permitted DNS domains", badName) | 	return fmt.Errorf("name %q disallowed by CA's permitted DNS domains", badName) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func convertRespToPKCS8(resp *logical.Response) error { | ||||||
|  | 	privRaw, ok := resp.Data["private_key"] | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	priv, ok := privRaw.(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		return fmt.Errorf("error converting response to pkcs8: could not parse original value as string") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	privKeyTypeRaw, ok := resp.Data["private_key_type"] | ||||||
|  | 	if !ok { | ||||||
|  | 		return fmt.Errorf("error converting response to pkcs8: %q not found in response", "private_key_type") | ||||||
|  | 	} | ||||||
|  | 	privKeyType, ok := privKeyTypeRaw.(certutil.PrivateKeyType) | ||||||
|  | 	if !ok { | ||||||
|  | 		return fmt.Errorf("error converting response to pkcs8: could not parse original type value as string") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var keyData []byte | ||||||
|  | 	var pemUsed bool | ||||||
|  | 	var err error | ||||||
|  | 	var signer crypto.Signer | ||||||
|  |  | ||||||
|  | 	block, _ := pem.Decode([]byte(priv)) | ||||||
|  | 	if block == nil { | ||||||
|  | 		keyData, err = base64.StdEncoding.DecodeString(priv) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errwrap.Wrapf("error converting response to pkcs8: error decoding original value: {{err}}", err) | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		keyData = block.Bytes | ||||||
|  | 		pemUsed = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch privKeyType { | ||||||
|  | 	case certutil.RSAPrivateKey: | ||||||
|  | 		signer, err = x509.ParsePKCS1PrivateKey(keyData) | ||||||
|  | 	case certutil.ECPrivateKey: | ||||||
|  | 		signer, err = x509.ParseECPrivateKey(keyData) | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Errorf("unknown private key type %q", privKeyType) | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errwrap.Wrapf("error converting response to pkcs8: error parsing previous key: {{err}}", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	keyData, err = certutil.MarshalPKCS8PrivateKey(signer) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errwrap.Wrapf("error converting response to pkcs8: error marshaling pkcs8 key: {{err}}", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if pemUsed { | ||||||
|  | 		block.Type = "PRIVATE KEY" | ||||||
|  | 		block.Bytes = keyData | ||||||
|  | 		resp.Data["private_key"] = string(pem.EncodeToMemory(block)) | ||||||
|  | 	} else { | ||||||
|  | 		resp.Data["private_key"] = base64.StdEncoding.EncodeToString(keyData) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -22,6 +22,17 @@ key and issuing cert will be appended to the | |||||||
| certificate pem. Defaults to "pem".`, | certificate pem. Defaults to "pem".`, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	fields["private_key_format"] = &framework.FieldSchema{ | ||||||
|  | 		Type:    framework.TypeString, | ||||||
|  | 		Default: "der", | ||||||
|  | 		Description: `Format for the returned private key. | ||||||
|  | Generally the default will be controlled by the "format" | ||||||
|  | parameter as either base64-encoded DER or PEM-encoded DER. | ||||||
|  | However, this can be set to "pkcs8" to have the returned | ||||||
|  | private key contain base64-encoded pkcs8 or PEM-encoded | ||||||
|  | pkcs8 instead. Defaults to "der".`, | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	fields["ip_sans"] = &framework.FieldSchema{ | 	fields["ip_sans"] = &framework.FieldSchema{ | ||||||
| 		Type: framework.TypeString, | 		Type: framework.TypeString, | ||||||
| 		Description: `The requested IP SANs, if any, in a | 		Description: `The requested IP SANs, if any, in a | ||||||
|   | |||||||
| @@ -106,6 +106,13 @@ func (b *backend) pathGenerateIntermediate( | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if data.Get("private_key_format").(string) == "pkcs8" { | ||||||
|  | 		err = convertRespToPKCS8(resp) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	cb := &certutil.CertBundle{} | 	cb := &certutil.CertBundle{} | ||||||
| 	cb.PrivateKey = csrb.PrivateKey | 	cb.PrivateKey = csrb.PrivateKey | ||||||
| 	cb.PrivateKeyType = csrb.PrivateKeyType | 	cb.PrivateKeyType = csrb.PrivateKeyType | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
| 	"github.com/hashicorp/vault/helper/certutil" | 	"github.com/hashicorp/vault/helper/certutil" | ||||||
| 	"github.com/hashicorp/vault/helper/errutil" | 	"github.com/hashicorp/vault/helper/errutil" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| @@ -163,7 +164,7 @@ func (b *backend) pathIssueSignCert( | |||||||
| 	format := getFormat(data) | 	format := getFormat(data) | ||||||
| 	if format == "" { | 	if format == "" { | ||||||
| 		return logical.ErrorResponse( | 		return logical.ErrorResponse( | ||||||
| 			`The "format" path parameter must be "pem", "der", or "pem_bundle"`), nil | 			`the "format" path parameter must be "pem", "der", or "pem_bundle"`), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var caErr error | 	var caErr error | ||||||
| @@ -171,10 +172,10 @@ func (b *backend) pathIssueSignCert( | |||||||
| 	switch caErr.(type) { | 	switch caErr.(type) { | ||||||
| 	case errutil.UserError: | 	case errutil.UserError: | ||||||
| 		return nil, errutil.UserError{Err: fmt.Sprintf( | 		return nil, errutil.UserError{Err: fmt.Sprintf( | ||||||
| 			"Could not fetch the CA certificate (was one set?): %s", caErr)} | 			"could not fetch the CA certificate (was one set?): %s", caErr)} | ||||||
| 	case errutil.InternalError: | 	case errutil.InternalError: | ||||||
| 		return nil, errutil.InternalError{Err: fmt.Sprintf( | 		return nil, errutil.InternalError{Err: fmt.Sprintf( | ||||||
| 			"Error fetching CA certificate: %s", caErr)} | 			"error fetching CA certificate: %s", caErr)} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var parsedBundle *certutil.ParsedCertBundle | 	var parsedBundle *certutil.ParsedCertBundle | ||||||
| @@ -195,12 +196,12 @@ func (b *backend) pathIssueSignCert( | |||||||
|  |  | ||||||
| 	signingCB, err := signingBundle.ToCertBundle() | 	signingCB, err := signingBundle.ToCertBundle() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("Error converting raw signing bundle to cert bundle: %s", err) | 		return nil, errwrap.Wrapf("error converting raw signing bundle to cert bundle: {{err}}", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cb, err := parsedBundle.ToCertBundle() | 	cb, err := parsedBundle.ToCertBundle() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("Error converting raw cert bundle to cert bundle: %s", err) | 		return nil, errwrap.Wrapf("error converting raw cert bundle to cert bundle: {{err}}", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	respData := map[string]interface{}{ | 	respData := map[string]interface{}{ | ||||||
| @@ -267,6 +268,13 @@ func (b *backend) pathIssueSignCert( | |||||||
| 		resp.Secret.TTL = parsedBundle.Certificate.NotAfter.Sub(time.Now()) | 		resp.Secret.TTL = parsedBundle.Certificate.NotAfter.Sub(time.Now()) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if data.Get("private_key_format").(string) == "pkcs8" { | ||||||
|  | 		err = convertRespToPKCS8(resp) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if !role.NoStore { | 	if !role.NoStore { | ||||||
| 		err = req.Storage.Put(&logical.StorageEntry{ | 		err = req.Storage.Put(&logical.StorageEntry{ | ||||||
| 			Key:   "certs/" + normalizeSerial(cb.SerialNumber), | 			Key:   "certs/" + normalizeSerial(cb.SerialNumber), | ||||||
|   | |||||||
| @@ -6,7 +6,6 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/fatih/structs" |  | ||||||
| 	"github.com/hashicorp/vault/helper/parseutil" | 	"github.com/hashicorp/vault/helper/parseutil" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/hashicorp/vault/logical/framework" | 	"github.com/hashicorp/vault/logical/framework" | ||||||
| @@ -57,13 +56,12 @@ name in a request`, | |||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
| 			"allowed_domains": &framework.FieldSchema{ | 			"allowed_domains": &framework.FieldSchema{ | ||||||
| 				Type:    framework.TypeString, | 				Type: framework.TypeCommaStringSlice, | ||||||
| 				Default: "", |  | ||||||
| 				Description: `If set, clients can request certificates for | 				Description: `If set, clients can request certificates for | ||||||
| subdomains directly beneath these domains, including | subdomains directly beneath these domains, including | ||||||
| the wildcard subdomains. See the documentation for more | the wildcard subdomains. See the documentation for more | ||||||
| information. This parameter accepts a comma-separated list | information. This parameter accepts a comma-separated  | ||||||
| of domains.`, | string or list of domains.`, | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
| 			"allow_bare_domains": &framework.FieldSchema{ | 			"allow_bare_domains": &framework.FieldSchema{ | ||||||
| @@ -158,14 +156,14 @@ the key_type.`, | |||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
| 			"key_usage": &framework.FieldSchema{ | 			"key_usage": &framework.FieldSchema{ | ||||||
| 				Type:    framework.TypeString, | 				Type:    framework.TypeCommaStringSlice, | ||||||
| 				Default: "DigitalSignature,KeyAgreement,KeyEncipherment", | 				Default: []string{"DigitalSignature", "KeyAgreement", "KeyEncipherment"}, | ||||||
| 				Description: `A comma-separated set of key usages (not extended | 				Description: `A comma-separated string or list of key usages (not extended | ||||||
| key usages). Valid values can be found at | key usages). Valid values can be found at | ||||||
| https://golang.org/pkg/crypto/x509/#KeyUsage | https://golang.org/pkg/crypto/x509/#KeyUsage | ||||||
| -- simply drop the "KeyUsage" part of the name. | -- simply drop the "KeyUsage" part of the name. | ||||||
| To remove all key usages from being set, set | To remove all key usages from being set, set | ||||||
| this value to an empty string.`, | this value to an empty list.`, | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
| 			"use_csr_common_name": &framework.FieldSchema{ | 			"use_csr_common_name": &framework.FieldSchema{ | ||||||
| @@ -217,8 +215,8 @@ leases adversely affect the startup time of Vault.`, | |||||||
| 				Default: false, | 				Default: false, | ||||||
| 				Description: ` | 				Description: ` | ||||||
| If set, certificates issued/signed against this role will not be stored in the | If set, certificates issued/signed against this role will not be stored in the | ||||||
| in the storage backend. This can improve performance when issuing large numbers | storage backend. This can improve performance when issuing large numbers of  | ||||||
| of certificates. However, certificates issued in this way cannot be enumerated | certificates. However, certificates issued in this way cannot be enumerated | ||||||
| or revoked, so this option is recommended only for certificates that are | or revoked, so this option is recommended only for certificates that are | ||||||
| non-sensitive, or extremely short-lived. This option implies a value of "false" | non-sensitive, or extremely short-lived. This option implies a value of "false" | ||||||
| for "generate_lease".`, | for "generate_lease".`, | ||||||
| @@ -267,23 +265,21 @@ func (b *backend) getRole(s logical.Storage, n string) (*roleEntry, error) { | |||||||
| 		result.AllowBareDomains = true | 		result.AllowBareDomains = true | ||||||
| 		modified = true | 		modified = true | ||||||
| 	} | 	} | ||||||
|  | 	if result.AllowedDomainsOld != "" { | ||||||
|  | 		result.AllowedDomains = strings.Split(result.AllowedDomainsOld, ",") | ||||||
|  | 		result.AllowedDomainsOld = "" | ||||||
|  | 		modified = true | ||||||
|  | 	} | ||||||
| 	if result.AllowedBaseDomain != "" { | 	if result.AllowedBaseDomain != "" { | ||||||
| 		found := false | 		found := false | ||||||
| 		allowedDomains := strings.Split(result.AllowedDomains, ",") | 		for _, v := range result.AllowedDomains { | ||||||
| 		if len(allowedDomains) != 0 { | 			if v == result.AllowedBaseDomain { | ||||||
| 			for _, v := range allowedDomains { | 				found = true | ||||||
| 				if v == result.AllowedBaseDomain { | 				break | ||||||
| 					found = true |  | ||||||
| 					break |  | ||||||
| 				} |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		if !found { | 		if !found { | ||||||
| 			if result.AllowedDomains == "" { | 			result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain) | ||||||
| 				result.AllowedDomains = result.AllowedBaseDomain |  | ||||||
| 			} else { |  | ||||||
| 				result.AllowedDomains += "," + result.AllowedBaseDomain |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
| 		result.AllowedBaseDomain = "" | 		result.AllowedBaseDomain = "" | ||||||
| 		modified = true | 		modified = true | ||||||
| @@ -299,13 +295,23 @@ func (b *backend) getRole(s logical.Storage, n string) (*roleEntry, error) { | |||||||
| 		modified = true | 		modified = true | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Upgrade key usages | ||||||
|  | 	if result.KeyUsageOld != "" { | ||||||
|  | 		result.KeyUsage = strings.Split(result.KeyUsageOld, ",") | ||||||
|  | 		result.KeyUsageOld = "" | ||||||
|  | 		modified = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if modified { | 	if modified { | ||||||
| 		jsonEntry, err := logical.StorageEntryJSON("role/"+n, &result) | 		jsonEntry, err := logical.StorageEntryJSON("role/"+n, &result) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		if err := s.Put(jsonEntry); err != nil { | 		if err := s.Put(jsonEntry); err != nil { | ||||||
| 			return nil, err | 			// Only perform upgrades on replication primary | ||||||
|  | 			if !strings.Contains(err.Error(), logical.ErrReadOnly.Error()) { | ||||||
|  | 				return nil, err | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -351,18 +357,8 @@ func (b *backend) pathRoleRead( | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	resp := &logical.Response{ | 	resp := &logical.Response{ | ||||||
| 		Data: structs.New(role).Map(), | 		Data: role.ToResponseData(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if resp.Data == nil { |  | ||||||
| 		return nil, fmt.Errorf("error converting role data to response") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// These values are deprecated and the entries are migrated on read |  | ||||||
| 	delete(resp.Data, "lease") |  | ||||||
| 	delete(resp.Data, "lease_max") |  | ||||||
| 	delete(resp.Data, "allowed_base_domain") |  | ||||||
|  |  | ||||||
| 	return resp, nil | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -385,7 +381,7 @@ func (b *backend) pathRoleCreate( | |||||||
| 		MaxTTL:              data.Get("max_ttl").(string), | 		MaxTTL:              data.Get("max_ttl").(string), | ||||||
| 		TTL:                 (time.Duration(data.Get("ttl").(int)) * time.Second).String(), | 		TTL:                 (time.Duration(data.Get("ttl").(int)) * time.Second).String(), | ||||||
| 		AllowLocalhost:      data.Get("allow_localhost").(bool), | 		AllowLocalhost:      data.Get("allow_localhost").(bool), | ||||||
| 		AllowedDomains:      data.Get("allowed_domains").(string), | 		AllowedDomains:      data.Get("allowed_domains").([]string), | ||||||
| 		AllowBareDomains:    data.Get("allow_bare_domains").(bool), | 		AllowBareDomains:    data.Get("allow_bare_domains").(bool), | ||||||
| 		AllowSubdomains:     data.Get("allow_subdomains").(bool), | 		AllowSubdomains:     data.Get("allow_subdomains").(bool), | ||||||
| 		AllowGlobDomains:    data.Get("allow_glob_domains").(bool), | 		AllowGlobDomains:    data.Get("allow_glob_domains").(bool), | ||||||
| @@ -400,7 +396,7 @@ func (b *backend) pathRoleCreate( | |||||||
| 		KeyBits:             data.Get("key_bits").(int), | 		KeyBits:             data.Get("key_bits").(int), | ||||||
| 		UseCSRCommonName:    data.Get("use_csr_common_name").(bool), | 		UseCSRCommonName:    data.Get("use_csr_common_name").(bool), | ||||||
| 		UseCSRSANs:          data.Get("use_csr_sans").(bool), | 		UseCSRSANs:          data.Get("use_csr_sans").(bool), | ||||||
| 		KeyUsage:            data.Get("key_usage").(string), | 		KeyUsage:            data.Get("key_usage").([]string), | ||||||
| 		OU:                  data.Get("ou").(string), | 		OU:                  data.Get("ou").(string), | ||||||
| 		Organization:        data.Get("organization").(string), | 		Organization:        data.Get("organization").(string), | ||||||
| 		GenerateLease:       new(bool), | 		GenerateLease:       new(bool), | ||||||
| @@ -473,10 +469,9 @@ func (b *backend) pathRoleCreate( | |||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseKeyUsages(input string) int { | func parseKeyUsages(input []string) int { | ||||||
| 	var parsedKeyUsages x509.KeyUsage | 	var parsedKeyUsages x509.KeyUsage | ||||||
| 	splitKeyUsage := strings.Split(input, ",") | 	for _, k := range input { | ||||||
| 	for _, k := range splitKeyUsage { |  | ||||||
| 		switch strings.ToLower(strings.TrimSpace(k)) { | 		switch strings.ToLower(strings.TrimSpace(k)) { | ||||||
| 		case "digitalsignature": | 		case "digitalsignature": | ||||||
| 			parsedKeyUsages |= x509.KeyUsageDigitalSignature | 			parsedKeyUsages |= x509.KeyUsageDigitalSignature | ||||||
| @@ -503,40 +498,77 @@ func parseKeyUsages(input string) int { | |||||||
| } | } | ||||||
|  |  | ||||||
| type roleEntry struct { | type roleEntry struct { | ||||||
| 	LeaseMax              string `json:"lease_max" structs:"lease_max" mapstructure:"lease_max"` | 	LeaseMax              string   `json:"lease_max"` | ||||||
| 	Lease                 string `json:"lease" structs:"lease" mapstructure:"lease"` | 	Lease                 string   `json:"lease"` | ||||||
| 	MaxTTL                string `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"` | 	MaxTTL                string   `json:"max_ttl" mapstructure:"max_ttl"` | ||||||
| 	TTL                   string `json:"ttl" structs:"ttl" mapstructure:"ttl"` | 	TTL                   string   `json:"ttl" mapstructure:"ttl"` | ||||||
| 	AllowLocalhost        bool   `json:"allow_localhost" structs:"allow_localhost" mapstructure:"allow_localhost"` | 	AllowLocalhost        bool     `json:"allow_localhost" mapstructure:"allow_localhost"` | ||||||
| 	AllowedBaseDomain     string `json:"allowed_base_domain" structs:"allowed_base_domain" mapstructure:"allowed_base_domain"` | 	AllowedBaseDomain     string   `json:"allowed_base_domain" mapstructure:"allowed_base_domain"` | ||||||
| 	AllowedDomains        string `json:"allowed_domains" structs:"allowed_domains" mapstructure:"allowed_domains"` | 	AllowedDomainsOld     string   `json:"allowed_domains,omit_empty"` | ||||||
| 	AllowBaseDomain       bool   `json:"allow_base_domain" structs:"allow_base_domain" mapstructure:"allow_base_domain"` | 	AllowedDomains        []string `json:"allowed_domains_list" mapstructure:"allowed_domains"` | ||||||
| 	AllowBareDomains      bool   `json:"allow_bare_domains" structs:"allow_bare_domains" mapstructure:"allow_bare_domains"` | 	AllowBaseDomain       bool     `json:"allow_base_domain"` | ||||||
| 	AllowTokenDisplayName bool   `json:"allow_token_displayname" structs:"allow_token_displayname" mapstructure:"allow_token_displayname"` | 	AllowBareDomains      bool     `json:"allow_bare_domains" mapstructure:"allow_bare_domains"` | ||||||
| 	AllowSubdomains       bool   `json:"allow_subdomains" structs:"allow_subdomains" mapstructure:"allow_subdomains"` | 	AllowTokenDisplayName bool     `json:"allow_token_displayname" mapstructure:"allow_token_displayname"` | ||||||
| 	AllowGlobDomains      bool   `json:"allow_glob_domains" structs:"allow_glob_domains" mapstructure:"allow_glob_domains"` | 	AllowSubdomains       bool     `json:"allow_subdomains" mapstructure:"allow_subdomains"` | ||||||
| 	AllowAnyName          bool   `json:"allow_any_name" structs:"allow_any_name" mapstructure:"allow_any_name"` | 	AllowGlobDomains      bool     `json:"allow_glob_domains" mapstructure:"allow_glob_domains"` | ||||||
| 	EnforceHostnames      bool   `json:"enforce_hostnames" structs:"enforce_hostnames" mapstructure:"enforce_hostnames"` | 	AllowAnyName          bool     `json:"allow_any_name" mapstructure:"allow_any_name"` | ||||||
| 	AllowIPSANs           bool   `json:"allow_ip_sans" structs:"allow_ip_sans" mapstructure:"allow_ip_sans"` | 	EnforceHostnames      bool     `json:"enforce_hostnames" mapstructure:"enforce_hostnames"` | ||||||
| 	ServerFlag            bool   `json:"server_flag" structs:"server_flag" mapstructure:"server_flag"` | 	AllowIPSANs           bool     `json:"allow_ip_sans" mapstructure:"allow_ip_sans"` | ||||||
| 	ClientFlag            bool   `json:"client_flag" structs:"client_flag" mapstructure:"client_flag"` | 	ServerFlag            bool     `json:"server_flag" mapstructure:"server_flag"` | ||||||
| 	CodeSigningFlag       bool   `json:"code_signing_flag" structs:"code_signing_flag" mapstructure:"code_signing_flag"` | 	ClientFlag            bool     `json:"client_flag" mapstructure:"client_flag"` | ||||||
| 	EmailProtectionFlag   bool   `json:"email_protection_flag" structs:"email_protection_flag" mapstructure:"email_protection_flag"` | 	CodeSigningFlag       bool     `json:"code_signing_flag" mapstructure:"code_signing_flag"` | ||||||
| 	UseCSRCommonName      bool   `json:"use_csr_common_name" structs:"use_csr_common_name" mapstructure:"use_csr_common_name"` | 	EmailProtectionFlag   bool     `json:"email_protection_flag" mapstructure:"email_protection_flag"` | ||||||
| 	UseCSRSANs            bool   `json:"use_csr_sans" structs:"use_csr_sans" mapstructure:"use_csr_sans"` | 	UseCSRCommonName      bool     `json:"use_csr_common_name" mapstructure:"use_csr_common_name"` | ||||||
| 	KeyType               string `json:"key_type" structs:"key_type" mapstructure:"key_type"` | 	UseCSRSANs            bool     `json:"use_csr_sans" mapstructure:"use_csr_sans"` | ||||||
| 	KeyBits               int    `json:"key_bits" structs:"key_bits" mapstructure:"key_bits"` | 	KeyType               string   `json:"key_type" mapstructure:"key_type"` | ||||||
| 	MaxPathLength         *int   `json:",omitempty" structs:"max_path_length,omitempty" mapstructure:"max_path_length"` | 	KeyBits               int      `json:"key_bits" mapstructure:"key_bits"` | ||||||
| 	KeyUsage              string `json:"key_usage" structs:"key_usage" mapstructure:"key_usage"` | 	MaxPathLength         *int     `json:",omitempty" mapstructure:"max_path_length"` | ||||||
| 	OU                    string `json:"ou" structs:"ou" mapstructure:"ou"` | 	KeyUsageOld           string   `json:"key_usage,omitempty"` | ||||||
| 	Organization          string `json:"organization" structs:"organization" mapstructure:"organization"` | 	KeyUsage              []string `json:"key_usage_list" mapstructure:"key_usage"` | ||||||
| 	GenerateLease         *bool  `json:"generate_lease,omitempty" structs:"generate_lease,omitempty"` | 	OU                    string   `json:"ou" mapstructure:"ou"` | ||||||
| 	NoStore               bool   `json:"no_store" structs:"no_store" mapstructure:"no_store"` | 	Organization          string   `json:"organization" mapstructure:"organization"` | ||||||
|  | 	GenerateLease         *bool    `json:"generate_lease,omitempty"` | ||||||
|  | 	NoStore               bool     `json:"no_store" mapstructure:"no_store"` | ||||||
|  |  | ||||||
| 	// Used internally for signing intermediates | 	// Used internally for signing intermediates | ||||||
| 	AllowExpirationPastCA bool | 	AllowExpirationPastCA bool | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (r *roleEntry) ToResponseData() map[string]interface{} { | ||||||
|  | 	responseData := map[string]interface{}{ | ||||||
|  | 		"ttl":                     r.TTL, | ||||||
|  | 		"max_ttl":                 r.MaxTTL, | ||||||
|  | 		"allow_localhost":         r.AllowLocalhost, | ||||||
|  | 		"allowed_domains":         r.AllowedDomains, | ||||||
|  | 		"allow_bare_domains":      r.AllowBareDomains, | ||||||
|  | 		"allow_token_displayname": r.AllowTokenDisplayName, | ||||||
|  | 		"allow_subdomains":        r.AllowSubdomains, | ||||||
|  | 		"allow_glob_domains":      r.AllowGlobDomains, | ||||||
|  | 		"allow_any_name":          r.AllowAnyName, | ||||||
|  | 		"enforce_hostnames":       r.EnforceHostnames, | ||||||
|  | 		"allow_ip_sans":           r.AllowIPSANs, | ||||||
|  | 		"server_flag":             r.ServerFlag, | ||||||
|  | 		"client_flag":             r.ClientFlag, | ||||||
|  | 		"code_signing_flag":       r.CodeSigningFlag, | ||||||
|  | 		"email_protection_flag":   r.EmailProtectionFlag, | ||||||
|  | 		"use_csr_common_name":     r.UseCSRCommonName, | ||||||
|  | 		"use_csr_sans":            r.UseCSRSANs, | ||||||
|  | 		"key_type":                r.KeyType, | ||||||
|  | 		"key_bits":                r.KeyBits, | ||||||
|  | 		"key_usage":               r.KeyUsage, | ||||||
|  | 		"ou":                      r.OU, | ||||||
|  | 		"organization":            r.Organization, | ||||||
|  | 		"no_store":                r.NoStore, | ||||||
|  | 	} | ||||||
|  | 	if r.MaxPathLength != nil { | ||||||
|  | 		responseData["max_path_length"] = r.MaxPathLength | ||||||
|  | 	} | ||||||
|  | 	if r.GenerateLease != nil { | ||||||
|  | 		responseData["generate_lease"] = r.GenerateLease | ||||||
|  | 	} | ||||||
|  | 	return responseData | ||||||
|  | } | ||||||
|  |  | ||||||
| const pathListRolesHelpSyn = `List the existing roles in this backend` | const pathListRolesHelpSyn = `List the existing roles in this backend` | ||||||
|  |  | ||||||
| const pathListRolesHelpDesc = `Roles will be listed by the role name.` | const pathListRolesHelpDesc = `Roles will be listed by the role name.` | ||||||
|   | |||||||
| @@ -120,6 +120,181 @@ func TestPki_RoleGenerateLease(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestPki_RoleKeyUsage(t *testing.T) { | ||||||
|  | 	var resp *logical.Response | ||||||
|  | 	var err error | ||||||
|  | 	b, storage := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	roleData := map[string]interface{}{ | ||||||
|  | 		"allowed_domains": "myvault.com", | ||||||
|  | 		"ttl":             "5h", | ||||||
|  | 		"key_usage":       []string{"KeyEncipherment", "DigitalSignature"}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleReq := &logical.Request{ | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "roles/testrole", | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data:      roleData, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v resp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleReq.Operation = logical.ReadOperation | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v resp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	keyUsage := resp.Data["key_usage"].([]string) | ||||||
|  | 	if len(keyUsage) != 2 { | ||||||
|  | 		t.Fatalf("key_usage should have 2 values") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check that old key usage value is nil | ||||||
|  | 	var role roleEntry | ||||||
|  | 	err = mapstructure.Decode(resp.Data, &role) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if role.KeyUsageOld != "" { | ||||||
|  | 		t.Fatalf("old key usage storage value should be blank") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Make it explicit | ||||||
|  | 	role.KeyUsageOld = "KeyEncipherment,DigitalSignature" | ||||||
|  | 	role.KeyUsage = nil | ||||||
|  |  | ||||||
|  | 	entry, err := logical.StorageEntryJSON("role/testrole", role) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if err := storage.Put(entry); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Reading should upgrade key_usage | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v resp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	keyUsage = resp.Data["key_usage"].([]string) | ||||||
|  | 	if len(keyUsage) != 2 { | ||||||
|  | 		t.Fatalf("key_usage should have 2 values") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Read back from storage to ensure upgrade | ||||||
|  | 	entry, err = storage.Get("role/testrole") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		t.Fatalf("role should not be nil") | ||||||
|  | 	} | ||||||
|  | 	var result roleEntry | ||||||
|  | 	if err := entry.DecodeJSON(&result); err != nil { | ||||||
|  | 		t.Fatalf("err: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if result.KeyUsageOld != "" { | ||||||
|  | 		t.Fatal("old key usage value should be blank") | ||||||
|  | 	} | ||||||
|  | 	if len(result.KeyUsage) != 2 { | ||||||
|  | 		t.Fatal("key_usage should have 2 values") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPki_RoleAllowedDomains(t *testing.T) { | ||||||
|  | 	var resp *logical.Response | ||||||
|  | 	var err error | ||||||
|  | 	b, storage := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	roleData := map[string]interface{}{ | ||||||
|  | 		"allowed_domains": []string{"foobar.com", "*example.com"}, | ||||||
|  | 		"ttl":             "5h", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleReq := &logical.Request{ | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "roles/testrole", | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data:      roleData, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v resp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	roleReq.Operation = logical.ReadOperation | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v resp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	allowedDomains := resp.Data["allowed_domains"].([]string) | ||||||
|  | 	if len(allowedDomains) != 2 { | ||||||
|  | 		t.Fatalf("allowed_domains should have 2 values") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check that old key usage value is nil | ||||||
|  | 	var role roleEntry | ||||||
|  | 	err = mapstructure.Decode(resp.Data, &role) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if role.AllowedDomainsOld != "" { | ||||||
|  | 		t.Fatalf("old allowed_domains storage value should be blank") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Make it explicit | ||||||
|  | 	role.AllowedDomainsOld = "foobar.com,*example.com" | ||||||
|  | 	role.AllowedDomains = nil | ||||||
|  |  | ||||||
|  | 	entry, err := logical.StorageEntryJSON("role/testrole", role) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if err := storage.Put(entry); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Reading should upgrade key_usage | ||||||
|  | 	resp, err = b.HandleRequest(roleReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v resp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	allowedDomains = resp.Data["allowed_domains"].([]string) | ||||||
|  | 	if len(allowedDomains) != 2 { | ||||||
|  | 		t.Fatalf("allowed_domains should have 2 values") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Read back from storage to ensure upgrade | ||||||
|  | 	entry, err = storage.Get("role/testrole") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("err: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		t.Fatalf("role should not be nil") | ||||||
|  | 	} | ||||||
|  | 	var result roleEntry | ||||||
|  | 	if err := entry.DecodeJSON(&result); err != nil { | ||||||
|  | 		t.Fatalf("err: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if result.AllowedDomainsOld != "" { | ||||||
|  | 		t.Fatal("old allowed_domains value should be blank") | ||||||
|  | 	} | ||||||
|  | 	if len(result.AllowedDomains) != 2 { | ||||||
|  | 		t.Fatal("allowed_domains should have 2 values") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestPki_RoleNoStore(t *testing.T) { | func TestPki_RoleNoStore(t *testing.T) { | ||||||
| 	var resp *logical.Response | 	var resp *logical.Response | ||||||
| 	var err error | 	var err error | ||||||
|   | |||||||
| @@ -149,7 +149,7 @@ func (b *backend) pathCAGenerateRoot( | |||||||
|  |  | ||||||
| 	cb, err := parsedBundle.ToCertBundle() | 	cb, err := parsedBundle.ToCertBundle() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("error converting raw cert bundle to cert bundle: %s", err) | 		return nil, errwrap.Wrapf("error converting raw cert bundle to cert bundle: {{err}}", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	resp := &logical.Response{ | 	resp := &logical.Response{ | ||||||
| @@ -188,6 +188,13 @@ func (b *backend) pathCAGenerateRoot( | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if data.Get("private_key_format").(string) == "pkcs8" { | ||||||
|  | 		err = convertRespToPKCS8(resp) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Store it as the CA bundle | 	// Store it as the CA bundle | ||||||
| 	entry, err = logical.StorageEntryJSON("config/ca_bundle", cb) | 	entry, err = logical.StorageEntryJSON("config/ca_bundle", cb) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -205,7 +212,7 @@ func (b *backend) pathCAGenerateRoot( | |||||||
| 		Value: parsedBundle.CertificateBytes, | 		Value: parsedBundle.CertificateBytes, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("Unable to store certificate locally: %v", err) | 		return nil, errwrap.Wrapf("unable to store certificate locally: {{err}}", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// For ease of later use, also store just the certificate at a known | 	// For ease of later use, also store just the certificate at a known | ||||||
|   | |||||||
| @@ -25,6 +25,12 @@ func Backend(conf *logical.BackendConfig) *backend { | |||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: strings.TrimSpace(backendHelp), | 		Help: strings.TrimSpace(backendHelp), | ||||||
|  |  | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/connection", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathConfigConnection(&b), | 			pathConfigConnection(&b), | ||||||
| 			pathConfigLease(&b), | 			pathConfigLease(&b), | ||||||
|   | |||||||
| @@ -26,6 +26,12 @@ func Backend() *backend { | |||||||
| 	b.Backend = &framework.Backend{ | 	b.Backend = &framework.Backend{ | ||||||
| 		Help: strings.TrimSpace(backendHelp), | 		Help: strings.TrimSpace(backendHelp), | ||||||
|  |  | ||||||
|  | 		PathsSpecial: &logical.Paths{ | ||||||
|  | 			SealWrapStorage: []string{ | ||||||
|  | 				"config/connection", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		Paths: []*framework.Path{ | 		Paths: []*framework.Path{ | ||||||
| 			pathConfigConnection(&b), | 			pathConfigConnection(&b), | ||||||
| 			pathConfigLease(&b), | 			pathConfigLease(&b), | ||||||
|   | |||||||
| @@ -271,6 +271,31 @@ func TestSSHBackend_Lookup(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestSSHBackend_RoleList(t *testing.T) { | ||||||
|  | 	testOTPRoleData := map[string]interface{}{ | ||||||
|  | 		"key_type":     testOTPKeyType, | ||||||
|  | 		"default_user": testUserName, | ||||||
|  | 		"cidr_list":    testCIDRList, | ||||||
|  | 	} | ||||||
|  | 	resp1 := map[string]interface{}{} | ||||||
|  | 	resp2 := map[string]interface{}{ | ||||||
|  | 		"keys": []string{testOTPRoleName}, | ||||||
|  | 		"key_info": map[string]interface{}{ | ||||||
|  | 			testOTPRoleName: map[string]interface{}{ | ||||||
|  | 				"key_type": testOTPKeyType, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
|  | 		Factory: testingFactory, | ||||||
|  | 		Steps: []logicaltest.TestStep{ | ||||||
|  | 			testRoleList(t, resp1), | ||||||
|  | 			testRoleWrite(t, testOTPRoleName, testOTPRoleData), | ||||||
|  | 			testRoleList(t, resp2), | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestSSHBackend_DynamicKeyCreate(t *testing.T) { | func TestSSHBackend_DynamicKeyCreate(t *testing.T) { | ||||||
| 	testDynamicRoleData := map[string]interface{}{ | 	testDynamicRoleData := map[string]interface{}{ | ||||||
| 		"key_type":     testDynamicKeyType, | 		"key_type":     testDynamicKeyType, | ||||||
| @@ -962,6 +987,25 @@ func testRoleWrite(t *testing.T, name string, data map[string]interface{}) logic | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func testRoleList(t *testing.T, expected map[string]interface{}) logicaltest.TestStep { | ||||||
|  | 	return logicaltest.TestStep{ | ||||||
|  | 		Operation: logical.ListOperation, | ||||||
|  | 		Path:      "roles", | ||||||
|  | 		Check: func(resp *logical.Response) error { | ||||||
|  | 			if resp == nil { | ||||||
|  | 				return fmt.Errorf("nil response") | ||||||
|  | 			} | ||||||
|  | 			if resp.Data == nil { | ||||||
|  | 				return fmt.Errorf("nil data") | ||||||
|  | 			} | ||||||
|  | 			if !reflect.DeepEqual(resp.Data, expected) { | ||||||
|  | 				return fmt.Errorf("Invalid response:\nactual:%#v\nexpected is %#v", resp.Data, expected) | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func testRoleRead(t *testing.T, roleName string, expected map[string]interface{}) logicaltest.TestStep { | func testRoleRead(t *testing.T, roleName string, expected map[string]interface{}) logicaltest.TestStep { | ||||||
| 	return logicaltest.TestStep{ | 	return logicaltest.TestStep{ | ||||||
| 		Operation: logical.ReadOperation, | 		Operation: logical.ReadOperation, | ||||||
|   | |||||||
| @@ -175,7 +175,7 @@ func pathRoles(b *backend) *framework.Path { | |||||||
| 				`, | 				`, | ||||||
| 			}, | 			}, | ||||||
| 			"ttl": &framework.FieldSchema{ | 			"ttl": &framework.FieldSchema{ | ||||||
| 				Type: framework.TypeString, | 				Type: framework.TypeDurationSecond, | ||||||
| 				Description: ` | 				Description: ` | ||||||
| 				[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type] | 				[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type] | ||||||
| 				The lease duration if no specific lease duration is | 				The lease duration if no specific lease duration is | ||||||
| @@ -184,7 +184,7 @@ func pathRoles(b *backend) *framework.Path { | |||||||
| 				the value of max_ttl.`, | 				the value of max_ttl.`, | ||||||
| 			}, | 			}, | ||||||
| 			"max_ttl": &framework.FieldSchema{ | 			"max_ttl": &framework.FieldSchema{ | ||||||
| 				Type: framework.TypeString, | 				Type: framework.TypeDurationSecond, | ||||||
| 				Description: ` | 				Description: ` | ||||||
| 				[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type] | 				[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type] | ||||||
| 				The maximum allowed lease duration | 				The maximum allowed lease duration | ||||||
| @@ -386,15 +386,15 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (* | |||||||
| 			return logical.ErrorResponse("missing admin username"), nil | 			return logical.ErrorResponse("missing admin username"), nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// This defaults to 1024 and it can also be 2048. | 		// This defaults to 1024 and it can also be 2048 and 4096. | ||||||
| 		keyBits := d.Get("key_bits").(int) | 		keyBits := d.Get("key_bits").(int) | ||||||
| 		if keyBits != 0 && keyBits != 1024 && keyBits != 2048 { | 		if keyBits != 0 && keyBits != 1024 && keyBits != 2048 && keyBits != 4096 { | ||||||
| 			return logical.ErrorResponse("invalid key_bits field"), nil | 			return logical.ErrorResponse("invalid key_bits field"), nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// If user has not set this field, default it to 1024 | 		// If user has not set this field, default it to 2048 | ||||||
| 		if keyBits == 0 { | 		if keyBits == 0 { | ||||||
| 			keyBits = 1024 | 			keyBits = 2048 | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Store all the fields required by dynamic key type | 		// Store all the fields required by dynamic key type | ||||||
| @@ -433,9 +433,9 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (* | |||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework.FieldData) (*sshRole, *logical.Response) { | func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework.FieldData) (*sshRole, *logical.Response) { | ||||||
|  | 	ttl := time.Duration(data.Get("ttl").(int)) * time.Second | ||||||
|  | 	maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second | ||||||
| 	role := &sshRole{ | 	role := &sshRole{ | ||||||
| 		MaxTTL: data.Get("max_ttl").(string), |  | ||||||
| 		TTL:    data.Get("ttl").(string), |  | ||||||
| 		AllowedCriticalOptions: data.Get("allowed_critical_options").(string), | 		AllowedCriticalOptions: data.Get("allowed_critical_options").(string), | ||||||
| 		AllowedExtensions:      data.Get("allowed_extensions").(string), | 		AllowedExtensions:      data.Get("allowed_extensions").(string), | ||||||
| 		AllowUserCertificates:  data.Get("allow_user_certificates").(bool), | 		AllowUserCertificates:  data.Get("allow_user_certificates").(bool), | ||||||
| @@ -457,44 +457,12 @@ func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework | |||||||
| 	defaultCriticalOptions := convertMapToStringValue(data.Get("default_critical_options").(map[string]interface{})) | 	defaultCriticalOptions := convertMapToStringValue(data.Get("default_critical_options").(map[string]interface{})) | ||||||
| 	defaultExtensions := convertMapToStringValue(data.Get("default_extensions").(map[string]interface{})) | 	defaultExtensions := convertMapToStringValue(data.Get("default_extensions").(map[string]interface{})) | ||||||
|  |  | ||||||
| 	var maxTTL time.Duration | 	if ttl != 0 && maxTTL != 0 && ttl > maxTTL { | ||||||
| 	maxSystemTTL := b.System().MaxLeaseTTL() | 		return nil, logical.ErrorResponse( | ||||||
| 	if len(role.MaxTTL) == 0 { | 			`"ttl" value must be less than "max_ttl" when both are specified`) | ||||||
| 		maxTTL = maxSystemTTL |  | ||||||
| 	} else { |  | ||||||
| 		var err error |  | ||||||
| 		maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, logical.ErrorResponse(fmt.Sprintf( |  | ||||||
| 				"Invalid max ttl: %s", err)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if maxTTL > maxSystemTTL { |  | ||||||
| 		return nil, logical.ErrorResponse("Requested max TTL is higher than backend maximum") |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ttl := b.System().DefaultLeaseTTL() | 	// Persist TTLs | ||||||
| 	if len(role.TTL) != 0 { |  | ||||||
| 		var err error |  | ||||||
| 		ttl, err = parseutil.ParseDurationSecond(role.TTL) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, logical.ErrorResponse(fmt.Sprintf( |  | ||||||
| 				"Invalid ttl: %s", err)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if ttl > maxTTL { |  | ||||||
| 		// If they are using the system default, cap it to the role max; |  | ||||||
| 		// if it was specified on the command line, make it an error |  | ||||||
| 		if len(role.TTL) == 0 { |  | ||||||
| 			ttl = maxTTL |  | ||||||
| 		} else { |  | ||||||
| 			return nil, logical.ErrorResponse( |  | ||||||
| 				`"ttl" value must be less than "max_ttl" and/or backend default max lease TTL value`, |  | ||||||
| 			) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Persist clamped TTLs |  | ||||||
| 	role.TTL = ttl.String() | 	role.TTL = ttl.String() | ||||||
| 	role.MaxTTL = maxTTL.String() | 	role.MaxTTL = maxTTL.String() | ||||||
| 	role.DefaultCriticalOptions = defaultCriticalOptions | 	role.DefaultCriticalOptions = defaultCriticalOptions | ||||||
| @@ -520,13 +488,115 @@ func (b *backend) getRole(s logical.Storage, n string) (*sshRole, error) { | |||||||
| 	return &result, nil | 	return &result, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // parseRole converts a sshRole object into its map[string]interface representation, | ||||||
|  | // with appropriate values for each KeyType. If the KeyType is invalid, it will retun | ||||||
|  | // an error. | ||||||
|  | func (b *backend) parseRole(role *sshRole) (map[string]interface{}, error) { | ||||||
|  | 	var result map[string]interface{} | ||||||
|  |  | ||||||
|  | 	switch role.KeyType { | ||||||
|  | 	case KeyTypeOTP: | ||||||
|  | 		result = map[string]interface{}{ | ||||||
|  | 			"default_user":      role.DefaultUser, | ||||||
|  | 			"cidr_list":         role.CIDRList, | ||||||
|  | 			"exclude_cidr_list": role.ExcludeCIDRList, | ||||||
|  | 			"key_type":          role.KeyType, | ||||||
|  | 			"port":              role.Port, | ||||||
|  | 			"allowed_users":     role.AllowedUsers, | ||||||
|  | 		} | ||||||
|  | 	case KeyTypeCA: | ||||||
|  | 		ttl, err := parseutil.ParseDurationSecond(role.TTL) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		maxTTL, err := parseutil.ParseDurationSecond(role.MaxTTL) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		result = map[string]interface{}{ | ||||||
|  | 			"allowed_users":            role.AllowedUsers, | ||||||
|  | 			"allowed_domains":          role.AllowedDomains, | ||||||
|  | 			"default_user":             role.DefaultUser, | ||||||
|  | 			"ttl":                      int64(ttl.Seconds()), | ||||||
|  | 			"max_ttl":                  int64(maxTTL.Seconds()), | ||||||
|  | 			"allowed_critical_options": role.AllowedCriticalOptions, | ||||||
|  | 			"allowed_extensions":       role.AllowedExtensions, | ||||||
|  | 			"allow_user_certificates":  role.AllowUserCertificates, | ||||||
|  | 			"allow_host_certificates":  role.AllowHostCertificates, | ||||||
|  | 			"allow_bare_domains":       role.AllowBareDomains, | ||||||
|  | 			"allow_subdomains":         role.AllowSubdomains, | ||||||
|  | 			"allow_user_key_ids":       role.AllowUserKeyIDs, | ||||||
|  | 			"key_id_format":            role.KeyIDFormat, | ||||||
|  | 			"key_type":                 role.KeyType, | ||||||
|  | 			"default_critical_options": role.DefaultCriticalOptions, | ||||||
|  | 			"default_extensions":       role.DefaultExtensions, | ||||||
|  | 		} | ||||||
|  | 	case KeyTypeDynamic: | ||||||
|  | 		result = map[string]interface{}{ | ||||||
|  | 			"key":               role.KeyName, | ||||||
|  | 			"admin_user":        role.AdminUser, | ||||||
|  | 			"default_user":      role.DefaultUser, | ||||||
|  | 			"cidr_list":         role.CIDRList, | ||||||
|  | 			"exclude_cidr_list": role.ExcludeCIDRList, | ||||||
|  | 			"port":              role.Port, | ||||||
|  | 			"key_type":          role.KeyType, | ||||||
|  | 			"key_bits":          role.KeyBits, | ||||||
|  | 			"allowed_users":     role.AllowedUsers, | ||||||
|  | 			"key_option_specs":  role.KeyOptionSpecs, | ||||||
|  | 			// Returning install script will make the output look messy. | ||||||
|  | 			// But this is one way for clients to see the script that is | ||||||
|  | 			// being used to install the key. If there is some problem, | ||||||
|  | 			// the script can be modified and configured by clients. | ||||||
|  | 			"install_script": role.InstallScript, | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		return nil, fmt.Errorf("invalid key type: %v", role.KeyType) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (b *backend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
| 	entries, err := req.Storage.List("roles/") | 	entries, err := req.Storage.List("roles/") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return logical.ListResponse(entries), nil | 	keyInfo := map[string]interface{}{} | ||||||
|  | 	for _, entry := range entries { | ||||||
|  | 		role, err := b.getRole(req.Storage, entry) | ||||||
|  | 		if err != nil { | ||||||
|  | 			// On error, log warning and continue | ||||||
|  | 			if b.Logger().IsWarn() { | ||||||
|  | 				b.Logger().Warn("ssh: error getting role info", "role", entry, "error", err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if role == nil { | ||||||
|  | 			// On empty role, log warning and continue | ||||||
|  | 			if b.Logger().IsWarn() { | ||||||
|  | 				b.Logger().Warn("ssh: no role info found", "role", entry) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		roleInfo, err := b.parseRole(role) | ||||||
|  | 		if err != nil { | ||||||
|  | 			if b.Logger().IsWarn() { | ||||||
|  | 				b.Logger().Warn("ssh: error parsing role info", "role", entry, "error", err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if keyType, ok := roleInfo["key_type"]; ok { | ||||||
|  | 			keyInfo[entry] = map[string]interface{}{ | ||||||
|  | 				"key_type": keyType, | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return logical.ListResponseWithInfo(entries, keyInfo), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
| @@ -538,60 +608,14 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l | |||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Return information should be based on the key type of the role | 	roleInfo, err := b.parseRole(role) | ||||||
| 	if role.KeyType == KeyTypeOTP { | 	if err != nil { | ||||||
| 		return &logical.Response{ | 		return nil, err | ||||||
| 			Data: map[string]interface{}{ |  | ||||||
| 				"default_user":      role.DefaultUser, |  | ||||||
| 				"cidr_list":         role.CIDRList, |  | ||||||
| 				"exclude_cidr_list": role.ExcludeCIDRList, |  | ||||||
| 				"key_type":          role.KeyType, |  | ||||||
| 				"port":              role.Port, |  | ||||||
| 				"allowed_users":     role.AllowedUsers, |  | ||||||
| 			}, |  | ||||||
| 		}, nil |  | ||||||
| 	} else if role.KeyType == KeyTypeCA { |  | ||||||
| 		return &logical.Response{ |  | ||||||
| 			Data: map[string]interface{}{ |  | ||||||
| 				"allowed_users":   role.AllowedUsers, |  | ||||||
| 				"allowed_domains": role.AllowedDomains, |  | ||||||
| 				"default_user":    role.DefaultUser, |  | ||||||
| 				"max_ttl":         role.MaxTTL, |  | ||||||
| 				"ttl":             role.TTL, |  | ||||||
| 				"allowed_critical_options": role.AllowedCriticalOptions, |  | ||||||
| 				"allowed_extensions":       role.AllowedExtensions, |  | ||||||
| 				"allow_user_certificates":  role.AllowUserCertificates, |  | ||||||
| 				"allow_host_certificates":  role.AllowHostCertificates, |  | ||||||
| 				"allow_bare_domains":       role.AllowBareDomains, |  | ||||||
| 				"allow_subdomains":         role.AllowSubdomains, |  | ||||||
| 				"allow_user_key_ids":       role.AllowUserKeyIDs, |  | ||||||
| 				"key_id_format":            role.KeyIDFormat, |  | ||||||
| 				"key_type":                 role.KeyType, |  | ||||||
| 				"default_critical_options": role.DefaultCriticalOptions, |  | ||||||
| 				"default_extensions":       role.DefaultExtensions, |  | ||||||
| 			}, |  | ||||||
| 		}, nil |  | ||||||
| 	} else { |  | ||||||
| 		return &logical.Response{ |  | ||||||
| 			Data: map[string]interface{}{ |  | ||||||
| 				"key":               role.KeyName, |  | ||||||
| 				"admin_user":        role.AdminUser, |  | ||||||
| 				"default_user":      role.DefaultUser, |  | ||||||
| 				"cidr_list":         role.CIDRList, |  | ||||||
| 				"exclude_cidr_list": role.ExcludeCIDRList, |  | ||||||
| 				"port":              role.Port, |  | ||||||
| 				"key_type":          role.KeyType, |  | ||||||
| 				"key_bits":          role.KeyBits, |  | ||||||
| 				"allowed_users":     role.AllowedUsers, |  | ||||||
| 				"key_option_specs":  role.KeyOptionSpecs, |  | ||||||
| 				// Returning install script will make the output look messy. |  | ||||||
| 				// But this is one way for clients to see the script that is |  | ||||||
| 				// being used to install the key. If there is some problem, |  | ||||||
| 				// the script can be modified and configured by clients. |  | ||||||
| 				"install_script": role.InstallScript, |  | ||||||
| 			}, |  | ||||||
| 		}, nil |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	return &logical.Response{ | ||||||
|  | 		Data: roleInfo, | ||||||
|  | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|   | |||||||
| @@ -43,7 +43,7 @@ func pathSign(b *backend) *framework.Path { | |||||||
| 				Description: `The desired role with configuration for this request.`, | 				Description: `The desired role with configuration for this request.`, | ||||||
| 			}, | 			}, | ||||||
| 			"ttl": &framework.FieldSchema{ | 			"ttl": &framework.FieldSchema{ | ||||||
| 				Type: framework.TypeString, | 				Type: framework.TypeDurationSecond, | ||||||
| 				Description: `The requested Time To Live for the SSH certificate; | 				Description: `The requested Time To Live for the SSH certificate; | ||||||
| sets the expiration date. If not specified | sets the expiration date. If not specified | ||||||
| the role default, backend default, or system | the role default, backend default, or system | ||||||
| @@ -345,40 +345,34 @@ func (b *backend) calculateExtensions(data *framework.FieldData, role *sshRole) | |||||||
| } | } | ||||||
|  |  | ||||||
| func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) { | func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) { | ||||||
|  |  | ||||||
| 	var ttl, maxTTL time.Duration | 	var ttl, maxTTL time.Duration | ||||||
| 	var ttlField string | 	var err error | ||||||
| 	ttlFieldInt, ok := data.GetOk("ttl") |  | ||||||
| 	if !ok { |  | ||||||
| 		ttlField = role.TTL |  | ||||||
| 	} else { |  | ||||||
| 		ttlField = ttlFieldInt.(string) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(ttlField) == 0 { | 	ttlRaw, specifiedTTL := data.GetOk("ttl") | ||||||
|  | 	if specifiedTTL { | ||||||
|  | 		ttl = time.Duration(ttlRaw.(int)) * time.Second | ||||||
|  | 	} else { | ||||||
|  | 		ttl, err = parseutil.ParseDurationSecond(role.TTL) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return 0, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if ttl == 0 { | ||||||
| 		ttl = b.System().DefaultLeaseTTL() | 		ttl = b.System().DefaultLeaseTTL() | ||||||
| 	} else { |  | ||||||
| 		var err error |  | ||||||
| 		ttl, err = parseutil.ParseDurationSecond(ttlField) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return 0, fmt.Errorf("invalid requested ttl: %s", err) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if len(role.MaxTTL) == 0 { | 	maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	if maxTTL == 0 { | ||||||
| 		maxTTL = b.System().MaxLeaseTTL() | 		maxTTL = b.System().MaxLeaseTTL() | ||||||
| 	} else { |  | ||||||
| 		var err error |  | ||||||
| 		maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return 0, fmt.Errorf("invalid requested max ttl: %s", err) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if ttl > maxTTL { | 	if ttl > maxTTL { | ||||||
| 		// Don't error if they were using system defaults, only error if | 		// Don't error if they were using system defaults, only error if | ||||||
| 		// they specifically chose a bad TTL | 		// they specifically chose a bad TTL | ||||||
| 		if len(ttlField) == 0 { | 		if !specifiedTTL { | ||||||
| 			ttl = maxTTL | 			ttl = maxTTL | ||||||
| 		} else { | 		} else { | ||||||
| 			return 0, fmt.Errorf("ttl is larger than maximum allowed (%d)", maxTTL/time.Second) | 			return 0, fmt.Errorf("ttl is larger than maximum allowed (%d)", maxTTL/time.Second) | ||||||
|   | |||||||
| @@ -43,6 +43,8 @@ func Backend(conf *logical.BackendConfig) *backend { | |||||||
| 			b.pathHMAC(), | 			b.pathHMAC(), | ||||||
| 			b.pathSign(), | 			b.pathSign(), | ||||||
| 			b.pathVerify(), | 			b.pathVerify(), | ||||||
|  | 			b.pathBackup(), | ||||||
|  | 			b.pathRestore(), | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		Secrets:     []*framework.Secret{}, | 		Secrets:     []*framework.Secret{}, | ||||||
|   | |||||||
| @@ -38,6 +38,191 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) { | |||||||
| 	return b, config.StorageView | 	return b, config.StorageView | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestTransit_RSA(t *testing.T) { | ||||||
|  | 	testTransit_RSA(t, "rsa-2048") | ||||||
|  | 	testTransit_RSA(t, "rsa-4096") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func testTransit_RSA(t *testing.T, keyType string) { | ||||||
|  | 	var resp *logical.Response | ||||||
|  | 	var err error | ||||||
|  | 	b, storage := createBackendWithStorage(t) | ||||||
|  |  | ||||||
|  | 	keyReq := &logical.Request{ | ||||||
|  | 		Path:      "keys/rsa", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"type": keyType, | ||||||
|  | 		}, | ||||||
|  | 		Storage: storage, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err = b.HandleRequest(keyReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox" | ||||||
|  |  | ||||||
|  | 	encryptReq := &logical.Request{ | ||||||
|  | 		Path:      "encrypt/rsa", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"plaintext": plaintext, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err = b.HandleRequest(encryptReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ciphertext1 := resp.Data["ciphertext"].(string) | ||||||
|  |  | ||||||
|  | 	decryptReq := &logical.Request{ | ||||||
|  | 		Path:      "decrypt/rsa", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"ciphertext": ciphertext1, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err = b.HandleRequest(decryptReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	decryptedPlaintext := resp.Data["plaintext"] | ||||||
|  |  | ||||||
|  | 	if plaintext != decryptedPlaintext { | ||||||
|  | 		t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Rotate the key | ||||||
|  | 	rotateReq := &logical.Request{ | ||||||
|  | 		Path:      "keys/rsa/rotate", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(rotateReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Encrypt again | ||||||
|  | 	resp, err = b.HandleRequest(encryptReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	ciphertext2 := resp.Data["ciphertext"].(string) | ||||||
|  |  | ||||||
|  | 	if ciphertext1 == ciphertext2 { | ||||||
|  | 		t.Fatalf("expected different ciphertexts") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// See if the older ciphertext can still be decrypted | ||||||
|  | 	resp, err = b.HandleRequest(decryptReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	if resp.Data["plaintext"].(string) != plaintext { | ||||||
|  | 		t.Fatal("failed to decrypt old ciphertext after rotating the key") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Decrypt the new ciphertext | ||||||
|  | 	decryptReq.Data = map[string]interface{}{ | ||||||
|  | 		"ciphertext": ciphertext2, | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(decryptReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	if resp.Data["plaintext"].(string) != plaintext { | ||||||
|  | 		t.Fatal("failed to decrypt ciphertext after rotating the key") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	signReq := &logical.Request{ | ||||||
|  | 		Path:      "sign/rsa", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"input": plaintext, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(signReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	signature := resp.Data["signature"].(string) | ||||||
|  |  | ||||||
|  | 	verifyReq := &logical.Request{ | ||||||
|  | 		Path:      "verify/rsa", | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Data: map[string]interface{}{ | ||||||
|  | 			"input":     plaintext, | ||||||
|  | 			"signature": signature, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err = b.HandleRequest(verifyReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	if !resp.Data["valid"].(bool) { | ||||||
|  | 		t.Fatalf("failed to verify the RSA signature") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	signReq.Data = map[string]interface{}{ | ||||||
|  | 		"input":     plaintext, | ||||||
|  | 		"algorithm": "invalid", | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(signReq) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if resp == nil || !resp.IsError() { | ||||||
|  | 		t.Fatal("expected an error response") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	signReq.Data = map[string]interface{}{ | ||||||
|  | 		"input":     plaintext, | ||||||
|  | 		"algorithm": "sha2-512", | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(signReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	signature = resp.Data["signature"].(string) | ||||||
|  |  | ||||||
|  | 	verifyReq.Data = map[string]interface{}{ | ||||||
|  | 		"input":     plaintext, | ||||||
|  | 		"signature": signature, | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(verifyReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	if resp.Data["valid"].(bool) { | ||||||
|  | 		t.Fatalf("expected validation to fail") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	verifyReq.Data = map[string]interface{}{ | ||||||
|  | 		"input":     plaintext, | ||||||
|  | 		"signature": signature, | ||||||
|  | 		"algorithm": "sha2-512", | ||||||
|  | 	} | ||||||
|  | 	resp, err = b.HandleRequest(verifyReq) | ||||||
|  | 	if err != nil || (resp != nil && resp.IsError()) { | ||||||
|  | 		t.Fatalf("bad: err: %v\nresp: %#v", err, resp) | ||||||
|  | 	} | ||||||
|  | 	if !resp.Data["valid"].(bool) { | ||||||
|  | 		t.Fatalf("failed to verify the RSA signature") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestBackend_basic(t *testing.T) { | func TestBackend_basic(t *testing.T) { | ||||||
| 	decryptData := make(map[string]interface{}) | 	decryptData := make(map[string]interface{}) | ||||||
| 	logicaltest.Test(t, logicaltest.TestCase{ | 	logicaltest.Test(t, logicaltest.TestCase{ | ||||||
| @@ -634,7 +819,7 @@ func TestKeyUpgrade(t *testing.T) { | |||||||
| 	if p.Key != nil || | 	if p.Key != nil || | ||||||
| 		p.Keys == nil || | 		p.Keys == nil || | ||||||
| 		len(p.Keys) != 1 || | 		len(p.Keys) != 1 || | ||||||
| 		!reflect.DeepEqual(p.Keys[1].Key, key) { | 		!reflect.DeepEqual(p.Keys[strconv.Itoa(1)].Key, key) { | ||||||
| 		t.Errorf("bad key migration, result is %#v", p.Keys) | 		t.Errorf("bad key migration, result is %#v", p.Keys) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -1091,3 +1276,38 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { | |||||||
| 	// Wait for them all to finish | 	// Wait for them all to finish | ||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestBadInput(t *testing.T) { | ||||||
|  | 	var b *backend | ||||||
|  | 	sysView := logical.TestSystemView() | ||||||
|  | 	storage := &logical.InmemStorage{} | ||||||
|  |  | ||||||
|  | 	b = Backend(&logical.BackendConfig{ | ||||||
|  | 		StorageView: storage, | ||||||
|  | 		System:      sysView, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	req := &logical.Request{ | ||||||
|  | 		Storage:   storage, | ||||||
|  | 		Operation: logical.UpdateOperation, | ||||||
|  | 		Path:      "keys/test", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp, err := b.HandleRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if resp != nil { | ||||||
|  | 		t.Fatal("expected nil response") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.Path = "decrypt/test" | ||||||
|  | 	req.Data = map[string]interface{}{ | ||||||
|  | 		"ciphertext": "vault:v1:abcd", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = b.HandleRequest(req) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatal("expected error") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	 Jeff Mitchell
					Jeff Mitchell