mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 11:38:02 +00:00
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
// Verify MSSQLBackend satisfies the correct interfaces
|
||||
var _ physical.Backend = (*MSSQLBackend)(nil)
|
||||
var identifierRegex = regexp.MustCompile(`^[\p{L}_][\p{L}\p{Nd}@#$_]*$`)
|
||||
|
||||
type MSSQLBackend struct {
|
||||
dbTable string
|
||||
@@ -30,6 +32,13 @@ type MSSQLBackend struct {
|
||||
permitPool *physical.PermitPool
|
||||
}
|
||||
|
||||
func isInvalidIdentifier(name string) bool {
|
||||
if !identifierRegex.MatchString(name) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
||||
username, ok := conf["username"]
|
||||
if !ok {
|
||||
@@ -71,11 +80,19 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||
database = "Vault"
|
||||
}
|
||||
|
||||
if isInvalidIdentifier(database) {
|
||||
return nil, fmt.Errorf("invalid database name")
|
||||
}
|
||||
|
||||
table, ok := conf["table"]
|
||||
if !ok {
|
||||
table = "Vault"
|
||||
}
|
||||
|
||||
if isInvalidIdentifier(table) {
|
||||
return nil, fmt.Errorf("invalid table name")
|
||||
}
|
||||
|
||||
appname, ok := conf["appname"]
|
||||
if !ok {
|
||||
appname = "Vault"
|
||||
@@ -96,6 +113,10 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||
schema = "dbo"
|
||||
}
|
||||
|
||||
if isInvalidIdentifier(schema) {
|
||||
return nil, fmt.Errorf("invalid schema name")
|
||||
}
|
||||
|
||||
connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel)
|
||||
if username != "" {
|
||||
connectionString += ";user id=" + username
|
||||
@@ -116,18 +137,17 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||
|
||||
db.SetMaxOpenConns(maxParInt)
|
||||
|
||||
if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil {
|
||||
if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = ?) CREATE DATABASE "+database, database); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mssql database: %w", err)
|
||||
}
|
||||
|
||||
dbTable := database + "." + schema + "." + table
|
||||
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema +
|
||||
"') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
|
||||
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME=? AND TABLE_SCHEMA=?) CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
|
||||
|
||||
if schema != "dbo" {
|
||||
|
||||
var num int
|
||||
err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num)
|
||||
err = db.QueryRow("SELECT 1 FROM "+database+".sys.schemas WHERE name = ?", schema).Scan(&num)
|
||||
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
@@ -140,7 +160,7 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createQuery); err != nil {
|
||||
if _, err := db.Exec(createQuery, table, schema); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mssql table: %w", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user