refactor databases code and refactor master key

This commit is contained in:
Yuli
2020-02-22 20:42:01 +02:00
parent 7663e0ef61
commit 828f78fb3d
15 changed files with 231 additions and 179 deletions

View File

@@ -1,5 +1,5 @@
# build without debug
go build -ldflags "-w" -o databunker ./src/bunker.go ./src/qldb.go ./src/xtokens_db.go \
go build -ldflags "-w" -o databunker ./src/bunker.go ./src/xtokens_db.go \
./src/utils.go ./src/cryptor.go ./src/notify.go \
./src/audit_db.go ./src/audit_api.go \
./src/sms.go ./src/email.go \

View File

@@ -8,6 +8,7 @@ import (
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
@@ -37,7 +38,7 @@ func auditApp(title string, record string, app string, mode string, address stri
return &auditEvent{Title: title, Mode: mode, Who: address, Record: record, Status: "ok", When: int32(time.Now().Unix())}
}
func (event auditEvent) submit(db dbcon) {
func (event auditEvent) submit(db *dbcon) {
//fmt.Println("submit event to audit!!!!!!!!!!")
/*
bdoc, err := bson.Marshal(event)
@@ -86,12 +87,12 @@ func (event auditEvent) submit(db dbcon) {
if len(event.After) > 0 {
bdoc["after"] = event.After
}
db.createRecord(TblName.Audit, &bdoc)
db.store.CreateRecord(storage.TblName.Audit, &bdoc)
}
func (dbobj dbcon) getAuditEvents(userTOKEN string, offset int32, limit int32) ([]byte, int64, error) {
//var results []*auditEvent
count, err := dbobj.countRecords(TblName.Audit, "record", userTOKEN)
count, err := dbobj.store.CountRecords(storage.TblName.Audit, "record", userTOKEN)
if err != nil {
return nil, 0, err
}
@@ -99,7 +100,7 @@ func (dbobj dbcon) getAuditEvents(userTOKEN string, offset int32, limit int32) (
return []byte("[]"), 0, err
}
var results []bson.M
records, err := dbobj.getList(TblName.Audit, "record", userTOKEN, offset, limit)
records, err := dbobj.store.GetList(storage.TblName.Audit, "record", userTOKEN, offset, limit)
if err != nil {
return nil, 0, err
}
@@ -130,7 +131,7 @@ func (dbobj dbcon) getAuditEvents(userTOKEN string, offset int32, limit int32) (
func (dbobj dbcon) getAuditEvent(atoken string) (string, []byte, error) {
//var results []*auditEvent
record, err := dbobj.getRecord(TblName.Audit, "atoken", atoken)
record, err := dbobj.store.GetRecord(storage.TblName.Audit, "atoken", atoken)
if err != nil {
return "", nil, err
}

View File

@@ -4,6 +4,7 @@ package main
import (
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"flag"
@@ -20,34 +21,16 @@ import (
"github.com/gobuffalo/packr"
"github.com/julienschmidt/httprouter"
"github.com/kelseyhightower/envconfig"
"github.com/paranoidguy/databunker/src/storage"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
yaml "gopkg.in/yaml.v2"
)
// Tbl is used to store table id
type Tbl int
// listTbls used to store list of tables
type listTbls struct {
Users Tbl
Audit Tbl
Xtokens Tbl
Consent Tbl
Sessions Tbl
Requests Tbl
Sharedrecords Tbl
}
// TblName is enum of tables
var TblName = &listTbls{
Users: 0,
Audit: 1,
Xtokens: 2,
Consent: 3,
Sessions: 4,
Requests: 5,
Sharedrecords: 6,
type dbcon struct {
store storage.DBStorage
masterKey []byte
hash []byte
}
// Config is u sed to store application configuration
@@ -99,7 +82,7 @@ type Config struct {
// mainEnv struct stores global structures
type mainEnv struct {
db dbcon
db *dbcon
conf Config
stopChan chan struct{}
}
@@ -164,7 +147,7 @@ func (e mainEnv) backupDB(w http.ResponseWriter, r *http.Request, ps httprouter.
return
}
w.WriteHeader(200)
e.db.backupDB(w)
e.db.store.BackupDB(w)
}
// setupRouter() setup HTTP Router object.
@@ -260,7 +243,7 @@ func readFile(cfg *Config, filepath *string) error {
confFile = *filepath
}
}
fmt.Printf("Databunker conf file is: %s\n", confFile)
fmt.Printf("Databunker configuration file is: %s\n", confFile)
f, err := os.Open(confFile)
if err != nil {
return err
@@ -284,7 +267,7 @@ func (e mainEnv) dbCleanupDo() {
log.Printf("db cleanup timeout\n")
exp, _ := parseExpiration0(e.conf.Policy.MaxAuditRetentionPeriod)
if exp > 0 {
e.db.deleteExpired0(TblName.Audit, exp)
e.db.store.DeleteExpired0(storage.TblName.Audit, exp)
}
notifyURL := e.conf.Notification.ConsentNotificationURL
e.db.expireConsentRecords(notifyURL)
@@ -344,17 +327,19 @@ func logRequest(handler http.Handler) http.Handler {
})
}
func setupDB(dbPtr *string) (dbcon, string, error) {
func setupDB(dbPtr *string) (*dbcon, string, error) {
fmt.Printf("\nDatabunker init\n\n")
masterKey, err := generateMasterKey()
hash := md5.Sum(masterKey)
fmt.Printf("Master key: %x\n\n", masterKey)
fmt.Printf("Init databunker.db\n\n")
db, _ := newDB(masterKey, dbPtr)
err = db.initDB()
store, _ := storage.NewDBStorage(dbPtr)
err = store.InitDB()
if err != nil {
//log.Panic("error %s", err.Error())
log.Fatalf("db init error %s", err.Error())
}
db := &dbcon{store, masterKey, hash[:]}
rootToken, err := db.createRootXtoken()
if err != nil {
//log.Panic("error %s", err.Error())
@@ -364,51 +349,62 @@ func setupDB(dbPtr *string) (dbcon, string, error) {
return db, rootToken, err
}
func getMasterKey(masterKeyPtr *string) []byte {
masterKeyStr := ""
if masterKeyPtr != nil && len(*masterKeyPtr) > 0 {
masterKeyStr = *masterKeyPtr
} else {
masterKeyStr = os.Getenv("DATABUNKER_MASTERKEY")
}
if len(masterKeyStr) != 48 {
fmt.Printf("Failed to decode master key: bad length\n")
os.Exit(0)
}
masterKey, err := hex.DecodeString(masterKeyStr)
if err != nil {
fmt.Printf("Failed to decode master key: %s\n", err)
os.Exit(0)
}
return masterKey
}
// main application function
func main() {
rand.Seed(time.Now().UnixNano())
lockMemory()
//fmt.Printf("%+v\n", cfg)
initPtr := flag.Bool("init", false, "a bool")
initPtr := flag.Bool("init", false, "generate master key and init database")
startPtr := flag.Bool("start", false, "start databunker service. User DATABUNKER_MASTERKEY environment variable.")
masterKeyPtr := flag.String("masterkey", "", "master key")
dbPtr := flag.String("db", "", "database file")
confPtr := flag.String("conf", "", "configuration file")
confPtr := flag.String("conf", "", "configuration file name")
flag.Parse()
var cfg Config
readFile(&cfg, confPtr)
readEnv(&cfg)
var err error
var masterKey []byte
if *initPtr {
db, _, _ := setupDB(dbPtr)
db.closeDB()
db.store.CloseDB()
os.Exit(0)
}
if dbExists(dbPtr) == false {
fmt.Printf("\ndatabunker.db file is missing.\n\n")
fmt.Println(`Run "./databunker -init" for the first time to init database.`)
if storage.DBExists(dbPtr) == false {
fmt.Printf("\nDatabase is not initialized.\n\n")
fmt.Println(`Run "databunker -init" for the first time to generate keys and init database.`)
fmt.Println("")
os.Exit(0)
}
if masterKeyPtr != nil && len(*masterKeyPtr) > 0 {
if len(*masterKeyPtr) != 48 {
fmt.Printf("Failed to decode master key: bad length\n")
os.Exit(0)
}
masterKey, err = hex.DecodeString(*masterKeyPtr)
if err != nil {
fmt.Printf("Failed to decode master key: %s\n", err)
os.Exit(0)
}
} else {
fmt.Println(`Masterkey is missing.`)
fmt.Println(`Run "./databunker -masterkey key"`)
if masterKeyPtr == nil || *startPtr == false {
fmt.Println(`Run "databunker -start" will load DATABUNKER_MASTERKEY environment variable.`)
fmt.Println(`For testing "databunker -masterkey MASTER_KEY_VALUE" can be used. Not recommended for production.`)
fmt.Println("")
os.Exit(0)
}
db, _ := newDB(masterKey, dbPtr)
db.initUserApps()
masterKey := getMasterKey(masterKeyPtr)
store, _ := storage.OpenDB(dbPtr)
store.InitUserApps()
hash := md5.Sum(masterKey)
db := &dbcon{store, masterKey, hash[:]}
e := mainEnv{db, cfg, make(chan struct{})}
e.dbCleanup()
fmt.Printf("host %s\n", cfg.Server.Host+":"+cfg.Server.Port)
@@ -424,7 +420,7 @@ func main() {
close(e.stopChan)
time.Sleep(1)
srv.Shutdown(context.TODO())
db.closeDB()
db.store.CloseDB()
}()
if _, err := os.Stat(cfg.Ssl.SslCertificate); !os.IsNotExist(err) {

View File

@@ -92,7 +92,7 @@ func init() {
fmt.Printf("error %s", err.Error())
}
rootToken = myRootToken
db.initUserApps()
db.store.InitUserApps()
var cfg2 Config
cfile := "../databunker.yaml"
readFile(&cfg2, &cfile)

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/fatih/structs"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
@@ -52,13 +53,13 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s
bdoc["lastmodifiedby"] = lastmodifiedby
if len(userTOKEN) > 0 {
// first check if this consent exists, then update
raw, err := dbobj.getRecord2(TblName.Consent, "token", userTOKEN, "brief", brief)
raw, err := dbobj.store.GetRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief)
if err != nil {
fmt.Printf("error to find:%s", err)
return false, err
}
if raw != nil {
dbobj.updateRecord2(TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil)
dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil)
if status != raw["status"].(string) {
// status changed
return true, nil
@@ -66,13 +67,13 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s
return false, nil
}
} else {
raw, err := dbobj.getRecord2(TblName.Consent, "who", usercode, "brief", brief)
raw, err := dbobj.store.GetRecord2(storage.TblName.Consent, "who", usercode, "brief", brief)
if err != nil {
fmt.Printf("error to find:%s", err)
return false, err
}
if raw != nil {
dbobj.updateRecord2(TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil)
dbobj.store.UpdateRecord2(storage.TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil)
if status != raw["status"].(string) {
// status changed
return true, nil
@@ -103,7 +104,7 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s
Lastmodifiedby: lastmodifiedby,
}
// in any case - insert record
_, err := dbobj.createRecord(TblName.Consent, structs.Map(ev))
_, err := dbobj.store.CreateRecord(storage.TblName.Consent, structs.Map(ev))
if err != nil {
fmt.Printf("error to insert record: %s\n", err)
return false, err
@@ -115,7 +116,7 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s
func (dbobj dbcon) linkConsentRecords(userTOKEN string, mode string, usercode string) error {
bdoc := bson.M{}
bdoc["token"] = userTOKEN
_, err := dbobj.updateRecord2(TblName.Consent, "token", "", "who", usercode, &bdoc, nil)
_, err := dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", "", "who", usercode, &bdoc, nil)
return err
}
@@ -135,9 +136,9 @@ func (dbobj dbcon) withdrawConsentRecord(userTOKEN string, brief string, mode st
bdoc["lastmodifiedby"] = lastmodifiedby
if len(userTOKEN) > 0 {
fmt.Printf("%s %s\n", userTOKEN, brief)
dbobj.updateRecord2(TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil)
dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil)
} else {
dbobj.updateRecord2(TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil)
dbobj.store.UpdateRecord2(storage.TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil)
}
return nil
}
@@ -145,7 +146,7 @@ func (dbobj dbcon) withdrawConsentRecord(userTOKEN string, brief string, mode st
// link consent to user?
func (dbobj dbcon) listConsentRecords(userTOKEN string) ([]byte, int, error) {
records, err := dbobj.getList(TblName.Consent, "token", userTOKEN, 0, 0)
records, err := dbobj.store.GetList(storage.TblName.Consent, "token", userTOKEN, 0, 0)
if err != nil {
return nil, 0, err
}
@@ -162,7 +163,7 @@ func (dbobj dbcon) listConsentRecords(userTOKEN string) ([]byte, int, error) {
}
func (dbobj dbcon) viewConsentRecord(userTOKEN string, brief string) ([]byte, error) {
record, err := dbobj.getRecord2(TblName.Consent, "token", userTOKEN, "brief", brief)
record, err := dbobj.store.GetRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief)
if record == nil || err != nil {
return nil, err
}
@@ -176,14 +177,14 @@ func (dbobj dbcon) viewConsentRecord(userTOKEN string, brief string) ([]byte, er
func (dbobj dbcon) filterConsentRecords(brief string, offset int32, limit int32) ([]byte, int64, error) {
//var results []*auditEvent
count, err := dbobj.countRecords(TblName.Consent, "brief", brief)
count, err := dbobj.store.CountRecords(storage.TblName.Consent, "brief", brief)
if err != nil {
return nil, 0, err
}
if count == 0 {
return []byte("[]"), 0, err
}
records, err := dbobj.getList(TblName.Consent, "brief", brief, offset, limit)
records, err := dbobj.store.GetList(storage.TblName.Consent, "brief", brief, offset, limit)
if err != nil {
return nil, 0, err
}
@@ -201,7 +202,7 @@ func (dbobj dbcon) filterConsentRecords(brief string, offset int32, limit int32)
}
func (dbobj dbcon) getConsentBriefs() ([]byte, int64, error) {
records, err := dbobj.getUniqueList(TblName.Consent, "brief")
records, err := dbobj.store.GetUniqueList(storage.TblName.Consent, "brief")
if err != nil {
return nil, 0, err
}
@@ -223,7 +224,7 @@ func (dbobj dbcon) getConsentBriefs() ([]byte, int64, error) {
}
func (dbobj dbcon) expireConsentRecords(notifyURL string) error {
records, err := dbobj.getExpiring(TblName.Consent, "status", "yes")
records, err := dbobj.store.GetExpiring(storage.TblName.Consent, "status", "yes")
if err != nil {
return err
}
@@ -238,11 +239,11 @@ func (dbobj dbcon) expireConsentRecords(notifyURL string) error {
fmt.Printf("This consent record is expired: %s - %s\n", userTOKEN, brief)
if len(userTOKEN) > 0 {
fmt.Printf("%s %s\n", userTOKEN, brief)
dbobj.updateRecord2(TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil)
dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil)
notifyConsentChange(notifyURL, brief, "expired", "token", userTOKEN)
} else {
usercode := rec["who"].(string)
dbobj.updateRecord2(TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil)
dbobj.store.UpdateRecord2(storage.TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil)
notifyConsentChange(notifyURL, brief, "expired", rec["mode"].(string), usercode)
}

View File

@@ -5,6 +5,7 @@ import (
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
@@ -44,18 +45,18 @@ func (dbobj dbcon) saveUserRequest(action string, token string, app string, brie
if len(brief) > 0 {
bdoc["brief"] = brief
}
_, err := dbobj.createRecord(TblName.Requests, &bdoc)
_, err := dbobj.store.CreateRecord(storage.TblName.Requests, &bdoc)
return rtoken, err
}
func (dbobj dbcon) getRequests(status string, offset int32, limit int32) ([]byte, int64, error) {
//var results []*auditEvent
count, err := dbobj.countRecords(TblName.Requests, "status", status)
count, err := dbobj.store.CountRecords(storage.TblName.Requests, "status", status)
if err != nil {
return nil, 0, err
}
var results []bson.M
records, err := dbobj.getList(TblName.Requests, "status", status, offset, limit)
records, err := dbobj.store.GetList(storage.TblName.Requests, "status", status, offset, limit)
if err != nil {
return nil, 0, err
}
@@ -77,7 +78,7 @@ func (dbobj dbcon) getRequests(status string, offset int32, limit int32) ([]byte
}
func (dbobj dbcon) getRequest(rtoken string) (bson.M, error) {
record, err := dbobj.getRecord(TblName.Requests, "rtoken", rtoken)
record, err := dbobj.store.GetRecord(storage.TblName.Requests, "rtoken", rtoken)
if err != nil {
return record, err
}
@@ -106,5 +107,5 @@ func (dbobj dbcon) updateRequestStatus(rtoken string, status string) {
bdoc := bson.M{}
bdoc["status"] = status
//fmt.Printf("op json: %s\n", update)
dbobj.updateRecord(TblName.Requests, "rtoken", rtoken, &bdoc)
dbobj.store.UpdateRecord(storage.TblName.Requests, "rtoken", rtoken, &bdoc)
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/julienschmidt/httprouter"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
@@ -81,7 +82,7 @@ func (e mainEnv) getUserSessions(w http.ResponseWriter, r *http.Request, ps http
mode := ps.ByName("mode")
if mode == "session" {
e.db.deleteExpired(TblName.Sessions, "session", address)
e.db.store.DeleteExpired(storage.TblName.Sessions, "session", address)
e.getSession(w, r, address)
return
}
@@ -114,7 +115,7 @@ func (e mainEnv) getUserSessions(w http.ResponseWriter, r *http.Request, ps http
if e.enforceAuth(w, r, event) == "" {
return
}
e.db.deleteExpired(TblName.Sessions, "token", userTOKEN)
e.db.store.DeleteExpired(storage.TblName.Sessions, "token", userTOKEN)
args := r.URL.Query()
var offset int32
var limit int32 = 10

View File

@@ -7,6 +7,7 @@ import (
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
@@ -39,7 +40,7 @@ func (dbobj dbcon) createSessionRecord(userTOKEN string, expiration string, data
bdoc["endtime"] = endtime
bdoc["when"] = now
bdoc["data"] = encodedStr
_, err = dbobj.createRecord(TblName.Sessions, bdoc)
_, err = dbobj.store.CreateRecord(storage.TblName.Sessions, bdoc)
if err != nil {
return "", err
}
@@ -47,7 +48,7 @@ func (dbobj dbcon) createSessionRecord(userTOKEN string, expiration string, data
}
func (dbobj dbcon) getUserSession(sessionUUID string) (int32, []byte, string, error) {
record, err := dbobj.getRecord(TblName.Sessions, "session", sessionUUID)
record, err := dbobj.store.GetRecord(storage.TblName.Sessions, "session", sessionUUID)
if err != nil {
return 0, nil, "", err
}
@@ -82,12 +83,12 @@ func (dbobj dbcon) getUserSessionsByToken(userTOKEN string, offset int32, limit
return nil, 0, err
}
count, err := dbobj.countRecords(TblName.Sessions, "token", userTOKEN)
count, err := dbobj.store.CountRecords(storage.TblName.Sessions, "token", userTOKEN)
if err != nil {
return nil, 0, err
}
records, err := dbobj.getList(TblName.Sessions, "token", userTOKEN, offset, limit)
records, err := dbobj.store.GetList(storage.TblName.Sessions, "token", userTOKEN, offset, limit)
if err != nil {
return nil, 0, err
}

View File

@@ -7,6 +7,7 @@ import (
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
@@ -57,7 +58,7 @@ func (dbobj dbcon) saveSharedRecord(userTOKEN string, fields string, expiration
if len(session) > 0 {
bdoc["session"] = session
}
_, err = dbobj.createRecord(TblName.Sharedrecords, bdoc)
_, err = dbobj.store.CreateRecord(storage.TblName.Sharedrecords, bdoc)
if err != nil {
return "", err
}
@@ -69,7 +70,7 @@ func (dbobj dbcon) getSharedRecord(recordUUID string) (checkRecordResult, error)
if isValidUUID(recordUUID) == false {
return result, errors.New("failed to authenticate")
}
record, err := dbobj.getRecord(TblName.Sharedrecords, "record", recordUUID)
record, err := dbobj.store.GetRecord(storage.TblName.Sharedrecords, "record", recordUUID)
if record == nil || err != nil {
return result, errors.New("failed to authenticate")
}

View File

@@ -1,19 +1,8 @@
package main
// github.com/mattn/go-sqlite3
// This project is using the following golang internal database:
// https://godoc.org/modernc.org/ql
// go build modernc.org/ql/ql
// go install modernc.org/ql/ql
// https://stackoverflow.com/questions/21986780/is-it-possible-to-retrieve-a-column-value-by-name-using-golang-database-sql
package storage
// https://stackoverflow.com/questions/21986780/is-it-possible-to-retrieve-a-column-value-by-name-using-golang-database-sql
import (
"crypto/md5"
"database/sql"
"fmt"
"log"
@@ -33,13 +22,38 @@ var (
knownApps []string
)
type dbcon struct {
db *sql.DB
masterKey []byte
hash []byte
// Tbl is used to store table id
type Tbl int
// listTbls used to store list of tables
type listTbls struct {
Users Tbl
Audit Tbl
Xtokens Tbl
Consent Tbl
Sessions Tbl
Requests Tbl
Sharedrecords Tbl
}
func dbExists(filepath *string) bool {
// TblName is enum of tables
var TblName = &listTbls{
Users: 0,
Audit: 1,
Xtokens: 2,
Consent: 3,
Sessions: 4,
Requests: 5,
Sharedrecords: 6,
}
// DBStorage struct is used to store database object
type DBStorage struct {
db *sql.DB
}
// DBExists function checks if database exists
func DBExists(filepath *string) bool {
dbfile := "./databunker.db"
if filepath != nil {
if len(*filepath) > 0 {
@@ -52,8 +66,8 @@ func dbExists(filepath *string) bool {
return true
}
func newDB(masterKey []byte, filepath *string) (dbcon, error) {
dbobj := dbcon{nil, nil, nil}
// OpenDB function opens the database
func OpenDB(filepath *string) (DBStorage, error) {
dbfile := "./databunker.db"
if filepath != nil {
if len(*filepath) > 0 {
@@ -90,8 +104,7 @@ func newDB(masterKey []byte, filepath *string) (dbcon, error) {
if err != nil {
log.Fatalf("Error on vacuum database command")
}
hash := md5.Sum(masterKey)
dbobj = dbcon{db, masterKey, hash[:]}
dbobj := DBStorage{db}
// load all table names
q := "select name from sqlite_master where type ='table'"
@@ -111,7 +124,8 @@ func newDB(masterKey []byte, filepath *string) (dbcon, error) {
return dbobj, nil
}
func (dbobj dbcon) initDB() error {
// InitDB function creates tables and indexes
func (dbobj DBStorage) InitDB() error {
var err error
if err = initUsers(dbobj.db); err != nil {
return err
@@ -137,11 +151,13 @@ func (dbobj dbcon) initDB() error {
return nil
}
func (dbobj dbcon) closeDB() {
// CloseDB function closes the open database
func (dbobj DBStorage) CloseDB() {
dbobj.db.Close()
}
func (dbobj dbcon) backupDB(w http.ResponseWriter) {
// BackupDB function backups existing databsae and prints database structure to http.ResponseWriter
func (dbobj DBStorage) BackupDB(w http.ResponseWriter) {
err := sqlite3dump.DumpDB(dbobj.db, w)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
@@ -149,7 +165,8 @@ func (dbobj dbcon) backupDB(w http.ResponseWriter) {
}
}
func (dbobj dbcon) initUserApps() error {
// InitUserApps initialises list of databases.
func (dbobj DBStorage) InitUserApps() error {
return nil
}
@@ -283,7 +300,8 @@ func getTable(t Tbl) string {
return "users"
}
func (dbobj dbcon) createRecordInTable(tbl string, data interface{}) (int, error) {
// CreateRecordInTable creates new record
func (dbobj DBStorage) CreateRecordInTable(tbl string, data interface{}) (int, error) {
fields, values := decodeFieldsValues(data)
valuesInQ := "$1"
for idx := range values {
@@ -309,13 +327,15 @@ func (dbobj dbcon) createRecordInTable(tbl string, data interface{}) (int, error
return 1, nil
}
func (dbobj dbcon) createRecord(t Tbl, data interface{}) (int, error) {
// CreateRecord creates new record
func (dbobj DBStorage) CreateRecord(t Tbl, data interface{}) (int, error) {
//if reflect.TypeOf(value) == reflect.TypeOf("string")
tbl := getTable(t)
return dbobj.createRecordInTable(tbl, data)
return dbobj.CreateRecordInTable(tbl, data)
}
func (dbobj dbcon) countRecords(t Tbl, keyName string, keyValue string) (int64, error) {
// CountRecords returns number of records that match filter
func (dbobj DBStorage) CountRecords(t Tbl, keyName string, keyValue string) (int64, error) {
tbl := getTable(t)
q := "select count(*) from " + tbl + " WHERE " + escapeName(keyName) + "=$1"
fmt.Printf("q: %s\n", q)
@@ -338,18 +358,21 @@ func (dbobj dbcon) countRecords(t Tbl, keyName string, keyValue string) (int64,
return int64(count), nil
}
func (dbobj dbcon) updateRecord(t Tbl, keyName string, keyValue string, bdoc *bson.M) (int64, error) {
// UpdateRecord updates database record
func (dbobj DBStorage) UpdateRecord(t Tbl, keyName string, keyValue string, bdoc *bson.M) (int64, error) {
table := getTable(t)
filter := escapeName(keyName) + "=\"" + keyValue + "\""
return dbobj.updateRecordInTableDo(table, filter, bdoc, nil)
}
func (dbobj dbcon) updateRecordInTable(table string, keyName string, keyValue string, bdoc *bson.M) (int64, error) {
// UpdateRecordInTable updates database record
func (dbobj DBStorage) UpdateRecordInTable(table string, keyName string, keyValue string, bdoc *bson.M) (int64, error) {
filter := escapeName(keyName) + "=\"" + keyValue + "\""
return dbobj.updateRecordInTableDo(table, filter, bdoc, nil)
}
func (dbobj dbcon) updateRecord2(t Tbl, keyName string, keyValue string,
// UpdateRecord2 updates database record
func (dbobj DBStorage) UpdateRecord2(t Tbl, keyName string, keyValue string,
keyName2 string, keyValue2 string, bdoc *bson.M, bdel *bson.M) (int64, error) {
table := getTable(t)
filter := escapeName(keyName) + "=\"" + keyValue + "\" AND " +
@@ -357,14 +380,15 @@ func (dbobj dbcon) updateRecord2(t Tbl, keyName string, keyValue string,
return dbobj.updateRecordInTableDo(table, filter, bdoc, bdel)
}
func (dbobj dbcon) updateRecordInTable2(table string, keyName string,
// UpdateRecordInTable2 updates database record
func (dbobj DBStorage) UpdateRecordInTable2(table string, keyName string,
keyValue string, keyName2 string, keyValue2 string, bdoc *bson.M, bdel *bson.M) (int64, error) {
filter := escapeName(keyName) + "=\"" + keyValue + "\" AND " +
escapeName(keyName2) + "=\"" + keyValue2 + "\""
return dbobj.updateRecordInTableDo(table, filter, bdoc, bdel)
}
func (dbobj dbcon) updateRecordInTableDo(table string, filter string, bdoc *bson.M, bdel *bson.M) (int64, error) {
func (dbobj DBStorage) updateRecordInTableDo(table string, filter string, bdoc *bson.M, bdel *bson.M) (int64, error) {
op, values := decodeForUpdate(bdoc, bdel)
q := "update " + table + " SET " + op + " WHERE " + filter
fmt.Printf("q: %s\n", q)
@@ -385,7 +409,8 @@ func (dbobj dbcon) updateRecordInTableDo(table string, filter string, bdoc *bson
return num, err
}
func (dbobj dbcon) getRecord(t Tbl, keyName string, keyValue string) (bson.M, error) {
// GetRecord returns specific record from database
func (dbobj DBStorage) GetRecord(t Tbl, keyName string, keyValue string) (bson.M, error) {
table := getTable(t)
q := "select * from " + table + " WHERE " + escapeName(keyName) + "=$1"
values := make([]interface{}, 0)
@@ -393,14 +418,16 @@ func (dbobj dbcon) getRecord(t Tbl, keyName string, keyValue string) (bson.M, er
return dbobj.getRecordInTableDo(q, values)
}
func (dbobj dbcon) getRecordInTable(table string, keyName string, keyValue string) (bson.M, error) {
// GetRecordInTable returns specific record from database
func (dbobj DBStorage) GetRecordInTable(table string, keyName string, keyValue string) (bson.M, error) {
q := "select * from " + table + " WHERE " + escapeName(keyName) + "=$1"
values := make([]interface{}, 0)
values = append(values, keyValue)
return dbobj.getRecordInTableDo(q, values)
}
func (dbobj dbcon) getRecord2(t Tbl, keyName string, keyValue string,
// GetRecord2 returns specific record from database
func (dbobj DBStorage) GetRecord2(t Tbl, keyName string, keyValue string,
keyName2 string, keyValue2 string) (bson.M, error) {
table := getTable(t)
q := "select * from " + table + " WHERE " + escapeName(keyName) + "=$1 AND " +
@@ -411,7 +438,7 @@ func (dbobj dbcon) getRecord2(t Tbl, keyName string, keyValue string,
return dbobj.getRecordInTableDo(q, values)
}
func (dbobj dbcon) getRecordInTableDo(q string, values []interface{}) (bson.M, error) {
func (dbobj DBStorage) getRecordInTableDo(q string, values []interface{}) (bson.M, error) {
fmt.Printf("query: %s\n", q)
tx, err := dbobj.db.Begin()
@@ -490,14 +517,14 @@ func (dbobj dbcon) getRecordInTableDo(q string, values []interface{}) (bson.M, e
return recBson, nil
}
/*
func (dbobj dbcon) deleteRecord(t Tbl, keyName string, keyValue string) (int64, error) {
// DeleteRecord deletes record in database
func (dbobj DBStorage) DeleteRecord(t Tbl, keyName string, keyValue string) (int64, error) {
tbl := getTable(t)
return dbobj.deleteRecordInTable(tbl, keyName, keyValue)
return dbobj.DeleteRecordInTable(tbl, keyName, keyValue)
}
*/
func (dbobj dbcon) deleteRecordInTable(table string, keyName string, keyValue string) (int64, error) {
// DeleteRecordInTable deletes record in database
func (dbobj DBStorage) DeleteRecordInTable(table string, keyName string, keyValue string) (int64, error) {
q := "delete from " + table + " WHERE " + escapeName(keyName) + "=$1"
fmt.Printf("q: %s\n", q)
@@ -517,12 +544,13 @@ func (dbobj dbcon) deleteRecordInTable(table string, keyName string, keyValue st
return num, err
}
func (dbobj dbcon) deleteRecord2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
// DeleteRecord2 deletes record in database
func (dbobj DBStorage) DeleteRecord2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
tbl := getTable(t)
return dbobj.deleteRecordInTable2(tbl, keyName, keyValue, keyName2, keyValue2)
}
func (dbobj dbcon) deleteRecordInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
func (dbobj DBStorage) deleteRecordInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
q := "delete from " + table + " WHERE " + escapeName(keyName) + "=$1 AND " +
escapeName(keyName2) + "=$2"
fmt.Printf("q: %s\n", q)
@@ -544,14 +572,14 @@ func (dbobj dbcon) deleteRecordInTable2(table string, keyName string, keyValue s
}
/*
func (dbobj dbcon) deleteDuplicate2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
func (dbobj DBStorage) deleteDuplicate2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
tbl := getTable(t)
return dbobj.deleteDuplicateInTable2(tbl, keyName, keyValue, keyName2, keyValue2)
}
*/
/*
func (dbobj dbcon) deleteDuplicateInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
func (dbobj DBStorage) deleteDuplicateInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) {
q := "delete from " + table + " where " + escapeName(keyName) + "=$1 AND " +
escapeName(keyName2) + "=$2 AND rowid not in " +
"(select min(rowid) from " + table + " where " + escapeName(keyName) + "=$3 AND " +
@@ -575,7 +603,8 @@ func (dbobj dbcon) deleteDuplicateInTable2(table string, keyName string, keyValu
}
*/
func (dbobj dbcon) deleteExpired0(t Tbl, expt int32) (int64, error) {
// DeleteExpired0 deletes expired records in database
func (dbobj DBStorage) DeleteExpired0(t Tbl, expt int32) (int64, error) {
table := getTable(t)
now := int32(time.Now().Unix())
q := fmt.Sprintf("delete from %s WHERE `when`>0 AND `when`<%d", table, now-expt)
@@ -598,7 +627,8 @@ func (dbobj dbcon) deleteExpired0(t Tbl, expt int32) (int64, error) {
return num, err
}
func (dbobj dbcon) deleteExpired(t Tbl, keyName string, keyValue string) (int64, error) {
// DeleteExpired deletes expired records in database
func (dbobj DBStorage) DeleteExpired(t Tbl, keyName string, keyValue string) (int64, error) {
table := getTable(t)
q := "delete from " + table + " WHERE endtime>0 AND endtime<$1 AND " + escapeName(keyName) + "=$2"
fmt.Printf("q: %s\n", q)
@@ -620,7 +650,8 @@ func (dbobj dbcon) deleteExpired(t Tbl, keyName string, keyValue string) (int64,
return num, err
}
func (dbobj dbcon) cleanupRecord(t Tbl, keyName string, keyValue string, data interface{}) (int64, error) {
// CleanupRecord nullifies specific feilds in records in database
func (dbobj DBStorage) CleanupRecord(t Tbl, keyName string, keyValue string, data interface{}) (int64, error) {
tbl := getTable(t)
cleanup := decodeForCleanup(data)
q := "update " + tbl + " SET " + cleanup + " WHERE " + escapeName(keyName) + "=$1"
@@ -642,7 +673,8 @@ func (dbobj dbcon) cleanupRecord(t Tbl, keyName string, keyValue string, data in
return num, err
}
func (dbobj dbcon) getExpiring(t Tbl, keyName string, keyValue string) ([]bson.M, error) {
// GetExpiring get records that are expiring
func (dbobj DBStorage) GetExpiring(t Tbl, keyName string, keyValue string) ([]bson.M, error) {
table := getTable(t)
now := int32(time.Now().Unix())
q := fmt.Sprintf("select * from %s WHERE endtime>0 AND endtime<%d AND %s=$1", table, now, escapeName(keyName))
@@ -650,7 +682,8 @@ func (dbobj dbcon) getExpiring(t Tbl, keyName string, keyValue string) ([]bson.M
return dbobj.getListDo(q, keyValue)
}
func (dbobj dbcon) getUniqueList(t Tbl, keyName string) ([]bson.M, error) {
// GetUniqueList returns a unique list of values from specific column in database
func (dbobj DBStorage) GetUniqueList(t Tbl, keyName string) ([]bson.M, error) {
table := getTable(t)
keyName = escapeName(keyName)
q := "select distinct " + keyName + " from " + table + " ORDER BY " + keyName
@@ -658,7 +691,8 @@ func (dbobj dbcon) getUniqueList(t Tbl, keyName string) ([]bson.M, error) {
return dbobj.getListDo(q, "")
}
func (dbobj dbcon) getList(t Tbl, keyName string, keyValue string, start int32, limit int32) ([]bson.M, error) {
// GetList is used to return list of rows. It can be used to return values using pager.
func (dbobj DBStorage) GetList(t Tbl, keyName string, keyValue string, start int32, limit int32) ([]bson.M, error) {
table := getTable(t)
if limit > 100 {
limit = 100
@@ -675,7 +709,7 @@ func (dbobj dbcon) getList(t Tbl, keyName string, keyValue string, start int32,
return dbobj.getListDo(q, keyValue)
}
func (dbobj dbcon) getListDo(q string, keyValue string) ([]bson.M, error) {
func (dbobj DBStorage) getListDo(q string, keyValue string) ([]bson.M, error) {
tx, err := dbobj.db.Begin()
if err != nil {
return nil, err
@@ -753,11 +787,13 @@ func (dbobj dbcon) getListDo(q string, keyValue string) ([]bson.M, error) {
return results, nil
}
func (dbobj dbcon) getAllTables() ([]string, error) {
// GetAllTables returns all tables that exists in database
func (dbobj DBStorage) GetAllTables() ([]string, error) {
return knownApps, nil
}
func (dbobj dbcon) validateNewApp(appName string) bool {
// ValidateNewApp function check if app name can be part of the table name
func (dbobj DBStorage) ValidateNewApp(appName string) bool {
if contains(knownApps, appName) == true {
return true
}
@@ -785,7 +821,8 @@ func execQueries(db *sql.DB, queries []string) error {
return nil
}
func (dbobj dbcon) indexNewApp(appName string) {
// IndexNewApp creates a new app table and creates indexes for it.
func (dbobj DBStorage) IndexNewApp(appName string) {
if contains(knownApps, appName) == false {
// it is a new app, create an index
fmt.Printf("This is a new app, creating table & index for: %s\n", appName)
@@ -922,3 +959,13 @@ func initSessions(db *sql.DB) error {
`CREATE INDEX sessions_session ON sessions (session);`}
return execQueries(db, queries)
}
func contains(slice []string, item string) bool {
set := make(map[string]struct{}, len(slice))
for _, s := range slice {
set[s] = struct{}{}
}
_, ok := set[item]
return ok
}

View File

@@ -25,7 +25,7 @@ func (e mainEnv) userappNew(w http.ResponseWriter, r *http.Request, ps httproute
returnError(w, r, "bad appname", 405, nil, event)
return
}
if e.db.validateNewApp("app_"+appName) == false {
if e.db.store.ValidateNewApp("app_"+appName) == false {
returnError(w, r, "db limitation", 405, nil, event)
return
}

View File

@@ -14,7 +14,7 @@ import (
func (dbobj dbcon) getUserApp(userTOKEN string, appName string) ([]byte, error) {
record, err := dbobj.getRecordInTable("app_"+appName, "token", userTOKEN)
record, err := dbobj.store.GetRecordInTable("app_"+appName, "token", userTOKEN)
if err != nil {
return nil, err
}
@@ -31,7 +31,7 @@ func (dbobj dbcon) createAppRecord(jsonData []byte, userTOKEN string, appName st
if err != nil {
return userTOKEN, err
}
dbobj.indexNewApp("app_" + appName)
dbobj.store.IndexNewApp("app_" + appName)
//var bdoc interface{}
bdoc := bson.M{}
@@ -46,14 +46,14 @@ func (dbobj dbcon) createAppRecord(jsonData []byte, userTOKEN string, appName st
event.Record = userTOKEN
}
//fmt.Println("creating new app")
record, err := dbobj.getRecordInTable("app_"+appName, "token", userTOKEN)
record, err := dbobj.store.GetRecordInTable("app_"+appName, "token", userTOKEN)
if err != nil {
return userTOKEN, err
}
if record != nil {
_, err = dbobj.updateRecordInTable("app_"+appName, "token", userTOKEN, &bdoc)
_, err = dbobj.store.UpdateRecordInTable("app_"+appName, "token", userTOKEN, &bdoc)
} else {
_, err = dbobj.createRecordInTable("app_"+appName, bdoc)
_, err = dbobj.store.CreateRecordInTable("app_"+appName, bdoc)
}
return userTOKEN, err
}
@@ -72,7 +72,7 @@ func (dbobj dbcon) updateAppRecord(jsonDataPatch []byte, userTOKEN string, appNa
return userTOKEN, err
}
record, err := dbobj.getRecordInTable("app_"+appName, "token", userTOKEN)
record, err := dbobj.store.GetRecordInTable("app_"+appName, "token", userTOKEN)
if err != nil {
return userTOKEN, err
}
@@ -113,7 +113,7 @@ func (dbobj dbcon) updateAppRecord(jsonDataPatch []byte, userTOKEN string, appNa
// here I add md5 of the original record to filter
// to make sure this record was not change by other thread
fmt.Println("update user app")
result, err := dbobj.updateRecordInTable2("app_"+appName, "token", userTOKEN, "md5", sig, &bdoc, nil)
result, err := dbobj.store.UpdateRecordInTable2("app_"+appName, "token", userTOKEN, "md5", sig, &bdoc, nil)
if err != nil {
return userTOKEN, err
}
@@ -138,14 +138,14 @@ func (dbobj dbcon) listUserApps(userTOKEN string) ([]byte, error) {
// not found
return nil, err
}
allCollections, err := dbobj.getAllTables()
allCollections, err := dbobj.store.GetAllTables()
if err != nil {
return nil, err
}
var result []string
for _, colName := range allCollections {
if strings.HasPrefix(colName, "app_") {
record, err := dbobj.getRecordInTable(colName, "token", userTOKEN)
record, err := dbobj.store.GetRecordInTable(colName, "token", userTOKEN)
if err != nil {
return nil, err
}
@@ -164,7 +164,7 @@ func (dbobj dbcon) listUserApps(userTOKEN string) ([]byte, error) {
func (dbobj dbcon) listAllAppsOnly() ([]string, error) {
//fmt.Println("dump list of collections")
allCollections, err := dbobj.getAllTables()
allCollections, err := dbobj.store.GetAllTables()
if err != nil {
return nil, err
}
@@ -179,7 +179,7 @@ func (dbobj dbcon) listAllAppsOnly() ([]string, error) {
func (dbobj dbcon) listAllApps() ([]byte, error) {
//fmt.Println("dump list of collections")
allCollections, err := dbobj.getAllTables()
allCollections, err := dbobj.store.GetAllTables()
if err != nil {
return nil, err
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"github.com/julienschmidt/httprouter"
"github.com/paranoidguy/databunker/src/storage"
)
func (e mainEnv) userNew(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
@@ -75,17 +76,17 @@ func (e mainEnv) userNew(w http.ResponseWriter, r *http.Request, ps httprouter.P
}
if len(parsedData.emailIdx) > 0 && len(parsedData.phoneIdx) > 0 {
// delete duplicate consent records for user
records, _ := e.db.getList(TblName.Consent, "who", parsedData.emailIdx, 0, 0)
records, _ := e.db.store.GetList(storage.TblName.Consent, "who", parsedData.emailIdx, 0, 0)
var briefCodes []string
for _, val := range records {
//fmt.Printf("adding brief code: %s\n", val["brief"].(string))
briefCodes = append(briefCodes, val["brief"].(string))
}
records, _ = e.db.getList(TblName.Consent, "who", parsedData.phoneIdx, 0, 0)
records, _ = e.db.store.GetList(storage.TblName.Consent, "who", parsedData.phoneIdx, 0, 0)
for _, val := range records {
//fmt.Printf("XXX checking brief code for duplicates: %s\n", val["brief"].(string))
if contains(briefCodes, val["brief"].(string)) == true {
e.db.deleteRecord2(TblName.Consent, "token", userTOKEN, "who", parsedData.phoneIdx)
e.db.store.DeleteRecord2(storage.TblName.Consent, "token", userTOKEN, "who", parsedData.phoneIdx)
}
}
}

View File

@@ -11,6 +11,7 @@ import (
jsonpatch "github.com/evanphx/json-patch"
uuid "github.com/hashicorp/go-uuid"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
@@ -56,7 +57,7 @@ func (dbobj dbcon) createUserRecord(parsedData userJSON, event *auditEvent) (str
event.Record = userTOKEN
}
//fmt.Println("creating new user")
_, err = dbobj.createRecord(TblName.Users, bdoc)
_, err = dbobj.store.CreateRecord(storage.TblName.Users, bdoc)
if err != nil {
fmt.Printf("error in create!\n")
return "", err
@@ -72,7 +73,7 @@ func (dbobj dbcon) generateTempLoginCode(userTOKEN string) int32 {
expired := int32(time.Now().Unix()) + 60
bdoc["tempcodeexp"] = expired
//fmt.Printf("op json: %s\n", update)
dbobj.updateRecord(TblName.Users, "token", userTOKEN, &bdoc)
dbobj.store.UpdateRecord(storage.TblName.Users, "token", userTOKEN, &bdoc)
return rnd
}
@@ -84,7 +85,7 @@ func (dbobj dbcon) generateDemoLoginCode(userTOKEN string) int32 {
expired := int32(time.Now().Unix()) + 60
bdoc["tempcodeexp"] = expired
//fmt.Printf("op json: %s\n", update)
dbobj.updateRecord(TblName.Users, "token", userTOKEN, &bdoc)
dbobj.store.UpdateRecord(storage.TblName.Users, "token", userTOKEN, &bdoc)
return rnd
}
@@ -239,7 +240,7 @@ func (dbobj dbcon) updateUserRecordDo(jsonDataPatch []byte, userTOKEN string, ev
//filter2 := bson.D{{"token", userTOKEN}, {"md5", sig}}
//fmt.Printf("op json: %s\n", update)
result, err := dbobj.updateRecord2(TblName.Users, "token", userTOKEN, "md5", sig, &bdoc, &bdel)
result, err := dbobj.store.UpdateRecord2(storage.TblName.Users, "token", userTOKEN, "md5", sig, &bdoc, &bdel)
if err != nil {
return nil, nil, false, err
}
@@ -257,7 +258,7 @@ func (dbobj dbcon) updateUserRecordDo(jsonDataPatch []byte, userTOKEN string, ev
}
func (dbobj dbcon) lookupUserRecord(userTOKEN string) (bson.M, error) {
return dbobj.getRecord(TblName.Users, "token", userTOKEN)
return dbobj.store.GetRecord(storage.TblName.Users, "token", userTOKEN)
}
func (dbobj dbcon) lookupUserRecordByIndex(indexName string, indexValue string, conf Config) (bson.M, error) {
@@ -271,7 +272,7 @@ func (dbobj dbcon) lookupUserRecordByIndex(indexName string, indexValue string,
}
idxStringHashHex := hashString(dbobj.hash, indexValue)
fmt.Printf("loading by %s, value: %s\n", indexName, indexValue)
return dbobj.getRecord(TblName.Users, indexName+"idx", idxStringHashHex)
return dbobj.store.GetRecord(storage.TblName.Users, indexName+"idx", idxStringHashHex)
}
func (dbobj dbcon) getUser(userTOKEN string) ([]byte, error) {
@@ -341,11 +342,11 @@ func (dbobj dbcon) deleteUserRecord(userTOKEN string) (bool, error) {
// delete all user app records
for _, appName := range userApps {
appNameFull := "app_" + appName
dbobj.deleteRecordInTable(appNameFull, "token", userTOKEN)
dbobj.store.DeleteRecordInTable(appNameFull, "token", userTOKEN)
}
//delete in audit
dbobj.deleteRecordInTable("audit", "record", userTOKEN)
dbobj.deleteRecordInTable("sessions", "token", userTOKEN)
dbobj.store.DeleteRecord(storage.TblName.Audit, "record", userTOKEN)
dbobj.store.DeleteRecord(storage.TblName.Sessions, "token", userTOKEN)
// cleanup user record
bdel := bson.M{}
bdel["data"] = ""
@@ -353,7 +354,7 @@ func (dbobj dbcon) deleteUserRecord(userTOKEN string) (bool, error) {
bdel["loginidx"] = ""
bdel["emailidx"] = ""
bdel["phoneidx"] = ""
result, err := dbobj.cleanupRecord(TblName.Users, "token", userTOKEN, bdel)
result, err := dbobj.store.CleanupRecord(storage.TblName.Users, "token", userTOKEN, bdel)
if err != nil {
return false, err
}
@@ -375,7 +376,7 @@ func (dbobj dbcon) wipeRecord(userTOKEN string) (bool, error) {
dbobj.deleteRecordInTable(appNameFull, "token", userTOKEN)
}
// delete user record
result, err := dbobj.deleteRecord(TblName.Users, "token", userTOKEN)
result, err := dbobj.store.DeleteRecord(storage.TblName.Users, "token", userTOKEN)
if err != nil {
return false, err
}

View File

@@ -6,13 +6,14 @@ import (
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/paranoidguy/databunker/src/storage"
"go.mongodb.org/mongo-driver/bson"
)
var rootXTOKEN string
func (dbobj dbcon) getRootXtoken() (string, error) {
record, err := dbobj.getRecord(TblName.Xtokens, "type", "root")
record, err := dbobj.store.GetRecord(storage.TblName.Xtokens, "type", "root")
if record == nil || err != nil {
return "", err
}
@@ -34,7 +35,7 @@ func (dbobj dbcon) createRootXtoken() (string, error) {
bdoc := bson.M{}
bdoc["xtoken"] = hashString(dbobj.hash, rootToken)
bdoc["type"] = "root"
_, err = dbobj.createRecord(TblName.Xtokens, bdoc)
_, err = dbobj.store.CreateRecord(storage.TblName.Xtokens, bdoc)
if err != nil {
return rootToken, err
}
@@ -59,7 +60,7 @@ func (dbobj dbcon) generateUserLoginXtoken(userTOKEN string) (string, error) {
bdoc["xtoken"] = hashString(dbobj.hash, tokenUUID)
bdoc["type"] = "login"
bdoc["endtime"] = expired
_, err = dbobj.createRecord(TblName.Xtokens, bdoc)
_, err = dbobj.store.CreateRecord(storage.TblName.Xtokens, bdoc)
return tokenUUID, err
}
@@ -74,7 +75,7 @@ func (dbobj dbcon) checkXtoken(xtokenUUID string) bool {
fmt.Println("It is a root token")
return true
}
record, err := dbobj.getRecord(TblName.Xtokens, "xtoken", xtokenHashed)
record, err := dbobj.store.GetRecord(storage.TblName.Xtokens, "xtoken", xtokenHashed)
if record == nil || err != nil {
return false
}
@@ -99,7 +100,7 @@ func (dbobj dbcon) checkUserAuthXToken(xtokenUUID string) (tokenAuthResult, erro
result.name = "root"
return result, nil
}
record, err := dbobj.getRecord(TblName.Xtokens, "xtoken", xtokenHashed)
record, err := dbobj.store.GetRecord(storage.TblName.Xtokens, "xtoken", xtokenHashed)
if record == nil || err != nil {
return result, errors.New("failed to authenticate")
}