diff --git a/build.sh b/build.sh index f109eda..50a94ed 100755 --- a/build.sh +++ b/build.sh @@ -1,5 +1,5 @@ # build without debug -go build -ldflags "-w" -o databunker ./src/bunker.go ./src/qldb.go ./src/xtokens_db.go \ +go build -ldflags "-w" -o databunker ./src/bunker.go ./src/xtokens_db.go \ ./src/utils.go ./src/cryptor.go ./src/notify.go \ ./src/audit_db.go ./src/audit_api.go \ ./src/sms.go ./src/email.go \ diff --git a/src/audit_db.go b/src/audit_db.go index 9dce244..3842893 100644 --- a/src/audit_db.go +++ b/src/audit_db.go @@ -8,6 +8,7 @@ import ( "time" uuid "github.com/hashicorp/go-uuid" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -37,7 +38,7 @@ func auditApp(title string, record string, app string, mode string, address stri return &auditEvent{Title: title, Mode: mode, Who: address, Record: record, Status: "ok", When: int32(time.Now().Unix())} } -func (event auditEvent) submit(db dbcon) { +func (event auditEvent) submit(db *dbcon) { //fmt.Println("submit event to audit!!!!!!!!!!") /* bdoc, err := bson.Marshal(event) @@ -86,12 +87,12 @@ func (event auditEvent) submit(db dbcon) { if len(event.After) > 0 { bdoc["after"] = event.After } - db.createRecord(TblName.Audit, &bdoc) + db.store.CreateRecord(storage.TblName.Audit, &bdoc) } func (dbobj dbcon) getAuditEvents(userTOKEN string, offset int32, limit int32) ([]byte, int64, error) { //var results []*auditEvent - count, err := dbobj.countRecords(TblName.Audit, "record", userTOKEN) + count, err := dbobj.store.CountRecords(storage.TblName.Audit, "record", userTOKEN) if err != nil { return nil, 0, err } @@ -99,7 +100,7 @@ func (dbobj dbcon) getAuditEvents(userTOKEN string, offset int32, limit int32) ( return []byte("[]"), 0, err } var results []bson.M - records, err := dbobj.getList(TblName.Audit, "record", userTOKEN, offset, limit) + records, err := dbobj.store.GetList(storage.TblName.Audit, "record", userTOKEN, offset, limit) if err != nil { return nil, 0, err } @@ -130,7 +131,7 @@ func (dbobj dbcon) getAuditEvents(userTOKEN string, offset int32, limit int32) ( func (dbobj dbcon) getAuditEvent(atoken string) (string, []byte, error) { //var results []*auditEvent - record, err := dbobj.getRecord(TblName.Audit, "atoken", atoken) + record, err := dbobj.store.GetRecord(storage.TblName.Audit, "atoken", atoken) if err != nil { return "", nil, err } diff --git a/src/bunker.go b/src/bunker.go index 992f96c..b16cad2 100644 --- a/src/bunker.go +++ b/src/bunker.go @@ -4,6 +4,7 @@ package main import ( "context" + "crypto/md5" "encoding/hex" "encoding/json" "flag" @@ -20,34 +21,16 @@ import ( "github.com/gobuffalo/packr" "github.com/julienschmidt/httprouter" "github.com/kelseyhightower/envconfig" + "github.com/paranoidguy/databunker/src/storage" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" yaml "gopkg.in/yaml.v2" ) -// Tbl is used to store table id -type Tbl int - -// listTbls used to store list of tables -type listTbls struct { - Users Tbl - Audit Tbl - Xtokens Tbl - Consent Tbl - Sessions Tbl - Requests Tbl - Sharedrecords Tbl -} - -// TblName is enum of tables -var TblName = &listTbls{ - Users: 0, - Audit: 1, - Xtokens: 2, - Consent: 3, - Sessions: 4, - Requests: 5, - Sharedrecords: 6, +type dbcon struct { + store storage.DBStorage + masterKey []byte + hash []byte } // Config is u sed to store application configuration @@ -99,7 +82,7 @@ type Config struct { // mainEnv struct stores global structures type mainEnv struct { - db dbcon + db *dbcon conf Config stopChan chan struct{} } @@ -164,7 +147,7 @@ func (e mainEnv) backupDB(w http.ResponseWriter, r *http.Request, ps httprouter. return } w.WriteHeader(200) - e.db.backupDB(w) + e.db.store.BackupDB(w) } // setupRouter() setup HTTP Router object. @@ -260,7 +243,7 @@ func readFile(cfg *Config, filepath *string) error { confFile = *filepath } } - fmt.Printf("Databunker conf file is: %s\n", confFile) + fmt.Printf("Databunker configuration file is: %s\n", confFile) f, err := os.Open(confFile) if err != nil { return err @@ -284,7 +267,7 @@ func (e mainEnv) dbCleanupDo() { log.Printf("db cleanup timeout\n") exp, _ := parseExpiration0(e.conf.Policy.MaxAuditRetentionPeriod) if exp > 0 { - e.db.deleteExpired0(TblName.Audit, exp) + e.db.store.DeleteExpired0(storage.TblName.Audit, exp) } notifyURL := e.conf.Notification.ConsentNotificationURL e.db.expireConsentRecords(notifyURL) @@ -344,17 +327,19 @@ func logRequest(handler http.Handler) http.Handler { }) } -func setupDB(dbPtr *string) (dbcon, string, error) { +func setupDB(dbPtr *string) (*dbcon, string, error) { fmt.Printf("\nDatabunker init\n\n") masterKey, err := generateMasterKey() + hash := md5.Sum(masterKey) fmt.Printf("Master key: %x\n\n", masterKey) fmt.Printf("Init databunker.db\n\n") - db, _ := newDB(masterKey, dbPtr) - err = db.initDB() + store, _ := storage.NewDBStorage(dbPtr) + err = store.InitDB() if err != nil { //log.Panic("error %s", err.Error()) log.Fatalf("db init error %s", err.Error()) } + db := &dbcon{store, masterKey, hash[:]} rootToken, err := db.createRootXtoken() if err != nil { //log.Panic("error %s", err.Error()) @@ -364,51 +349,62 @@ func setupDB(dbPtr *string) (dbcon, string, error) { return db, rootToken, err } +func getMasterKey(masterKeyPtr *string) []byte { + masterKeyStr := "" + if masterKeyPtr != nil && len(*masterKeyPtr) > 0 { + masterKeyStr = *masterKeyPtr + } else { + masterKeyStr = os.Getenv("DATABUNKER_MASTERKEY") + } + if len(masterKeyStr) != 48 { + fmt.Printf("Failed to decode master key: bad length\n") + os.Exit(0) + } + masterKey, err := hex.DecodeString(masterKeyStr) + if err != nil { + fmt.Printf("Failed to decode master key: %s\n", err) + os.Exit(0) + } + return masterKey +} + // main application function func main() { rand.Seed(time.Now().UnixNano()) lockMemory() //fmt.Printf("%+v\n", cfg) - initPtr := flag.Bool("init", false, "a bool") + initPtr := flag.Bool("init", false, "generate master key and init database") + startPtr := flag.Bool("start", false, "start databunker service. User DATABUNKER_MASTERKEY environment variable.") masterKeyPtr := flag.String("masterkey", "", "master key") dbPtr := flag.String("db", "", "database file") - confPtr := flag.String("conf", "", "configuration file") + confPtr := flag.String("conf", "", "configuration file name") flag.Parse() var cfg Config readFile(&cfg, confPtr) readEnv(&cfg) - - var err error - var masterKey []byte if *initPtr { db, _, _ := setupDB(dbPtr) - db.closeDB() + db.store.CloseDB() os.Exit(0) } - if dbExists(dbPtr) == false { - fmt.Printf("\ndatabunker.db file is missing.\n\n") - fmt.Println(`Run "./databunker -init" for the first time to init database.`) + if storage.DBExists(dbPtr) == false { + fmt.Printf("\nDatabase is not initialized.\n\n") + fmt.Println(`Run "databunker -init" for the first time to generate keys and init database.`) fmt.Println("") os.Exit(0) } - if masterKeyPtr != nil && len(*masterKeyPtr) > 0 { - if len(*masterKeyPtr) != 48 { - fmt.Printf("Failed to decode master key: bad length\n") - os.Exit(0) - } - masterKey, err = hex.DecodeString(*masterKeyPtr) - if err != nil { - fmt.Printf("Failed to decode master key: %s\n", err) - os.Exit(0) - } - } else { - fmt.Println(`Masterkey is missing.`) - fmt.Println(`Run "./databunker -masterkey key"`) + if masterKeyPtr == nil || *startPtr == false { + fmt.Println(`Run "databunker -start" will load DATABUNKER_MASTERKEY environment variable.`) + fmt.Println(`For testing "databunker -masterkey MASTER_KEY_VALUE" can be used. Not recommended for production.`) + fmt.Println("") os.Exit(0) } - db, _ := newDB(masterKey, dbPtr) - db.initUserApps() + masterKey := getMasterKey(masterKeyPtr) + store, _ := storage.OpenDB(dbPtr) + store.InitUserApps() + hash := md5.Sum(masterKey) + db := &dbcon{store, masterKey, hash[:]} e := mainEnv{db, cfg, make(chan struct{})} e.dbCleanup() fmt.Printf("host %s\n", cfg.Server.Host+":"+cfg.Server.Port) @@ -424,7 +420,7 @@ func main() { close(e.stopChan) time.Sleep(1) srv.Shutdown(context.TODO()) - db.closeDB() + db.store.CloseDB() }() if _, err := os.Stat(cfg.Ssl.SslCertificate); !os.IsNotExist(err) { diff --git a/src/bunker_test.go b/src/bunker_test.go index f453aef..e5b74bd 100644 --- a/src/bunker_test.go +++ b/src/bunker_test.go @@ -92,7 +92,7 @@ func init() { fmt.Printf("error %s", err.Error()) } rootToken = myRootToken - db.initUserApps() + db.store.InitUserApps() var cfg2 Config cfile := "../databunker.yaml" readFile(&cfg2, &cfile) diff --git a/src/consent_db.go b/src/consent_db.go index ba9d398..30cf450 100644 --- a/src/consent_db.go +++ b/src/consent_db.go @@ -7,6 +7,7 @@ import ( "time" "github.com/fatih/structs" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -52,13 +53,13 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s bdoc["lastmodifiedby"] = lastmodifiedby if len(userTOKEN) > 0 { // first check if this consent exists, then update - raw, err := dbobj.getRecord2(TblName.Consent, "token", userTOKEN, "brief", brief) + raw, err := dbobj.store.GetRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief) if err != nil { fmt.Printf("error to find:%s", err) return false, err } if raw != nil { - dbobj.updateRecord2(TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil) + dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil) if status != raw["status"].(string) { // status changed return true, nil @@ -66,13 +67,13 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s return false, nil } } else { - raw, err := dbobj.getRecord2(TblName.Consent, "who", usercode, "brief", brief) + raw, err := dbobj.store.GetRecord2(storage.TblName.Consent, "who", usercode, "brief", brief) if err != nil { fmt.Printf("error to find:%s", err) return false, err } if raw != nil { - dbobj.updateRecord2(TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil) + dbobj.store.UpdateRecord2(storage.TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil) if status != raw["status"].(string) { // status changed return true, nil @@ -103,7 +104,7 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s Lastmodifiedby: lastmodifiedby, } // in any case - insert record - _, err := dbobj.createRecord(TblName.Consent, structs.Map(ev)) + _, err := dbobj.store.CreateRecord(storage.TblName.Consent, structs.Map(ev)) if err != nil { fmt.Printf("error to insert record: %s\n", err) return false, err @@ -115,7 +116,7 @@ func (dbobj dbcon) createConsentRecord(userTOKEN string, mode string, usercode s func (dbobj dbcon) linkConsentRecords(userTOKEN string, mode string, usercode string) error { bdoc := bson.M{} bdoc["token"] = userTOKEN - _, err := dbobj.updateRecord2(TblName.Consent, "token", "", "who", usercode, &bdoc, nil) + _, err := dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", "", "who", usercode, &bdoc, nil) return err } @@ -135,9 +136,9 @@ func (dbobj dbcon) withdrawConsentRecord(userTOKEN string, brief string, mode st bdoc["lastmodifiedby"] = lastmodifiedby if len(userTOKEN) > 0 { fmt.Printf("%s %s\n", userTOKEN, brief) - dbobj.updateRecord2(TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil) + dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil) } else { - dbobj.updateRecord2(TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil) + dbobj.store.UpdateRecord2(storage.TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil) } return nil } @@ -145,7 +146,7 @@ func (dbobj dbcon) withdrawConsentRecord(userTOKEN string, brief string, mode st // link consent to user? func (dbobj dbcon) listConsentRecords(userTOKEN string) ([]byte, int, error) { - records, err := dbobj.getList(TblName.Consent, "token", userTOKEN, 0, 0) + records, err := dbobj.store.GetList(storage.TblName.Consent, "token", userTOKEN, 0, 0) if err != nil { return nil, 0, err } @@ -162,7 +163,7 @@ func (dbobj dbcon) listConsentRecords(userTOKEN string) ([]byte, int, error) { } func (dbobj dbcon) viewConsentRecord(userTOKEN string, brief string) ([]byte, error) { - record, err := dbobj.getRecord2(TblName.Consent, "token", userTOKEN, "brief", brief) + record, err := dbobj.store.GetRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief) if record == nil || err != nil { return nil, err } @@ -176,14 +177,14 @@ func (dbobj dbcon) viewConsentRecord(userTOKEN string, brief string) ([]byte, er func (dbobj dbcon) filterConsentRecords(brief string, offset int32, limit int32) ([]byte, int64, error) { //var results []*auditEvent - count, err := dbobj.countRecords(TblName.Consent, "brief", brief) + count, err := dbobj.store.CountRecords(storage.TblName.Consent, "brief", brief) if err != nil { return nil, 0, err } if count == 0 { return []byte("[]"), 0, err } - records, err := dbobj.getList(TblName.Consent, "brief", brief, offset, limit) + records, err := dbobj.store.GetList(storage.TblName.Consent, "brief", brief, offset, limit) if err != nil { return nil, 0, err } @@ -201,7 +202,7 @@ func (dbobj dbcon) filterConsentRecords(brief string, offset int32, limit int32) } func (dbobj dbcon) getConsentBriefs() ([]byte, int64, error) { - records, err := dbobj.getUniqueList(TblName.Consent, "brief") + records, err := dbobj.store.GetUniqueList(storage.TblName.Consent, "brief") if err != nil { return nil, 0, err } @@ -223,7 +224,7 @@ func (dbobj dbcon) getConsentBriefs() ([]byte, int64, error) { } func (dbobj dbcon) expireConsentRecords(notifyURL string) error { - records, err := dbobj.getExpiring(TblName.Consent, "status", "yes") + records, err := dbobj.store.GetExpiring(storage.TblName.Consent, "status", "yes") if err != nil { return err } @@ -238,11 +239,11 @@ func (dbobj dbcon) expireConsentRecords(notifyURL string) error { fmt.Printf("This consent record is expired: %s - %s\n", userTOKEN, brief) if len(userTOKEN) > 0 { fmt.Printf("%s %s\n", userTOKEN, brief) - dbobj.updateRecord2(TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil) + dbobj.store.UpdateRecord2(storage.TblName.Consent, "token", userTOKEN, "brief", brief, &bdoc, nil) notifyConsentChange(notifyURL, brief, "expired", "token", userTOKEN) } else { usercode := rec["who"].(string) - dbobj.updateRecord2(TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil) + dbobj.store.UpdateRecord2(storage.TblName.Consent, "who", usercode, "brief", brief, &bdoc, nil) notifyConsentChange(notifyURL, brief, "expired", rec["mode"].(string), usercode) } diff --git a/src/requests_db.go b/src/requests_db.go index f4c64d0..4840c11 100644 --- a/src/requests_db.go +++ b/src/requests_db.go @@ -5,6 +5,7 @@ import ( "time" uuid "github.com/hashicorp/go-uuid" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -44,18 +45,18 @@ func (dbobj dbcon) saveUserRequest(action string, token string, app string, brie if len(brief) > 0 { bdoc["brief"] = brief } - _, err := dbobj.createRecord(TblName.Requests, &bdoc) + _, err := dbobj.store.CreateRecord(storage.TblName.Requests, &bdoc) return rtoken, err } func (dbobj dbcon) getRequests(status string, offset int32, limit int32) ([]byte, int64, error) { //var results []*auditEvent - count, err := dbobj.countRecords(TblName.Requests, "status", status) + count, err := dbobj.store.CountRecords(storage.TblName.Requests, "status", status) if err != nil { return nil, 0, err } var results []bson.M - records, err := dbobj.getList(TblName.Requests, "status", status, offset, limit) + records, err := dbobj.store.GetList(storage.TblName.Requests, "status", status, offset, limit) if err != nil { return nil, 0, err } @@ -77,7 +78,7 @@ func (dbobj dbcon) getRequests(status string, offset int32, limit int32) ([]byte } func (dbobj dbcon) getRequest(rtoken string) (bson.M, error) { - record, err := dbobj.getRecord(TblName.Requests, "rtoken", rtoken) + record, err := dbobj.store.GetRecord(storage.TblName.Requests, "rtoken", rtoken) if err != nil { return record, err } @@ -106,5 +107,5 @@ func (dbobj dbcon) updateRequestStatus(rtoken string, status string) { bdoc := bson.M{} bdoc["status"] = status //fmt.Printf("op json: %s\n", update) - dbobj.updateRecord(TblName.Requests, "rtoken", rtoken, &bdoc) + dbobj.store.UpdateRecord(storage.TblName.Requests, "rtoken", rtoken, &bdoc) } diff --git a/src/sessions_api.go b/src/sessions_api.go index 8e31b8a..a287848 100644 --- a/src/sessions_api.go +++ b/src/sessions_api.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/julienschmidt/httprouter" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -81,7 +82,7 @@ func (e mainEnv) getUserSessions(w http.ResponseWriter, r *http.Request, ps http mode := ps.ByName("mode") if mode == "session" { - e.db.deleteExpired(TblName.Sessions, "session", address) + e.db.store.DeleteExpired(storage.TblName.Sessions, "session", address) e.getSession(w, r, address) return } @@ -114,7 +115,7 @@ func (e mainEnv) getUserSessions(w http.ResponseWriter, r *http.Request, ps http if e.enforceAuth(w, r, event) == "" { return } - e.db.deleteExpired(TblName.Sessions, "token", userTOKEN) + e.db.store.DeleteExpired(storage.TblName.Sessions, "token", userTOKEN) args := r.URL.Query() var offset int32 var limit int32 = 10 diff --git a/src/sessions_db.go b/src/sessions_db.go index 2e4ce52..df899e1 100644 --- a/src/sessions_db.go +++ b/src/sessions_db.go @@ -7,6 +7,7 @@ import ( "time" uuid "github.com/hashicorp/go-uuid" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -39,7 +40,7 @@ func (dbobj dbcon) createSessionRecord(userTOKEN string, expiration string, data bdoc["endtime"] = endtime bdoc["when"] = now bdoc["data"] = encodedStr - _, err = dbobj.createRecord(TblName.Sessions, bdoc) + _, err = dbobj.store.CreateRecord(storage.TblName.Sessions, bdoc) if err != nil { return "", err } @@ -47,7 +48,7 @@ func (dbobj dbcon) createSessionRecord(userTOKEN string, expiration string, data } func (dbobj dbcon) getUserSession(sessionUUID string) (int32, []byte, string, error) { - record, err := dbobj.getRecord(TblName.Sessions, "session", sessionUUID) + record, err := dbobj.store.GetRecord(storage.TblName.Sessions, "session", sessionUUID) if err != nil { return 0, nil, "", err } @@ -82,12 +83,12 @@ func (dbobj dbcon) getUserSessionsByToken(userTOKEN string, offset int32, limit return nil, 0, err } - count, err := dbobj.countRecords(TblName.Sessions, "token", userTOKEN) + count, err := dbobj.store.CountRecords(storage.TblName.Sessions, "token", userTOKEN) if err != nil { return nil, 0, err } - records, err := dbobj.getList(TblName.Sessions, "token", userTOKEN, offset, limit) + records, err := dbobj.store.GetList(storage.TblName.Sessions, "token", userTOKEN, offset, limit) if err != nil { return nil, 0, err } diff --git a/src/sharedrecords_db.go b/src/sharedrecords_db.go index 4ecb9f6..06654ee 100644 --- a/src/sharedrecords_db.go +++ b/src/sharedrecords_db.go @@ -7,6 +7,7 @@ import ( "time" uuid "github.com/hashicorp/go-uuid" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -57,7 +58,7 @@ func (dbobj dbcon) saveSharedRecord(userTOKEN string, fields string, expiration if len(session) > 0 { bdoc["session"] = session } - _, err = dbobj.createRecord(TblName.Sharedrecords, bdoc) + _, err = dbobj.store.CreateRecord(storage.TblName.Sharedrecords, bdoc) if err != nil { return "", err } @@ -69,7 +70,7 @@ func (dbobj dbcon) getSharedRecord(recordUUID string) (checkRecordResult, error) if isValidUUID(recordUUID) == false { return result, errors.New("failed to authenticate") } - record, err := dbobj.getRecord(TblName.Sharedrecords, "record", recordUUID) + record, err := dbobj.store.GetRecord(storage.TblName.Sharedrecords, "record", recordUUID) if record == nil || err != nil { return result, errors.New("failed to authenticate") } diff --git a/src/qldb.go b/src/storage/storage.go similarity index 78% rename from src/qldb.go rename to src/storage/storage.go index c69d5a7..4efd8d7 100644 --- a/src/qldb.go +++ b/src/storage/storage.go @@ -1,19 +1,8 @@ -package main - -// github.com/mattn/go-sqlite3 - -// This project is using the following golang internal database: -// https://godoc.org/modernc.org/ql - -// go build modernc.org/ql/ql -// go install modernc.org/ql/ql - -// https://stackoverflow.com/questions/21986780/is-it-possible-to-retrieve-a-column-value-by-name-using-golang-database-sql +package storage // https://stackoverflow.com/questions/21986780/is-it-possible-to-retrieve-a-column-value-by-name-using-golang-database-sql import ( - "crypto/md5" "database/sql" "fmt" "log" @@ -33,13 +22,38 @@ var ( knownApps []string ) -type dbcon struct { - db *sql.DB - masterKey []byte - hash []byte +// Tbl is used to store table id +type Tbl int + +// listTbls used to store list of tables +type listTbls struct { + Users Tbl + Audit Tbl + Xtokens Tbl + Consent Tbl + Sessions Tbl + Requests Tbl + Sharedrecords Tbl } -func dbExists(filepath *string) bool { +// TblName is enum of tables +var TblName = &listTbls{ + Users: 0, + Audit: 1, + Xtokens: 2, + Consent: 3, + Sessions: 4, + Requests: 5, + Sharedrecords: 6, +} + +// DBStorage struct is used to store database object +type DBStorage struct { + db *sql.DB +} + +// DBExists function checks if database exists +func DBExists(filepath *string) bool { dbfile := "./databunker.db" if filepath != nil { if len(*filepath) > 0 { @@ -52,8 +66,8 @@ func dbExists(filepath *string) bool { return true } -func newDB(masterKey []byte, filepath *string) (dbcon, error) { - dbobj := dbcon{nil, nil, nil} +// OpenDB function opens the database +func OpenDB(filepath *string) (DBStorage, error) { dbfile := "./databunker.db" if filepath != nil { if len(*filepath) > 0 { @@ -90,8 +104,7 @@ func newDB(masterKey []byte, filepath *string) (dbcon, error) { if err != nil { log.Fatalf("Error on vacuum database command") } - hash := md5.Sum(masterKey) - dbobj = dbcon{db, masterKey, hash[:]} + dbobj := DBStorage{db} // load all table names q := "select name from sqlite_master where type ='table'" @@ -111,7 +124,8 @@ func newDB(masterKey []byte, filepath *string) (dbcon, error) { return dbobj, nil } -func (dbobj dbcon) initDB() error { +// InitDB function creates tables and indexes +func (dbobj DBStorage) InitDB() error { var err error if err = initUsers(dbobj.db); err != nil { return err @@ -137,11 +151,13 @@ func (dbobj dbcon) initDB() error { return nil } -func (dbobj dbcon) closeDB() { +// CloseDB function closes the open database +func (dbobj DBStorage) CloseDB() { dbobj.db.Close() } -func (dbobj dbcon) backupDB(w http.ResponseWriter) { +// BackupDB function backups existing databsae and prints database structure to http.ResponseWriter +func (dbobj DBStorage) BackupDB(w http.ResponseWriter) { err := sqlite3dump.DumpDB(dbobj.db, w) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -149,7 +165,8 @@ func (dbobj dbcon) backupDB(w http.ResponseWriter) { } } -func (dbobj dbcon) initUserApps() error { +// InitUserApps initialises list of databases. +func (dbobj DBStorage) InitUserApps() error { return nil } @@ -283,7 +300,8 @@ func getTable(t Tbl) string { return "users" } -func (dbobj dbcon) createRecordInTable(tbl string, data interface{}) (int, error) { +// CreateRecordInTable creates new record +func (dbobj DBStorage) CreateRecordInTable(tbl string, data interface{}) (int, error) { fields, values := decodeFieldsValues(data) valuesInQ := "$1" for idx := range values { @@ -309,13 +327,15 @@ func (dbobj dbcon) createRecordInTable(tbl string, data interface{}) (int, error return 1, nil } -func (dbobj dbcon) createRecord(t Tbl, data interface{}) (int, error) { +// CreateRecord creates new record +func (dbobj DBStorage) CreateRecord(t Tbl, data interface{}) (int, error) { //if reflect.TypeOf(value) == reflect.TypeOf("string") tbl := getTable(t) - return dbobj.createRecordInTable(tbl, data) + return dbobj.CreateRecordInTable(tbl, data) } -func (dbobj dbcon) countRecords(t Tbl, keyName string, keyValue string) (int64, error) { +// CountRecords returns number of records that match filter +func (dbobj DBStorage) CountRecords(t Tbl, keyName string, keyValue string) (int64, error) { tbl := getTable(t) q := "select count(*) from " + tbl + " WHERE " + escapeName(keyName) + "=$1" fmt.Printf("q: %s\n", q) @@ -338,18 +358,21 @@ func (dbobj dbcon) countRecords(t Tbl, keyName string, keyValue string) (int64, return int64(count), nil } -func (dbobj dbcon) updateRecord(t Tbl, keyName string, keyValue string, bdoc *bson.M) (int64, error) { +// UpdateRecord updates database record +func (dbobj DBStorage) UpdateRecord(t Tbl, keyName string, keyValue string, bdoc *bson.M) (int64, error) { table := getTable(t) filter := escapeName(keyName) + "=\"" + keyValue + "\"" return dbobj.updateRecordInTableDo(table, filter, bdoc, nil) } -func (dbobj dbcon) updateRecordInTable(table string, keyName string, keyValue string, bdoc *bson.M) (int64, error) { +// UpdateRecordInTable updates database record +func (dbobj DBStorage) UpdateRecordInTable(table string, keyName string, keyValue string, bdoc *bson.M) (int64, error) { filter := escapeName(keyName) + "=\"" + keyValue + "\"" return dbobj.updateRecordInTableDo(table, filter, bdoc, nil) } -func (dbobj dbcon) updateRecord2(t Tbl, keyName string, keyValue string, +// UpdateRecord2 updates database record +func (dbobj DBStorage) UpdateRecord2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string, bdoc *bson.M, bdel *bson.M) (int64, error) { table := getTable(t) filter := escapeName(keyName) + "=\"" + keyValue + "\" AND " + @@ -357,14 +380,15 @@ func (dbobj dbcon) updateRecord2(t Tbl, keyName string, keyValue string, return dbobj.updateRecordInTableDo(table, filter, bdoc, bdel) } -func (dbobj dbcon) updateRecordInTable2(table string, keyName string, +// UpdateRecordInTable2 updates database record +func (dbobj DBStorage) UpdateRecordInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string, bdoc *bson.M, bdel *bson.M) (int64, error) { filter := escapeName(keyName) + "=\"" + keyValue + "\" AND " + escapeName(keyName2) + "=\"" + keyValue2 + "\"" return dbobj.updateRecordInTableDo(table, filter, bdoc, bdel) } -func (dbobj dbcon) updateRecordInTableDo(table string, filter string, bdoc *bson.M, bdel *bson.M) (int64, error) { +func (dbobj DBStorage) updateRecordInTableDo(table string, filter string, bdoc *bson.M, bdel *bson.M) (int64, error) { op, values := decodeForUpdate(bdoc, bdel) q := "update " + table + " SET " + op + " WHERE " + filter fmt.Printf("q: %s\n", q) @@ -385,7 +409,8 @@ func (dbobj dbcon) updateRecordInTableDo(table string, filter string, bdoc *bson return num, err } -func (dbobj dbcon) getRecord(t Tbl, keyName string, keyValue string) (bson.M, error) { +// GetRecord returns specific record from database +func (dbobj DBStorage) GetRecord(t Tbl, keyName string, keyValue string) (bson.M, error) { table := getTable(t) q := "select * from " + table + " WHERE " + escapeName(keyName) + "=$1" values := make([]interface{}, 0) @@ -393,14 +418,16 @@ func (dbobj dbcon) getRecord(t Tbl, keyName string, keyValue string) (bson.M, er return dbobj.getRecordInTableDo(q, values) } -func (dbobj dbcon) getRecordInTable(table string, keyName string, keyValue string) (bson.M, error) { +// GetRecordInTable returns specific record from database +func (dbobj DBStorage) GetRecordInTable(table string, keyName string, keyValue string) (bson.M, error) { q := "select * from " + table + " WHERE " + escapeName(keyName) + "=$1" values := make([]interface{}, 0) values = append(values, keyValue) return dbobj.getRecordInTableDo(q, values) } -func (dbobj dbcon) getRecord2(t Tbl, keyName string, keyValue string, +// GetRecord2 returns specific record from database +func (dbobj DBStorage) GetRecord2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (bson.M, error) { table := getTable(t) q := "select * from " + table + " WHERE " + escapeName(keyName) + "=$1 AND " + @@ -411,7 +438,7 @@ func (dbobj dbcon) getRecord2(t Tbl, keyName string, keyValue string, return dbobj.getRecordInTableDo(q, values) } -func (dbobj dbcon) getRecordInTableDo(q string, values []interface{}) (bson.M, error) { +func (dbobj DBStorage) getRecordInTableDo(q string, values []interface{}) (bson.M, error) { fmt.Printf("query: %s\n", q) tx, err := dbobj.db.Begin() @@ -490,14 +517,14 @@ func (dbobj dbcon) getRecordInTableDo(q string, values []interface{}) (bson.M, e return recBson, nil } -/* -func (dbobj dbcon) deleteRecord(t Tbl, keyName string, keyValue string) (int64, error) { +// DeleteRecord deletes record in database +func (dbobj DBStorage) DeleteRecord(t Tbl, keyName string, keyValue string) (int64, error) { tbl := getTable(t) - return dbobj.deleteRecordInTable(tbl, keyName, keyValue) + return dbobj.DeleteRecordInTable(tbl, keyName, keyValue) } -*/ -func (dbobj dbcon) deleteRecordInTable(table string, keyName string, keyValue string) (int64, error) { +// DeleteRecordInTable deletes record in database +func (dbobj DBStorage) DeleteRecordInTable(table string, keyName string, keyValue string) (int64, error) { q := "delete from " + table + " WHERE " + escapeName(keyName) + "=$1" fmt.Printf("q: %s\n", q) @@ -517,12 +544,13 @@ func (dbobj dbcon) deleteRecordInTable(table string, keyName string, keyValue st return num, err } -func (dbobj dbcon) deleteRecord2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { +// DeleteRecord2 deletes record in database +func (dbobj DBStorage) DeleteRecord2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { tbl := getTable(t) return dbobj.deleteRecordInTable2(tbl, keyName, keyValue, keyName2, keyValue2) } -func (dbobj dbcon) deleteRecordInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { +func (dbobj DBStorage) deleteRecordInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { q := "delete from " + table + " WHERE " + escapeName(keyName) + "=$1 AND " + escapeName(keyName2) + "=$2" fmt.Printf("q: %s\n", q) @@ -544,14 +572,14 @@ func (dbobj dbcon) deleteRecordInTable2(table string, keyName string, keyValue s } /* -func (dbobj dbcon) deleteDuplicate2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { +func (dbobj DBStorage) deleteDuplicate2(t Tbl, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { tbl := getTable(t) return dbobj.deleteDuplicateInTable2(tbl, keyName, keyValue, keyName2, keyValue2) } */ /* -func (dbobj dbcon) deleteDuplicateInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { +func (dbobj DBStorage) deleteDuplicateInTable2(table string, keyName string, keyValue string, keyName2 string, keyValue2 string) (int64, error) { q := "delete from " + table + " where " + escapeName(keyName) + "=$1 AND " + escapeName(keyName2) + "=$2 AND rowid not in " + "(select min(rowid) from " + table + " where " + escapeName(keyName) + "=$3 AND " + @@ -575,7 +603,8 @@ func (dbobj dbcon) deleteDuplicateInTable2(table string, keyName string, keyValu } */ -func (dbobj dbcon) deleteExpired0(t Tbl, expt int32) (int64, error) { +// DeleteExpired0 deletes expired records in database +func (dbobj DBStorage) DeleteExpired0(t Tbl, expt int32) (int64, error) { table := getTable(t) now := int32(time.Now().Unix()) q := fmt.Sprintf("delete from %s WHERE `when`>0 AND `when`<%d", table, now-expt) @@ -598,7 +627,8 @@ func (dbobj dbcon) deleteExpired0(t Tbl, expt int32) (int64, error) { return num, err } -func (dbobj dbcon) deleteExpired(t Tbl, keyName string, keyValue string) (int64, error) { +// DeleteExpired deletes expired records in database +func (dbobj DBStorage) DeleteExpired(t Tbl, keyName string, keyValue string) (int64, error) { table := getTable(t) q := "delete from " + table + " WHERE endtime>0 AND endtime<$1 AND " + escapeName(keyName) + "=$2" fmt.Printf("q: %s\n", q) @@ -620,7 +650,8 @@ func (dbobj dbcon) deleteExpired(t Tbl, keyName string, keyValue string) (int64, return num, err } -func (dbobj dbcon) cleanupRecord(t Tbl, keyName string, keyValue string, data interface{}) (int64, error) { +// CleanupRecord nullifies specific feilds in records in database +func (dbobj DBStorage) CleanupRecord(t Tbl, keyName string, keyValue string, data interface{}) (int64, error) { tbl := getTable(t) cleanup := decodeForCleanup(data) q := "update " + tbl + " SET " + cleanup + " WHERE " + escapeName(keyName) + "=$1" @@ -642,7 +673,8 @@ func (dbobj dbcon) cleanupRecord(t Tbl, keyName string, keyValue string, data in return num, err } -func (dbobj dbcon) getExpiring(t Tbl, keyName string, keyValue string) ([]bson.M, error) { +// GetExpiring get records that are expiring +func (dbobj DBStorage) GetExpiring(t Tbl, keyName string, keyValue string) ([]bson.M, error) { table := getTable(t) now := int32(time.Now().Unix()) q := fmt.Sprintf("select * from %s WHERE endtime>0 AND endtime<%d AND %s=$1", table, now, escapeName(keyName)) @@ -650,7 +682,8 @@ func (dbobj dbcon) getExpiring(t Tbl, keyName string, keyValue string) ([]bson.M return dbobj.getListDo(q, keyValue) } -func (dbobj dbcon) getUniqueList(t Tbl, keyName string) ([]bson.M, error) { +// GetUniqueList returns a unique list of values from specific column in database +func (dbobj DBStorage) GetUniqueList(t Tbl, keyName string) ([]bson.M, error) { table := getTable(t) keyName = escapeName(keyName) q := "select distinct " + keyName + " from " + table + " ORDER BY " + keyName @@ -658,7 +691,8 @@ func (dbobj dbcon) getUniqueList(t Tbl, keyName string) ([]bson.M, error) { return dbobj.getListDo(q, "") } -func (dbobj dbcon) getList(t Tbl, keyName string, keyValue string, start int32, limit int32) ([]bson.M, error) { +// GetList is used to return list of rows. It can be used to return values using pager. +func (dbobj DBStorage) GetList(t Tbl, keyName string, keyValue string, start int32, limit int32) ([]bson.M, error) { table := getTable(t) if limit > 100 { limit = 100 @@ -675,7 +709,7 @@ func (dbobj dbcon) getList(t Tbl, keyName string, keyValue string, start int32, return dbobj.getListDo(q, keyValue) } -func (dbobj dbcon) getListDo(q string, keyValue string) ([]bson.M, error) { +func (dbobj DBStorage) getListDo(q string, keyValue string) ([]bson.M, error) { tx, err := dbobj.db.Begin() if err != nil { return nil, err @@ -753,11 +787,13 @@ func (dbobj dbcon) getListDo(q string, keyValue string) ([]bson.M, error) { return results, nil } -func (dbobj dbcon) getAllTables() ([]string, error) { +// GetAllTables returns all tables that exists in database +func (dbobj DBStorage) GetAllTables() ([]string, error) { return knownApps, nil } -func (dbobj dbcon) validateNewApp(appName string) bool { +// ValidateNewApp function check if app name can be part of the table name +func (dbobj DBStorage) ValidateNewApp(appName string) bool { if contains(knownApps, appName) == true { return true } @@ -785,7 +821,8 @@ func execQueries(db *sql.DB, queries []string) error { return nil } -func (dbobj dbcon) indexNewApp(appName string) { +// IndexNewApp creates a new app table and creates indexes for it. +func (dbobj DBStorage) IndexNewApp(appName string) { if contains(knownApps, appName) == false { // it is a new app, create an index fmt.Printf("This is a new app, creating table & index for: %s\n", appName) @@ -922,3 +959,13 @@ func initSessions(db *sql.DB) error { `CREATE INDEX sessions_session ON sessions (session);`} return execQueries(db, queries) } + +func contains(slice []string, item string) bool { + set := make(map[string]struct{}, len(slice)) + for _, s := range slice { + set[s] = struct{}{} + } + + _, ok := set[item] + return ok +} diff --git a/src/userapps_api.go b/src/userapps_api.go index 5c694e1..7cc85a4 100644 --- a/src/userapps_api.go +++ b/src/userapps_api.go @@ -25,7 +25,7 @@ func (e mainEnv) userappNew(w http.ResponseWriter, r *http.Request, ps httproute returnError(w, r, "bad appname", 405, nil, event) return } - if e.db.validateNewApp("app_"+appName) == false { + if e.db.store.ValidateNewApp("app_"+appName) == false { returnError(w, r, "db limitation", 405, nil, event) return } diff --git a/src/userapps_db.go b/src/userapps_db.go index e392108..fbe431d 100644 --- a/src/userapps_db.go +++ b/src/userapps_db.go @@ -14,7 +14,7 @@ import ( func (dbobj dbcon) getUserApp(userTOKEN string, appName string) ([]byte, error) { - record, err := dbobj.getRecordInTable("app_"+appName, "token", userTOKEN) + record, err := dbobj.store.GetRecordInTable("app_"+appName, "token", userTOKEN) if err != nil { return nil, err } @@ -31,7 +31,7 @@ func (dbobj dbcon) createAppRecord(jsonData []byte, userTOKEN string, appName st if err != nil { return userTOKEN, err } - dbobj.indexNewApp("app_" + appName) + dbobj.store.IndexNewApp("app_" + appName) //var bdoc interface{} bdoc := bson.M{} @@ -46,14 +46,14 @@ func (dbobj dbcon) createAppRecord(jsonData []byte, userTOKEN string, appName st event.Record = userTOKEN } //fmt.Println("creating new app") - record, err := dbobj.getRecordInTable("app_"+appName, "token", userTOKEN) + record, err := dbobj.store.GetRecordInTable("app_"+appName, "token", userTOKEN) if err != nil { return userTOKEN, err } if record != nil { - _, err = dbobj.updateRecordInTable("app_"+appName, "token", userTOKEN, &bdoc) + _, err = dbobj.store.UpdateRecordInTable("app_"+appName, "token", userTOKEN, &bdoc) } else { - _, err = dbobj.createRecordInTable("app_"+appName, bdoc) + _, err = dbobj.store.CreateRecordInTable("app_"+appName, bdoc) } return userTOKEN, err } @@ -72,7 +72,7 @@ func (dbobj dbcon) updateAppRecord(jsonDataPatch []byte, userTOKEN string, appNa return userTOKEN, err } - record, err := dbobj.getRecordInTable("app_"+appName, "token", userTOKEN) + record, err := dbobj.store.GetRecordInTable("app_"+appName, "token", userTOKEN) if err != nil { return userTOKEN, err } @@ -113,7 +113,7 @@ func (dbobj dbcon) updateAppRecord(jsonDataPatch []byte, userTOKEN string, appNa // here I add md5 of the original record to filter // to make sure this record was not change by other thread fmt.Println("update user app") - result, err := dbobj.updateRecordInTable2("app_"+appName, "token", userTOKEN, "md5", sig, &bdoc, nil) + result, err := dbobj.store.UpdateRecordInTable2("app_"+appName, "token", userTOKEN, "md5", sig, &bdoc, nil) if err != nil { return userTOKEN, err } @@ -138,14 +138,14 @@ func (dbobj dbcon) listUserApps(userTOKEN string) ([]byte, error) { // not found return nil, err } - allCollections, err := dbobj.getAllTables() + allCollections, err := dbobj.store.GetAllTables() if err != nil { return nil, err } var result []string for _, colName := range allCollections { if strings.HasPrefix(colName, "app_") { - record, err := dbobj.getRecordInTable(colName, "token", userTOKEN) + record, err := dbobj.store.GetRecordInTable(colName, "token", userTOKEN) if err != nil { return nil, err } @@ -164,7 +164,7 @@ func (dbobj dbcon) listUserApps(userTOKEN string) ([]byte, error) { func (dbobj dbcon) listAllAppsOnly() ([]string, error) { //fmt.Println("dump list of collections") - allCollections, err := dbobj.getAllTables() + allCollections, err := dbobj.store.GetAllTables() if err != nil { return nil, err } @@ -179,7 +179,7 @@ func (dbobj dbcon) listAllAppsOnly() ([]string, error) { func (dbobj dbcon) listAllApps() ([]byte, error) { //fmt.Println("dump list of collections") - allCollections, err := dbobj.getAllTables() + allCollections, err := dbobj.store.GetAllTables() if err != nil { return nil, err } diff --git a/src/users_api.go b/src/users_api.go index 7d03f73..bd9bde2 100644 --- a/src/users_api.go +++ b/src/users_api.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/julienschmidt/httprouter" + "github.com/paranoidguy/databunker/src/storage" ) func (e mainEnv) userNew(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { @@ -75,17 +76,17 @@ func (e mainEnv) userNew(w http.ResponseWriter, r *http.Request, ps httprouter.P } if len(parsedData.emailIdx) > 0 && len(parsedData.phoneIdx) > 0 { // delete duplicate consent records for user - records, _ := e.db.getList(TblName.Consent, "who", parsedData.emailIdx, 0, 0) + records, _ := e.db.store.GetList(storage.TblName.Consent, "who", parsedData.emailIdx, 0, 0) var briefCodes []string for _, val := range records { //fmt.Printf("adding brief code: %s\n", val["brief"].(string)) briefCodes = append(briefCodes, val["brief"].(string)) } - records, _ = e.db.getList(TblName.Consent, "who", parsedData.phoneIdx, 0, 0) + records, _ = e.db.store.GetList(storage.TblName.Consent, "who", parsedData.phoneIdx, 0, 0) for _, val := range records { //fmt.Printf("XXX checking brief code for duplicates: %s\n", val["brief"].(string)) if contains(briefCodes, val["brief"].(string)) == true { - e.db.deleteRecord2(TblName.Consent, "token", userTOKEN, "who", parsedData.phoneIdx) + e.db.store.DeleteRecord2(storage.TblName.Consent, "token", userTOKEN, "who", parsedData.phoneIdx) } } } diff --git a/src/users_db.go b/src/users_db.go index 2cea144..ac42d62 100644 --- a/src/users_db.go +++ b/src/users_db.go @@ -11,6 +11,7 @@ import ( jsonpatch "github.com/evanphx/json-patch" uuid "github.com/hashicorp/go-uuid" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -56,7 +57,7 @@ func (dbobj dbcon) createUserRecord(parsedData userJSON, event *auditEvent) (str event.Record = userTOKEN } //fmt.Println("creating new user") - _, err = dbobj.createRecord(TblName.Users, bdoc) + _, err = dbobj.store.CreateRecord(storage.TblName.Users, bdoc) if err != nil { fmt.Printf("error in create!\n") return "", err @@ -72,7 +73,7 @@ func (dbobj dbcon) generateTempLoginCode(userTOKEN string) int32 { expired := int32(time.Now().Unix()) + 60 bdoc["tempcodeexp"] = expired //fmt.Printf("op json: %s\n", update) - dbobj.updateRecord(TblName.Users, "token", userTOKEN, &bdoc) + dbobj.store.UpdateRecord(storage.TblName.Users, "token", userTOKEN, &bdoc) return rnd } @@ -84,7 +85,7 @@ func (dbobj dbcon) generateDemoLoginCode(userTOKEN string) int32 { expired := int32(time.Now().Unix()) + 60 bdoc["tempcodeexp"] = expired //fmt.Printf("op json: %s\n", update) - dbobj.updateRecord(TblName.Users, "token", userTOKEN, &bdoc) + dbobj.store.UpdateRecord(storage.TblName.Users, "token", userTOKEN, &bdoc) return rnd } @@ -239,7 +240,7 @@ func (dbobj dbcon) updateUserRecordDo(jsonDataPatch []byte, userTOKEN string, ev //filter2 := bson.D{{"token", userTOKEN}, {"md5", sig}} //fmt.Printf("op json: %s\n", update) - result, err := dbobj.updateRecord2(TblName.Users, "token", userTOKEN, "md5", sig, &bdoc, &bdel) + result, err := dbobj.store.UpdateRecord2(storage.TblName.Users, "token", userTOKEN, "md5", sig, &bdoc, &bdel) if err != nil { return nil, nil, false, err } @@ -257,7 +258,7 @@ func (dbobj dbcon) updateUserRecordDo(jsonDataPatch []byte, userTOKEN string, ev } func (dbobj dbcon) lookupUserRecord(userTOKEN string) (bson.M, error) { - return dbobj.getRecord(TblName.Users, "token", userTOKEN) + return dbobj.store.GetRecord(storage.TblName.Users, "token", userTOKEN) } func (dbobj dbcon) lookupUserRecordByIndex(indexName string, indexValue string, conf Config) (bson.M, error) { @@ -271,7 +272,7 @@ func (dbobj dbcon) lookupUserRecordByIndex(indexName string, indexValue string, } idxStringHashHex := hashString(dbobj.hash, indexValue) fmt.Printf("loading by %s, value: %s\n", indexName, indexValue) - return dbobj.getRecord(TblName.Users, indexName+"idx", idxStringHashHex) + return dbobj.store.GetRecord(storage.TblName.Users, indexName+"idx", idxStringHashHex) } func (dbobj dbcon) getUser(userTOKEN string) ([]byte, error) { @@ -341,11 +342,11 @@ func (dbobj dbcon) deleteUserRecord(userTOKEN string) (bool, error) { // delete all user app records for _, appName := range userApps { appNameFull := "app_" + appName - dbobj.deleteRecordInTable(appNameFull, "token", userTOKEN) + dbobj.store.DeleteRecordInTable(appNameFull, "token", userTOKEN) } //delete in audit - dbobj.deleteRecordInTable("audit", "record", userTOKEN) - dbobj.deleteRecordInTable("sessions", "token", userTOKEN) + dbobj.store.DeleteRecord(storage.TblName.Audit, "record", userTOKEN) + dbobj.store.DeleteRecord(storage.TblName.Sessions, "token", userTOKEN) // cleanup user record bdel := bson.M{} bdel["data"] = "" @@ -353,7 +354,7 @@ func (dbobj dbcon) deleteUserRecord(userTOKEN string) (bool, error) { bdel["loginidx"] = "" bdel["emailidx"] = "" bdel["phoneidx"] = "" - result, err := dbobj.cleanupRecord(TblName.Users, "token", userTOKEN, bdel) + result, err := dbobj.store.CleanupRecord(storage.TblName.Users, "token", userTOKEN, bdel) if err != nil { return false, err } @@ -375,7 +376,7 @@ func (dbobj dbcon) wipeRecord(userTOKEN string) (bool, error) { dbobj.deleteRecordInTable(appNameFull, "token", userTOKEN) } // delete user record - result, err := dbobj.deleteRecord(TblName.Users, "token", userTOKEN) + result, err := dbobj.store.DeleteRecord(storage.TblName.Users, "token", userTOKEN) if err != nil { return false, err } diff --git a/src/xtokens_db.go b/src/xtokens_db.go index 723e813..0636cf9 100644 --- a/src/xtokens_db.go +++ b/src/xtokens_db.go @@ -6,13 +6,14 @@ import ( "time" uuid "github.com/hashicorp/go-uuid" + "github.com/paranoidguy/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) var rootXTOKEN string func (dbobj dbcon) getRootXtoken() (string, error) { - record, err := dbobj.getRecord(TblName.Xtokens, "type", "root") + record, err := dbobj.store.GetRecord(storage.TblName.Xtokens, "type", "root") if record == nil || err != nil { return "", err } @@ -34,7 +35,7 @@ func (dbobj dbcon) createRootXtoken() (string, error) { bdoc := bson.M{} bdoc["xtoken"] = hashString(dbobj.hash, rootToken) bdoc["type"] = "root" - _, err = dbobj.createRecord(TblName.Xtokens, bdoc) + _, err = dbobj.store.CreateRecord(storage.TblName.Xtokens, bdoc) if err != nil { return rootToken, err } @@ -59,7 +60,7 @@ func (dbobj dbcon) generateUserLoginXtoken(userTOKEN string) (string, error) { bdoc["xtoken"] = hashString(dbobj.hash, tokenUUID) bdoc["type"] = "login" bdoc["endtime"] = expired - _, err = dbobj.createRecord(TblName.Xtokens, bdoc) + _, err = dbobj.store.CreateRecord(storage.TblName.Xtokens, bdoc) return tokenUUID, err } @@ -74,7 +75,7 @@ func (dbobj dbcon) checkXtoken(xtokenUUID string) bool { fmt.Println("It is a root token") return true } - record, err := dbobj.getRecord(TblName.Xtokens, "xtoken", xtokenHashed) + record, err := dbobj.store.GetRecord(storage.TblName.Xtokens, "xtoken", xtokenHashed) if record == nil || err != nil { return false } @@ -99,7 +100,7 @@ func (dbobj dbcon) checkUserAuthXToken(xtokenUUID string) (tokenAuthResult, erro result.name = "root" return result, nil } - record, err := dbobj.getRecord(TblName.Xtokens, "xtoken", xtokenHashed) + record, err := dbobj.store.GetRecord(storage.TblName.Xtokens, "xtoken", xtokenHashed) if record == nil || err != nil { return result, errors.New("failed to authenticate") }