mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 04:28:08 +00:00 
			
		
		
		
	* Swap sdk/helper libs to go-secure-stdlib * Migrate to go-secure-stdlib reloadutil * Migrate to go-secure-stdlib kv-builder * Migrate to go-secure-stdlib gatedwriter
		
			
				
	
	
		
			263 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			263 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package mssql
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"database/sql"
 | 
						|
	"fmt"
 | 
						|
	"sort"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	metrics "github.com/armon/go-metrics"
 | 
						|
	_ "github.com/denisenkom/go-mssqldb"
 | 
						|
	log "github.com/hashicorp/go-hclog"
 | 
						|
	"github.com/hashicorp/go-secure-stdlib/strutil"
 | 
						|
	"github.com/hashicorp/vault/sdk/physical"
 | 
						|
)
 | 
						|
 | 
						|
// Verify MSSQLBackend satisfies the correct interfaces
 | 
						|
var _ physical.Backend = (*MSSQLBackend)(nil)
 | 
						|
 | 
						|
type MSSQLBackend struct {
 | 
						|
	dbTable    string
 | 
						|
	client     *sql.DB
 | 
						|
	statements map[string]*sql.Stmt
 | 
						|
	logger     log.Logger
 | 
						|
	permitPool *physical.PermitPool
 | 
						|
}
 | 
						|
 | 
						|
func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
 | 
						|
	username, ok := conf["username"]
 | 
						|
	if !ok {
 | 
						|
		username = ""
 | 
						|
	}
 | 
						|
 | 
						|
	password, ok := conf["password"]
 | 
						|
	if !ok {
 | 
						|
		password = ""
 | 
						|
	}
 | 
						|
 | 
						|
	server, ok := conf["server"]
 | 
						|
	if !ok || server == "" {
 | 
						|
		return nil, fmt.Errorf("missing server")
 | 
						|
	}
 | 
						|
 | 
						|
	port, ok := conf["port"]
 | 
						|
	if !ok {
 | 
						|
		port = ""
 | 
						|
	}
 | 
						|
 | 
						|
	maxParStr, ok := conf["max_parallel"]
 | 
						|
	var maxParInt int
 | 
						|
	var err error
 | 
						|
	if ok {
 | 
						|
		maxParInt, err = strconv.Atoi(maxParStr)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
 | 
						|
		}
 | 
						|
		if logger.IsDebug() {
 | 
						|
			logger.Debug("max_parallel set", "max_parallel", maxParInt)
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		maxParInt = physical.DefaultParallelOperations
 | 
						|
	}
 | 
						|
 | 
						|
	database, ok := conf["database"]
 | 
						|
	if !ok {
 | 
						|
		database = "Vault"
 | 
						|
	}
 | 
						|
 | 
						|
	table, ok := conf["table"]
 | 
						|
	if !ok {
 | 
						|
		table = "Vault"
 | 
						|
	}
 | 
						|
 | 
						|
	appname, ok := conf["appname"]
 | 
						|
	if !ok {
 | 
						|
		appname = "Vault"
 | 
						|
	}
 | 
						|
 | 
						|
	connectionTimeout, ok := conf["connectiontimeout"]
 | 
						|
	if !ok {
 | 
						|
		connectionTimeout = "30"
 | 
						|
	}
 | 
						|
 | 
						|
	logLevel, ok := conf["loglevel"]
 | 
						|
	if !ok {
 | 
						|
		logLevel = "0"
 | 
						|
	}
 | 
						|
 | 
						|
	schema, ok := conf["schema"]
 | 
						|
	if !ok || schema == "" {
 | 
						|
		schema = "dbo"
 | 
						|
	}
 | 
						|
 | 
						|
	connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel)
 | 
						|
	if username != "" {
 | 
						|
		connectionString += ";user id=" + username
 | 
						|
	}
 | 
						|
 | 
						|
	if password != "" {
 | 
						|
		connectionString += ";password=" + password
 | 
						|
	}
 | 
						|
 | 
						|
	if port != "" {
 | 
						|
		connectionString += ";port=" + port
 | 
						|
	}
 | 
						|
 | 
						|
	db, err := sql.Open("mssql", connectionString)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to connect to mssql: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	db.SetMaxOpenConns(maxParInt)
 | 
						|
 | 
						|
	if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to create mssql database: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	dbTable := database + "." + schema + "." + table
 | 
						|
	createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema +
 | 
						|
		"') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
 | 
						|
 | 
						|
	if schema != "dbo" {
 | 
						|
 | 
						|
		var num int
 | 
						|
		err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num)
 | 
						|
 | 
						|
		switch {
 | 
						|
		case err == sql.ErrNoRows:
 | 
						|
			if _, err := db.Exec("USE " + database + "; EXEC ('CREATE SCHEMA " + schema + "')"); err != nil {
 | 
						|
				return nil, fmt.Errorf("failed to create mssql schema: %w", err)
 | 
						|
			}
 | 
						|
 | 
						|
		case err != nil:
 | 
						|
			return nil, fmt.Errorf("failed to check if mssql schema exists: %w", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := db.Exec(createQuery); err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to create mssql table: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	m := &MSSQLBackend{
 | 
						|
		dbTable:    dbTable,
 | 
						|
		client:     db,
 | 
						|
		statements: make(map[string]*sql.Stmt),
 | 
						|
		logger:     logger,
 | 
						|
		permitPool: physical.NewPermitPool(maxParInt),
 | 
						|
	}
 | 
						|
 | 
						|
	statements := map[string]string{
 | 
						|
		"put": "IF EXISTS(SELECT 1 FROM " + dbTable + " WHERE Path = ?) UPDATE " + dbTable + " SET Value = ? WHERE Path = ?" +
 | 
						|
			" ELSE INSERT INTO " + dbTable + " VALUES(?, ?)",
 | 
						|
		"get":    "SELECT Value FROM " + dbTable + " WHERE Path = ?",
 | 
						|
		"delete": "DELETE FROM " + dbTable + " WHERE Path = ?",
 | 
						|
		"list":   "SELECT Path FROM " + dbTable + " WHERE Path LIKE ?",
 | 
						|
	}
 | 
						|
 | 
						|
	for name, query := range statements {
 | 
						|
		if err := m.prepare(name, query); err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return m, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *MSSQLBackend) prepare(name, query string) error {
 | 
						|
	stmt, err := m.client.Prepare(query)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to prepare %q: %w", name, err)
 | 
						|
	}
 | 
						|
 | 
						|
	m.statements[name] = stmt
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *MSSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
 | 
						|
	defer metrics.MeasureSince([]string{"mssql", "put"}, time.Now())
 | 
						|
 | 
						|
	m.permitPool.Acquire()
 | 
						|
	defer m.permitPool.Release()
 | 
						|
 | 
						|
	_, err := m.statements["put"].Exec(entry.Key, entry.Value, entry.Key, entry.Key, entry.Value)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *MSSQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
 | 
						|
	defer metrics.MeasureSince([]string{"mssql", "get"}, time.Now())
 | 
						|
 | 
						|
	m.permitPool.Acquire()
 | 
						|
	defer m.permitPool.Release()
 | 
						|
 | 
						|
	var result []byte
 | 
						|
	err := m.statements["get"].QueryRow(key).Scan(&result)
 | 
						|
	if err == sql.ErrNoRows {
 | 
						|
		return nil, nil
 | 
						|
	}
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	ent := &physical.Entry{
 | 
						|
		Key:   key,
 | 
						|
		Value: result,
 | 
						|
	}
 | 
						|
 | 
						|
	return ent, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *MSSQLBackend) Delete(ctx context.Context, key string) error {
 | 
						|
	defer metrics.MeasureSince([]string{"mssql", "delete"}, time.Now())
 | 
						|
 | 
						|
	m.permitPool.Acquire()
 | 
						|
	defer m.permitPool.Release()
 | 
						|
 | 
						|
	_, err := m.statements["delete"].Exec(key)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *MSSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
 | 
						|
	defer metrics.MeasureSince([]string{"mssql", "list"}, time.Now())
 | 
						|
 | 
						|
	m.permitPool.Acquire()
 | 
						|
	defer m.permitPool.Release()
 | 
						|
 | 
						|
	likePrefix := prefix + "%"
 | 
						|
	rows, err := m.statements["list"].Query(likePrefix)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	var keys []string
 | 
						|
	for rows.Next() {
 | 
						|
		var key string
 | 
						|
		err = rows.Scan(&key)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("failed to scan rows: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		key = strings.TrimPrefix(key, prefix)
 | 
						|
		if i := strings.Index(key, "/"); i == -1 {
 | 
						|
			keys = append(keys, key)
 | 
						|
		} else if i != -1 {
 | 
						|
			keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Strings(keys)
 | 
						|
 | 
						|
	return keys, nil
 | 
						|
}
 |