Browse Source

Merge branch 'oauth_re' of https://github.com/IamTaoChen/rustdesk-api into IamTaoChen-oauth_re

ljw 1 year ago
parent
commit
8af01c859c

+ 4 - 1
Dockerfile.dev

@@ -42,8 +42,11 @@ RUN if [ "$COUNTRY" = "CN" ] ; then \
42
     fi && \
42
     fi && \
43
     apk update && apk add --no-cache git
43
     apk update && apk add --no-cache git
44
 
44
 
45
+ARG FREONTEND_GIT_REPO=https://github.com/lejianwen/rustdesk-api-web.git
46
+ARG FRONTEND_GIT_BRANCH=master
45
 # Clone the frontend repository
47
 # Clone the frontend repository
46
-RUN git clone https://github.com/lejianwen/rustdesk-api-web .
48
+
49
+RUN git clone -b $FRONTEND_GIT_BRANCH $FREONTEND_GIT_REPO .
47
 
50
 
48
 # Install required tools without caching index to minimize image size
51
 # Install required tools without caching index to minimize image size
49
 RUN if [ "$COUNTRY" = "CN" ] ; then \
52
 RUN if [ "$COUNTRY" = "CN" ] ; then \

+ 2 - 0
docker-compose-dev.yaml

@@ -5,6 +5,8 @@ services:
5
       dockerfile: Dockerfile.dev
5
       dockerfile: Dockerfile.dev
6
       args:
6
       args:
7
         COUNTRY: CN
7
         COUNTRY: CN
8
+        FREONTEND_GIT_REPO: https://github.com/lejianwen/rustdesk-api-web.git
9
+        FRONTEND_GIT_BRANCH: master
8
     # image: lejianwen/rustdesk-api
10
     # image: lejianwen/rustdesk-api
9
     container_name: rustdesk-api
11
     container_name: rustdesk-api
10
     environment:
12
     environment:

+ 12 - 22
http/controller/admin/login.go

@@ -11,7 +11,6 @@ import (
11
 	"Gwen/service"
11
 	"Gwen/service"
12
 	"fmt"
12
 	"fmt"
13
 	"github.com/gin-gonic/gin"
13
 	"github.com/gin-gonic/gin"
14
-	"gorm.io/gorm"
15
 )
14
 )
16
 
15
 
17
 type Login struct {
16
 type Login struct {
@@ -60,12 +59,7 @@ func (ct *Login) Login(c *gin.Context) {
60
 		Platform: f.Platform,
59
 		Platform: f.Platform,
61
 	})
60
 	})
62
 
61
 
63
-	response.Success(c, &adResp.LoginPayload{
64
-		Token:      ut.Token,
65
-		Username:   u.Username,
66
-		RouteNames: service.AllService.UserService.RouteNames(u),
67
-		Nickname:   u.Nickname,
68
-	})
62
+	responseLoginSuccess(c, u, ut.Token)
69
 }
63
 }
70
 
64
 
71
 // Logout 登出
65
 // Logout 登出
@@ -96,13 +90,7 @@ func (ct *Login) Logout(c *gin.Context) {
96
 // @Failure 500 {object} response.ErrorResponse
90
 // @Failure 500 {object} response.ErrorResponse
97
 // @Router /admin/login-options [post]
91
 // @Router /admin/login-options [post]
98
 func (ct *Login) LoginOptions(c *gin.Context) {
92
 func (ct *Login) LoginOptions(c *gin.Context) {
99
-	res := service.AllService.OauthService.List(1, 100, func(tx *gorm.DB) {
100
-		tx.Select("op").Order("id")
101
-	})
102
-	var ops []string
103
-	for _, v := range res.Oauths {
104
-		ops = append(ops, v.Op)
105
-	}
93
+	ops := service.AllService.OauthService.GetOauthProviders()
106
 	response.Success(c, gin.H{
94
 	response.Success(c, gin.H{
107
 		"ops":      ops,
95
 		"ops":      ops,
108
 		"register": global.Config.App.Register,
96
 		"register": global.Config.App.Register,
@@ -163,12 +151,14 @@ func (ct *Login) OidcAuthQuery(c *gin.Context) {
163
 	if ut == nil {
151
 	if ut == nil {
164
 		return
152
 		return
165
 	}
153
 	}
166
-	//fmt.Println("u:", u)
167
-	//fmt.Println("ut:", ut)
168
-	response.Success(c, &adResp.LoginPayload{
169
-		Token:      ut.Token,
170
-		Username:   u.Username,
171
-		RouteNames: service.AllService.UserService.RouteNames(u),
172
-		Nickname:   u.Nickname,
173
-	})
154
+	responseLoginSuccess(c, u, ut.Token)
174
 }
155
 }
156
+
157
+
158
+func responseLoginSuccess(c *gin.Context, u *model.User, token string) {
159
+	lp := &adResp.LoginPayload{}
160
+	lp.FromUser(u)
161
+	lp.Token = token
162
+	lp.RouteNames = service.AllService.UserService.RouteNames(u)
163
+	response.Success(c, lp)
164
+}

+ 13 - 29
http/controller/admin/oauth.go

@@ -5,7 +5,6 @@ import (
5
 	"Gwen/http/request/admin"
5
 	"Gwen/http/request/admin"
6
 	adminReq "Gwen/http/request/admin"
6
 	adminReq "Gwen/http/request/admin"
7
 	"Gwen/http/response"
7
 	"Gwen/http/response"
8
-	"Gwen/model"
9
 	"Gwen/service"
8
 	"Gwen/service"
10
 	"github.com/gin-gonic/gin"
9
 	"github.com/gin-gonic/gin"
11
 	"strconv"
10
 	"strconv"
@@ -96,21 +95,23 @@ func (o *Oauth) BindConfirm(c *gin.Context) {
96
 		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
95
 		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
97
 		return
96
 		return
98
 	}
97
 	}
99
-	v := service.AllService.OauthService.GetOauthCache(j.Code)
100
-	if v == nil {
98
+	oauthService := service.AllService.OauthService
99
+	oauthCache := oauthService.GetOauthCache(j.Code)
100
+	if oauthCache == nil {
101
 		response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired"))
101
 		response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired"))
102
 		return
102
 		return
103
 	}
103
 	}
104
-	u := service.AllService.UserService.CurUser(c)
105
-	err = service.AllService.OauthService.BindOauthUser(v.Op, v.ThirdOpenId, v.ThirdName, u.Id)
104
+	oauthUser := oauthCache.ToOauthUser()
105
+	user := service.AllService.UserService.CurUser(c)
106
+	err = oauthService.BindOauthUser(user.Id, oauthUser, oauthCache.Op)
106
 	if err != nil {
107
 	if err != nil {
107
 		response.Fail(c, 101, response.TranslateMsg(c, "BindFail"))
108
 		response.Fail(c, 101, response.TranslateMsg(c, "BindFail"))
108
 		return
109
 		return
109
 	}
110
 	}
110
 
111
 
111
-	v.UserId = u.Id
112
-	service.AllService.OauthService.SetOauthCache(j.Code, v, 0)
113
-	response.Success(c, v)
112
+	oauthCache.UserId = user.Id
113
+	oauthService.SetOauthCache(j.Code, oauthCache, 0)
114
+	response.Success(c, oauthCache)
114
 }
115
 }
115
 
116
 
116
 func (o *Oauth) Unbind(c *gin.Context) {
117
 func (o *Oauth) Unbind(c *gin.Context) {
@@ -126,28 +127,11 @@ func (o *Oauth) Unbind(c *gin.Context) {
126
 		response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound"))
127
 		response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound"))
127
 		return
128
 		return
128
 	}
129
 	}
129
-	if f.Op == model.OauthTypeGithub {
130
-		err = service.AllService.OauthService.UnBindGithubUser(u.Id)
131
-		if err != nil {
132
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
133
-			return
134
-		}
135
-	}
136
-	if f.Op == model.OauthTypeGoogle {
137
-		err = service.AllService.OauthService.UnBindGoogleUser(u.Id)
138
-		if err != nil {
139
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
140
-			return
141
-		}
142
-	}
143
-	if f.Op == model.OauthTypeOidc {
144
-		err = service.AllService.OauthService.UnBindOidcUser(u.Id)
145
-		if err != nil {
146
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
147
-			return
148
-		}
130
+	err = service.AllService.OauthService.UnBindOauthUser(u.Id, f.Op)
131
+	if err != nil {
132
+		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
133
+		return
149
 	}
134
 	}
150
-
151
 	response.Success(c, nil)
135
 	response.Success(c, nil)
152
 }
136
 }
153
 
137
 

+ 51 - 15
http/controller/admin/user.go

@@ -10,6 +10,7 @@ import (
10
 	"github.com/gin-gonic/gin"
10
 	"github.com/gin-gonic/gin"
11
 	"gorm.io/gorm"
11
 	"gorm.io/gorm"
12
 	"strconv"
12
 	"strconv"
13
+	"time"
13
 )
14
 )
14
 
15
 
15
 type User struct {
16
 type User struct {
@@ -216,12 +217,7 @@ func (ct *User) Current(c *gin.Context) {
216
 	u := service.AllService.UserService.CurUser(c)
217
 	u := service.AllService.UserService.CurUser(c)
217
 	token, _ := c.Get("token")
218
 	token, _ := c.Get("token")
218
 	t := token.(string)
219
 	t := token.(string)
219
-	response.Success(c, &adResp.LoginPayload{
220
-		Token:      t,
221
-		Username:   u.Username,
222
-		RouteNames: service.AllService.UserService.RouteNames(u),
223
-		Nickname:   u.Nickname,
224
-	})
220
+	responseLoginSuccess(c, u, t)
225
 }
221
 }
226
 
222
 
227
 // ChangeCurPwd 修改当前用户密码
223
 // ChangeCurPwd 修改当前用户密码
@@ -286,10 +282,10 @@ func (ct *User) MyOauth(c *gin.Context) {
286
 	var res []*adResp.UserOauthItem
282
 	var res []*adResp.UserOauthItem
287
 	for _, oa := range oal.Oauths {
283
 	for _, oa := range oal.Oauths {
288
 		item := &adResp.UserOauthItem{
284
 		item := &adResp.UserOauthItem{
289
-			ThirdType: oa.Op,
285
+			Op: oa.Op,
290
 		}
286
 		}
291
 		for _, ut := range uts {
287
 		for _, ut := range uts {
292
-			if ut.ThirdType == oa.Op {
288
+			if ut.Op == oa.Op {
293
 				item.Status = 1
289
 				item.Status = 1
294
 				break
290
 				break
295
 			}
291
 			}
@@ -299,6 +295,51 @@ func (ct *User) MyOauth(c *gin.Context) {
299
 	response.Success(c, res)
295
 	response.Success(c, res)
300
 }
296
 }
301
 
297
 
298
+// List 列表
299
+// @Tags 设备
300
+// @Summary 设备列表
301
+// @Description 设备列表
302
+// @Accept  json
303
+// @Produce  json
304
+// @Param page query int false "页码"
305
+// @Param page_size query int false "页大小"
306
+// @Param time_ago query int false "时间"
307
+// @Param id query string false "ID"
308
+// @Param hostname query string false "主机名"
309
+// @Param uuids query string false "uuids 用逗号分隔"
310
+// @Success 200 {object} response.Response{data=model.PeerList}
311
+// @Failure 500 {object} response.Response
312
+// @Router /admin/user/myPeer [get]
313
+// @Security token
314
+func (ct *User) MyPeer(c *gin.Context) {
315
+	query := &admin.PeerQuery{}
316
+	if err := c.ShouldBindQuery(query); err != nil {
317
+		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
318
+		return
319
+	}
320
+	u := service.AllService.UserService.CurUser(c)
321
+	res := service.AllService.PeerService.ListFilterByUserId(query.Page, query.PageSize, func(tx *gorm.DB) {
322
+		if query.TimeAgo > 0 {
323
+			lt := time.Now().Unix() - int64(query.TimeAgo)
324
+			tx.Where("last_online_time < ?", lt)
325
+		}
326
+		if query.TimeAgo < 0 {
327
+			lt := time.Now().Unix() + int64(query.TimeAgo)
328
+			tx.Where("last_online_time > ?", lt)
329
+		}
330
+		if query.Id != "" {
331
+			tx.Where("id like ?", "%"+query.Id+"%")
332
+		}
333
+		if query.Hostname != "" {
334
+			tx.Where("hostname like ?", "%"+query.Hostname+"%")
335
+		}
336
+		if query.Uuids != "" {
337
+			tx.Where("uuid in (?)", query.Uuids)
338
+		}
339
+	}, u.Id)
340
+	response.Success(c, res)
341
+}
342
+
302
 // groupUsers
343
 // groupUsers
303
 func (ct *User) GroupUsers(c *gin.Context) {
344
 func (ct *User) GroupUsers(c *gin.Context) {
304
 	q := &admin.GroupUsersQuery{}
345
 	q := &admin.GroupUsersQuery{}
@@ -345,7 +386,7 @@ func (ct *User) Register(c *gin.Context) {
345
 		response.Fail(c, 101, errList[0])
386
 		response.Fail(c, 101, errList[0])
346
 		return
387
 		return
347
 	}
388
 	}
348
-	u := service.AllService.UserService.Register(f.Username, f.Password)
389
+	u := service.AllService.UserService.Register(f.Username, f.Email, f.Password)
349
 	if u == nil || u.Id == 0 {
390
 	if u == nil || u.Id == 0 {
350
 		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed"))
391
 		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed"))
351
 		return
392
 		return
@@ -358,10 +399,5 @@ func (ct *User) Register(c *gin.Context) {
358
 		Ip:     c.ClientIP(),
399
 		Ip:     c.ClientIP(),
359
 		Type:   model.LoginLogTypeAccount,
400
 		Type:   model.LoginLogTypeAccount,
360
 	})
401
 	})
361
-	response.Success(c, &adResp.LoginPayload{
362
-		Token:      ut.Token,
363
-		Username:   u.Username,
364
-		RouteNames: service.AllService.UserService.RouteNames(u),
365
-		Nickname:   u.Nickname,
366
-	})
402
+	responseLoginSuccess(c, u, ut.Token)
367
 }
403
 }

+ 5 - 16
http/controller/api/login.go

@@ -60,6 +60,7 @@ func (l *Login) Login(c *gin.Context) {
60
 	ut := service.AllService.UserService.Login(u, &model.LoginLog{
60
 	ut := service.AllService.UserService.Login(u, &model.LoginLog{
61
 		UserId:   u.Id,
61
 		UserId:   u.Id,
62
 		Client:   f.DeviceInfo.Type,
62
 		Client:   f.DeviceInfo.Type,
63
+		DeviceId: f.Id,
63
 		Uuid:     f.Uuid,
64
 		Uuid:     f.Uuid,
64
 		Ip:       c.ClientIP(),
65
 		Ip:       c.ClientIP(),
65
 		Type:     model.LoginLogTypeAccount,
66
 		Type:     model.LoginLogTypeAccount,
@@ -83,22 +84,10 @@ func (l *Login) Login(c *gin.Context) {
83
 // @Failure 500 {object} response.ErrorResponse
84
 // @Failure 500 {object} response.ErrorResponse
84
 // @Router /login-options [get]
85
 // @Router /login-options [get]
85
 func (l *Login) LoginOptions(c *gin.Context) {
86
 func (l *Login) LoginOptions(c *gin.Context) {
86
-	oauthOks := []string{}
87
-	err, _ := service.AllService.OauthService.GetOauthConfig(model.OauthTypeGithub)
88
-	if err == nil {
89
-		oauthOks = append(oauthOks, model.OauthTypeGithub)
90
-	}
91
-	err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeGoogle)
92
-	if err == nil {
93
-		oauthOks = append(oauthOks, model.OauthTypeGoogle)
94
-	}
95
-	err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeOidc)
96
-	if err == nil {
97
-		oauthOks = append(oauthOks, model.OauthTypeOidc)
98
-	}
99
-	oauthOks = append(oauthOks, model.OauthTypeWebauth)
87
+	ops := service.AllService.OauthService.GetOauthProviders()
88
+	ops = append(ops, model.OauthTypeWebauth)
100
 	var oidcItems []map[string]string
89
 	var oidcItems []map[string]string
101
-	for _, v := range oauthOks {
90
+	for _, v := range ops {
102
 		oidcItems = append(oidcItems, map[string]string{"name": v})
91
 		oidcItems = append(oidcItems, map[string]string{"name": v})
103
 	}
92
 	}
104
 	common, err := json.Marshal(oidcItems)
93
 	common, err := json.Marshal(oidcItems)
@@ -108,7 +97,7 @@ func (l *Login) LoginOptions(c *gin.Context) {
108
 	}
97
 	}
109
 	var res []string
98
 	var res []string
110
 	res = append(res, "common-oidc/"+string(common))
99
 	res = append(res, "common-oidc/"+string(common))
111
-	for _, v := range oauthOks {
100
+	for _, v := range ops {
112
 		res = append(res, "oidc/"+v)
101
 		res = append(res, "oidc/"+v)
113
 	}
102
 	}
114
 	c.JSON(http.StatusOK, res)
103
 	c.JSON(http.StatusOK, res)

+ 38 - 69
http/controller/api/ouath.go

@@ -9,8 +9,6 @@ import (
9
 	"Gwen/service"
9
 	"Gwen/service"
10
 	"github.com/gin-gonic/gin"
10
 	"github.com/gin-gonic/gin"
11
 	"net/http"
11
 	"net/http"
12
-	"strconv"
13
-	"strings"
14
 )
12
 )
15
 
13
 
16
 type Oauth struct {
14
 type Oauth struct {
@@ -32,13 +30,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
32
 		response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
30
 		response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
33
 		return
31
 		return
34
 	}
32
 	}
35
-	//fmt.Println(f)
36
-	if f.Op != model.OauthTypeWebauth && f.Op != model.OauthTypeGoogle && f.Op != model.OauthTypeGithub && f.Op != model.OauthTypeOidc {
37
-		response.Error(c, response.TranslateMsg(c, "ParamsError"))
38
-		return
39
-	}
40
 
33
 
41
-	err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
34
+	oauthService := service.AllService.OauthService
35
+	var code string
36
+	var url string
37
+	err, code, url = oauthService.BeginAuth(f.Op)
42
 	if err != nil {
38
 	if err != nil {
43
 		response.Error(c, response.TranslateMsg(c, err.Error()))
39
 		response.Error(c, response.TranslateMsg(c, err.Error()))
44
 		return
40
 		return
@@ -98,6 +94,7 @@ func (o *Oauth) OidcAuthQueryPre(c *gin.Context) (*model.User, *model.UserToken)
98
 	ut = service.AllService.UserService.Login(u, &model.LoginLog{
94
 	ut = service.AllService.UserService.Login(u, &model.LoginLog{
99
 		UserId:   u.Id,
95
 		UserId:   u.Id,
100
 		Client:   v.DeviceType,
96
 		Client:   v.DeviceType,
97
+		DeviceId: v.Id,
101
 		Uuid:     v.Uuid,
98
 		Uuid:     v.Uuid,
102
 		Ip:       c.ClientIP(),
99
 		Ip:       c.ClientIP(),
103
 		Type:     model.LoginLogTypeOauth,
100
 		Type:     model.LoginLogTypeOauth,
@@ -149,70 +146,43 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
149
 		c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state"))
146
 		c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state"))
150
 		return
147
 		return
151
 	}
148
 	}
152
-
153
 	cacheKey := state
149
 	cacheKey := state
150
+	oauthService := service.AllService.OauthService
154
 	//从缓存中获取
151
 	//从缓存中获取
155
-	v := service.AllService.OauthService.GetOauthCache(cacheKey)
156
-	if v == nil {
152
+	oauthCache := oauthService.GetOauthCache(cacheKey)
153
+	if oauthCache == nil {
157
 		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired"))
154
 		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired"))
158
 		return
155
 		return
159
 	}
156
 	}
160
-
161
-	ty := v.Op
162
-	ac := v.Action
163
-	var u *model.User
164
-	openid := ""
165
-	thirdName := ""
166
-	//fmt.Println("ty ac ", ty, ac)
167
-
168
-	if ty == model.OauthTypeGithub {
169
-		code := c.Query("code")
170
-		err, userData := service.AllService.OauthService.GithubCallback(code)
171
-		if err != nil {
172
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
173
-			return
174
-		}
175
-		openid = strconv.Itoa(userData.Id)
176
-		thirdName = userData.Login
177
-	} else if ty == model.OauthTypeGoogle {
178
-		code := c.Query("code")
179
-		err, userData := service.AllService.OauthService.GoogleCallback(code)
180
-		if err != nil {
181
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
182
-			return
183
-		}
184
-		openid = userData.Email
185
-		//将空格替换成_
186
-		thirdName = strings.Replace(userData.Name, " ", "_", -1)
187
-	} else if ty == model.OauthTypeOidc {
188
-		code := c.Query("code")
189
-		err, userData := service.AllService.OauthService.OidcCallback(code)
190
-		if err != nil {
191
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
192
-			return
193
-		}
194
-		openid = userData.Sub
195
-		thirdName = userData.PreferredUsername
196
-	} else {
197
-		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError"))
157
+	op := oauthCache.Op
158
+	action := oauthCache.Action
159
+	var user *model.User
160
+	// 获取用户信息
161
+	code := c.Query("code")
162
+	err, oauthUser := oauthService.Callback(code, op)
163
+	if err != nil {
164
+		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
198
 		return
165
 		return
199
 	}
166
 	}
200
-	if ac == service.OauthActionTypeBind {
167
+	userId := oauthCache.UserId
168
+	openid := oauthUser.OpenId
169
+	if action == service.OauthActionTypeBind {
201
 
170
 
202
 		//fmt.Println("bind", ty, userData)
171
 		//fmt.Println("bind", ty, userData)
203
-		utr := service.AllService.OauthService.UserThirdInfo(ty, openid)
172
+		// 检查此openid是否已经绑定过
173
+		utr := oauthService.UserThirdInfo(op, openid)
204
 		if utr.UserId > 0 {
174
 		if utr.UserId > 0 {
205
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser"))
175
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser"))
206
 			return
176
 			return
207
 		}
177
 		}
208
 		//绑定
178
 		//绑定
209
-		u = service.AllService.UserService.InfoById(v.UserId)
210
-		if u == nil {
179
+		user = service.AllService.UserService.InfoById(userId)
180
+		if user == nil {
211
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound"))
181
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound"))
212
 			return
182
 			return
213
 		}
183
 		}
214
 		//绑定
184
 		//绑定
215
-		err := service.AllService.OauthService.BindOauthUser(ty, openid, thirdName, v.UserId)
185
+		err := oauthService.BindOauthUser(userId, oauthUser, op)
216
 		if err != nil {
186
 		if err != nil {
217
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail"))
187
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail"))
218
 			return
188
 			return
@@ -220,42 +190,41 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
220
 		c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess"))
190
 		c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess"))
221
 		return
191
 		return
222
 
192
 
223
-	} else if ac == service.OauthActionTypeLogin {
193
+	} else if action == service.OauthActionTypeLogin {
224
 		//登录
194
 		//登录
225
-		if v.UserId != 0 {
195
+		if userId != 0 {
226
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess"))
196
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess"))
227
 			return
197
 			return
228
 		}
198
 		}
229
-		u = service.AllService.UserService.InfoByGithubId(openid)
230
-		if u == nil {
231
-			oa := service.AllService.OauthService.InfoByOp(ty)
232
-			if !*oa.AutoRegister {
199
+		user = service.AllService.UserService.InfoByOauthId(op, openid)
200
+		if user == nil {
201
+			oauthConfig := oauthService.InfoByOp(op)
202
+			if !*oauthConfig.AutoRegister {
233
 				//c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定")
203
 				//c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定")
234
-				v.ThirdName = thirdName
235
-				v.ThirdOpenId = openid
204
+				oauthCache.UpdateFromOauthUser(oauthUser)
236
 				url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey
205
 				url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey
237
 				c.Redirect(http.StatusFound, url)
206
 				c.Redirect(http.StatusFound, url)
238
 				return
207
 				return
239
 			}
208
 			}
240
 
209
 
241
 			//自动注册
210
 			//自动注册
242
-			u = service.AllService.UserService.RegisterByOauth(ty, thirdName, openid)
243
-			if u.Id == 0 {
244
-				c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed"))
211
+			err, user = service.AllService.UserService.RegisterByOauth(oauthUser, op)
212
+			if err != nil {
213
+				c.String(http.StatusInternalServerError, response.TranslateMsg(c, err.Error()))
245
 				return
214
 				return
246
 			}
215
 			}
247
 		}
216
 		}
248
-		v.UserId = u.Id
249
-		service.AllService.OauthService.SetOauthCache(cacheKey, v, 0)
217
+		oauthCache.UserId = user.Id
218
+		oauthService.SetOauthCache(cacheKey, oauthCache, 0)
250
 		// 如果是webadmin,登录成功后跳转到webadmin
219
 		// 如果是webadmin,登录成功后跳转到webadmin
251
-		if v.DeviceType == "webadmin" {
220
+		if oauthCache.DeviceType == "webadmin" {
252
 			/*service.AllService.UserService.Login(u, &model.LoginLog{
221
 			/*service.AllService.UserService.Login(u, &model.LoginLog{
253
 				UserId:   u.Id,
222
 				UserId:   u.Id,
254
 				Client:   "webadmin",
223
 				Client:   "webadmin",
255
 				Uuid:     "", //must be empty
224
 				Uuid:     "", //must be empty
256
 				Ip:       c.ClientIP(),
225
 				Ip:       c.ClientIP(),
257
 				Type:     model.LoginLogTypeOauth,
226
 				Type:     model.LoginLogTypeOauth,
258
-				Platform: v.DeviceOs,
227
+				Platform: oauthService.DeviceOs,
259
 			})*/
228
 			})*/
260
 			url := global.Config.Rustdesk.ApiServer + "/_admin/#/"
229
 			url := global.Config.Rustdesk.ApiServer + "/_admin/#/"
261
 			c.Redirect(http.StatusFound, url)
230
 			c.Redirect(http.StatusFound, url)

+ 13 - 9
http/request/admin/oauth.go

@@ -1,6 +1,8 @@
1
 package admin
1
 package admin
2
 
2
 
3
-import "Gwen/model"
3
+import (
4
+	"Gwen/model"
5
+)
4
 
6
 
5
 type BindOauthForm struct {
7
 type BindOauthForm struct {
6
 	Op string `json:"op" binding:"required"`
8
 	Op string `json:"op" binding:"required"`
@@ -13,19 +15,21 @@ type UnBindOauthForm struct {
13
 	Op string `json:"op" binding:"required"`
15
 	Op string `json:"op" binding:"required"`
14
 }
16
 }
15
 type OauthForm struct {
17
 type OauthForm struct {
16
-	Id           uint   `json:"id"`
17
-	Op           string `json:"op" validate:"required"`
18
-	Issuer	     string `json:"issuer" validate:"omitempty,url"`
19
-	Scopes	   	 string `json:"scopes" validate:"omitempty"`
20
-	ClientId     string `json:"client_id" validate:"required"`
21
-	ClientSecret string `json:"client_secret" validate:"required"`
22
-	RedirectUrl  string `json:"redirect_url" validate:"required"`
23
-	AutoRegister *bool  `json:"auto_register"`
18
+	Id           uint   			`json:"id"`
19
+	Op           string 			`json:"op" validate:"omitempty"`
20
+	OauthType    string 			`json:"oauth_type" validate:"required"`
21
+	Issuer	     string 			`json:"issuer" validate:"omitempty,url"`
22
+	Scopes	   	 string 			`json:"scopes" validate:"omitempty"`
23
+	ClientId     string 			`json:"client_id" validate:"required"`
24
+	ClientSecret string 			`json:"client_secret" validate:"required"`
25
+	RedirectUrl  string 			`json:"redirect_url" validate:"required"`
26
+	AutoRegister *bool  			`json:"auto_register"`
24
 }
27
 }
25
 
28
 
26
 func (of *OauthForm) ToOauth() *model.Oauth {
29
 func (of *OauthForm) ToOauth() *model.Oauth {
27
 	oa := &model.Oauth{
30
 	oa := &model.Oauth{
28
 		Op:           of.Op,
31
 		Op:           of.Op,
32
+		OauthType:	  of.OauthType,
29
 		ClientId:     of.ClientId,
33
 		ClientId:     of.ClientId,
30
 		ClientSecret: of.ClientSecret,
34
 		ClientSecret: of.ClientSecret,
31
 		RedirectUrl:  of.RedirectUrl,
35
 		RedirectUrl:  of.RedirectUrl,

+ 11 - 7
http/request/admin/user.go

@@ -5,20 +5,22 @@ import (
5
 )
5
 )
6
 
6
 
7
 type UserForm struct {
7
 type UserForm struct {
8
-	Id       uint   `json:"id"`
9
-	Username string `json:"username" validate:"required,gte=4,lte=10"`
8
+	Id       uint   			`json:"id"`
9
+	Username string 			`json:"username" validate:"required,gte=4,lte=10"`
10
+	Email	 string           	`json:"email" validate:"required,email"`
10
 	//Password string           `json:"password" validate:"required,gte=4,lte=20"`
11
 	//Password string           `json:"password" validate:"required,gte=4,lte=20"`
11
-	Nickname string           `json:"nickname"`
12
-	Avatar   string           `json:"avatar"`
13
-	GroupId  uint             `json:"group_id" validate:"required"`
14
-	IsAdmin  *bool            `json:"is_admin" `
15
-	Status   model.StatusCode `json:"status" validate:"required,gte=0"`
12
+	Nickname string           	`json:"nickname"`
13
+	Avatar   string           	`json:"avatar"`
14
+	GroupId  uint             	`json:"group_id" validate:"required"`
15
+	IsAdmin  *bool            	`json:"is_admin" `
16
+	Status   model.StatusCode 	`json:"status" validate:"required,gte=0"`
16
 }
17
 }
17
 
18
 
18
 func (uf *UserForm) FromUser(user *model.User) *UserForm {
19
 func (uf *UserForm) FromUser(user *model.User) *UserForm {
19
 	uf.Id = user.Id
20
 	uf.Id = user.Id
20
 	uf.Username = user.Username
21
 	uf.Username = user.Username
21
 	uf.Nickname = user.Nickname
22
 	uf.Nickname = user.Nickname
23
+	uf.Email = user.Email
22
 	uf.Avatar = user.Avatar
24
 	uf.Avatar = user.Avatar
23
 	uf.GroupId = user.GroupId
25
 	uf.GroupId = user.GroupId
24
 	uf.IsAdmin = user.IsAdmin
26
 	uf.IsAdmin = user.IsAdmin
@@ -30,6 +32,7 @@ func (uf *UserForm) ToUser() *model.User {
30
 	user.Id = uf.Id
32
 	user.Id = uf.Id
31
 	user.Username = uf.Username
33
 	user.Username = uf.Username
32
 	user.Nickname = uf.Nickname
34
 	user.Nickname = uf.Nickname
35
+	user.Email = uf.Email
33
 	user.Avatar = uf.Avatar
36
 	user.Avatar = uf.Avatar
34
 	user.GroupId = uf.GroupId
37
 	user.GroupId = uf.GroupId
35
 	user.IsAdmin = uf.IsAdmin
38
 	user.IsAdmin = uf.IsAdmin
@@ -62,6 +65,7 @@ type GroupUsersQuery struct {
62
 
65
 
63
 type RegisterForm struct {
66
 type RegisterForm struct {
64
 	Username        string `json:"username" validate:"required,gte=4,lte=10"`
67
 	Username        string `json:"username" validate:"required,gte=4,lte=10"`
68
+	Email           string `json:"email" validate:"required,email"`
65
 	Password        string `json:"password" validate:"required,gte=4,lte=20"`
69
 	Password        string `json:"password" validate:"required,gte=4,lte=20"`
66
 	ConfirmPassword string `json:"confirm_password" validate:"required,gte=4,lte=20"`
70
 	ConfirmPassword string `json:"confirm_password" validate:"required,gte=4,lte=20"`
67
 }
71
 }

+ 12 - 3
http/response/admin/user.go

@@ -4,19 +4,28 @@ import "Gwen/model"
4
 
4
 
5
 type LoginPayload struct {
5
 type LoginPayload struct {
6
 	Username   string   `json:"username"`
6
 	Username   string   `json:"username"`
7
+	Email	   string   `json:"email"`
8
+	Avatar	   string   `json:"avatar"`
7
 	Token      string   `json:"token"`
9
 	Token      string   `json:"token"`
8
 	RouteNames []string `json:"route_names"`
10
 	RouteNames []string `json:"route_names"`
9
 	Nickname   string   `json:"nickname"`
11
 	Nickname   string   `json:"nickname"`
10
 }
12
 }
11
 
13
 
14
+func (lp *LoginPayload) FromUser(user *model.User) {
15
+	lp.Username = user.Username
16
+	lp.Email = user.Email
17
+	lp.Avatar = user.Avatar
18
+	lp.Nickname = user.Nickname
19
+}
20
+
12
 var UserRouteNames = []string{
21
 var UserRouteNames = []string{
13
-	"MyTagList", "MyAddressBookList", "MyInfo", "MyAddressBookCollection",
22
+	"MyTagList", "MyAddressBookList", "MyInfo", "MyAddressBookCollection", "MyPeer",
14
 }
23
 }
15
 var AdminRouteNames = []string{"*"}
24
 var AdminRouteNames = []string{"*"}
16
 
25
 
17
 type UserOauthItem struct {
26
 type UserOauthItem struct {
18
-	ThirdType string `json:"third_type"`
19
-	Status    int    `json:"status"`
27
+	Op 			string `json:"op"`
28
+	Status    	int    `json:"status"`
20
 }
29
 }
21
 
30
 
22
 type GroupUsersPayload struct {
31
 type GroupUsersPayload struct {

+ 1 - 0
http/response/api/user.go

@@ -29,6 +29,7 @@ type UserPayload struct {
29
 
29
 
30
 func (up *UserPayload) FromUser(user *model.User) *UserPayload {
30
 func (up *UserPayload) FromUser(user *model.User) *UserPayload {
31
 	up.Name = user.Username
31
 	up.Name = user.Username
32
+	up.Email = user.Email
32
 	up.IsAdmin = user.IsAdmin
33
 	up.IsAdmin = user.IsAdmin
33
 	up.Status = int(user.Status)
34
 	up.Status = int(user.Status)
34
 	up.Info = map[string]interface{}{}
35
 	up.Info = map[string]interface{}{}

+ 1 - 0
http/router/admin.go

@@ -53,6 +53,7 @@ func UserBind(rg *gin.RouterGroup) {
53
 		aR.GET("/current", cont.Current)
53
 		aR.GET("/current", cont.Current)
54
 		aR.POST("/changeCurPwd", cont.ChangeCurPwd)
54
 		aR.POST("/changeCurPwd", cont.ChangeCurPwd)
55
 		aR.POST("/myOauth", cont.MyOauth)
55
 		aR.POST("/myOauth", cont.MyOauth)
56
+		aR.GET("/myPeer", cont.MyPeer)
56
 		aR.POST("/groupUsers", cont.GroupUsers)
57
 		aR.POST("/groupUsers", cont.GroupUsers)
57
 	}
58
 	}
58
 	aRP := rg.Group("/user").Use(middleware.AdminPrivilege())
59
 	aRP := rg.Group("/user").Use(middleware.AdminPrivilege())

+ 1 - 0
model/loginLog.go

@@ -4,6 +4,7 @@ type LoginLog struct {
4
 	IdModel
4
 	IdModel
5
 	UserId      uint   `json:"user_id" gorm:"default:0;not null;"`
5
 	UserId      uint   `json:"user_id" gorm:"default:0;not null;"`
6
 	Client      string `json:"client"` //webadmin,webclient,app,
6
 	Client      string `json:"client"` //webadmin,webclient,app,
7
+	DeviceId	string `json:"device_id"`
7
 	Uuid        string `json:"uuid"`
8
 	Uuid        string `json:"uuid"`
8
 	Ip          string `json:"ip"`
9
 	Ip          string `json:"ip"`
9
 	Type        string `json:"type"`     //account,oauth
10
 	Type        string `json:"type"`     //account,oauth

+ 152 - 13
model/oauth.go

@@ -1,23 +1,162 @@
1
 package model
1
 package model
2
 
2
 
3
+import (
4
+	"strconv"
5
+	"strings"
6
+	"errors"
7
+)
8
+
9
+const OIDC_DEFAULT_SCOPES = "openid,profile,email"
10
+
11
+const (
12
+	// make sure the value shouldbe lowercase
13
+	OauthTypeGithub  string = "github"
14
+	OauthTypeGoogle  string = "google"
15
+	OauthTypeOidc    string = "oidc"
16
+	OauthTypeWebauth string = "webauth"
17
+)
18
+
19
+// Validate the oauth type
20
+func ValidateOauthType(oauthType string) error {
21
+	switch oauthType {
22
+	case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth:
23
+		return nil
24
+	default:
25
+		return errors.New("invalid Oauth type")
26
+	}
27
+}
28
+
29
+const (
30
+	OauthNameGithub  string = "GitHub"
31
+	OauthNameGoogle  string = "Google"
32
+	OauthNameOidc    string = "OIDC"
33
+	OauthNameWebauth string = "WebAuth"
34
+)
35
+
36
+const (
37
+	UserEndpointGithub  string = "https://api.github.com/user"
38
+	IssuerGoogle 		string = "https://accounts.google.com"
39
+)
40
+
3
 type Oauth struct {
41
 type Oauth struct {
4
 	IdModel
42
 	IdModel
5
-	Op           string `json:"op"`
6
-	ClientId     string `json:"client_id"`
7
-	ClientSecret string `json:"client_secret"`
8
-	RedirectUrl  string `json:"redirect_url"`
9
-	AutoRegister *bool  `json:"auto_register"`
10
-	Scopes       string `json:"scopes"`
11
-	Issuer	     string `json:"issuer"`
43
+	Op           string 	`json:"op"`
44
+	OauthType    string 	`json:"oauth_type"`
45
+	ClientId     string 	`json:"client_id"`
46
+	ClientSecret string 	`json:"client_secret"`
47
+	RedirectUrl  string 	`json:"redirect_url"`
48
+	AutoRegister *bool  	`json:"auto_register"`
49
+	Scopes       string 	`json:"scopes"`
50
+	Issuer	     string 	`json:"issuer"`
12
 	TimeModel
51
 	TimeModel
13
 }
52
 }
14
 
53
 
15
-const (
16
-	OauthTypeGithub  = "github"
17
-	OauthTypeGoogle  = "google"
18
-	OauthTypeOidc    = "oidc"
19
-	OauthTypeWebauth = "webauth"
20
-)
54
+
55
+
56
+// Helper function to format oauth info, it's used in the update and create method
57
+func (oa *Oauth) FormatOauthInfo() error {
58
+	oauthType := strings.TrimSpace(oa.OauthType)
59
+	err := ValidateOauthType(oa.OauthType)
60
+	if err != nil {
61
+		return err
62
+	}
63
+	// check if the op is empty, set the default value
64
+	op := strings.TrimSpace(oa.Op)
65
+	if op == "" {
66
+		switch oauthType {
67
+		case OauthTypeGithub:
68
+			oa.Op = OauthNameGithub
69
+		case OauthTypeGoogle:
70
+			oa.Op = OauthNameGoogle
71
+		case OauthTypeOidc:
72
+			oa.Op = OauthNameOidc
73
+		case OauthTypeWebauth:
74
+			oa.Op = OauthNameWebauth
75
+		default:
76
+			oa.Op = oauthType
77
+		}
78
+	}
79
+	// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
80
+	issuer := strings.TrimSpace(oa.Issuer)
81
+	// If the oauth type is google and the issuer is empty, set the issuer to the default value 
82
+	if oauthType == OauthTypeGoogle && issuer == "" {
83
+		oa.Issuer = IssuerGoogle
84
+	}
85
+	return nil
86
+}
87
+
88
+type OauthUser struct {
89
+	OpenId 			string 	`json:"open_id" gorm:"not null;index"`
90
+	Name   			string 	`json:"name"`
91
+	Username 		string 	`json:"username"`
92
+	Email  			string 	`json:"email"`
93
+	VerifiedEmail 	bool 	`json:"verified_email,omitempty"`
94
+	Picture			string 	`json:"picture,omitempty"`
95
+}
96
+
97
+func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
98
+	if overideUsername {
99
+		user.Username = ou.Username
100
+	}
101
+	user.Email = ou.Email
102
+	user.Nickname = ou.Name
103
+	user.Avatar = ou.Picture
104
+}
105
+
106
+type OauthUserBase struct {
107
+	Name  string `json:"name"`
108
+	Email string `json:"email"`
109
+}
110
+
111
+type OidcUser struct {
112
+	OauthUserBase
113
+	Sub               string `json:"sub"`
114
+	VerifiedEmail     bool   `json:"email_verified"`
115
+	PreferredUsername string `json:"preferred_username"`
116
+	Picture           string `json:"picture"`
117
+}
118
+
119
+func (ou *OidcUser) ToOauthUser() *OauthUser {
120
+	var username string
121
+	// 使用 PreferredUsername,如果不存在,降级到 Email 前缀
122
+	if ou.PreferredUsername != "" {
123
+		username = ou.PreferredUsername
124
+	} else {
125
+		username = strings.ToLower(strings.Split(ou.Email, "@")[0])
126
+	}
127
+
128
+	return &OauthUser{
129
+		OpenId:        ou.Sub,
130
+		Name:          ou.Name,
131
+		Username:      username,
132
+		Email:         ou.Email,
133
+		VerifiedEmail: ou.VerifiedEmail,
134
+		Picture:       ou.Picture,
135
+	}
136
+}
137
+
138
+
139
+type GithubUser struct {
140
+	OauthUserBase
141
+	Id                int         `json:"id"`
142
+	Login             string      `json:"login"`
143
+	AvatarUrl         string      `json:"avatar_url"`
144
+	VerifiedEmail	  bool        `json:"verified_email"`
145
+}
146
+
147
+func (gu *GithubUser) ToOauthUser() *OauthUser {
148
+	username := strings.ToLower(gu.Login)
149
+	return &OauthUser{
150
+		OpenId: 		strconv.Itoa(gu.Id),
151
+		Name:   		gu.Name,
152
+		Username: 		username,
153
+		Email:  		gu.Email,
154
+		VerifiedEmail: 	gu.VerifiedEmail,
155
+		Picture:		gu.AvatarUrl,
156
+	}
157
+}
158
+
159
+
21
 
160
 
22
 type OauthList struct {
161
 type OauthList struct {
23
 	Oauths []*Oauth `json:"list"`
162
 	Oauths []*Oauth `json:"list"`

+ 16 - 0
model/user.go

@@ -1,8 +1,15 @@
1
 package model
1
 package model
2
 
2
 
3
+import (
4
+	"fmt"
5
+	"gorm.io/gorm"
6
+)
7
+
3
 type User struct {
8
 type User struct {
4
 	IdModel
9
 	IdModel
5
 	Username string     `json:"username" gorm:"default:'';not null;uniqueIndex"`
10
 	Username string     `json:"username" gorm:"default:'';not null;uniqueIndex"`
11
+	Email	string     	`json:"email" gorm:"default:'';not null;uniqueIndex"`
12
+	// Email	string     	`json:"email" `
6
 	Password string     `json:"-" gorm:"default:'';not null;"`
13
 	Password string     `json:"-" gorm:"default:'';not null;"`
7
 	Nickname string     `json:"nickname" gorm:"default:'';not null;"`
14
 	Nickname string     `json:"nickname" gorm:"default:'';not null;"`
8
 	Avatar   string     `json:"avatar" gorm:"default:'';not null;"`
15
 	Avatar   string     `json:"avatar" gorm:"default:'';not null;"`
@@ -12,6 +19,15 @@ type User struct {
12
 	TimeModel
19
 	TimeModel
13
 }
20
 }
14
 
21
 
22
+// BeforeSave 钩子用于确保 email 字段有合理的默认值
23
+func (u *User) BeforeSave(tx *gorm.DB) (err error) {
24
+    // 如果 email 为空,设置为默认值
25
+    if u.Email == "" {
26
+        u.Email = fmt.Sprintf("%s@example.com", u.Username)
27
+    }
28
+    return nil
29
+}
30
+
15
 type UserList struct {
31
 type UserList struct {
16
 	Users []*User `json:"list,omitempty"`
32
 	Users []*User `json:"list,omitempty"`
17
 	Pagination
33
 	Pagination

+ 19 - 6
model/userThird.go

@@ -1,12 +1,25 @@
1
 package model
1
 package model
2
 
2
 
3
+import (
4
+	"strings"
5
+)
6
+
3
 type UserThird struct {
7
 type UserThird struct {
4
 	IdModel
8
 	IdModel
5
-	UserId     uint   `json:"user_id" gorm:"not null;index"`
6
-	OpenId     string `json:"open_id" gorm:"not null;index"`
7
-	UnionId    string `json:"union_id" gorm:"not null;"`
8
-	ThirdType  string `json:"third_type" gorm:"not null;"`
9
-	ThirdEmail string `json:"third_email"`
10
-	ThirdName  string `json:"third_name"`
9
+	UserId     		uint   `	json:"user_id" gorm:"not null;index"`
10
+	OauthUser
11
+	// UnionId    		string `json:"union_id" gorm:"not null;"`
12
+	// OauthType  	   	string 		`json:"oauth_type" gorm:"not null;"`
13
+	OauthType  	   	string 		`json:"oauth_type"`
14
+	Op  			string 		`json:"op" gorm:"not null;"`
11
 	TimeModel
15
 	TimeModel
12
 }
16
 }
17
+
18
+func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
19
+	u.UserId 			= userId
20
+	u.OauthUser 		= *oauthUser
21
+	u.OauthType 		= oauthType
22
+	u.Op 				= op
23
+	// make sure email is lower case
24
+	u.Email 			= strings.ToLower(u.Email)
25
+}

+ 5 - 3
model/userToken.go

@@ -2,9 +2,11 @@ package model
2
 
2
 
3
 type UserToken struct {
3
 type UserToken struct {
4
 	IdModel
4
 	IdModel
5
-	UserId    uint   `json:"user_id" gorm:"default:0;not null;index"`
6
-	Token     string `json:"token" gorm:"default:'';not null;index"`
7
-	ExpiredAt int64  `json:"expired_at" gorm:"default:0;not null;"`
5
+	UserId    	uint   `json:"user_id" gorm:"default:0;not null;index"`
6
+	DeviceUuid 	string `json:"device_uuid" gorm:"default:'';omitempty;"`
7
+	DeviceId	string `json:"device_id" gorm:"default:'';omitempty;"`
8
+	Token     	string `json:"token" gorm:"default:'';not null;index"`
9
+	ExpiredAt 	int64  `json:"expired_at" gorm:"default:0;not null;"`
8
 	TimeModel
10
 	TimeModel
9
 }
11
 }
10
 
12
 

+ 252 - 277
service/oauth.go

@@ -9,17 +9,22 @@ import (
9
 	"errors"
9
 	"errors"
10
 	"golang.org/x/oauth2"
10
 	"golang.org/x/oauth2"
11
 	"golang.org/x/oauth2/github"
11
 	"golang.org/x/oauth2/github"
12
-	"golang.org/x/oauth2/google"
12
+	// "golang.org/x/oauth2/google"
13
 	"gorm.io/gorm"
13
 	"gorm.io/gorm"
14
-	"io"
14
+	// "io"
15
 	"net/http"
15
 	"net/http"
16
 	"net/url"
16
 	"net/url"
17
 	"strconv"
17
 	"strconv"
18
 	"strings"
18
 	"strings"
19
 	"sync"
19
 	"sync"
20
 	"time"
20
 	"time"
21
+	"fmt"
21
 )
22
 )
22
 
23
 
24
+
25
+type OauthService struct {
26
+}
27
+
23
 // Define a struct to parse the .well-known/openid-configuration response
28
 // Define a struct to parse the .well-known/openid-configuration response
24
 type OidcEndpoint struct {
29
 type OidcEndpoint struct {
25
 	Issuer   string `json:"issuer"`
30
 	Issuer   string `json:"issuer"`
@@ -28,73 +33,6 @@ type OidcEndpoint struct {
28
 	UserInfo string `json:"userinfo_endpoint"`
33
 	UserInfo string `json:"userinfo_endpoint"`
29
 }
34
 }
30
 
35
 
31
-type OauthService struct {
32
-}
33
-
34
-type GithubUserdata struct {
35
-	AvatarUrl         string      `json:"avatar_url"`
36
-	Bio               string      `json:"bio"`
37
-	Blog              string      `json:"blog"`
38
-	Collaborators     int         `json:"collaborators"`
39
-	Company           interface{} `json:"company"`
40
-	CreatedAt         time.Time   `json:"created_at"`
41
-	DiskUsage         int         `json:"disk_usage"`
42
-	Email             interface{} `json:"email"`
43
-	EventsUrl         string      `json:"events_url"`
44
-	Followers         int         `json:"followers"`
45
-	FollowersUrl      string      `json:"followers_url"`
46
-	Following         int         `json:"following"`
47
-	FollowingUrl      string      `json:"following_url"`
48
-	GistsUrl          string      `json:"gists_url"`
49
-	GravatarId        string      `json:"gravatar_id"`
50
-	Hireable          interface{} `json:"hireable"`
51
-	HtmlUrl           string      `json:"html_url"`
52
-	Id                int         `json:"id"`
53
-	Location          interface{} `json:"location"`
54
-	Login             string      `json:"login"`
55
-	Name              string      `json:"name"`
56
-	NodeId            string      `json:"node_id"`
57
-	NotificationEmail interface{} `json:"notification_email"`
58
-	OrganizationsUrl  string      `json:"organizations_url"`
59
-	OwnedPrivateRepos int         `json:"owned_private_repos"`
60
-	Plan              struct {
61
-		Collaborators int    `json:"collaborators"`
62
-		Name          string `json:"name"`
63
-		PrivateRepos  int    `json:"private_repos"`
64
-		Space         int    `json:"space"`
65
-	} `json:"plan"`
66
-	PrivateGists      int    `json:"private_gists"`
67
-	PublicGists       int    `json:"public_gists"`
68
-	PublicRepos       int    `json:"public_repos"`
69
-	ReceivedEventsUrl string `json:"received_events_url"`
70
-	ReposUrl          string `json:"repos_url"`
71
-	SiteAdmin         bool   `json:"site_admin"`
72
-	StarredUrl        string `json:"starred_url"`
73
-	SubscriptionsUrl  string `json:"subscriptions_url"`
74
-	TotalPrivateRepos int    `json:"total_private_repos"`
75
-	//TwitterUsername         interface{} `json:"twitter_username"`
76
-	TwoFactorAuthentication bool      `json:"two_factor_authentication"`
77
-	Type                    string    `json:"type"`
78
-	UpdatedAt               time.Time `json:"updated_at"`
79
-	Url                     string    `json:"url"`
80
-}
81
-type GoogleUserdata struct {
82
-	Email         string `json:"email"`
83
-	FamilyName    string `json:"family_name"`
84
-	GivenName     string `json:"given_name"`
85
-	Id            string `json:"id"`
86
-	Name          string `json:"name"`
87
-	Picture       string `json:"picture"`
88
-	VerifiedEmail bool   `json:"verified_email"`
89
-}
90
-type OidcUserdata struct {
91
-	Sub               string `json:"sub"`
92
-	Email             string `json:"email"`
93
-	VerifiedEmail     bool   `json:"email_verified"`
94
-	Name              string `json:"name"`
95
-	PreferredUsername string `json:"preferred_username"`
96
-}
97
-
98
 type OauthCacheItem struct {
36
 type OauthCacheItem struct {
99
 	UserId      uint   `json:"user_id"`
37
 	UserId      uint   `json:"user_id"`
100
 	Id          string `json:"id"` //rustdesk的设备ID
38
 	Id          string `json:"id"` //rustdesk的设备ID
@@ -104,9 +42,19 @@ type OauthCacheItem struct {
104
 	DeviceName  string `json:"device_name"`
42
 	DeviceName  string `json:"device_name"`
105
 	DeviceOs    string `json:"device_os"`
43
 	DeviceOs    string `json:"device_os"`
106
 	DeviceType  string `json:"device_type"`
44
 	DeviceType  string `json:"device_type"`
107
-	ThirdOpenId string `json:"third_open_id"`
108
-	ThirdName   string `json:"third_name"`
109
-	ThirdEmail  string `json:"third_email"`
45
+	OpenId 		string `json:"open_id"`
46
+	Username	string `json:"username"`
47
+	Name   		string `json:"name"`
48
+	Email  		string `json:"email"`
49
+}
50
+
51
+func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
52
+	return &model.OauthUser{
53
+		OpenId: oci.OpenId,
54
+		Username: oci.Username,
55
+		Name: oci.Name,
56
+		Email: oci.Email,
57
+	}
110
 }
58
 }
111
 
59
 
112
 var OauthCache = &sync.Map{}
60
 var OauthCache = &sync.Map{}
@@ -116,6 +64,14 @@ const (
116
 	OauthActionTypeBind  = "bind"
64
 	OauthActionTypeBind  = "bind"
117
 )
65
 )
118
 
66
 
67
+func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
68
+	oa.OpenId = oauthUser.OpenId
69
+	oa.Username = oauthUser.Username
70
+	oa.Name = oauthUser.Name
71
+	oa.Email = oauthUser.Email
72
+}
73
+
74
+
119
 func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
75
 func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
120
 	v, ok := OauthCache.Load(key)
76
 	v, ok := OauthCache.Load(key)
121
 	if !ok {
77
 	if !ok {
@@ -140,22 +96,21 @@ func (os *OauthService) DeleteOauthCache(key string) {
140
 
96
 
141
 func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
97
 func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
142
 	code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
98
 	code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
143
-
144
-	if op == model.OauthTypeWebauth {
99
+	if op == string(model.OauthTypeWebauth) {
145
 		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
100
 		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
146
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
101
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
147
 		return nil, code, url
102
 		return nil, code, url
148
 	}
103
 	}
149
-	err, conf := os.GetOauthConfig(op)
104
+	err, _, oauthConfig := os.GetOauthConfig(op)
150
 	if err == nil {
105
 	if err == nil {
151
-		return err, code, conf.AuthCodeURL(code)
106
+		return err, code, oauthConfig.AuthCodeURL(code)
152
 	}
107
 	}
153
 
108
 
154
 	return err, code, ""
109
 	return err, code, ""
155
 }
110
 }
156
 
111
 
157
 // Method to fetch OIDC configuration dynamically
112
 // Method to fetch OIDC configuration dynamically
158
-func FetchOidcConfig(issuer string) (error, OidcEndpoint) {
113
+func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) {
159
 	configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
114
 	configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
160
 
115
 
161
 	// Get the HTTP client (with or without proxy based on configuration)
116
 	// Get the HTTP client (with or without proxy based on configuration)
@@ -179,76 +134,58 @@ func FetchOidcConfig(issuer string) (error, OidcEndpoint) {
179
 	return nil, endpoint
134
 	return nil, endpoint
180
 }
135
 }
181
 
136
 
182
-// GetOauthConfig retrieves the OAuth2 configuration based on the provider type
183
-func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) {
184
-	switch op {
185
-	case model.OauthTypeGithub:
186
-		return os.getGithubConfig()
187
-	case model.OauthTypeGoogle:
188
-		return os.getGoogleConfig()
189
-	case model.OauthTypeOidc:
190
-		return os.getOidcConfig()
191
-	default:
192
-		return errors.New("unsupported OAuth type"), nil
137
+func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
138
+	oauthInfo := os.InfoByOp(op)
139
+	if oauthInfo.Issuer == "" {
140
+		return errors.New("issuer is empty"), OidcEndpoint{}
193
 	}
141
 	}
142
+	return os.FetchOidcEndpoint(oauthInfo.Issuer)
194
 }
143
 }
195
 
144
 
196
-// Helper function to get GitHub OAuth2 configuration
197
-func (os *OauthService) getGithubConfig() (error, *oauth2.Config) {
198
-	g := os.InfoByOp(model.OauthTypeGithub)
199
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" {
200
-		return errors.New("ConfigNotFound"), nil
201
-	}
202
-	return nil, &oauth2.Config{
203
-		ClientID:     g.ClientId,
204
-		ClientSecret: g.ClientSecret,
205
-		RedirectURL:  g.RedirectUrl,
206
-		Endpoint:     github.Endpoint,
207
-		Scopes:       []string{"read:user", "user:email"},
145
+// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
146
+func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
147
+	err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
148
+	if err != nil {
149
+		return err, nil, nil
208
 	}
150
 	}
209
-}
210
-
211
-// Helper function to get Google OAuth2 configuration
212
-func (os *OauthService) getGoogleConfig() (error, *oauth2.Config) {
213
-	g := os.InfoByOp(model.OauthTypeGoogle)
214
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" {
215
-		return errors.New("ConfigNotFound"), nil
151
+	// Maybe should validate the oauthConfig here
152
+	oauthType := oauthInfo.OauthType
153
+	err = model.ValidateOauthType(oauthType)
154
+	if err != nil {
155
+		return err, nil, nil
216
 	}
156
 	}
217
-	return nil, &oauth2.Config{
218
-		ClientID:     g.ClientId,
219
-		ClientSecret: g.ClientSecret,
220
-		RedirectURL:  g.RedirectUrl,
221
-		Endpoint:     google.Endpoint,
222
-		Scopes:       []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"},
157
+	switch oauthType {
158
+	case model.OauthTypeGithub:
159
+		oauthConfig.Endpoint = github.Endpoint
160
+		oauthConfig.Scopes = []string{"read:user", "user:email"}
161
+	case model.OauthTypeOidc, model.OauthTypeGoogle:
162
+		var endpoint OidcEndpoint
163
+		err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
164
+		if err != nil {
165
+			return err, nil, nil
166
+		}
167
+		oauthConfig.Endpoint = oauth2.Endpoint{AuthURL:  endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
168
+		oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
169
+	default:
170
+		return errors.New("unsupported OAuth type"), nil, nil
223
 	}
171
 	}
172
+	return nil, oauthInfo, oauthConfig
224
 }
173
 }
225
 
174
 
226
-// Helper function to get OIDC OAuth2 configuration
227
-func (os *OauthService) getOidcConfig() (error, *oauth2.Config) {
228
-	g := os.InfoByOp(model.OauthTypeOidc)
229
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" || g.Issuer == "" {
230
-		return errors.New("ConfigNotFound"), nil
175
+// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
176
+func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
177
+	oauthInfo = os.InfoByOp(op)
178
+	if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
179
+		return errors.New("ConfigNotFound"), nil, nil
231
 	}
180
 	}
232
-
233
-	// Set scopes
234
-	scopes := strings.TrimSpace(g.Scopes)
235
-	if scopes == "" {
236
-		scopes = "openid,profile,email"
181
+	// If the redirect URL is empty, use the default redirect URL
182
+	if oauthInfo.RedirectUrl == "" {
183
+		oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
237
 	}
184
 	}
238
-	scopeList := strings.Split(scopes, ",")
239
-	err, endpoint := FetchOidcConfig(g.Issuer)
240
-	if err != nil {
241
-		return err, nil
242
-	}
243
-	return nil, &oauth2.Config{
244
-		ClientID:     g.ClientId,
245
-		ClientSecret: g.ClientSecret,
246
-		RedirectURL:  g.RedirectUrl,
247
-		Endpoint: oauth2.Endpoint{
248
-			AuthURL:  endpoint.AuthURL,
249
-			TokenURL: endpoint.TokenURL,
250
-		},
251
-		Scopes: scopeList,
185
+	return nil, oauthInfo, &oauth2.Config{
186
+		ClientID:     oauthInfo.ClientId,
187
+		ClientSecret: oauthInfo.ClientSecret,
188
+		RedirectURL:  oauthInfo.RedirectUrl,
252
 	}
189
 	}
253
 }
190
 }
254
 
191
 
@@ -272,194 +209,153 @@ func getHTTPClientWithProxy() *http.Client {
272
 	return http.DefaultClient
209
 	return http.DefaultClient
273
 }
210
 }
274
 
211
 
275
-func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
276
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
277
-	if err != nil {
278
-		return err, nil
279
-	}
212
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
280
 
213
 
281
-	// 使用代理配置创建 HTTP 客户端
214
+	// 设置代理客户端
282
 	httpClient := getHTTPClientWithProxy()
215
 	httpClient := getHTTPClientWithProxy()
283
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
216
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
284
 
217
 
285
-	token, err := oauthConfig.Exchange(ctx, code)
218
+	// 使用 code 换取 token
219
+	var token *oauth2.Token
220
+	token, err = oauthConfig.Exchange(ctx, code)
286
 	if err != nil {
221
 	if err != nil {
287
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
222
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
288
-		error = errors.New("GetOauthTokenError")
289
-		return
223
+		return errors.New("GetOauthTokenError"), nil
290
 	}
224
 	}
291
 
225
 
292
-	// 使用带有代理的 HTTP 客户端获取用户信息
293
-	client := oauthConfig.Client(ctx, token)
294
-	resp, err := client.Get("https://api.github.com/user")
226
+	// 获取用户信息
227
+	client = oauthConfig.Client(ctx, token)
228
+	resp, err := client.Get(userEndpoint)
295
 	if err != nil {
229
 	if err != nil {
296
 		global.Logger.Warn("failed getting user info: ", err)
230
 		global.Logger.Warn("failed getting user info: ", err)
297
-		error = errors.New("GetOauthUserInfoError")
298
-		return
231
+		return errors.New("GetOauthUserInfoError"), nil
299
 	}
232
 	}
300
-	defer func(Body io.ReadCloser) {
301
-		err := Body.Close()
302
-		if err != nil {
303
-			global.Logger.Warn("failed closing response body: ", err)
233
+	defer func() {
234
+		if closeErr := resp.Body.Close(); closeErr != nil {
235
+			global.Logger.Warn("failed closing response body: ", closeErr)
304
 		}
236
 		}
305
-	}(resp.Body)
237
+	}()
306
 
238
 
307
 	// 解析用户信息
239
 	// 解析用户信息
308
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
240
+	if err = json.NewDecoder(resp.Body).Decode(userData); err != nil {
309
 		global.Logger.Warn("failed decoding user info: ", err)
241
 		global.Logger.Warn("failed decoding user info: ", err)
310
-		error = errors.New("DecodeOauthUserInfoError")
311
-		return
242
+		return errors.New("DecodeOauthUserInfoError"), nil
312
 	}
243
 	}
313
-	return
244
+
245
+	return nil, client
314
 }
246
 }
315
 
247
 
316
-func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
317
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
248
+// githubCallback github回调
249
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
250
+	var user = &model.GithubUser{}
251
+	err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
318
 	if err != nil {
252
 	if err != nil {
319
 		return err, nil
253
 		return err, nil
320
 	}
254
 	}
321
-
322
-	// 使用代理配置创建 HTTP 客户端
323
-	httpClient := getHTTPClientWithProxy()
324
-	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
325
-
326
-	token, err := oauthConfig.Exchange(ctx, code)
255
+	err = os.getGithubPrimaryEmail(client, user)
327
 	if err != nil {
256
 	if err != nil {
328
-		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
329
-		error = errors.New("GetOauthTokenError")
330
-		return
257
+		return err, nil
331
 	}
258
 	}
259
+	return nil, user.ToOauthUser()
260
+}
332
 
261
 
333
-	// 使用带有代理的 HTTP 客户端获取用户信息
334
-	client := oauthConfig.Client(ctx, token)
335
-	resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
336
-	if err != nil {
337
-		global.Logger.Warn("failed getting user info: ", err)
338
-		error = errors.New("GetOauthUserInfoError")
339
-		return
340
-	}
341
-	defer func(Body io.ReadCloser) {
342
-		err := Body.Close()
343
-		if err != nil {
344
-			global.Logger.Warn("failed closing response body: ", err)
345
-		}
346
-	}(resp.Body)
347
 
262
 
348
-	// 解析用户信息
349
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
350
-		global.Logger.Warn("failed decoding user info: ", err)
351
-		error = errors.New("DecodeOauthUserInfoError")
352
-		return
263
+// oidcCallback oidc回调, 通过code获取用户信息
264
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
265
+	var user = &model.OidcUser{}
266
+	if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
267
+		return err, nil
353
 	}
268
 	}
354
-	return
269
+	return nil, user.ToOauthUser()
355
 }
270
 }
356
 
271
 
357
-func (os *OauthService) OidcCallback(code string) (error error, userData *OidcUserdata) {
358
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeOidc)
272
+// Callback: Get user information by code and op(Oauth provider)
273
+func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
274
+	var oauthInfo *model.Oauth
275
+	var oauthConfig *oauth2.Config
276
+	err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
277
+	// oauthType is already validated in GetOauthConfig
359
 	if err != nil {
278
 	if err != nil {
360
 		return err, nil
279
 		return err, nil
361
 	}
280
 	}
362
-	// 使用代理配置创建 HTTP 客户端
363
-	httpClient := getHTTPClientWithProxy()
364
-	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
365
-
366
-	token, err := oauthConfig.Exchange(ctx, code)
367
-	if err != nil {
368
-		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
369
-		error = errors.New("GetOauthTokenError")
370
-		return
371
-	}
372
-
373
-	// 使用带有代理的 HTTP 客户端获取用户信息
374
-	client := oauthConfig.Client(ctx, token)
375
-	g := os.InfoByOp(model.OauthTypeOidc)
376
-	err, endpoint := FetchOidcConfig(g.Issuer)
377
-	if err != nil {
378
-		global.Logger.Warn("failed fetching OIDC configuration: ", err)
379
-		error = errors.New("FetchOidcConfigError")
380
-		return
381
-	}
382
-	resp, err := client.Get(endpoint.UserInfo)
383
-	if err != nil {
384
-		global.Logger.Warn("failed getting user info: ", err)
385
-		error = errors.New("GetOauthUserInfoError")
386
-		return
387
-	}
388
-	defer func(Body io.ReadCloser) {
389
-		err := Body.Close()
281
+	oauthType := oauthInfo.OauthType
282
+	switch oauthType {
283
+    case model.OauthTypeGithub:
284
+        err, oauthUser = os.githubCallback(oauthConfig, code)
285
+    case model.OauthTypeOidc, model.OauthTypeGoogle:
286
+		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
390
 		if err != nil {
287
 		if err != nil {
391
-			global.Logger.Warn("failed closing response body: ", err)
288
+			return err, nil
392
 		}
289
 		}
393
-	}(resp.Body)
394
-
395
-	// 解析用户信息
396
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
397
-		global.Logger.Warn("failed decoding user info: ", err)
398
-		error = errors.New("DecodeOauthUserInfoError")
399
-		return
400
-	}
401
-	return
290
+        err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
291
+    default:
292
+        return errors.New("unsupported OAuth type"), nil
293
+    }
294
+    return err, oauthUser
402
 }
295
 }
403
 
296
 
404
-func (os *OauthService) UserThirdInfo(op, openid string) *model.UserThird {
297
+
298
+func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
405
 	ut := &model.UserThird{}
299
 	ut := &model.UserThird{}
406
-	global.DB.Where("open_id = ? and third_type = ?", openid, op).First(ut)
300
+	global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
407
 	return ut
301
 	return ut
408
 }
302
 }
409
 
303
 
410
-func (os *OauthService) BindGithubUser(openid, username string, userId uint) error {
411
-	return os.BindOauthUser(model.OauthTypeGithub, openid, username, userId)
412
-}
413
-
414
-func (os *OauthService) BindGoogleUser(email, username string, userId uint) error {
415
-	return os.BindOauthUser(model.OauthTypeGoogle, email, username, userId)
416
-}
417
-
418
-func (os *OauthService) BindOidcUser(sub, username string, userId uint) error {
419
-	return os.BindOauthUser(model.OauthTypeOidc, sub, username, userId)
420
-}
421
-
422
-func (os *OauthService) BindOauthUser(thirdType, openid, username string, userId uint) error {
423
-	utr := &model.UserThird{
424
-		OpenId:    openid,
425
-		ThirdType: thirdType,
426
-		ThirdName: username,
427
-		UserId:    userId,
304
+// BindOauthUser: Bind third party account
305
+func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error {
306
+	utr := &model.UserThird{}
307
+	err, oauthType := os.GetTypeByOp(op)
308
+	if err != nil {
309
+		return err
428
 	}
310
 	}
311
+	utr.FromOauthUser(userId, oauthUser, oauthType, op)
429
 	return global.DB.Create(utr).Error
312
 	return global.DB.Create(utr).Error
430
 }
313
 }
431
 
314
 
432
-func (os *OauthService) UnBindGithubUser(userid uint) error {
433
-	return os.UnBindThird(model.OauthTypeGithub, userid)
315
+// UnBindOauthUser: Unbind third party account
316
+func (os *OauthService) UnBindOauthUser(userId uint, op string) error {
317
+	return os.UnBindThird(op, userId)
434
 }
318
 }
435
-func (os *OauthService) UnBindGoogleUser(userid uint) error {
436
-	return os.UnBindThird(model.OauthTypeGoogle, userid)
437
-}
438
-func (os *OauthService) UnBindOidcUser(userid uint) error {
439
-	return os.UnBindThird(model.OauthTypeOidc, userid)
440
-}
441
-func (os *OauthService) UnBindThird(thirdType string, userid uint) error {
442
-	return global.DB.Where("user_id = ? and third_type = ?", userid, thirdType).Delete(&model.UserThird{}).Error
319
+
320
+// UnBindThird: Unbind third party account
321
+func (os *OauthService) UnBindThird(op string, userId uint) error {
322
+	return global.DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error
443
 }
323
 }
444
 
324
 
445
 // DeleteUserByUserId: When user is deleted, delete all third party bindings
325
 // DeleteUserByUserId: When user is deleted, delete all third party bindings
446
-func (os *OauthService) DeleteUserByUserId(userid uint) error {
447
-	return global.DB.Where("user_id = ?", userid).Delete(&model.UserThird{}).Error
326
+func (os *OauthService) DeleteUserByUserId(userId uint) error {
327
+	return global.DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error
448
 }
328
 }
449
 
329
 
450
-// InfoById 根据id取用户信息
330
+// InfoById 根据id获取Oauth信息
451
 func (os *OauthService) InfoById(id uint) *model.Oauth {
331
 func (os *OauthService) InfoById(id uint) *model.Oauth {
452
-	u := &model.Oauth{}
453
-	global.DB.Where("id = ?", id).First(u)
454
-	return u
332
+	oauthInfo := &model.Oauth{}
333
+	global.DB.Where("id = ?", id).First(oauthInfo)
334
+	return oauthInfo
455
 }
335
 }
456
 
336
 
457
-// InfoByOp 根据op取用户信息
337
+// InfoByOp 根据op获取Oauth信息
458
 func (os *OauthService) InfoByOp(op string) *model.Oauth {
338
 func (os *OauthService) InfoByOp(op string) *model.Oauth {
459
-	u := &model.Oauth{}
460
-	global.DB.Where("op = ?", op).First(u)
461
-	return u
339
+	oauthInfo := &model.Oauth{}
340
+	global.DB.Where("op = ?", op).First(oauthInfo)
341
+	return oauthInfo
462
 }
342
 }
343
+
344
+// Helper function to get scopes by operation
345
+func (os *OauthService) getScopesByOp(op string) []string {
346
+    scopes := os.InfoByOp(op).Scopes
347
+	return os.constructScopes(scopes)
348
+}
349
+
350
+// Helper function to construct scopes
351
+func (os *OauthService) constructScopes(scopes string) []string {
352
+    scopes = strings.TrimSpace(scopes)
353
+    if scopes == "" {
354
+        scopes = model.OIDC_DEFAULT_SCOPES
355
+    }
356
+    return strings.Split(scopes, ",")
357
+}
358
+
463
 func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
359
 func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
464
 	res = &model.OauthList{}
360
 	res = &model.OauthList{}
465
 	res.Page = int64(page)
361
 	res.Page = int64(page)
@@ -474,16 +370,95 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
474
 	return
370
 	return
475
 }
371
 }
476
 
372
 
373
+// GetTypeByOp 根据op获取OauthType
374
+func (os *OauthService) GetTypeByOp(op string) (error, string) {
375
+	oauthInfo := &model.Oauth{}
376
+	if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil {
377
+		return fmt.Errorf("OAuth provider with op '%s' not found", op), ""
378
+	}
379
+	return nil, oauthInfo.OauthType
380
+}
381
+
382
+// ValidateOauthProvider 验证Oauth提供者是否正确
383
+func (os *OauthService) ValidateOauthProvider(op string) error {
384
+	if !os.IsOauthProviderExist(op) {
385
+		return fmt.Errorf("OAuth provider with op '%s' not found", op)
386
+	}
387
+	return nil
388
+}
389
+
390
+// IsOauthProviderExist 验证Oauth提供者是否存在
391
+func (os *OauthService) IsOauthProviderExist(op string) bool {
392
+	oauthInfo := &model.Oauth{}
393
+	// 使用 Gorm 的 Take 方法查找符合条件的记录
394
+	if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil {
395
+		return false
396
+	}
397
+	return true
398
+}
399
+
477
 // Create 创建
400
 // Create 创建
478
-func (os *OauthService) Create(u *model.Oauth) error {
479
-	res := global.DB.Create(u).Error
401
+func (os *OauthService) Create(oauthInfo *model.Oauth) error {
402
+	err := oauthInfo.FormatOauthInfo()
403
+	if err != nil {
404
+		return err
405
+	}
406
+	res := global.DB.Create(oauthInfo).Error
480
 	return res
407
 	return res
481
 }
408
 }
482
-func (os *OauthService) Delete(u *model.Oauth) error {
483
-	return global.DB.Delete(u).Error
409
+func (os *OauthService) Delete(oauthInfo *model.Oauth) error {
410
+	return global.DB.Delete(oauthInfo).Error
484
 }
411
 }
485
 
412
 
486
 // Update 更新
413
 // Update 更新
487
-func (os *OauthService) Update(u *model.Oauth) error {
488
-	return global.DB.Model(u).Updates(u).Error
414
+func (os *OauthService) Update(oauthInfo *model.Oauth) error {
415
+	err := oauthInfo.FormatOauthInfo()
416
+	if err != nil {
417
+		return err
418
+	}
419
+	return global.DB.Model(oauthInfo).Updates(oauthInfo).Error
489
 }
420
 }
421
+
422
+// GetOauthProviders 获取所有的provider
423
+func (os *OauthService) GetOauthProviders() []string {
424
+	var res []string
425
+	global.DB.Model(&model.Oauth{}).Pluck("op", &res)
426
+	return res
427
+}
428
+
429
+// getGithubPrimaryEmail: Get the primary email of the user from Github
430
+func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *model.GithubUser) error {
431
+	// the client is already set with the token
432
+	resp, err := client.Get("https://api.github.com/user/emails")
433
+	if err != nil {
434
+		return fmt.Errorf("failed to fetch emails: %w", err)
435
+	}
436
+	defer resp.Body.Close()
437
+
438
+	// check the response status code
439
+	if resp.StatusCode != http.StatusOK {
440
+		return fmt.Errorf("failed to fetch emails: %s", resp.Status)
441
+	}
442
+
443
+	// decode the response
444
+	var emails []struct {
445
+		Email    string `json:"email"`
446
+		Primary  bool   `json:"primary"`
447
+		Verified bool   `json:"verified"`
448
+	}
449
+
450
+	if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
451
+		return fmt.Errorf("failed to decode response: %w", err)
452
+	}
453
+
454
+	// find the primary verified email
455
+	for _, e := range emails {
456
+		if e.Primary && e.Verified {
457
+			githubUser.Email = e.Email
458
+			githubUser.VerifiedEmail = e.Verified
459
+			return nil
460
+		}
461
+	}
462
+
463
+	return fmt.Errorf("no primary verified email found")
464
+}

+ 68 - 4
service/peer.go

@@ -26,15 +26,43 @@ func (ps *PeerService) InfoByRowId(id uint) *model.Peer {
26
 	return p
26
 	return p
27
 }
27
 }
28
 
28
 
29
+// FindByUserIdAndUuid 根据用户id和uuid查找peer
30
+func (ps *PeerService) FindByUserIdAndUuid(uuid string,userId uint) *model.Peer {
31
+	p := &model.Peer{}
32
+	global.DB.Where("uuid = ? and user_id = ?", uuid, userId).First(p)
33
+	return p
34
+}
35
+
29
 // UuidBindUserId 绑定用户id
36
 // UuidBindUserId 绑定用户id
30
-func (ps *PeerService) UuidBindUserId(uuid string, userId uint) {
37
+func (ps *PeerService) UuidBindUserId(deviceId string, uuid string, userId uint) {
31
 	peer := ps.FindByUuid(uuid)
38
 	peer := ps.FindByUuid(uuid)
39
+	// 如果存在则更新
32
 	if peer.RowId > 0 {
40
 	if peer.RowId > 0 {
33
 		peer.UserId = userId
41
 		peer.UserId = userId
34
 		ps.Update(peer)
42
 		ps.Update(peer)
43
+	} else {
44
+		// 不存在则创建
45
+		global.DB.Create(&model.Peer{
46
+			Id: 		deviceId,
47
+			Uuid:     	uuid,
48
+			UserId:   	userId,
49
+		})
50
+	}
51
+}
52
+
53
+// UuidUnbindUserId 解绑用户id, 用于用户注销
54
+func (ps *PeerService) UuidUnbindUserId(uuid string, userId uint) {
55
+	peer := ps.FindByUserIdAndUuid(uuid, userId)
56
+	if peer.RowId > 0 {
57
+		global.DB.Model(peer).Update("user_id", 0)
35
 	}
58
 	}
36
 }
59
 }
37
 
60
 
61
+// EraseUserId 清除用户id, 用于用户删除
62
+func (ps *PeerService) EraseUserId(userId uint) error {
63
+	return global.DB.Model(&model.Peer{}).Where("user_id = ?", userId).Update("user_id", 0).Error
64
+}
65
+
38
 // ListByUserIds 根据用户id取列表
66
 // ListByUserIds 根据用户id取列表
39
 func (ps *PeerService) ListByUserIds(userIds []uint, page, pageSize uint) (res *model.PeerList) {
67
 func (ps *PeerService) ListByUserIds(userIds []uint, page, pageSize uint) (res *model.PeerList) {
40
 	res = &model.PeerList{}
68
 	res = &model.PeerList{}
@@ -62,21 +90,57 @@ func (ps *PeerService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *
62
 	return
90
 	return
63
 }
91
 }
64
 
92
 
93
+// ListFilterByUserId 根据用户id过滤Peer列表
94
+func (ps *PeerService) ListFilterByUserId(page, pageSize uint, where func(tx *gorm.DB), userId uint) (res *model.PeerList) {
95
+	userWhere := func(tx *gorm.DB) {
96
+		tx.Where("user_id = ?", userId)
97
+		// 如果还有额外的筛选条件,执行它
98
+		if where != nil {
99
+			where(tx)
100
+		}
101
+	}
102
+	return ps.List(page, pageSize, userWhere)
103
+}
104
+
65
 // Create 创建
105
 // Create 创建
66
 func (ps *PeerService) Create(u *model.Peer) error {
106
 func (ps *PeerService) Create(u *model.Peer) error {
67
 	res := global.DB.Create(u).Error
107
 	res := global.DB.Create(u).Error
68
 	return res
108
 	return res
69
 }
109
 }
110
+
111
+// Delete 删除, 同时也应该删除token
70
 func (ps *PeerService) Delete(u *model.Peer) error {
112
 func (ps *PeerService) Delete(u *model.Peer) error {
71
-	return global.DB.Delete(u).Error
113
+	uuid := u.Uuid
114
+	err := global.DB.Delete(u).Error
115
+	if err != nil {
116
+		return err
117
+	}
118
+	// 删除token
119
+	return AllService.UserService.FlushTokenByUuid(uuid)
72
 }
120
 }
73
 
121
 
74
-// BatchDelete
122
+// GetUuidListByIDs 根据ids获取uuid列表
123
+func (ps *PeerService) GetUuidListByIDs(ids []uint) ([]string, error) {
124
+	var uuids []string
125
+	err := global.DB.Model(&model.Peer{}).
126
+		Where("row_id in (?)", ids).
127
+		Pluck("uuid", &uuids).Error
128
+	return uuids, err
129
+}
130
+
131
+// BatchDelete 批量删除, 同时也应该删除token
75
 func (ps *PeerService) BatchDelete(ids []uint) error {
132
 func (ps *PeerService) BatchDelete(ids []uint) error {
76
-	return global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error
133
+	uuids, err := ps.GetUuidListByIDs(ids)
134
+	err = global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error
135
+	if err != nil {
136
+		return err
137
+	}
138
+	// 删除token
139
+	return AllService.UserService.FlushTokenByUuids(uuids)
77
 }
140
 }
78
 
141
 
79
 // Update 更新
142
 // Update 更新
80
 func (ps *PeerService) Update(u *model.Peer) error {
143
 func (ps *PeerService) Update(u *model.Peer) error {
81
 	return global.DB.Model(u).Updates(u).Error
144
 	return global.DB.Model(u).Updates(u).Error
82
 }
145
 }
146
+

+ 131 - 66
service/user.go

@@ -10,6 +10,8 @@ import (
10
 	"math/rand"
10
 	"math/rand"
11
 	"strconv"
11
 	"strconv"
12
 	"time"
12
 	"time"
13
+	"strings"
14
+	"errors"
13
 )
15
 )
14
 
16
 
15
 type UserService struct {
17
 type UserService struct {
@@ -21,12 +23,20 @@ func (us *UserService) InfoById(id uint) *model.User {
21
 	global.DB.Where("id = ?", id).First(u)
23
 	global.DB.Where("id = ?", id).First(u)
22
 	return u
24
 	return u
23
 }
25
 }
26
+// InfoByUsername 根据用户名取用户信息
24
 func (us *UserService) InfoByUsername(un string) *model.User {
27
 func (us *UserService) InfoByUsername(un string) *model.User {
25
 	u := &model.User{}
28
 	u := &model.User{}
26
 	global.DB.Where("username = ?", un).First(u)
29
 	global.DB.Where("username = ?", un).First(u)
27
 	return u
30
 	return u
28
 }
31
 }
29
 
32
 
33
+// InfoByEmail 根据邮箱取用户信息
34
+func (us *UserService) InfoByEmail(email string) *model.User {
35
+	u := &model.User{}
36
+	global.DB.Where("email = ?", email).First(u)
37
+	return u
38
+}
39
+
30
 // InfoByOpenid 根据openid取用户信息
40
 // InfoByOpenid 根据openid取用户信息
31
 func (us *UserService) InfoByOpenid(openid string) *model.User {
41
 func (us *UserService) InfoByOpenid(openid string) *model.User {
32
 	u := &model.User{}
42
 	u := &model.User{}
@@ -65,15 +75,17 @@ func (us *UserService) GenerateToken(u *model.User) string {
65
 func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
75
 func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
66
 	token := us.GenerateToken(u)
76
 	token := us.GenerateToken(u)
67
 	ut := &model.UserToken{
77
 	ut := &model.UserToken{
68
-		UserId:    u.Id,
69
-		Token:     token,
70
-		ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
78
+		UserId:    	u.Id,
79
+		Token:     	token,
80
+		DeviceUuid: llog.Uuid,
81
+		DeviceId:   llog.DeviceId,
82
+		ExpiredAt: 	time.Now().Add(time.Hour * 24 * 7).Unix(),
71
 	}
83
 	}
72
 	global.DB.Create(ut)
84
 	global.DB.Create(ut)
73
 	llog.UserTokenId = ut.UserId
85
 	llog.UserTokenId = ut.UserId
74
 	global.DB.Create(llog)
86
 	global.DB.Create(llog)
75
 	if llog.Uuid != "" {
87
 	if llog.Uuid != "" {
76
-		AllService.PeerService.UuidBindUserId(llog.Uuid, u.Id)
88
+		AllService.PeerService.UuidBindUserId(llog.DeviceId, llog.Uuid, u.Id)
77
 	}
89
 	}
78
 	return ut
90
 	return ut
79
 }
91
 }
@@ -140,18 +152,42 @@ func (us *UserService) CheckUserEnable(u *model.User) bool {
140
 
152
 
141
 // Create 创建
153
 // Create 创建
142
 func (us *UserService) Create(u *model.User) error {
154
 func (us *UserService) Create(u *model.User) error {
155
+	// The initial username should be formatted, and the username should be unique
156
+	u.Username = us.formatUsername(u.Username)
143
 	u.Password = us.EncryptPassword(u.Password)
157
 	u.Password = us.EncryptPassword(u.Password)
144
 	res := global.DB.Create(u).Error
158
 	res := global.DB.Create(u).Error
145
 	return res
159
 	return res
146
 }
160
 }
147
 
161
 
148
-// Logout 退出登录
162
+// GetUuidByToken 根据token和user取uuid
163
+func (us *UserService) GetUuidByToken(u *model.User, token string) string {
164
+	ut := &model.UserToken{}
165
+	err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
166
+	if err != nil {
167
+		return ""
168
+	}
169
+	return ut.DeviceUuid
170
+}
171
+
172
+// Logout 退出登录 -> 删除token, 解绑uuid
149
 func (us *UserService) Logout(u *model.User, token string) error {
173
 func (us *UserService) Logout(u *model.User, token string) error {
150
-	return global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error
174
+	uuid := us.GetUuidByToken(u, token)
175
+	err := global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error
176
+	if err != nil {
177
+		return err
178
+	}
179
+	if uuid != "" {
180
+		AllService.PeerService.UuidUnbindUserId(uuid, u.Id)
181
+	}
182
+	return nil
151
 }
183
 }
152
 
184
 
153
 // Delete 删除用户和oauth信息
185
 // Delete 删除用户和oauth信息
154
 func (us *UserService) Delete(u *model.User) error {
186
 func (us *UserService) Delete(u *model.User) error {
187
+	userCount := us.getAdminUserCount()
188
+	if userCount <= 1 && us.IsAdmin(u) {
189
+		return errors.New("The last admin user cannot be deleted")
190
+	}
155
 	tx := global.DB.Begin()
191
 	tx := global.DB.Begin()
156
 	// 删除用户
192
 	// 删除用户
157
 	if err := tx.Delete(u).Error; err != nil {
193
 	if err := tx.Delete(u).Error; err != nil {
@@ -179,11 +215,25 @@ func (us *UserService) Delete(u *model.User) error {
179
 		return err
215
 		return err
180
 	}
216
 	}
181
 	tx.Commit()
217
 	tx.Commit()
218
+	// 删除关联的peer
219
+	if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
220
+		tx.Rollback()
221
+		return err
222
+	}
182
 	return nil
223
 	return nil
183
 }
224
 }
184
 
225
 
185
 // Update 更新
226
 // Update 更新
186
 func (us *UserService) Update(u *model.User) error {
227
 func (us *UserService) Update(u *model.User) error {
228
+	currentUser := us.InfoById(u.Id)
229
+	// 如果当前用户是管理员并且 IsAdmin 不为空,进行检查
230
+	if us.IsAdmin(currentUser) {
231
+		adminCount := us.getAdminUserCount()
232
+		// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
233
+		if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
234
+			return errors.New("The last admin user cannot be disabled or demoted")
235
+		}
236
+	}
187
 	return global.DB.Model(u).Updates(u).Error
237
 	return global.DB.Model(u).Updates(u).Error
188
 }
238
 }
189
 
239
 
@@ -192,6 +242,16 @@ func (us *UserService) FlushToken(u *model.User) error {
192
 	return global.DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error
242
 	return global.DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error
193
 }
243
 }
194
 
244
 
245
+// FlushTokenByUuid 清空token
246
+func (us *UserService) FlushTokenByUuid(uuid string) error {
247
+	return global.DB.Where("device_uuid = ?", uuid).Delete(&model.UserToken{}).Error
248
+}
249
+
250
+// FlushTokenByUuids 清空token
251
+func (us *UserService) FlushTokenByUuids(uuids []string) error {
252
+	return global.DB.Where("device_uuid in (?)", uuids).Delete(&model.UserToken{}).Error
253
+}
254
+
195
 // UpdatePassword 更新密码
255
 // UpdatePassword 更新密码
196
 func (us *UserService) UpdatePassword(u *model.User, password string) error {
256
 func (us *UserService) UpdatePassword(u *model.User, password string) error {
197
 	u.Password = us.EncryptPassword(password)
257
 	u.Password = us.EncryptPassword(password)
@@ -216,24 +276,9 @@ func (us *UserService) RouteNames(u *model.User) []string {
216
 	return adResp.UserRouteNames
276
 	return adResp.UserRouteNames
217
 }
277
 }
218
 
278
 
219
-// InfoByGithubId 根据githubid取用户信息
220
-func (us *UserService) InfoByGithubId(githubId string) *model.User {
221
-	return us.InfoByOauthId(model.OauthTypeGithub, githubId)
222
-}
223
-
224
-// InfoByGoogleEmail 根据googleid取用户信息
225
-func (us *UserService) InfoByGoogleEmail(email string) *model.User {
226
-	return us.InfoByOauthId(model.OauthTypeGithub, email)
227
-}
228
-
229
-// InfoByOidcSub 根据oidc取用户信息
230
-func (us *UserService) InfoByOidcSub(sub string) *model.User {
231
-	return us.InfoByOauthId(model.OauthTypeOidc, sub)
232
-}
233
-
234
-// InfoByOauthId 根据oauth取用户信息
235
-func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User {
236
-	ut := AllService.OauthService.UserThirdInfo(thirdType, uid)
279
+// InfoByOauthId 根据oauth的name和openId取用户信息
280
+func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
281
+	ut := AllService.OauthService.UserThirdInfo(op, openId)
237
 	if ut.Id == 0 {
282
 	if ut.Id == 0 {
238
 		return nil
283
 		return nil
239
 	}
284
 	}
@@ -244,55 +289,52 @@ func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User {
244
 	return u
289
 	return u
245
 }
290
 }
246
 
291
 
247
-// RegisterByGithub 注册
248
-func (us *UserService) RegisterByGithub(githubName string, githubId string) *model.User {
249
-	return us.RegisterByOauth(model.OauthTypeGithub, githubName, githubId)
250
-}
251
-
252
-// RegisterByGoogle 注册
253
-func (us *UserService) RegisterByGoogle(name string, email string) *model.User {
254
-	return us.RegisterByOauth(model.OauthTypeGoogle, name, email)
255
-}
256
-
257
-// RegisterByOidc 注册, use PreferredUsername as username, sub as openid
258
-func (us *UserService) RegisterByOidc(PreferredUsername string, sub string) *model.User {
259
-	return us.RegisterByOauth(model.OauthTypeOidc, PreferredUsername, sub)
260
-}
261
-
262
 // RegisterByOauth 注册
292
 // RegisterByOauth 注册
263
-func (us *UserService) RegisterByOauth(thirdType, thirdName, uid string) *model.User {
293
+func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) {
264
 	global.Lock.Lock("registerByOauth")
294
 	global.Lock.Lock("registerByOauth")
265
 	defer global.Lock.UnLock("registerByOauth")
295
 	defer global.Lock.UnLock("registerByOauth")
266
-	ut := AllService.OauthService.UserThirdInfo(thirdType, uid)
296
+	ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
267
 	if ut.Id != 0 {
297
 	if ut.Id != 0 {
268
-		u := &model.User{}
269
-		global.DB.Where("id = ?", ut.UserId).First(u)
270
-		return u
271
-	}
272
-
273
-	tx := global.DB.Begin()
274
-	ut = &model.UserThird{
275
-		OpenId:    uid,
276
-		ThirdName: thirdName,
277
-		ThirdType: thirdType,
298
+		return nil, us.InfoById(ut.UserId)
278
 	}
299
 	}
279
-
280
-	username := us.GenerateUsernameByOauth(thirdName)
281
-	u := &model.User{
282
-		Username: username,
283
-		GroupId:  1,
300
+	//check if this email has been registered 
301
+	email := oauthUser.Email
302
+	err, oauthType := AllService.OauthService.GetTypeByOp(op)
303
+	if err != nil {
304
+		return err, nil
284
 	}
305
 	}
285
-	tx.Create(u)
286
-	if u.Id == 0 {
287
-		tx.Rollback()
288
-		return u
306
+	// if email is empty, use username and op as email
307
+	if email == "" {
308
+		email = oauthUser.Username + "@" + op
309
+	} 
310
+	email = strings.ToLower(email)
311
+	// update email to oauthUser, in case it contain upper case
312
+	oauthUser.Email = email
313
+	user := us.InfoByEmail(email)
314
+	tx := global.DB.Begin()
315
+	if user.Id != 0 {
316
+		ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
317
+	} else {
318
+		ut = &model.UserThird{}
319
+		ut.FromOauthUser(0, oauthUser, oauthType, op)
320
+		// The initial username should be formatted
321
+		username := us.formatUsername(oauthUser.Username)
322
+		usernameUnique := us.GenerateUsernameByOauth(username)
323
+		user = &model.User{
324
+			Username: usernameUnique,
325
+			GroupId:  1,
326
+		}
327
+		oauthUser.ToUser(user, false)
328
+		tx.Create(user)
329
+		if user.Id == 0 {
330
+			tx.Rollback()
331
+			return errors.New("OauthRegisterFailed"), user
332
+		}
333
+		ut.UserId = user.Id
289
 	}
334
 	}
290
-
291
-	ut.UserId = u.Id
292
 	tx.Create(ut)
335
 	tx.Create(ut)
293
-
294
 	tx.Commit()
336
 	tx.Commit()
295
-	return u
337
+	return nil, user
296
 }
338
 }
297
 
339
 
298
 // GenerateUsernameByOauth 生成用户名
340
 // GenerateUsernameByOauth 生成用户名
@@ -314,7 +356,7 @@ func (us *UserService) UserThirdsByUserId(userId uint) (res []*model.UserThird)
314
 
356
 
315
 func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird {
357
 func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird {
316
 	ut := &model.UserThird{}
358
 	ut := &model.UserThird{}
317
-	global.DB.Where("user_id = ? and third_type = ?", userId, op).First(ut)
359
+	global.DB.Where("user_id = ? and op = ?", userId, op).First(ut)
318
 	return ut
360
 	return ut
319
 }
361
 }
320
 
362
 
@@ -348,9 +390,11 @@ func (us *UserService) IsPasswordEmptyByUser(u *model.User) bool {
348
 	return us.IsPasswordEmptyById(u.Id)
390
 	return us.IsPasswordEmptyById(u.Id)
349
 }
391
 }
350
 
392
 
351
-func (us *UserService) Register(username string, password string) *model.User {
393
+// Register 注册
394
+func (us *UserService) Register(username string, email string, password string) *model.User {
352
 	u := &model.User{
395
 	u := &model.User{
353
 		Username: username,
396
 		Username: username,
397
+		Email:    email,
354
 		Password: us.EncryptPassword(password),
398
 		Password: us.EncryptPassword(password),
355
 		GroupId:  1,
399
 		GroupId:  1,
356
 	}
400
 	}
@@ -381,3 +425,24 @@ func (us *UserService) TokenInfoById(id uint) *model.UserToken {
381
 func (us *UserService) DeleteToken(l *model.UserToken) error {
425
 func (us *UserService) DeleteToken(l *model.UserToken) error {
382
 	return global.DB.Delete(l).Error
426
 	return global.DB.Delete(l).Error
383
 }
427
 }
428
+
429
+// Helper functions, used for formatting username
430
+func (us *UserService) formatUsername(username string) string {
431
+	username = strings.ReplaceAll(username, " ", "")
432
+	username = strings.ToLower(username)
433
+	return username
434
+}
435
+
436
+//  Helper functions, getUserCount
437
+func (us *UserService) getUserCount() int64 {
438
+	var count int64
439
+	global.DB.Model(&model.User{}).Count(&count)
440
+	return count
441
+}
442
+
443
+// helper functions, getAdminUserCount
444
+func (us *UserService) getAdminUserCount() int64 {
445
+	var count int64
446
+	global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
447
+	return count
448
+}