Просмотр исходного кода

feat(login): Captcha upgrade and add the function to ban IP addresses (#250)

lejianwen месяцев назад: 8
Родитель
Сommit
f19109cdf8

+ 9 - 0
cmd/apimain.go

@@ -18,6 +18,7 @@ import (
18 18
 	"github.com/spf13/cobra"
19 19
 	"os"
20 20
 	"strconv"
21
+	"time"
21 22
 )
22 23
 
23 24
 // @title 管理系统API
@@ -175,8 +176,16 @@ func InitGlobal() {
175 176
 	//service
176 177
 	service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock)
177 178
 
179
+	global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{
180
+		CaptchaThreshold: global.Config.App.CaptchaThreshold,
181
+		BanThreshold:     global.Config.App.BanThreshold,
182
+		AttemptsWindow:   10 * time.Minute,
183
+		BanDuration:      30 * time.Minute,
184
+	})
185
+	global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{})
178 186
 	DatabaseAutoUpdate()
179 187
 }
188
+
180 189
 func DatabaseAutoUpdate() {
181 190
 	version := 262
182 191
 

+ 3 - 0
conf/config.yaml

@@ -2,10 +2,13 @@ lang: "zh-CN"
2 2
 app:
3 3
   web-client: 1  # 1:启用 0:禁用
4 4
   register: false #是否开启注册
5
+  captcha-threshold: 3 #   <0:disabled, 0 always, >0:enabled
6
+  ban-threshold: 0 # 0:disabled, >0:enabled
5 7
   show-swagger: 0 # 1:启用 0:禁用
6 8
   token-expire: 168h
7 9
   web-sso: true #web auth sso
8 10
   disable-pwd-login: false #禁用密码登录
11
+
9 12
 admin:
10 13
   title: "RustDesk Api Admin"
11 14
   hello-file: "./conf/admin/hello.html"  #优先使用file

+ 8 - 6
config/config.go

@@ -14,12 +14,14 @@ const (
14 14
 )
15 15
 
16 16
 type App struct {
17
-	WebClient       int           `mapstructure:"web-client"`
18
-	Register        bool          `mapstructure:"register"`
19
-	ShowSwagger     int           `mapstructure:"show-swagger"`
20
-	TokenExpire     time.Duration `mapstructure:"token-expire"`
21
-	WebSso          bool          `mapstructure:"web-sso"`
22
-	DisablePwdLogin bool          `mapstructure:"disable-pwd-login"`
17
+	WebClient        int           `mapstructure:"web-client"`
18
+	Register         bool          `mapstructure:"register"`
19
+	ShowSwagger      int           `mapstructure:"show-swagger"`
20
+	TokenExpire      time.Duration `mapstructure:"token-expire"`
21
+	WebSso           bool          `mapstructure:"web-sso"`
22
+	DisablePwdLogin  bool          `mapstructure:"disable-pwd-login"`
23
+	CaptchaThreshold int           `mapstructure:"captcha-threshold"`
24
+	BanThreshold     int           `mapstructure:"ban-threshold"`
23 25
 }
24 26
 type Admin struct {
25 27
 	Title     string `mapstructure:"title"`

+ 6 - 4
global/global.go

@@ -10,6 +10,7 @@ import (
10 10
 	"github.com/lejianwen/rustdesk-api/v2/lib/jwt"
11 11
 	"github.com/lejianwen/rustdesk-api/v2/lib/lock"
12 12
 	"github.com/lejianwen/rustdesk-api/v2/lib/upload"
13
+	"github.com/lejianwen/rustdesk-api/v2/utils"
13 14
 	"github.com/nicksnyder/go-i18n/v2/i18n"
14 15
 	"github.com/sirupsen/logrus"
15 16
 	"github.com/spf13/viper"
@@ -31,8 +32,9 @@ var (
31 32
 		ValidStruct func(*gin.Context, interface{}) []string
32 33
 		ValidVar    func(ctx *gin.Context, field interface{}, tag string) []string
33 34
 	}
34
-	Oss       *upload.Oss
35
-	Jwt       *jwt.Jwt
36
-	Lock      lock.Locker
37
-	Localizer func(lang string) *i18n.Localizer
35
+	Oss          *upload.Oss
36
+	Jwt          *jwt.Jwt
37
+	Lock         lock.Locker
38
+	Localizer    func(lang string) *i18n.Localizer
39
+	LoginLimiter *utils.LoginLimiter
38 40
 )

+ 49 - 142
http/controller/admin/login.go

@@ -11,135 +11,11 @@ import (
11 11
 	adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin"
12 12
 	"github.com/lejianwen/rustdesk-api/v2/model"
13 13
 	"github.com/lejianwen/rustdesk-api/v2/service"
14
-	"github.com/mojocn/base64Captcha"
15
-	"sync"
16
-	"time"
17 14
 )
18 15
 
19 16
 type Login struct {
20 17
 }
21 18
 
22
-// Captcha 验证码结构
23
-type Captcha struct {
24
-	Id        string    `json:"id"`  // 验证码 ID
25
-	B64       string    `json:"b64"` // base64 验证码
26
-	Code      string    `json:"-"`   // 验证码内容
27
-	ExpiresAt time.Time `json:"-"`   // 过期时间
28
-}
29
-type LoginLimiter struct {
30
-	mu        sync.RWMutex
31
-	failCount map[string]int       // 记录每个 IP 的失败次数
32
-	timestamp map[string]time.Time // 记录每个 IP 的最后失败时间
33
-	captchas  map[string]Captcha   // 每个 IP 的验证码
34
-	threshold int                  // 失败阈值
35
-	expiry    time.Duration        // 失败记录过期时间
36
-}
37
-
38
-func NewLoginLimiter(threshold int, expiry time.Duration) *LoginLimiter {
39
-	return &LoginLimiter{
40
-		failCount: make(map[string]int),
41
-		timestamp: make(map[string]time.Time),
42
-		captchas:  make(map[string]Captcha),
43
-		threshold: threshold,
44
-		expiry:    expiry,
45
-	}
46
-}
47
-
48
-// RecordFailure 记录登录失败
49
-func (l *LoginLimiter) RecordFailure(ip string) {
50
-	l.mu.Lock()
51
-	defer l.mu.Unlock()
52
-
53
-	// 如果该 IP 的记录已经过期,重置计数
54
-	if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) > l.expiry {
55
-		l.failCount[ip] = 0
56
-	}
57
-
58
-	// 更新失败次数和时间戳
59
-	l.failCount[ip]++
60
-	l.timestamp[ip] = time.Now()
61
-}
62
-
63
-// NeedsCaptcha 检查是否需要验证码
64
-func (l *LoginLimiter) NeedsCaptcha(ip string) bool {
65
-	l.mu.RLock()
66
-	defer l.mu.RUnlock()
67
-
68
-	// 检查记录是否存在且未过期
69
-	if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) <= l.expiry {
70
-		return l.failCount[ip] >= l.threshold
71
-	}
72
-	return false
73
-}
74
-
75
-// GenerateCaptcha 为指定 IP 生成验证码
76
-func (l *LoginLimiter) GenerateCaptcha(ip string) Captcha {
77
-	l.mu.Lock()
78
-	defer l.mu.Unlock()
79
-
80
-	capd := base64Captcha.NewDriverString(50, 150, 5, 10, 4, "1234567890abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
81
-	b64cap := base64Captcha.NewCaptcha(capd, base64Captcha.DefaultMemStore)
82
-	id, b64s, answer, err := b64cap.Generate()
83
-	if err != nil {
84
-		global.Logger.Error("Generate captcha failed: " + err.Error())
85
-		return Captcha{}
86
-	}
87
-	// 保存验证码到对应 IP
88
-	l.captchas[ip] = Captcha{
89
-		Id:        id,
90
-		B64:       b64s,
91
-		Code:      answer,
92
-		ExpiresAt: time.Now().Add(5 * time.Minute),
93
-	}
94
-	return l.captchas[ip]
95
-}
96
-
97
-// VerifyCaptcha 验证指定 IP 的验证码
98
-func (l *LoginLimiter) VerifyCaptcha(ip, code string) bool {
99
-	l.mu.RLock()
100
-	defer l.mu.RUnlock()
101
-
102
-	// 检查验证码是否存在且未过期
103
-	if captcha, exists := l.captchas[ip]; exists && time.Now().Before(captcha.ExpiresAt) {
104
-		return captcha.Code == code
105
-	}
106
-	return false
107
-}
108
-
109
-// RemoveCaptcha 移除指定 IP 的验证码
110
-func (l *LoginLimiter) RemoveCaptcha(ip string) {
111
-	l.mu.Lock()
112
-	defer l.mu.Unlock()
113
-
114
-	delete(l.captchas, ip)
115
-}
116
-
117
-// CleanupExpired 清理过期的记录
118
-func (l *LoginLimiter) CleanupExpired() {
119
-	l.mu.Lock()
120
-	defer l.mu.Unlock()
121
-
122
-	now := time.Now()
123
-	for ip, lastTime := range l.timestamp {
124
-		if now.Sub(lastTime) > l.expiry {
125
-			delete(l.failCount, ip)
126
-			delete(l.timestamp, ip)
127
-			delete(l.captchas, ip)
128
-		}
129
-	}
130
-}
131
-
132
-func (l *LoginLimiter) RemoveRecord(ip string) {
133
-	l.mu.Lock()
134
-	defer l.mu.Unlock()
135
-
136
-	delete(l.failCount, ip)
137
-	delete(l.timestamp, ip)
138
-	delete(l.captchas, ip)
139
-}
140
-
141
-var loginLimiter = NewLoginLimiter(3, 5*time.Minute)
142
-
143 19
 // Login 登录
144 20
 // @Tags 登录
145 21
 // @Summary 登录
@@ -156,10 +32,16 @@ func (ct *Login) Login(c *gin.Context) {
156 32
 		response.Fail(c, 101, response.TranslateMsg(c, "PwdLoginDisabled"))
157 33
 		return
158 34
 	}
35
+
36
+	// 检查登录限制
37
+	loginLimiter := global.LoginLimiter
38
+	clientIp := c.ClientIP()
39
+	_, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
40
+
159 41
 	f := &admin.Login{}
160 42
 	err := c.ShouldBindJSON(f)
161
-	clientIp := c.ClientIP()
162 43
 	if err != nil {
44
+		loginLimiter.RecordFailedAttempt(clientIp)
163 45
 		global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
164 46
 		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
165 47
 		return
@@ -167,13 +49,14 @@ func (ct *Login) Login(c *gin.Context) {
167 49
 
168 50
 	errList := global.Validator.ValidStruct(c, f)
169 51
 	if len(errList) > 0 {
52
+		loginLimiter.RecordFailedAttempt(clientIp)
170 53
 		global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
171 54
 		response.Fail(c, 101, errList[0])
172 55
 		return
173 56
 	}
174 57
 
175 58
 	// 检查是否需要验证码
176
-	if loginLimiter.NeedsCaptcha(clientIp) {
59
+	if needCaptcha {
177 60
 		if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) {
178 61
 			response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
179 62
 			return
@@ -184,17 +67,22 @@ func (ct *Login) Login(c *gin.Context) {
184 67
 
185 68
 	if u.Id == 0 {
186 69
 		global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
187
-		loginLimiter.RecordFailure(clientIp)
188
-		if loginLimiter.NeedsCaptcha(clientIp) {
189
-			loginLimiter.RemoveCaptcha(clientIp)
70
+		loginLimiter.RecordFailedAttempt(clientIp)
71
+		// 移除验证码,重新生成
72
+		loginLimiter.RemoveCaptcha(clientIp)
73
+		if _, needCaptcha = loginLimiter.CheckSecurityStatus(clientIp); needCaptcha {
74
+			response.Fail(c, 110, response.TranslateMsg(c, "UsernameOrPasswordError"))
75
+		} else {
76
+			response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
190 77
 		}
191
-		response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
192 78
 		return
193 79
 	}
194 80
 
195 81
 	if !service.AllService.UserService.CheckUserEnable(u) {
196
-		if loginLimiter.NeedsCaptcha(clientIp) {
82
+		if needCaptcha {
197 83
 			loginLimiter.RemoveCaptcha(clientIp)
84
+			response.Fail(c, 110, response.TranslateMsg(c, "UserDisabled"))
85
+			return
198 86
 		}
199 87
 		response.Fail(c, 101, response.TranslateMsg(c, "UserDisabled"))
200 88
 		return
@@ -209,23 +97,36 @@ func (ct *Login) Login(c *gin.Context) {
209 97
 		Platform: f.Platform,
210 98
 	})
211 99
 
212
-	// 成功后清除记录
213
-	loginLimiter.RemoveRecord(clientIp)
214
-
215
-	// 清理过期记录
216
-	go loginLimiter.CleanupExpired()
217
-
100
+	// 登录成功,清除登录限制
101
+	loginLimiter.RemoveAttempts(clientIp)
218 102
 	responseLoginSuccess(c, u, ut.Token)
219 103
 }
220 104
 func (ct *Login) Captcha(c *gin.Context) {
105
+	loginLimiter := global.LoginLimiter
221 106
 	clientIp := c.ClientIP()
222
-	if !loginLimiter.NeedsCaptcha(clientIp) {
107
+	banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
108
+	if banned {
109
+		response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
110
+		return
111
+	}
112
+	if !needCaptcha {
223 113
 		response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
224 114
 		return
225 115
 	}
226
-	captcha := loginLimiter.GenerateCaptcha(clientIp)
116
+	err, captcha := loginLimiter.RequireCaptcha(clientIp)
117
+	if err != nil {
118
+		response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
119
+		return
120
+	}
121
+	err, b64 := loginLimiter.DrawCaptcha(captcha.Content)
122
+	if err != nil {
123
+		response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError")+err.Error())
124
+		return
125
+	}
227 126
 	response.Success(c, gin.H{
228
-		"captcha": captcha,
127
+		"captcha": gin.H{
128
+			"b64": b64,
129
+		},
229 130
 	})
230 131
 }
231 132
 
@@ -257,12 +158,18 @@ func (ct *Login) Logout(c *gin.Context) {
257 158
 // @Failure 500 {object} response.ErrorResponse
258 159
 // @Router /admin/login-options [post]
259 160
 func (ct *Login) LoginOptions(c *gin.Context) {
260
-	ip := c.ClientIP()
161
+	loginLimiter := global.LoginLimiter
162
+	clientIp := c.ClientIP()
163
+	banned, needCaptcha := loginLimiter.CheckSecurityStatus(clientIp)
164
+	if banned {
165
+		response.Fail(c, 101, response.TranslateMsg(c, "LoginBanned"))
166
+		return
167
+	}
261 168
 	ops := service.AllService.OauthService.GetOauthProviders()
262 169
 	response.Success(c, gin.H{
263 170
 		"ops":          ops,
264 171
 		"register":     global.Config.App.Register,
265
-		"need_captcha": loginLimiter.NeedsCaptcha(ip),
172
+		"need_captcha": needCaptcha,
266 173
 	})
267 174
 }
268 175
 

+ 8 - 0
http/controller/api/login.go

@@ -31,10 +31,16 @@ func (l *Login) Login(c *gin.Context) {
31 31
 		response.Error(c, response.TranslateMsg(c, "PwdLoginDisabled"))
32 32
 		return
33 33
 	}
34
+
35
+	// 检查登录限制
36
+	loginLimiter := global.LoginLimiter
37
+	clientIp := c.ClientIP()
38
+
34 39
 	f := &api.LoginForm{}
35 40
 	err := c.ShouldBindJSON(f)
36 41
 	//fmt.Println(f)
37 42
 	if err != nil {
43
+		loginLimiter.RecordFailedAttempt(clientIp)
38 44
 		global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
39 45
 		response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
40 46
 		return
@@ -42,6 +48,7 @@ func (l *Login) Login(c *gin.Context) {
42 48
 
43 49
 	errList := global.Validator.ValidStruct(c, f)
44 50
 	if len(errList) > 0 {
51
+		loginLimiter.RecordFailedAttempt(clientIp)
45 52
 		global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), c.ClientIP()))
46 53
 		response.Error(c, errList[0])
47 54
 		return
@@ -50,6 +57,7 @@ func (l *Login) Login(c *gin.Context) {
50 57
 	u := service.AllService.UserService.InfoByUsernamePassword(f.Username, f.Password)
51 58
 
52 59
 	if u.Id == 0 {
60
+		loginLimiter.RecordFailedAttempt(clientIp)
53 61
 		global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), c.ClientIP()))
54 62
 		response.Error(c, response.TranslateMsg(c, "UsernameOrPasswordError"))
55 63
 		return

+ 1 - 1
http/http.go

@@ -33,7 +33,7 @@ func ApiInit() {
33 33
 	g.NoRoute(func(c *gin.Context) {
34 34
 		c.String(http.StatusNotFound, "404 not found")
35 35
 	})
36
-	g.Use(middleware.Logger(), gin.Recovery())
36
+	g.Use(middleware.Logger(), middleware.Limiter(), gin.Recovery())
37 37
 	router.WebInit(g)
38 38
 	router.Init(g)
39 39
 	router.ApiInit(g)

+ 22 - 0
http/middleware/limiter.go

@@ -0,0 +1,22 @@
1
+package middleware
2
+
3
+import (
4
+	"github.com/gin-gonic/gin"
5
+	"github.com/lejianwen/rustdesk-api/v2/global"
6
+	"github.com/lejianwen/rustdesk-api/v2/http/response"
7
+	"net/http"
8
+)
9
+
10
+func Limiter() gin.HandlerFunc {
11
+	return func(c *gin.Context) {
12
+		loginLimiter := global.LoginLimiter
13
+		clientIp := c.ClientIP()
14
+		banned, _ := loginLimiter.CheckSecurityStatus(clientIp)
15
+		if banned {
16
+			response.Fail(c, http.StatusLocked, response.TranslateMsg(c, "Banned"))
17
+			c.Abort()
18
+			return
19
+		}
20
+		c.Next()
21
+	}
22
+}

+ 6 - 1
resources/i18n/en.toml

@@ -142,4 +142,9 @@ other = "Password login disabled."
142 142
 [CannotShareToSelf]
143 143
 description = "Cannot share to self."
144 144
 one = "Cannot share to self."
145
-other = "Cannot share to self."
145
+other = "Cannot share to self."
146
+
147
+[Banned]
148
+description = "Banned."
149
+one = "Banned."
150
+other = "Banned."

+ 6 - 1
resources/i18n/es.toml

@@ -151,4 +151,9 @@ other = "Inicio de sesión con contraseña deshabilitado."
151 151
 [CannotShareToSelf]
152 152
 description = "Cannot share to self."
153 153
 one = "No se puede compartir con uno mismo."
154
-other = "No se puede compartir con uno mismo."
154
+other = "No se puede compartir con uno mismo."
155
+
156
+[Banned]
157
+description = "Banned."
158
+one = "Prohibido."
159
+other = "Prohibido."

+ 6 - 1
resources/i18n/fr.toml

@@ -151,4 +151,9 @@ other = "Connexion par mot de passe désactivée."
151 151
 [CannotShareToSelf]
152 152
 description = "Cannot share to self."
153 153
 one = "Impossible de partager avec soi-même."
154
-other = "Impossible de partager avec soi-même."
154
+other = "Impossible de partager avec soi-même."
155
+
156
+[Banned]
157
+description = "Banned."
158
+one = "Banni."
159
+other = "Banni."

+ 6 - 1
resources/i18n/ko.toml

@@ -145,4 +145,9 @@ other = "비밀번호 로그인이 비활성화되었습니다."
145 145
 [CannotShareToSelf]
146 146
 description = "Cannot share to self."
147 147
 one = "자기 자신에게 공유할 수 없습니다."
148
-other = "자기 자신에게 공유할 수 없습니다."
148
+other = "자기 자신에게 공유할 수 없습니다."
149
+
150
+[Banned]
151
+description = "Banned."
152
+one = "금지됨."
153
+other = "금지됨."

+ 6 - 1
resources/i18n/ru.toml

@@ -151,4 +151,9 @@ other = "Вход по паролю отключен."
151 151
 [CannotShareToSelf]
152 152
 description = "Cannot share to self."
153 153
 one = "Нельзя поделиться с собой."
154
-other = "Нельзя поделиться с собой."
154
+other = "Нельзя поделиться с собой."
155
+
156
+[Banned]
157
+description = "Banned."
158
+one = "Заблокировано."
159
+other = "Заблокировано."

+ 6 - 1
resources/i18n/zh_CN.toml

@@ -144,4 +144,9 @@ other = "密码登录已禁用。"
144 144
 [CannotShareToSelf]
145 145
 description = "Cannot share to self."
146 146
 one = "不能共享给自己。"
147
-other = "不能共享给自己。"
147
+other = "不能共享给自己。"
148
+
149
+[Banned]
150
+description = "Banned."
151
+one = "已被封禁。"
152
+other = "已被封禁。"

+ 6 - 1
resources/i18n/zh_TW.toml

@@ -144,4 +144,9 @@ other = "密碼登錄已禁用。"
144 144
 [CannotShareToSelf]
145 145
 description = "Cannot share to self."
146 146
 one = "無法共享給自己。"
147
-other = "無法共享給自己。"
147
+other = "無法共享給自己。"
148
+
149
+[Banned]
150
+description = "Banned."
151
+one = "禁止使用。"
152
+other = "禁止使用。"

+ 48 - 0
utils/captcha.go

@@ -0,0 +1,48 @@
1
+package utils
2
+
3
+import (
4
+	"github.com/mojocn/base64Captcha"
5
+	"time"
6
+)
7
+
8
+var capdString = base64Captcha.NewDriverString(50, 150, 5, 10, 4, "123456789abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
9
+
10
+var capdMath = base64Captcha.NewDriverMath(50, 150, 5, 10, nil, nil, nil)
11
+
12
+type B64StringCaptchaProvider struct{}
13
+
14
+func (p B64StringCaptchaProvider) Generate(ip string) (string, string, error) {
15
+	_, content, answer := capdString.GenerateIdQuestionAnswer()
16
+	return content, answer, nil
17
+}
18
+
19
+func (p B64StringCaptchaProvider) Expiration() time.Duration {
20
+	return 5 * time.Minute
21
+}
22
+func (p B64StringCaptchaProvider) Draw(content string) (string, error) {
23
+	item, err := capdString.DrawCaptcha(content)
24
+	if err != nil {
25
+		return "", err
26
+	}
27
+	b64str := item.EncodeB64string()
28
+	return b64str, nil
29
+}
30
+
31
+type B64MathCaptchaProvider struct{}
32
+
33
+func (p B64MathCaptchaProvider) Generate(ip string) (string, string, error) {
34
+	_, content, answer := capdMath.GenerateIdQuestionAnswer()
35
+	return content, answer, nil
36
+}
37
+
38
+func (p B64MathCaptchaProvider) Expiration() time.Duration {
39
+	return 5 * time.Minute
40
+}
41
+func (p B64MathCaptchaProvider) Draw(content string) (string, error) {
42
+	item, err := capdMath.DrawCaptcha(content)
43
+	if err != nil {
44
+		return "", err
45
+	}
46
+	b64str := item.EncodeB64string()
47
+	return b64str, nil
48
+}

+ 305 - 0
utils/login_limiter.go

@@ -0,0 +1,305 @@
1
+package utils
2
+
3
+import (
4
+	"errors"
5
+	"sync"
6
+	"time"
7
+)
8
+
9
+// 安全策略配置
10
+type SecurityPolicy struct {
11
+	CaptchaThreshold int // 尝试失败次数达到验证码阈值,小于0表示不启用, 0表示强制启用
12
+	BanThreshold     int // 尝试失败次数达到封禁阈值,为0表示不启用
13
+	AttemptsWindow   time.Duration
14
+	BanDuration      time.Duration
15
+}
16
+
17
+// 验证码提供者接口
18
+type CaptchaProvider interface {
19
+	Generate(ip string) (string, string, error)
20
+	//Validate(ip, code string) bool
21
+	Expiration() time.Duration           // 验证码过期时间, 应该小于 AttemptsWindow
22
+	Draw(content string) (string, error) // 绘制验证码
23
+}
24
+
25
+// 验证码元数据
26
+type CaptchaMeta struct {
27
+	Content   string
28
+	Answer    string
29
+	ExpiresAt time.Time
30
+}
31
+
32
+// IP封禁记录
33
+type BanRecord struct {
34
+	ExpiresAt time.Time
35
+	Reason    string
36
+}
37
+
38
+// 登录限制器
39
+type LoginLimiter struct {
40
+	mu          sync.Mutex
41
+	policy      SecurityPolicy
42
+	attempts    map[string][]time.Time //
43
+	captchas    map[string]CaptchaMeta
44
+	bannedIPs   map[string]BanRecord
45
+	provider    CaptchaProvider
46
+	cleanupStop chan struct{}
47
+}
48
+
49
+var defaultSecurityPolicy = SecurityPolicy{
50
+	CaptchaThreshold: 3,
51
+	BanThreshold:     5,
52
+	AttemptsWindow:   5 * time.Minute,
53
+	BanDuration:      30 * time.Minute,
54
+}
55
+
56
+func NewLoginLimiter(policy SecurityPolicy) *LoginLimiter {
57
+	// 设置默认值
58
+	if policy.AttemptsWindow == 0 {
59
+		policy.AttemptsWindow = 5 * time.Minute
60
+	}
61
+	if policy.BanDuration == 0 {
62
+		policy.BanDuration = 30 * time.Minute
63
+	}
64
+
65
+	ll := &LoginLimiter{
66
+		policy:      policy,
67
+		attempts:    make(map[string][]time.Time),
68
+		captchas:    make(map[string]CaptchaMeta),
69
+		bannedIPs:   make(map[string]BanRecord),
70
+		cleanupStop: make(chan struct{}),
71
+	}
72
+	go ll.cleanupRoutine()
73
+	return ll
74
+}
75
+
76
+// 注册验证码提供者
77
+func (ll *LoginLimiter) RegisterProvider(p CaptchaProvider) {
78
+	ll.mu.Lock()
79
+	defer ll.mu.Unlock()
80
+	ll.provider = p
81
+}
82
+
83
+// isDisabled 检查是否禁用登录限制
84
+func (ll *LoginLimiter) isDisabled() bool {
85
+	return ll.policy.CaptchaThreshold < 0 && ll.policy.BanThreshold == 0
86
+}
87
+
88
+// 记录登录失败尝试
89
+func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
90
+	if ll.isDisabled() {
91
+		return
92
+	}
93
+	ll.mu.Lock()
94
+	defer ll.mu.Unlock()
95
+
96
+	if banned, _ := ll.isBanned(ip); banned {
97
+		return
98
+	}
99
+
100
+	now := time.Now()
101
+	windowStart := now.Add(-ll.policy.AttemptsWindow)
102
+
103
+	// 清理过期尝试
104
+	validAttempts := ll.pruneAttempts(ip, windowStart)
105
+
106
+	// 记录新尝试
107
+	validAttempts = append(validAttempts, now)
108
+	ll.attempts[ip] = validAttempts
109
+
110
+	// 检查封禁条件
111
+	if ll.policy.BanThreshold > 0 && len(validAttempts) >= ll.policy.BanThreshold {
112
+		ll.banIP(ip, "excessive failed attempts")
113
+		return
114
+	}
115
+
116
+	return
117
+}
118
+
119
+// 生成验证码
120
+func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
121
+	ll.mu.Lock()
122
+	defer ll.mu.Unlock()
123
+
124
+	if ll.provider == nil {
125
+		return errors.New("no captcha provider available"), CaptchaMeta{}
126
+	}
127
+
128
+	content, answer, err := ll.provider.Generate(ip)
129
+	if err != nil {
130
+		return err, CaptchaMeta{}
131
+	}
132
+
133
+	// 存储验证码
134
+	ll.captchas[ip] = CaptchaMeta{
135
+		Content:   content,
136
+		Answer:    answer,
137
+		ExpiresAt: time.Now().Add(ll.provider.Expiration()),
138
+	}
139
+
140
+	return nil, ll.captchas[ip]
141
+}
142
+
143
+// 验证验证码
144
+func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
145
+	ll.mu.Lock()
146
+	defer ll.mu.Unlock()
147
+
148
+	// 查找匹配验证码
149
+	if ll.provider == nil {
150
+		return false
151
+	}
152
+
153
+	// 获取并验证验证码
154
+	captcha, exists := ll.captchas[ip]
155
+	if !exists {
156
+		return false
157
+	}
158
+
159
+	// 清理过期验证码
160
+	if time.Now().After(captcha.ExpiresAt) {
161
+		delete(ll.captchas, ip)
162
+		return false
163
+	}
164
+
165
+	// 验证并清理状态
166
+	if answer == captcha.Answer {
167
+		delete(ll.captchas, ip)
168
+		return true
169
+	}
170
+
171
+	return false
172
+}
173
+
174
+func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
175
+	str, err = ll.provider.Draw(content)
176
+	return
177
+}
178
+
179
+func (ll *LoginLimiter) RemoveCaptcha(ip string) {
180
+	ll.mu.Lock()
181
+	defer ll.mu.Unlock()
182
+
183
+	_, exists := ll.captchas[ip]
184
+	if exists {
185
+		delete(ll.captchas, ip)
186
+	}
187
+}
188
+
189
+// 清除记录窗口
190
+func (ll *LoginLimiter) RemoveAttempts(ip string) {
191
+	ll.mu.Lock()
192
+	defer ll.mu.Unlock()
193
+
194
+	_, exists := ll.attempts[ip]
195
+	if exists {
196
+		delete(ll.attempts, ip)
197
+	}
198
+}
199
+
200
+// CheckSecurityStatus 检查安全状态
201
+func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequired bool) {
202
+	if ll.isDisabled() {
203
+		return
204
+	}
205
+	ll.mu.Lock()
206
+	defer ll.mu.Unlock()
207
+
208
+	// 检查封禁状态
209
+	if banned, _ = ll.isBanned(ip); banned {
210
+		return
211
+	}
212
+
213
+	// 清理过期数据
214
+	ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
215
+	ll.pruneCaptchas(ip)
216
+
217
+	// 检查验证码要求
218
+	captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
219
+
220
+	return
221
+}
222
+
223
+// 后台清理任务
224
+func (ll *LoginLimiter) cleanupRoutine() {
225
+	ticker := time.NewTicker(1 * time.Minute)
226
+	defer ticker.Stop()
227
+
228
+	for {
229
+		select {
230
+		case <-ticker.C:
231
+			ll.cleanupExpired()
232
+		case <-ll.cleanupStop:
233
+			return
234
+		}
235
+	}
236
+}
237
+
238
+// 内部工具方法
239
+func (ll *LoginLimiter) isBanned(ip string) (bool, BanRecord) {
240
+	record, exists := ll.bannedIPs[ip]
241
+	if !exists {
242
+		return false, BanRecord{}
243
+	}
244
+	if time.Now().After(record.ExpiresAt) {
245
+		delete(ll.bannedIPs, ip)
246
+		return false, BanRecord{}
247
+	}
248
+	return true, record
249
+}
250
+
251
+func (ll *LoginLimiter) banIP(ip, reason string) {
252
+	ll.bannedIPs[ip] = BanRecord{
253
+		ExpiresAt: time.Now().Add(ll.policy.BanDuration),
254
+		Reason:    reason,
255
+	}
256
+	delete(ll.attempts, ip)
257
+	delete(ll.captchas, ip)
258
+}
259
+
260
+func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
261
+	var valid []time.Time
262
+	for _, t := range ll.attempts[ip] {
263
+		if t.After(cutoff) {
264
+			valid = append(valid, t)
265
+		}
266
+	}
267
+	if len(valid) == 0 {
268
+		delete(ll.attempts, ip)
269
+	} else {
270
+		ll.attempts[ip] = valid
271
+	}
272
+	return valid
273
+}
274
+
275
+func (ll *LoginLimiter) pruneCaptchas(ip string) {
276
+	if captcha, exists := ll.captchas[ip]; exists {
277
+		if time.Now().After(captcha.ExpiresAt) {
278
+			delete(ll.captchas, ip)
279
+		}
280
+	}
281
+}
282
+
283
+func (ll *LoginLimiter) cleanupExpired() {
284
+	ll.mu.Lock()
285
+	defer ll.mu.Unlock()
286
+
287
+	now := time.Now()
288
+
289
+	// 清理封禁记录
290
+	for ip, record := range ll.bannedIPs {
291
+		if now.After(record.ExpiresAt) {
292
+			delete(ll.bannedIPs, ip)
293
+		}
294
+	}
295
+
296
+	// 清理尝试记录
297
+	for ip := range ll.attempts {
298
+		ll.pruneAttempts(ip, now.Add(-ll.policy.AttemptsWindow))
299
+	}
300
+
301
+	// 清理验证码
302
+	for ip := range ll.captchas {
303
+		ll.pruneCaptchas(ip)
304
+	}
305
+}

+ 286 - 0
utils/login_limiter_test.go

@@ -0,0 +1,286 @@
1
+package utils
2
+
3
+import (
4
+	"fmt"
5
+	"testing"
6
+	"time"
7
+)
8
+
9
+type MockCaptchaProvider struct{}
10
+
11
+func (p *MockCaptchaProvider) Generate(ip string) (string, string, error) {
12
+	return "CONTENT", "MOCK", nil
13
+}
14
+
15
+func (p *MockCaptchaProvider) Validate(ip, code string) bool {
16
+	return code == "MOCK"
17
+}
18
+
19
+func (p *MockCaptchaProvider) Expiration() time.Duration {
20
+	return 2 * time.Second
21
+}
22
+func (p *MockCaptchaProvider) Draw(content string) (string, error) {
23
+	return "MOCK", nil
24
+}
25
+
26
+func TestSecurityWorkflow(t *testing.T) {
27
+	policy := SecurityPolicy{
28
+		CaptchaThreshold: 3,
29
+		BanThreshold:     5,
30
+		AttemptsWindow:   5 * time.Minute,
31
+		BanDuration:      5 * time.Minute,
32
+	}
33
+	limiter := NewLoginLimiter(policy)
34
+	ip := "192.168.1.100"
35
+
36
+	// 测试正常失败记录
37
+	for i := 0; i < 3; i++ {
38
+		limiter.RecordFailedAttempt(ip)
39
+	}
40
+	isBanned, capRequired := limiter.CheckSecurityStatus(ip)
41
+	fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
42
+	if isBanned {
43
+		t.Error("IP should not be banned yet")
44
+	}
45
+	if !capRequired {
46
+		t.Error("Captcha should be required")
47
+	}
48
+	// 测试触发封禁
49
+	for i := 0; i < 3; i++ {
50
+		limiter.RecordFailedAttempt(ip)
51
+		isBanned, capRequired = limiter.CheckSecurityStatus(ip)
52
+		fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, isBanned, capRequired)
53
+	}
54
+
55
+	// 测试封禁状态
56
+	if isBanned, _ = limiter.CheckSecurityStatus(ip); !isBanned {
57
+		t.Error("IP should be banned")
58
+	}
59
+}
60
+
61
+func TestCaptchaFlow(t *testing.T) {
62
+	policy := SecurityPolicy{CaptchaThreshold: 2}
63
+	limiter := NewLoginLimiter(policy)
64
+	limiter.RegisterProvider(&MockCaptchaProvider{})
65
+	ip := "10.0.0.1"
66
+
67
+	// 触发验证码要求
68
+	limiter.RecordFailedAttempt(ip)
69
+	limiter.RecordFailedAttempt(ip)
70
+
71
+	// 检查状态
72
+	if _, need := limiter.CheckSecurityStatus(ip); !need {
73
+		t.Error("应该需要验证码")
74
+	}
75
+
76
+	// 生成验证码
77
+	err, capc := limiter.RequireCaptcha(ip)
78
+	if err != nil {
79
+		t.Fatalf("生成验证码失败: %v", err)
80
+	}
81
+	fmt.Printf("验证码内容: %#v\n", capc)
82
+
83
+	// 验证成功
84
+	if !limiter.VerifyCaptcha(ip, capc.Answer) {
85
+		t.Error("验证码应该验证成功")
86
+	}
87
+
88
+	limiter.RemoveAttempts(ip)
89
+	// 验证后状态
90
+	if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
91
+		t.Error("验证成功后应该重置状态")
92
+	}
93
+}
94
+
95
+func TestCaptchaMustFlow(t *testing.T) {
96
+	policy := SecurityPolicy{CaptchaThreshold: 0}
97
+	limiter := NewLoginLimiter(policy)
98
+	limiter.RegisterProvider(&MockCaptchaProvider{})
99
+	ip := "10.0.0.1"
100
+
101
+	// 检查状态
102
+	if _, need := limiter.CheckSecurityStatus(ip); !need {
103
+		t.Error("应该需要验证码")
104
+	}
105
+
106
+	// 生成验证码
107
+	err, capc := limiter.RequireCaptcha(ip)
108
+	if err != nil {
109
+		t.Fatalf("生成验证码失败: %v", err)
110
+	}
111
+	fmt.Printf("验证码内容: %#v\n", capc)
112
+
113
+	// 验证成功
114
+	if !limiter.VerifyCaptcha(ip, capc.Answer) {
115
+		t.Error("验证码应该验证成功")
116
+	}
117
+
118
+	// 验证后状态
119
+	if _, need := limiter.CheckSecurityStatus(ip); !need {
120
+		t.Error("应该需要验证码")
121
+	}
122
+}
123
+func TestAttemptTimeout(t *testing.T) {
124
+	policy := SecurityPolicy{CaptchaThreshold: 2, AttemptsWindow: 1 * time.Second}
125
+	limiter := NewLoginLimiter(policy)
126
+	limiter.RegisterProvider(&MockCaptchaProvider{})
127
+	ip := "10.0.0.1"
128
+
129
+	// 触发验证码要求
130
+	limiter.RecordFailedAttempt(ip)
131
+	limiter.RecordFailedAttempt(ip)
132
+
133
+	// 检查状态
134
+	if _, need := limiter.CheckSecurityStatus(ip); !need {
135
+		t.Error("应该需要验证码")
136
+	}
137
+
138
+	// 生成验证码
139
+	err, _ := limiter.RequireCaptcha(ip)
140
+	if err != nil {
141
+		t.Fatalf("生成验证码失败: %v", err)
142
+	}
143
+	// 等待超过 AttemptsWindow
144
+	time.Sleep(2 * time.Second)
145
+	// 触发验证码要求
146
+	limiter.RecordFailedAttempt(ip)
147
+
148
+	// 检查状态
149
+	if _, need := limiter.CheckSecurityStatus(ip); need {
150
+		t.Error("不应该需要验证码")
151
+	}
152
+}
153
+
154
+func TestCaptchaTimeout(t *testing.T) {
155
+	policy := SecurityPolicy{CaptchaThreshold: 2}
156
+	limiter := NewLoginLimiter(policy)
157
+	limiter.RegisterProvider(&MockCaptchaProvider{})
158
+	ip := "10.0.0.1"
159
+
160
+	// 触发验证码要求
161
+	limiter.RecordFailedAttempt(ip)
162
+	limiter.RecordFailedAttempt(ip)
163
+
164
+	// 检查状态
165
+	if _, need := limiter.CheckSecurityStatus(ip); !need {
166
+		t.Error("应该需要验证码")
167
+	}
168
+
169
+	// 生成验证码
170
+	err, _ := limiter.RequireCaptcha(ip)
171
+	if err != nil {
172
+		t.Fatalf("生成验证码失败: %v", err)
173
+	}
174
+
175
+	// 等待超过 CaptchaValidPeriod
176
+	time.Sleep(3 * time.Second)
177
+
178
+	code := "MOCK"
179
+	// 验证成功
180
+	if limiter.VerifyCaptcha(ip, code) {
181
+		t.Error("验证码应该已过期")
182
+	}
183
+
184
+}
185
+
186
+func TestBanFlow(t *testing.T) {
187
+	policy := SecurityPolicy{BanThreshold: 5}
188
+	limiter := NewLoginLimiter(policy)
189
+	ip := "10.0.0.1"
190
+	// 触发ban
191
+	for i := 0; i < 5; i++ {
192
+		limiter.RecordFailedAttempt(ip)
193
+	}
194
+
195
+	// 检查状态
196
+	if banned, _ := limiter.CheckSecurityStatus(ip); !banned {
197
+		t.Error("should be banned")
198
+	}
199
+}
200
+func TestBanDisableFlow(t *testing.T) {
201
+	policy := SecurityPolicy{BanThreshold: 0}
202
+	limiter := NewLoginLimiter(policy)
203
+	ip := "10.0.0.1"
204
+	// 触发ban
205
+	for i := 0; i < 5; i++ {
206
+		limiter.RecordFailedAttempt(ip)
207
+	}
208
+
209
+	// 检查状态
210
+	if banned, _ := limiter.CheckSecurityStatus(ip); banned {
211
+		t.Error("should not be banned")
212
+	}
213
+}
214
+func TestBanTimeout(t *testing.T) {
215
+	policy := SecurityPolicy{BanThreshold: 5, BanDuration: 1 * time.Second}
216
+	limiter := NewLoginLimiter(policy)
217
+	ip := "10.0.0.1"
218
+	// 触发ban
219
+	// 触发ban
220
+	for i := 0; i < 5; i++ {
221
+		limiter.RecordFailedAttempt(ip)
222
+	}
223
+
224
+	time.Sleep(2 * time.Second)
225
+
226
+	// 检查状态
227
+	if banned, _ := limiter.CheckSecurityStatus(ip); banned {
228
+		t.Error("should not be banned")
229
+	}
230
+}
231
+
232
+func TestLimiterDisabled(t *testing.T) {
233
+	policy := SecurityPolicy{BanThreshold: 0, CaptchaThreshold: -1}
234
+	limiter := NewLoginLimiter(policy)
235
+	ip := "10.0.0.1"
236
+	// 触发ban
237
+	for i := 0; i < 5; i++ {
238
+		limiter.RecordFailedAttempt(ip)
239
+	}
240
+
241
+	// 检查状态
242
+	if banned, capNeed := limiter.CheckSecurityStatus(ip); banned || capNeed {
243
+		fmt.Printf("IP: %s, Banned: %v, Captcha Required: %v\n", ip, banned, capNeed)
244
+		t.Error("should not be banned or need captcha")
245
+	}
246
+}
247
+
248
+func TestB64CaptchaFlow(t *testing.T) {
249
+	limiter := NewLoginLimiter(defaultSecurityPolicy)
250
+	limiter.RegisterProvider(B64StringCaptchaProvider{})
251
+	ip := "10.0.0.1"
252
+
253
+	// 触发验证码要求
254
+	limiter.RecordFailedAttempt(ip)
255
+	limiter.RecordFailedAttempt(ip)
256
+	limiter.RecordFailedAttempt(ip)
257
+
258
+	// 检查状态
259
+	if _, need := limiter.CheckSecurityStatus(ip); !need {
260
+		t.Error("应该需要验证码")
261
+	}
262
+
263
+	// 生成验证码
264
+	err, capc := limiter.RequireCaptcha(ip)
265
+	if err != nil {
266
+		t.Fatalf("生成验证码失败: %v", err)
267
+	}
268
+	fmt.Printf("验证码内容: %#v\n", capc)
269
+
270
+	//draw
271
+	err, b64 := limiter.DrawCaptcha(capc.Content)
272
+	if err != nil {
273
+		t.Fatalf("绘制验证码失败: %v", err)
274
+	}
275
+	fmt.Printf("验证码内容: %#v\n", b64)
276
+
277
+	// 验证成功
278
+	if !limiter.VerifyCaptcha(ip, capc.Answer) {
279
+		t.Error("验证码应该验证成功")
280
+	}
281
+	limiter.RemoveAttempts(ip)
282
+	// 验证后状态
283
+	if banned, need := limiter.CheckSecurityStatus(ip); banned || need {
284
+		t.Error("验证成功后应该重置状态")
285
+	}
286
+}