From cb4ff2d2beda7ccbae0e7b9a333a4d8e2c3a505a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 7 Jun 2021 20:48:25 +0000 Subject: [PATCH] format code with gofmt --- src/agreements_api.go | 34 +-- src/agreements_db.go | 30 ++- src/audit_api.go | 45 ++-- src/audit_db.go | 79 +++--- src/autocontext/context.go | 13 +- src/bunker.go | 46 ++-- src/captcha.go | 215 ++++++++-------- src/conf.go | 88 +++---- src/consent_test.go | 14 +- src/cryptor.go | 77 +++--- src/email.go | 71 +++--- src/expiration_api.go | 3 +- src/lbasis_api.go | 204 ++++++++-------- src/lbasis_db.go | 39 ++- src/pactivities_api.go | 2 +- src/pactivities_db.go | 2 +- src/requests_api.go | 10 +- src/schema.go | 444 +++++++++++++++++----------------- src/sessions_api.go | 38 +-- src/sessions_db.go | 7 +- src/sms.go | 2 +- src/storage/mysql-storage.go | 274 +++++++++++---------- src/storage/sqlite-storage.go | 60 +++-- src/storage/storage.go | 201 ++++++++------- src/users_api.go | 16 +- src/users_db.go | 162 ++++++------- src/utils.go | 52 ++-- src/xtokens_test.go | 18 +- 28 files changed, 1117 insertions(+), 1129 deletions(-) diff --git a/src/agreements_api.go b/src/agreements_api.go index bc82058..edd65e8 100644 --- a/src/agreements_api.go +++ b/src/agreements_api.go @@ -49,7 +49,7 @@ func (e mainEnv) agreementAccept(w http.ResponseWriter, r *http.Request, ps http userTOKEN = address } else { userBson, err := e.db.lookupUserRecordByIndex(mode, address, e.conf) - if err != nil { + if err != nil { returnError(w, r, "internal error", 405, err, event) return } @@ -83,10 +83,10 @@ func (e mainEnv) agreementAccept(w http.ResponseWriter, r *http.Request, ps http } if value, ok := records["expiration"]; ok { switch records["expiration"].(type) { - case string: - expiration, _ = parseExpiration(value.(string)) - case float64: - expiration = int32(value.(float64)) + case string: + expiration, _ = parseExpiration(value.(string)) + case float64: + expiration = int32(value.(float64)) } } if value, ok := records["starttime"]; ok { @@ -107,15 +107,15 @@ func (e mainEnv) agreementAccept(w http.ResponseWriter, r *http.Request, ps http e.db.acceptAgreement(userTOKEN, mode, address, brief, status, agreementmethod, referencecode, lastmodifiedby, starttime, expiration) /* - notifyURL := e.conf.Notification.NotificationURL - if newStatus == true && len(notifyURL) > 0 { - // change notificate on new record or if status change - if len(userTOKEN) > 0 { - notifyConsentChange(notifyURL, brief, status, "token", userTOKEN) - } else { - notifyConsentChange(notifyURL, brief, status, mode, address) + notifyURL := e.conf.Notification.NotificationURL + if newStatus == true && len(notifyURL) > 0 { + // change notificate on new record or if status change + if len(userTOKEN) > 0 { + notifyConsentChange(notifyURL, brief, status, "token", userTOKEN) + } else { + notifyConsentChange(notifyURL, brief, status, mode, address) + } } - } */ w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(200) @@ -146,7 +146,7 @@ func (e mainEnv) agreementWithdraw(w http.ResponseWriter, r *http.Request, ps ht } if lbasis == nil { returnError(w, r, "not found", 405, nil, event) - return + return } userTOKEN := "" authResult := "" @@ -252,9 +252,9 @@ func (e mainEnv) agreementRevokeAll(w http.ResponseWriter, r *http.Request, ps h } if exists == false { returnError(w, r, "not found", 405, nil, nil) - return + return } - e.db.revokeLegalBasis(brief); + e.db.revokeLegalBasis(brief) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(200) w.Write([]byte(`{"status":"ok"}`)) @@ -312,7 +312,7 @@ func (e mainEnv) getUserAgreements(w http.ResponseWriter, r *http.Request, ps ht var err error if len(userTOKEN) > 0 { resultJSON, numRecords, err = e.db.listAgreementRecords(userTOKEN) - } else{ + } else { resultJSON, numRecords, err = e.db.listAgreementRecordsByIdentity(address) } if err != nil { diff --git a/src/agreements_db.go b/src/agreements_db.go index e81201b..61c1d61 100644 --- a/src/agreements_db.go +++ b/src/agreements_db.go @@ -114,7 +114,6 @@ func (dbobj dbcon) withdrawAgreement(userTOKEN string, brief string, mode string return nil } - func (dbobj dbcon) listAgreementRecords(userTOKEN string) ([]byte, int, error) { records, err := dbobj.store.GetList(storage.TblName.Agreements, "token", userTOKEN, 0, 0, "") if err != nil { @@ -133,23 +132,22 @@ func (dbobj dbcon) listAgreementRecords(userTOKEN string) ([]byte, int, error) { } func (dbobj dbcon) listAgreementRecordsByIdentity(identity string) ([]byte, int, error) { - records, err := dbobj.store.GetList(storage.TblName.Agreements, "who", identity, 0, 0, "") - if err != nil { - return nil, 0, err - } - count := len(records) - if count == 0 { - return []byte("[]"), 0, err - } - resultJSON, err := json.Marshal(records) - if err != nil { - return nil, 0, err - } - //fmt.Printf("Found multiple documents (array of pointers): %+v\n", results) - return resultJSON, count, nil + records, err := dbobj.store.GetList(storage.TblName.Agreements, "who", identity, 0, 0, "") + if err != nil { + return nil, 0, err + } + count := len(records) + if count == 0 { + return []byte("[]"), 0, err + } + resultJSON, err := json.Marshal(records) + if err != nil { + return nil, 0, err + } + //fmt.Printf("Found multiple documents (array of pointers): %+v\n", results) + return resultJSON, count, nil } - func (dbobj dbcon) viewAgreementRecord(userTOKEN string, brief string) ([]byte, error) { record, err := dbobj.store.GetRecord2(storage.TblName.Agreements, "token", userTOKEN, "brief", brief) if record == nil || err != nil { diff --git a/src/audit_api.go b/src/audit_api.go index 52f1cfa..14a6817 100644 --- a/src/audit_api.go +++ b/src/audit_api.go @@ -41,31 +41,30 @@ func (e mainEnv) getAuditEvents(w http.ResponseWriter, r *http.Request, ps httpr func (e mainEnv) getAdminAuditEvents(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { authResult := e.enforceAdmin(w, r) - if authResult == "" { - return - } - var offset int32 - var limit int32 = 10 - args := r.URL.Query() - if value, ok := args["offset"]; ok { - offset = atoi(value[0]) - } - if value, ok := args["limit"]; ok { - limit = atoi(value[0]) - } - resultJSON, counter, err := e.db.getAdminAuditEvents(offset, limit) - if err != nil { - returnError(w, r, "internal error", 405, err, nil) - return - } - //fmt.Printf("Total count of events: %d\n", counter) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - str := fmt.Sprintf(`{"status":"ok","total":%d,"rows":%s}`, counter, resultJSON) - w.Write([]byte(str)) + if authResult == "" { + return + } + var offset int32 + var limit int32 = 10 + args := r.URL.Query() + if value, ok := args["offset"]; ok { + offset = atoi(value[0]) + } + if value, ok := args["limit"]; ok { + limit = atoi(value[0]) + } + resultJSON, counter, err := e.db.getAdminAuditEvents(offset, limit) + if err != nil { + returnError(w, r, "internal error", 405, err, nil) + return + } + //fmt.Printf("Total count of events: %d\n", counter) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + str := fmt.Sprintf(`{"status":"ok","total":%d,"rows":%s}`, counter, resultJSON) + w.Write([]byte(str)) } - func (e mainEnv) getAuditEvent(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { atoken := ps.ByName("atoken") event := audit("view audit event", atoken, "token", atoken) diff --git a/src/audit_db.go b/src/audit_db.go index 86299d4..589af51 100644 --- a/src/audit_db.go +++ b/src/audit_db.go @@ -120,49 +120,48 @@ func (dbobj dbcon) getAuditEvents(userTOKEN string, offset int32, limit int32) ( } func (dbobj dbcon) getAdminAuditEvents(offset int32, limit int32) ([]byte, int64, error) { - count, err := dbobj.store.CountRecords0(storage.TblName.Audit) - if err != nil { - return nil, 0, err - } - if count == 0 { - return []byte("[]"), 0, err - } - var results []bson.M - records, err := dbobj.store.GetList0(storage.TblName.Audit, offset, limit, "when") - if err != nil { - return nil, 0, err - } + count, err := dbobj.store.CountRecords0(storage.TblName.Audit) + if err != nil { + return nil, 0, err + } + if count == 0 { + return []byte("[]"), 0, err + } + var results []bson.M + records, err := dbobj.store.GetList0(storage.TblName.Audit, offset, limit, "when") + if err != nil { + return nil, 0, err + } code := dbobj.GetCode() - for _, element := range records { - element["more"] = false - if _, ok := element["before"]; ok { - element["more"] = true - element["before"] = "" - } - if _, ok := element["after"]; ok { - element["more"] = true - element["after"] = "" - } - if _, ok := element["debug"]; ok { - element["more"] = true - element["debug"] = "" - } - if _, ok := element["record"]; ok { - element["record"], _ = basicStringDecrypt(element["record"].(string), dbobj.masterKey, code) - } - if _, ok := element["who"]; ok { - element["who"], _ = basicStringDecrypt(element["who"].(string), dbobj.masterKey, code) - } - results = append(results, element) - } - resultJSON, err := json.Marshal(records) - if err != nil { - return nil, 0, err - } - return resultJSON, count, nil + for _, element := range records { + element["more"] = false + if _, ok := element["before"]; ok { + element["more"] = true + element["before"] = "" + } + if _, ok := element["after"]; ok { + element["more"] = true + element["after"] = "" + } + if _, ok := element["debug"]; ok { + element["more"] = true + element["debug"] = "" + } + if _, ok := element["record"]; ok { + element["record"], _ = basicStringDecrypt(element["record"].(string), dbobj.masterKey, code) + } + if _, ok := element["who"]; ok { + element["who"], _ = basicStringDecrypt(element["who"].(string), dbobj.masterKey, code) + } + results = append(results, element) + } + resultJSON, err := json.Marshal(records) + if err != nil { + return nil, 0, err + } + return resultJSON, count, nil } - func (dbobj dbcon) getAuditEvent(atoken string) (string, []byte, error) { //var results []*auditEvent record, err := dbobj.store.GetRecord(storage.TblName.Audit, "atoken", atoken) diff --git a/src/autocontext/context.go b/src/autocontext/context.go index 2b67504..8a3fff0 100644 --- a/src/autocontext/context.go +++ b/src/autocontext/context.go @@ -1,17 +1,17 @@ package autocontext import ( - "fmt" "errors" - "regexp" + "fmt" "net/http" - "sync" + "regexp" "runtime" + "sync" ) var ( - contextMutex sync.Mutex - contextData = make(map[string]map[string]interface{}) + contextMutex sync.Mutex + contextData = make(map[string]map[string]interface{}) regexServeHTTP = regexp.MustCompile("\\.ServeHTTP\\(0x[a-fA-F0-9]+, 0x[a-fA-F0-9]+, 0x[a-fA-F0-9]+, (0x[a-fA-F0-9]+)\\)") ) @@ -37,7 +37,7 @@ func Get(r *http.Request, key string) interface{} { return nil } -// GetAuto ruturns value from current *http.Request context. It is automatically extracted from stacktrace. +// GetAuto ruturns value from current *http.Request context. It is automatically extracted from stacktrace. func GetAuto(key string) interface{} { addr, err := getRequestAddress() if err != nil { @@ -79,4 +79,3 @@ func getRequestAddress() (string, error) { //fmt.Printf("*** extracted address from stacktrace: %s\n", match[1]) return match[1], nil } - diff --git a/src/bunker.go b/src/bunker.go index 55ee385..169c7d1 100644 --- a/src/bunker.go +++ b/src/bunker.go @@ -22,10 +22,10 @@ import ( "github.com/gobuffalo/packr" "github.com/julienschmidt/httprouter" "github.com/kelseyhightower/envconfig" - "github.com/securitybunker/databunker/src/autocontext" - "github.com/securitybunker/databunker/src/storage" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/securitybunker/databunker/src/autocontext" + "github.com/securitybunker/databunker/src/storage" yaml "gopkg.in/yaml.v2" ) @@ -110,12 +110,12 @@ type mainEnv struct { // userJSON used to parse user POST type userJSON struct { - jsonData []byte - loginIdx string - emailIdx string - phoneIdx string + jsonData []byte + loginIdx string + emailIdx string + phoneIdx string customIdx string - token string + token string } type tokenAuthResult struct { @@ -255,12 +255,12 @@ func (e mainEnv) setupRouter() *httprouter.Router { } else { w.WriteHeader(200) captcha, err := generateCaptcha() - if err != nil { - w.WriteHeader(501) - } else { - data2 := bytes.ReplaceAll(data, []byte("%CAPTCHAURL%"), []byte(captcha)) - w.Write(data2) - } + if err != nil { + w.WriteHeader(501) + } else { + data2 := bytes.ReplaceAll(data, []byte("%CAPTCHAURL%"), []byte(captcha)) + w.Write(data2) + } } }) router.GET("/site/*filepath", func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { @@ -295,7 +295,7 @@ func (e mainEnv) setupRouter() *httprouter.Router { header := w.Header() header.Set("Access-Control-Allow-Methods", "POST, PUT, DELETE") header.Set("Access-Control-Allow-Origin", "*") - header.Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-Bunker-Token"); + header.Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-Bunker-Token") //} // Adjust status code to 204 w.WriteHeader(http.StatusNoContent) @@ -402,7 +402,7 @@ func logRequest(handler http.Handler) http.Handler { if HealthCheckerCounter == 0 { log.Printf("%d %s %s skiping %s\n", w2.Code, r.Method, r.URL, r.Header.Get("User-Agent")) HealthCheckerCounter = 1 - } else if (HealthCheckerCounter == 100) { + } else if HealthCheckerCounter == 100 { HealthCheckerCounter = 0 } else { HealthCheckerCounter = HealthCheckerCounter + 1 @@ -446,13 +446,13 @@ func setupDB(dbPtr *string, masterKeyPtr *string, customRootToken string) (*dbco //log.Panic("error %s", err.Error()) fmt.Printf("error %s", err.Error()) } - log.Println("Creating default entities: core-send-email-on-login and core-send-sms-on-login"); + log.Println("Creating default entities: core-send-email-on-login and core-send-sms-on-login") db.createLegalBasis("core-send-email-on-login", "", "login", "Send email on login", - "Confirm to allow sending access code using 3rd party email gateway", "consent", - "This consent is required to give you our service.", "active", true, true); + "Confirm to allow sending access code using 3rd party email gateway", "consent", + "This consent is required to give you our service.", "active", true, true) db.createLegalBasis("core-send-sms-on-login", "", "login", "Send SMS on login", - "Confirm to allow sending access code using 3rd party SMS gateway", "consent", - "This consent is required to give you our service.", "active", true, true); + "Confirm to allow sending access code using 3rd party SMS gateway", "consent", + "This consent is required to give you our service.", "active", true, true) fmt.Printf("\nAPI Root token: %s\n\n", rootToken) return db, rootToken, err } @@ -475,8 +475,8 @@ func masterkeyGet(masterKeyPtr *string) ([]byte, error) { masterKeyStr = os.Getenv("DATABUNKER_MASTERKEY") } if len(masterKeyStr) == 0 { - return nil, errors.New("Master key environment variable/parameter is missing") - } + return nil, errors.New("Master key environment variable/parameter is missing") + } if len(masterKeyStr) != 48 { return nil, errors.New("Master key length is wrong") } @@ -508,7 +508,7 @@ func main() { readEnv(&cfg) customRootToken := "" if *demoPtr { - customRootToken = "DEMO" + customRootToken = "DEMO" } else if variableProvided("DATABUNKER_ROOTTOKEN", rootTokenKeyPtr) == true { if rootTokenKeyPtr != nil && len(*rootTokenKeyPtr) > 0 { customRootToken = *rootTokenKeyPtr diff --git a/src/captcha.go b/src/captcha.go index 1203908..9166adb 100644 --- a/src/captcha.go +++ b/src/captcha.go @@ -1,127 +1,126 @@ package main import ( - "log" - "fmt" - "time" - "errors" - "regexp" - "net/http" - "image/png" - "crypto/aes" - "crypto/cipher" - "encoding/hex" - "github.com/julienschmidt/httprouter" - "github.com/gobuffalo/packr" - "github.com/afocus/captcha" + "crypto/aes" + "crypto/cipher" + "encoding/hex" + "errors" + "fmt" + "github.com/afocus/captcha" + "github.com/gobuffalo/packr" + "github.com/julienschmidt/httprouter" + "image/png" + "log" + "net/http" + "regexp" + "time" ) var ( - comic []byte - captchaKey = make([]byte, 16) - regexCaptcha = regexp.MustCompile("^([a-zA-Z0-9]+):([0-9]+)$") + comic []byte + captchaKey = make([]byte, 16) + regexCaptcha = regexp.MustCompile("^([a-zA-Z0-9]+):([0-9]+)$") ) func (e mainEnv) showCaptcha(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - log.Printf("Starting showCaptcha fn") - code := ps.ByName("code") - if len(code) == 0 { - err := errors.New("Bad code") - returnError(w, r, "bad code", 405, err, nil) - return - } - s, err := decryptCaptcha(code) - if err != nil { - returnError(w, r, err.Error(), 405, err, nil) - return - } - log.Printf("Decoded captcha: %s", s) - //box := packr.NewBox("../ui") - //comic, err := box.Find("site/comic.ttf") - //if err != nil { - // returnError(w, r, err.Error(), 405, err, nil) - // return - //} - cap := captcha.New() - cap.SetSize(128, 64) - cap.AddFontFromBytes(comic) - img := cap.CreateCustom(s) - w.WriteHeader(200) - png.Encode(w, img) + log.Printf("Starting showCaptcha fn") + code := ps.ByName("code") + if len(code) == 0 { + err := errors.New("Bad code") + returnError(w, r, "bad code", 405, err, nil) + return + } + s, err := decryptCaptcha(code) + if err != nil { + returnError(w, r, err.Error(), 405, err, nil) + return + } + log.Printf("Decoded captcha: %s", s) + //box := packr.NewBox("../ui") + //comic, err := box.Find("site/comic.ttf") + //if err != nil { + // returnError(w, r, err.Error(), 405, err, nil) + // return + //} + cap := captcha.New() + cap.SetSize(128, 64) + cap.AddFontFromBytes(comic) + img := cap.CreateCustom(s) + w.WriteHeader(200) + png.Encode(w, img) } func initCaptcha(h [16]byte) { - var err error - copy(captchaKey[:], h[:]) - box := packr.NewBox("../ui") - comic, err = box.Find("site/comic.ttf") - if err != nil { - log.Fatalf("Failed to load font") - return - } - //captchaKey = h + var err error + copy(captchaKey[:], h[:]) + box := packr.NewBox("../ui") + comic, err = box.Find("site/comic.ttf") + if err != nil { + log.Fatalf("Failed to load font") + return + } + //captchaKey = h } func generateCaptcha() (string, error) { - code := randNum(6) - //log.Printf("Generate captcha code: %d", code) - now := int32(time.Now().Unix()) - s := fmt.Sprintf("%d:%d", code, now) - plaintext := []byte(s) - //log.Printf("Going to encrypt %s", plaintext) - nonce := []byte("$DataBunker$") - block, err := aes.NewCipher(captchaKey) - if err != nil { - log.Printf("error in aes.NewCipher %s", err) - return "", err - } - aesgcm, err := cipher.NewGCM(block) - if err != nil { - log.Printf("error in cipher.NewGCM: %s", err) - return "", err - } - ciphertext := aesgcm.Seal(nil, nonce, []byte(plaintext), nil) - result := hex.EncodeToString(ciphertext) - //log.Printf("Encoded captcha: %s", result) - //log.Printf("ciphertext : %s", result) - return result, nil + code := randNum(6) + //log.Printf("Generate captcha code: %d", code) + now := int32(time.Now().Unix()) + s := fmt.Sprintf("%d:%d", code, now) + plaintext := []byte(s) + //log.Printf("Going to encrypt %s", plaintext) + nonce := []byte("$DataBunker$") + block, err := aes.NewCipher(captchaKey) + if err != nil { + log.Printf("error in aes.NewCipher %s", err) + return "", err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + log.Printf("error in cipher.NewGCM: %s", err) + return "", err + } + ciphertext := aesgcm.Seal(nil, nonce, []byte(plaintext), nil) + result := hex.EncodeToString(ciphertext) + //log.Printf("Encoded captcha: %s", result) + //log.Printf("ciphertext : %s", result) + return result, nil } func decryptCaptcha(data string) (string, error) { - if len(data) > 100 { - return "", errors.New("Ciphertext too long") - } - ciphertext, err := hex.DecodeString(data) - if err != nil { - return "", err - } - nonce := []byte("$DataBunker$") - block, err := aes.NewCipher(captchaKey) - if err != nil { - return "", err - } - aesgcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return "", err - } - match := regexCaptcha.FindStringSubmatch(string(plaintext)) - if len(match) != 3 { - return "", errors.New("Failed to parse captcha") - } - code := match[1] - t := atoi(match[2]) - // check if time expired - now := int32(time.Now().Unix()) - if now > (t+120) { - return "", errors.New("Captcha expired") - } - if t > now { - return "", errors.New("Captcha from the future") - } - return code, nil + if len(data) > 100 { + return "", errors.New("Ciphertext too long") + } + ciphertext, err := hex.DecodeString(data) + if err != nil { + return "", err + } + nonce := []byte("$DataBunker$") + block, err := aes.NewCipher(captchaKey) + if err != nil { + return "", err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err + } + match := regexCaptcha.FindStringSubmatch(string(plaintext)) + if len(match) != 3 { + return "", errors.New("Failed to parse captcha") + } + code := match[1] + t := atoi(match[2]) + // check if time expired + now := int32(time.Now().Unix()) + if now > (t + 120) { + return "", errors.New("Captcha expired") + } + if t > now { + return "", errors.New("Captcha from the future") + } + return code, nil } - diff --git a/src/conf.go b/src/conf.go index 583ca93..7be8656 100644 --- a/src/conf.go +++ b/src/conf.go @@ -1,65 +1,65 @@ package main import ( - "encoding/json" - "fmt" - "net/http" + "encoding/json" + "fmt" + "net/http" - "github.com/julienschmidt/httprouter" -); + "github.com/julienschmidt/httprouter" +) -func (e mainEnv) setupConfRouter(router *httprouter.Router ) *httprouter.Router { - router.GET("/v1/sys/configuration", e.configurationDump) - router.GET("/v1/sys/uiconfiguration", e.uiConfigurationDump) - router.GET("/v1/sys/cookiesettings", e.cookieSettings) - return router +func (e mainEnv) setupConfRouter(router *httprouter.Router) *httprouter.Router { + router.GET("/v1/sys/configuration", e.configurationDump) + router.GET("/v1/sys/uiconfiguration", e.uiConfigurationDump) + router.GET("/v1/sys/cookiesettings", e.cookieSettings) + return router } func (e mainEnv) cookieSettings(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - resultJSON, scriptsJSON, _, err := e.db.getLegalBasisCookieConf() - if err != nil { - returnError(w, r, "internal error", 405, err, nil) - return - } - resultUIConfJSON, _ := json.Marshal(e.conf.UI) - finalJSON := fmt.Sprintf(`{"status":"ok","ui":%s,"rows":%s,"scripts":%s}`, resultUIConfJSON, resultJSON, scriptsJSON) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - w.Write([]byte(finalJSON)) + resultJSON, scriptsJSON, _, err := e.db.getLegalBasisCookieConf() + if err != nil { + returnError(w, r, "internal error", 405, err, nil) + return + } + resultUIConfJSON, _ := json.Marshal(e.conf.UI) + finalJSON := fmt.Sprintf(`{"status":"ok","ui":%s,"rows":%s,"scripts":%s}`, resultUIConfJSON, resultJSON, scriptsJSON) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + w.Write([]byte(finalJSON)) } func (e mainEnv) configurationDump(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - if e.enforceAuth(w, r, nil) == "" { - return - } - resultJSON, _ := json.Marshal(e.conf) - finalJSON := fmt.Sprintf(`{"status":"ok","configuration":%s}`, resultJSON) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - w.Write([]byte(finalJSON)) + if e.enforceAuth(w, r, nil) == "" { + return + } + resultJSON, _ := json.Marshal(e.conf) + finalJSON := fmt.Sprintf(`{"status":"ok","configuration":%s}`, resultJSON) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + w.Write([]byte(finalJSON)) } // UI configuration dump API call. func (e mainEnv) uiConfigurationDump(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - if len(e.conf.Notification.MagicSyncURL) != 0 && - len(e.conf.Notification.MagicSyncToken) != 0 { - e.conf.UI.MagicLookup = true - } else { - e.conf.UI.MagicLookup = false - } - resultJSON, _ := json.Marshal(e.conf.UI) - finalJSON := fmt.Sprintf(`{"status":"ok","ui":%s}`, resultJSON) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - w.Write([]byte(finalJSON)) + if len(e.conf.Notification.MagicSyncURL) != 0 && + len(e.conf.Notification.MagicSyncToken) != 0 { + e.conf.UI.MagicLookup = true + } else { + e.conf.UI.MagicLookup = false + } + resultJSON, _ := json.Marshal(e.conf.UI) + finalJSON := fmt.Sprintf(`{"status":"ok","ui":%s}`, resultJSON) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + w.Write([]byte(finalJSON)) } func (e mainEnv) globalUserDelete(userTOKEN string) { - // not implemented + // not implemented } func (dbobj dbcon) GetTenantAdmin(cfg Config) string { - return cfg.Generic.AdminEmail + return cfg.Generic.AdminEmail } func (e mainEnv) pluginUserDelete(pluginid string, userTOKEN string) { @@ -71,10 +71,10 @@ func (e mainEnv) pluginUserLookup(email string) { } func (dbobj dbcon) GlobalUserChangeEmail(oldEmail string, newEmail string) { - // not implemented + // not implemented } func (dbobj dbcon) GetCode() []byte { - code := dbobj.hash[4:12] - return code + code := dbobj.hash[4:12] + return code } diff --git a/src/consent_test.go b/src/consent_test.go index fe725eb..78a170c 100644 --- a/src/consent_test.go +++ b/src/consent_test.go @@ -151,13 +151,13 @@ func TestConsentCreateWithdraw(t *testing.T) { t.Fatalf("Wrong number of briefs") } /* - raw, _ = helpGetAllUsersByBrief(brief) - if _, ok := raw["status"]; !ok || raw["status"].(string) != "ok" { - t.Fatalf("Failed to get user consents") - } - if raw["total"].(float64) != 1 { - t.Fatalf("Wrong number of briefs") - } + raw, _ = helpGetAllUsersByBrief(brief) + if _, ok := raw["status"]; !ok || raw["status"].(string) != "ok" { + t.Fatalf("Failed to get user consents") + } + if raw["total"].(float64) != 1 { + t.Fatalf("Wrong number of briefs") + } */ } diff --git a/src/cryptor.go b/src/cryptor.go index 4e301d8..e7ca3c7 100644 --- a/src/cryptor.go +++ b/src/cryptor.go @@ -5,8 +5,8 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" - "log" "io" + "log" ) // shamir secret split @@ -78,45 +78,44 @@ func encrypt(masterKey []byte, userKey []byte, plaintext []byte) ([]byte, error) } func basicStringEncrypt(plaintext string, masterKey []byte, code []byte) (string, error) { - //log.Printf("Going to encrypt %s", plaintext) - nonce := []byte("$DataBunker$") - key := append(masterKey, code...) - block, err := aes.NewCipher(key) - if err != nil { - log.Printf("error in aes.NewCipher %s", err) - return "", err - } - aesgcm, err := cipher.NewGCM(block) - if err != nil { - log.Printf("error in cipher.NewGCM: %s", err) - return "", err - } - ciphertext := aesgcm.Seal(nil, nonce, []byte(plaintext), nil) - result := base64.StdEncoding.EncodeToString(ciphertext) - //log.Printf("ciphertext : %s", result) - return result, nil + //log.Printf("Going to encrypt %s", plaintext) + nonce := []byte("$DataBunker$") + key := append(masterKey, code...) + block, err := aes.NewCipher(key) + if err != nil { + log.Printf("error in aes.NewCipher %s", err) + return "", err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + log.Printf("error in cipher.NewGCM: %s", err) + return "", err + } + ciphertext := aesgcm.Seal(nil, nonce, []byte(plaintext), nil) + result := base64.StdEncoding.EncodeToString(ciphertext) + //log.Printf("ciphertext : %s", result) + return result, nil } func basicStringDecrypt(data string, masterKey []byte, code []byte) (string, error) { - ciphertext, err := base64.StdEncoding.DecodeString(data) - if err != nil { - return "", err - } - nonce := []byte("$DataBunker$") - key := append(masterKey, code...) - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - aesgcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return "", err - } - //log.Printf("decrypt result : %s", string(plaintext)) - return string(plaintext), err + ciphertext, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return "", err + } + nonce := []byte("$DataBunker$") + key := append(masterKey, code...) + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err + } + //log.Printf("decrypt result : %s", string(plaintext)) + return string(plaintext), err } - diff --git a/src/email.go b/src/email.go index 684aade..f5b091b 100644 --- a/src/email.go +++ b/src/email.go @@ -1,46 +1,45 @@ package main import ( - "fmt" - "net/smtp" - "strconv" - "strings" + "fmt" + "net/smtp" + "strconv" + "strings" ) func sendCodeByEmail(code int32, address string, cfg Config) { - Dest := []string{address} - Subject := "Access Code" - bodyMessage := "Access code is " + strconv.Itoa(int((code))) - msg := "From: " + cfg.SMTP.Sender + "\n" + - "To: " + strings.Join(Dest, ",") + "\n" + - "Subject: " + Subject + "\n" + bodyMessage - auth := smtp.PlainAuth("", cfg.SMTP.User, cfg.SMTP.Pass, cfg.SMTP.Server) - err := smtp.SendMail(cfg.SMTP.Server+":"+cfg.SMTP.Port, - auth, cfg.SMTP.User, Dest, []byte(msg)) - if err != nil { - fmt.Printf("smtp error: %s", err) - return - } - fmt.Println("Mail sent successfully!") + Dest := []string{address} + Subject := "Access Code" + bodyMessage := "Access code is " + strconv.Itoa(int((code))) + msg := "From: " + cfg.SMTP.Sender + "\n" + + "To: " + strings.Join(Dest, ",") + "\n" + + "Subject: " + Subject + "\n" + bodyMessage + auth := smtp.PlainAuth("", cfg.SMTP.User, cfg.SMTP.Pass, cfg.SMTP.Server) + err := smtp.SendMail(cfg.SMTP.Server+":"+cfg.SMTP.Port, + auth, cfg.SMTP.User, Dest, []byte(msg)) + if err != nil { + fmt.Printf("smtp error: %s", err) + return + } + fmt.Println("Mail sent successfully!") } func adminEmailAlert(action string, adminEmail string, cfg Config) { - if len(adminEmail) == 0 { - return - } - Dest := []string{adminEmail} - Subject := "Data Subject request received" - bodyMessage := "Request: " + action - msg := "From: " + cfg.SMTP.Sender + "\n" + - "To: " + strings.Join(Dest, ",") + "\n" + - "Subject: " + Subject + "\n" + bodyMessage - auth := smtp.PlainAuth("", cfg.SMTP.User, cfg.SMTP.Pass, cfg.SMTP.Server) - err := smtp.SendMail(cfg.SMTP.Server+":"+cfg.SMTP.Port, - auth, cfg.SMTP.User, Dest, []byte(msg)) - if err != nil { - fmt.Printf("smtp error: %s", err) - return - } - fmt.Println("Mail sent successfully!") + if len(adminEmail) == 0 { + return + } + Dest := []string{adminEmail} + Subject := "Data Subject request received" + bodyMessage := "Request: " + action + msg := "From: " + cfg.SMTP.Sender + "\n" + + "To: " + strings.Join(Dest, ",") + "\n" + + "Subject: " + Subject + "\n" + bodyMessage + auth := smtp.PlainAuth("", cfg.SMTP.User, cfg.SMTP.Pass, cfg.SMTP.Server) + err := smtp.SendMail(cfg.SMTP.Server+":"+cfg.SMTP.Port, + auth, cfg.SMTP.User, Dest, []byte(msg)) + if err != nil { + fmt.Printf("smtp error: %s", err) + return + } + fmt.Println("Mail sent successfully!") } - diff --git a/src/expiration_api.go b/src/expiration_api.go index d122bd7..eb5923b 100644 --- a/src/expiration_api.go +++ b/src/expiration_api.go @@ -5,8 +5,8 @@ import ( "net/http" uuid "github.com/hashicorp/go-uuid" - "github.com/securitybunker/databunker/src/storage" "github.com/julienschmidt/httprouter" + "github.com/securitybunker/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" ) @@ -206,4 +206,3 @@ func (e mainEnv) expInitiate(w http.ResponseWriter, r *http.Request, ps httprout w.WriteHeader(200) w.Write([]byte(finalJSON)) } - diff --git a/src/lbasis_api.go b/src/lbasis_api.go index 8792773..02aa543 100644 --- a/src/lbasis_api.go +++ b/src/lbasis_api.go @@ -1,116 +1,116 @@ package main import ( - "fmt" - "net/http" - "reflect" + "fmt" + "net/http" + "reflect" - "github.com/julienschmidt/httprouter" - //"go.mongodb.org/mongo-driver/bson" + "github.com/julienschmidt/httprouter" + //"go.mongodb.org/mongo-driver/bson" ) func (e mainEnv) createLegalBasis(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - brief := ps.ByName("brief") - authResult := e.enforceAdmin(w, r) - if authResult == "" { - return - } - brief = normalizeBrief(brief) - if isValidBrief(brief) == false { - returnError(w, r, "bad brief format", 405, nil, nil) - return - } - records, err := getJSONPostData(r) - if err != nil { - returnError(w, r, "failed to decode request body", 405, err, nil) - return - } - newbrief := getStringValue(records["brief"]) - if len(newbrief) > 0 && newbrief != brief { - if isValidBrief(newbrief) == false { - returnError(w, r, "bad brief format", 405, nil, nil) - return - } - } - status := getStringValue(records["status"]) - module := getStringValue(records["module"]) - fulldesc := getStringValue(records["fulldesc"]) - shortdesc := getStringValue(records["shortdesc"]) - basistype := getStringValue(records["basistype"]) - requiredmsg := getStringValue(records["requiredmsg"]) - usercontrol := false - requiredflag := false - if status != "disabled" { - status = "active" - } - if value, ok := records["usercontrol"]; ok { - if reflect.TypeOf(value).Kind() == reflect.Bool { - usercontrol = value.(bool) - } else { - num := value.(int32) - if num > 0 { - usercontrol = true - } - } - } - if value, ok := records["requiredflag"]; ok { - if reflect.TypeOf(value).Kind() == reflect.Bool { - requiredflag = value.(bool) - } else { - num := value.(int32) - if num > 0 { - requiredflag = true - } - } - } - e.db.createLegalBasis(brief, newbrief, module, shortdesc, fulldesc, basistype, requiredmsg, status, usercontrol, requiredflag) - /* - notifyURL := e.conf.Notification.NotificationURL - if newStatus == true && len(notifyURL) > 0 { - // change notificate on new record or if status change - if len(userTOKEN) > 0 { - notifyConsentChange(notifyURL, brief, status, "token", userTOKEN) - } else { - notifyConsentChange(notifyURL, brief, status, mode, address) - } - } - */ - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - w.Write([]byte(`{"status":"ok"}`)) + brief := ps.ByName("brief") + authResult := e.enforceAdmin(w, r) + if authResult == "" { + return + } + brief = normalizeBrief(brief) + if isValidBrief(brief) == false { + returnError(w, r, "bad brief format", 405, nil, nil) + return + } + records, err := getJSONPostData(r) + if err != nil { + returnError(w, r, "failed to decode request body", 405, err, nil) + return + } + newbrief := getStringValue(records["brief"]) + if len(newbrief) > 0 && newbrief != brief { + if isValidBrief(newbrief) == false { + returnError(w, r, "bad brief format", 405, nil, nil) + return + } + } + status := getStringValue(records["status"]) + module := getStringValue(records["module"]) + fulldesc := getStringValue(records["fulldesc"]) + shortdesc := getStringValue(records["shortdesc"]) + basistype := getStringValue(records["basistype"]) + requiredmsg := getStringValue(records["requiredmsg"]) + usercontrol := false + requiredflag := false + if status != "disabled" { + status = "active" + } + if value, ok := records["usercontrol"]; ok { + if reflect.TypeOf(value).Kind() == reflect.Bool { + usercontrol = value.(bool) + } else { + num := value.(int32) + if num > 0 { + usercontrol = true + } + } + } + if value, ok := records["requiredflag"]; ok { + if reflect.TypeOf(value).Kind() == reflect.Bool { + requiredflag = value.(bool) + } else { + num := value.(int32) + if num > 0 { + requiredflag = true + } + } + } + e.db.createLegalBasis(brief, newbrief, module, shortdesc, fulldesc, basistype, requiredmsg, status, usercontrol, requiredflag) + /* + notifyURL := e.conf.Notification.NotificationURL + if newStatus == true && len(notifyURL) > 0 { + // change notificate on new record or if status change + if len(userTOKEN) > 0 { + notifyConsentChange(notifyURL, brief, status, "token", userTOKEN) + } else { + notifyConsentChange(notifyURL, brief, status, mode, address) + } + } + */ + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + w.Write([]byte(`{"status":"ok"}`)) } func (e mainEnv) deleteLegalBasis(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - brief := ps.ByName("brief") - authResult := e.enforceAdmin(w, r) - if authResult == "" { - return - } - brief = normalizeBrief(brief) - if isValidBrief(brief) == false { - returnError(w, r, "bad brief format", 405, nil, nil) - return - } - e.db.unlinkProcessingActivityBrief(brief) - e.db.deleteLegalBasis(brief); - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - w.Write([]byte(`{"status":"ok"}`)) + brief := ps.ByName("brief") + authResult := e.enforceAdmin(w, r) + if authResult == "" { + return + } + brief = normalizeBrief(brief) + if isValidBrief(brief) == false { + returnError(w, r, "bad brief format", 405, nil, nil) + return + } + e.db.unlinkProcessingActivityBrief(brief) + e.db.deleteLegalBasis(brief) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + w.Write([]byte(`{"status":"ok"}`)) } func (e mainEnv) listLegalBasisRecords(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - authResult := e.enforceAdmin(w, r) - if authResult == "" { - return - } - resultJSON, numRecords, err := e.db.getLegalBasisRecords() - if err != nil { - returnError(w, r, "internal error", 405, err, nil) - return - } - fmt.Printf("Total count of rows: %d\n", numRecords) - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - str := fmt.Sprintf(`{"status":"ok","total":%d,"rows":%s}`, numRecords, resultJSON) - w.Write([]byte(str)) + authResult := e.enforceAdmin(w, r) + if authResult == "" { + return + } + resultJSON, numRecords, err := e.db.getLegalBasisRecords() + if err != nil { + returnError(w, r, "internal error", 405, err, nil) + return + } + fmt.Printf("Total count of rows: %d\n", numRecords) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + str := fmt.Sprintf(`{"status":"ok","total":%d,"rows":%s}`, numRecords, resultJSON) + w.Write([]byte(str)) } diff --git a/src/lbasis_db.go b/src/lbasis_db.go index f5c3992..0992ace 100644 --- a/src/lbasis_db.go +++ b/src/lbasis_db.go @@ -11,24 +11,24 @@ import ( ) type legalBasis struct { - Brief string `json:"brief" structs:"brief"` - Status string `json:"status" structs:"status"` - Module string `json:"module,omitempty" structs:"module,omitempty"` - Shortdesc string `json:"shortdesc,omitempty" structs:"shortdesc,omitempty"` - Fulldesc string `json:"fulldesc,omitempty" structs:"fulldesc,omitempty"` - Basistype string `json:"basistype,omitempty" structs:"basistype"` - Requiredmsg string `json:"requiredmsg,omitempty" structs:"requiredmsg,omitempty"` - Usercontrol bool `json:"usercontrol" structs:"usercontrol"` - Requiredflag bool `json:"requiredflag" structs:"requiredflag"` - Creationtime int32 `json:"creationtime" structs:"creationtime"` + Brief string `json:"brief" structs:"brief"` + Status string `json:"status" structs:"status"` + Module string `json:"module,omitempty" structs:"module,omitempty"` + Shortdesc string `json:"shortdesc,omitempty" structs:"shortdesc,omitempty"` + Fulldesc string `json:"fulldesc,omitempty" structs:"fulldesc,omitempty"` + Basistype string `json:"basistype,omitempty" structs:"basistype"` + Requiredmsg string `json:"requiredmsg,omitempty" structs:"requiredmsg,omitempty"` + Usercontrol bool `json:"usercontrol" structs:"usercontrol"` + Requiredflag bool `json:"requiredflag" structs:"requiredflag"` + Creationtime int32 `json:"creationtime" structs:"creationtime"` } -func (dbobj dbcon) createLegalBasis(brief string, newbrief string, module string, shortdesc string, - fulldesc string, basistype string, requiredmsg string, status string, +func (dbobj dbcon) createLegalBasis(brief string, newbrief string, module string, shortdesc string, + fulldesc string, basistype string, requiredmsg string, status string, usercontrol bool, requiredflag bool) (bool, error) { bdoc := bson.M{} bdoc["basistype"] = basistype - bdoc["module"] = module + bdoc["module"] = module bdoc["shortdesc"] = shortdesc bdoc["fulldesc"] = fulldesc if requiredflag == true { @@ -36,7 +36,7 @@ func (dbobj dbcon) createLegalBasis(brief string, newbrief string, module string } else { bdoc["requiredmsg"] = "" } - bdoc["status"] = status; + bdoc["status"] = status bdoc["usercontrol"] = usercontrol bdoc["requiredflag"] = requiredflag raw, err := dbobj.store.GetRecord(storage.TblName.Legalbasis, "brief", brief) @@ -46,7 +46,7 @@ func (dbobj dbcon) createLegalBasis(brief string, newbrief string, module string } if raw != nil { if len(newbrief) > 0 && newbrief != brief { - bdoc["brief"] = newbrief; + bdoc["brief"] = newbrief } dbobj.store.UpdateRecord(storage.TblName.Legalbasis, "brief", brief, &bdoc) return true, nil @@ -77,7 +77,7 @@ func (dbobj dbcon) deleteLegalBasis(brief string) (bool, error) { bdoc := bson.M{} now := int32(time.Now().Unix()) bdoc["when"] = now - bdoc["status"] = "revoked" + bdoc["status"] = "revoked" dbobj.store.UpdateRecord2(storage.TblName.Agreements, "brief", brief, "status", "yes", &bdoc, nil) bdoc2 := bson.M{} bdoc2["status"] = "deleted" @@ -85,7 +85,6 @@ func (dbobj dbcon) deleteLegalBasis(brief string) (bool, error) { return true, nil } - func (dbobj dbcon) revokeLegalBasis(brief string) (bool, error) { // look up for user with this legal basis bdoc := bson.M{} @@ -98,7 +97,7 @@ func (dbobj dbcon) revokeLegalBasis(brief string) (bool, error) { } func (dbobj dbcon) getLegalBasisCookieConf() ([]byte, []byte, int, error) { - records, err := dbobj.store.GetList(storage.TblName.Legalbasis, "status", "active", 0,0, "requiredflag") + records, err := dbobj.store.GetList(storage.TblName.Legalbasis, "status", "active", 0, 0, "requiredflag") if err != nil { return nil, nil, 0, err } @@ -139,8 +138,8 @@ func (dbobj dbcon) getLegalBasisCookieConf() ([]byte, []byte, int, error) { } if len(found) > 0 { bdoc := bson.M{} - bdoc["script"]= record["script"] - bdoc["briefs"] = found; + bdoc["script"] = record["script"] + bdoc["briefs"] = found //fmt.Println("appending bdoc script") scripts = append(scripts, bdoc) } diff --git a/src/pactivities_api.go b/src/pactivities_api.go index 4acdc16..7d40401 100644 --- a/src/pactivities_api.go +++ b/src/pactivities_api.go @@ -109,7 +109,7 @@ func (e mainEnv) pactivityLink(w http.ResponseWriter, r *http.Request, ps httpro } if exists == false { returnError(w, r, "not found", 405, nil, nil) - return + return } _, err = e.db.linkProcessingActivity(activity, brief) if err != nil { diff --git a/src/pactivities_db.go b/src/pactivities_db.go index a5d3987..42d2531 100644 --- a/src/pactivities_db.go +++ b/src/pactivities_db.go @@ -38,7 +38,7 @@ func (dbobj dbcon) createProcessingActivity(activity string, newactivity string, } if raw != nil { if len(newactivity) > 0 && newactivity != activity { - bdoc["activity"] = newactivity; + bdoc["activity"] = newactivity } _, err = dbobj.store.UpdateRecord(storage.TblName.Processingactivities, "activity", activity, &bdoc) return false, err diff --git a/src/requests_api.go b/src/requests_api.go index a780a99..2fb2b22 100644 --- a/src/requests_api.go +++ b/src/requests_api.go @@ -3,8 +3,8 @@ package main import ( "errors" "fmt" - "strings" "net/http" + "strings" "github.com/julienschmidt/httprouter" "go.mongodb.org/mongo-driver/bson" @@ -176,7 +176,7 @@ func (e mainEnv) approveUserRequest(w http.ResponseWriter, r *http.Request, ps h returnError(w, r, "failed to decode request body", 405, err, event) return } - reason := getStringValue(records["reason"]); + reason := getStringValue(records["reason"]) requestInfo, err := e.db.getRequest(request) if err != nil { returnError(w, r, "internal error", 405, err, event) @@ -193,7 +193,7 @@ func (e mainEnv) approveUserRequest(w http.ResponseWriter, r *http.Request, ps h action := getStringValue(requestInfo["action"]) status := getStringValue(requestInfo["status"]) if status != "open" { - returnError(w, r, "wrong status: " + status, 405, err, event) + returnError(w, r, "wrong status: "+status, 405, err, event) return } userJSON, userBSON, err := e.db.getUser(userTOKEN) @@ -267,7 +267,7 @@ func (e mainEnv) cancelUserRequest(w http.ResponseWriter, r *http.Request, ps ht returnError(w, r, "failed to decode request body", 405, err, event) return } - reason := getStringValue(records["reason"]); + reason := getStringValue(records["reason"]) requestInfo, err := e.db.getRequest(request) if err != nil { returnError(w, r, "internal error", 405, err, event) @@ -286,7 +286,7 @@ func (e mainEnv) cancelUserRequest(w http.ResponseWriter, r *http.Request, ps ht return } if requestInfo["status"].(string) != "open" { - returnError(w, r, "wrong status: " + requestInfo["status"].(string), 405, err, event) + returnError(w, r, "wrong status: "+requestInfo["status"].(string), 405, err, event) return } resultJSON, err := e.db.getUserJson(userTOKEN) diff --git a/src/schema.go b/src/schema.go index a9aac84..e813cf4 100644 --- a/src/schema.go +++ b/src/schema.go @@ -1,19 +1,19 @@ package main import ( - "context" - "errors" - "encoding/json" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "strings" - "strconv" + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" - "github.com/securitybunker/jsonschema" - jsonpatch "github.com/evanphx/json-patch" - jptr "github.com/qri-io/jsonpointer" + jsonpatch "github.com/evanphx/json-patch" + jptr "github.com/qri-io/jsonpointer" + "github.com/securitybunker/jsonschema" ) var userSchema *jsonschema.Schema @@ -24,224 +24,224 @@ type IsLocked bool type IsPreserve bool func loadUserSchema(cfg Config, confFile *string) error { - fileSchema := cfg.Generic.UserRecordSchema - parentDir := "" - if confFile != nil && len(*confFile) > 0 { - parentDir = filepath.Base(*confFile) - if parentDir != "." { - parentDir = "" - } - } - if len(fileSchema) == 0 { - return nil - } - if strings.HasPrefix(fileSchema, "./") { - _, err := os.Stat(cfg.Generic.UserRecordSchema) - if os.IsNotExist(err) && confFile != nil { - fileSchema = parentDir + fileSchema[2:] + fileSchema := cfg.Generic.UserRecordSchema + parentDir := "" + if confFile != nil && len(*confFile) > 0 { + parentDir = filepath.Base(*confFile) + if parentDir != "." { + parentDir = "" + } } - } else { - fileSchema = parentDir + fileSchema - } - _, err := os.Stat(fileSchema) - if os.IsNotExist(err) { - return err - } - schemaData, err := ioutil.ReadFile(fileSchema) - if err != nil { - return err - } - rs := &jsonschema.Schema{} - jsonschema.LoadDraft2019_09() - jsonschema.RegisterKeyword("admin", newIsAdmin) - jsonschema.RegisterKeyword("locked", newIsLocked) - jsonschema.RegisterKeyword("preserve", newIsPreserve) - err = rs.UnmarshalJSON(schemaData) - if err != nil { - return err - } - userSchema = rs - return nil + if len(fileSchema) == 0 { + return nil + } + if strings.HasPrefix(fileSchema, "./") { + _, err := os.Stat(cfg.Generic.UserRecordSchema) + if os.IsNotExist(err) && confFile != nil { + fileSchema = parentDir + fileSchema[2:] + } + } else { + fileSchema = parentDir + fileSchema + } + _, err := os.Stat(fileSchema) + if os.IsNotExist(err) { + return err + } + schemaData, err := ioutil.ReadFile(fileSchema) + if err != nil { + return err + } + rs := &jsonschema.Schema{} + jsonschema.LoadDraft2019_09() + jsonschema.RegisterKeyword("admin", newIsAdmin) + jsonschema.RegisterKeyword("locked", newIsLocked) + jsonschema.RegisterKeyword("preserve", newIsPreserve) + err = rs.UnmarshalJSON(schemaData) + if err != nil { + return err + } + userSchema = rs + return nil } func UserSchemaEnabled() bool { - if userSchema == nil { - return false - } - return true + if userSchema == nil { + return false + } + return true } func validateUserRecord(record []byte) error { - if userSchema == nil { - return nil - } - var doc interface{} - if err := json.Unmarshal(record, &doc); err != nil { - return err - } - result := userSchema.Validate(nil, doc) - if len(*result.Errs) > 0 { - return (*result.Errs)[0] - } - return nil + if userSchema == nil { + return nil + } + var doc interface{} + if err := json.Unmarshal(record, &doc); err != nil { + return err + } + result := userSchema.Validate(nil, doc) + if len(*result.Errs) > 0 { + return (*result.Errs)[0] + } + return nil } func validateUserRecordChange(oldRecord []byte, newRecord []byte, authResult string) (bool, error) { - if userSchema == nil { - return false, nil - } - var oldDoc interface{} - var newDoc interface{} - if err := json.Unmarshal(oldRecord, &oldDoc); err != nil { - return false, err - } - if err := json.Unmarshal(newRecord, &newDoc); err != nil { - return false, err - } - result := userSchema.Validate(nil, newDoc) - //if len(*result.Errs) > 0 { - // return (*result.Errs)[0] - //} - result2 := userSchema.Validate(nil, oldDoc) - if len(*result2.Errs) > 0 { - return false, (*result.Errs)[0] - } - if result.ExtendedResults == nil { - return false, nil - } - adminRecordChanged := false - for _, r := range *result.ExtendedResults { - fmt.Printf("path: %s key: %s data: %v\n", r.PropertyPath, r.Key, r.Value) - if r.Key == "locked" || (r.Key == "admin" && authResult == "login" && adminRecordChanged == false) { - pointer, _ := jptr.Parse(r.PropertyPath) - data1, _ := pointer.Eval(oldDoc) - data1Binary, _ := json.Marshal(data1) - data2, _ := pointer.Eval(newDoc) - data2Binary, _ := json.Marshal(data2) - if !jsonpatch.Equal(data1Binary, data2Binary) { - if r.Key == "locked" { - fmt.Printf("Locked value changed. Old: %s New %s\n", data1Binary, data2Binary) - return false, errors.New("User schema check error. Locked value changed: "+r.PropertyPath) - } else { - fmt.Printf("Admin value changed. Approval required. Old: %s New %s\n", data1Binary, data2Binary) - adminRecordChanged = true - } - } - } - } - return adminRecordChanged, nil + if userSchema == nil { + return false, nil + } + var oldDoc interface{} + var newDoc interface{} + if err := json.Unmarshal(oldRecord, &oldDoc); err != nil { + return false, err + } + if err := json.Unmarshal(newRecord, &newDoc); err != nil { + return false, err + } + result := userSchema.Validate(nil, newDoc) + //if len(*result.Errs) > 0 { + // return (*result.Errs)[0] + //} + result2 := userSchema.Validate(nil, oldDoc) + if len(*result2.Errs) > 0 { + return false, (*result.Errs)[0] + } + if result.ExtendedResults == nil { + return false, nil + } + adminRecordChanged := false + for _, r := range *result.ExtendedResults { + fmt.Printf("path: %s key: %s data: %v\n", r.PropertyPath, r.Key, r.Value) + if r.Key == "locked" || (r.Key == "admin" && authResult == "login" && adminRecordChanged == false) { + pointer, _ := jptr.Parse(r.PropertyPath) + data1, _ := pointer.Eval(oldDoc) + data1Binary, _ := json.Marshal(data1) + data2, _ := pointer.Eval(newDoc) + data2Binary, _ := json.Marshal(data2) + if !jsonpatch.Equal(data1Binary, data2Binary) { + if r.Key == "locked" { + fmt.Printf("Locked value changed. Old: %s New %s\n", data1Binary, data2Binary) + return false, errors.New("User schema check error. Locked value changed: " + r.PropertyPath) + } else { + fmt.Printf("Admin value changed. Approval required. Old: %s New %s\n", data1Binary, data2Binary) + adminRecordChanged = true + } + } + } + } + return adminRecordChanged, nil } func cleanupRecord(record []byte) ([]byte, map[string]interface{}) { - if userSchema == nil { - return nil, nil - } - var doc interface{} - if err := json.Unmarshal(record, &doc); err != nil { - return nil, nil - } - result := userSchema.Validate(nil, doc) - if result.ExtendedResults == nil { - return nil, nil - } - - doc1 := make(map[string]interface{}) - doc2 := make([]interface{},1) - nested := func(path string, data interface{}) { - currentStr := &doc1 - currentNum := &doc2 - keys := strings.Split(path, "/") - fmt.Printf("path: %s\n", path) - for i,k := range keys { - if len(k) == 0 { - continue - } - if kNum, err := strconv.Atoi(k); err == nil { - if (i+1) == len(keys) { - (*currentNum)[kNum] = data - } else if kNxt, err := strconv.Atoi(keys[i+1]); err == nil { - if (*currentNum)[kNum] == nil { - v := make([]interface{}, kNxt+1) - (*currentNum)[kNum] = v - currentNum = &v - } else { - v := (*currentNum)[kNum].([]interface{}) - for (len(v) < kNxt+1) { - v = append(v, nil) - } - (*currentNum)[kNum] = v - currentNum = &v - } - } else { - if (*currentNum)[kNum] == nil { - v := make(map[string]interface{}) - (*currentNum)[kNum] = v - currentStr = &v - } else { - v := (*currentNum)[kNum].(map[string]interface{}) - currentStr = &v - } - } - } else { - if (i+1) == len(keys) { - (*currentStr)[k] = data - } else if kNxt, err := strconv.Atoi(keys[i+1]); err == nil { - if _, ok := (*currentStr)[k]; !ok { - v := make([]interface{}, kNxt+1) - (*currentStr)[k] = v - currentNum = &v - } else { - v := (*currentStr)[k].([]interface{}) - for (len(v) < kNxt+1) { - v = append(v, nil) - } - (*currentStr)[k] = v - currentNum = &v - } - } else { - if _, ok := (*currentStr)[k]; !ok { - v := make(map[string]interface{}) - (*currentStr)[k] = v - currentStr = &v - } else { - v := (*currentNum)[kNum].(map[string]interface{}) - currentStr = &v - } - } - } - } - } + if userSchema == nil { + return nil, nil + } + var doc interface{} + if err := json.Unmarshal(record, &doc); err != nil { + return nil, nil + } + result := userSchema.Validate(nil, doc) + if result.ExtendedResults == nil { + return nil, nil + } - found := false - for _, r := range *result.ExtendedResults { - fmt.Printf("path: %s key: %s data: %v\n", r.PropertyPath, r.Key, r.Value) - if r.Key == "preserve" { - //pointer, _ := jptr.Parse(r.PropertyPath) - //data1, _ := pointer.Eval(oldDoc) - nested(r.PropertyPath, r.Value) - found = true - } - } - if found == false { - return nil, nil - } - //fmt.Printf("final doc1 %v\n", doc1) - dataBinary, _ := json.Marshal(doc1) - //fmt.Println(err) - fmt.Printf("data bin %s\n", dataBinary) - return dataBinary, doc1 + doc1 := make(map[string]interface{}) + doc2 := make([]interface{}, 1) + nested := func(path string, data interface{}) { + currentStr := &doc1 + currentNum := &doc2 + keys := strings.Split(path, "/") + fmt.Printf("path: %s\n", path) + for i, k := range keys { + if len(k) == 0 { + continue + } + if kNum, err := strconv.Atoi(k); err == nil { + if (i + 1) == len(keys) { + (*currentNum)[kNum] = data + } else if kNxt, err := strconv.Atoi(keys[i+1]); err == nil { + if (*currentNum)[kNum] == nil { + v := make([]interface{}, kNxt+1) + (*currentNum)[kNum] = v + currentNum = &v + } else { + v := (*currentNum)[kNum].([]interface{}) + for len(v) < kNxt+1 { + v = append(v, nil) + } + (*currentNum)[kNum] = v + currentNum = &v + } + } else { + if (*currentNum)[kNum] == nil { + v := make(map[string]interface{}) + (*currentNum)[kNum] = v + currentStr = &v + } else { + v := (*currentNum)[kNum].(map[string]interface{}) + currentStr = &v + } + } + } else { + if (i + 1) == len(keys) { + (*currentStr)[k] = data + } else if kNxt, err := strconv.Atoi(keys[i+1]); err == nil { + if _, ok := (*currentStr)[k]; !ok { + v := make([]interface{}, kNxt+1) + (*currentStr)[k] = v + currentNum = &v + } else { + v := (*currentStr)[k].([]interface{}) + for len(v) < kNxt+1 { + v = append(v, nil) + } + (*currentStr)[k] = v + currentNum = &v + } + } else { + if _, ok := (*currentStr)[k]; !ok { + v := make(map[string]interface{}) + (*currentStr)[k] = v + currentStr = &v + } else { + v := (*currentNum)[kNum].(map[string]interface{}) + currentStr = &v + } + } + } + } + } + + found := false + for _, r := range *result.ExtendedResults { + fmt.Printf("path: %s key: %s data: %v\n", r.PropertyPath, r.Key, r.Value) + if r.Key == "preserve" { + //pointer, _ := jptr.Parse(r.PropertyPath) + //data1, _ := pointer.Eval(oldDoc) + nested(r.PropertyPath, r.Value) + found = true + } + } + if found == false { + return nil, nil + } + //fmt.Printf("final doc1 %v\n", doc1) + dataBinary, _ := json.Marshal(doc1) + //fmt.Println(err) + fmt.Printf("data bin %s\n", dataBinary) + return dataBinary, doc1 } /*******************************************************************/ // Admin keyword. Any change in this record requires admin approval. func newIsAdmin() jsonschema.Keyword { - return new(IsAdmin) + return new(IsAdmin) } // Validate implements jsonschema.Keyword func (f *IsAdmin) Validate(propPath string, data interface{}, errs *[]jsonschema.KeyError) { - fmt.Printf("Validate: %s -> %v\n", propPath, data) + fmt.Printf("Validate: %s -> %v\n", propPath, data) } // Register implements jsonschema.Keyword @@ -250,25 +250,25 @@ func (f *IsAdmin) Register(uri string, registry *jsonschema.SchemaRegistry) { // Resolve implements jsonschema.Keyword func (f *IsAdmin) Resolve(pointer jptr.Pointer, uri string) *jsonschema.Schema { - fmt.Printf("Resolve %s\n", uri) - return nil + fmt.Printf("Resolve %s\n", uri) + return nil } func (f *IsAdmin) ValidateKeyword(ctx context.Context, currentState *jsonschema.ValidationState, data interface{}) { - //fmt.Printf("ValidateKeyword admin %s => %v\n", currentState.InstanceLocation.String(), data) - currentState.AddExtendedResult("admin", data) + //fmt.Printf("ValidateKeyword admin %s => %v\n", currentState.InstanceLocation.String(), data) + currentState.AddExtendedResult("admin", data) } /*******************************************************************/ // Locked keyword - meaningin that value should never be changed after record creation func newIsLocked() jsonschema.Keyword { - return new(IsLocked) + return new(IsLocked) } // Validate implements jsonschema.Keyword func (f *IsLocked) Validate(propPath string, data interface{}, errs *[]jsonschema.KeyError) { - fmt.Printf("Validate: %s -> %v\n", propPath, data) + fmt.Printf("Validate: %s -> %v\n", propPath, data) } // Register implements jsonschema.Keyword @@ -277,25 +277,25 @@ func (f *IsLocked) Register(uri string, registry *jsonschema.SchemaRegistry) { // Resolve implements jsonschema.Keyword func (f *IsLocked) Resolve(pointer jptr.Pointer, uri string) *jsonschema.Schema { - fmt.Printf("Resolve %s\n", uri) - return nil + fmt.Printf("Resolve %s\n", uri) + return nil } func (f *IsLocked) ValidateKeyword(ctx context.Context, currentState *jsonschema.ValidationState, data interface{}) { - //fmt.Printf("ValidateKeyword locked %s => %v\n", currentState.InstanceLocation.String(), data) - currentState.AddExtendedResult("locked", data) + //fmt.Printf("ValidateKeyword locked %s => %v\n", currentState.InstanceLocation.String(), data) + currentState.AddExtendedResult("locked", data) } /*******************************************************************/ // Preserve keyword - meaningin that value should never be deleted (after user is delete it is left) func newIsPreserve() jsonschema.Keyword { - return new(IsPreserve) + return new(IsPreserve) } // Validate implements jsonschema.Keyword func (f *IsPreserve) Validate(propPath string, data interface{}, errs *[]jsonschema.KeyError) { - fmt.Printf("Validate: %s -> %v\n", propPath, data) + fmt.Printf("Validate: %s -> %v\n", propPath, data) } // Register implements jsonschema.Keyword @@ -304,11 +304,11 @@ func (f *IsPreserve) Register(uri string, registry *jsonschema.SchemaRegistry) { // Resolve implements jsonschema.Keyword func (f *IsPreserve) Resolve(pointer jptr.Pointer, uri string) *jsonschema.Schema { - fmt.Printf("Resolve %s\n", uri) - return nil + fmt.Printf("Resolve %s\n", uri) + return nil } func (f *IsPreserve) ValidateKeyword(ctx context.Context, currentState *jsonschema.ValidationState, data interface{}) { - //fmt.Printf("ValidateKeyword preserve %s => %v\n", currentState.InstanceLocation.String(), data) - currentState.AddExtendedResult("preserve", data) + //fmt.Printf("ValidateKeyword preserve %s => %v\n", currentState.InstanceLocation.String(), data) + currentState.AddExtendedResult("preserve", data) } diff --git a/src/sessions_api.go b/src/sessions_api.go index a44494b..bcf8ce2 100644 --- a/src/sessions_api.go +++ b/src/sessions_api.go @@ -3,19 +3,21 @@ package main import ( "encoding/json" "fmt" - "net/http" - "strings" uuid "github.com/hashicorp/go-uuid" "github.com/julienschmidt/httprouter" "github.com/securitybunker/databunker/src/storage" "go.mongodb.org/mongo-driver/bson" + "net/http" + "strings" ) func (e mainEnv) createSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { session := ps.ByName("session") var event *auditEvent defer func() { - if event != nil { event.submit(e.db) } + if event != nil { + event.submit(e.db) + } }() if enforceUUID(w, session, event) == false { //returnError(w, r, "bad session format", nil, event) @@ -68,21 +70,21 @@ func (e mainEnv) createSession(w http.ResponseWriter, r *http.Request, ps httpro } func (e mainEnv) deleteSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - session := ps.ByName("session") - event := audit("delete session", session, "session", session) - defer func() { event.submit(e.db) }() - if enforceUUID(w, session, event) == false { - //returnError(w, r, "bad session format", nil, event) - return - } - authResult := e.enforceAdmin(w, r) - if authResult == "" { - return - } + session := ps.ByName("session") + event := audit("delete session", session, "session", session) + defer func() { event.submit(e.db) }() + if enforceUUID(w, session, event) == false { + //returnError(w, r, "bad session format", nil, event) + return + } + authResult := e.enforceAdmin(w, r) + if authResult == "" { + return + } e.db.deleteSession(session) w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(200) - fmt.Fprintf(w, `{"status":"ok"}`) + w.WriteHeader(200) + fmt.Fprintf(w, `{"status":"ok"}`) } func (e mainEnv) newUserSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { @@ -206,7 +208,9 @@ func (e mainEnv) getSession(w http.ResponseWriter, r *http.Request, ps httproute session := ps.ByName("session") var event *auditEvent defer func() { - if event != nil { event.submit(e.db) } + if event != nil { + event.submit(e.db) + } }() when, record, userTOKEN, err := e.db.getSession(session) if err != nil { diff --git a/src/sessions_db.go b/src/sessions_db.go index bca2735..d9827b9 100644 --- a/src/sessions_db.go +++ b/src/sessions_db.go @@ -72,15 +72,15 @@ func (dbobj dbcon) getSession(sessionUUID string) (int32, []byte, string, error) recordKey0 := record["key"].(string) recordKey, err := base64.StdEncoding.DecodeString(recordKey0) if err != nil { - return 0, nil, "", err + return 0, nil, "", err } encData, err := base64.StdEncoding.DecodeString(encData0) if err != nil { - return 0, nil, "", err + return 0, nil, "", err } decrypted, err := decrypt(dbobj.masterKey, recordKey, encData) if err != nil { - return 0, nil, "", err + return 0, nil, "", err } return when, decrypted, userTOKEN, err } @@ -114,4 +114,3 @@ func (dbobj dbcon) deleteSession(sessionUUID string) (bool, error) { dbobj.store.DeleteRecord(storage.TblName.Sessions, "session", sessionUUID) return true, nil } - diff --git a/src/sms.go b/src/sms.go index e9f7a12..8f3c7d2 100644 --- a/src/sms.go +++ b/src/sms.go @@ -20,7 +20,7 @@ func sendCodeByPhoneDo(domain string, client *http.Client, code int32, address s log.Printf("SMS gateway provider URL is missing") return } - msg := "Databunker code "+strconv.Itoa(int(code)) + msg := "Databunker code " + strconv.Itoa(int(code)) finalUrl := cfg.Sms.Url finalUrl = strings.ReplaceAll(finalUrl, "_PHONE_", url.QueryEscape(address)) finalUrl = strings.ReplaceAll(finalUrl, "_FROM_", url.QueryEscape(cfg.Sms.From)) diff --git a/src/storage/mysql-storage.go b/src/storage/mysql-storage.go index d62a344..5b54a05 100644 --- a/src/storage/mysql-storage.go +++ b/src/storage/mysql-storage.go @@ -18,13 +18,12 @@ import ( ) var ( - allTables []string + allTables []string ) - // MySQL struct is used to store database object type MySQLDB struct { - db *sql.DB + db *sql.DB } func (dbobj MySQLDB) getConnectionString(dbname *string) string { @@ -165,7 +164,7 @@ func (dbobj *MySQLDB) InitDB(dbname *string) error { } func (dbobj MySQLDB) Ping() error { - return dbobj.db.Ping() + return dbobj.db.Ping() } // CloseDB function closes the open database @@ -346,25 +345,25 @@ func (dbobj MySQLDB) CreateRecord(t Tbl, data interface{}) (int, error) { // CountRecords0 returns number of records in table func (dbobj MySQLDB) CountRecords0(t Tbl) (int64, error) { - tbl := GetTable(t) - q := "select count(*) from " + tbl - //fmt.Printf("q: %s\n", q) + tbl := GetTable(t) + q := "select count(*) from " + tbl + //fmt.Printf("q: %s\n", q) - tx, err := dbobj.db.Begin() - if err != nil { - return 0, err - } - defer tx.Rollback() - row := tx.QueryRow(q) - var count int - err = row.Scan(&count) - if err != nil { - return 0, err - } - if err = tx.Commit(); err != nil { - return 0, err - } - return int64(count), nil + tx, err := dbobj.db.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + row := tx.QueryRow(q) + var count int + err = row.Scan(&count) + if err != nil { + return 0, err + } + if err = tx.Commit(); err != nil { + return 0, err + } + return int64(count), nil } // CountRecords returns number of records that match filter @@ -444,19 +443,19 @@ func (dbobj MySQLDB) updateRecordInTableDo(table string, filter string, bdoc *bs // Lookup record by multiple fields func (dbobj MySQLDB) LookupRecord(t Tbl, row bson.M) (bson.M, error) { - table := GetTable(t) - q := "select * from " + table + " WHERE " - num := 1 - values := make([]interface{}, 0) - for keyName, keyValue := range row { - q = q + dbobj.escapeName(keyName) + "=?" - if num < len(row) { - q = q + " AND " - } - values = append(values, keyValue) - num = num + 1 - } - return dbobj.getRecordInTableDo(q, values) + table := GetTable(t) + q := "select * from " + table + " WHERE " + num := 1 + values := make([]interface{}, 0) + for keyName, keyValue := range row { + q = q + dbobj.escapeName(keyName) + "=?" + if num < len(row) { + q = q + " AND " + } + values = append(values, keyValue) + num = num + 1 + } + return dbobj.getRecordInTableDo(q, values) } // GetRecord returns specific record from database @@ -730,7 +729,7 @@ func (dbobj MySQLDB) GetExpiring(t Tbl, keyName string, keyValue string) ([]bson table := GetTable(t) now := int32(time.Now().Unix()) q := fmt.Sprintf("select * from %s WHERE endtime>0 AND endtime<%d AND %s=?", - table, now, dbobj.escapeName(keyName)) + table, now, dbobj.escapeName(keyName)) fmt.Printf("q: %s\n", q) values := make([]interface{}, 0) values = append(values, keyValue) @@ -907,13 +906,13 @@ func (dbobj MySQLDB) IndexNewApp(appName string) { // it is a new app, create an index log.Printf("This is a new app, creating table & index for: %s\n", appName) queries := []string{ - `CREATE TABLE IF NOT EXISTS ` + appName + ` (`+ - `token TINYTEXT,`+ - `md5 TINYTEXT,`+ - `rofields TINYTEXT,`+ - `data TEXT,`+ - `status TINYTEXT,`+ - "`when` int);", + `CREATE TABLE IF NOT EXISTS ` + appName + ` (` + + `token TINYTEXT,` + + `md5 TINYTEXT,` + + `rofields TINYTEXT,` + + `data TEXT,` + + `status TINYTEXT,` + + "`when` int);", "CREATE UNIQUE INDEX " + appName + "_token ON " + appName + " (token(36));"} err := dbobj.execQueries(queries) if err == nil { @@ -925,20 +924,20 @@ func (dbobj MySQLDB) IndexNewApp(appName string) { func (dbobj MySQLDB) initUsers() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS users (`+ - `token TINYTEXT,`+ - "`key` TINYTEXT,"+ - `md5 TINYTEXT,`+ - `loginidx TINYTEXT,`+ - `emailidx TINYTEXT,`+ - `phoneidx TINYTEXT,`+ - `customidx TINYTEXT,`+ - `expstatus TINYTEXT,`+ - `exptoken TINYTEXT,`+ - `endtime int,`+ - `tempcodeexp int,`+ - `tempcode int,`+ - `data TEXT);`, + `CREATE TABLE IF NOT EXISTS users (` + + `token TINYTEXT,` + + "`key` TINYTEXT," + + `md5 TINYTEXT,` + + `loginidx TINYTEXT,` + + `emailidx TINYTEXT,` + + `phoneidx TINYTEXT,` + + `customidx TINYTEXT,` + + `expstatus TINYTEXT,` + + `exptoken TINYTEXT,` + + `endtime int,` + + `tempcodeexp int,` + + `tempcode int,` + + `data TEXT);`, `CREATE UNIQUE INDEX users_token ON users (token(36));`, `CREATE INDEX users_login ON users (loginidx(36));`, `CREATE INDEX users_email ON users (emailidx(36));`, @@ -951,13 +950,13 @@ func (dbobj MySQLDB) initUsers() error { func (dbobj MySQLDB) initXTokens() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS xtokens (`+ - `xtoken TINYTEXT,`+ - `token TINYTEXT,`+ - `type TINYTEXT,`+ - `app TINYTEXT,`+ - `fields TINYTEXT,`+ - `endtime int);`, + `CREATE TABLE IF NOT EXISTS xtokens (` + + `xtoken TINYTEXT,` + + `token TINYTEXT,` + + `type TINYTEXT,` + + `app TINYTEXT,` + + `fields TINYTEXT,` + + `endtime int);`, `CREATE UNIQUE INDEX xtokens_xtoken ON xtokens (xtoken(36));`, `CREATE INDEX xtokens_type ON xtokens (type(20));`} return dbobj.execQueries(queries) @@ -965,35 +964,35 @@ func (dbobj MySQLDB) initXTokens() error { func (dbobj MySQLDB) initSharedRecords() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS sharedrecords (`+ - `token TINYTEXT,`+ - `record TINYTEXT,`+ - `partner TINYTEXT,`+ - `session TINYTEXT,`+ - `app TINYTEXT,`+ - `fields TINYTEXT,`+ - `endtime int,`+ - "`when` int);", + `CREATE TABLE IF NOT EXISTS sharedrecords (` + + `token TINYTEXT,` + + `record TINYTEXT,` + + `partner TINYTEXT,` + + `session TINYTEXT,` + + `app TINYTEXT,` + + `fields TINYTEXT,` + + `endtime int,` + + "`when` int);", `CREATE INDEX sharedrecords_record ON sharedrecords (record(36));`} return dbobj.execQueries(queries) } func (dbobj MySQLDB) initAudit() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS audit (`+ - `atoken TINYTEXT,`+ - `identity TINYTEXT,`+ - `record TINYTEXT,`+ - `who TINYTEXT,`+ - `mode TINYTEXT,`+ - `app TINYTEXT,`+ - `title TINYTEXT,`+ - `status TINYTEXT,`+ - `msg TINYTEXT,`+ - `debug TINYTEXT,`+ - "`before` TEXT,"+ - `after TEXT,`+ - "`when` int);", + `CREATE TABLE IF NOT EXISTS audit (` + + `atoken TINYTEXT,` + + `identity TINYTEXT,` + + `record TINYTEXT,` + + `who TINYTEXT,` + + `mode TINYTEXT,` + + `app TINYTEXT,` + + `title TINYTEXT,` + + `status TINYTEXT,` + + `msg TINYTEXT,` + + `debug TINYTEXT,` + + "`before` TEXT," + + `after TEXT,` + + "`when` int);", `CREATE INDEX audit_atoken ON audit (atoken(36));`, `CREATE INDEX audit_record ON audit (record(36));`} return dbobj.execQueries(queries) @@ -1001,17 +1000,17 @@ func (dbobj MySQLDB) initAudit() error { func (dbobj MySQLDB) initRequests() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS requests (`+ - `rtoken TINYTEXT,`+ - `token TINYTEXT,`+ - `app TINYTEXT,`+ - `brief TINYTEXT,`+ - `action TINYTEXT,`+ - `status TINYTEXT,`+ - "`change` TINYTEXT,"+ - `reason TINYTEXT,`+ - `creationtime int,`+ - "`when` int);", + `CREATE TABLE IF NOT EXISTS requests (` + + `rtoken TINYTEXT,` + + `token TINYTEXT,` + + `app TINYTEXT,` + + `brief TINYTEXT,` + + `action TINYTEXT,` + + `status TINYTEXT,` + + "`change` TINYTEXT," + + `reason TINYTEXT,` + + `creationtime int,` + + "`when` int);", `CREATE INDEX requests_rtoken ON requests (rtoken(36));`, `CREATE INDEX requests_token ON requests (token(36));`, `CREATE INDEX requests_status ON requests (status(20));`} @@ -1020,50 +1019,50 @@ func (dbobj MySQLDB) initRequests() error { func (dbobj MySQLDB) initProcessingactivities() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS processingactivities (`+ - `activity TINYTEXT,`+ - `title TINYTEXT,`+ - `script TEXT,`+ - `fulldesc TINYTEXT,`+ - `legalbasis TINYTEXT,`+ - `applicableto TINYTEXT,`+ - `creationtime int);`, + `CREATE TABLE IF NOT EXISTS processingactivities (` + + `activity TINYTEXT,` + + `title TINYTEXT,` + + `script TEXT,` + + `fulldesc TINYTEXT,` + + `legalbasis TINYTEXT,` + + `applicableto TINYTEXT,` + + `creationtime int);`, `CREATE INDEX processingactivities_activity ON processingactivities (activity(36));`} return dbobj.execQueries(queries) } func (dbobj MySQLDB) initLegalbasis() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS legalbasis (`+ - `brief TINYTEXT,`+ - `status TINYTEXT,`+ - `module TINYTEXT,`+ - `shortdesc TINYTEXT,`+ - `fulldesc TEXT,`+ - `basistype TINYTEXT,`+ - `requiredmsg TINYTEXT,`+ - `usercontrol BOOLEAN,`+ - `requiredflag BOOLEAN,`+ - `creationtime int);`, + `CREATE TABLE IF NOT EXISTS legalbasis (` + + `brief TINYTEXT,` + + `status TINYTEXT,` + + `module TINYTEXT,` + + `shortdesc TINYTEXT,` + + `fulldesc TEXT,` + + `basistype TINYTEXT,` + + `requiredmsg TINYTEXT,` + + `usercontrol BOOLEAN,` + + `requiredflag BOOLEAN,` + + `creationtime int);`, `CREATE INDEX legalbasis_brief ON legalbasis (brief(36));`} return dbobj.execQueries(queries) } func (dbobj MySQLDB) initAgreements() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS agreements (`+ - `who TINYTEXT,`+ - `mode TINYTEXT,`+ - `token TINYTEXT,`+ - `brief TINYTEXT,`+ - `status TINYTEXT,`+ - `referencecode TINYTEXT,`+ - `lastmodifiedby TINYTEXT,`+ - `agreementmethod TINYTEXT,`+ - `creationtime int,`+ - `starttime int,`+ - `endtime int,`+ - "`when` int);", + `CREATE TABLE IF NOT EXISTS agreements (` + + `who TINYTEXT,` + + `mode TINYTEXT,` + + `token TINYTEXT,` + + `brief TINYTEXT,` + + `status TINYTEXT,` + + `referencecode TINYTEXT,` + + `lastmodifiedby TINYTEXT,` + + `agreementmethod TINYTEXT,` + + `creationtime int,` + + `starttime int,` + + `endtime int,` + + "`when` int);", `CREATE INDEX agreements_token ON agreements (token(36));`, `CREATE INDEX agreements_brief ON agreements (brief(36));`} return dbobj.execQueries(queries) @@ -1071,15 +1070,14 @@ func (dbobj MySQLDB) initAgreements() error { func (dbobj MySQLDB) initSessions() error { queries := []string{ - `CREATE TABLE IF NOT EXISTS sessions (`+ - `token TINYTEXT,`+ - `session TINYTEXT,`+ - "`key` TINYTEXT,"+ - `data TEXT,`+ - `endtime int,`+ - "`when` int);", + `CREATE TABLE IF NOT EXISTS sessions (` + + `token TINYTEXT,` + + `session TINYTEXT,` + + "`key` TINYTEXT," + + `data TEXT,` + + `endtime int,` + + "`when` int);", `CREATE INDEX sessions_a_token ON sessions (token(36));`, `CREATE INDEX sessions_a_session ON sessions (session(36));`} return dbobj.execQueries(queries) } - diff --git a/src/storage/sqlite-storage.go b/src/storage/sqlite-storage.go index 07e8257..11abe63 100644 --- a/src/storage/sqlite-storage.go +++ b/src/storage/sqlite-storage.go @@ -47,14 +47,14 @@ func (dbobj SQLiteDB) DBExists(filepath *string) bool { } err = db.Ping() if err != nil { - return false - } + return false + } dbobj2 := SQLiteDB{db} - record, err := dbobj2.GetRecord2(TblName.Xtokens, "token", "", "type", "root") - if record == nil || err != nil { + record, err := dbobj2.GetRecord2(TblName.Xtokens, "token", "", "type", "root") + if record == nil || err != nil { dbobj2.CloseDB() - return false - } + return false + } dbobj2.CloseDB() return true } @@ -162,9 +162,8 @@ func (dbobj *SQLiteDB) InitDB(filepath *string) error { return nil } - func (dbobj SQLiteDB) Ping() error { - return dbobj.db.Ping() + return dbobj.db.Ping() } // CloseDB function closes the open database @@ -299,8 +298,8 @@ func (dbobj SQLiteDB) decodeForUpdate(bdoc *bson.M, bdel *bson.M) (string, []int } func (dbobj SQLiteDB) Exec(q string) error { - _, err := dbobj.db.Exec(q) - return err + _, err := dbobj.db.Exec(q) + return err } // CreateRecordInTable creates new record @@ -363,26 +362,26 @@ func (dbobj SQLiteDB) CountRecords0(t Tbl) (int64, error) { // CountRecords returns number of records that match filter func (dbobj SQLiteDB) CountRecords(t Tbl, keyName string, keyValue string) (int64, error) { - tbl := GetTable(t) - q := "select count(*) from " + tbl + " WHERE " + dbobj.escapeName(keyName) + "=$1" - //fmt.Printf("q: %s\n", q) + tbl := GetTable(t) + q := "select count(*) from " + tbl + " WHERE " + dbobj.escapeName(keyName) + "=$1" + //fmt.Printf("q: %s\n", q) - tx, err := dbobj.db.Begin() - if err != nil { - return 0, err - } - defer tx.Rollback() - row := tx.QueryRow(q, keyValue) - // Columns - var count int - err = row.Scan(&count) - if err != nil { - return 0, err - } - if err = tx.Commit(); err != nil { - return 0, err - } - return int64(count), nil + tx, err := dbobj.db.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + row := tx.QueryRow(q, keyValue) + // Columns + var count int + err = row.Scan(&count) + if err != nil { + return 0, err + } + if err = tx.Commit(); err != nil { + return 0, err + } + return int64(count), nil } // UpdateRecord updates database record @@ -720,7 +719,7 @@ func (dbobj SQLiteDB) GetExpiring(t Tbl, keyName string, keyValue string) ([]bso table := GetTable(t) now := int32(time.Now().Unix()) q := fmt.Sprintf("select * from %s WHERE endtime>0 AND endtime<%d AND %s=$1", - table, now, dbobj.escapeName(keyName)) + table, now, dbobj.escapeName(keyName)) //fmt.Printf("q: %s\n", q) values := make([]interface{}, 0) values = append(values, keyValue) @@ -1065,4 +1064,3 @@ func (dbobj SQLiteDB) initSessions() error { `CREATE INDEX sessions_session ON sessions (session);`} return dbobj.execQueries(queries) } - diff --git a/src/storage/storage.go b/src/storage/storage.go index 2632634..916023f 100644 --- a/src/storage/storage.go +++ b/src/storage/storage.go @@ -1,9 +1,9 @@ package storage import ( - "net/http" - "os" - "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson" + "net/http" + "os" ) // Tbl is used to store table id @@ -11,131 +11,130 @@ type Tbl int // listTbls used to store list of tables type listTbls struct { - Users Tbl - Audit Tbl - Xtokens Tbl - Sessions Tbl - Requests Tbl - Legalbasis Tbl - Agreements Tbl - Sharedrecords Tbl - Processingactivities Tbl + Users Tbl + Audit Tbl + Xtokens Tbl + Sessions Tbl + Requests Tbl + Legalbasis Tbl + Agreements Tbl + Sharedrecords Tbl + Processingactivities Tbl } // TblName is enum of tables var TblName = &listTbls{ - Users: 0, - Audit: 1, - Xtokens: 2, - Sessions: 3, - Requests: 4, - Legalbasis: 5, - Agreements: 6, - Sharedrecords: 7, - Processingactivities: 8, + Users: 0, + Audit: 1, + Xtokens: 2, + Sessions: 3, + Requests: 4, + Legalbasis: 5, + Agreements: 6, + Sharedrecords: 7, + Processingactivities: 8, } func GetTable(t Tbl) string { - switch t { - case TblName.Users: - return "users" - case TblName.Audit: - return "audit" - case TblName.Xtokens: - return "xtokens" - case TblName.Sessions: - return "sessions" - case TblName.Requests: - return "requests" - case TblName.Legalbasis: - return "legalbasis" - case TblName.Agreements: - return "agreements" - case TblName.Sharedrecords: - return "sharedrecords" - case TblName.Processingactivities: - return "processingactivities" - } - return "users" + switch t { + case TblName.Users: + return "users" + case TblName.Audit: + return "audit" + case TblName.Xtokens: + return "xtokens" + case TblName.Sessions: + return "sessions" + case TblName.Requests: + return "requests" + case TblName.Legalbasis: + return "legalbasis" + case TblName.Agreements: + return "agreements" + case TblName.Sharedrecords: + return "sharedrecords" + case TblName.Processingactivities: + return "processingactivities" + } + return "users" } type BackendDB interface { - DBExists(*string) bool - OpenDB(*string) error - InitDB(*string) (error) - CreateTestDB() string - Ping() error - CloseDB() - BackupDB(http.ResponseWriter) - InitUserApps() error - IndexNewApp(string) - Exec(string) error - CreateRecordInTable(string, interface{}) (int, error) - CreateRecord(Tbl, interface{}) (int, error) - CountRecords0(Tbl) (int64, error) - CountRecords(Tbl, string, string) (int64, error) - UpdateRecord(Tbl, string, string, *bson.M) (int64, error) - UpdateRecordInTable(string, string, string, *bson.M) (int64, error) - UpdateRecord2(Tbl, string, string, string, string, *bson.M, *bson.M) (int64, error) - UpdateRecordInTable2(string, string, string, string, string, *bson.M, *bson.M) (int64, error) - LookupRecord(Tbl, bson.M) (bson.M, error) - GetRecord(Tbl, string, string) (bson.M, error) - GetRecordInTable(string, string, string) (bson.M, error) - GetRecord2(Tbl, string, string, string, string) (bson.M, error) - DeleteRecord(Tbl, string, string) (int64, error) - DeleteRecordInTable(string, string, string) (int64, error) - DeleteRecord2(Tbl, string, string, string, string) (int64, error) - DeleteExpired0(Tbl, int32) (int64, error) - DeleteExpired(Tbl, string, string) (int64, error) - CleanupRecord(Tbl, string, string, interface{}) (int64, error) - GetExpiring(Tbl, string, string) ([]bson.M, error) - GetUniqueList(Tbl, string) ([]bson.M, error) - GetList0(Tbl, int32, int32, string) ([]bson.M, error) - GetList(Tbl, string, string, int32, int32, string) ([]bson.M, error) - GetAllTables() ([]string, error) - ValidateNewApp(appName string) bool + DBExists(*string) bool + OpenDB(*string) error + InitDB(*string) error + CreateTestDB() string + Ping() error + CloseDB() + BackupDB(http.ResponseWriter) + InitUserApps() error + IndexNewApp(string) + Exec(string) error + CreateRecordInTable(string, interface{}) (int, error) + CreateRecord(Tbl, interface{}) (int, error) + CountRecords0(Tbl) (int64, error) + CountRecords(Tbl, string, string) (int64, error) + UpdateRecord(Tbl, string, string, *bson.M) (int64, error) + UpdateRecordInTable(string, string, string, *bson.M) (int64, error) + UpdateRecord2(Tbl, string, string, string, string, *bson.M, *bson.M) (int64, error) + UpdateRecordInTable2(string, string, string, string, string, *bson.M, *bson.M) (int64, error) + LookupRecord(Tbl, bson.M) (bson.M, error) + GetRecord(Tbl, string, string) (bson.M, error) + GetRecordInTable(string, string, string) (bson.M, error) + GetRecord2(Tbl, string, string, string, string) (bson.M, error) + DeleteRecord(Tbl, string, string) (int64, error) + DeleteRecordInTable(string, string, string) (int64, error) + DeleteRecord2(Tbl, string, string, string, string) (int64, error) + DeleteExpired0(Tbl, int32) (int64, error) + DeleteExpired(Tbl, string, string) (int64, error) + CleanupRecord(Tbl, string, string, interface{}) (int64, error) + GetExpiring(Tbl, string, string) ([]bson.M, error) + GetUniqueList(Tbl, string) ([]bson.M, error) + GetList0(Tbl, int32, int32, string) ([]bson.M, error) + GetList(Tbl, string, string, int32, int32, string) ([]bson.M, error) + GetAllTables() ([]string, error) + ValidateNewApp(appName string) bool } func getDBObj() BackendDB { - host := os.Getenv("MYSQL_HOST") - var db BackendDB - if len(host) > 0 { - db = &MySQLDB{} - } else { - db = &SQLiteDB{} - } - return db + host := os.Getenv("MYSQL_HOST") + var db BackendDB + if len(host) > 0 { + db = &MySQLDB{} + } else { + db = &SQLiteDB{} + } + return db } // InitDB function creates tables and indexes func InitDB(dbname *string) (BackendDB, error) { - db := getDBObj() - err := db.InitDB(dbname) - return db, err + db := getDBObj() + err := db.InitDB(dbname) + return db, err } func OpenDB(dbname *string) (BackendDB, error) { - db := getDBObj() - err := db.OpenDB(dbname) - return db, err + db := getDBObj() + err := db.OpenDB(dbname) + return db, err } func DBExists(filepath *string) bool { - db := getDBObj() - return db.DBExists(filepath) + db := getDBObj() + return db.DBExists(filepath) } func CreateTestDB() string { - db := getDBObj() - return db.CreateTestDB() + db := getDBObj() + return db.CreateTestDB() } func contains(slice []string, item string) bool { - set := make(map[string]struct{}, len(slice)) - for _, s := range slice { - set[s] = struct{}{} - } - _, ok := set[item] - return ok + set := make(map[string]struct{}, len(slice)) + for _, s := range slice { + set[s] = struct{}{} + } + _, ok := set[item] + return ok } - diff --git a/src/users_api.go b/src/users_api.go index 1758380..047d287 100644 --- a/src/users_api.go +++ b/src/users_api.go @@ -206,11 +206,11 @@ func (e mainEnv) userChange(w http.ResponseWriter, r *http.Request, ps httproute } adminRecordChanged := false if UserSchemaEnabled() { - adminRecordChanged, err = e.db.validateUserRecordChange(userJSON, parsedData.jsonData, userTOKEN, authResult) - if err != nil { - returnError(w, r, "schema validation error: " + err.Error(), 405, err, event) - return - } + adminRecordChanged, err = e.db.validateUserRecordChange(userJSON, parsedData.jsonData, userTOKEN, authResult) + if err != nil { + returnError(w, r, "schema validation error: "+err.Error(), 405, err, event) + return + } } if authResult == "login" { event.Title = "user change-profile request" @@ -312,13 +312,13 @@ func (e mainEnv) userPrelogin(w http.ResponseWriter, r *http.Request, ps httprou event := audit("user prelogin by "+mode, address, mode, address) defer func() { event.submit(e.db) }() - code0, err := decryptCaptcha(captcha) - if err != nil || code0 != code { + code0, err := decryptCaptcha(captcha) + if err != nil || code0 != code { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(200) fmt.Fprintf(w, `{"status":"error","result":"captcha-error"}`) return - } + } if mode != "phone" && mode != "email" { returnError(w, r, "bad mode", 405, nil, event) return diff --git a/src/users_db.go b/src/users_db.go index caee800..c6d0d60 100644 --- a/src/users_db.go +++ b/src/users_db.go @@ -120,26 +120,26 @@ func (dbobj dbcon) validateUserRecordChange(oldUserJSON []byte, jsonDataPatch [] } func (dbobj dbcon) updateUserRecord(jsonDataPatch []byte, userTOKEN string, userBSON bson.M, event *auditEvent, conf Config) ([]byte, []byte, bool, error) { - oldJSON, newJSON, lookupErr, err := dbobj.updateUserRecordDo(jsonDataPatch, userTOKEN, userBSON, event, conf) - if lookupErr == true { - return oldJSON, newJSON, lookupErr, err - } - if err == nil { - return oldJSON, newJSON, lookupErr, nil - } - // load one more time user BSON structure - userBSON2, err := dbobj.lookupUserRecord(userTOKEN) - if userBSON2 == nil || err != nil { - return nil, nil, true, err - } - oldJSON, newJSON, lookupErr, err = dbobj.updateUserRecordDo(jsonDataPatch, userTOKEN, userBSON2, event, conf) - if lookupErr == true { - return oldJSON, newJSON, lookupErr, err - } - if err == nil { - return oldJSON, newJSON, lookupErr, nil - } - return nil, nil, false, err + oldJSON, newJSON, lookupErr, err := dbobj.updateUserRecordDo(jsonDataPatch, userTOKEN, userBSON, event, conf) + if lookupErr == true { + return oldJSON, newJSON, lookupErr, err + } + if err == nil { + return oldJSON, newJSON, lookupErr, nil + } + // load one more time user BSON structure + userBSON2, err := dbobj.lookupUserRecord(userTOKEN) + if userBSON2 == nil || err != nil { + return nil, nil, true, err + } + oldJSON, newJSON, lookupErr, err = dbobj.updateUserRecordDo(jsonDataPatch, userTOKEN, userBSON2, event, conf) + if lookupErr == true { + return oldJSON, newJSON, lookupErr, err + } + if err == nil { + return oldJSON, newJSON, lookupErr, nil + } + return nil, nil, false, err } func (dbobj dbcon) updateUserRecordDo(jsonDataPatch []byte, userTOKEN string, oldUserBson bson.M, event *auditEvent, conf Config) ([]byte, []byte, bool, error) { @@ -160,14 +160,14 @@ func (dbobj dbcon) updateUserRecordDo(jsonDataPatch []byte, userTOKEN string, ol return nil, nil, false, err } var raw2 map[string]interface{} - err = json.Unmarshal(decrypted, &raw2) - if err != nil { + err = json.Unmarshal(decrypted, &raw2) + if err != nil { return nil, nil, false, err - } + } oldEmail := "" - if _, ok := raw2["email"]; ok { - oldEmail = normalizeEmail(raw2["email"].(string)) - } + if _, ok := raw2["email"]; ok { + oldEmail = normalizeEmail(raw2["email"].(string)) + } // merge //fmt.Printf("old json: %s\n", decrypted) //fmt.Printf("json patch: %s\n", jsonDataPatch) @@ -214,8 +214,8 @@ func (dbobj dbcon) updateUserRecordDo(jsonDataPatch []byte, userTOKEN string, ol } else { //fmt.Println("index value changed!") } - //} else { - // fmt.Println("old or new is empty") + //} else { + // fmt.Println("old or new is empty") } } if len(newIdxFinalValue) > 0 && actionCode == 1 { @@ -289,34 +289,34 @@ func (dbobj dbcon) lookupUserRecordByIndex(indexName string, indexValue string, } func (dbobj dbcon) getUserJson(userTOKEN string) ([]byte, error) { - userBson, err := dbobj.lookupUserRecord(userTOKEN) - if userBson == nil || err != nil { - // not found - return nil, err - } - if _, ok := userBson["key"]; !ok { - return []byte("{}"), nil - } - userKey := userBson["key"].(string) - recordKey, err := base64.StdEncoding.DecodeString(userKey) - if err != nil { - return nil, err - } - var decrypted []byte - if _, ok := userBson["data"]; ok { - encData0 := userBson["data"].(string) - if len(encData0) > 0 { - encData, err := base64.StdEncoding.DecodeString(encData0) - if err != nil { - return nil, err - } - decrypted, err = decrypt(dbobj.masterKey, recordKey, encData) - if err != nil { - return nil, err - } - } - } - return decrypted, err + userBson, err := dbobj.lookupUserRecord(userTOKEN) + if userBson == nil || err != nil { + // not found + return nil, err + } + if _, ok := userBson["key"]; !ok { + return []byte("{}"), nil + } + userKey := userBson["key"].(string) + recordKey, err := base64.StdEncoding.DecodeString(userKey) + if err != nil { + return nil, err + } + var decrypted []byte + if _, ok := userBson["data"]; ok { + encData0 := userBson["data"].(string) + if len(encData0) > 0 { + encData, err := base64.StdEncoding.DecodeString(encData0) + if err != nil { + return nil, err + } + decrypted, err = decrypt(dbobj.masterKey, recordKey, encData) + if err != nil { + return nil, err + } + } + } + return decrypted, err } func (dbobj dbcon) getUser(userTOKEN string) ([]byte, bson.M, error) { @@ -351,31 +351,31 @@ func (dbobj dbcon) getUser(userTOKEN string) ([]byte, bson.M, error) { } func (dbobj dbcon) getUserJsonByIndex(indexValue string, indexName string, conf Config) ([]byte, string, error) { - userBson, err := dbobj.lookupUserRecordByIndex(indexName, indexValue, conf) - if userBson == nil || err != nil { - return nil, "", err - } - // decrypt record - userKey := userBson["key"].(string) - recordKey, err := base64.StdEncoding.DecodeString(userKey) - if err != nil { - return nil, "", err - } - var decrypted []byte - if _, ok := userBson["data"]; ok { - encData0 := userBson["data"].(string) - if len(encData0) > 0 { - encData, err := base64.StdEncoding.DecodeString(encData0) - if err != nil { - return nil, "", err - } - decrypted, err = decrypt(dbobj.masterKey, recordKey, encData) - if err != nil { - return nil, "", err - } - } - } - return decrypted, userBson["token"].(string), err + userBson, err := dbobj.lookupUserRecordByIndex(indexName, indexValue, conf) + if userBson == nil || err != nil { + return nil, "", err + } + // decrypt record + userKey := userBson["key"].(string) + recordKey, err := base64.StdEncoding.DecodeString(userKey) + if err != nil { + return nil, "", err + } + var decrypted []byte + if _, ok := userBson["data"]; ok { + encData0 := userBson["data"].(string) + if len(encData0) > 0 { + encData, err := base64.StdEncoding.DecodeString(encData0) + if err != nil { + return nil, "", err + } + decrypted, err = decrypt(dbobj.masterKey, recordKey, encData) + if err != nil { + return nil, "", err + } + } + } + return decrypted, userBson["token"].(string), err } func (dbobj dbcon) getUserByIndex(indexValue string, indexName string, conf Config) ([]byte, string, bson.M, error) { diff --git a/src/utils.go b/src/utils.go index 1abe510..b420f4b 100644 --- a/src/utils.go +++ b/src/utils.go @@ -61,20 +61,20 @@ func getStringValue(r interface{}) string { return "" } switch r.(type) { - case string: - return strings.TrimSpace(r.(string)) - case []uint8: - return strings.TrimSpace(string(r.([]uint8))) + case string: + return strings.TrimSpace(r.(string)) + case []uint8: + return strings.TrimSpace(string(r.([]uint8))) } return "" } func getIntValue(r interface{}) int { switch r.(type) { - case int: - return r.(int) - case int32: - return int(r.(int32)) + case int: + return r.(int) + case int32: + return int(r.(int32)) } return 0 } @@ -469,11 +469,11 @@ func getJSONPostData(r *http.Request) (map[string]interface{}, error) { return nil, err } } else if strings.HasPrefix(cType, "application/xml") { - err = json.Unmarshal(body, &records) - if err != nil { - log.Printf("Error in xml/json decode %s", err) - return nil, err - } + err = json.Unmarshal(body, &records) + if err != nil { + log.Printf("Error in xml/json decode %s", err) + return nil, err + } } else { log.Printf("Ignore wrong content type: %s", cType) maxStrLen := 200 @@ -488,19 +488,19 @@ func getJSONPostData(r *http.Request) (map[string]interface{}, error) { func getIndexString(val interface{}) string { switch val.(type) { - case nil: - return "" - case string: - return strings.TrimSpace(val.(string)) - case []uint8: - return strings.TrimSpace(string(val.([]uint8))) - case int: - return strconv.Itoa(val.(int)) - case int64: - return fmt.Sprintf("%v", val.(int64)) - case float64: - return strconv.Itoa(int(val.(float64))) - } + case nil: + return "" + case string: + return strings.TrimSpace(val.(string)) + case []uint8: + return strings.TrimSpace(string(val.([]uint8))) + case int: + return strconv.Itoa(val.(int)) + case int64: + return fmt.Sprintf("%v", val.(int64)) + case float64: + return strconv.Itoa(int(val.(float64))) + } return "" } diff --git a/src/xtokens_test.go b/src/xtokens_test.go index 88cf927..8eea1e5 100644 --- a/src/xtokens_test.go +++ b/src/xtokens_test.go @@ -56,11 +56,11 @@ func helpCancelUserRequest(rtoken string) (map[string]interface{}, error) { func TestUserLoginDelete(t *testing.T) { raw, err := helpCreateLBasis("contract1", `{"basistype":"contract","usercontrol":false}`) if err != nil { - t.Fatalf("error: %s", err) - } - if _, ok := raw["status"]; !ok || raw["status"].(string) != "ok" { - t.Fatalf("Failed to create lbasis") - } + t.Fatalf("error: %s", err) + } + if _, ok := raw["status"]; !ok || raw["status"].(string) != "ok" { + t.Fatalf("Failed to create lbasis") + } email := "test@securitybunker.io" jsonData := `{"email":"test@securitybunker.io","phone":"22346622","fname":"Yuli","lname":"Str","tz":"323xxxxx","password":"123456","address":"Y-d habanim 7","city":"Petah-Tiqva","btest":true,"numtest":123,"testnul":null}` raw, err = helpCreateUser(jsonData) @@ -185,10 +185,10 @@ func TestUserLoginDelete(t *testing.T) { } } helpApproveUserRequest(rtoken0) - raw, _ = helpGetUserRequests() - if raw["total"].(float64) != 0 { - t.Fatalf("Wrong number of user requests for admin to approve/reject/s\n") - } + raw, _ = helpGetUserRequests() + if raw["total"].(float64) != 0 { + t.Fatalf("Wrong number of user requests for admin to approve/reject/s\n") + } // user should be deleted now raw10, _ := helpGetUserAppList(userTOKEN) if len(raw10["apps"].([]interface{})) != 0 {