login_limiter.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. package utils
  2. import (
  3. "errors"
  4. "sync"
  5. "time"
  6. )
  7. // 安全策略配置
  8. type SecurityPolicy struct {
  9. CaptchaThreshold int // 尝试失败次数达到验证码阈值,小于0表示不启用, 0表示强制启用
  10. BanThreshold int // 尝试失败次数达到封禁阈值,为0表示不启用
  11. AttemptsWindow time.Duration
  12. BanDuration time.Duration
  13. }
  14. // 验证码提供者接口
  15. type CaptchaProvider interface {
  16. Generate() (id string, content string, answer string, err error)
  17. //Validate(ip, code string) bool
  18. Expiration() time.Duration // 验证码过期时间, 应该小于 AttemptsWindow
  19. Draw(content string) (string, error) // 绘制验证码
  20. }
  21. // 验证码元数据
  22. type CaptchaMeta struct {
  23. Id string
  24. Content string
  25. Answer string
  26. ExpiresAt time.Time
  27. }
  28. // IP封禁记录
  29. type BanRecord struct {
  30. ExpiresAt time.Time
  31. Reason string
  32. }
  33. // 登录限制器
  34. type LoginLimiter struct {
  35. mu sync.Mutex
  36. policy SecurityPolicy
  37. attempts map[string][]time.Time //
  38. captchas map[string]CaptchaMeta
  39. bannedIPs map[string]BanRecord
  40. provider CaptchaProvider
  41. cleanupStop chan struct{}
  42. }
  43. var defaultSecurityPolicy = SecurityPolicy{
  44. CaptchaThreshold: 3,
  45. BanThreshold: 5,
  46. AttemptsWindow: 5 * time.Minute,
  47. BanDuration: 30 * time.Minute,
  48. }
  49. func NewLoginLimiter(policy SecurityPolicy) *LoginLimiter {
  50. // 设置默认值
  51. if policy.AttemptsWindow == 0 {
  52. policy.AttemptsWindow = 5 * time.Minute
  53. }
  54. if policy.BanDuration == 0 {
  55. policy.BanDuration = 30 * time.Minute
  56. }
  57. ll := &LoginLimiter{
  58. policy: policy,
  59. attempts: make(map[string][]time.Time),
  60. captchas: make(map[string]CaptchaMeta),
  61. bannedIPs: make(map[string]BanRecord),
  62. cleanupStop: make(chan struct{}),
  63. }
  64. go ll.cleanupRoutine()
  65. return ll
  66. }
  67. // 注册验证码提供者
  68. func (ll *LoginLimiter) RegisterProvider(p CaptchaProvider) {
  69. ll.mu.Lock()
  70. defer ll.mu.Unlock()
  71. ll.provider = p
  72. }
  73. // isDisabled 检查是否禁用登录限制
  74. func (ll *LoginLimiter) isDisabled() bool {
  75. return ll.policy.CaptchaThreshold < 0 && ll.policy.BanThreshold == 0
  76. }
  77. // 记录登录失败尝试
  78. func (ll *LoginLimiter) RecordFailedAttempt(ip string) {
  79. if ll.isDisabled() {
  80. return
  81. }
  82. ll.mu.Lock()
  83. defer ll.mu.Unlock()
  84. if banned, _ := ll.isBanned(ip); banned {
  85. return
  86. }
  87. now := time.Now()
  88. windowStart := now.Add(-ll.policy.AttemptsWindow)
  89. // 清理过期尝试
  90. validAttempts := ll.pruneAttempts(ip, windowStart)
  91. // 记录新尝试
  92. validAttempts = append(validAttempts, now)
  93. ll.attempts[ip] = validAttempts
  94. // 检查封禁条件
  95. if ll.policy.BanThreshold > 0 && len(validAttempts) >= ll.policy.BanThreshold {
  96. ll.banIP(ip, "excessive failed attempts")
  97. return
  98. }
  99. return
  100. }
  101. // 生成验证码
  102. func (ll *LoginLimiter) RequireCaptcha() (error, CaptchaMeta) {
  103. ll.mu.Lock()
  104. defer ll.mu.Unlock()
  105. if ll.provider == nil {
  106. return errors.New("no captcha provider available"), CaptchaMeta{}
  107. }
  108. id, content, answer, err := ll.provider.Generate()
  109. if err != nil {
  110. return err, CaptchaMeta{}
  111. }
  112. // 存储验证码
  113. ll.captchas[id] = CaptchaMeta{
  114. Id: id,
  115. Content: content,
  116. Answer: answer,
  117. ExpiresAt: time.Now().Add(ll.provider.Expiration()),
  118. }
  119. return nil, ll.captchas[id]
  120. }
  121. // 验证验证码
  122. func (ll *LoginLimiter) VerifyCaptcha(id, answer string) bool {
  123. ll.mu.Lock()
  124. defer ll.mu.Unlock()
  125. // 查找匹配验证码
  126. if ll.provider == nil {
  127. return false
  128. }
  129. // 获取并验证验证码
  130. captcha, exists := ll.captchas[id]
  131. if !exists {
  132. return false
  133. }
  134. // 清理过期验证码
  135. if time.Now().After(captcha.ExpiresAt) {
  136. delete(ll.captchas, id)
  137. return false
  138. }
  139. // 验证并清理状态
  140. if answer == captcha.Answer {
  141. delete(ll.captchas, id)
  142. return true
  143. }
  144. return false
  145. }
  146. func (ll *LoginLimiter) DrawCaptcha(content string) (err error, str string) {
  147. str, err = ll.provider.Draw(content)
  148. return
  149. }
  150. // 清除记录窗口
  151. func (ll *LoginLimiter) RemoveAttempts(ip string) {
  152. ll.mu.Lock()
  153. defer ll.mu.Unlock()
  154. _, exists := ll.attempts[ip]
  155. if exists {
  156. delete(ll.attempts, ip)
  157. }
  158. }
  159. // CheckSecurityStatus 检查安全状态
  160. func (ll *LoginLimiter) CheckSecurityStatus(ip string) (banned bool, captchaRequired bool) {
  161. if ll.isDisabled() {
  162. return
  163. }
  164. ll.mu.Lock()
  165. defer ll.mu.Unlock()
  166. // 检查封禁状态
  167. if banned, _ = ll.isBanned(ip); banned {
  168. return
  169. }
  170. // 清理过期数据
  171. ll.pruneAttempts(ip, time.Now().Add(-ll.policy.AttemptsWindow))
  172. // 检查验证码要求
  173. captchaRequired = len(ll.attempts[ip]) >= ll.policy.CaptchaThreshold
  174. return
  175. }
  176. // 后台清理任务
  177. func (ll *LoginLimiter) cleanupRoutine() {
  178. ticker := time.NewTicker(1 * time.Minute)
  179. defer ticker.Stop()
  180. for {
  181. select {
  182. case <-ticker.C:
  183. ll.cleanupExpired()
  184. case <-ll.cleanupStop:
  185. return
  186. }
  187. }
  188. }
  189. // 内部工具方法
  190. func (ll *LoginLimiter) isBanned(ip string) (bool, BanRecord) {
  191. record, exists := ll.bannedIPs[ip]
  192. if !exists {
  193. return false, BanRecord{}
  194. }
  195. if time.Now().After(record.ExpiresAt) {
  196. delete(ll.bannedIPs, ip)
  197. return false, BanRecord{}
  198. }
  199. return true, record
  200. }
  201. func (ll *LoginLimiter) banIP(ip, reason string) {
  202. ll.bannedIPs[ip] = BanRecord{
  203. ExpiresAt: time.Now().Add(ll.policy.BanDuration),
  204. Reason: reason,
  205. }
  206. delete(ll.attempts, ip)
  207. delete(ll.captchas, ip)
  208. }
  209. func (ll *LoginLimiter) pruneAttempts(ip string, cutoff time.Time) []time.Time {
  210. var valid []time.Time
  211. for _, t := range ll.attempts[ip] {
  212. if t.After(cutoff) {
  213. valid = append(valid, t)
  214. }
  215. }
  216. if len(valid) == 0 {
  217. delete(ll.attempts, ip)
  218. } else {
  219. ll.attempts[ip] = valid
  220. }
  221. return valid
  222. }
  223. func (ll *LoginLimiter) pruneCaptchas(id string) {
  224. if captcha, exists := ll.captchas[id]; exists {
  225. if time.Now().After(captcha.ExpiresAt) {
  226. delete(ll.captchas, id)
  227. }
  228. }
  229. }
  230. func (ll *LoginLimiter) cleanupExpired() {
  231. ll.mu.Lock()
  232. defer ll.mu.Unlock()
  233. now := time.Now()
  234. // 清理封禁记录
  235. for ip, record := range ll.bannedIPs {
  236. if now.After(record.ExpiresAt) {
  237. delete(ll.bannedIPs, ip)
  238. }
  239. }
  240. // 清理尝试记录
  241. for ip := range ll.attempts {
  242. ll.pruneAttempts(ip, now.Add(-ll.policy.AttemptsWindow))
  243. }
  244. // 清理验证码
  245. for id := range ll.captchas {
  246. ll.pruneCaptchas(id)
  247. }
  248. }