Compare commits

...

9 Commits

Author SHA1 Message Date
Yening Qin
7568119667 Merge branch 'release-21' into update-workflow 2026-01-22 19:44:42 +08:00
ning
c93694a2a9 refactor: update init metrics tpl 2026-01-21 19:45:57 +08:00
ning
cfb8c3b66a refactor: update doris check max rows 2026-01-21 16:03:04 +08:00
ning
cb5e62b7bb fix save workflow execution 2026-01-20 21:28:51 +08:00
yuansheng
ebfde8d6a0 refactor: record_rule support writeback_enabled (#3048) 2026-01-20 19:32:09 +08:00
ning
b4dcaebf83 refactor: update doris check max rows 2026-01-20 16:34:50 +08:00
huangjie
fa491e313a sso add feishu (#3046) 2026-01-19 14:12:38 +08:00
ning
40722d2ff3 update workflow status 2026-01-16 19:56:24 +08:00
ning
4fe2b5042f refactor: update trigger value 2026-01-14 19:41:32 +08:00
15 changed files with 1703 additions and 34 deletions

View File

@@ -182,7 +182,6 @@ func (e *WorkflowEngine) executeDAG(nodeMap map[string]*models.WorkflowNode, con
result.Status = models.ExecutionStatusFailed
result.ErrorNode = nodeID
result.Message = fmt.Sprintf("node %s failed: %s", node.Name, nodeResult.Error)
return result
}
}

View File

@@ -271,10 +271,8 @@ func Init(ctx *ctx.Context, builtinIntegrationsDir string) {
}
for _, metric := range metrics {
if metric.UUID == 0 {
time.Sleep(time.Microsecond)
metric.UUID = time.Now().UnixMicro()
}
time.Sleep(time.Microsecond)
metric.UUID = time.Now().UnixMicro()
metric.ID = metric.UUID
metric.CreatedBy = SYSTEM
metric.UpdatedBy = SYSTEM

View File

@@ -251,10 +251,12 @@ func (rt *Router) Config(r *gin.Engine) {
pages.GET("/auth/redirect/cas", rt.loginRedirectCas)
pages.GET("/auth/redirect/oauth", rt.loginRedirectOAuth)
pages.GET("/auth/redirect/dingtalk", rt.loginRedirectDingTalk)
pages.GET("/auth/redirect/feishu", rt.loginRedirectFeiShu)
pages.GET("/auth/callback", rt.loginCallback)
pages.GET("/auth/callback/cas", rt.loginCallbackCas)
pages.GET("/auth/callback/oauth", rt.loginCallbackOAuth)
pages.GET("/auth/callback/dingtalk", rt.loginCallbackDingTalk)
pages.GET("/auth/callback/feishu", rt.loginCallbackFeiShu)
pages.GET("/auth/perms", rt.allPerms)
pages.GET("/metrics/desc", rt.metricsDescGetFile)
@@ -705,6 +707,7 @@ func (rt *Router) Config(r *gin.Engine) {
service.GET("/event-pipelines", rt.eventPipelinesListByService)
service.POST("/event-pipeline/:id/trigger", rt.triggerEventPipelineByService)
service.POST("/event-pipeline/:id/stream", rt.streamEventPipelineByService)
service.POST("/event-pipeline-execution", rt.eventPipelineExecutionAdd)
// 手机号加密存储配置接口
service.POST("/users/phone/encrypt", rt.usersPhoneEncrypt)

View File

@@ -621,3 +621,18 @@ func (rt *Router) streamEventPipelineByService(c *gin.Context) {
ginx.NewRender(c).Data(result, nil)
}
// eventPipelineExecutionAdd 接收 edge 节点同步的 Pipeline 执行记录
func (rt *Router) eventPipelineExecutionAdd(c *gin.Context) {
var execution models.EventPipelineExecution
ginx.BindJSON(c, &execution)
if execution.ID == "" {
ginx.Bomb(http.StatusBadRequest, "id is required")
}
if execution.PipelineID <= 0 {
ginx.Bomb(http.StatusBadRequest, "pipeline_id is required")
}
ginx.NewRender(c).Message(models.DB(rt.Ctx).Create(&execution).Error)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/ccfos/nightingale/v6/models"
"github.com/ccfos/nightingale/v6/pkg/cas"
"github.com/ccfos/nightingale/v6/pkg/dingtalk"
"github.com/ccfos/nightingale/v6/pkg/feishu"
"github.com/ccfos/nightingale/v6/pkg/ldapx"
"github.com/ccfos/nightingale/v6/pkg/oauth2x"
"github.com/ccfos/nightingale/v6/pkg/oidcx"
@@ -519,6 +520,85 @@ func (rt *Router) loginCallbackDingTalk(c *gin.Context) {
}
func (rt *Router) loginRedirectFeiShu(c *gin.Context) {
redirect := ginx.QueryStr(c, "redirect", "/")
v, exists := c.Get("userid")
if exists {
userid := v.(int64)
user, err := models.UserGetById(rt.Ctx, userid)
ginx.Dangerous(err)
if user == nil {
ginx.Bomb(200, "user not found")
}
if user.Username != "" { // already login
ginx.NewRender(c).Data(redirect, nil)
return
}
}
if rt.Sso.FeiShu == nil || !rt.Sso.FeiShu.Enable {
ginx.NewRender(c).Data("", nil)
return
}
redirect, err := rt.Sso.FeiShu.Authorize(rt.Redis, redirect)
ginx.Dangerous(err)
ginx.NewRender(c).Data(redirect, err)
}
func (rt *Router) loginCallbackFeiShu(c *gin.Context) {
code := ginx.QueryStr(c, "code", "")
state := ginx.QueryStr(c, "state", "")
ret, err := rt.Sso.FeiShu.Callback(rt.Redis, c.Request.Context(), code, state)
if err != nil {
logger.Errorf("sso_callback FeiShu fail. code:%s, state:%s, get ret: %+v. error: %v", code, state, ret, err)
ginx.NewRender(c).Data(CallbackOutput{}, err)
return
}
user, err := models.UserGet(rt.Ctx, "username=?", ret.Username)
ginx.Dangerous(err)
if user != nil {
if rt.Sso.FeiShu != nil && rt.Sso.FeiShu.FeiShuConfig != nil && rt.Sso.FeiShu.FeiShuConfig.CoverAttributes {
updatedFields := user.UpdateSsoFields(feishu.SsoTypeName, ret.Nickname, ret.Phone, ret.Email)
ginx.Dangerous(user.Update(rt.Ctx, "update_at", updatedFields...))
}
} else {
user = new(models.User)
defaultRoles := []string{}
if rt.Sso.FeiShu != nil && rt.Sso.FeiShu.FeiShuConfig != nil {
defaultRoles = rt.Sso.FeiShu.FeiShuConfig.DefaultRoles
}
user.FullSsoFields(feishu.SsoTypeName, ret.Username, ret.Nickname, ret.Phone, ret.Email, defaultRoles)
// create user from feishu
ginx.Dangerous(user.Add(rt.Ctx))
}
// set user login state
userIdentity := fmt.Sprintf("%d-%s", user.Id, user.Username)
ts, err := rt.createTokens(rt.HTTP.JWTAuth.SigningKey, userIdentity)
ginx.Dangerous(err)
ginx.Dangerous(rt.createAuth(c.Request.Context(), userIdentity, ts))
redirect := "/"
if ret.Redirect != "/login" {
redirect = ret.Redirect
}
ginx.NewRender(c).Data(CallbackOutput{
Redirect: redirect,
User: user,
AccessToken: ts.AccessToken,
RefreshToken: ts.RefreshToken,
}, nil)
}
func (rt *Router) loginCallbackOAuth(c *gin.Context) {
code := ginx.QueryStr(c, "code", "")
state := ginx.QueryStr(c, "state", "")
@@ -569,10 +649,11 @@ type SsoConfigOutput struct {
CasDisplayName string `json:"casDisplayName"`
OauthDisplayName string `json:"oauthDisplayName"`
DingTalkDisplayName string `json:"dingTalkDisplayName"`
FeiShuDisplayName string `json:"feishuDisplayName"`
}
func (rt *Router) ssoConfigNameGet(c *gin.Context) {
var oidcDisplayName, casDisplayName, oauthDisplayName, dingTalkDisplayName string
var oidcDisplayName, casDisplayName, oauthDisplayName, dingTalkDisplayName, feiShuDisplayName string
if rt.Sso.OIDC != nil {
oidcDisplayName = rt.Sso.OIDC.GetDisplayName()
}
@@ -589,11 +670,16 @@ func (rt *Router) ssoConfigNameGet(c *gin.Context) {
dingTalkDisplayName = rt.Sso.DingTalk.GetDisplayName()
}
if rt.Sso.FeiShu != nil {
feiShuDisplayName = rt.Sso.FeiShu.GetDisplayName()
}
ginx.NewRender(c).Data(SsoConfigOutput{
OidcDisplayName: oidcDisplayName,
CasDisplayName: casDisplayName,
OauthDisplayName: oauthDisplayName,
DingTalkDisplayName: dingTalkDisplayName,
FeiShuDisplayName: feiShuDisplayName,
}, nil)
}
@@ -608,6 +694,7 @@ func (rt *Router) ssoConfigGets(c *gin.Context) {
// TODO: dingTalkExist 为了兼容当前前端配置, 后期单点登陆统一调整后不在预先设置默认内容
dingTalkExist := false
feiShuExist := false
for _, config := range lst {
var ssoReqConfig models.SsoConfig
ssoReqConfig.Id = config.Id
@@ -618,6 +705,10 @@ func (rt *Router) ssoConfigGets(c *gin.Context) {
dingTalkExist = true
err := json.Unmarshal([]byte(config.Content), &ssoReqConfig.SettingJson)
ginx.Dangerous(err)
case feishu.SsoTypeName:
feiShuExist = true
err := json.Unmarshal([]byte(config.Content), &ssoReqConfig.SettingJson)
ginx.Dangerous(err)
default:
ssoReqConfig.Content = config.Content
}
@@ -630,6 +721,11 @@ func (rt *Router) ssoConfigGets(c *gin.Context) {
ssoConfig.Name = dingtalk.SsoTypeName
ssoConfigs = append(ssoConfigs, ssoConfig)
}
if !feiShuExist {
var ssoConfig models.SsoConfig
ssoConfig.Name = feishu.SsoTypeName
ssoConfigs = append(ssoConfigs, ssoConfig)
}
ginx.NewRender(c).Data(ssoConfigs, nil)
}
@@ -657,6 +753,23 @@ func (rt *Router) ssoConfigUpdate(c *gin.Context) {
err = f.Update(rt.Ctx)
}
ginx.Dangerous(err)
case feishu.SsoTypeName:
f.Name = ssoConfig.Name
setting, err := json.Marshal(ssoConfig.SettingJson)
ginx.Dangerous(err)
f.Content = string(setting)
f.UpdateAt = time.Now().Unix()
sso, err := f.Query(rt.Ctx)
if !errors.Is(err, gorm.ErrRecordNotFound) {
ginx.Dangerous(err)
}
if errors.Is(err, gorm.ErrRecordNotFound) {
err = f.Create(rt.Ctx)
} else {
f.Id = sso.Id
err = f.Update(rt.Ctx)
}
ginx.Dangerous(err)
default:
f.Id = ssoConfig.Id
f.Name = ssoConfig.Name
@@ -695,6 +808,14 @@ func (rt *Router) ssoConfigUpdate(c *gin.Context) {
rt.Sso.DingTalk = dingtalk.New(config)
}
rt.Sso.DingTalk.Reload(config)
case feishu.SsoTypeName:
var config feishu.Config
err := json.Unmarshal([]byte(f.Content), &config)
ginx.Dangerous(err)
if rt.Sso.FeiShu == nil {
rt.Sso.FeiShu = feishu.New(config)
}
rt.Sso.FeiShu.Reload(config)
}
ginx.NewRender(c).Message(nil)

View File

@@ -12,6 +12,7 @@ import (
"github.com/ccfos/nightingale/v6/pkg/cas"
"github.com/ccfos/nightingale/v6/pkg/ctx"
"github.com/ccfos/nightingale/v6/pkg/dingtalk"
"github.com/ccfos/nightingale/v6/pkg/feishu"
"github.com/ccfos/nightingale/v6/pkg/ldapx"
"github.com/ccfos/nightingale/v6/pkg/oauth2x"
"github.com/ccfos/nightingale/v6/pkg/oidcx"
@@ -27,6 +28,7 @@ type SsoClient struct {
CAS *cas.SsoClient
OAuth2 *oauth2x.SsoClient
DingTalk *dingtalk.SsoClient
FeiShu *feishu.SsoClient
LastUpdateTime int64
configCache *memsto.ConfigCache
configLastUpdateTime int64
@@ -203,6 +205,13 @@ func Init(center cconf.Center, ctx *ctx.Context, configCache *memsto.ConfigCache
log.Fatalf("init %s failed: %s", dingtalk.SsoTypeName, err)
}
ssoClient.DingTalk = dingtalk.New(config)
case feishu.SsoTypeName:
var config feishu.Config
err := json.Unmarshal([]byte(cfg.Content), &config)
if err != nil {
log.Fatalf("init %s failed: %s", feishu.SsoTypeName, err)
}
ssoClient.FeiShu = feishu.New(config)
}
}
@@ -291,6 +300,22 @@ func (s *SsoClient) reload(ctx *ctx.Context) error {
s.DingTalk = nil
}
if feiShuConfig, ok := ssoConfigMap[feishu.SsoTypeName]; ok {
var config feishu.Config
err := json.Unmarshal([]byte(feiShuConfig.Content), &config)
if err != nil {
logger.Warningf("reload %s failed: %s", feishu.SsoTypeName, err)
} else {
if s.FeiShu != nil {
s.FeiShu.Reload(config)
} else {
s.FeiShu = feishu.New(config)
}
}
} else {
s.FeiShu = nil
}
s.LastUpdateTime = lastUpdateTime
s.configLastUpdateTime = lastCacheUpdateTime
return nil

324
dskit/doris/sql_analyzer.go Normal file
View File

@@ -0,0 +1,324 @@
package doris
import (
"regexp"
"strings"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
_ "github.com/pingcap/tidb/pkg/parser/test_driver" // required for parser
)
// mapAccessPattern matches Doris map/array access syntax like `col['key']` or col["key"]
var mapAccessPattern = regexp.MustCompile(`\[['"]\w+['"]\]`)
// castStringPattern matches Doris CAST(... AS STRING) syntax
var castStringPattern = regexp.MustCompile(`(?i)\bAS\s+STRING\b`)
// macro patterns
var timeGroupPattern = regexp.MustCompile(`\$__timeGroup\([^)]+\)`)
var timeFilterPattern = regexp.MustCompile(`\$__timeFilter\([^)]+\)`)
var intervalPattern = regexp.MustCompile(`\$__interval`)
// SQLAnalyzeResult holds the analysis result of a SQL statement
type SQLAnalyzeResult struct {
IsSelectLike bool // whether the statement is a SELECT-like query
HasTopAgg bool // whether the top-level query has aggregate functions
LimitConst *int64 // top-level LIMIT constant value (nil if no LIMIT or non-constant)
}
// AnalyzeSQL analyzes a SQL statement and extracts top-level features
func AnalyzeSQL(sql string) (*SQLAnalyzeResult, error) {
// Preprocess SQL to remove Doris-specific syntax that TiDB parser doesn't support
preprocessedSQL := preprocessDorisSQL(sql)
p := parser.New()
stmtNodes, _, err := p.Parse(preprocessedSQL, "", "")
if err != nil {
return nil, err
}
if len(stmtNodes) == 0 {
return &SQLAnalyzeResult{}, nil
}
result := &SQLAnalyzeResult{}
stmt := stmtNodes[0]
switch s := stmt.(type) {
case *ast.SelectStmt:
result.IsSelectLike = true
analyzeSelectStmt(s, result)
case *ast.SetOprStmt: // UNION / INTERSECT / EXCEPT
result.IsSelectLike = true
analyzeSetOprStmt(s, result)
default:
result.IsSelectLike = false
}
return result, nil
}
// analyzeSelectStmt analyzes a SELECT statement
func analyzeSelectStmt(sel *ast.SelectStmt, result *SQLAnalyzeResult) {
// Check if top-level SELECT has aggregate functions
if sel.Fields != nil {
for _, field := range sel.Fields.Fields {
if field.Expr != nil && hasAggregateFunc(field.Expr) {
result.HasTopAgg = true
break
}
}
}
// Check if any CTE has aggregate functions
if !result.HasTopAgg && sel.With != nil {
for _, cte := range sel.With.CTEs {
if selectHasAggregate(cte.Query) {
result.HasTopAgg = true
break
}
}
}
// Extract top-level LIMIT
if sel.Limit != nil && sel.Limit.Count != nil {
if val, ok := extractConstValue(sel.Limit.Count); ok {
result.LimitConst = &val
}
}
}
// selectHasAggregate checks if a node (SELECT, UNION, or SubqueryExpr) has aggregate functions
func selectHasAggregate(node ast.Node) bool {
switch n := node.(type) {
case *ast.SelectStmt:
if n.Fields != nil {
for _, field := range n.Fields.Fields {
if field.Expr != nil && hasAggregateFunc(field.Expr) {
return true
}
}
}
case *ast.SetOprStmt:
// For UNION, check all branches
if n.SelectList != nil {
for _, sel := range n.SelectList.Selects {
if selectHasAggregate(sel) {
return true
}
}
}
case *ast.SubqueryExpr:
// CTE query is wrapped in SubqueryExpr
if n.Query != nil {
return selectHasAggregate(n.Query)
}
}
return false
}
// analyzeSetOprStmt analyzes UNION/INTERSECT/EXCEPT statements
func analyzeSetOprStmt(setOpr *ast.SetOprStmt, result *SQLAnalyzeResult) {
// UNION's LIMIT is at the outermost level
if setOpr.Limit != nil && setOpr.Limit.Count != nil {
if val, ok := extractConstValue(setOpr.Limit.Count); ok {
result.LimitConst = &val
}
}
// Check if all branches are aggregates (conservative: if any is non-aggregate, don't skip)
if setOpr.SelectList == nil || len(setOpr.SelectList.Selects) == 0 {
return
}
allAgg := true
for _, sel := range setOpr.SelectList.Selects {
if selectStmt, ok := sel.(*ast.SelectStmt); ok {
if selectStmt.Fields != nil {
hasAgg := false
for _, field := range selectStmt.Fields.Fields {
if field.Expr != nil && hasAggregateFunc(field.Expr) {
hasAgg = true
break
}
}
if !hasAgg {
allAgg = false
break
}
}
}
}
result.HasTopAgg = allAgg
}
// hasAggregateFunc checks if an expression contains aggregate functions (without entering subqueries)
func hasAggregateFunc(expr ast.ExprNode) bool {
checker := &aggregateChecker{}
expr.Accept(checker)
return checker.found
}
// aggregateChecker implements ast.Visitor to find aggregate functions
type aggregateChecker struct {
found bool
}
func (c *aggregateChecker) Enter(n ast.Node) (ast.Node, bool) {
if c.found {
return n, true // stop traversal
}
switch node := n.(type) {
case *ast.SubqueryExpr:
return n, true // don't enter subquery
case *ast.AggregateFuncExpr:
c.found = true
return n, true
case *ast.FuncCallExpr:
// Check for Doris-specific aggregate/statistic functions
funcName := strings.ToUpper(node.FnName.L)
if isDorisAggregateFunc(funcName) {
c.found = true
return n, true
}
}
return n, false // continue traversal
}
func (c *aggregateChecker) Leave(n ast.Node) (ast.Node, bool) {
return n, true
}
// isDorisAggregateFunc checks if a function is a Doris-specific aggregate/statistic function
func isDorisAggregateFunc(funcName string) bool {
dorisAggFuncs := map[string]bool{
// Standard aggregates (in case parser doesn't recognize them)
"COUNT": true,
"SUM": true,
"AVG": true,
"MIN": true,
"MAX": true,
"ANY": true,
"ANY_VALUE": true,
// HLL related
"HLL_UNION_AGG": true,
"HLL_RAW_AGG": true,
"HLL_CARDINALITY": true,
"HLL_UNION": true,
"HLL_HASH": true,
// Bitmap related
"BITMAP_UNION": true,
"BITMAP_UNION_COUNT": true,
"BITMAP_INTERSECT": true,
"BITMAP_COUNT": true,
"BITMAP_AND_COUNT": true,
"BITMAP_OR_COUNT": true,
"BITMAP_XOR_COUNT": true,
"BITMAP_AND_NOT_COUNT": true,
// Other aggregates
"PERCENTILE": true,
"PERCENTILE_APPROX": true,
"APPROX_COUNT_DISTINCT": true,
"NDV": true,
"COLLECT_LIST": true,
"COLLECT_SET": true,
"GROUP_CONCAT": true,
"GROUP_BIT_AND": true,
"GROUP_BIT_OR": true,
"GROUP_BIT_XOR": true,
"GROUPING": true,
"GROUPING_ID": true,
// Statistical functions
"STDDEV": true,
"STDDEV_POP": true,
"STDDEV_SAMP": true,
"STD": true,
"VARIANCE": true,
"VAR_POP": true,
"VAR_SAMP": true,
"COVAR_POP": true,
"COVAR_SAMP": true,
"CORR": true,
// Window functions that are also aggregates
"FIRST_VALUE": true,
"LAST_VALUE": true,
"LAG": true,
"LEAD": true,
"ROW_NUMBER": true,
"RANK": true,
"DENSE_RANK": true,
"NTILE": true,
"CUME_DIST": true,
"PERCENT_RANK": true,
}
return dorisAggFuncs[funcName]
}
// extractConstValue extracts constant integer value from an expression
func extractConstValue(expr ast.ExprNode) (int64, bool) {
switch v := expr.(type) {
case ast.ValueExpr:
switch val := v.GetValue().(type) {
case int64:
return val, true
case uint64:
return int64(val), true
case float64:
return int64(val), true
case int:
return int64(val), true
}
}
return 0, false
}
// preprocessDorisSQL removes Doris-specific syntax that TiDB parser doesn't support
func preprocessDorisSQL(sql string) string {
// Remove map/array access syntax like ['key'] or ["key"]
// This is used in Doris for accessing map/variant/json fields
sql = mapAccessPattern.ReplaceAllString(sql, "")
// Replace Doris CAST(... AS STRING) with CAST(... AS CHAR)
sql = castStringPattern.ReplaceAllString(sql, "AS CHAR")
// Replace macros with valid SQL equivalents
sql = timeGroupPattern.ReplaceAllString(sql, "ts")
sql = timeFilterPattern.ReplaceAllString(sql, "1=1")
sql = intervalPattern.ReplaceAllString(sql, "60")
return sql
}
// NeedsRowCountCheck determines if a SQL query needs row count checking
// Returns: needsCheck bool, directReject bool, rejectReason string
func NeedsRowCountCheck(sql string, maxQueryRows int) (bool, bool, string) {
result, err := AnalyzeSQL(sql)
if err != nil {
// Parse failed, fall back to probe check
return true, false, ""
}
if !result.IsSelectLike {
// Not a SELECT query, skip check
return false, false, ""
}
// Rule 1: Top-level has aggregate functions -> skip check
if result.HasTopAgg {
return false, false, ""
}
// Rule 2: Top-level LIMIT <= maxRows -> skip check
if result.LimitConst != nil && *result.LimitConst <= int64(maxQueryRows) {
return false, false, ""
}
// Otherwise, needs probe check (including LIMIT > maxRows, since actual result may be smaller)
return true, false, ""
}

View File

@@ -0,0 +1,784 @@
package doris
import (
"testing"
)
func TestAnalyzeSQL_AggregateQueries(t *testing.T) {
tests := []struct {
name string
sql string
wantHasAgg bool
wantIsSelect bool
}{
// Standard aggregate functions - should skip check
{
name: "COUNT(*)",
sql: "SELECT COUNT(*) AS `cnt`, FLOOR(UNIX_TIMESTAMP(event_date) DIV 10) * 10 AS `time`, CAST(`labels`['event'] AS STRING) AS `labels.event` FROM `db_insight_doris`.`ewall_event` WHERE `event_date` BETWEEN FROM_UNIXTIME(1768965669) AND FROM_UNIXTIME(1768965969) GROUP BY `time`, `labels.event` ORDER BY `time` ASC",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "COUNT with column",
sql: "SELECT COUNT(id) FROM users",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "SUM function",
sql: "SELECT SUM(amount) FROM orders",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "AVG function",
sql: "SELECT AVG(price) FROM products",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "MIN function",
sql: "SELECT MIN(created_at) FROM logs",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "MAX function",
sql: "SELECT MAX(score) FROM results",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "Multiple aggregates",
sql: "SELECT COUNT(*), SUM(amount), AVG(price) FROM orders",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "Aggregate with GROUP BY",
sql: "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "Aggregate with WHERE and GROUP BY",
sql: "SELECT category, SUM(sales) FROM products WHERE status = 'active' GROUP BY category",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "Aggregate with HAVING",
sql: "SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id HAVING cnt > 10",
wantHasAgg: true,
wantIsSelect: true,
},
// macro queries with aggregates
{
name: "COUNT with timeGroup",
sql: "SELECT COUNT(*) AS `cnt`, $__timeGroup(timestamp,$__interval) AS `time` FROM `apm`.`traces_span` WHERE (`service_name` = 'demo-logic-server') AND $__timeFilter(`timestamp`) GROUP BY `time` ORDER BY `time` ASC",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "CTE with ratio calculation",
sql: "WITH `time_totals` AS (SELECT $__timeGroup(timestamp,$__interval) AS `time`, COUNT(*) AS `total_count` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) GROUP BY `time`), `time_counts` AS (SELECT ANY_VALUE(`service_name`) AS `service_name`, $__timeGroup(timestamp,$__interval) AS `time`, COUNT(*) AS `count` FROM `apm`.`traces_span` WHERE (`service_name` = 'demo-logic-server') AND $__timeFilter(`timestamp`) GROUP BY `time`) SELECT tc.`service_name`, tc.`time`, ROUND(tc.`count` * 100.0 / tt.`total_count`, 2) AS `ratio` FROM `time_counts` tc JOIN `time_totals` tt ON tc.`time` = tt.`time` ORDER BY tc.`time` ASC",
wantHasAgg: true, // CTE has aggregate functions
wantIsSelect: true,
},
{
name: "CTE with top values and ratio",
sql: "WITH `top_values` AS (SELECT `service_name` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) GROUP BY `service_name` ORDER BY COUNT(*) DESC LIMIT 5), `time_totals` AS (SELECT $__timeGroup(timestamp,$__interval) AS `time`, COUNT(*) AS `total_count` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) GROUP BY `time`), `time_counts` AS (SELECT `service_name`, $__timeGroup(timestamp,$__interval) AS `time`, COUNT(*) AS `count` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) AND `service_name` IN (SELECT `service_name` FROM `top_values`) GROUP BY `service_name`, `time`) SELECT tc.`service_name`, tc.`time`, ROUND(tc.`count` * 100.0 / tt.`total_count`, 2) AS `ratio` FROM `time_counts` tc JOIN `time_totals` tt ON tc.`time` = tt.`time` ORDER BY tc.`time` ASC",
wantHasAgg: true, // CTE has aggregate functions
wantIsSelect: true,
},
{
name: "PERCENTILE_APPROX with timeGroup",
sql: "SELECT PERCENTILE_APPROX(`duration`, 0.95) AS `p95`, $__timeGroup(timestamp,$__interval) AS `time` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) GROUP BY `time` ORDER BY `time` ASC",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "COUNT DISTINCT with timeGroup",
sql: "SELECT COUNT(DISTINCT `duration`) AS `unique_count`, $__timeGroup(timestamp,$__interval) AS `time` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) GROUP BY `time` ORDER BY `time` ASC",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "CASE WHEN with COUNT and ROUND",
sql: "SELECT ROUND(COUNT(CASE WHEN `duration` IS NOT NULL THEN 1 END) * 100.0 / COUNT(*), 2) AS `exist_ratio`, $__timeGroup(timestamp,$__interval) AS `time` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) GROUP BY `time` ORDER BY `time` ASC",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "AVG with timeGroup",
sql: "SELECT AVG(`duration`) AS `avg`, $__timeGroup(timestamp,$__interval) AS `time` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`) GROUP BY `time` ORDER BY `time` ASC",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "Simple COUNT with timeFilter",
sql: "SELECT COUNT(*) AS `cnt` FROM `apm`.`traces_span` WHERE (`span_name` = 'GET /backend/detail') AND $__timeFilter(`timestamp`)",
wantHasAgg: true,
wantIsSelect: true,
},
{
name: "CTE with CROSS JOIN ratio",
sql: "WITH `total` AS (SELECT COUNT(*) AS `total_count` FROM `apm`.`traces_span` WHERE $__timeFilter(`timestamp`)), `value_counts` AS (SELECT ANY_VALUE(`span_kind`) AS `span_kind`, COUNT(*) AS `count` FROM `apm`.`traces_span` WHERE (`span_kind` = 'SPAN_KIND_SERVER') AND $__timeFilter(`timestamp`)) SELECT vc.`span_kind`, vc.`count` AS `count`, ROUND(vc.`count` * 100.0 / t.`total_count`, 2) AS `ratio` FROM `value_counts` vc CROSS JOIN `total` t ORDER BY vc.`count` DESC;",
wantHasAgg: true, // CTE has aggregate functions
wantIsSelect: true,
},
// Non-aggregate queries - should not skip check
{
name: "Simple SELECT *",
sql: "SELECT * FROM users",
wantHasAgg: false,
wantIsSelect: true,
},
{
name: "SELECT with columns",
sql: "SELECT id, name, email FROM users",
wantHasAgg: false,
wantIsSelect: true,
},
{
name: "SELECT with WHERE",
sql: "SELECT * FROM users WHERE status = 'active'",
wantHasAgg: false,
wantIsSelect: true,
},
{
name: "SELECT with JOIN",
sql: "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id",
wantHasAgg: false,
wantIsSelect: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := AnalyzeSQL(tt.sql)
if err != nil {
t.Fatalf("AnalyzeSQL() error = %v", err)
}
if result.HasTopAgg != tt.wantHasAgg {
t.Errorf("name: %s, HasTopAgg = %v, want %v", tt.name, result.HasTopAgg, tt.wantHasAgg)
}
if result.IsSelectLike != tt.wantIsSelect {
t.Errorf("IsSelectLike = %v, want %v", result.IsSelectLike, tt.wantIsSelect)
}
})
}
}
func TestAnalyzeSQL_SubqueryWithAggregate(t *testing.T) {
// Aggregate in subquery should NOT skip check for main query
tests := []struct {
name string
sql string
wantHasAgg bool
}{
{
name: "Aggregate in subquery only",
sql: "SELECT * FROM (SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id) t",
wantHasAgg: false, // top-level has no aggregate
},
{
name: "Aggregate in WHERE subquery",
sql: "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders GROUP BY user_id HAVING COUNT(*) > 5)",
wantHasAgg: false, // top-level has no aggregate
},
{
name: "Both top-level and subquery aggregates",
sql: "SELECT COUNT(*) FROM (SELECT user_id FROM orders GROUP BY user_id) t",
wantHasAgg: true, // top-level has aggregate
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := AnalyzeSQL(tt.sql)
if err != nil {
t.Fatalf("AnalyzeSQL() error = %v", err)
}
if result.HasTopAgg != tt.wantHasAgg {
t.Errorf("HasTopAgg = %v, want %v", result.HasTopAgg, tt.wantHasAgg)
}
})
}
}
func TestAnalyzeSQL_LimitQueries(t *testing.T) {
tests := []struct {
name string
sql string
wantLimit *int64
wantIsSelect bool
}{
{
name: "LIMIT 10",
sql: "SELECT * FROM users LIMIT 10",
wantLimit: ptr(int64(10)),
wantIsSelect: true,
},
{
name: "LIMIT 100",
sql: "SELECT * FROM users LIMIT 100",
wantLimit: ptr(int64(100)),
wantIsSelect: true,
},
{
name: "LIMIT 1000",
sql: "SELECT * FROM users LIMIT 1000",
wantLimit: ptr(int64(1000)),
wantIsSelect: true,
},
{
name: "LIMIT with OFFSET",
sql: "SELECT * FROM users LIMIT 50 OFFSET 100",
wantLimit: ptr(int64(50)),
wantIsSelect: true,
},
{
name: "No LIMIT",
sql: "SELECT * FROM users",
wantLimit: nil,
wantIsSelect: true,
},
{
name: "LIMIT 0",
sql: "SELECT * FROM users LIMIT 0",
wantLimit: ptr(int64(0)),
wantIsSelect: true,
},
{
name: "LIMIT 1",
sql: "SELECT * FROM users LIMIT 1",
wantLimit: ptr(int64(1)),
wantIsSelect: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := AnalyzeSQL(tt.sql)
if err != nil {
t.Fatalf("AnalyzeSQL() error = %v", err)
}
if result.IsSelectLike != tt.wantIsSelect {
t.Errorf("IsSelectLike = %v, want %v", result.IsSelectLike, tt.wantIsSelect)
}
if tt.wantLimit == nil {
if result.LimitConst != nil {
t.Errorf("LimitConst = %v, want nil", *result.LimitConst)
}
} else {
if result.LimitConst == nil {
t.Errorf("LimitConst = nil, want %v", *tt.wantLimit)
} else if *result.LimitConst != *tt.wantLimit {
t.Errorf("LimitConst = %v, want %v", *result.LimitConst, *tt.wantLimit)
}
}
})
}
}
func TestAnalyzeSQL_UnionQueries(t *testing.T) {
tests := []struct {
name string
sql string
wantHasAgg bool
wantLimit *int64
}{
{
name: "UNION without aggregate",
sql: "SELECT id, name FROM users UNION SELECT id, name FROM admins",
wantHasAgg: false,
wantLimit: nil,
},
{
name: "UNION ALL without aggregate",
sql: "SELECT * FROM users UNION ALL SELECT * FROM admins",
wantHasAgg: false,
wantLimit: nil,
},
{
name: "UNION with aggregate in all branches",
sql: "SELECT COUNT(*) FROM users UNION SELECT COUNT(*) FROM admins",
wantHasAgg: true,
wantLimit: nil,
},
{
name: "UNION with aggregate in one branch only",
sql: "SELECT COUNT(*) FROM users UNION SELECT id FROM admins",
wantHasAgg: false, // not all branches have aggregate
wantLimit: nil,
},
{
name: "UNION with outer LIMIT",
sql: "SELECT * FROM users UNION SELECT * FROM admins LIMIT 100",
wantHasAgg: false,
wantLimit: ptr(int64(100)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := AnalyzeSQL(tt.sql)
if err != nil {
t.Fatalf("AnalyzeSQL() error = %v", err)
}
if result.HasTopAgg != tt.wantHasAgg {
t.Errorf("HasTopAgg = %v, want %v", result.HasTopAgg, tt.wantHasAgg)
}
if tt.wantLimit == nil {
if result.LimitConst != nil {
t.Errorf("LimitConst = %v, want nil", *result.LimitConst)
}
} else {
if result.LimitConst == nil {
t.Errorf("LimitConst = nil, want %v", *tt.wantLimit)
} else if *result.LimitConst != *tt.wantLimit {
t.Errorf("LimitConst = %v, want %v", *result.LimitConst, *tt.wantLimit)
}
}
})
}
}
func TestAnalyzeSQL_NonSelectStatements(t *testing.T) {
tests := []struct {
name string
sql string
wantIsSelect bool
}{
{
name: "SHOW DATABASES",
sql: "SHOW DATABASES",
wantIsSelect: false,
},
{
name: "SHOW TABLES",
sql: "SHOW TABLES",
wantIsSelect: false,
},
{
name: "DESCRIBE table",
sql: "DESCRIBE users",
wantIsSelect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := AnalyzeSQL(tt.sql)
if err != nil {
// Some statements may not be parseable, which is fine
return
}
if result.IsSelectLike != tt.wantIsSelect {
t.Errorf("IsSelectLike = %v, want %v", result.IsSelectLike, tt.wantIsSelect)
}
})
}
}
func TestNeedsRowCountCheck(t *testing.T) {
maxRows := 500
tests := []struct {
name string
sql string
wantNeedCheck bool
wantReject bool
}{
// Should skip check (needsCheck = false)
{
name: "Aggregate COUNT(*)",
sql: "SELECT COUNT(*) FROM users",
wantNeedCheck: false,
wantReject: false,
},
{
name: "Aggregate SUM",
sql: "SELECT SUM(amount) FROM orders",
wantNeedCheck: false,
wantReject: false,
},
{
name: "Aggregate with GROUP BY",
sql: "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id",
wantNeedCheck: false,
wantReject: false,
},
{
name: "LIMIT equal to max",
sql: "SELECT * FROM users LIMIT 500",
wantNeedCheck: false,
wantReject: false,
},
{
name: "LIMIT less than max",
sql: "SELECT * FROM users LIMIT 100",
wantNeedCheck: false,
wantReject: false,
},
{
name: "LIMIT 1",
sql: "SELECT * FROM users LIMIT 1",
wantNeedCheck: false,
wantReject: false,
},
// LIMIT > maxRows still needs probe check (actual result might be smaller)
{
name: "LIMIT exceeds max",
sql: "SELECT * FROM users LIMIT 1000",
wantNeedCheck: true,
wantReject: false,
},
{
name: "LIMIT much larger than max",
sql: "SELECT * FROM users LIMIT 10000",
wantNeedCheck: true,
wantReject: false,
},
// Should execute probe check (needsCheck = true)
{
name: "No LIMIT no aggregate",
sql: "SELECT * FROM users",
wantNeedCheck: true,
wantReject: false,
},
{
name: "SELECT with WHERE no LIMIT",
sql: "SELECT * FROM users WHERE status = 'active'",
wantNeedCheck: true,
wantReject: false,
},
{
name: "SELECT with JOIN no LIMIT",
sql: "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id",
wantNeedCheck: true,
wantReject: false,
},
{
name: "Aggregate in subquery only",
sql: "SELECT * FROM (SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id) t",
wantNeedCheck: true,
wantReject: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
needsCheck, directReject, _ := NeedsRowCountCheck(tt.sql, maxRows)
if needsCheck != tt.wantNeedCheck {
t.Errorf("needsCheck = %v, want %v", needsCheck, tt.wantNeedCheck)
}
if directReject != tt.wantReject {
t.Errorf("directReject = %v, want %v", directReject, tt.wantReject)
}
})
}
}
func TestNeedsRowCountCheck_DorisSpecificFunctions(t *testing.T) {
maxRows := 500
tests := []struct {
name string
sql string
wantNeedCheck bool
}{
// Doris HLL functions
{
name: "HLL_UNION_AGG",
sql: "SELECT HLL_UNION_AGG(hll_col) FROM user_stats",
wantNeedCheck: false,
},
{
name: "HLL_CARDINALITY",
sql: "SELECT HLL_CARDINALITY(hll_col) FROM user_stats",
wantNeedCheck: false,
},
// Doris Bitmap functions
{
name: "BITMAP_UNION_COUNT",
sql: "SELECT BITMAP_UNION_COUNT(bitmap_col) FROM user_tags",
wantNeedCheck: false,
},
{
name: "BITMAP_UNION",
sql: "SELECT BITMAP_UNION(bitmap_col) FROM user_tags GROUP BY category",
wantNeedCheck: false,
},
// Other Doris aggregate functions
{
name: "APPROX_COUNT_DISTINCT",
sql: "SELECT APPROX_COUNT_DISTINCT(user_id) FROM events",
wantNeedCheck: false,
},
{
name: "GROUP_CONCAT",
sql: "SELECT GROUP_CONCAT(name) FROM users GROUP BY department",
wantNeedCheck: false,
},
{
name: "PERCENTILE_APPROX",
sql: "SELECT PERCENTILE_APPROX(latency, 0.99) FROM requests",
wantNeedCheck: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
needsCheck, _, _ := NeedsRowCountCheck(tt.sql, maxRows)
if needsCheck != tt.wantNeedCheck {
t.Errorf("needsCheck = %v, want %v (should skip check for Doris aggregate functions)", needsCheck, tt.wantNeedCheck)
}
})
}
}
func TestNeedsRowCountCheck_ComplexQueries(t *testing.T) {
maxRows := 500
tests := []struct {
name string
sql string
wantNeedCheck bool
wantReject bool
}{
{
name: "CTE with aggregate",
sql: "WITH user_counts AS (SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id) SELECT * FROM user_counts",
wantNeedCheck: false, // CTE has aggregate, skip check
wantReject: false,
},
{
name: "Complex JOIN with aggregate",
sql: "SELECT u.department, COUNT(*) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.department",
wantNeedCheck: false, // has aggregate
wantReject: false,
},
{
name: "Nested subquery",
sql: "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE amount > 100)",
wantNeedCheck: true,
wantReject: false,
},
{
name: "DISTINCT query",
sql: "SELECT DISTINCT category FROM products",
wantNeedCheck: true, // DISTINCT is not aggregate
wantReject: false,
},
{
name: "ORDER BY with LIMIT",
sql: "SELECT * FROM users ORDER BY created_at DESC LIMIT 100",
wantNeedCheck: false, // has valid LIMIT
wantReject: false,
},
{
name: "Multiple aggregates in single query",
sql: "SELECT COUNT(*), SUM(amount), AVG(amount), MIN(amount), MAX(amount) FROM orders",
wantNeedCheck: false,
wantReject: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
needsCheck, directReject, _ := NeedsRowCountCheck(tt.sql, maxRows)
if needsCheck != tt.wantNeedCheck {
t.Errorf("needsCheck = %v, want %v", needsCheck, tt.wantNeedCheck)
}
if directReject != tt.wantReject {
t.Errorf("directReject = %v, want %v", directReject, tt.wantReject)
}
})
}
}
func TestNeedsRowCountCheck_EdgeCases(t *testing.T) {
maxRows := 500
tests := []struct {
name string
sql string
wantNeedCheck bool
wantReject bool
}{
{
name: "Empty-ish LIMIT 0",
sql: "SELECT * FROM users LIMIT 0",
wantNeedCheck: false,
wantReject: false,
},
{
name: "LIMIT at boundary",
sql: "SELECT * FROM users LIMIT 501",
wantNeedCheck: true, // 501 > 500, needs probe check
wantReject: false,
},
{
name: "SELECT with trailing semicolon",
sql: "SELECT * FROM users;",
wantNeedCheck: true,
wantReject: false,
},
{
name: "SELECT with extra whitespace",
sql: " SELECT * FROM users ",
wantNeedCheck: true,
wantReject: false,
},
{
name: "Lowercase keywords",
sql: "select count(*) from users",
wantNeedCheck: false,
wantReject: false,
},
{
name: "Mixed case keywords",
sql: "Select Count(*) From users",
wantNeedCheck: false,
wantReject: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
needsCheck, directReject, _ := NeedsRowCountCheck(tt.sql, maxRows)
if needsCheck != tt.wantNeedCheck {
t.Errorf("needsCheck = %v, want %v", needsCheck, tt.wantNeedCheck)
}
if directReject != tt.wantReject {
t.Errorf("directReject = %v, want %v", directReject, tt.wantReject)
}
})
}
}
func TestNeedsRowCountCheck_DifferentMaxRows(t *testing.T) {
tests := []struct {
name string
sql string
maxRows int
wantNeedCheck bool
wantReject bool
}{
{
name: "LIMIT 100 with maxRows 50",
sql: "SELECT * FROM users LIMIT 100",
maxRows: 50,
wantNeedCheck: true, // LIMIT > maxRows, needs probe check
wantReject: false,
},
{
name: "LIMIT 100 with maxRows 100",
sql: "SELECT * FROM users LIMIT 100",
maxRows: 100,
wantNeedCheck: false,
wantReject: false,
},
{
name: "LIMIT 100 with maxRows 200",
sql: "SELECT * FROM users LIMIT 100",
maxRows: 200,
wantNeedCheck: false,
wantReject: false,
},
{
name: "No LIMIT with maxRows 1000",
sql: "SELECT * FROM users",
maxRows: 1000,
wantNeedCheck: true,
wantReject: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
needsCheck, directReject, _ := NeedsRowCountCheck(tt.sql, tt.maxRows)
if needsCheck != tt.wantNeedCheck {
t.Errorf("needsCheck = %v, want %v", needsCheck, tt.wantNeedCheck)
}
if directReject != tt.wantReject {
t.Errorf("directReject = %v, want %v", directReject, tt.wantReject)
}
})
}
}
// TestSummary_SkipProbeCheck prints a summary of which SQL patterns skip the probe check
func TestSummary_SkipProbeCheck(t *testing.T) {
maxRows := 500
skipCheckCases := []struct {
category string
sql string
}{
// Aggregate functions
{"Aggregate - COUNT(*)", "SELECT COUNT(*) FROM users"},
{"Aggregate - COUNT(col)", "SELECT COUNT(id) FROM users"},
{"Aggregate - SUM", "SELECT SUM(amount) FROM orders"},
{"Aggregate - AVG", "SELECT AVG(price) FROM products"},
{"Aggregate - MIN", "SELECT MIN(created_at) FROM logs"},
{"Aggregate - MAX", "SELECT MAX(score) FROM results"},
{"Aggregate - GROUP BY", "SELECT user_id, COUNT(*) FROM orders GROUP BY user_id"},
{"Aggregate - HAVING", "SELECT user_id, SUM(amount) FROM orders GROUP BY user_id HAVING SUM(amount) > 1000"},
// Doris specific aggregates
{"Doris - HLL_UNION_AGG", "SELECT HLL_UNION_AGG(hll_col) FROM stats"},
{"Doris - BITMAP_UNION_COUNT", "SELECT BITMAP_UNION_COUNT(bitmap_col) FROM tags"},
{"Doris - APPROX_COUNT_DISTINCT", "SELECT APPROX_COUNT_DISTINCT(user_id) FROM events"},
{"Doris - GROUP_CONCAT", "SELECT GROUP_CONCAT(name) FROM users GROUP BY dept"},
// LIMIT <= maxRows
{"LIMIT - Equal to max", "SELECT * FROM users LIMIT 500"},
{"LIMIT - Less than max", "SELECT * FROM users LIMIT 100"},
{"LIMIT - With OFFSET", "SELECT * FROM users LIMIT 100 OFFSET 50"},
{"LIMIT - Value 1", "SELECT * FROM users LIMIT 1"},
{"LIMIT - Value 0", "SELECT * FROM users LIMIT 0"},
}
t.Log("=== SQL patterns that SKIP probe check (no extra query needed) ===")
for _, tc := range skipCheckCases {
needsCheck, _, _ := NeedsRowCountCheck(tc.sql, maxRows)
status := "✓ SKIP"
if needsCheck {
status = "✗ NEEDS CHECK (unexpected)"
}
t.Logf(" %s: %s\n SQL: %s", status, tc.category, tc.sql)
}
needsCheckCases := []struct {
category string
sql string
}{
{"No LIMIT - Simple SELECT", "SELECT * FROM users"},
{"No LIMIT - With WHERE", "SELECT * FROM users WHERE status = 'active'"},
{"No LIMIT - With JOIN", "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id"},
{"No LIMIT - Subquery with agg", "SELECT * FROM (SELECT user_id, COUNT(*) FROM orders GROUP BY user_id) t"},
{"No LIMIT - DISTINCT", "SELECT DISTINCT category FROM products"},
{"LIMIT > max (actual may be smaller)", "SELECT * FROM users LIMIT 1000"},
{"LIMIT >> max", "SELECT * FROM users LIMIT 10000"},
}
t.Log("\n=== SQL patterns that NEED probe check ===")
for _, tc := range needsCheckCases {
needsCheck, _, _ := NeedsRowCountCheck(tc.sql, maxRows)
status := "✓ NEEDS CHECK"
if !needsCheck {
status = "✗ SKIP (unexpected)"
}
t.Logf(" %s: %s\n SQL: %s", status, tc.category, tc.sql)
}
}
// ptr is a helper function to create a pointer to int64
func ptr(v int64) *int64 {
return &v
}

View File

@@ -73,35 +73,44 @@ func (d *Doris) QueryTimeseries(ctx context.Context, query *QueryParam) ([]types
}
// CheckMaxQueryRows checks if the query result exceeds the maximum allowed rows
// It uses SQL analysis to skip unnecessary checks for aggregate queries or queries with LIMIT <= maxRows
// For queries that need checking, it uses probe approach (LIMIT maxRows+1) instead of COUNT(*) for better performance
func (d *Doris) CheckMaxQueryRows(ctx context.Context, database, sql string) error {
maxQueryRows := d.MaxQueryRows
if maxQueryRows == 0 {
maxQueryRows = 500
}
cleanedSQL := strings.TrimSpace(strings.TrimSuffix(strings.TrimSpace(sql), ";"))
// Step 1: Analyze SQL to determine if check is needed
needsCheck, _, _ := NeedsRowCountCheck(cleanedSQL, maxQueryRows)
if !needsCheck {
return nil
}
// Step 2: Execute probe query (more efficient than COUNT(*))
return d.probeRowCount(ctx, database, cleanedSQL, maxQueryRows)
}
// probeRowCount uses threshold probing to check row count
// It reads at most maxRows+1 rows, which is O(maxRows) instead of O(totalRows) for COUNT(*)
// Doris optimizes LIMIT queries by stopping scan early once limit is reached
func (d *Doris) probeRowCount(ctx context.Context, database, sql string, maxRows int) error {
timeoutCtx, cancel := d.createTimeoutContext(ctx)
defer cancel()
cleanedSQL := strings.ReplaceAll(sql, ";", "")
checkQuery := fmt.Sprintf("SELECT COUNT(*) as count FROM (%s) AS subquery;", cleanedSQL)
// Probe SQL: only need to check if exceeds threshold, not actual data
probeSQL := fmt.Sprintf("SELECT 1 FROM (%s) AS __probe_chk LIMIT %d", sql, maxRows+1)
// 执行计数查询
results, err := d.ExecQuery(timeoutCtx, database, checkQuery)
results, err := d.ExecQuery(timeoutCtx, database, probeSQL)
if err != nil {
return err
}
if len(results) > 0 {
if count, exists := results[0]["count"]; exists {
v, err := sqlbase.ParseFloat64Value(count)
if err != nil {
return err
}
maxQueryRows := d.MaxQueryRows
if maxQueryRows == 0 {
maxQueryRows = 500
}
if v > float64(maxQueryRows) {
return fmt.Errorf("query result rows count %d exceeds the maximum limit %d", int(v), maxQueryRows)
}
}
// If returned rows > maxRows, it exceeds the limit
if len(results) > maxRows {
return fmt.Errorf("query result rows count exceeds the maximum limit %d", maxRows)
}
return nil

12
go.mod
View File

@@ -33,6 +33,7 @@ require (
github.com/jinzhu/copier v0.4.0
github.com/json-iterator/go v1.1.12
github.com/koding/multiconfig v0.0.0-20171124222453-69c27309b2d7
github.com/larksuite/oapi-sdk-go/v3 v3.5.1
github.com/lib/pq v1.10.9
github.com/mailru/easyjson v0.7.7
github.com/mattn/go-isatty v0.0.19
@@ -42,6 +43,7 @@ require (
github.com/opensearch-project/opensearch-go/v2 v2.3.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pelletier/go-toml/v2 v2.0.8
github.com/pingcap/tidb/pkg/parser v0.0.0-20260120034856-e15515e804da
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.20.5
github.com/prometheus/common v0.60.1
@@ -101,6 +103,9 @@ require (
github.com/jcmturner/gofork v1.7.6 // indirect
github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect
github.com/jcmturner/rpc/v2 v2.0.3 // indirect
github.com/pingcap/errors v0.11.5-0.20250523034308-74f78ae071ee // indirect
github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect
github.com/pingcap/log v1.1.0 // indirect
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
@@ -108,10 +113,13 @@ require (
github.com/valyala/fastrand v1.1.0 // indirect
github.com/valyala/histogram v1.2.0 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/sync v0.18.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
)
@@ -135,7 +143,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.6.0
github.com/go-sql-driver/mysql v1.7.1
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/grafana/regexp v0.0.0-20221122212121-6b5c0a4cb7fd // indirect

39
go.sum
View File

@@ -101,6 +101,7 @@ github.com/aws/aws-sdk-go-v2/service/sso v1.12.10/go.mod h1:ouy2P4z6sJN70fR3ka3w
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.10/go.mod h1:AFvkxc8xfBe8XA+5St5XIHHrQQtkxqrRincx4hmMHOk=
github.com/aws/aws-sdk-go-v2/service/sts v1.19.0/go.mod h1:BgQOMsg8av8jset59jelyPW7NoZcZXLVpDsXunGDrk8=
github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow=
@@ -195,8 +196,9 @@ github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
@@ -243,6 +245,7 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grafana/regexp v0.0.0-20221122212121-6b5c0a4cb7fd h1:PpuIBO5P3e9hpqBD0O/HjhShYuM6XE0i/lbE6J94kww=
github.com/grafana/regexp v0.0.0-20221122212121-6b5c0a4cb7fd/go.mod h1:M5qHK+eWfAv8VR/265dIuEpL3fNfeC21tXXp9itM24A=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -315,6 +318,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/larksuite/oapi-sdk-go/v3 v3.5.1 h1:gX4dz92YU70inuIX+ug+PBe64eHToIN9rHB4Vupv5Eg=
github.com/larksuite/oapi-sdk-go/v3 v3.5.1/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
@@ -361,9 +366,19 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pingcap/errors v0.11.5-0.20250523034308-74f78ae071ee h1:/IDPbpzkzA97t1/Z1+C3KlxbevjMeaI6BQYxvivu4u8=
github.com/pingcap/errors v0.11.5-0.20250523034308-74f78ae071ee/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg=
github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 h1:tdMsjOqUR7YXHoBitzdebTvOjs/swniBTOLy5XiMtuE=
github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86/go.mod h1:exzhVYca3WRtd6gclGNErRWb1qEgff3LYta0LvRmON4=
github.com/pingcap/log v1.1.0 h1:ELiPxACz7vdo1qAvvaWJg1NrYFoY6gqAh/+Uo6aXdD8=
github.com/pingcap/log v1.1.0/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4=
github.com/pingcap/tidb/pkg/parser v0.0.0-20260120034856-e15515e804da h1:PhkRZgMWdq9kTsu7vtVbcDs+SBXjHfFj84027WVZCzI=
github.com/pingcap/tidb/pkg/parser v0.0.0-20260120034856-e15515e804da/go.mod h1:oHE+ub2QaDERd+UNHe4z2BhFV2jZrm7VNOe6atR9AF4=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -392,7 +407,6 @@ github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5X
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/redis/go-redis/v9 v9.0.2 h1:BA426Zqe/7r56kCcvxYLWe1mkaz71LKF77GwgFzSxfE=
github.com/redis/go-redis/v9 v9.0.2/go.mod h1:/xDTe9EF1LM61hek62Poq2nzQSGj0xSrEtEHbBQevps=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
@@ -467,13 +481,24 @@ go.opentelemetry.io/otel v1.32.0 h1:WnBN+Xjcteh0zdk01SVqV55d/m62NJLJdIyb4y/WO5U=
go.opentelemetry.io/otel v1.32.0/go.mod h1:00DCVSB0RQcnzlwyTfqtxSm+DRr9hpYrHjNGiBHVQIg=
go.opentelemetry.io/otel/trace v1.32.0 h1:WIC9mYrXf8TmY/EXuULKc8hR17vE+Hjv2cssQDe03fM=
go.opentelemetry.io/otel/trace v1.32.0/go.mod h1:+i4rkvCraA+tG6AzwloGaCtkx53Fa+L+V8e9a7YvhT8=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/automaxprocs v1.5.2 h1:2LxUOGiR3O6tw8ui5sZa2LAaHnsviZdVOUZw4fvbnME=
go.uber.org/automaxprocs v1.5.2/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0=
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
@@ -507,6 +532,7 @@ golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
@@ -623,6 +649,8 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200509030707-2212a7e161a5/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
@@ -668,6 +696,9 @@ gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkp
gopkg.in/ini.v1 v1.56.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
@@ -693,8 +724,8 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=

View File

@@ -6,7 +6,9 @@ import (
"fmt"
"github.com/ccfos/nightingale/v6/pkg/ctx"
"gorm.io/gorm"
"github.com/ccfos/nightingale/v6/pkg/poster"
"gorm.io/gorm"
)
// 执行状态常量
@@ -94,6 +96,9 @@ func (e *EventPipelineExecution) GetInputsSnapshot() (map[string]string, error)
// CreateEventPipelineExecution 创建执行记录
func CreateEventPipelineExecution(c *ctx.Context, execution *EventPipelineExecution) error {
if !c.IsCenter {
return poster.PostByUrls(c, "/v1/n9e/event-pipeline-execution", execution)
}
return DB(c).Create(execution).Error
}

View File

@@ -44,6 +44,7 @@ type QueryConfig struct {
Exp string `json:"exp"`
WriteDatasourceId int64 `json:"write_datasource_id"`
Delay int `json:"delay"`
WritebackEnabled bool `json:"writeback_enabled"` // 是否写入与查询数据源相同的数据源
}
type Query struct {

345
pkg/feishu/feishu.go Normal file
View File

@@ -0,0 +1,345 @@
package feishu
import (
"bytes"
"context"
"fmt"
"net/url"
"strings"
"sync"
"time"
"github.com/ccfos/nightingale/v6/storage"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/toolkits/pkg/logger"
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
larkauthen "github.com/larksuite/oapi-sdk-go/v3/service/authen/v1"
larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3"
)
const defaultAuthURL = "https://accounts.feishu.cn/open-apis/authen/v1/authorize"
const SsoTypeName = "feishu"
type SsoClient struct {
Enable bool
FeiShuConfig *Config `json:"-"`
Ctx context.Context
client *lark.Client
sync.RWMutex
}
type Config struct {
Enable bool `json:"enable"`
AuthURL string `json:"auth_url"`
DisplayName string `json:"display_name"`
AppID string `json:"app_id"`
AppSecret string `json:"app_secret"`
RedirectURL string `json:"redirect_url"`
UsernameField string `json:"username_field"` // name, email, phone
FeiShuEndpoint string `json:"feishu_endpoint"` // 飞书API端点默认为 open.feishu.cn
Proxy string `json:"proxy"`
CoverAttributes bool `json:"cover_attributes"`
DefaultRoles []string `json:"default_roles"`
}
type CallbackOutput struct {
Redirect string `json:"redirect"`
Msg string `json:"msg"`
AccessToken string `json:"accessToken"`
Username string `json:"Username"`
Nickname string `json:"Nickname"`
Phone string `yaml:"Phone"`
Email string `yaml:"Email"`
}
func wrapStateKey(key string) string {
return "n9e_feishu_oauth_" + key
}
// createClient 创建飞书SDK客户端v3版本
func (c *Config) createClient() (*lark.Client, error) {
opts := []lark.ClientOptionFunc{
lark.WithLogLevel(larkcore.LogLevelInfo),
lark.WithEnableTokenCache(true), // 启用token缓存
}
if c.FeiShuEndpoint != "" {
lark.FeishuBaseUrl = c.FeiShuEndpoint
}
// 创建客户端v3版本
client := lark.NewClient(
c.AppID,
c.AppSecret,
opts...,
)
return client, nil
}
func New(cf Config) *SsoClient {
var s = &SsoClient{}
if !cf.Enable {
return s
}
s.Reload(cf)
return s
}
func (s *SsoClient) AuthCodeURL(state string) (string, error) {
var buf bytes.Buffer
feishuAuthURL := defaultAuthURL
if s.FeiShuConfig.AuthURL != "" {
feishuAuthURL = s.FeiShuConfig.AuthURL
}
buf.WriteString(feishuAuthURL)
v := url.Values{
"app_id": {s.FeiShuConfig.AppID},
"state": {state},
}
v.Set("redirect_uri", s.FeiShuConfig.RedirectURL)
if s.FeiShuConfig.RedirectURL == "" {
return "", errors.New("FeiShu OAuth RedirectURL is empty")
}
if strings.Contains(feishuAuthURL, "?") {
buf.WriteByte('&')
} else {
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
return buf.String(), nil
}
// GetUserToken 通过授权码获取用户access token和user_id使用SDK v3
func (s *SsoClient) GetUserToken(code string) (string, string, error) {
if s.client == nil {
return "", "", errors.New("feishu client is not initialized")
}
ctx := context.Background()
// 使用SDK v3的authen服务获取access token
req := larkauthen.NewCreateAccessTokenReqBuilder().
Body(larkauthen.NewCreateAccessTokenReqBodyBuilder().
GrantType("authorization_code").
Code(code).
Build()).
Build()
resp, err := s.client.Authen.AccessToken.Create(ctx, req)
if err != nil {
return "", "", fmt.Errorf("feishu get access token error: %w", err)
}
// 检查响应
if !resp.Success() {
return "", "", fmt.Errorf("feishu api error: code=%d, msg=%s", resp.Code, resp.Msg)
}
if resp.Data == nil {
return "", "", errors.New("feishu api returned empty data")
}
userID := ""
if resp.Data.UserId != nil {
userID = *resp.Data.UserId
}
if userID == "" {
return "", "", errors.New("feishu api returned empty user_id")
}
accessToken := ""
if resp.Data.AccessToken != nil {
accessToken = *resp.Data.AccessToken
}
if accessToken == "" {
return "", "", errors.New("feishu api returned empty access_token")
}
return accessToken, userID, nil
}
// GetUserInfo 通过user_id获取用户详细信息使用SDK v3
// 注意SDK内部会自动管理token所以不需要传入accessToken
func (s *SsoClient) GetUserInfo(userID string) (*larkcontact.GetUserRespData, error) {
if s.client == nil {
return nil, errors.New("feishu client is not initialized")
}
ctx := context.Background()
// 使用SDK v3的contact服务获取用户详情
req := larkcontact.NewGetUserReqBuilder().
UserId(userID).
UserIdType(larkcontact.UserIdTypeUserId).
Build()
resp, err := s.client.Contact.User.Get(ctx, req)
if err != nil {
return nil, fmt.Errorf("feishu get user detail error: %w", err)
}
// 检查响应
if !resp.Success() {
return nil, fmt.Errorf("feishu api error: code=%d, msg=%s", resp.Code, resp.Msg)
}
if resp.Data == nil || resp.Data.User == nil {
return nil, errors.New("feishu api returned empty user data")
}
return resp.Data, nil
}
func (s *SsoClient) Reload(feishuConfig Config) {
s.Lock()
defer s.Unlock()
s.Enable = feishuConfig.Enable
s.FeiShuConfig = &feishuConfig
// 重新创建客户端
if feishuConfig.Enable && feishuConfig.AppID != "" && feishuConfig.AppSecret != "" {
client, err := feishuConfig.createClient()
if err != nil {
logger.Errorf("create feishu client error: %v", err)
} else {
s.client = client
}
}
}
func (s *SsoClient) GetDisplayName() string {
s.RLock()
defer s.RUnlock()
if !s.Enable {
return ""
}
return s.FeiShuConfig.DisplayName
}
func (s *SsoClient) Authorize(redis storage.Redis, redirect string) (string, error) {
state := uuid.New().String()
ctx := context.Background()
err := redis.Set(ctx, wrapStateKey(state), redirect, time.Duration(300*time.Second)).Err()
if err != nil {
return "", err
}
s.RLock()
defer s.RUnlock()
return s.AuthCodeURL(state)
}
func (s *SsoClient) Callback(redis storage.Redis, ctx context.Context, code, state string) (*CallbackOutput, error) {
// 通过code获取access token和user_id
accessToken, userID, err := s.GetUserToken(code)
if err != nil {
return nil, fmt.Errorf("feishu GetUserToken error: %s", err)
}
// 获取用户详细信息
userData, err := s.GetUserInfo(userID)
if err != nil {
return nil, fmt.Errorf("feishu GetUserInfo error: %s", err)
}
// 获取redirect URL
redirect := ""
if redis != nil {
redirect, err = fetchRedirect(redis, ctx, state)
if err != nil {
logger.Errorf("get redirect err:%v code:%s state:%s", err, code, state)
}
}
if redirect == "" {
redirect = "/"
}
err = deleteRedirect(redis, ctx, state)
if err != nil {
logger.Errorf("delete redirect err:%v code:%s state:%s", err, code, state)
}
var callbackOutput CallbackOutput
if userData == nil || userData.User == nil {
return nil, fmt.Errorf("feishu GetUserInfo failed, user data is nil")
}
user := userData.User
logger.Debugf("feishu get user info userID %s result %+v", userID, user)
// 提取用户信息
username := ""
if user.UserId != nil {
username = *user.UserId
}
if username == "" {
return nil, errors.New("feishu user_id is empty")
}
nickname := ""
if user.Name != nil {
nickname = *user.Name
}
phone := ""
if user.Mobile != nil {
phone = *user.Mobile
}
email := ""
if user.Email != nil {
email = *user.Email
}
if email == "" {
if user.EnterpriseEmail != nil {
email = *user.EnterpriseEmail
}
}
callbackOutput.Redirect = redirect
callbackOutput.AccessToken = accessToken
// 根据UsernameField配置确定username
switch s.FeiShuConfig.UsernameField {
case "name":
if nickname == "" {
return nil, errors.New("feishu user name is empty")
}
callbackOutput.Username = nickname
case "phone":
if phone == "" {
return nil, errors.New("feishu user phone is empty")
}
callbackOutput.Username = phone
default:
if email == "" {
return nil, errors.New("feishu user email is empty")
}
callbackOutput.Username = email
}
callbackOutput.Nickname = nickname
callbackOutput.Email = email
callbackOutput.Phone = phone
return &callbackOutput, nil
}
func fetchRedirect(redis storage.Redis, ctx context.Context, state string) (string, error) {
return redis.Get(ctx, wrapStateKey(state)).Result()
}
func deleteRedirect(redis storage.Redis, ctx context.Context, state string) error {
return redis.Del(ctx, wrapStateKey(state)).Err()
}

View File

@@ -184,7 +184,8 @@ func (s *Set) updateDBTargetTs(ident string, now int64) {
func (s *Set) updateTargetsUpdateTs(lst []string, now int64, redis storage.Redis) error {
if redis == nil {
return fmt.Errorf("redis is nil")
logger.Debugf("update_ts: redis is nil")
return nil
}
newMap := make(map[string]interface{}, len(lst))