mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	 861454e0ed
			
		
	
	861454e0ed
	
	
	
		
			
			* Swap sdk/helper libs to go-secure-stdlib * Migrate to go-secure-stdlib reloadutil * Migrate to go-secure-stdlib kv-builder * Migrate to go-secure-stdlib gatedwriter
		
			
				
	
	
		
			261 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			261 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package cassandra
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/hashicorp/vault/sdk/helper/template"
 | |
| 
 | |
| 	"github.com/gocql/gocql"
 | |
| 	multierror "github.com/hashicorp/go-multierror"
 | |
| 	"github.com/hashicorp/go-secure-stdlib/strutil"
 | |
| 	dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
 | |
| 	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	defaultUserCreationCQL   = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
 | |
| 	defaultUserDeletionCQL   = `DROP USER '{{username}}';`
 | |
| 	defaultChangePasswordCQL = `ALTER USER '{{username}}' WITH PASSWORD '{{password}}';`
 | |
| 	cassandraTypeName        = "cassandra"
 | |
| 
 | |
| 	defaultUserNameTemplate = `{{ printf "v_%s_%s_%s_%s" (.DisplayName | truncate 15) (.RoleName | truncate 15) (random 20) (unix_time) | truncate 100 | replace "-" "_" | lowercase }}`
 | |
| )
 | |
| 
 | |
| var _ dbplugin.Database = &Cassandra{}
 | |
| 
 | |
| // Cassandra is an implementation of Database interface
 | |
| type Cassandra struct {
 | |
| 	*cassandraConnectionProducer
 | |
| 
 | |
| 	usernameProducer template.StringTemplate
 | |
| }
 | |
| 
 | |
| // New returns a new Cassandra instance
 | |
| func New() (interface{}, error) {
 | |
| 	db := new()
 | |
| 	dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
 | |
| 
 | |
| 	return dbType, nil
 | |
| }
 | |
| 
 | |
| func new() *Cassandra {
 | |
| 	connProducer := &cassandraConnectionProducer{}
 | |
| 	connProducer.Type = cassandraTypeName
 | |
| 
 | |
| 	return &Cassandra{
 | |
| 		cassandraConnectionProducer: connProducer,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Type returns the TypeName for this backend
 | |
| func (c *Cassandra) Type() (string, error) {
 | |
| 	return cassandraTypeName, nil
 | |
| }
 | |
| 
 | |
| func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
 | |
| 	session, err := c.Connection(ctx)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return session.(*gocql.Session), nil
 | |
| }
 | |
| 
 | |
| func (c *Cassandra) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
 | |
| 	usernameTemplate, err := strutil.GetString(req.Config, "username_template")
 | |
| 	if err != nil {
 | |
| 		return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err)
 | |
| 	}
 | |
| 	if usernameTemplate == "" {
 | |
| 		usernameTemplate = defaultUserNameTemplate
 | |
| 	}
 | |
| 
 | |
| 	up, err := template.NewTemplate(template.Template(usernameTemplate))
 | |
| 	if err != nil {
 | |
| 		return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err)
 | |
| 	}
 | |
| 	c.usernameProducer = up
 | |
| 
 | |
| 	_, err = c.usernameProducer.Generate(dbplugin.UsernameMetadata{})
 | |
| 	if err != nil {
 | |
| 		return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	err = c.cassandraConnectionProducer.Initialize(ctx, req)
 | |
| 	if err != nil {
 | |
| 		return dbplugin.InitializeResponse{}, fmt.Errorf("failed to initialize: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	resp := dbplugin.InitializeResponse{
 | |
| 		Config: req.Config,
 | |
| 	}
 | |
| 	return resp, nil
 | |
| }
 | |
| 
 | |
| // NewUser generates the username/password on the underlying Cassandra secret backend as instructed by
 | |
| // the statements provided.
 | |
| func (c *Cassandra) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 
 | |
| 	session, err := c.getConnection(ctx)
 | |
| 	if err != nil {
 | |
| 		return dbplugin.NewUserResponse{}, err
 | |
| 	}
 | |
| 
 | |
| 	creationCQL := req.Statements.Commands
 | |
| 	if len(creationCQL) == 0 {
 | |
| 		creationCQL = []string{defaultUserCreationCQL}
 | |
| 	}
 | |
| 
 | |
| 	rollbackCQL := req.RollbackStatements.Commands
 | |
| 	if len(rollbackCQL) == 0 {
 | |
| 		rollbackCQL = []string{defaultUserDeletionCQL}
 | |
| 	}
 | |
| 
 | |
| 	username, err := c.usernameProducer.Generate(req.UsernameConfig)
 | |
| 	if err != nil {
 | |
| 		return dbplugin.NewUserResponse{}, err
 | |
| 	}
 | |
| 
 | |
| 	for _, stmt := range creationCQL {
 | |
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
 | |
| 			query = strings.TrimSpace(query)
 | |
| 			if len(query) == 0 {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			m := map[string]string{
 | |
| 				"username": username,
 | |
| 				"password": req.Password,
 | |
| 			}
 | |
| 			err = session.
 | |
| 				Query(dbutil.QueryHelper(query, m)).
 | |
| 				WithContext(ctx).
 | |
| 				Exec()
 | |
| 			if err != nil {
 | |
| 				rollbackErr := rollbackUser(ctx, session, username, rollbackCQL)
 | |
| 				if rollbackErr != nil {
 | |
| 					err = multierror.Append(err, rollbackErr)
 | |
| 				}
 | |
| 				return dbplugin.NewUserResponse{}, err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	resp := dbplugin.NewUserResponse{
 | |
| 		Username: username,
 | |
| 	}
 | |
| 	return resp, nil
 | |
| }
 | |
| 
 | |
| func rollbackUser(ctx context.Context, session *gocql.Session, username string, rollbackCQL []string) error {
 | |
| 	for _, stmt := range rollbackCQL {
 | |
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
 | |
| 			query = strings.TrimSpace(query)
 | |
| 			if len(query) == 0 {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			m := map[string]string{
 | |
| 				"username": username,
 | |
| 			}
 | |
| 			err := session.
 | |
| 				Query(dbutil.QueryHelper(query, m)).
 | |
| 				WithContext(ctx).
 | |
| 				Exec()
 | |
| 			if err != nil {
 | |
| 				return fmt.Errorf("failed to roll back user %s: %w", username, err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *Cassandra) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
 | |
| 	if req.Password == nil && req.Expiration == nil {
 | |
| 		return dbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
 | |
| 	}
 | |
| 
 | |
| 	if req.Password != nil {
 | |
| 		err := c.changeUserPassword(ctx, req.Username, req.Password)
 | |
| 		return dbplugin.UpdateUserResponse{}, err
 | |
| 	}
 | |
| 	// Expiration is no-op
 | |
| 	return dbplugin.UpdateUserResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (c *Cassandra) changeUserPassword(ctx context.Context, username string, changePass *dbplugin.ChangePassword) error {
 | |
| 	session, err := c.getConnection(ctx)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	rotateCQL := changePass.Statements.Commands
 | |
| 	if len(rotateCQL) == 0 {
 | |
| 		rotateCQL = []string{defaultChangePasswordCQL}
 | |
| 	}
 | |
| 
 | |
| 	var result *multierror.Error
 | |
| 	for _, stmt := range rotateCQL {
 | |
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
 | |
| 			query = strings.TrimSpace(query)
 | |
| 			if len(query) == 0 {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			m := map[string]string{
 | |
| 				"username": username,
 | |
| 				"password": changePass.NewPassword,
 | |
| 			}
 | |
| 			err := session.
 | |
| 				Query(dbutil.QueryHelper(query, m)).
 | |
| 				WithContext(ctx).
 | |
| 				Exec()
 | |
| 			result = multierror.Append(result, err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return result.ErrorOrNil()
 | |
| }
 | |
| 
 | |
| // DeleteUser attempts to drop the specified user.
 | |
| func (c *Cassandra) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 
 | |
| 	session, err := c.getConnection(ctx)
 | |
| 	if err != nil {
 | |
| 		return dbplugin.DeleteUserResponse{}, err
 | |
| 	}
 | |
| 
 | |
| 	revocationCQL := req.Statements.Commands
 | |
| 	if len(revocationCQL) == 0 {
 | |
| 		revocationCQL = []string{defaultUserDeletionCQL}
 | |
| 	}
 | |
| 
 | |
| 	var result *multierror.Error
 | |
| 	for _, stmt := range revocationCQL {
 | |
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
 | |
| 			query = strings.TrimSpace(query)
 | |
| 			if len(query) == 0 {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			m := map[string]string{
 | |
| 				"username": req.Username,
 | |
| 			}
 | |
| 			err := session.
 | |
| 				Query(dbutil.QueryHelper(query, m)).
 | |
| 				WithContext(ctx).
 | |
| 				Exec()
 | |
| 
 | |
| 			result = multierror.Append(result, err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return dbplugin.DeleteUserResponse{}, result.ErrorOrNil()
 | |
| }
 |