mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
[VAULT-3379] Add support for contained DBs in MSSQL root rotation and lease revocation (#12839)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
@@ -30,6 +31,9 @@ type MSSQL struct {
|
||||
*connutil.SQLConnectionProducer
|
||||
|
||||
usernameProducer template.StringTemplate
|
||||
|
||||
// A flag to let us know to skip cross DB queries and server login checks
|
||||
containedDB bool
|
||||
}
|
||||
|
||||
func New() (interface{}, error) {
|
||||
@@ -94,6 +98,20 @@ func (m *MSSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest)
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template - did you reference a field that isn't available? : %w", err)
|
||||
}
|
||||
|
||||
containedDB := false
|
||||
containedDBRaw, err := strutil.GetString(req.Config, "contained_db")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve contained_db: %w", err)
|
||||
}
|
||||
if containedDBRaw != "" {
|
||||
containedDB, err = strconv.ParseBool(containedDBRaw)
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("parsing error: incorrect boolean operator provided for contained_db: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.containedDB = containedDB
|
||||
|
||||
resp := dbplugin.InitializeResponse{
|
||||
Config: newConf,
|
||||
}
|
||||
@@ -201,6 +219,19 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if DB is contained
|
||||
if m.containedDB {
|
||||
revokeStmt, err := db.PrepareContext(ctx, fmt.Sprintf("DROP USER IF EXISTS [%s]", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer revokeStmt.Close()
|
||||
if _, err := revokeStmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// First disable server login
|
||||
disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
if err != nil {
|
||||
@@ -311,7 +342,7 @@ func (m *MSSQL) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest)
|
||||
|
||||
func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass *dbplugin.ChangePassword) error {
|
||||
stmts := changePass.Statements.Commands
|
||||
if len(stmts) == 0 {
|
||||
if len(stmts) == 0 && !m.containedDB {
|
||||
stmts = []string{alterLoginSQL}
|
||||
}
|
||||
|
||||
@@ -329,12 +360,16 @@ func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass
|
||||
return err
|
||||
}
|
||||
|
||||
var exists bool
|
||||
// Since contained DB users do not have server logins, we
|
||||
// only query for a login if DB is not a contained DB
|
||||
if !m.containedDB {
|
||||
var exists bool
|
||||
|
||||
err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists)
|
||||
err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists)
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
|
||||
@@ -42,6 +42,15 @@ func TestInitialize(t *testing.T) {
|
||||
VerifyConnection: true,
|
||||
},
|
||||
},
|
||||
"contained_db set": {
|
||||
dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"contained_db": "true",
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
@@ -265,6 +274,26 @@ func TestUpdateUser_password(t *testing.T) {
|
||||
}
|
||||
|
||||
assertCredsExist(t, connURL, dbUser, test.expectedPassword)
|
||||
|
||||
// Delete user at the end of each test
|
||||
deleteReq := dbplugin.DeleteUserRequest{
|
||||
Username: dbUser,
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
deleteResp, err := db.DeleteUser(ctx, deleteReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete user: %s", err)
|
||||
}
|
||||
|
||||
// Protect against future fields that aren't specified
|
||||
expectedDeleteResp := dbplugin.DeleteUserResponse{}
|
||||
if !reflect.DeepEqual(deleteResp, expectedDeleteResp) {
|
||||
t.Fatalf("Fields missing from expected response: Actual: %#v", deleteResp)
|
||||
}
|
||||
|
||||
assertCredsDoNotExist(t, connURL, dbUser, initPassword)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user