login.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. package admin
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "github.com/lejianwen/rustdesk-api/v2/global"
  6. "github.com/lejianwen/rustdesk-api/v2/http/controller/api"
  7. "github.com/lejianwen/rustdesk-api/v2/http/request/admin"
  8. apiReq "github.com/lejianwen/rustdesk-api/v2/http/request/api"
  9. "github.com/lejianwen/rustdesk-api/v2/http/response"
  10. adResp "github.com/lejianwen/rustdesk-api/v2/http/response/admin"
  11. "github.com/lejianwen/rustdesk-api/v2/model"
  12. "github.com/lejianwen/rustdesk-api/v2/service"
  13. "github.com/mojocn/base64Captcha"
  14. "sync"
  15. "time"
  16. )
  17. type Login struct {
  18. }
  19. // Captcha 验证码结构
  20. type Captcha struct {
  21. Id string `json:"id"` // 验证码 ID
  22. B64 string `json:"b64"` // base64 验证码
  23. Code string `json:"-"` // 验证码内容
  24. ExpiresAt time.Time `json:"-"` // 过期时间
  25. }
  26. type LoginLimiter struct {
  27. mu sync.RWMutex
  28. failCount map[string]int // 记录每个 IP 的失败次数
  29. timestamp map[string]time.Time // 记录每个 IP 的最后失败时间
  30. captchas map[string]Captcha // 每个 IP 的验证码
  31. threshold int // 失败阈值
  32. expiry time.Duration // 失败记录过期时间
  33. }
  34. func NewLoginLimiter(threshold int, expiry time.Duration) *LoginLimiter {
  35. return &LoginLimiter{
  36. failCount: make(map[string]int),
  37. timestamp: make(map[string]time.Time),
  38. captchas: make(map[string]Captcha),
  39. threshold: threshold,
  40. expiry: expiry,
  41. }
  42. }
  43. // RecordFailure 记录登录失败
  44. func (l *LoginLimiter) RecordFailure(ip string) {
  45. l.mu.Lock()
  46. defer l.mu.Unlock()
  47. // 如果该 IP 的记录已经过期,重置计数
  48. if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) > l.expiry {
  49. l.failCount[ip] = 0
  50. }
  51. // 更新失败次数和时间戳
  52. l.failCount[ip]++
  53. l.timestamp[ip] = time.Now()
  54. }
  55. // NeedsCaptcha 检查是否需要验证码
  56. func (l *LoginLimiter) NeedsCaptcha(ip string) bool {
  57. l.mu.RLock()
  58. defer l.mu.RUnlock()
  59. // 检查记录是否存在且未过期
  60. if lastTime, exists := l.timestamp[ip]; exists && time.Since(lastTime) <= l.expiry {
  61. return l.failCount[ip] >= l.threshold
  62. }
  63. return false
  64. }
  65. // GenerateCaptcha 为指定 IP 生成验证码
  66. func (l *LoginLimiter) GenerateCaptcha(ip string) Captcha {
  67. l.mu.Lock()
  68. defer l.mu.Unlock()
  69. capd := base64Captcha.NewDriverString(50, 150, 5, 10, 4, "1234567890abcdefghijklmnopqrstuvwxyz", nil, nil, nil)
  70. b64cap := base64Captcha.NewCaptcha(capd, base64Captcha.DefaultMemStore)
  71. id, b64s, answer, err := b64cap.Generate()
  72. if err != nil {
  73. global.Logger.Error("Generate captcha failed: " + err.Error())
  74. return Captcha{}
  75. }
  76. // 保存验证码到对应 IP
  77. l.captchas[ip] = Captcha{
  78. Id: id,
  79. B64: b64s,
  80. Code: answer,
  81. ExpiresAt: time.Now().Add(5 * time.Minute),
  82. }
  83. return l.captchas[ip]
  84. }
  85. // VerifyCaptcha 验证指定 IP 的验证码
  86. func (l *LoginLimiter) VerifyCaptcha(ip, code string) bool {
  87. l.mu.RLock()
  88. defer l.mu.RUnlock()
  89. // 检查验证码是否存在且未过期
  90. if captcha, exists := l.captchas[ip]; exists && time.Now().Before(captcha.ExpiresAt) {
  91. return captcha.Code == code
  92. }
  93. return false
  94. }
  95. // RemoveCaptcha 移除指定 IP 的验证码
  96. func (l *LoginLimiter) RemoveCaptcha(ip string) {
  97. l.mu.Lock()
  98. defer l.mu.Unlock()
  99. delete(l.captchas, ip)
  100. }
  101. // CleanupExpired 清理过期的记录
  102. func (l *LoginLimiter) CleanupExpired() {
  103. l.mu.Lock()
  104. defer l.mu.Unlock()
  105. now := time.Now()
  106. for ip, lastTime := range l.timestamp {
  107. if now.Sub(lastTime) > l.expiry {
  108. delete(l.failCount, ip)
  109. delete(l.timestamp, ip)
  110. delete(l.captchas, ip)
  111. }
  112. }
  113. }
  114. func (l *LoginLimiter) RemoveRecord(ip string) {
  115. l.mu.Lock()
  116. defer l.mu.Unlock()
  117. delete(l.failCount, ip)
  118. delete(l.timestamp, ip)
  119. delete(l.captchas, ip)
  120. }
  121. var loginLimiter = NewLoginLimiter(3, 5*time.Minute)
  122. // Login 登录
  123. // @Tags 登录
  124. // @Summary 登录
  125. // @Description 登录
  126. // @Accept json
  127. // @Produce json
  128. // @Param body body admin.Login true "登录信息"
  129. // @Success 200 {object} response.Response{data=adResp.LoginPayload}
  130. // @Failure 500 {object} response.Response
  131. // @Router /admin/login [post]
  132. // @Security token
  133. func (ct *Login) Login(c *gin.Context) {
  134. if global.Config.App.DisablePwdLogin {
  135. response.Fail(c, 101, response.TranslateMsg(c, "PwdLoginDisabled"))
  136. return
  137. }
  138. f := &admin.Login{}
  139. err := c.ShouldBindJSON(f)
  140. clientIp := c.ClientIP()
  141. if err != nil {
  142. global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
  143. response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
  144. return
  145. }
  146. errList := global.Validator.ValidStruct(c, f)
  147. if len(errList) > 0 {
  148. global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "ParamsError", c.RemoteIP(), clientIp))
  149. response.Fail(c, 101, errList[0])
  150. return
  151. }
  152. // 检查是否需要验证码
  153. if loginLimiter.NeedsCaptcha(clientIp) {
  154. if f.Captcha == "" || !loginLimiter.VerifyCaptcha(clientIp, f.Captcha) {
  155. response.Fail(c, 101, response.TranslateMsg(c, "CaptchaError"))
  156. return
  157. }
  158. }
  159. u := service.AllService.UserService.InfoByUsernamePassword(f.Username, f.Password)
  160. if u.Id == 0 {
  161. global.Logger.Warn(fmt.Sprintf("Login Fail: %s %s %s", "UsernameOrPasswordError", c.RemoteIP(), clientIp))
  162. loginLimiter.RecordFailure(clientIp)
  163. if loginLimiter.NeedsCaptcha(clientIp) {
  164. loginLimiter.RemoveCaptcha(clientIp)
  165. }
  166. response.Fail(c, 101, response.TranslateMsg(c, "UsernameOrPasswordError"))
  167. return
  168. }
  169. if !service.AllService.UserService.CheckUserEnable(u) {
  170. if loginLimiter.NeedsCaptcha(clientIp) {
  171. loginLimiter.RemoveCaptcha(clientIp)
  172. }
  173. response.Fail(c, 101, response.TranslateMsg(c, "UserDisabled"))
  174. return
  175. }
  176. ut := service.AllService.UserService.Login(u, &model.LoginLog{
  177. UserId: u.Id,
  178. Client: model.LoginLogClientWebAdmin,
  179. Uuid: "", //must be empty
  180. Ip: clientIp,
  181. Type: model.LoginLogTypeAccount,
  182. Platform: f.Platform,
  183. })
  184. // 成功后清除记录
  185. loginLimiter.RemoveRecord(clientIp)
  186. // 清理过期记录
  187. go loginLimiter.CleanupExpired()
  188. responseLoginSuccess(c, u, ut.Token)
  189. }
  190. func (ct *Login) Captcha(c *gin.Context) {
  191. clientIp := c.ClientIP()
  192. if !loginLimiter.NeedsCaptcha(clientIp) {
  193. response.Fail(c, 101, response.TranslateMsg(c, "NoCaptchaRequired"))
  194. return
  195. }
  196. captcha := loginLimiter.GenerateCaptcha(clientIp)
  197. response.Success(c, gin.H{
  198. "captcha": captcha,
  199. })
  200. }
  201. // Logout 登出
  202. // @Tags 登录
  203. // @Summary 登出
  204. // @Description 登出
  205. // @Accept json
  206. // @Produce json
  207. // @Success 200 {object} response.Response
  208. // @Failure 500 {object} response.Response
  209. // @Router /admin/logout [post]
  210. func (ct *Login) Logout(c *gin.Context) {
  211. u := service.AllService.UserService.CurUser(c)
  212. token, ok := c.Get("token")
  213. if ok {
  214. service.AllService.UserService.Logout(u, token.(string))
  215. }
  216. response.Success(c, nil)
  217. }
  218. // LoginOptions
  219. // @Tags 登录
  220. // @Summary 登录选项
  221. // @Description 登录选项
  222. // @Accept json
  223. // @Produce json
  224. // @Success 200 {object} []string
  225. // @Failure 500 {object} response.ErrorResponse
  226. // @Router /admin/login-options [post]
  227. func (ct *Login) LoginOptions(c *gin.Context) {
  228. ip := c.ClientIP()
  229. ops := service.AllService.OauthService.GetOauthProviders()
  230. response.Success(c, gin.H{
  231. "ops": ops,
  232. "register": global.Config.App.Register,
  233. "need_captcha": loginLimiter.NeedsCaptcha(ip),
  234. })
  235. }
  236. // OidcAuth
  237. // @Tags Oauth
  238. // @Summary OidcAuth
  239. // @Description OidcAuth
  240. // @Accept json
  241. // @Produce json
  242. // @Router /admin/oidc/auth [post]
  243. func (ct *Login) OidcAuth(c *gin.Context) {
  244. // o := &api.Oauth{}
  245. // o.OidcAuth(c)
  246. f := &apiReq.OidcAuthRequest{}
  247. err := c.ShouldBindJSON(f)
  248. if err != nil {
  249. response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
  250. return
  251. }
  252. err, state, verifier, nonce, url := service.AllService.OauthService.BeginAuth(f.Op)
  253. if err != nil {
  254. response.Error(c, response.TranslateMsg(c, err.Error()))
  255. return
  256. }
  257. service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
  258. Action: service.OauthActionTypeLogin,
  259. Op: f.Op,
  260. Id: f.Id,
  261. DeviceType: "webadmin",
  262. // DeviceOs: ct.Platform(c),
  263. DeviceOs: f.DeviceInfo.Os,
  264. Uuid: f.Uuid,
  265. Verifier: verifier,
  266. Nonce: nonce,
  267. }, 5*60)
  268. response.Success(c, gin.H{
  269. "code": state,
  270. "url": url,
  271. })
  272. }
  273. // OidcAuthQuery
  274. // @Tags Oauth
  275. // @Summary OidcAuthQuery
  276. // @Description OidcAuthQuery
  277. // @Accept json
  278. // @Produce json
  279. // @Success 200 {object} response.Response{data=adResp.LoginPayload}
  280. // @Failure 500 {object} response.Response
  281. // @Router /admin/oidc/auth-query [get]
  282. func (ct *Login) OidcAuthQuery(c *gin.Context) {
  283. o := &api.Oauth{}
  284. u, ut := o.OidcAuthQueryPre(c)
  285. if ut == nil {
  286. return
  287. }
  288. responseLoginSuccess(c, u, ut.Token)
  289. }
  290. func responseLoginSuccess(c *gin.Context, u *model.User, token string) {
  291. lp := &adResp.LoginPayload{}
  292. lp.FromUser(u)
  293. lp.Token = token
  294. lp.RouteNames = service.AllService.UserService.RouteNames(u)
  295. response.Success(c, lp)
  296. }