login_limiter.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. package utils
  2. import (
  3. "errors"
  4. "sync"
  5. "time"
  6. )
  7. // 安全策略配置
  8. type SecurityPolicy struct {
  9. CaptchaThreshold int // 尝试失败次数达到验证码阈值,小于0表示不启用, 0表示强制启用
  10. BanThreshold int // 尝试失败次数达到封禁阈值,为0表示不启用
  11. AttemptsWindow time.Duration
  12. BanDuration time.Duration
  13. }
  14. // 验证码提供者接口
  15. type CaptchaProvider interface {
  16. Generate(ip string) (string, string, error)
  17. //Validate(ip, code string) bool
  18. Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
  19. Draw(content string) (string, error) // 绘制验证码
  20. }
  21. // 验证码元数据
  22. type CaptchaMeta struct {
  23. Content string
  24. Answer string
  25. ExpiresAt time.Time
  26. }
  27. // IP封禁记录
  28. type BanRecord struct {
  29. ExpiresAt time.Time
  30. Reason string
  31. }
  32. // 登录限制器
  33. type LoginLimiter struct {
  34. mu sync.Mutex
  35. policy SecurityPolicy
  36. attempts map[string][]time.Time //
  37. captchas map[string]CaptchaMeta
  38. bannedIPs map[string]BanRecord
  39. provider CaptchaProvider
  40. cleanupStop chan struct{}
  41. }
  42. var defaultSecurityPolicy = SecurityPolicy{
  43. CaptchaThreshold: 3,
  44. BanThreshold: 5,
  45. AttemptsWindow: 5 * time.Minute,
  46. BanDuration: 30 * time.Minute,
  47. }
  48. func NewLoginLimiter(policy SecurityPolicy) *LoginLimiter {
  49. // 设置默认值
  50. if policy.AttemptsWindow == 0 {
  51. policy.AttemptsWindow = 5 * time.Minute
  52. }
  53. if policy.BanDuration == 0 {
  54. policy.BanDuration = 30 * time.Minute
  55. }
  56. ll := &LoginLimiter{
  57. policy: policy,
  58. attempts: make(map[string][]time.Time),
  59. captchas: make(map[string]CaptchaMeta),
  60. bannedIPs: make(map[string]BanRecord),
  61. cleanupStop: make(chan struct{}),
  62. }
  63. go ll.cleanupRoutine()
  64. return ll
  65. }
  66. // 注册验证码提供者
  67. func (ll *LoginLimiter) RegisterProvider(p CaptchaProvider) {
  68. ll.mu.Lock()
  69. defer ll.mu.Unlock()
  70. ll.provider = p
  71. }
  72. // isDisabled 检查是否禁用登录限制
  73. func (ll *LoginLimiter) isDisabled() bool {
  74. return ll.policy.CaptchaThreshold < 0 && ll.policy.BanThreshold == 0
  75. }
  76. // 记录登录失败尝试
  77. func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
  78. if ll.isDisabled() {
  79. return
  80. }
  81. ll.mu.Lock()
  82. defer ll.mu.Unlock()
  83. if banned, _ := ll.isBanned(ip); banned {
  84. return
  85. }
  86. now := time.Now()
  87. windowStart := now.Add(-ll.policy.AttemptsWindow)
  88. // 清理过期尝试
  89. validAttempts := ll.pruneAttempts(ip, windowStart)
  90. // 记录新尝试
  91. validAttempts = append(validAttempts, now)
  92. ll.attempts[ip] = validAttempts
  93. // 检查封禁条件
  94. if ll.policy.BanThreshold > 0 && len(validAttempts) >= ll.policy.BanThreshold {
  95. ll.banIP(ip, "excessive failed attempts")
  96. return
  97. }
  98. return
  99. }
  100. // 生成验证码
  101. func (ll *LoginLimiter) RequireCaptcha(ip string) (error, CaptchaMeta) {
  102. ll.mu.Lock()
  103. defer ll.mu.Unlock()
  104. if ll.provider == nil {
  105. return errors.New("no captcha provider available"), CaptchaMeta{}
  106. }
  107. content, answer, err := ll.provider.Generate(ip)
  108. if err != nil {
  109. return err, CaptchaMeta{}
  110. }
  111. // 存储验证码
  112. ll.captchas[ip] = CaptchaMeta{
  113. Content: content,
  114. Answer: answer,
  115. ExpiresAt: time.Now().Add(ll.provider.Expiration()),
  116. }
  117. return nil, ll.captchas[ip]
  118. }
  119. // 验证验证码
  120. func (ll *LoginLimiter) VerifyCaptcha(ip, answer string) bool {
  121. ll.mu.Lock()
  122. defer ll.mu.Unlock()
  123. // 查找匹配验证码
  124. if ll.provider == nil {
  125. return false
  126. }
  127. // 获取并验证验证码
  128. captcha, exists := ll.captchas[ip]
  129. if !exists {
  130. return false
  131. }
  132. // 清理过期验证码
  133. if time.Now().After(captcha.ExpiresAt) {
  134. delete(ll.captchas, ip)
  135. return false
  136. }
  137. // 验证并清理状态
  138. if answer == captcha.Answer {
  139. delete(ll.captchas, ip)
  140. return true
  141. }
  142. return false
  143. }
  144. func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
  145. str, err = ll.provider.Draw(content)
  146. return
  147. }
  148. func (ll *LoginLimiter) RemoveCaptcha(ip string) {
  149. ll.mu.Lock()
  150. defer ll.mu.Unlock()
  151. _, exists := ll.captchas[ip]
  152. if exists {
  153. delete(ll.captchas, ip)
  154. }
  155. }
  156. // 清除记录窗口
  157. func (ll *LoginLimiter) RemoveAttempts(ip string) {
  158. ll.mu.Lock()
  159. defer ll.mu.Unlock()
  160. _, exists := ll.attempts[ip]
  161. if exists {
  162. delete(ll.attempts, ip)
  163. }
  164. }
  165. // CheckSecurityStatus 检查安全状态
  166. func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequired bool) {
  167. if ll.isDisabled() {
  168. return
  169. }
  170. ll.mu.Lock()
  171. defer ll.mu.Unlock()
  172. // 检查封禁状态
  173. if banned, _ = ll.isBanned(ip); banned {
  174. return
  175. }
  176. // 清理过期数据
  177. ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
  178. ll.pruneCaptchas(ip)
  179. // 检查验证码要求
  180. captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
  181. return
  182. }
  183. // 后台清理任务
  184. func (ll *LoginLimiter) cleanupRoutine() {
  185. ticker := time.NewTicker(1 * time.Minute)
  186. defer ticker.Stop()
  187. for {
  188. select {
  189. case <-ticker.C:
  190. ll.cleanupExpired()
  191. case <-ll.cleanupStop:
  192. return
  193. }
  194. }
  195. }
  196. // 内部工具方法
  197. func (ll *LoginLimiter) isBanned(ip string) (bool, BanRecord) {
  198. record, exists := ll.bannedIPs[ip]
  199. if !exists {
  200. return false, BanRecord{}
  201. }
  202. if time.Now().After(record.ExpiresAt) {
  203. delete(ll.bannedIPs, ip)
  204. return false, BanRecord{}
  205. }
  206. return true, record
  207. }
  208. func (ll *LoginLimiter) banIP(ip, reason string) {
  209. ll.bannedIPs[ip] = BanRecord{
  210. ExpiresAt: time.Now().Add(ll.policy.BanDuration),
  211. Reason: reason,
  212. }
  213. delete(ll.attempts, ip)
  214. delete(ll.captchas, ip)
  215. }
  216. func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
  217. var valid []time.Time
  218. for _, t := range ll.attempts[ip] {
  219. if t.After(cutoff) {
  220. valid = append(valid, t)
  221. }
  222. }
  223. if len(valid) == 0 {
  224. delete(ll.attempts, ip)
  225. } else {
  226. ll.attempts[ip] = valid
  227. }
  228. return valid
  229. }
  230. func (ll *LoginLimiter) pruneCaptchas(ip string) {
  231. if captcha, exists := ll.captchas[ip]; exists {
  232. if time.Now().After(captcha.ExpiresAt) {
  233. delete(ll.captchas, ip)
  234. }
  235. }
  236. }
  237. func (ll *LoginLimiter) cleanupExpired() {
  238. ll.mu.Lock()
  239. defer ll.mu.Unlock()
  240. now := time.Now()
  241. // 清理封禁记录
  242. for ip, record := range ll.bannedIPs {
  243. if now.After(record.ExpiresAt) {
  244. delete(ll.bannedIPs, ip)
  245. }
  246. }
  247. // 清理尝试记录
  248. for ip := range ll.attempts {
  249. ll.pruneAttempts(ip, now.Add(-ll.policy.AttemptsWindow))
  250. }
  251. // 清理验证码
  252. for ip := range ll.captchas {
  253. ll.pruneCaptchas(ip)
  254. }
  255. }