mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	Add OSS stub functions for Self-Managed Static Roles (#28199)
This commit is contained in:
		| @@ -199,6 +199,15 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { | ||||
| 	return db.(*sql.DB), nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) getStaticConnection(ctx context.Context, username, password string) (*sql.DB, error) { | ||||
| 	db, err := p.StaticConnection(ctx, username, password) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return db, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) { | ||||
| 	if req.Username == "" { | ||||
| 		return dbplugin.UpdateUserResponse{}, fmt.Errorf("missing username") | ||||
| @@ -209,17 +218,17 @@ func (p *PostgreSQL) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequ | ||||
|  | ||||
| 	merr := &multierror.Error{} | ||||
| 	if req.Password != nil { | ||||
| 		err := p.changeUserPassword(ctx, req.Username, req.Password) | ||||
| 		err := p.changeUserPassword(ctx, req.Username, req.Password, req.SelfManagedPassword) | ||||
| 		merr = multierror.Append(merr, err) | ||||
| 	} | ||||
| 	if req.Expiration != nil { | ||||
| 		err := p.changeUserExpiration(ctx, req.Username, req.Expiration) | ||||
| 		err := p.changeUserExpiration(ctx, req.Username, req.Expiration, req.SelfManagedPassword) | ||||
| 		merr = multierror.Append(merr, err) | ||||
| 	} | ||||
| 	return dbplugin.UpdateUserResponse{}, merr.ErrorOrNil() | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, changePass *dbplugin.ChangePassword) error { | ||||
| func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, changePass *dbplugin.ChangePassword, selfManagedPass string) error { | ||||
| 	stmts := changePass.Statements.Commands | ||||
| 	if len(stmts) == 0 { | ||||
| 		stmts = []string{defaultChangePasswordStatement} | ||||
| @@ -233,9 +242,18 @@ func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, ch | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("unable to get connection: %w", err) | ||||
| 	var db *sql.DB | ||||
| 	var err error | ||||
| 	if selfManagedPass == "" { | ||||
| 		db, err = p.getConnection(ctx) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("unable to get connection: %w", err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		db, err = p.getStaticConnection(ctx, username, selfManagedPass) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("unable to get static connection from cache: %w", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Check if the role exists | ||||
| @@ -285,7 +303,7 @@ func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, ch | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, changeExp *dbplugin.ChangeExpiration) error { | ||||
| func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, changeExp *dbplugin.ChangeExpiration, selfManagedPass string) error { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| @@ -294,9 +312,18 @@ func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, | ||||
| 		renewStmts = []string{defaultExpirationStatement} | ||||
| 	} | ||||
|  | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	var db *sql.DB | ||||
| 	var err error | ||||
| 	if selfManagedPass == "" { | ||||
| 		db, err = p.getConnection(ctx) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("unable to get connection: %w", err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		db, err = p.getStaticConnection(ctx, username, selfManagedPass) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("unable to get static connection from cache: %w", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
|   | ||||
| @@ -640,6 +640,94 @@ func TestPostgreSQL_Initialize_CloudGCP(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestPostgreSQL_Initialize_SelfManaged_OSS tests the initialization of | ||||
| // the self-managed flow and ensures an error is returned on OSS. | ||||
| func TestPostgreSQL_Initialize_SelfManaged_OSS(t *testing.T) { | ||||
| 	cleanup, url := postgresql.PrepareTestContainerSelfManaged(t) | ||||
| 	defer cleanup() | ||||
|  | ||||
| 	connURL := fmt.Sprintf("postgresql://{{username}}:{{password}}@%s/postgres?sslmode=disable", url.Host) | ||||
|  | ||||
| 	testCases := []struct { | ||||
| 		name              string | ||||
| 		connectionDetails map[string]interface{} | ||||
| 		wantErr           bool | ||||
| 		errContains       string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "no parameters set", | ||||
| 			connectionDetails: map[string]interface{}{ | ||||
| 				"connection_url": connURL, | ||||
| 				"self_managed":   false, | ||||
| 				"username":       "", | ||||
| 				"password":       "", | ||||
| 			}, | ||||
| 			wantErr:     true, | ||||
| 			errContains: "must either provide username/password or set self-managed to 'true'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "both sets of parameters set", | ||||
| 			connectionDetails: map[string]interface{}{ | ||||
| 				"connection_url": connURL, | ||||
| 				"self_managed":   true, | ||||
| 				"username":       "test", | ||||
| 				"password":       "test", | ||||
| 			}, | ||||
| 			wantErr:     true, | ||||
| 			errContains: "cannot use both self-managed and vault-managed workflows", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "either username/password with self-managed", | ||||
| 			connectionDetails: map[string]interface{}{ | ||||
| 				"connection_url": connURL, | ||||
| 				"self_managed":   true, | ||||
| 				"username":       "test", | ||||
| 				"password":       "", | ||||
| 			}, | ||||
| 			wantErr:     true, | ||||
| 			errContains: "cannot use both self-managed and vault-managed workflows", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "cache not implemented", | ||||
| 			connectionDetails: map[string]interface{}{ | ||||
| 				"connection_url": connURL, | ||||
| 				"self_managed":   true, | ||||
| 				"username":       "", | ||||
| 				"password":       "", | ||||
| 			}, | ||||
| 			wantErr:     true, | ||||
| 			errContains: "self-managed static roles only available in Vault Enterprise", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			req := dbplugin.InitializeRequest{ | ||||
| 				Config:           tc.connectionDetails, | ||||
| 				VerifyConnection: true, | ||||
| 			} | ||||
|  | ||||
| 			db := new() | ||||
| 			_, err := dbtesting.VerifyInitialize(t, db, req) | ||||
| 			if err == nil && tc.wantErr { | ||||
| 				t.Fatalf("got: %s, wantErr: %t", err, tc.wantErr) | ||||
| 			} | ||||
|  | ||||
| 			if err != nil && !strings.Contains(err.Error(), tc.errContains) { | ||||
| 				t.Fatalf("expected error: %s, received error: %s", tc.errContains, err) | ||||
| 			} | ||||
|  | ||||
| 			if !tc.wantErr && !db.Initialized { | ||||
| 				t.Fatal("Database should be initialized") | ||||
| 			} | ||||
|  | ||||
| 			if err := db.Close(); err != nil { | ||||
| 				t.Fatalf("err closing DB: %s", err) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestPostgreSQL_PasswordAuthentication tests that the default "password_authentication" is "none", and that | ||||
| // an error is returned if an invalid "password_authentication" is provided. | ||||
| func TestPostgreSQL_PasswordAuthentication(t *testing.T) { | ||||
| @@ -1045,6 +1133,37 @@ func TestUpdateUser_Password(t *testing.T) { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // TestUpdateUser_SelfManaged_OSS checks basic validation | ||||
| // for self-managed fields and confirms an error is returned on OSS | ||||
| func TestUpdateUser_SelfManaged_OSS(t *testing.T) { | ||||
| 	// Shared test container for speed - there should not be any overlap between the tests | ||||
| 	db, cleanup := getPostgreSQL(t, nil) | ||||
| 	defer cleanup() | ||||
|  | ||||
| 	updateReq := dbplugin.UpdateUserRequest{ | ||||
| 		Username: "static", | ||||
| 		Password: &dbplugin.ChangePassword{ | ||||
| 			NewPassword: "somenewpassword", | ||||
| 			Statements: dbplugin.Statements{ | ||||
| 				Commands: nil, | ||||
| 			}, | ||||
| 		}, | ||||
| 		SelfManagedPassword: "test", | ||||
| 	} | ||||
|  | ||||
| 	expectedErr := "self-managed static roles only available in Vault Enterprise" | ||||
|  | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 	defer cancel() | ||||
| 	_, err := db.UpdateUser(ctx, updateReq) | ||||
| 	if err == nil { | ||||
| 		t.Fatalf("err expected, got nil") | ||||
| 	} | ||||
| 	if !strings.Contains(err.Error(), expectedErr) { | ||||
| 		t.Fatalf("err expected: %s, got: %s", expectedErr, err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestUpdateUser_Expiration(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		initialExpiration  time.Time | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 vinay-gopalan
					vinay-gopalan