mirror of
https://github.com/ccfos/nightingale.git
synced 2026-03-03 01:49:38 +00:00
1239 lines
30 KiB
Go
1239 lines
30 KiB
Go
package models
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"os"
|
||
"reflect"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
"unicode"
|
||
|
||
"github.com/ccfos/nightingale/v6/pkg/ctx"
|
||
"github.com/ccfos/nightingale/v6/pkg/logx"
|
||
"github.com/ccfos/nightingale/v6/pkg/ormx"
|
||
"github.com/ccfos/nightingale/v6/pkg/poster"
|
||
"github.com/ccfos/nightingale/v6/pkg/secu"
|
||
"github.com/ccfos/nightingale/v6/storage"
|
||
"github.com/redis/go-redis/v9"
|
||
|
||
"github.com/pkg/errors"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/toolkits/pkg/logger"
|
||
"github.com/toolkits/pkg/slice"
|
||
"github.com/toolkits/pkg/str"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const (
|
||
Dingtalk = "dingtalk"
|
||
Wecom = "wecom"
|
||
Feishu = "feishu"
|
||
FeishuCard = "feishucard"
|
||
Discord = "discord"
|
||
MattermostWebhook = "mattermostwebhook"
|
||
MattermostBot = "mattermostbot"
|
||
SlackWebhook = "slackwebhook"
|
||
SlackBot = "slackbot"
|
||
Mm = "mm"
|
||
Telegram = "telegram"
|
||
Email = "email"
|
||
EmailSubject = "mailsubject"
|
||
Lark = "lark"
|
||
LarkCard = "larkcard"
|
||
Phone = "phone"
|
||
Jira = "jira"
|
||
JSMAlert = "jsm_alert"
|
||
|
||
DingtalkKey = "dingtalk_robot_token"
|
||
WecomKey = "wecom_robot_token"
|
||
FeishuKey = "feishu_robot_token"
|
||
MmKey = "mm_webhook_url"
|
||
TelegramKey = "telegram_robot_token"
|
||
LarkKey = "lark_robot_token"
|
||
PagerDutyKey = "pagerduty_key"
|
||
|
||
DingtalkDomain = "oapi.dingtalk.com"
|
||
WecomDomain = "qyapi.weixin.qq.com"
|
||
FeishuDomain = "open.feishu.cn"
|
||
LarkDomain = "open.larksuite.com"
|
||
|
||
// FeishuCardDomain The domain name of the feishu card is the same as the feishu,distinguished by the parameter
|
||
FeishuCardDomain = "open.feishu.cn?card=1"
|
||
LarkCardDomain = "open.larksuite.com?card=1"
|
||
TelegramDomain = "api.telegram.org"
|
||
IbexDomain = "ibex"
|
||
DefaultDomain = "default"
|
||
)
|
||
|
||
var (
|
||
DefaultChannels = []string{Dingtalk, Wecom, Feishu, Mm, Telegram, Email, FeishuCard, Lark, LarkCard}
|
||
DefaultContacts = []string{DingtalkKey, WecomKey, FeishuKey, MmKey, TelegramKey, LarkKey}
|
||
)
|
||
|
||
type User struct {
|
||
Id int64 `json:"id" gorm:"primaryKey"`
|
||
Username string `json:"username"`
|
||
Nickname string `json:"nickname"`
|
||
Password string `json:"-"`
|
||
Phone string `json:"phone"`
|
||
Email string `json:"email"`
|
||
Portrait string `json:"portrait"`
|
||
Roles string `json:"-"` // 这个字段写入数据库
|
||
RolesLst []string `json:"roles" gorm:"-"` // 这个字段和前端交互
|
||
TeamsLst []int64 `json:"-" gorm:"-"` // 这个字段方便映射团队,前端和数据库都不用到
|
||
Contacts ormx.JSONObj `json:"contacts"` // 内容为 map[string]string 结构
|
||
Maintainer int `json:"maintainer"` // 是否给管理员发消息 0:not send 1:send
|
||
CreateAt int64 `json:"create_at"`
|
||
CreateBy string `json:"create_by"`
|
||
UpdateAt int64 `json:"update_at"`
|
||
UpdateBy string `json:"update_by"`
|
||
Belong string `json:"belong"`
|
||
Admin bool `json:"admin" gorm:"-"` // 方便前端使用
|
||
UserGroupsRes []*UserGroupRes `json:"user_groups" gorm:"-"`
|
||
BusiGroupsRes []*BusiGroupRes `json:"busi_groups" gorm:"-"`
|
||
LastActiveTime int64 `json:"last_active_time"`
|
||
}
|
||
|
||
type UserGroupRes struct {
|
||
Id int64 `json:"id"`
|
||
Name string `json:"name"`
|
||
}
|
||
|
||
type BusiGroupRes struct {
|
||
Id int64 `json:"id"`
|
||
Name string `json:"name"`
|
||
}
|
||
|
||
func (u *User) TableName() string {
|
||
return "users"
|
||
}
|
||
|
||
func (u *User) String() string {
|
||
bs, err := u.Contacts.MarshalJSON()
|
||
if err != nil {
|
||
return err.Error()
|
||
}
|
||
|
||
return fmt.Sprintf("<id:%d username:%s nickname:%s email:%s phone:%s contacts:%s>", u.Id, u.Username, u.Nickname, u.Email, u.Phone, string(bs))
|
||
}
|
||
|
||
func (u *User) IsAdmin() bool {
|
||
for i := 0; i < len(u.RolesLst); i++ {
|
||
if u.RolesLst[i] == AdminRole {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// has group permission
|
||
func (u *User) CheckGroupPermission(ctx *ctx.Context, groupIds []int64) error {
|
||
if !u.IsAdmin() {
|
||
ids, err := MyGroupIdsMap(ctx, u.Id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, id := range groupIds {
|
||
if _, ok := ids[id]; ok {
|
||
return nil
|
||
}
|
||
}
|
||
|
||
return errors.New("forbidden")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// stripInvisibleChars removes invisible Unicode characters from a string
|
||
// This includes zero-width spaces, control characters, and other invisible chars
|
||
func stripInvisibleChars(s string) string {
|
||
return strings.Map(func(r rune) rune {
|
||
// Keep printable characters and common whitespace (space, tab, newline)
|
||
if unicode.IsPrint(r) || r == ' ' || r == '\t' || r == '\n' || r == '\r' {
|
||
return r
|
||
}
|
||
// Remove invisible characters
|
||
return -1
|
||
}, s)
|
||
}
|
||
|
||
// stripInvisibleCharsFromContacts removes invisible characters from Contacts JSON values
|
||
func stripInvisibleCharsFromContacts(contacts ormx.JSONObj) ormx.JSONObj {
|
||
if len(contacts) == 0 {
|
||
return contacts
|
||
}
|
||
|
||
var contactsMap map[string]string
|
||
if err := json.Unmarshal(contacts, &contactsMap); err != nil {
|
||
return contacts
|
||
}
|
||
|
||
for k, v := range contactsMap {
|
||
contactsMap[k] = stripInvisibleChars(v)
|
||
}
|
||
|
||
result, err := json.Marshal(contactsMap)
|
||
if err != nil {
|
||
return contacts
|
||
}
|
||
|
||
return ormx.JSONObj(result)
|
||
}
|
||
|
||
func (u *User) Verify() error {
|
||
u.Username = strings.TrimSpace(u.Username)
|
||
|
||
if u.Username == "" {
|
||
return errors.New("Username is blank")
|
||
}
|
||
|
||
if str.Dangerous(u.Username) {
|
||
return errors.New("Username has invalid characters")
|
||
}
|
||
|
||
if str.Dangerous(u.Nickname) {
|
||
return errors.New("Nickname has invalid characters")
|
||
}
|
||
|
||
if u.Phone != "" && !str.IsPhone(u.Phone) {
|
||
return errors.New("Phone invalid")
|
||
}
|
||
|
||
if u.Email != "" && !str.IsMail(u.Email) {
|
||
return errors.New("Email invalid")
|
||
}
|
||
|
||
// Strip invisible characters from Contacts values
|
||
u.Contacts = stripInvisibleCharsFromContacts(u.Contacts)
|
||
|
||
if u.Phone != "" {
|
||
return u.EncryptPhone()
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (u *User) UpdateSsoFields(sso string, nickname, phone, email string) []interface{} {
|
||
u.UpdateAt = time.Now().Unix()
|
||
|
||
if nickname != "" {
|
||
u.Nickname = nickname
|
||
}
|
||
if phone != "" {
|
||
u.Phone = phone
|
||
}
|
||
if email != "" {
|
||
u.Email = email
|
||
}
|
||
u.UpdateBy = sso
|
||
u.Belong = sso
|
||
|
||
updatedFields := []interface{}{"nickname", "phone", "email", "update_by", "belong"}
|
||
return updatedFields
|
||
}
|
||
|
||
func (u *User) UpdateSsoFieldsWithRoles(sso string, nickname, phone, email string, roles []string) []interface{} {
|
||
updatedFields := u.UpdateSsoFields(sso, nickname, phone, email)
|
||
|
||
if len(roles) == 0 {
|
||
return updatedFields
|
||
}
|
||
|
||
u.Roles = strings.Join(roles, " ")
|
||
u.RolesLst = roles
|
||
|
||
return append(updatedFields, "roles")
|
||
}
|
||
|
||
func (u *User) FullSsoFields(sso, username, nickname, phone, email string, defaultRoles []string) {
|
||
now := time.Now().Unix()
|
||
|
||
u.Username = username
|
||
u.Password = "******"
|
||
u.Nickname = nickname
|
||
u.Phone = phone
|
||
u.Email = email
|
||
u.Portrait = ""
|
||
u.Roles = strings.Join(defaultRoles, " ")
|
||
u.RolesLst = defaultRoles
|
||
u.Contacts = []byte("{}")
|
||
u.CreateAt = now
|
||
u.UpdateAt = now
|
||
u.CreateBy = sso
|
||
u.UpdateBy = sso
|
||
u.Belong = sso
|
||
}
|
||
|
||
func (u *User) FullSsoFieldsWithTeams(sso, username, nickname, phone, email string, defaultRoles []string,
|
||
teams []int64) {
|
||
u.FullSsoFields(sso, username, nickname, phone, email, defaultRoles)
|
||
u.TeamsLst = teams
|
||
}
|
||
|
||
func (u *User) Add(ctx *ctx.Context) error {
|
||
user, err := UserGetByUsername(ctx, u.Username)
|
||
if err != nil {
|
||
return errors.WithMessage(err, "failed to query user")
|
||
}
|
||
|
||
if user != nil {
|
||
return errors.New("Username already exists")
|
||
}
|
||
|
||
now := time.Now().Unix()
|
||
u.CreateAt = now
|
||
u.UpdateAt = now
|
||
return Insert(ctx, u)
|
||
}
|
||
|
||
func (u *User) Update(ctx *ctx.Context, selectField interface{}, selectFields ...interface{}) error {
|
||
if u.Belong == "" {
|
||
if err := u.Verify(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return DB(ctx).Model(u).Select(selectField, selectFields...).Updates(u).Error
|
||
}
|
||
|
||
func (u *User) UpdateAllFields(ctx *ctx.Context) error {
|
||
if err := u.Verify(); err != nil {
|
||
return err
|
||
}
|
||
|
||
u.UpdateAt = time.Now().Unix()
|
||
return DB(ctx).Model(u).Select("*").Updates(u).Error
|
||
}
|
||
|
||
func (u *User) UpdatePassword(ctx *ctx.Context, password, updateBy string) error {
|
||
return DB(ctx).Model(u).Updates(map[string]interface{}{
|
||
"password": password,
|
||
"update_at": time.Now().Unix(),
|
||
"update_by": updateBy,
|
||
}).Error
|
||
}
|
||
|
||
func (u *User) AddToUserGroups(ctx *ctx.Context, userGroupIds []int64) error {
|
||
|
||
count := len(userGroupIds)
|
||
for i := 0; i < count; i++ {
|
||
err := UserGroupMemberAdd(ctx, userGroupIds[i], u.Id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func UpdateUserLastActiveTime(ctx *ctx.Context, userId int64, lastActiveTime int64) error {
|
||
return DB(ctx).Model(&User{}).Where("id = ?", userId).Updates(map[string]interface{}{
|
||
"last_active_time": lastActiveTime,
|
||
"update_at": time.Now().Unix(),
|
||
}).Error
|
||
}
|
||
|
||
func (u *User) Del(ctx *ctx.Context) error {
|
||
return DB(ctx).Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Where("user_id=?", u.Id).Delete(&UserGroupMember{}).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tx.Where("id=?", u.Id).Delete(&User{}).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tx.Where("username=?", u.Username).Delete(&UserToken{}).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (u *User) ChangePassword(ctx *ctx.Context, oldpass, newpass string) error {
|
||
// SSO 用户(ldap/oidc/cas/oauth2/dingtalk等)且未设置本地密码,不支持本地修改密码
|
||
if u.Belong != "" && u.Password == "******" {
|
||
return fmt.Errorf("SSO user(%s) cannot change password locally, please change password in %s", u.Username, u.Belong)
|
||
}
|
||
|
||
_oldpass, err := CryptoPass(ctx, oldpass)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
_newpass, err := CryptoPass(ctx, newpass)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if u.Password != _oldpass {
|
||
return errors.New("Incorrect old password")
|
||
}
|
||
|
||
return u.UpdatePassword(ctx, _newpass, u.Username)
|
||
}
|
||
|
||
func UserGet(ctx *ctx.Context, where string, args ...interface{}) (*User, error) {
|
||
var lst []*User
|
||
err := DB(ctx).Where(where, args...).Find(&lst).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if len(lst) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
lst[0].RolesLst = strings.Fields(lst[0].Roles)
|
||
lst[0].Admin = lst[0].IsAdmin()
|
||
lst[0].DecryptPhone() // 解密手机号
|
||
|
||
return lst[0], nil
|
||
}
|
||
|
||
func UsersGet(ctx *ctx.Context, where string, args ...interface{}) ([]*User, error) {
|
||
var lst []*User
|
||
err := DB(ctx).Where(where, args...).Find(&lst).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, user := range lst {
|
||
user.RolesLst = strings.Fields(user.Roles)
|
||
user.Admin = user.IsAdmin()
|
||
user.DecryptPhone() // 解密手机号
|
||
}
|
||
|
||
return lst, nil
|
||
}
|
||
|
||
func UserMapGet(ctx *ctx.Context, where string, args ...interface{}) map[string]*User {
|
||
lst, err := UsersGet(ctx, where, args...)
|
||
if err != nil {
|
||
logger.Errorf("UsersGet err: %v", err)
|
||
return nil
|
||
}
|
||
um := make(map[string]*User, len(lst))
|
||
for _, user := range lst {
|
||
um[user.Username] = user
|
||
}
|
||
return um
|
||
}
|
||
|
||
// UserNicknameMap returns a deduplicated username -> nickname map.
|
||
func UserNicknameMap(ctx *ctx.Context, names []string) map[string]string {
|
||
m := make(map[string]string)
|
||
if len(names) == 0 {
|
||
return m
|
||
}
|
||
seen := make(map[string]struct{}, len(names))
|
||
unique := make([]string, 0, len(names))
|
||
for _, name := range names {
|
||
if name == "" {
|
||
continue
|
||
}
|
||
if _, ok := seen[name]; ok {
|
||
continue
|
||
}
|
||
seen[name] = struct{}{}
|
||
unique = append(unique, name)
|
||
}
|
||
if len(unique) == 0 {
|
||
return m
|
||
}
|
||
users := UserMapGet(ctx, "username in (?)", unique)
|
||
for username, user := range users {
|
||
m[username] = user.Nickname
|
||
}
|
||
return m
|
||
}
|
||
|
||
// FillUpdateByNicknames fills the UpdateByNickname field for each element in items
|
||
// by looking up the UpdateBy username. Supports both []T and []*T slices.
|
||
func FillUpdateByNicknames[T any](ctx *ctx.Context, items []T) {
|
||
if len(items) == 0 {
|
||
return
|
||
}
|
||
|
||
elemType := reflect.TypeOf(items).Elem()
|
||
isPtr := elemType.Kind() == reflect.Ptr
|
||
if isPtr {
|
||
elemType = elemType.Elem()
|
||
}
|
||
|
||
updateByField, ok1 := elemType.FieldByName("UpdateBy")
|
||
nicknameField, ok2 := elemType.FieldByName("UpdateByNickname")
|
||
if !ok1 || !ok2 {
|
||
return
|
||
}
|
||
|
||
names := make([]string, 0, len(items))
|
||
for i := range items {
|
||
v := reflect.ValueOf(&items[i]).Elem()
|
||
if isPtr {
|
||
if v.IsNil() {
|
||
continue
|
||
}
|
||
v = v.Elem()
|
||
}
|
||
names = append(names, v.FieldByIndex(updateByField.Index).String())
|
||
}
|
||
|
||
nm := UserNicknameMap(ctx, names)
|
||
|
||
for i := range items {
|
||
v := reflect.ValueOf(&items[i]).Elem()
|
||
if isPtr {
|
||
if v.IsNil() {
|
||
continue
|
||
}
|
||
v = v.Elem()
|
||
}
|
||
updateBy := v.FieldByIndex(updateByField.Index).String()
|
||
v.FieldByIndex(nicknameField.Index).SetString(nm[updateBy])
|
||
}
|
||
}
|
||
|
||
func UserGetByUsername(ctx *ctx.Context, username string) (*User, error) {
|
||
return UserGet(ctx, "username=?", username)
|
||
}
|
||
|
||
func UserGetById(ctx *ctx.Context, id int64) (*User, error) {
|
||
return UserGet(ctx, "id=?", id)
|
||
}
|
||
|
||
func CountAdminUsers(ctx *ctx.Context) (int64, error) {
|
||
var count int64
|
||
err := DB(ctx).Model(&User{}).Where("roles LIKE ?", "%"+AdminRole+"%").Count(&count).Error
|
||
return count, err
|
||
}
|
||
|
||
func UsersGetByGroupIds(ctx *ctx.Context, groupIds []int64) ([]User, error) {
|
||
if len(groupIds) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
userIds, err := GroupsMemberIds(ctx, groupIds)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
users, err := UserGetsByIds(ctx, userIds)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return users, nil
|
||
}
|
||
|
||
func InitRoot(ctx *ctx.Context) bool {
|
||
user, err := UserGetByUsername(ctx, "root")
|
||
if err != nil {
|
||
fmt.Println("failed to query user root:", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
if user == nil {
|
||
return false
|
||
}
|
||
|
||
if len(user.Password) > 31 {
|
||
// already done before
|
||
return false
|
||
}
|
||
|
||
// 查询用户个数
|
||
count, err := Count(DB(ctx).Model(&User{}))
|
||
if err != nil {
|
||
fmt.Println("failed to count user:", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
if count == 1 {
|
||
// 说明数据库只有一个 root 用户,并且 root 用户密码没有加密,需要初始化 salt
|
||
InitSalt(ctx)
|
||
}
|
||
|
||
newPass, err := CryptoPass(ctx, user.Password)
|
||
if err != nil {
|
||
fmt.Println("failed to crypto pass:", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
err = DB(ctx).Model(user).Update("password", newPass).Error
|
||
if err != nil {
|
||
fmt.Println("failed to update root password:", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
fmt.Println("root password init done")
|
||
return true
|
||
}
|
||
|
||
func reachLoginFailCount(ctx *ctx.Context, redisObj storage.Redis, username string, count int64) (bool, error) {
|
||
key := "/userlogin/errorcount/" + username
|
||
val, err := redisObj.Get(ctx.GetContext(), key).Result()
|
||
if err == redis.Nil {
|
||
return false, nil
|
||
}
|
||
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
c, err := strconv.ParseInt(val, 10, 64)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return c >= count, nil
|
||
}
|
||
|
||
func incrLoginFailCount(ctx *ctx.Context, redisObj storage.Redis, username string, seconds int64) {
|
||
key := "/userlogin/errorcount/" + username
|
||
duration := time.Duration(seconds) * time.Second
|
||
|
||
val, err := redisObj.Get(ctx.GetContext(), key).Result()
|
||
if err == redis.Nil {
|
||
redisObj.Set(ctx.GetContext(), key, "1", duration)
|
||
return
|
||
}
|
||
|
||
if err != nil {
|
||
logx.Warningf(ctx.Ctx, "login_fail_count: failed to get redis value. key:%s, error:%s", key, err)
|
||
redisObj.Set(ctx.GetContext(), key, "1", duration)
|
||
return
|
||
}
|
||
|
||
count, err := strconv.ParseInt(val, 10, 64)
|
||
if err != nil {
|
||
logx.Warningf(ctx.Ctx, "login_fail_count: failed to parse int64. key:%s, error:%s", key, err)
|
||
redisObj.Set(ctx.GetContext(), key, "1", duration)
|
||
return
|
||
}
|
||
|
||
count++
|
||
redisObj.Set(ctx.GetContext(), key, fmt.Sprintf("%d", count), duration)
|
||
}
|
||
|
||
func PassLogin(ctx *ctx.Context, redis storage.Redis, username, pass string) (*User, error) {
|
||
// 300 5 meaning: 300 seconds, 5 times
|
||
val, err := ConfigsGet(ctx, "login_fail_count")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var (
|
||
needCheck = val != "" // DB 里有配置,说明启用了这个 feature
|
||
seconds int64
|
||
count int64
|
||
)
|
||
|
||
if needCheck {
|
||
pair := strings.Fields(val)
|
||
if len(pair) != 2 {
|
||
logx.Warningf(ctx.Ctx, "login_fail_count config invalid: %s", val)
|
||
needCheck = false
|
||
} else {
|
||
seconds, err = strconv.ParseInt(pair[0], 10, 64)
|
||
if err != nil {
|
||
logx.Warningf(ctx.Ctx, "login_fail_count seconds invalid: %s", pair[0])
|
||
needCheck = false
|
||
}
|
||
|
||
count, err = strconv.ParseInt(pair[1], 10, 64)
|
||
if err != nil {
|
||
logx.Warningf(ctx.Ctx, "login_fail_count count invalid: %s", pair[1])
|
||
needCheck = false
|
||
}
|
||
}
|
||
}
|
||
|
||
if needCheck {
|
||
reach, err := reachLoginFailCount(ctx, redis, username, count)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if reach {
|
||
return nil, fmt.Errorf("reach login fail count")
|
||
}
|
||
}
|
||
|
||
user, err := UserGetByUsername(ctx, username)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if user == nil {
|
||
if needCheck {
|
||
incrLoginFailCount(ctx, redis, username, seconds)
|
||
}
|
||
|
||
return nil, fmt.Errorf("Username or password invalid")
|
||
}
|
||
|
||
loginPass, err := CryptoPass(ctx, pass)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if loginPass != user.Password {
|
||
if needCheck {
|
||
incrLoginFailCount(ctx, redis, username, seconds)
|
||
}
|
||
return nil, fmt.Errorf("Username or password invalid")
|
||
}
|
||
|
||
return user, nil
|
||
}
|
||
|
||
func UserTotal(ctx *ctx.Context, query string, stime, etime int64) (num int64, err error) {
|
||
db := DB(ctx).Model(&User{})
|
||
|
||
if stime != 0 && etime != 0 {
|
||
db = db.Where("last_active_time between ? and ?", stime, etime)
|
||
}
|
||
|
||
if query != "" {
|
||
q := "%" + query + "%"
|
||
num, err = Count(db.Where("username like ? or nickname like ? or phone like ? or email like ?", q, q, q, q))
|
||
} else {
|
||
num, err = Count(db)
|
||
}
|
||
|
||
if err != nil {
|
||
return num, errors.WithMessage(err, "failed to count user")
|
||
}
|
||
|
||
return num, nil
|
||
}
|
||
|
||
var (
|
||
// 预编译正则表达式,避免重复编译
|
||
whitespaceRegex = regexp.MustCompile(`\s+`)
|
||
validOrderRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?$`)
|
||
)
|
||
|
||
func validateOrderField(order string, defaultField string) string {
|
||
// 空值检查
|
||
if order == "" {
|
||
return defaultField
|
||
}
|
||
|
||
// 长度检查
|
||
if len(order) > 64 {
|
||
logger.Warningf("SQL injection attempt detected: order field too long (%d chars)", len(order))
|
||
return defaultField
|
||
}
|
||
|
||
// 移除所有空白字符
|
||
order = whitespaceRegex.ReplaceAllString(order, "")
|
||
if order == "" {
|
||
return defaultField
|
||
}
|
||
|
||
// 检查危险字符
|
||
orderLower := strings.ToLower(order)
|
||
if strings.ContainsAny(order, "();,'\"` --/*\\=+-*/><|&^~") ||
|
||
strings.Contains(orderLower, "0x") || strings.Contains(orderLower, "0b") {
|
||
logger.Warningf("SQL injection attempt detected: contains dangerous characters")
|
||
return defaultField
|
||
}
|
||
|
||
// 使用正则表达式验证格式:只允许字母开头的字段名,可选择性包含表名
|
||
if !validOrderRegex.MatchString(order) {
|
||
logger.Warningf("SQL injection attempt detected: invalid order field format")
|
||
return defaultField
|
||
}
|
||
|
||
return order
|
||
}
|
||
|
||
func UserGets(ctx *ctx.Context, query string, limit, offset int, stime, etime int64,
|
||
order string, desc bool, usernames, phones, emails []string) ([]User, error) {
|
||
|
||
session := DB(ctx)
|
||
|
||
if stime != 0 && etime != 0 {
|
||
session = session.Where("last_active_time between ? and ?", stime, etime)
|
||
}
|
||
|
||
order = validateOrderField(order, "username")
|
||
|
||
if desc {
|
||
order = order + " desc"
|
||
} else {
|
||
order = order + " asc"
|
||
}
|
||
|
||
session = session.Order(order)
|
||
|
||
if len(usernames) > 0 {
|
||
session = session.Where("username in (?)", usernames)
|
||
}
|
||
|
||
if len(phones) > 0 {
|
||
session = session.Where("phone in (?)", phones)
|
||
}
|
||
|
||
if len(emails) > 0 {
|
||
session = session.Where("email in (?)", emails)
|
||
}
|
||
|
||
if query != "" {
|
||
q := "%" + query + "%"
|
||
session = session.Where("username like ? or nickname like ? or phone like ? or email like ?", q, q, q, q)
|
||
}
|
||
|
||
var users []User
|
||
err := session.Limit(limit).Offset(offset).Find(&users).Error
|
||
if err != nil {
|
||
return users, errors.WithMessage(err, "failed to query user")
|
||
}
|
||
|
||
for i := range users {
|
||
users[i].RolesLst = strings.Fields(users[i].Roles)
|
||
users[i].Admin = users[i].IsAdmin()
|
||
users[i].Password = ""
|
||
users[i].DecryptPhone() // 解密手机号
|
||
|
||
// query for user group information
|
||
var userGroupIDs []int64
|
||
userGroupIDs, err = MyGroupIds(ctx, users[i].Id)
|
||
if err != nil {
|
||
return users, errors.WithMessage(err, "failed to query group_ids")
|
||
}
|
||
|
||
if err = DB(ctx).Table("user_group").Where("id IN (?)", userGroupIDs).
|
||
Find(&users[i].UserGroupsRes).Error; err != nil {
|
||
return users, errors.WithMessage(err, "failed to query user_groups")
|
||
}
|
||
|
||
// query business group information
|
||
var busiGroupIDs []int64
|
||
busiGroupIDs, err = BusiGroupIds(ctx, userGroupIDs)
|
||
if err != nil {
|
||
return users, errors.WithMessage(err, "failed to query busi_group_id")
|
||
}
|
||
|
||
if err = DB(ctx).Table("busi_group").Where("id IN (?)", busiGroupIDs).
|
||
Find(&users[i].BusiGroupsRes).Error; err != nil {
|
||
return users, errors.WithMessage(err, "failed to query busi_groups")
|
||
}
|
||
}
|
||
|
||
return users, nil
|
||
}
|
||
|
||
func UserGetAll(ctx *ctx.Context) ([]*User, error) {
|
||
if !ctx.IsCenter {
|
||
lst, err := poster.GetByUrls[[]*User](ctx, "/v1/n9e/users")
|
||
return lst, err
|
||
}
|
||
|
||
var lst []*User
|
||
err := DB(ctx).Find(&lst).Error
|
||
if err == nil {
|
||
for i := 0; i < len(lst); i++ {
|
||
lst[i].RolesLst = strings.Fields(lst[i].Roles)
|
||
lst[i].Admin = lst[i].IsAdmin()
|
||
lst[i].DecryptPhone() // 解密手机号
|
||
}
|
||
}
|
||
return lst, err
|
||
}
|
||
|
||
func UserGetsByIds(ctx *ctx.Context, ids []int64) ([]User, error) {
|
||
if len(ids) == 0 {
|
||
return []User{}, nil
|
||
}
|
||
|
||
var lst []User
|
||
err := DB(ctx).Where("id in ?", ids).Order("username").Find(&lst).Error
|
||
if err == nil {
|
||
for i := 0; i < len(lst); i++ {
|
||
lst[i].RolesLst = strings.Fields(lst[i].Roles)
|
||
lst[i].Admin = lst[i].IsAdmin()
|
||
lst[i].DecryptPhone() // 解密手机号
|
||
}
|
||
}
|
||
|
||
return lst, err
|
||
}
|
||
|
||
func UserGetsBySso(ctx *ctx.Context, sso string) (map[string]*User, error) {
|
||
session := DB(ctx).Where("belong=?", sso).Order("username")
|
||
|
||
var users []User
|
||
err := session.Find(&users).Error
|
||
if err != nil {
|
||
return nil, errors.WithMessage(err, "failed to query user")
|
||
}
|
||
|
||
usersMap := make(map[string]*User, len(users))
|
||
for i, user := range users {
|
||
usersMap[user.Username] = &users[i]
|
||
}
|
||
|
||
return usersMap, nil
|
||
}
|
||
|
||
func UserDelByIds(ctx *ctx.Context, userIds []int64) error {
|
||
return DB(ctx).Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Where("user_id in ?", userIds).Delete(&UserGroupMember{}).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tx.Where("id in ?", userIds).Delete(&User{}).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (u *User) CanModifyUserGroup(ctx *ctx.Context, ug *UserGroup) (bool, error) {
|
||
// 我是管理员,自然可以
|
||
if u.IsAdmin() {
|
||
return true, nil
|
||
}
|
||
|
||
// 我是创建者,自然可以
|
||
if ug.CreateBy == u.Username {
|
||
return true, nil
|
||
}
|
||
|
||
// 我是成员,也可以吧,简单搞
|
||
num, err := UserGroupMemberCount(ctx, "user_id=? and group_id=?", u.Id, ug.Id)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return num > 0, nil
|
||
}
|
||
|
||
func (u *User) CanDoBusiGroup(ctx *ctx.Context, bg *BusiGroup, permFlag ...string) (bool, error) {
|
||
if u.IsAdmin() {
|
||
return true, nil
|
||
}
|
||
|
||
// 我在任意一个UserGroup里,就有权限
|
||
ugids, err := UserGroupIdsOfBusiGroup(ctx, bg.Id, permFlag...)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
if len(ugids) == 0 {
|
||
return false, nil
|
||
}
|
||
|
||
num, err := UserGroupMemberCount(ctx, "user_id = ? and group_id in ?", u.Id, ugids)
|
||
return num > 0, err
|
||
}
|
||
|
||
func (u *User) CheckPerm(ctx *ctx.Context, operation string) (bool, error) {
|
||
if u.IsAdmin() {
|
||
return true, nil
|
||
}
|
||
|
||
return RoleHasOperation(ctx, u.RolesLst, operation)
|
||
}
|
||
|
||
func UserStatistics(ctx *ctx.Context) (*Statistics, error) {
|
||
if !ctx.IsCenter {
|
||
s, err := poster.GetByUrls[*Statistics](ctx, "/v1/n9e/statistic?name=user")
|
||
return s, err
|
||
}
|
||
|
||
session := DB(ctx).Model(&User{}).Select("count(*) as total", "max(update_at) as last_updated")
|
||
|
||
var stats []*Statistics
|
||
err := session.Find(&stats).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return stats[0], nil
|
||
}
|
||
|
||
func (u *User) NopriIdents(ctx *ctx.Context, idents []string) ([]string, error) {
|
||
if u.IsAdmin() {
|
||
return []string{}, nil
|
||
}
|
||
|
||
ugids, err := MyGroupIds(ctx, u.Id)
|
||
if err != nil {
|
||
return []string{}, err
|
||
}
|
||
|
||
if len(ugids) == 0 {
|
||
return idents, nil
|
||
}
|
||
|
||
bgids, err := BusiGroupIds(ctx, ugids, "rw")
|
||
if err != nil {
|
||
return []string{}, err
|
||
}
|
||
|
||
if len(bgids) == 0 {
|
||
return idents, nil
|
||
}
|
||
|
||
var allowedIdents []string
|
||
sub := DB(ctx).Model(&Target{}).Distinct("target.ident").
|
||
Joins("join target_busi_group on target.ident = target_busi_group.target_ident").
|
||
Where("target_busi_group.group_id in (?)", bgids)
|
||
err = DB(ctx).Model(&Target{}).Where("ident in (?)", sub).Pluck("ident", &allowedIdents).Error
|
||
if err != nil {
|
||
return []string{}, err
|
||
}
|
||
|
||
return slice.SubString(idents, allowedIdents), nil
|
||
}
|
||
|
||
// 我是管理员,返回所有
|
||
// 或者我是成员
|
||
func (u *User) BusiGroups(ctx *ctx.Context, limit int, query string, all ...bool) ([]BusiGroup, error) {
|
||
session := DB(ctx).Order("name").Limit(limit)
|
||
|
||
var lst []BusiGroup
|
||
if u.IsAdmin() || (len(all) > 0 && all[0]) {
|
||
err := session.Where("name like ?", "%"+query+"%").Find(&lst).Error
|
||
if err != nil {
|
||
return lst, err
|
||
}
|
||
|
||
if len(lst) == 0 && len(query) > 0 {
|
||
// 隐藏功能,一般人不告诉,哈哈。query可能是给的ident,所以上面的sql没有查到,当做ident来查一下试试
|
||
var t *Target
|
||
t, err = TargetGet(ctx, "ident=?", query)
|
||
if err != nil {
|
||
return lst, err
|
||
}
|
||
|
||
if t == nil {
|
||
return lst, nil
|
||
}
|
||
|
||
t.GroupIds, err = TargetGroupIdsGetByIdent(ctx, t.Ident)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = DB(ctx).Order("name").Limit(limit).Where("id in ?", t.GroupIds).Find(&lst).Error
|
||
}
|
||
|
||
return lst, err
|
||
}
|
||
|
||
userGroupIds, err := MyGroupIds(ctx, u.Id)
|
||
if err != nil {
|
||
return nil, errors.WithMessage(err, "failed to get MyGroupIds")
|
||
}
|
||
|
||
busiGroupIds, err := BusiGroupIds(ctx, userGroupIds)
|
||
if err != nil {
|
||
return nil, errors.WithMessage(err, "failed to get BusiGroupIds")
|
||
}
|
||
|
||
if len(busiGroupIds) == 0 {
|
||
return lst, nil
|
||
}
|
||
|
||
err = session.Where("id in ?", busiGroupIds).Where("name like ?", "%"+query+"%").Find(&lst).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if len(lst) == 0 && len(query) > 0 {
|
||
var t *Target
|
||
t, err = TargetGet(ctx, "ident=?", query)
|
||
if err != nil {
|
||
return lst, err
|
||
}
|
||
|
||
if t == nil {
|
||
return lst, nil
|
||
}
|
||
|
||
t.GroupIds, err = TargetGroupIdsGetByIdent(ctx, t.Ident)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if t != nil && t.MatchGroupId(busiGroupIds...) {
|
||
err = DB(ctx).Order("name").Limit(limit).Where("id in ?", t.GroupIds).Find(&lst).Error
|
||
}
|
||
}
|
||
|
||
return lst, err
|
||
}
|
||
|
||
func (u *User) UserGroups(ctx *ctx.Context, limit int, query string) ([]UserGroup, error) {
|
||
session := DB(ctx).Order("name").Limit(limit)
|
||
|
||
var lst []UserGroup
|
||
if u.IsAdmin() {
|
||
err := session.Where("name like ?", "%"+query+"%").Find(&lst).Error
|
||
if err != nil {
|
||
return lst, err
|
||
}
|
||
|
||
var user *User
|
||
if len(lst) == 0 && len(query) > 0 {
|
||
// 隐藏功能,一般人不告诉,哈哈。query可能是给的用户名,所以上面的sql没有查到,当做user来查一下试试
|
||
user, err = UserGetByUsername(ctx, query)
|
||
if user == nil {
|
||
return lst, err
|
||
}
|
||
var ids []int64
|
||
ids, err = MyGroupIds(ctx, user.Id)
|
||
if err != nil || len(ids) == 0 {
|
||
return lst, err
|
||
}
|
||
lst, err = UserGroupGetByIds(ctx, ids)
|
||
}
|
||
return lst, err
|
||
}
|
||
|
||
ids, err := MyGroupIds(ctx, u.Id)
|
||
if err != nil {
|
||
return nil, errors.WithMessage(err, "failed to get MyGroupIds")
|
||
}
|
||
|
||
if len(ids) > 0 {
|
||
session = session.Where("id in ? or create_by = ?", ids, u.Username)
|
||
} else {
|
||
session = session.Where("create_by = ?", u.Username)
|
||
}
|
||
|
||
if len(query) > 0 {
|
||
session = session.Where("name like ?", "%"+query+"%")
|
||
}
|
||
|
||
err = session.Find(&lst).Error
|
||
return lst, err
|
||
}
|
||
|
||
func (u *User) ExtractToken(key string) (string, bool) {
|
||
bs, err := u.Contacts.MarshalJSON()
|
||
if err != nil {
|
||
logger.Errorf("handle_notice: failed to marshal contacts: %v", err)
|
||
return "", false
|
||
}
|
||
|
||
switch key {
|
||
case Dingtalk:
|
||
ret := gjson.GetBytes(bs, DingtalkKey)
|
||
return ret.String(), ret.Exists()
|
||
case Wecom:
|
||
ret := gjson.GetBytes(bs, WecomKey)
|
||
return ret.String(), ret.Exists()
|
||
case Feishu, FeishuCard:
|
||
ret := gjson.GetBytes(bs, FeishuKey)
|
||
return ret.String(), ret.Exists()
|
||
case Mm:
|
||
ret := gjson.GetBytes(bs, MmKey)
|
||
return ret.String(), ret.Exists()
|
||
case Telegram:
|
||
ret := gjson.GetBytes(bs, TelegramKey)
|
||
return ret.String(), ret.Exists()
|
||
case Email:
|
||
return u.Email, u.Email != ""
|
||
case Lark, LarkCard:
|
||
ret := gjson.GetBytes(bs, LarkKey)
|
||
return ret.String(), ret.Exists()
|
||
case Phone:
|
||
return u.Phone, u.Phone != ""
|
||
default:
|
||
ret := gjson.GetBytes(bs, key)
|
||
return ret.String(), ret.Exists()
|
||
}
|
||
}
|
||
|
||
func (u *User) FindSameContact(email, phone string) string {
|
||
if u.Email != "" && u.Email == email {
|
||
return "email"
|
||
}
|
||
|
||
if u.Phone != "" && u.Phone == phone {
|
||
return "phone"
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// AddUserAndGroups Add a user and add it to multiple groups in a single transaction
|
||
func (u *User) AddUserAndGroups(ctx *ctx.Context, coverTeams bool) error {
|
||
|
||
// Try to add a user
|
||
if err := u.Add(ctx); err != nil {
|
||
return errors.WithMessage(err, "failed to add user")
|
||
}
|
||
|
||
// Try to add a group for the user
|
||
if err := UserGroupMemberSyncByUser(ctx, u, coverTeams); err != nil {
|
||
return errors.WithMessage(err, "failed to add user to groups")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (u *User) EncryptPhone() (err error) {
|
||
// 从缓存获取手机号加密配置
|
||
enabled, publicKey, _, _, loaded := GetPhoneEncryptionConfigFromCache()
|
||
if !loaded {
|
||
// 如果缓存未加载,记录日志但不阻止保存
|
||
logger.Infof("Phone encryption config cache not loaded, user: %s", u.Username)
|
||
return nil
|
||
}
|
||
|
||
// 检查是否启用了手机号加密
|
||
if enabled && u.Phone != "" {
|
||
// 检查手机号是否已经加密(避免重复加密)
|
||
if len(u.Phone) > 4 && u.Phone[:4] == "enc:" {
|
||
// 已经加密,跳过
|
||
return nil
|
||
}
|
||
|
||
encryptedPhone, err := secu.EncryptValue(u.Phone, publicKey)
|
||
if err != nil {
|
||
logger.Warningf("Failed to encrypt phone: %v, user: %s", err, u.Username)
|
||
return nil
|
||
}
|
||
|
||
u.Phone = encryptedPhone
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DecryptPhone 解密用户手机号(如果已加密)
|
||
func (u *User) DecryptPhone() {
|
||
if u.Phone == "" {
|
||
return
|
||
}
|
||
|
||
// 检查手机号是否是加密格式(有 "enc:" 前缀)
|
||
if len(u.Phone) <= 4 || u.Phone[:4] != "enc:" {
|
||
// 不是加密格式,不需要解密
|
||
return
|
||
}
|
||
|
||
// 从缓存获取手机号加密配置
|
||
enabled, _, privateKey, password, loaded := GetPhoneEncryptionConfigFromCache()
|
||
if !loaded || !enabled {
|
||
// 如果缓存未加载或未启用加密,不解密
|
||
return
|
||
}
|
||
|
||
// 对手机号进行解密
|
||
decryptedPhone, err := secu.Decrypt(u.Phone, privateKey, password)
|
||
if err != nil {
|
||
// 如果解密失败,记录错误但保持原样
|
||
logger.Warningf("Failed to decrypt phone for user %s: %v", u.Username, err)
|
||
return
|
||
}
|
||
u.Phone = decryptedPhone
|
||
}
|