ouath.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. package api
  2. import (
  3. "Gwen/global"
  4. "Gwen/http/request/api"
  5. "Gwen/http/response"
  6. apiResp "Gwen/http/response/api"
  7. "Gwen/model"
  8. "Gwen/service"
  9. "github.com/gin-gonic/gin"
  10. "net/http"
  11. "strconv"
  12. "strings"
  13. )
  14. type Oauth struct {
  15. }
  16. // OidcAuth
  17. // @Tags Oauth
  18. // @Summary OidcAuth
  19. // @Description OidcAuth
  20. // @Accept json
  21. // @Produce json
  22. // @Success 200 {object} apiResp.LoginRes
  23. // @Failure 500 {object} response.ErrorResponse
  24. // @Router /oidc/auth [post]
  25. func (o *Oauth) OidcAuth(c *gin.Context) {
  26. f := &api.OidcAuthRequest{}
  27. err := c.ShouldBindJSON(&f)
  28. if err != nil {
  29. response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
  30. return
  31. }
  32. //fmt.Println(f)
  33. if f.Op != model.OauthTypeWebauth && f.Op != model.OauthTypeGoogle && f.Op != model.OauthTypeGithub && f.Op != model.OauthTypeOidc {
  34. response.Error(c, response.TranslateMsg(c, "ParamsError"))
  35. return
  36. }
  37. err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
  38. if err != nil {
  39. response.Error(c, response.TranslateMsg(c, err.Error()))
  40. return
  41. }
  42. service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
  43. Action: service.OauthActionTypeLogin,
  44. Id: f.Id,
  45. Op: f.Op,
  46. Uuid: f.Uuid,
  47. DeviceName: f.DeviceInfo.Name,
  48. DeviceOs: f.DeviceInfo.Os,
  49. DeviceType: f.DeviceInfo.Type,
  50. }, 5*60)
  51. //fmt.Println("code url", code, url)
  52. c.JSON(http.StatusOK, gin.H{
  53. "code": code,
  54. "url": url,
  55. })
  56. }
  57. func (o *Oauth) OidcAuthQueryPre(c *gin.Context) (*model.User, *model.UserToken) {
  58. var u *model.User
  59. var ut *model.UserToken
  60. q := &api.OidcAuthQuery{}
  61. // 解析查询参数并处理错误
  62. if err := c.ShouldBindQuery(q); err != nil {
  63. response.Error(c, response.TranslateMsg(c, "ParamsError")+": "+err.Error())
  64. return nil, nil
  65. }
  66. // 获取 OAuth 缓存
  67. v := service.AllService.OauthService.GetOauthCache(q.Code)
  68. if v == nil {
  69. response.Error(c, response.TranslateMsg(c, "OauthExpired"))
  70. return nil, nil
  71. }
  72. // 如果 UserId 为 0,说明还在授权中
  73. if v.UserId == 0 {
  74. c.JSON(http.StatusOK, gin.H{"message": "Authorization in progress, please login and bind"})
  75. return nil, nil
  76. }
  77. // 获取用户信息
  78. u = service.AllService.UserService.InfoById(v.UserId)
  79. if u == nil {
  80. response.Error(c, response.TranslateMsg(c, "UserNotFound"))
  81. return nil, nil
  82. }
  83. // 删除 OAuth 缓存
  84. service.AllService.OauthService.DeleteOauthCache(q.Code)
  85. // 创建登录日志并生成用户令牌
  86. ut = service.AllService.UserService.Login(u, &model.LoginLog{
  87. UserId: u.Id,
  88. Client: v.DeviceType,
  89. Uuid: v.Uuid,
  90. Ip: c.ClientIP(),
  91. Type: model.LoginLogTypeOauth,
  92. Platform: v.DeviceOs,
  93. })
  94. if ut == nil {
  95. response.Error(c, response.TranslateMsg(c, "LoginFailed"))
  96. return nil, nil
  97. }
  98. // 返回用户令牌
  99. return u, ut
  100. }
  101. // OidcAuthQuery
  102. // @Tags Oauth
  103. // @Summary OidcAuthQuery
  104. // @Description OidcAuthQuery
  105. // @Accept json
  106. // @Produce json
  107. // @Success 200 {object} apiResp.LoginRes
  108. // @Failure 500 {object} response.ErrorResponse
  109. // @Router /oidc/auth-query [get]
  110. func (o *Oauth) OidcAuthQuery(c *gin.Context) {
  111. u, ut := o.OidcAuthQueryPre(c)
  112. if u == nil || ut == nil {
  113. return
  114. }
  115. c.JSON(http.StatusOK, apiResp.LoginRes{
  116. AccessToken: ut.Token,
  117. Type: "access_token",
  118. User: *(&apiResp.UserPayload{}).FromUser(u),
  119. })
  120. }
  121. // OauthCallback 回调
  122. // @Tags Oauth
  123. // @Summary OauthCallback
  124. // @Description OauthCallback
  125. // @Accept json
  126. // @Produce json
  127. // @Success 200 {object} apiResp.LoginRes
  128. // @Failure 500 {object} response.ErrorResponse
  129. // @Router /oauth/callback [get]
  130. func (o *Oauth) OauthCallback(c *gin.Context) {
  131. state := c.Query("state")
  132. if state == "" {
  133. c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state"))
  134. return
  135. }
  136. cacheKey := state
  137. //从缓存中获取
  138. v := service.AllService.OauthService.GetOauthCache(cacheKey)
  139. if v == nil {
  140. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired"))
  141. return
  142. }
  143. ty := v.Op
  144. ac := v.Action
  145. var u *model.User
  146. openid := ""
  147. thirdName := ""
  148. //fmt.Println("ty ac ", ty, ac)
  149. if ty == model.OauthTypeGithub {
  150. code := c.Query("code")
  151. err, userData := service.AllService.OauthService.GithubCallback(code)
  152. if err != nil {
  153. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
  154. return
  155. }
  156. openid = strconv.Itoa(userData.Id)
  157. thirdName = userData.Login
  158. } else if ty == model.OauthTypeGoogle {
  159. code := c.Query("code")
  160. err, userData := service.AllService.OauthService.GoogleCallback(code)
  161. if err != nil {
  162. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
  163. return
  164. }
  165. openid = userData.Email
  166. //将空格替换成_
  167. thirdName = strings.Replace(userData.Name, " ", "_", -1)
  168. } else if ty == model.OauthTypeOidc {
  169. code := c.Query("code")
  170. err, userData := service.AllService.OauthService.OidcCallback(code)
  171. if err != nil {
  172. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
  173. return
  174. }
  175. openid = userData.Sub
  176. thirdName = userData.PreferredUsername
  177. } else {
  178. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError"))
  179. return
  180. }
  181. if ac == service.OauthActionTypeBind {
  182. //fmt.Println("bind", ty, userData)
  183. utr := service.AllService.OauthService.UserThirdInfo(ty, openid)
  184. if utr.UserId > 0 {
  185. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser"))
  186. return
  187. }
  188. //绑定
  189. u = service.AllService.UserService.InfoById(v.UserId)
  190. if u == nil {
  191. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound"))
  192. return
  193. }
  194. //绑定
  195. err := service.AllService.OauthService.BindOauthUser(ty, openid, thirdName, v.UserId)
  196. if err != nil {
  197. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail"))
  198. return
  199. }
  200. c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess"))
  201. return
  202. } else if ac == service.OauthActionTypeLogin {
  203. //登录
  204. if v.UserId != 0 {
  205. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess"))
  206. return
  207. }
  208. u = service.AllService.UserService.InfoByGithubId(openid)
  209. if u == nil {
  210. oa := service.AllService.OauthService.InfoByOp(ty)
  211. if !*oa.AutoRegister {
  212. //c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定")
  213. v.ThirdName = thirdName
  214. v.ThirdOpenId = openid
  215. url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey
  216. c.Redirect(http.StatusFound, url)
  217. return
  218. }
  219. //自动注册
  220. u = service.AllService.UserService.RegisterByOauth(ty, thirdName, openid)
  221. if u.Id == 0 {
  222. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed"))
  223. return
  224. }
  225. }
  226. v.UserId = u.Id
  227. service.AllService.OauthService.SetOauthCache(cacheKey, v, 0)
  228. // 如果是webadmin,登录成功后跳转到webadmin
  229. if v.DeviceType == "webadmin" {
  230. /*service.AllService.UserService.Login(u, &model.LoginLog{
  231. UserId: u.Id,
  232. Client: "webadmin",
  233. Uuid: "", //must be empty
  234. Ip: c.ClientIP(),
  235. Type: model.LoginLogTypeOauth,
  236. Platform: v.DeviceOs,
  237. })*/
  238. url := global.Config.Rustdesk.ApiServer + "/_admin/#/"
  239. c.Redirect(http.StatusFound, url)
  240. return
  241. }
  242. c.String(http.StatusOK, response.TranslateMsg(c, "OauthSuccess"))
  243. return
  244. } else {
  245. c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError"))
  246. return
  247. }
  248. }