plugins/database: use context with plugins that use database/sql package (#3691)

This commit is contained in:
Brian Kassouf
2017-12-15 10:26:17 -08:00
committed by GitHub
parent d1b12356d8
commit 1eec51abff
3 changed files with 47 additions and 47 deletions

View File

@@ -120,7 +120,7 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@@ -133,7 +133,7 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
@@ -142,7 +142,7 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
return "", "", err return "", "", err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
} }
} }
@@ -164,7 +164,7 @@ func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, us
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -178,12 +178,12 @@ func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, us
} }
// Renew user's valid until property field // Renew user's valid until property field
stmt, err := tx.Prepare("ALTER USER " + username + " VALID UNTIL " + "'" + expirationStr + "'") stmt, err := tx.PrepareContext(ctx, "ALTER USER "+username+" VALID UNTIL "+"'"+expirationStr+"'")
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
@@ -209,7 +209,7 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -222,14 +222,14 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@@ -250,30 +250,30 @@ func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
// Disable server login for user // Disable server login for user
disableStmt, err := tx.Prepare(fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username)) disableStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
if err != nil { if err != nil {
return err return err
} }
defer disableStmt.Close() defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil { if _, err := disableStmt.ExecContext(ctx); err != nil {
return err return err
} }
// Invalidates current sessions and performs soft drop (drop if no dependencies) // Invalidates current sessions and performs soft drop (drop if no dependencies)
// if hard drop is desired, custom revoke statements should be written for role // if hard drop is desired, custom revoke statements should be written for role
dropStmt, err := tx.Prepare(fmt.Sprintf("DROP USER %s RESTRICT", username)) dropStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("DROP USER %s RESTRICT", username))
if err != nil { if err != nil {
return err return err
} }
defer dropStmt.Close() defer dropStmt.Close()
if _, err := dropStmt.Exec(); err != nil { if _, err := dropStmt.ExecContext(ctx); err != nil {
return err return err
} }

View File

@@ -105,7 +105,7 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@@ -118,7 +118,7 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
@@ -127,7 +127,7 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
return "", "", err return "", "", err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
} }
} }
@@ -161,7 +161,7 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -174,14 +174,14 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@@ -202,12 +202,12 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
} }
// First disable server login // First disable server login
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
if err != nil { if err != nil {
return err return err
} }
defer disableStmt.Close() defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil { if _, err := disableStmt.ExecContext(ctx); err != nil {
return err return err
} }
@@ -215,14 +215,14 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// sessions. There cannot be any active sessions before we drop the logins // sessions. There cannot be any active sessions before we drop the logins
// This isn't done in a transaction because even if we fail along the way, // This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible // we want to remove as much access as possible
sessionStmt, err := db.Prepare(fmt.Sprintf( sessionStmt, err := db.PrepareContext(ctx, fmt.Sprintf(
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
if err != nil { if err != nil {
return err return err
} }
defer sessionStmt.Close() defer sessionStmt.Close()
sessionRows, err := sessionStmt.Query() sessionRows, err := sessionStmt.QueryContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -243,13 +243,13 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// we need to drop the database users before we can drop the login and the role // we need to drop the database users before we can drop the login and the role
// This isn't done in a transaction because even if we fail along the way, // This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible // we want to remove as much access as possible
stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username)) stmt, err := db.PrepareContext(ctx, fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
rows, err := stmt.Query() rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -269,13 +269,13 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// many permissions as possible right now // many permissions as possible right now
var lastStmtError error var lastStmtError error
for _, query := range revokeStmts { for _, query := range revokeStmts {
stmt, err := db.Prepare(query) stmt, err := db.PrepareContext(ctx, query)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
continue continue
} }
defer stmt.Close() defer stmt.Close()
_, err = stmt.Exec() _, err = stmt.ExecContext(ctx)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
} }
@@ -290,12 +290,12 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
} }
// Drop this login // Drop this login
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }

View File

@@ -109,7 +109,7 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return "", "", err
@@ -126,7 +126,7 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
@@ -136,7 +136,7 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
} }
@@ -165,7 +165,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
return err return err
} }
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -183,7 +183,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
if len(query) == 0 { if len(query) == 0 {
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"expiration": expirationStr, "expiration": expirationStr,
})) }))
@@ -192,7 +192,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@@ -222,7 +222,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS
return err return err
} }
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -236,7 +236,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
@@ -244,7 +244,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@@ -264,7 +264,7 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
// Check if the role exists // Check if the role exists
var exists bool var exists bool
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
@@ -277,13 +277,13 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
// the role // the role
// This isn't done in a transaction because even if we fail along the way, // This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible // we want to remove as much access as possible
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
rows, err := stmt.Query(username) rows, err := stmt.QueryContext(ctx, username)
if err != nil { if err != nil {
return err return err
} }
@@ -325,7 +325,7 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
// get the current database name so we can issue a REVOKE CONNECT for // get the current database name so we can issue a REVOKE CONNECT for
// this username // this username
var dbname sql.NullString var dbname sql.NullString
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
return err return err
} }
@@ -340,13 +340,13 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
// many permissions as possible right now // many permissions as possible right now
var lastStmtError error var lastStmtError error
for _, query := range revocationStmts { for _, query := range revocationStmts {
stmt, err := db.Prepare(query) stmt, err := db.PrepareContext(ctx, query)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
continue continue
} }
defer stmt.Close() defer stmt.Close()
_, err = stmt.Exec() _, err = stmt.ExecContext(ctx)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
} }
@@ -361,13 +361,13 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
} }
// Drop this user // Drop this user
stmt, err = db.Prepare(fmt.Sprintf( stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }