From 9b1b55684b34e35aeab59723ab6ef69374ed7cd1 Mon Sep 17 00:00:00 2001 From: yuli Date: Thu, 27 Feb 2025 23:36:54 +0200 Subject: [PATCH] review record expiration code --- src/expiration_api.go | 5 ++--- src/sessions_api.go | 8 ++++---- src/sessions_db.go | 15 +++------------ src/sharedrecords_api.go | 16 +--------------- src/sharedrecords_db.go | 14 ++------------ src/utils/utils.go | 14 +++++--------- 6 files changed, 17 insertions(+), 55 deletions(-) diff --git a/src/expiration_api.go b/src/expiration_api.go index e84b024..b8d39db 100644 --- a/src/expiration_api.go +++ b/src/expiration_api.go @@ -149,9 +149,8 @@ func (e mainEnv) expStart(w http.ResponseWriter, r *http.Request, ps httprouter. utils.ReturnError(w, r, "failed to decode request body", 405, err, event) return } - expirationStr := utils.GetStringValue(postData["expiration"]) - expiration := utils.SetExpiration(e.conf.Policy.MaxUserRetentionPeriod, expirationStr) - endtime, _ := utils.ParseExpiration(expiration) + endtime := utils.SetExpiration(e.conf.Policy.MaxUserRetentionPeriod, postData["expiration"]) + // fmt.Printf("Set exp time: %d\n", endtime) status := utils.GetStringValue(postData["status"]) if len(status) == 0 { status = "wait" diff --git a/src/sessions_api.go b/src/sessions_api.go index 2863fa0..23fcfb5 100644 --- a/src/sessions_api.go +++ b/src/sessions_api.go @@ -38,9 +38,9 @@ func (e mainEnv) sessionCreate(w http.ResponseWriter, r *http.Request, ps httpro utils.ReturnError(w, r, "empty body", 405, nil, event) return } - expirationStr := utils.GetStringValue(postData["expiration"]) - expiration := utils.SetExpiration(e.conf.Policy.MaxSessionRetentionPeriod, expirationStr) - log.Printf("Record expiration: %s", expiration) + expiration := utils.SetExpiration(e.conf.Policy.MaxSessionRetentionPeriod, postData["expiration"]) + // now := int32(time.Now().Unix()) + // log.Printf("Record expiration: %d now %d", expiration, now) userToken := utils.GetStringValue(postData["token"]) userLogin := utils.GetStringValue(postData["login"]) userEmail := utils.GetStringValue(postData["email"]) @@ -123,7 +123,7 @@ func (e mainEnv) sessionNewOld(w http.ResponseWriter, r *http.Request, ps httpro } expirationStr := utils.GetStringValue(postData["expiration"]) expiration := utils.SetExpiration(e.conf.Policy.MaxSessionRetentionPeriod, expirationStr) - log.Printf("Record expiration: %s", expiration) + log.Printf("Record expiration: %d", expiration) jsonData, err := json.Marshal(postData) if err != nil { utils.ReturnError(w, r, "internal error", 405, err, event) diff --git a/src/sessions_db.go b/src/sessions_db.go index 4d57d98..6888616 100644 --- a/src/sessions_db.go +++ b/src/sessions_db.go @@ -16,17 +16,7 @@ type sessionEvent struct { Data string `json:"data"` } -func (dbobj dbcon) createSessionRecord(sessionUUID string, userTOKEN string, expiration string, data []byte) (string, error) { - var endtime int32 - var err error - now := int32(time.Now().Unix()) - if len(expiration) > 0 { - endtime, err = utils.ParseExpiration(expiration) - if err != nil { - return "", err - } - //log.Printf("expiration set to: %d, now: %d", endtime, now) - } +func (dbobj dbcon) createSessionRecord(sessionUUID string, userTOKEN string, endtime int32, data []byte) (string, error) { recordKey, err := utils.GenerateRecordKey() if err != nil { return "", err @@ -36,6 +26,7 @@ func (dbobj dbcon) createSessionRecord(sessionUUID string, userTOKEN string, exp return "", err } encodedStr := base64.StdEncoding.EncodeToString(encoded) + now := int32(time.Now().Unix()) bdoc := bson.M{} bdoc["token"] = userTOKEN bdoc["session"] = sessionUUID @@ -65,7 +56,7 @@ func (dbobj dbcon) getSession(sessionUUID string) (int32, []byte, string, error) } // check expiration now := int32(time.Now().Unix()) - //log.Printf("getSession checking now: %d exp %d", now, record["endtime"].(int32)) + // fmt.Printf("getSession checking now: %d exp %d\n", now, record["endtime"].(int32)) if now > record["endtime"].(int32) { return 0, nil, "", errors.New("session expired") } diff --git a/src/sharedrecords_api.go b/src/sharedrecords_api.go index 3b64a74..f97fe09 100644 --- a/src/sharedrecords_api.go +++ b/src/sharedrecords_api.go @@ -5,7 +5,6 @@ import ( "fmt" "log" "net/http" - "reflect" "strings" "github.com/julienschmidt/httprouter" @@ -33,26 +32,13 @@ func (e mainEnv) sharedRecordCreate(w http.ResponseWriter, r *http.Request, ps h session := utils.GetStringValue(postData["session"]) partner := utils.GetStringValue(postData["partner"]) appName := utils.GetStringValue(postData["app"]) - expiration := e.conf.Policy.MaxShareableRecordRetentionPeriod - if len(appName) > 0 { appName = strings.ToLower(appName) if utils.CheckValidApp(appName) == false { utils.ReturnError(w, r, "unknown app name", 405, nil, event) } } - if value, ok := postData["expiration"]; ok { - if reflect.TypeOf(value) == reflect.TypeOf("string") { - expiration = utils.SetExpiration(e.conf.Policy.MaxShareableRecordRetentionPeriod, value.(string)) - } else { - utils.ReturnError(w, r, "failed to parse expiration field", 405, err, event) - return - } - } - if len(expiration) == 0 { - // using default expiration time for record - expiration = "1m" - } + expiration := utils.SetExpiration(e.conf.Policy.MaxShareableRecordRetentionPeriod, postData["expiration"]) recordUUID, err := e.db.saveSharedRecord(userTOKEN, fields, expiration, session, appName, partner, e.conf) if err != nil { utils.ReturnError(w, r, err.Error(), 405, err, event) diff --git a/src/sharedrecords_db.go b/src/sharedrecords_db.go index d3b8c54..6e76d5d 100644 --- a/src/sharedrecords_db.go +++ b/src/sharedrecords_db.go @@ -2,7 +2,6 @@ package main import ( "errors" - "log" "strings" "time" @@ -12,25 +11,16 @@ import ( "go.mongodb.org/mongo-driver/bson" ) -func (dbobj dbcon) saveSharedRecord(userTOKEN string, fields string, expiration string, session string, appName string, partner string, conf Config) (string, error) { +func (dbobj dbcon) saveSharedRecord(userTOKEN string, fields string, endtime int32, session string, appName string, partner string, conf Config) (string, error) { if utils.CheckValidUUID(userTOKEN) == false { return "", errors.New("bad uuid") } - if len(expiration) == 0 { - return "", errors.New("failed to parse expiration") - } if len(appName) > 0 { apps, _ := dbobj.listAllApps(conf) if strings.Contains(string(apps), appName) == false { return "", errors.New("app not found") } } - - log.Printf("Expiration is : %s\n", expiration) - start, err := utils.ParseExpiration(expiration) - if err != nil { - return "", err - } recordUUID, err := uuid.GenerateUUID() if err != nil { return "", err @@ -40,7 +30,7 @@ func (dbobj dbcon) saveSharedRecord(userTOKEN string, fields string, expiration bdoc["token"] = userTOKEN bdoc["record"] = recordUUID bdoc["when"] = now - bdoc["endtime"] = start + bdoc["endtime"] = endtime if len(fields) > 0 { bdoc["fields"] = fields } diff --git a/src/utils/utils.go b/src/utils/utils.go index bc633b5..a09f480 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -170,7 +170,6 @@ func GetExpirationNum(val interface{}) int32 { case string: expiration := val.(string) match := regexExpiration.FindStringSubmatch(expiration) - log.Printf("match: %v", match) // expiration format: 10d, 10h, 10m, 10s if len(match) == 2 { num = Atoi(match[1]) @@ -314,10 +313,7 @@ func Atoi(s string) int32 { return int32(n) } -func SetExpiration(maxExpiration string, userExpiration string) string { - if len(userExpiration) == 0 { - return maxExpiration - } +func SetExpiration(maxExpiration interface{}, userExpiration interface{}) int32 { userExpirationNum, _ := ParseExpiration(userExpiration) maxExpirationNum, _ := ParseExpiration(maxExpiration) if maxExpirationNum == 0 { @@ -325,12 +321,12 @@ func SetExpiration(maxExpiration string, userExpiration string) string { maxExpirationNum, _ = ParseExpiration(maxExpiration) } if userExpirationNum == 0 { - return maxExpiration + return maxExpirationNum } if userExpirationNum > maxExpirationNum { - return maxExpiration + return maxExpirationNum } - return userExpiration + return userExpirationNum } func ParseExpiration0(expiration string) (int32, error) { @@ -357,7 +353,7 @@ func ParseExpiration0(expiration string) (int32, error) { } func ParseExpiration(expiration interface{}) (int32, error) { - now := int32(time.Now().Unix()) + 10 + now := int32(time.Now().Unix()) result := GetExpirationNum(expiration) if result == 0 { return 0, nil