diff --git a/command/commands.go b/command/commands.go index 5ded81cd72..9b6533afaf 100644 --- a/command/commands.go +++ b/command/commands.go @@ -52,6 +52,7 @@ import ( physManta "github.com/hashicorp/vault/physical/manta" physMSSQL "github.com/hashicorp/vault/physical/mssql" physMySQL "github.com/hashicorp/vault/physical/mysql" + physOCI "github.com/hashicorp/vault/physical/oci" physPostgreSQL "github.com/hashicorp/vault/physical/postgresql" physRaft "github.com/hashicorp/vault/physical/raft" physS3 "github.com/hashicorp/vault/physical/s3" @@ -142,6 +143,7 @@ var ( "manta": physManta.NewMantaBackend, "mssql": physMSSQL.NewMSSQLBackend, "mysql": physMySQL.NewMySQLBackend, + "oci": physOCI.NewBackend, "postgresql": physPostgreSQL.NewPostgreSQLBackend, "s3": physS3.NewS3Backend, "spanner": physSpanner.NewBackend, diff --git a/go.mod b/go.mod index 47a87735d7..b96b8de26b 100644 --- a/go.mod +++ b/go.mod @@ -101,6 +101,7 @@ require ( github.com/ncw/swift v1.0.47 github.com/oklog/run v1.0.0 github.com/onsi/ginkgo v1.7.0 // indirect + github.com/oracle/oci-go-sdk v5.15.0+incompatible github.com/ory/dockertest v3.3.4+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.8.1 diff --git a/go.sum b/go.sum index adab966a4d..fb50d94fba 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,7 @@ github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible h1:C29Ae4G5GtYyY github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= github.com/circonus-labs/circonusllhist v0.1.3 h1:TJH+oke8D16535+jHExHj4nQvzlZrj7ug5D7I/orNUA= github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= +github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudfoundry-community/go-cfclient v0.0.0-20190201205600-f136f9222381 h1:rdRS5BT13Iae9ssvcslol66gfOOXjaLYwqerEn/cl9s= github.com/cloudfoundry-community/go-cfclient v0.0.0-20190201205600-f136f9222381/go.mod h1:e5+USP2j8Le2M0Jo3qKPFnNhuo1wueU4nWHCXBOfQ14= @@ -435,6 +436,7 @@ github.com/mitchellh/go-testing-interface v1.0.0 h1:fzU/JVNcaqHQEcVFAKeR41fkiLdI github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= +github.com/mitchellh/iochan v1.0.0 h1:C+X3KsSTLFVBr/tK1eYN/vs4rJcvsiLU338UhYPJWeY= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= @@ -473,6 +475,8 @@ github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59P github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/openzipkin/zipkin-go v0.1.3/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= +github.com/oracle/oci-go-sdk v5.15.0+incompatible h1:rTlmaWEe255HczQJ2uOPM9xw3prU9jNk5GxPy+RFi3k= +github.com/oracle/oci-go-sdk v5.15.0+incompatible/go.mod h1:VQb79nF8Z2cwLkLS35ukwStZIg5F66tcBccjip/j888= github.com/ory/dockertest v3.3.4+incompatible h1:VrpM6Gqg7CrPm3bL4Wm1skO+zFWLbh7/Xb5kGEbJRh8= github.com/ory/dockertest v3.3.4+incompatible/go.mod h1:1vX4m9wsvi00u5bseYwXaSnhNrne+V0E6LAcBILJdPs= github.com/oxtoacart/bpool v0.0.0-20150712133111-4e1c5567d7c2 h1:CXwSGu/LYmbjEab5aMCs5usQRVBGThelUKBNnoSOuso= @@ -693,6 +697,7 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384 h1:TFlARGu6Czu1z7q93HTxcP1P+/ZFC/IKythI5RzrnRg= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= diff --git a/physical/oci/oci.go b/physical/oci/oci.go new file mode 100644 index 0000000000..e470602008 --- /dev/null +++ b/physical/oci/oci.go @@ -0,0 +1,378 @@ +// Copyright © 2019, Oracle and/or its affiliates. +package oci + +import ( + "bytes" + "errors" + "fmt" + "github.com/armon/go-metrics" + "github.com/hashicorp/errwrap" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/sdk/physical" + "github.com/oracle/oci-go-sdk/common" + "github.com/oracle/oci-go-sdk/common/auth" + "github.com/oracle/oci-go-sdk/objectstorage" + "golang.org/x/net/context" + "io/ioutil" + "net/http" + "sort" + "strconv" + "strings" + "time" +) + +// Verify Backend satisfies the correct interfaces +var _ physical.Backend = (*Backend)(nil) + +const ( + // Limits maximum outstanding requests + MaxNumberOfPermits = 256 +) + +var ( + metricDelete = []string{"oci", "delete"} + metricGet = []string{"oci", "get"} + metricList = []string{"oci", "list"} + metricPut = []string{"oci", "put"} + metricDeleteFull = []string{"oci", "deleteFull"} + metricGetFull = []string{"oci", "getFull"} + metricListFull = []string{"oci", "listFull"} + metricPutFull = []string{"oci", "putFull"} + + metricDeleteHa = []string{"oci", "deleteHa"} + metricGetHa = []string{"oci", "getHa"} + metricPutHa = []string{"oci", "putHa"} + + metricDeleteAcquirePool = []string{"oci", "deleteAcquirePool"} + metricGetAcquirePool = []string{"oci", "getAcquirePool"} + metricListAcquirePool = []string{"oci", "listAcquirePool"} + metricPutAcquirePool = []string{"oci", "putAcquirePool"} + + metricDeleteFailed = []string{"oci", "deleteFailed"} + metricGetFailed = []string{"oci", "getFailed"} + metricListFailed = []string{"oci", "listFailed"} + metricPutFailed = []string{"oci", "putFailed"} + metricHaWatchLockRetriable = []string{"oci", "haWatchLockRetriable"} + metricPermitsUsed = []string{"oci", "permitsUsed"} + + metric5xx = []string{"oci", "5xx"} +) + +type Backend struct { + client *objectstorage.ObjectStorageClient + bucketName string + logger log.Logger + permitPool *physical.PermitPool + namespaceName string + haEnabled bool + lockBucketName string +} + +func NewBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { + bucketName := conf["bucket_name"] + if bucketName == "" { + return nil, errors.New("missing bucket name") + } + + namespaceName := conf["namespace_name"] + if bucketName == "" { + return nil, errors.New("missing namespace name") + } + + lockBucketName := "" + haEnabled := false + var err error + haEnabledStr := conf["ha_enabled"] + if haEnabledStr != "" { + haEnabled, err = strconv.ParseBool(haEnabledStr) + if err != nil { + return nil, errwrap.Wrapf("failed to parse HA enabled: {{err}}", err) + } + + if haEnabled { + lockBucketName = conf["lock_bucket_name"] + if lockBucketName == "" { + return nil, errors.New("missing lock bucket name") + } + } + } + + authTypeAPIKeyBool := false + authTypeAPIKeyStr := conf["auth_type_api_key"] + if authTypeAPIKeyStr != "" { + authTypeAPIKeyBool, err = strconv.ParseBool(authTypeAPIKeyStr) + if err != nil { + return nil, errwrap.Wrapf("failed parsing auth_type_api_key parameter: {{err}}", err) + } + } + + var cp common.ConfigurationProvider + if authTypeAPIKeyBool { + cp = common.DefaultConfigProvider() + } else { + cp, err = auth.InstancePrincipalConfigurationProvider() + if err != nil { + return nil, errwrap.Wrapf("failed creating InstancePrincipalConfigurationProvider: {{err}}", err) + } + } + + objectStorageClient, err := objectstorage.NewObjectStorageClientWithConfigurationProvider(cp) + if err != nil { + return nil, errwrap.Wrapf("failed creating NewObjectStorageClientWithConfigurationProvider: {{err}}", err) + } + + logger.Debug("configuration", + "bucket_name", bucketName, + "namespace_name", namespaceName, + "ha_enabled", haEnabled, + "lock_bucket_name", lockBucketName, + "auth_type_api_key", authTypeAPIKeyBool, + ) + + return &Backend{ + client: &objectStorageClient, + bucketName: bucketName, + logger: logger, + permitPool: physical.NewPermitPool(MaxNumberOfPermits), + namespaceName: namespaceName, + haEnabled: haEnabled, + lockBucketName: lockBucketName, + }, nil +} + +func (o *Backend) Put(ctx context.Context, entry *physical.Entry) error { + o.logger.Debug("PUT started") + defer metrics.MeasureSince(metricPutFull, time.Now()) + startAcquirePool := time.Now() + metrics.SetGauge(metricPermitsUsed, float32(o.permitPool.CurrentPermits())) + o.permitPool.Acquire() + defer o.permitPool.Release() + metrics.MeasureSince(metricPutAcquirePool, startAcquirePool) + + defer metrics.MeasureSince(metricPut, time.Now()) + size := int64(len(entry.Value)) + opcClientRequestId, err := uuid.GenerateUUID() + if err != nil { + metrics.IncrCounter(metricPutFailed, 1) + o.logger.Error("failed to generate UUID") + return errwrap.Wrapf("failed to generate UUID: {{err}}", err) + } + + o.logger.Debug("PUT", "opc-client-request-id", opcClientRequestId) + request := objectstorage.PutObjectRequest{ + NamespaceName: &o.namespaceName, + BucketName: &o.bucketName, + ObjectName: &entry.Key, + ContentLength: &size, + PutObjectBody: ioutil.NopCloser(bytes.NewReader(entry.Value)), + OpcMeta: nil, + OpcClientRequestId: &opcClientRequestId, + } + + resp, err := o.client.PutObject(ctx, request) + if resp.RawResponse != nil && resp.RawResponse.Body != nil { + defer resp.RawResponse.Body.Close() + } + + if err != nil { + metrics.IncrCounter(metricPutFailed, 1) + return errwrap.Wrapf("failed to put data: {{err}}", err) + } + + o.logRequest("PUT", resp.RawResponse, resp.OpcClientRequestId, resp.OpcRequestId, err) + o.logger.Debug("PUT completed") + + return nil +} + +func (o *Backend) Get(ctx context.Context, key string) (*physical.Entry, error) { + o.logger.Debug("GET started") + defer metrics.MeasureSince(metricGetFull, time.Now()) + metrics.SetGauge(metricPermitsUsed, float32(o.permitPool.CurrentPermits())) + startAcquirePool := time.Now() + o.permitPool.Acquire() + defer o.permitPool.Release() + metrics.MeasureSince(metricGetAcquirePool, startAcquirePool) + + defer metrics.MeasureSince(metricGet, time.Now()) + opcClientRequestId, err := uuid.GenerateUUID() + if err != nil { + o.logger.Error("failed to generate UUID") + return nil, errwrap.Wrapf("failed to generate UUID: {{err}}", err) + } + o.logger.Debug("GET", "opc-client-request-id", opcClientRequestId) + request := objectstorage.GetObjectRequest{ + NamespaceName: &o.namespaceName, + BucketName: &o.bucketName, + ObjectName: &key, + OpcClientRequestId: &opcClientRequestId, + } + + resp, err := o.client.GetObject(ctx, request) + if resp.RawResponse != nil && resp.RawResponse.Body != nil { + defer resp.RawResponse.Body.Close() + } + o.logRequest("GET", resp.RawResponse, resp.OpcClientRequestId, resp.OpcRequestId, err) + + if err != nil { + if resp.RawResponse != nil && resp.RawResponse.StatusCode == http.StatusNotFound { + return nil, nil + } + metrics.IncrCounter(metricGetFailed, 1) + return nil, errwrap.Wrapf(fmt.Sprintf("failed to read Value: {{err}}"), err) + } + + body, err := ioutil.ReadAll(resp.Content) + if err != nil { + metrics.IncrCounter(metricGetFailed, 1) + return nil, errwrap.Wrapf("failed to decode Value into bytes: {{err}}", err) + } + + o.logger.Debug("GET completed") + + return &physical.Entry{ + Key: key, + Value: body, + }, nil +} + +func (o *Backend) Delete(ctx context.Context, key string) error { + o.logger.Debug("DELETE started") + defer metrics.MeasureSince(metricDeleteFull, time.Now()) + metrics.SetGauge(metricPermitsUsed, float32(o.permitPool.CurrentPermits())) + startAcquirePool := time.Now() + o.permitPool.Acquire() + defer o.permitPool.Release() + metrics.MeasureSince(metricDeleteAcquirePool, startAcquirePool) + + defer metrics.MeasureSince(metricDelete, time.Now()) + opcClientRequestId, err := uuid.GenerateUUID() + if err != nil { + o.logger.Error("Delete: error generating UUID") + return errwrap.Wrapf("failed to generate UUID: {{err}}", err) + } + o.logger.Debug("Delete", "opc-client-request-id", opcClientRequestId) + request := objectstorage.DeleteObjectRequest{ + NamespaceName: &o.namespaceName, + BucketName: &o.bucketName, + ObjectName: &key, + OpcClientRequestId: &opcClientRequestId, + } + + resp, err := o.client.DeleteObject(ctx, request) + if resp.RawResponse != nil && resp.RawResponse.Body != nil { + defer resp.RawResponse.Body.Close() + } + + o.logRequest("DELETE", resp.RawResponse, resp.OpcClientRequestId, resp.OpcRequestId, err) + + if err != nil { + if resp.RawResponse != nil && resp.RawResponse.StatusCode == http.StatusNotFound { + return nil + } + metrics.IncrCounter(metricDeleteFailed, 1) + return errwrap.Wrapf("failed to delete Key: {{err}}", err) + } + o.logger.Debug("DELETE completed") + + return nil +} + +func (o *Backend) List(ctx context.Context, prefix string) ([]string, error) { + o.logger.Debug("LIST started") + defer metrics.MeasureSince(metricListFull, time.Now()) + metrics.SetGauge(metricPermitsUsed, float32(o.permitPool.CurrentPermits())) + startAcquirePool := time.Now() + o.permitPool.Acquire() + defer o.permitPool.Release() + + metrics.MeasureSince(metricListAcquirePool, startAcquirePool) + defer metrics.MeasureSince(metricList, time.Now()) + var keys []string + delimiter := "/" + var start *string + + for { + opcClientRequestId, err := uuid.GenerateUUID() + if err != nil { + o.logger.Error("List: error generating UUID") + return nil, errwrap.Wrapf("failed to generate UUID {{err}}", err) + } + o.logger.Debug("LIST", "opc-client-request-id", opcClientRequestId) + request := objectstorage.ListObjectsRequest{ + NamespaceName: &o.namespaceName, + BucketName: &o.bucketName, + Prefix: &prefix, + Delimiter: &delimiter, + Start: start, + OpcClientRequestId: &opcClientRequestId, + } + + resp, err := o.client.ListObjects(ctx, request) + o.logRequest("LIST", resp.RawResponse, resp.OpcClientRequestId, resp.OpcRequestId, err) + + if err != nil { + metrics.IncrCounter(metricListFailed, 1) + return nil, errwrap.Wrapf("failed to list using prefix: {{err}}", err) + } + + for _, commonPrefix := range resp.Prefixes { + commonPrefix := strings.TrimPrefix(commonPrefix, prefix) + keys = append(keys, commonPrefix) + } + + for _, object := range resp.Objects { + key := strings.TrimPrefix(*object.Name, prefix) + keys = append(keys, key) + } + + // Duplicate keys are not expected + keys = strutil.RemoveDuplicates(keys, false) + + if resp.NextStartWith == nil { + resp.RawResponse.Body.Close() + break + } + + start = resp.NextStartWith + resp.RawResponse.Body.Close() + } + + sort.Strings(keys) + o.logger.Debug("LIST completed") + return keys, nil +} + +func (o *Backend) logRequest(operation string, response *http.Response, clientOpcRequestIdPtr *string, opcRequestIdPtr *string, err error) { + statusCode := 0 + clientOpcRequestId := " " + opcRequestId := " " + + if response != nil { + statusCode = response.StatusCode + if statusCode/100 == 5 { + metrics.IncrCounter(metric5xx, 1) + } + } + + if clientOpcRequestIdPtr != nil { + clientOpcRequestId = *clientOpcRequestIdPtr + } + + if opcRequestIdPtr != nil { + opcRequestId = *opcRequestIdPtr + } + + statusCodeStr := "No response" + if statusCode != 0 { + statusCodeStr = strconv.Itoa(statusCode) + } + + logLine := fmt.Sprintf("%s client:opc-request-id %s opc-request-id: %s status-code: %s", + operation, clientOpcRequestId, opcRequestId, statusCodeStr) + if err != nil && statusCode/100 == 5 { + o.logger.Error(logLine, "error", err) + } +} diff --git a/physical/oci/oci_ha.go b/physical/oci/oci_ha.go new file mode 100644 index 0000000000..a053a0ea48 --- /dev/null +++ b/physical/oci/oci_ha.go @@ -0,0 +1,549 @@ +// Copyright © 2019, Oracle and/or its affiliates. +package oci + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/armon/go-metrics" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/physical" + "github.com/oracle/oci-go-sdk/objectstorage" + "io/ioutil" + "net/http" + "sync" + "sync/atomic" + "time" +) + +// The lock implementation below prioritizes ensuring that there are not 2 primary at any given point in time +// over high availability of the primary instance + +// Verify Backend satisfies the correct interfaces +var _ physical.HABackend = (*Backend)(nil) +var _ physical.Lock = (*Lock)(nil) + +const ( + // LockRenewInterval is the time to wait between lock renewals. + LockRenewInterval = 3 * time.Second + + // LockRetryInterval is the amount of time to wait if the lock fails before trying again. + LockRetryInterval = 5 * time.Second + + // LockWatchRetryInterval is the amount of time to wait if a watch fails before trying again. + LockWatchRetryInterval = 2 * time.Second + + // LockTTL is the default lock TTL. + LockTTL = 15 * time.Second + + // LockWatchRetryMax is the number of times to retry a failed watch before signaling that leadership is lost. + LockWatchRetryMax = 4 + + // LockCacheMinAcceptableAge is minimum cache age in seconds to determine that its safe for a secondary instance + // to acquire lock. + LockCacheMinAcceptableAge = 45 * time.Second + + // LockWriteRetriesOnFailures is the number of retries that are made on write 5xx failures. + LockWriteRetriesOnFailures = 4 + + ObjectStorageCallsReadTimeout = 3 * time.Second + + ObjectStorageCallsWriteTimeout = 3 * time.Second +) + +type LockCache struct { + // ETag values are unique identifiers generated by the OCI service and changed every time the object is modified. + etag string + lastUpdate time.Time + lockRecord *LockRecord +} + +type Lock struct { + // backend is the underlying physical backend. + backend *Backend + + // Key is the name of the Key. Value is the Value of the Key. + key, value string + + // held is a boolean indicating if the lock is currently held. + held bool + + // Identity is the internal Identity of this Key (unique to this server instance). + identity string + + internalLock sync.Mutex + + // stopCh is the channel that stops all operations. It may be closed in the + // event of a leader loss or graceful shutdown. stopped is a boolean + // indicating if we are stopped - it exists to prevent double closing the + // channel. stopLock is a mutex around the locks. + stopCh chan struct{} + stopped bool + stopLock sync.Mutex + + lockRecordCache atomic.Value + + // Allow modifying the Lock durations for ease of unit testing. + renewInterval time.Duration + retryInterval time.Duration + ttl time.Duration + watchRetryInterval time.Duration + watchRetryMax int +} + +type LockRecord struct { + Key string + Value string + Identity string +} + +var ( + metricLockUnlock = []string{"oci", "lock", "unlock"} + metricLockLock = []string{"oci", "lock", "lock"} + metricLockValue = []string{"oci", "lock", "Value"} + metricLeaderValue = []string{"oci", "leader", "Value"} +) + +func (b *Backend) HAEnabled() bool { + return b.haEnabled +} + +// LockWith acquires a mutual exclusion based on the given Key. +func (b *Backend) LockWith(key, value string) (physical.Lock, error) { + identity, err := uuid.GenerateUUID() + if err != nil { + return nil, errwrap.Wrapf("Lock with: {{err}}", err) + } + return &Lock{ + backend: b, + key: key, + value: value, + identity: identity, + stopped: true, + + renewInterval: LockRenewInterval, + retryInterval: LockRetryInterval, + ttl: LockTTL, + watchRetryInterval: LockWatchRetryInterval, + watchRetryMax: LockWatchRetryMax, + }, nil +} + +func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { + l.backend.logger.Debug("Lock() called") + defer metrics.MeasureSince(metricLockLock, time.Now().UTC()) + l.internalLock.Lock() + defer l.internalLock.Unlock() + if l.held { + return nil, errors.New("lock already held") + } + + // Attempt to lock - this function blocks until a lock is acquired or an error + // occurs. + acquired, err := l.attemptLock(stopCh) + if err != nil { + return nil, errwrap.Wrapf("lock: {{err}}", err) + } + if !acquired { + return nil, nil + } + + // We have the lock now + l.held = true + + // Build the locks + l.stopLock.Lock() + l.stopCh = make(chan struct{}) + l.stopped = false + l.stopLock.Unlock() + + // Periodically renew and watch the lock + go l.renewLock() + go l.watchLock() + + return l.stopCh, nil +} + +// attemptLock attempts to acquire a lock. If the given channel is closed, the +// acquisition attempt stops. This function returns when a lock is acquired or +// an error occurs. +func (l *Lock) attemptLock(stopCh <-chan struct{}) (bool, error) { + l.backend.logger.Debug("AttemptLock() called") + ticker := time.NewTicker(l.retryInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + acquired, err := l.writeLock() + if err != nil { + return false, errwrap.Wrapf("attempt lock: {{err}}", err) + } + if !acquired { + continue + } + + return true, nil + case <-stopCh: + return false, nil + } + } +} + +// renewLock renews the given lock until the channel is closed. +func (l *Lock) renewLock() { + l.backend.logger.Debug("RenewLock() called") + ticker := time.NewTicker(l.renewInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + l.writeLock() + case <-l.stopCh: + return + } + } +} + +func loadLockRecordCache(l *Lock) *LockCache { + lockRecordCache := l.lockRecordCache.Load() + if lockRecordCache == nil { + return nil + } + return lockRecordCache.(*LockCache) +} + +// watchLock checks whether the lock has changed in the table and closes the +// leader channel accordingly. If an error occurs during the check, watchLock +// will retry the operation and then close the leader channel if it can't +// succeed after retries. +func (l *Lock) watchLock() { + l.backend.logger.Debug("WatchLock() called") + retries := 0 + ticker := time.NewTicker(l.watchRetryInterval) + defer ticker.Stop() + +OUTER: + for { + // Check if the channel is already closed + select { + case <-l.stopCh: + l.backend.logger.Debug("WatchLock():Stop lock signaled/closed.") + break OUTER + default: + } + + // Check if we've exceeded retries + if retries >= l.watchRetryMax-1 { + l.backend.logger.Debug("WatchLock: Failed to get lock data from object storage. Giving up the lease after max retries") + break OUTER + } + + // Wait for the timer + select { + case <-ticker.C: + case <-l.stopCh: + break OUTER + } + + lockRecordCache := loadLockRecordCache(l) + if (lockRecordCache == nil) || + (lockRecordCache.lockRecord == nil) || + (lockRecordCache.lockRecord.Identity != l.identity) || + (time.Now().Sub(lockRecordCache.lastUpdate) > l.ttl) { + l.backend.logger.Debug("WatchLock: Lock record cache is nil, stale or does not belong to self.") + break OUTER + } + + lockRecord, _, err := l.get(context.Background()) + if err != nil { + retries++ + l.backend.logger.Debug("WatchLock: Failed to get lock data from object storage. Retrying..") + metrics.SetGauge(metricHaWatchLockRetriable, 1) + continue + } + + if (lockRecord == nil) || (lockRecord.Identity != l.identity) { + l.backend.logger.Debug("WatchLock: Lock record cache is nil or does not belong to self.") + break OUTER + } + + // reset retries counter on success + retries = 0 + l.backend.logger.Debug("WatchLock() successful") + metrics.SetGauge(metricHaWatchLockRetriable, 0) + } + + l.stopLock.Lock() + defer l.stopLock.Unlock() + if !l.stopped { + l.stopped = true + l.backend.logger.Debug("Closing the stop channel to give up leadership.") + close(l.stopCh) + } +} + +func (l *Lock) Unlock() error { + l.backend.logger.Debug("Unlock() called") + defer metrics.MeasureSince(metricLockUnlock, time.Now().UTC()) + + l.internalLock.Lock() + defer l.internalLock.Unlock() + if !l.held { + return nil + } + + // Stop any existing locking or renewal attempts + l.stopLock.Lock() + if !l.stopped { + l.stopped = true + close(l.stopCh) + } + l.stopLock.Unlock() + + // We are no longer holding the lock + l.held = false + + // Get current lock record + currentLockRecord, etag, err := l.get(context.Background()) + if err != nil { + return errwrap.Wrapf("error reading lock record: {{err}}", err) + } + + if currentLockRecord != nil && currentLockRecord.Identity == l.identity { + + defer metrics.MeasureSince(metricDeleteHa, time.Now()) + opcClientRequestId, err := uuid.GenerateUUID() + if err != nil { + l.backend.logger.Debug("Unlock: error generating UUID") + return errwrap.Wrapf("failed to generate UUID: {{err}}", err) + } + l.backend.logger.Debug("Unlock", "opc-client-request-id", opcClientRequestId) + request := objectstorage.DeleteObjectRequest{ + NamespaceName: &l.backend.namespaceName, + BucketName: &l.backend.lockBucketName, + ObjectName: &l.key, + IfMatch: &etag, + OpcClientRequestId: &opcClientRequestId, + } + + response, err := l.backend.client.DeleteObject(context.Background(), request) + l.backend.logRequest("deleteHA", response.RawResponse, response.OpcClientRequestId, response.OpcRequestId, err) + + if err != nil { + metrics.IncrCounter(metricDeleteFailed, 1) + return errwrap.Wrapf("write lock: {{err}}", err) + } + } + + return nil +} + +func (l *Lock) Value() (bool, string, error) { + l.backend.logger.Debug("Value() called") + defer metrics.MeasureSince(metricLockValue, time.Now().UTC()) + + lockRecord, _, err := l.get(context.Background()) + if err != nil { + return false, "", err + } + if lockRecord == nil { + return false, "", err + } + return true, lockRecord.Value, nil +} + +// get retrieves the Value for the lock. +func (l *Lock) get(ctx context.Context) (*LockRecord, string, error) { + l.backend.logger.Debug("Called getLockRecord()") + + // Read lock Key + + defer metrics.MeasureSince(metricGetHa, time.Now()) + opcClientRequestId, err := uuid.GenerateUUID() + if err != nil { + l.backend.logger.Error("getHa: error generating UUID") + return nil, "", errwrap.Wrapf("failed to generate UUID: {{err}}", err) + } + l.backend.logger.Debug("getHa", "opc-client-request-id", opcClientRequestId) + + request := objectstorage.GetObjectRequest{ + NamespaceName: &l.backend.namespaceName, + BucketName: &l.backend.lockBucketName, + ObjectName: &l.key, + OpcClientRequestId: &opcClientRequestId, + } + + ctx, cancel := context.WithTimeout(ctx, ObjectStorageCallsReadTimeout) + defer cancel() + + response, err := l.backend.client.GetObject(ctx, request) + l.backend.logRequest("getHA", response.RawResponse, response.OpcClientRequestId, response.OpcRequestId, err) + + if err != nil { + if response.RawResponse != nil && response.RawResponse.StatusCode == http.StatusNotFound { + return nil, "", nil + } + + metrics.IncrCounter(metricGetFailed, 1) + l.backend.logger.Error("Error calling GET", "err", err) + return nil, "", errwrap.Wrapf(fmt.Sprintf("failed to read Value for %q: {{err}}", l.key), err) + } + + defer response.RawResponse.Body.Close() + + body, err := ioutil.ReadAll(response.Content) + if err != nil { + metrics.IncrCounter(metricGetFailed, 1) + l.backend.logger.Error("Error reading content", "err", err) + return nil, "", errwrap.Wrapf("failed to decode Value into bytes: {{err}}", err) + } + + var lockRecord LockRecord + err = json.Unmarshal(body, &lockRecord) + if err != nil { + metrics.IncrCounter(metricGetFailed, 1) + l.backend.logger.Error("Error un-marshalling content", "err", err) + return nil, "", errwrap.Wrapf(fmt.Sprintf("failed to read Value for %q: {{err}}", l.key), err) + } + + return &lockRecord, *response.ETag, nil +} + +func (l *Lock) writeLock() (bool, error) { + l.backend.logger.Debug("WriteLock() called") + + // Create a transaction to read and the update (maybe) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // The transaction will be retried, and it could sit in a queue behind, say, + // the delete operation. To stop the transaction, we close the context when + // the associated stopCh is received. + go func() { + select { + case <-l.stopCh: + cancel() + case <-ctx.Done(): + } + }() + + lockRecordCache := loadLockRecordCache(l) + if (lockRecordCache == nil) || lockRecordCache.lockRecord == nil || + lockRecordCache.lockRecord.Identity != l.identity || + time.Now().Sub(lockRecordCache.lastUpdate) > l.ttl { + // case secondary + currentLockRecord, currentEtag, err := l.get(ctx) + if err != nil { + return false, errwrap.Wrapf("error reading lock record: {{err}}", err) + } + + if (lockRecordCache == nil) || lockRecordCache.etag != currentEtag { + // update cached lock record + l.lockRecordCache.Store(&LockCache{ + etag: currentEtag, + lastUpdate: time.Now().UTC(), + lockRecord: currentLockRecord, + }) + + lockRecordCache = loadLockRecordCache(l) + } + + // Current lock record being null implies that there is no leader. In this case we want to try acquiring lock. + if currentLockRecord != nil && time.Now().Sub(lockRecordCache.lastUpdate) < LockCacheMinAcceptableAge { + return false, nil + } + // cache is old enough and current, try acquiring lock as secondary + } + + newLockRecord := &LockRecord{ + Key: l.key, + Value: l.value, + Identity: l.identity, + } + + newLockRecordJson, err := json.Marshal(newLockRecord) + if err != nil { + return false, errwrap.Wrapf("error reading lock record: {{err}}", err) + } + + defer metrics.MeasureSince(metricPutHa, time.Now()) + + opcClientRequestId, err := uuid.GenerateUUID() + if err != nil { + l.backend.logger.Error("putHa: error generating UUID") + return false, errwrap.Wrapf("failed to generate UUID", err) + } + l.backend.logger.Debug("putHa", "opc-client-request-id", opcClientRequestId) + size := int64(len(newLockRecordJson)) + putRequest := objectstorage.PutObjectRequest{ + NamespaceName: &l.backend.namespaceName, + BucketName: &l.backend.lockBucketName, + ObjectName: &l.key, + ContentLength: &size, + PutObjectBody: ioutil.NopCloser(bytes.NewReader(newLockRecordJson)), + OpcMeta: nil, + OpcClientRequestId: &opcClientRequestId, + } + + if lockRecordCache.etag == "" { + noneMatch := "*" + putRequest.IfNoneMatch = &noneMatch + } else { + putRequest.IfMatch = &lockRecordCache.etag + } + + newtEtag := "" + for i := 1; i <= LockWriteRetriesOnFailures; i++ { + writeCtx, writeCancel := context.WithTimeout(ctx, ObjectStorageCallsWriteTimeout) + defer writeCancel() + + putObjectResponse, putObjectError := l.backend.client.PutObject(writeCtx, putRequest) + l.backend.logRequest("putHA", putObjectResponse.RawResponse, putObjectResponse.OpcClientRequestId, putObjectResponse.OpcRequestId, putObjectError) + + if putObjectError == nil { + newtEtag = *putObjectResponse.ETag + putObjectResponse.RawResponse.Body.Close() + break + } + + err = putObjectError + + if putObjectResponse.RawResponse == nil { + metrics.IncrCounter(metricPutFailed, 1) + l.backend.logger.Error("PUT", "err", err) + break + } + + putObjectResponse.RawResponse.Body.Close() + + // Retry if the return code is 5xx + if (putObjectResponse.RawResponse.StatusCode / 100) == 5 { + metrics.IncrCounter(metricPutFailed, 1) + l.backend.logger.Warn("PUT. Retrying..", "err", err) + time.Sleep(time.Duration(100*i) * time.Millisecond) + } else { + l.backend.logger.Error("PUT", "err", err) + break + } + } + + if err != nil { + return false, errwrap.Wrapf("write lock: {{err}}", err) + } + + l.backend.logger.Debug("Lock written", string(newLockRecordJson)) + + l.lockRecordCache.Store(&LockCache{ + etag: newtEtag, + lastUpdate: time.Now().UTC(), + lockRecord: newLockRecord, + }) + + metrics.SetGauge(metricLeaderValue, 1) + return true, nil +} diff --git a/physical/oci/oci_ha_test.go b/physical/oci/oci_ha_test.go new file mode 100644 index 0000000000..e3d5bc3e15 --- /dev/null +++ b/physical/oci/oci_ha_test.go @@ -0,0 +1,33 @@ +// Copyright © 2019, Oracle and/or its affiliates. +package oci + +import ( + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/physical" + "github.com/oracle/oci-go-sdk/common" + "github.com/oracle/oci-go-sdk/objectstorage" + "os" + "testing" +) + +func TestOCIHABackend(t *testing.T) { + // Skip tests if we are not running acceptance tests + if os.Getenv("VAULT_ACC") == "" { + t.SkipNow() + } + bucketName, _ := uuid.GenerateUUID() + configProvider := common.DefaultConfigProvider() + objectStorageClient, _ := objectstorage.NewObjectStorageClientWithConfigurationProvider(configProvider) + namespaceName := getNamespaceName(objectStorageClient, t) + + createBucket(bucketName, getTenancyOcid(configProvider, t), namespaceName, objectStorageClient, t) + defer deleteBucket(namespaceName, bucketName, objectStorageClient, t) + + backend := createBackend(bucketName, namespaceName, "true", bucketName, t) + ha, ok := backend.(physical.HABackend) + if !ok { + t.Fatalf("does not implement") + } + + physical.ExerciseHABackend(t, ha, ha) +} diff --git a/physical/oci/oci_test.go b/physical/oci/oci_test.go new file mode 100644 index 0000000000..bf85f0b566 --- /dev/null +++ b/physical/oci/oci_test.go @@ -0,0 +1,88 @@ +// Copyright © 2019, Oracle and/or its affiliates. +package oci + +import ( + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/physical" + "github.com/oracle/oci-go-sdk/common" + "github.com/oracle/oci-go-sdk/objectstorage" + "golang.org/x/net/context" + "os" + "testing" +) + +func TestOCIBackend(t *testing.T) { + // Skip tests if we are not running acceptance tests + if os.Getenv("VAULT_ACC") == "" { + t.SkipNow() + } + bucketName, _ := uuid.GenerateUUID() + configProvider := common.DefaultConfigProvider() + objectStorageClient, _ := objectstorage.NewObjectStorageClientWithConfigurationProvider(configProvider) + namespaceName := getNamespaceName(objectStorageClient, t) + + createBucket(bucketName, getTenancyOcid(configProvider, t), namespaceName, objectStorageClient, t) + defer deleteBucket(namespaceName, bucketName, objectStorageClient, t) + + backend := createBackend(bucketName, namespaceName, "false", "", t) + physical.ExerciseBackend(t, backend) + physical.ExerciseBackend_ListPrefix(t, backend) +} + +func createBucket(bucketName string, tenancyOcid string, namespaceName string, objectStorageClient objectstorage.ObjectStorageClient, t *testing.T) { + createBucketRequest := objectstorage.CreateBucketRequest{ + NamespaceName: &namespaceName, + } + createBucketRequest.CompartmentId = &tenancyOcid + createBucketRequest.Name = &bucketName + createBucketRequest.Metadata = make(map[string]string) + createBucketRequest.PublicAccessType = objectstorage.CreateBucketDetailsPublicAccessTypeNopublicaccess + _, err := objectStorageClient.CreateBucket(context.Background(), createBucketRequest) + if err != nil { + t.Fatalf("Failed to create bucket: %v", err) + } +} + +func deleteBucket(nameSpaceName string, bucketName string, objectStorageClient objectstorage.ObjectStorageClient, t *testing.T) { + request := objectstorage.DeleteBucketRequest{ + NamespaceName: &nameSpaceName, + BucketName: &bucketName, + } + _, err := objectStorageClient.DeleteBucket(context.Background(), request) + if err != nil { + t.Fatalf("Failed to delete bucket: %v", err) + } +} + +func getTenancyOcid(configProvider common.ConfigurationProvider, t *testing.T) (tenancyOcid string) { + tenancyOcid, err := configProvider.TenancyOCID() + if err != nil { + t.Fatalf("Failed to get tenancy ocid: %v", err) + } + return tenancyOcid +} + +func createBackend(bucketName string, namespaceName string, haEnabledStr string, lockBucketName string, t *testing.T) physical.Backend { + backend, err := NewBackend(map[string]string{ + "auth_type_api_key": "true", + "bucket_name": bucketName, + "namespace_name": namespaceName, + "ha_enabled": haEnabledStr, + "lock_bucket_name": lockBucketName, + }, logging.NewVaultLogger(log.Trace)) + if err != nil { + t.Fatalf("Failed to create new backend: %v", err) + } + return backend +} + +func getNamespaceName(objectStorageClient objectstorage.ObjectStorageClient, t *testing.T) string { + response, err := objectStorageClient.GetNamespace(context.Background(), objectstorage.GetNamespaceRequest{}) + if err != nil { + t.Fatalf("Failed to get namespaceName: %v", err) + } + nameSpaceName := *response.Value + return nameSpaceName +} diff --git a/sdk/physical/physical.go b/sdk/physical/physical.go index cb621282fb..ec98e90590 100644 --- a/sdk/physical/physical.go +++ b/sdk/physical/physical.go @@ -148,6 +148,11 @@ func (c *PermitPool) Release() { <-c.sem } +// Get number of requests in the permit pool +func (c *PermitPool) CurrentPermits() int { + return len(c.sem) +} + // Prefixes is a shared helper function returns all parent 'folders' for a // given vault key. // e.g. for 'foo/bar/baz', it returns ['foo', 'foo/bar'] diff --git a/vendor/github.com/hashicorp/vault/sdk/physical/physical.go b/vendor/github.com/hashicorp/vault/sdk/physical/physical.go index cb621282fb..ec98e90590 100644 --- a/vendor/github.com/hashicorp/vault/sdk/physical/physical.go +++ b/vendor/github.com/hashicorp/vault/sdk/physical/physical.go @@ -148,6 +148,11 @@ func (c *PermitPool) Release() { <-c.sem } +// Get number of requests in the permit pool +func (c *PermitPool) CurrentPermits() int { + return len(c.sem) +} + // Prefixes is a shared helper function returns all parent 'folders' for a // given vault key. // e.g. for 'foo/bar/baz', it returns ['foo', 'foo/bar']