Просмотр исходного кода

Merge branch 'oauth_re' of https://github.com/IamTaoChen/rustdesk-api into IamTaoChen-oauth_re

ljw 1 год назад
Родитель
Сommit
4321a41cd7

+ 4 - 1
Dockerfile.dev

@@ -42,8 +42,11 @@ RUN if [ "$COUNTRY" = "CN" ] ; then \
42 42
     fi && \
43 43
     apk update && apk add --no-cache git
44 44
 
45
+ARG FREONTEND_GIT_REPO=https://github.com/lejianwen/rustdesk-api-web.git
46
+ARG FRONTEND_GIT_BRANCH=master
45 47
 # Clone the frontend repository
46
-RUN git clone https://github.com/lejianwen/rustdesk-api-web .
48
+
49
+RUN git clone -b $FRONTEND_GIT_BRANCH $FREONTEND_GIT_REPO .
47 50
 
48 51
 # Install required tools without caching index to minimize image size
49 52
 RUN if [ "$COUNTRY" = "CN" ] ; then \

+ 2 - 0
docker-compose-dev.yaml

@@ -5,6 +5,8 @@ services:
5 5
       dockerfile: Dockerfile.dev
6 6
       args:
7 7
         COUNTRY: CN
8
+        FREONTEND_GIT_REPO: https://github.com/lejianwen/rustdesk-api-web.git
9
+        FRONTEND_GIT_BRANCH: master
8 10
     # image: lejianwen/rustdesk-api
9 11
     container_name: rustdesk-api
10 12
     environment:

+ 12 - 22
http/controller/admin/login.go

@@ -11,7 +11,6 @@ import (
11 11
 	"Gwen/service"
12 12
 	"fmt"
13 13
 	"github.com/gin-gonic/gin"
14
-	"gorm.io/gorm"
15 14
 )
16 15
 
17 16
 type Login struct {
@@ -60,12 +59,7 @@ func (ct *Login) Login(c *gin.Context) {
60 59
 		Platform: f.Platform,
61 60
 	})
62 61
 
63
-	response.Success(c, &adResp.LoginPayload{
64
-		Token:      ut.Token,
65
-		Username:   u.Username,
66
-		RouteNames: service.AllService.UserService.RouteNames(u),
67
-		Nickname:   u.Nickname,
68
-	})
62
+	responseLoginSuccess(c, u, ut.Token)
69 63
 }
70 64
 
71 65
 // Logout 登出
@@ -96,13 +90,7 @@ func (ct *Login) Logout(c *gin.Context) {
96 90
 // @Failure 500 {object} response.ErrorResponse
97 91
 // @Router /admin/login-options [post]
98 92
 func (ct *Login) LoginOptions(c *gin.Context) {
99
-	res := service.AllService.OauthService.List(1, 100, func(tx *gorm.DB) {
100
-		tx.Select("op").Order("id")
101
-	})
102
-	var ops []string
103
-	for _, v := range res.Oauths {
104
-		ops = append(ops, v.Op)
105
-	}
93
+	ops := service.AllService.OauthService.GetOauthProviders()
106 94
 	response.Success(c, gin.H{
107 95
 		"ops":      ops,
108 96
 		"register": global.Config.App.Register,
@@ -163,12 +151,14 @@ func (ct *Login) OidcAuthQuery(c *gin.Context) {
163 151
 	if ut == nil {
164 152
 		return
165 153
 	}
166
-	//fmt.Println("u:", u)
167
-	//fmt.Println("ut:", ut)
168
-	response.Success(c, &adResp.LoginPayload{
169
-		Token:      ut.Token,
170
-		Username:   u.Username,
171
-		RouteNames: service.AllService.UserService.RouteNames(u),
172
-		Nickname:   u.Nickname,
173
-	})
154
+	responseLoginSuccess(c, u, ut.Token)
174 155
 }
156
+
157
+
158
+func responseLoginSuccess(c *gin.Context, u *model.User, token string) {
159
+	lp := &adResp.LoginPayload{}
160
+	lp.FromUser(u)
161
+	lp.Token = token
162
+	lp.RouteNames = service.AllService.UserService.RouteNames(u)
163
+	response.Success(c, lp)
164
+}

+ 13 - 29
http/controller/admin/oauth.go

@@ -5,7 +5,6 @@ import (
5 5
 	"Gwen/http/request/admin"
6 6
 	adminReq "Gwen/http/request/admin"
7 7
 	"Gwen/http/response"
8
-	"Gwen/model"
9 8
 	"Gwen/service"
10 9
 	"github.com/gin-gonic/gin"
11 10
 	"strconv"
@@ -96,21 +95,23 @@ func (o *Oauth) BindConfirm(c *gin.Context) {
96 95
 		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
97 96
 		return
98 97
 	}
99
-	v := service.AllService.OauthService.GetOauthCache(j.Code)
100
-	if v == nil {
98
+	oauthService := service.AllService.OauthService
99
+	oauthCache := oauthService.GetOauthCache(j.Code)
100
+	if oauthCache == nil {
101 101
 		response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired"))
102 102
 		return
103 103
 	}
104
-	u := service.AllService.UserService.CurUser(c)
105
-	err = service.AllService.OauthService.BindOauthUser(v.Op, v.ThirdOpenId, v.ThirdName, u.Id)
104
+	oauthUser := oauthCache.ToOauthUser()
105
+	user := service.AllService.UserService.CurUser(c)
106
+	err = oauthService.BindOauthUser(user.Id, oauthUser, oauthCache.Op)
106 107
 	if err != nil {
107 108
 		response.Fail(c, 101, response.TranslateMsg(c, "BindFail"))
108 109
 		return
109 110
 	}
110 111
 
111
-	v.UserId = u.Id
112
-	service.AllService.OauthService.SetOauthCache(j.Code, v, 0)
113
-	response.Success(c, v)
112
+	oauthCache.UserId = user.Id
113
+	oauthService.SetOauthCache(j.Code, oauthCache, 0)
114
+	response.Success(c, oauthCache)
114 115
 }
115 116
 
116 117
 func (o *Oauth) Unbind(c *gin.Context) {
@@ -126,28 +127,11 @@ func (o *Oauth) Unbind(c *gin.Context) {
126 127
 		response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound"))
127 128
 		return
128 129
 	}
129
-	if f.Op == model.OauthTypeGithub {
130
-		err = service.AllService.OauthService.UnBindGithubUser(u.Id)
131
-		if err != nil {
132
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
133
-			return
134
-		}
135
-	}
136
-	if f.Op == model.OauthTypeGoogle {
137
-		err = service.AllService.OauthService.UnBindGoogleUser(u.Id)
138
-		if err != nil {
139
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
140
-			return
141
-		}
142
-	}
143
-	if f.Op == model.OauthTypeOidc {
144
-		err = service.AllService.OauthService.UnBindOidcUser(u.Id)
145
-		if err != nil {
146
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
147
-			return
148
-		}
130
+	err = service.AllService.OauthService.UnBindOauthUser(u.Id, f.Op)
131
+	if err != nil {
132
+		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
133
+		return
149 134
 	}
150
-
151 135
 	response.Success(c, nil)
152 136
 }
153 137
 

+ 51 - 15
http/controller/admin/user.go

@@ -10,6 +10,7 @@ import (
10 10
 	"github.com/gin-gonic/gin"
11 11
 	"gorm.io/gorm"
12 12
 	"strconv"
13
+	"time"
13 14
 )
14 15
 
15 16
 type User struct {
@@ -216,12 +217,7 @@ func (ct *User) Current(c *gin.Context) {
216 217
 	u := service.AllService.UserService.CurUser(c)
217 218
 	token, _ := c.Get("token")
218 219
 	t := token.(string)
219
-	response.Success(c, &adResp.LoginPayload{
220
-		Token:      t,
221
-		Username:   u.Username,
222
-		RouteNames: service.AllService.UserService.RouteNames(u),
223
-		Nickname:   u.Nickname,
224
-	})
220
+	responseLoginSuccess(c, u, t)
225 221
 }
226 222
 
227 223
 // ChangeCurPwd 修改当前用户密码
@@ -286,10 +282,10 @@ func (ct *User) MyOauth(c *gin.Context) {
286 282
 	var res []*adResp.UserOauthItem
287 283
 	for _, oa := range oal.Oauths {
288 284
 		item := &adResp.UserOauthItem{
289
-			ThirdType: oa.Op,
285
+			Op: oa.Op,
290 286
 		}
291 287
 		for _, ut := range uts {
292
-			if ut.ThirdType == oa.Op {
288
+			if ut.Op == oa.Op {
293 289
 				item.Status = 1
294 290
 				break
295 291
 			}
@@ -299,6 +295,51 @@ func (ct *User) MyOauth(c *gin.Context) {
299 295
 	response.Success(c, res)
300 296
 }
301 297
 
298
+// List 列表
299
+// @Tags 设备
300
+// @Summary 设备列表
301
+// @Description 设备列表
302
+// @Accept  json
303
+// @Produce  json
304
+// @Param page query int false "页码"
305
+// @Param page_size query int false "页大小"
306
+// @Param time_ago query int false "时间"
307
+// @Param id query string false "ID"
308
+// @Param hostname query string false "主机名"
309
+// @Param uuids query string false "uuids 用逗号分隔"
310
+// @Success 200 {object} response.Response{data=model.PeerList}
311
+// @Failure 500 {object} response.Response
312
+// @Router /admin/user/myPeer [get]
313
+// @Security token
314
+func (ct *User) MyPeer(c *gin.Context) {
315
+	query := &admin.PeerQuery{}
316
+	if err := c.ShouldBindQuery(query); err != nil {
317
+		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
318
+		return
319
+	}
320
+	u := service.AllService.UserService.CurUser(c)
321
+	res := service.AllService.PeerService.ListFilterByUserId(query.Page, query.PageSize, func(tx *gorm.DB) {
322
+		if query.TimeAgo > 0 {
323
+			lt := time.Now().Unix() - int64(query.TimeAgo)
324
+			tx.Where("last_online_time < ?", lt)
325
+		}
326
+		if query.TimeAgo < 0 {
327
+			lt := time.Now().Unix() + int64(query.TimeAgo)
328
+			tx.Where("last_online_time > ?", lt)
329
+		}
330
+		if query.Id != "" {
331
+			tx.Where("id like ?", "%"+query.Id+"%")
332
+		}
333
+		if query.Hostname != "" {
334
+			tx.Where("hostname like ?", "%"+query.Hostname+"%")
335
+		}
336
+		if query.Uuids != "" {
337
+			tx.Where("uuid in (?)", query.Uuids)
338
+		}
339
+	}, u.Id)
340
+	response.Success(c, res)
341
+}
342
+
302 343
 // groupUsers
303 344
 func (ct *User) GroupUsers(c *gin.Context) {
304 345
 	q := &admin.GroupUsersQuery{}
@@ -345,7 +386,7 @@ func (ct *User) Register(c *gin.Context) {
345 386
 		response.Fail(c, 101, errList[0])
346 387
 		return
347 388
 	}
348
-	u := service.AllService.UserService.Register(f.Username, f.Password)
389
+	u := service.AllService.UserService.Register(f.Username, f.Email, f.Password)
349 390
 	if u == nil || u.Id == 0 {
350 391
 		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed"))
351 392
 		return
@@ -358,10 +399,5 @@ func (ct *User) Register(c *gin.Context) {
358 399
 		Ip:     c.ClientIP(),
359 400
 		Type:   model.LoginLogTypeAccount,
360 401
 	})
361
-	response.Success(c, &adResp.LoginPayload{
362
-		Token:      ut.Token,
363
-		Username:   u.Username,
364
-		RouteNames: service.AllService.UserService.RouteNames(u),
365
-		Nickname:   u.Nickname,
366
-	})
402
+	responseLoginSuccess(c, u, ut.Token)
367 403
 }

+ 5 - 16
http/controller/api/login.go

@@ -60,6 +60,7 @@ func (l *Login) Login(c *gin.Context) {
60 60
 	ut := service.AllService.UserService.Login(u, &model.LoginLog{
61 61
 		UserId:   u.Id,
62 62
 		Client:   f.DeviceInfo.Type,
63
+		DeviceId: f.Id,
63 64
 		Uuid:     f.Uuid,
64 65
 		Ip:       c.ClientIP(),
65 66
 		Type:     model.LoginLogTypeAccount,
@@ -83,22 +84,10 @@ func (l *Login) Login(c *gin.Context) {
83 84
 // @Failure 500 {object} response.ErrorResponse
84 85
 // @Router /login-options [get]
85 86
 func (l *Login) LoginOptions(c *gin.Context) {
86
-	oauthOks := []string{}
87
-	err, _ := service.AllService.OauthService.GetOauthConfig(model.OauthTypeGithub)
88
-	if err == nil {
89
-		oauthOks = append(oauthOks, model.OauthTypeGithub)
90
-	}
91
-	err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeGoogle)
92
-	if err == nil {
93
-		oauthOks = append(oauthOks, model.OauthTypeGoogle)
94
-	}
95
-	err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeOidc)
96
-	if err == nil {
97
-		oauthOks = append(oauthOks, model.OauthTypeOidc)
98
-	}
99
-	oauthOks = append(oauthOks, model.OauthTypeWebauth)
87
+	ops := service.AllService.OauthService.GetOauthProviders()
88
+	ops = append(ops, model.OauthTypeWebauth)
100 89
 	var oidcItems []map[string]string
101
-	for _, v := range oauthOks {
90
+	for _, v := range ops {
102 91
 		oidcItems = append(oidcItems, map[string]string{"name": v})
103 92
 	}
104 93
 	common, err := json.Marshal(oidcItems)
@@ -108,7 +97,7 @@ func (l *Login) LoginOptions(c *gin.Context) {
108 97
 	}
109 98
 	var res []string
110 99
 	res = append(res, "common-oidc/"+string(common))
111
-	for _, v := range oauthOks {
100
+	for _, v := range ops {
112 101
 		res = append(res, "oidc/"+v)
113 102
 	}
114 103
 	c.JSON(http.StatusOK, res)

+ 38 - 69
http/controller/api/ouath.go

@@ -9,8 +9,6 @@ import (
9 9
 	"Gwen/service"
10 10
 	"github.com/gin-gonic/gin"
11 11
 	"net/http"
12
-	"strconv"
13
-	"strings"
14 12
 )
15 13
 
16 14
 type Oauth struct {
@@ -32,13 +30,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
32 30
 		response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
33 31
 		return
34 32
 	}
35
-	//fmt.Println(f)
36
-	if f.Op != model.OauthTypeWebauth && f.Op != model.OauthTypeGoogle && f.Op != model.OauthTypeGithub && f.Op != model.OauthTypeOidc {
37
-		response.Error(c, response.TranslateMsg(c, "ParamsError"))
38
-		return
39
-	}
40 33
 
41
-	err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
34
+	oauthService := service.AllService.OauthService
35
+	var code string
36
+	var url string
37
+	err, code, url = oauthService.BeginAuth(f.Op)
42 38
 	if err != nil {
43 39
 		response.Error(c, response.TranslateMsg(c, err.Error()))
44 40
 		return
@@ -98,6 +94,7 @@ func (o *Oauth) OidcAuthQueryPre(c *gin.Context) (*model.User, *model.UserToken)
98 94
 	ut = service.AllService.UserService.Login(u, &model.LoginLog{
99 95
 		UserId:   u.Id,
100 96
 		Client:   v.DeviceType,
97
+		DeviceId: v.Id,
101 98
 		Uuid:     v.Uuid,
102 99
 		Ip:       c.ClientIP(),
103 100
 		Type:     model.LoginLogTypeOauth,
@@ -149,70 +146,43 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
149 146
 		c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state"))
150 147
 		return
151 148
 	}
152
-
153 149
 	cacheKey := state
150
+	oauthService := service.AllService.OauthService
154 151
 	//从缓存中获取
155
-	v := service.AllService.OauthService.GetOauthCache(cacheKey)
156
-	if v == nil {
152
+	oauthCache := oauthService.GetOauthCache(cacheKey)
153
+	if oauthCache == nil {
157 154
 		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired"))
158 155
 		return
159 156
 	}
160
-
161
-	ty := v.Op
162
-	ac := v.Action
163
-	var u *model.User
164
-	openid := ""
165
-	thirdName := ""
166
-	//fmt.Println("ty ac ", ty, ac)
167
-
168
-	if ty == model.OauthTypeGithub {
169
-		code := c.Query("code")
170
-		err, userData := service.AllService.OauthService.GithubCallback(code)
171
-		if err != nil {
172
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
173
-			return
174
-		}
175
-		openid = strconv.Itoa(userData.Id)
176
-		thirdName = userData.Login
177
-	} else if ty == model.OauthTypeGoogle {
178
-		code := c.Query("code")
179
-		err, userData := service.AllService.OauthService.GoogleCallback(code)
180
-		if err != nil {
181
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
182
-			return
183
-		}
184
-		openid = userData.Email
185
-		//将空格替换成_
186
-		thirdName = strings.Replace(userData.Name, " ", "_", -1)
187
-	} else if ty == model.OauthTypeOidc {
188
-		code := c.Query("code")
189
-		err, userData := service.AllService.OauthService.OidcCallback(code)
190
-		if err != nil {
191
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
192
-			return
193
-		}
194
-		openid = userData.Sub
195
-		thirdName = userData.PreferredUsername
196
-	} else {
197
-		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError"))
157
+	op := oauthCache.Op
158
+	action := oauthCache.Action
159
+	var user *model.User
160
+	// 获取用户信息
161
+	code := c.Query("code")
162
+	err, oauthUser := oauthService.Callback(code, op)
163
+	if err != nil {
164
+		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
198 165
 		return
199 166
 	}
200
-	if ac == service.OauthActionTypeBind {
167
+	userId := oauthCache.UserId
168
+	openid := oauthUser.OpenId
169
+	if action == service.OauthActionTypeBind {
201 170
 
202 171
 		//fmt.Println("bind", ty, userData)
203
-		utr := service.AllService.OauthService.UserThirdInfo(ty, openid)
172
+		// 检查此openid是否已经绑定过
173
+		utr := oauthService.UserThirdInfo(op, openid)
204 174
 		if utr.UserId > 0 {
205 175
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser"))
206 176
 			return
207 177
 		}
208 178
 		//绑定
209
-		u = service.AllService.UserService.InfoById(v.UserId)
210
-		if u == nil {
179
+		user = service.AllService.UserService.InfoById(userId)
180
+		if user == nil {
211 181
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound"))
212 182
 			return
213 183
 		}
214 184
 		//绑定
215
-		err := service.AllService.OauthService.BindOauthUser(ty, openid, thirdName, v.UserId)
185
+		err := oauthService.BindOauthUser(userId, oauthUser, op)
216 186
 		if err != nil {
217 187
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail"))
218 188
 			return
@@ -220,42 +190,41 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
220 190
 		c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess"))
221 191
 		return
222 192
 
223
-	} else if ac == service.OauthActionTypeLogin {
193
+	} else if action == service.OauthActionTypeLogin {
224 194
 		//登录
225
-		if v.UserId != 0 {
195
+		if userId != 0 {
226 196
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess"))
227 197
 			return
228 198
 		}
229
-		u = service.AllService.UserService.InfoByGithubId(openid)
230
-		if u == nil {
231
-			oa := service.AllService.OauthService.InfoByOp(ty)
232
-			if !*oa.AutoRegister {
199
+		user = service.AllService.UserService.InfoByOauthId(op, openid)
200
+		if user == nil {
201
+			oauthConfig := oauthService.InfoByOp(op)
202
+			if !*oauthConfig.AutoRegister {
233 203
 				//c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定")
234
-				v.ThirdName = thirdName
235
-				v.ThirdOpenId = openid
204
+				oauthCache.UpdateFromOauthUser(oauthUser)
236 205
 				url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey
237 206
 				c.Redirect(http.StatusFound, url)
238 207
 				return
239 208
 			}
240 209
 
241 210
 			//自动注册
242
-			u = service.AllService.UserService.RegisterByOauth(ty, thirdName, openid)
243
-			if u.Id == 0 {
244
-				c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed"))
211
+			err, user = service.AllService.UserService.RegisterByOauth(oauthUser, op)
212
+			if err != nil {
213
+				c.String(http.StatusInternalServerError, response.TranslateMsg(c, err.Error()))
245 214
 				return
246 215
 			}
247 216
 		}
248
-		v.UserId = u.Id
249
-		service.AllService.OauthService.SetOauthCache(cacheKey, v, 0)
217
+		oauthCache.UserId = user.Id
218
+		oauthService.SetOauthCache(cacheKey, oauthCache, 0)
250 219
 		// 如果是webadmin,登录成功后跳转到webadmin
251
-		if v.DeviceType == "webadmin" {
220
+		if oauthCache.DeviceType == "webadmin" {
252 221
 			/*service.AllService.UserService.Login(u, &model.LoginLog{
253 222
 				UserId:   u.Id,
254 223
 				Client:   "webadmin",
255 224
 				Uuid:     "", //must be empty
256 225
 				Ip:       c.ClientIP(),
257 226
 				Type:     model.LoginLogTypeOauth,
258
-				Platform: v.DeviceOs,
227
+				Platform: oauthService.DeviceOs,
259 228
 			})*/
260 229
 			url := global.Config.Rustdesk.ApiServer + "/_admin/#/"
261 230
 			c.Redirect(http.StatusFound, url)

+ 13 - 9
http/request/admin/oauth.go

@@ -1,6 +1,8 @@
1 1
 package admin
2 2
 
3
-import "Gwen/model"
3
+import (
4
+	"Gwen/model"
5
+)
4 6
 
5 7
 type BindOauthForm struct {
6 8
 	Op string `json:"op" binding:"required"`
@@ -13,19 +15,21 @@ type UnBindOauthForm struct {
13 15
 	Op string `json:"op" binding:"required"`
14 16
 }
15 17
 type OauthForm struct {
16
-	Id           uint   `json:"id"`
17
-	Op           string `json:"op" validate:"required"`
18
-	Issuer	     string `json:"issuer" validate:"omitempty,url"`
19
-	Scopes	   	 string `json:"scopes" validate:"omitempty"`
20
-	ClientId     string `json:"client_id" validate:"required"`
21
-	ClientSecret string `json:"client_secret" validate:"required"`
22
-	RedirectUrl  string `json:"redirect_url" validate:"required"`
23
-	AutoRegister *bool  `json:"auto_register"`
18
+	Id           uint   			`json:"id"`
19
+	Op           string 			`json:"op" validate:"omitempty"`
20
+	OauthType    string 			`json:"oauth_type" validate:"required"`
21
+	Issuer	     string 			`json:"issuer" validate:"omitempty,url"`
22
+	Scopes	   	 string 			`json:"scopes" validate:"omitempty"`
23
+	ClientId     string 			`json:"client_id" validate:"required"`
24
+	ClientSecret string 			`json:"client_secret" validate:"required"`
25
+	RedirectUrl  string 			`json:"redirect_url" validate:"required"`
26
+	AutoRegister *bool  			`json:"auto_register"`
24 27
 }
25 28
 
26 29
 func (of *OauthForm) ToOauth() *model.Oauth {
27 30
 	oa := &model.Oauth{
28 31
 		Op:           of.Op,
32
+		OauthType:	  of.OauthType,
29 33
 		ClientId:     of.ClientId,
30 34
 		ClientSecret: of.ClientSecret,
31 35
 		RedirectUrl:  of.RedirectUrl,

+ 11 - 7
http/request/admin/user.go

@@ -5,20 +5,22 @@ import (
5 5
 )
6 6
 
7 7
 type UserForm struct {
8
-	Id       uint   `json:"id"`
9
-	Username string `json:"username" validate:"required,gte=4,lte=10"`
8
+	Id       uint   			`json:"id"`
9
+	Username string 			`json:"username" validate:"required,gte=4,lte=10"`
10
+	Email	 string           	`json:"email" validate:"required,email"`
10 11
 	//Password string           `json:"password" validate:"required,gte=4,lte=20"`
11
-	Nickname string           `json:"nickname"`
12
-	Avatar   string           `json:"avatar"`
13
-	GroupId  uint             `json:"group_id" validate:"required"`
14
-	IsAdmin  *bool            `json:"is_admin" `
15
-	Status   model.StatusCode `json:"status" validate:"required,gte=0"`
12
+	Nickname string           	`json:"nickname"`
13
+	Avatar   string           	`json:"avatar"`
14
+	GroupId  uint             	`json:"group_id" validate:"required"`
15
+	IsAdmin  *bool            	`json:"is_admin" `
16
+	Status   model.StatusCode 	`json:"status" validate:"required,gte=0"`
16 17
 }
17 18
 
18 19
 func (uf *UserForm) FromUser(user *model.User) *UserForm {
19 20
 	uf.Id = user.Id
20 21
 	uf.Username = user.Username
21 22
 	uf.Nickname = user.Nickname
23
+	uf.Email = user.Email
22 24
 	uf.Avatar = user.Avatar
23 25
 	uf.GroupId = user.GroupId
24 26
 	uf.IsAdmin = user.IsAdmin
@@ -30,6 +32,7 @@ func (uf *UserForm) ToUser() *model.User {
30 32
 	user.Id = uf.Id
31 33
 	user.Username = uf.Username
32 34
 	user.Nickname = uf.Nickname
35
+	user.Email = uf.Email
33 36
 	user.Avatar = uf.Avatar
34 37
 	user.GroupId = uf.GroupId
35 38
 	user.IsAdmin = uf.IsAdmin
@@ -62,6 +65,7 @@ type GroupUsersQuery struct {
62 65
 
63 66
 type RegisterForm struct {
64 67
 	Username        string `json:"username" validate:"required,gte=4,lte=10"`
68
+	Email           string `json:"email" validate:"required,email"`
65 69
 	Password        string `json:"password" validate:"required,gte=4,lte=20"`
66 70
 	ConfirmPassword string `json:"confirm_password" validate:"required,gte=4,lte=20"`
67 71
 }

+ 12 - 3
http/response/admin/user.go

@@ -4,19 +4,28 @@ import "Gwen/model"
4 4
 
5 5
 type LoginPayload struct {
6 6
 	Username   string   `json:"username"`
7
+	Email	   string   `json:"email"`
8
+	Avatar	   string   `json:"avatar"`
7 9
 	Token      string   `json:"token"`
8 10
 	RouteNames []string `json:"route_names"`
9 11
 	Nickname   string   `json:"nickname"`
10 12
 }
11 13
 
14
+func (lp *LoginPayload) FromUser(user *model.User) {
15
+	lp.Username = user.Username
16
+	lp.Email = user.Email
17
+	lp.Avatar = user.Avatar
18
+	lp.Nickname = user.Nickname
19
+}
20
+
12 21
 var UserRouteNames = []string{
13
-	"MyTagList", "MyAddressBookList", "MyInfo", "MyAddressBookCollection",
22
+	"MyTagList", "MyAddressBookList", "MyInfo", "MyAddressBookCollection", "MyPeer",
14 23
 }
15 24
 var AdminRouteNames = []string{"*"}
16 25
 
17 26
 type UserOauthItem struct {
18
-	ThirdType string `json:"third_type"`
19
-	Status    int    `json:"status"`
27
+	Op 			string `json:"op"`
28
+	Status    	int    `json:"status"`
20 29
 }
21 30
 
22 31
 type GroupUsersPayload struct {

+ 1 - 0
http/response/api/user.go

@@ -29,6 +29,7 @@ type UserPayload struct {
29 29
 
30 30
 func (up *UserPayload) FromUser(user *model.User) *UserPayload {
31 31
 	up.Name = user.Username
32
+	up.Email = user.Email
32 33
 	up.IsAdmin = user.IsAdmin
33 34
 	up.Status = int(user.Status)
34 35
 	up.Info = map[string]interface{}{}

+ 1 - 0
http/router/admin.go

@@ -53,6 +53,7 @@ func UserBind(rg *gin.RouterGroup) {
53 53
 		aR.GET("/current", cont.Current)
54 54
 		aR.POST("/changeCurPwd", cont.ChangeCurPwd)
55 55
 		aR.POST("/myOauth", cont.MyOauth)
56
+		aR.GET("/myPeer", cont.MyPeer)
56 57
 		aR.POST("/groupUsers", cont.GroupUsers)
57 58
 	}
58 59
 	aRP := rg.Group("/user").Use(middleware.AdminPrivilege())

+ 1 - 0
model/loginLog.go

@@ -4,6 +4,7 @@ type LoginLog struct {
4 4
 	IdModel
5 5
 	UserId      uint   `json:"user_id" gorm:"default:0;not null;"`
6 6
 	Client      string `json:"client"` //webadmin,webclient,app,
7
+	DeviceId	string `json:"device_id"`
7 8
 	Uuid        string `json:"uuid"`
8 9
 	Ip          string `json:"ip"`
9 10
 	Type        string `json:"type"`     //account,oauth

+ 152 - 13
model/oauth.go

@@ -1,23 +1,162 @@
1 1
 package model
2 2
 
3
+import (
4
+	"strconv"
5
+	"strings"
6
+	"errors"
7
+)
8
+
9
+const OIDC_DEFAULT_SCOPES = "openid,profile,email"
10
+
11
+const (
12
+	// make sure the value shouldbe lowercase
13
+	OauthTypeGithub  string = "github"
14
+	OauthTypeGoogle  string = "google"
15
+	OauthTypeOidc    string = "oidc"
16
+	OauthTypeWebauth string = "webauth"
17
+)
18
+
19
+// Validate the oauth type
20
+func ValidateOauthType(oauthType string) error {
21
+	switch oauthType {
22
+	case OauthTypeGithub, OauthTypeGoogle, OauthTypeOidc, OauthTypeWebauth:
23
+		return nil
24
+	default:
25
+		return errors.New("invalid Oauth type")
26
+	}
27
+}
28
+
29
+const (
30
+	OauthNameGithub  string = "GitHub"
31
+	OauthNameGoogle  string = "Google"
32
+	OauthNameOidc    string = "OIDC"
33
+	OauthNameWebauth string = "WebAuth"
34
+)
35
+
36
+const (
37
+	UserEndpointGithub  string = "https://api.github.com/user"
38
+	IssuerGoogle 		string = "https://accounts.google.com"
39
+)
40
+
3 41
 type Oauth struct {
4 42
 	IdModel
5
-	Op           string `json:"op"`
6
-	ClientId     string `json:"client_id"`
7
-	ClientSecret string `json:"client_secret"`
8
-	RedirectUrl  string `json:"redirect_url"`
9
-	AutoRegister *bool  `json:"auto_register"`
10
-	Scopes       string `json:"scopes"`
11
-	Issuer	     string `json:"issuer"`
43
+	Op           string 	`json:"op"`
44
+	OauthType    string 	`json:"oauth_type"`
45
+	ClientId     string 	`json:"client_id"`
46
+	ClientSecret string 	`json:"client_secret"`
47
+	RedirectUrl  string 	`json:"redirect_url"`
48
+	AutoRegister *bool  	`json:"auto_register"`
49
+	Scopes       string 	`json:"scopes"`
50
+	Issuer	     string 	`json:"issuer"`
12 51
 	TimeModel
13 52
 }
14 53
 
15
-const (
16
-	OauthTypeGithub  = "github"
17
-	OauthTypeGoogle  = "google"
18
-	OauthTypeOidc    = "oidc"
19
-	OauthTypeWebauth = "webauth"
20
-)
54
+
55
+
56
+// Helper function to format oauth info, it's used in the update and create method
57
+func (oa *Oauth) FormatOauthInfo() error {
58
+	oauthType := strings.TrimSpace(oa.OauthType)
59
+	err := ValidateOauthType(oa.OauthType)
60
+	if err != nil {
61
+		return err
62
+	}
63
+	// check if the op is empty, set the default value
64
+	op := strings.TrimSpace(oa.Op)
65
+	if op == "" {
66
+		switch oauthType {
67
+		case OauthTypeGithub:
68
+			oa.Op = OauthNameGithub
69
+		case OauthTypeGoogle:
70
+			oa.Op = OauthNameGoogle
71
+		case OauthTypeOidc:
72
+			oa.Op = OauthNameOidc
73
+		case OauthTypeWebauth:
74
+			oa.Op = OauthNameWebauth
75
+		default:
76
+			oa.Op = oauthType
77
+		}
78
+	}
79
+	// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
80
+	issuer := strings.TrimSpace(oa.Issuer)
81
+	// If the oauth type is google and the issuer is empty, set the issuer to the default value 
82
+	if oauthType == OauthTypeGoogle && issuer == "" {
83
+		oa.Issuer = IssuerGoogle
84
+	}
85
+	return nil
86
+}
87
+
88
+type OauthUser struct {
89
+	OpenId 			string 	`json:"open_id" gorm:"not null;index"`
90
+	Name   			string 	`json:"name"`
91
+	Username 		string 	`json:"username"`
92
+	Email  			string 	`json:"email"`
93
+	VerifiedEmail 	bool 	`json:"verified_email,omitempty"`
94
+	Picture			string 	`json:"picture,omitempty"`
95
+}
96
+
97
+func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
98
+	if overideUsername {
99
+		user.Username = ou.Username
100
+	}
101
+	user.Email = ou.Email
102
+	user.Nickname = ou.Name
103
+	user.Avatar = ou.Picture
104
+}
105
+
106
+type OauthUserBase struct {
107
+	Name  string `json:"name"`
108
+	Email string `json:"email"`
109
+}
110
+
111
+type OidcUser struct {
112
+	OauthUserBase
113
+	Sub               string `json:"sub"`
114
+	VerifiedEmail     bool   `json:"email_verified"`
115
+	PreferredUsername string `json:"preferred_username"`
116
+	Picture           string `json:"picture"`
117
+}
118
+
119
+func (ou *OidcUser) ToOauthUser() *OauthUser {
120
+	var username string
121
+	// 使用 PreferredUsername,如果不存在,降级到 Email 前缀
122
+	if ou.PreferredUsername != "" {
123
+		username = ou.PreferredUsername
124
+	} else {
125
+		username = strings.ToLower(strings.Split(ou.Email, "@")[0])
126
+	}
127
+
128
+	return &OauthUser{
129
+		OpenId:        ou.Sub,
130
+		Name:          ou.Name,
131
+		Username:      username,
132
+		Email:         ou.Email,
133
+		VerifiedEmail: ou.VerifiedEmail,
134
+		Picture:       ou.Picture,
135
+	}
136
+}
137
+
138
+
139
+type GithubUser struct {
140
+	OauthUserBase
141
+	Id                int         `json:"id"`
142
+	Login             string      `json:"login"`
143
+	AvatarUrl         string      `json:"avatar_url"`
144
+	VerifiedEmail	  bool        `json:"verified_email"`
145
+}
146
+
147
+func (gu *GithubUser) ToOauthUser() *OauthUser {
148
+	username := strings.ToLower(gu.Login)
149
+	return &OauthUser{
150
+		OpenId: 		strconv.Itoa(gu.Id),
151
+		Name:   		gu.Name,
152
+		Username: 		username,
153
+		Email:  		gu.Email,
154
+		VerifiedEmail: 	gu.VerifiedEmail,
155
+		Picture:		gu.AvatarUrl,
156
+	}
157
+}
158
+
159
+
21 160
 
22 161
 type OauthList struct {
23 162
 	Oauths []*Oauth `json:"list"`

+ 16 - 0
model/user.go

@@ -1,8 +1,15 @@
1 1
 package model
2 2
 
3
+import (
4
+	"fmt"
5
+	"gorm.io/gorm"
6
+)
7
+
3 8
 type User struct {
4 9
 	IdModel
5 10
 	Username string     `json:"username" gorm:"default:'';not null;uniqueIndex"`
11
+	Email	string     	`json:"email" gorm:"default:'';not null;uniqueIndex"`
12
+	// Email	string     	`json:"email" `
6 13
 	Password string     `json:"-" gorm:"default:'';not null;"`
7 14
 	Nickname string     `json:"nickname" gorm:"default:'';not null;"`
8 15
 	Avatar   string     `json:"avatar" gorm:"default:'';not null;"`
@@ -12,6 +19,15 @@ type User struct {
12 19
 	TimeModel
13 20
 }
14 21
 
22
+// BeforeSave 钩子用于确保 email 字段有合理的默认值
23
+func (u *User) BeforeSave(tx *gorm.DB) (err error) {
24
+    // 如果 email 为空,设置为默认值
25
+    if u.Email == "" {
26
+        u.Email = fmt.Sprintf("%s@example.com", u.Username)
27
+    }
28
+    return nil
29
+}
30
+
15 31
 type UserList struct {
16 32
 	Users []*User `json:"list,omitempty"`
17 33
 	Pagination

+ 19 - 6
model/userThird.go

@@ -1,12 +1,25 @@
1 1
 package model
2 2
 
3
+import (
4
+	"strings"
5
+)
6
+
3 7
 type UserThird struct {
4 8
 	IdModel
5
-	UserId     uint   `json:"user_id" gorm:"not null;index"`
6
-	OpenId     string `json:"open_id" gorm:"not null;index"`
7
-	UnionId    string `json:"union_id" gorm:"not null;"`
8
-	ThirdType  string `json:"third_type" gorm:"not null;"`
9
-	ThirdEmail string `json:"third_email"`
10
-	ThirdName  string `json:"third_name"`
9
+	UserId     		uint   `	json:"user_id" gorm:"not null;index"`
10
+	OauthUser
11
+	// UnionId    		string `json:"union_id" gorm:"not null;"`
12
+	// OauthType  	   	string 		`json:"oauth_type" gorm:"not null;"`
13
+	OauthType  	   	string 		`json:"oauth_type"`
14
+	Op  			string 		`json:"op" gorm:"not null;"`
11 15
 	TimeModel
12 16
 }
17
+
18
+func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
19
+	u.UserId 			= userId
20
+	u.OauthUser 		= *oauthUser
21
+	u.OauthType 		= oauthType
22
+	u.Op 				= op
23
+	// make sure email is lower case
24
+	u.Email 			= strings.ToLower(u.Email)
25
+}

+ 5 - 3
model/userToken.go

@@ -2,9 +2,11 @@ package model
2 2
 
3 3
 type UserToken struct {
4 4
 	IdModel
5
-	UserId    uint   `json:"user_id" gorm:"default:0;not null;index"`
6
-	Token     string `json:"token" gorm:"default:'';not null;index"`
7
-	ExpiredAt int64  `json:"expired_at" gorm:"default:0;not null;"`
5
+	UserId    	uint   `json:"user_id" gorm:"default:0;not null;index"`
6
+	DeviceUuid 	string `json:"device_uuid" gorm:"default:'';omitempty;"`
7
+	DeviceId	string `json:"device_id" gorm:"default:'';omitempty;"`
8
+	Token     	string `json:"token" gorm:"default:'';not null;index"`
9
+	ExpiredAt 	int64  `json:"expired_at" gorm:"default:0;not null;"`
8 10
 	TimeModel
9 11
 }
10 12
 

+ 252 - 277
service/oauth.go

@@ -9,17 +9,22 @@ import (
9 9
 	"errors"
10 10
 	"golang.org/x/oauth2"
11 11
 	"golang.org/x/oauth2/github"
12
-	"golang.org/x/oauth2/google"
12
+	// "golang.org/x/oauth2/google"
13 13
 	"gorm.io/gorm"
14
-	"io"
14
+	// "io"
15 15
 	"net/http"
16 16
 	"net/url"
17 17
 	"strconv"
18 18
 	"strings"
19 19
 	"sync"
20 20
 	"time"
21
+	"fmt"
21 22
 )
22 23
 
24
+
25
+type OauthService struct {
26
+}
27
+
23 28
 // Define a struct to parse the .well-known/openid-configuration response
24 29
 type OidcEndpoint struct {
25 30
 	Issuer   string `json:"issuer"`
@@ -28,73 +33,6 @@ type OidcEndpoint struct {
28 33
 	UserInfo string `json:"userinfo_endpoint"`
29 34
 }
30 35
 
31
-type OauthService struct {
32
-}
33
-
34
-type GithubUserdata struct {
35
-	AvatarUrl         string      `json:"avatar_url"`
36
-	Bio               string      `json:"bio"`
37
-	Blog              string      `json:"blog"`
38
-	Collaborators     int         `json:"collaborators"`
39
-	Company           interface{} `json:"company"`
40
-	CreatedAt         time.Time   `json:"created_at"`
41
-	DiskUsage         int         `json:"disk_usage"`
42
-	Email             interface{} `json:"email"`
43
-	EventsUrl         string      `json:"events_url"`
44
-	Followers         int         `json:"followers"`
45
-	FollowersUrl      string      `json:"followers_url"`
46
-	Following         int         `json:"following"`
47
-	FollowingUrl      string      `json:"following_url"`
48
-	GistsUrl          string      `json:"gists_url"`
49
-	GravatarId        string      `json:"gravatar_id"`
50
-	Hireable          interface{} `json:"hireable"`
51
-	HtmlUrl           string      `json:"html_url"`
52
-	Id                int         `json:"id"`
53
-	Location          interface{} `json:"location"`
54
-	Login             string      `json:"login"`
55
-	Name              string      `json:"name"`
56
-	NodeId            string      `json:"node_id"`
57
-	NotificationEmail interface{} `json:"notification_email"`
58
-	OrganizationsUrl  string      `json:"organizations_url"`
59
-	OwnedPrivateRepos int         `json:"owned_private_repos"`
60
-	Plan              struct {
61
-		Collaborators int    `json:"collaborators"`
62
-		Name          string `json:"name"`
63
-		PrivateRepos  int    `json:"private_repos"`
64
-		Space         int    `json:"space"`
65
-	} `json:"plan"`
66
-	PrivateGists      int    `json:"private_gists"`
67
-	PublicGists       int    `json:"public_gists"`
68
-	PublicRepos       int    `json:"public_repos"`
69
-	ReceivedEventsUrl string `json:"received_events_url"`
70
-	ReposUrl          string `json:"repos_url"`
71
-	SiteAdmin         bool   `json:"site_admin"`
72
-	StarredUrl        string `json:"starred_url"`
73
-	SubscriptionsUrl  string `json:"subscriptions_url"`
74
-	TotalPrivateRepos int    `json:"total_private_repos"`
75
-	//TwitterUsername         interface{} `json:"twitter_username"`
76
-	TwoFactorAuthentication bool      `json:"two_factor_authentication"`
77
-	Type                    string    `json:"type"`
78
-	UpdatedAt               time.Time `json:"updated_at"`
79
-	Url                     string    `json:"url"`
80
-}
81
-type GoogleUserdata struct {
82
-	Email         string `json:"email"`
83
-	FamilyName    string `json:"family_name"`
84
-	GivenName     string `json:"given_name"`
85
-	Id            string `json:"id"`
86
-	Name          string `json:"name"`
87
-	Picture       string `json:"picture"`
88
-	VerifiedEmail bool   `json:"verified_email"`
89
-}
90
-type OidcUserdata struct {
91
-	Sub               string `json:"sub"`
92
-	Email             string `json:"email"`
93
-	VerifiedEmail     bool   `json:"email_verified"`
94
-	Name              string `json:"name"`
95
-	PreferredUsername string `json:"preferred_username"`
96
-}
97
-
98 36
 type OauthCacheItem struct {
99 37
 	UserId      uint   `json:"user_id"`
100 38
 	Id          string `json:"id"` //rustdesk的设备ID
@@ -104,9 +42,19 @@ type OauthCacheItem struct {
104 42
 	DeviceName  string `json:"device_name"`
105 43
 	DeviceOs    string `json:"device_os"`
106 44
 	DeviceType  string `json:"device_type"`
107
-	ThirdOpenId string `json:"third_open_id"`
108
-	ThirdName   string `json:"third_name"`
109
-	ThirdEmail  string `json:"third_email"`
45
+	OpenId 		string `json:"open_id"`
46
+	Username	string `json:"username"`
47
+	Name   		string `json:"name"`
48
+	Email  		string `json:"email"`
49
+}
50
+
51
+func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
52
+	return &model.OauthUser{
53
+		OpenId: oci.OpenId,
54
+		Username: oci.Username,
55
+		Name: oci.Name,
56
+		Email: oci.Email,
57
+	}
110 58
 }
111 59
 
112 60
 var OauthCache = &sync.Map{}
@@ -116,6 +64,14 @@ const (
116 64
 	OauthActionTypeBind  = "bind"
117 65
 )
118 66
 
67
+func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
68
+	oa.OpenId = oauthUser.OpenId
69
+	oa.Username = oauthUser.Username
70
+	oa.Name = oauthUser.Name
71
+	oa.Email = oauthUser.Email
72
+}
73
+
74
+
119 75
 func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
120 76
 	v, ok := OauthCache.Load(key)
121 77
 	if !ok {
@@ -140,22 +96,21 @@ func (os *OauthService) DeleteOauthCache(key string) {
140 96
 
141 97
 func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
142 98
 	code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
143
-
144
-	if op == model.OauthTypeWebauth {
99
+	if op == string(model.OauthTypeWebauth) {
145 100
 		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
146 101
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
147 102
 		return nil, code, url
148 103
 	}
149
-	err, conf := os.GetOauthConfig(op)
104
+	err, _, oauthConfig := os.GetOauthConfig(op)
150 105
 	if err == nil {
151
-		return err, code, conf.AuthCodeURL(code)
106
+		return err, code, oauthConfig.AuthCodeURL(code)
152 107
 	}
153 108
 
154 109
 	return err, code, ""
155 110
 }
156 111
 
157 112
 // Method to fetch OIDC configuration dynamically
158
-func FetchOidcConfig(issuer string) (error, OidcEndpoint) {
113
+func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) {
159 114
 	configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
160 115
 
161 116
 	// Get the HTTP client (with or without proxy based on configuration)
@@ -179,76 +134,58 @@ func FetchOidcConfig(issuer string) (error, OidcEndpoint) {
179 134
 	return nil, endpoint
180 135
 }
181 136
 
182
-// GetOauthConfig retrieves the OAuth2 configuration based on the provider type
183
-func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) {
184
-	switch op {
185
-	case model.OauthTypeGithub:
186
-		return os.getGithubConfig()
187
-	case model.OauthTypeGoogle:
188
-		return os.getGoogleConfig()
189
-	case model.OauthTypeOidc:
190
-		return os.getOidcConfig()
191
-	default:
192
-		return errors.New("unsupported OAuth type"), nil
137
+func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
138
+	oauthInfo := os.InfoByOp(op)
139
+	if oauthInfo.Issuer == "" {
140
+		return errors.New("issuer is empty"), OidcEndpoint{}
193 141
 	}
142
+	return os.FetchOidcEndpoint(oauthInfo.Issuer)
194 143
 }
195 144
 
196
-// Helper function to get GitHub OAuth2 configuration
197
-func (os *OauthService) getGithubConfig() (error, *oauth2.Config) {
198
-	g := os.InfoByOp(model.OauthTypeGithub)
199
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" {
200
-		return errors.New("ConfigNotFound"), nil
201
-	}
202
-	return nil, &oauth2.Config{
203
-		ClientID:     g.ClientId,
204
-		ClientSecret: g.ClientSecret,
205
-		RedirectURL:  g.RedirectUrl,
206
-		Endpoint:     github.Endpoint,
207
-		Scopes:       []string{"read:user", "user:email"},
145
+// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
146
+func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
147
+	err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
148
+	if err != nil {
149
+		return err, nil, nil
208 150
 	}
209
-}
210
-
211
-// Helper function to get Google OAuth2 configuration
212
-func (os *OauthService) getGoogleConfig() (error, *oauth2.Config) {
213
-	g := os.InfoByOp(model.OauthTypeGoogle)
214
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" {
215
-		return errors.New("ConfigNotFound"), nil
151
+	// Maybe should validate the oauthConfig here
152
+	oauthType := oauthInfo.OauthType
153
+	err = model.ValidateOauthType(oauthType)
154
+	if err != nil {
155
+		return err, nil, nil
216 156
 	}
217
-	return nil, &oauth2.Config{
218
-		ClientID:     g.ClientId,
219
-		ClientSecret: g.ClientSecret,
220
-		RedirectURL:  g.RedirectUrl,
221
-		Endpoint:     google.Endpoint,
222
-		Scopes:       []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"},
157
+	switch oauthType {
158
+	case model.OauthTypeGithub:
159
+		oauthConfig.Endpoint = github.Endpoint
160
+		oauthConfig.Scopes = []string{"read:user", "user:email"}
161
+	case model.OauthTypeOidc, model.OauthTypeGoogle:
162
+		var endpoint OidcEndpoint
163
+		err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
164
+		if err != nil {
165
+			return err, nil, nil
166
+		}
167
+		oauthConfig.Endpoint = oauth2.Endpoint{AuthURL:  endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
168
+		oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
169
+	default:
170
+		return errors.New("unsupported OAuth type"), nil, nil
223 171
 	}
172
+	return nil, oauthInfo, oauthConfig
224 173
 }
225 174
 
226
-// Helper function to get OIDC OAuth2 configuration
227
-func (os *OauthService) getOidcConfig() (error, *oauth2.Config) {
228
-	g := os.InfoByOp(model.OauthTypeOidc)
229
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" || g.Issuer == "" {
230
-		return errors.New("ConfigNotFound"), nil
175
+// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
176
+func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
177
+	oauthInfo = os.InfoByOp(op)
178
+	if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
179
+		return errors.New("ConfigNotFound"), nil, nil
231 180
 	}
232
-
233
-	// Set scopes
234
-	scopes := strings.TrimSpace(g.Scopes)
235
-	if scopes == "" {
236
-		scopes = "openid,profile,email"
181
+	// If the redirect URL is empty, use the default redirect URL
182
+	if oauthInfo.RedirectUrl == "" {
183
+		oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
237 184
 	}
238
-	scopeList := strings.Split(scopes, ",")
239
-	err, endpoint := FetchOidcConfig(g.Issuer)
240
-	if err != nil {
241
-		return err, nil
242
-	}
243
-	return nil, &oauth2.Config{
244
-		ClientID:     g.ClientId,
245
-		ClientSecret: g.ClientSecret,
246
-		RedirectURL:  g.RedirectUrl,
247
-		Endpoint: oauth2.Endpoint{
248
-			AuthURL:  endpoint.AuthURL,
249
-			TokenURL: endpoint.TokenURL,
250
-		},
251
-		Scopes: scopeList,
185
+	return nil, oauthInfo, &oauth2.Config{
186
+		ClientID:     oauthInfo.ClientId,
187
+		ClientSecret: oauthInfo.ClientSecret,
188
+		RedirectURL:  oauthInfo.RedirectUrl,
252 189
 	}
253 190
 }
254 191
 
@@ -272,194 +209,153 @@ func getHTTPClientWithProxy() *http.Client {
272 209
 	return http.DefaultClient
273 210
 }
274 211
 
275
-func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
276
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
277
-	if err != nil {
278
-		return err, nil
279
-	}
212
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
280 213
 
281
-	// 使用代理配置创建 HTTP 客户端
214
+	// 设置代理客户端
282 215
 	httpClient := getHTTPClientWithProxy()
283 216
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
284 217
 
285
-	token, err := oauthConfig.Exchange(ctx, code)
218
+	// 使用 code 换取 token
219
+	var token *oauth2.Token
220
+	token, err = oauthConfig.Exchange(ctx, code)
286 221
 	if err != nil {
287 222
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
288
-		error = errors.New("GetOauthTokenError")
289
-		return
223
+		return errors.New("GetOauthTokenError"), nil
290 224
 	}
291 225
 
292
-	// 使用带有代理的 HTTP 客户端获取用户信息
293
-	client := oauthConfig.Client(ctx, token)
294
-	resp, err := client.Get("https://api.github.com/user")
226
+	// 获取用户信息
227
+	client = oauthConfig.Client(ctx, token)
228
+	resp, err := client.Get(userEndpoint)
295 229
 	if err != nil {
296 230
 		global.Logger.Warn("failed getting user info: ", err)
297
-		error = errors.New("GetOauthUserInfoError")
298
-		return
231
+		return errors.New("GetOauthUserInfoError"), nil
299 232
 	}
300
-	defer func(Body io.ReadCloser) {
301
-		err := Body.Close()
302
-		if err != nil {
303
-			global.Logger.Warn("failed closing response body: ", err)
233
+	defer func() {
234
+		if closeErr := resp.Body.Close(); closeErr != nil {
235
+			global.Logger.Warn("failed closing response body: ", closeErr)
304 236
 		}
305
-	}(resp.Body)
237
+	}()
306 238
 
307 239
 	// 解析用户信息
308
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
240
+	if err = json.NewDecoder(resp.Body).Decode(userData); err != nil {
309 241
 		global.Logger.Warn("failed decoding user info: ", err)
310
-		error = errors.New("DecodeOauthUserInfoError")
311
-		return
242
+		return errors.New("DecodeOauthUserInfoError"), nil
312 243
 	}
313
-	return
244
+
245
+	return nil, client
314 246
 }
315 247
 
316
-func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
317
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
248
+// githubCallback github回调
249
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
250
+	var user = &model.GithubUser{}
251
+	err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
318 252
 	if err != nil {
319 253
 		return err, nil
320 254
 	}
321
-
322
-	// 使用代理配置创建 HTTP 客户端
323
-	httpClient := getHTTPClientWithProxy()
324
-	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
325
-
326
-	token, err := oauthConfig.Exchange(ctx, code)
255
+	err = os.getGithubPrimaryEmail(client, user)
327 256
 	if err != nil {
328
-		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
329
-		error = errors.New("GetOauthTokenError")
330
-		return
257
+		return err, nil
331 258
 	}
259
+	return nil, user.ToOauthUser()
260
+}
332 261
 
333
-	// 使用带有代理的 HTTP 客户端获取用户信息
334
-	client := oauthConfig.Client(ctx, token)
335
-	resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
336
-	if err != nil {
337
-		global.Logger.Warn("failed getting user info: ", err)
338
-		error = errors.New("GetOauthUserInfoError")
339
-		return
340
-	}
341
-	defer func(Body io.ReadCloser) {
342
-		err := Body.Close()
343
-		if err != nil {
344
-			global.Logger.Warn("failed closing response body: ", err)
345
-		}
346
-	}(resp.Body)
347 262
 
348
-	// 解析用户信息
349
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
350
-		global.Logger.Warn("failed decoding user info: ", err)
351
-		error = errors.New("DecodeOauthUserInfoError")
352
-		return
263
+// oidcCallback oidc回调, 通过code获取用户信息
264
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
265
+	var user = &model.OidcUser{}
266
+	if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
267
+		return err, nil
353 268
 	}
354
-	return
269
+	return nil, user.ToOauthUser()
355 270
 }
356 271
 
357
-func (os *OauthService) OidcCallback(code string) (error error, userData *OidcUserdata) {
358
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeOidc)
272
+// Callback: Get user information by code and op(Oauth provider)
273
+func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
274
+	var oauthInfo *model.Oauth
275
+	var oauthConfig *oauth2.Config
276
+	err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
277
+	// oauthType is already validated in GetOauthConfig
359 278
 	if err != nil {
360 279
 		return err, nil
361 280
 	}
362
-	// 使用代理配置创建 HTTP 客户端
363
-	httpClient := getHTTPClientWithProxy()
364
-	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
365
-
366
-	token, err := oauthConfig.Exchange(ctx, code)
367
-	if err != nil {
368
-		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
369
-		error = errors.New("GetOauthTokenError")
370
-		return
371
-	}
372
-
373
-	// 使用带有代理的 HTTP 客户端获取用户信息
374
-	client := oauthConfig.Client(ctx, token)
375
-	g := os.InfoByOp(model.OauthTypeOidc)
376
-	err, endpoint := FetchOidcConfig(g.Issuer)
377
-	if err != nil {
378
-		global.Logger.Warn("failed fetching OIDC configuration: ", err)
379
-		error = errors.New("FetchOidcConfigError")
380
-		return
381
-	}
382
-	resp, err := client.Get(endpoint.UserInfo)
383
-	if err != nil {
384
-		global.Logger.Warn("failed getting user info: ", err)
385
-		error = errors.New("GetOauthUserInfoError")
386
-		return
387
-	}
388
-	defer func(Body io.ReadCloser) {
389
-		err := Body.Close()
281
+	oauthType := oauthInfo.OauthType
282
+	switch oauthType {
283
+    case model.OauthTypeGithub:
284
+        err, oauthUser = os.githubCallback(oauthConfig, code)
285
+    case model.OauthTypeOidc, model.OauthTypeGoogle:
286
+		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
390 287
 		if err != nil {
391
-			global.Logger.Warn("failed closing response body: ", err)
288
+			return err, nil
392 289
 		}
393
-	}(resp.Body)
394
-
395
-	// 解析用户信息
396
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
397
-		global.Logger.Warn("failed decoding user info: ", err)
398
-		error = errors.New("DecodeOauthUserInfoError")
399
-		return
400
-	}
401
-	return
290
+        err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
291
+    default:
292
+        return errors.New("unsupported OAuth type"), nil
293
+    }
294
+    return err, oauthUser
402 295
 }
403 296
 
404
-func (os *OauthService) UserThirdInfo(op, openid string) *model.UserThird {
297
+
298
+func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
405 299
 	ut := &model.UserThird{}
406
-	global.DB.Where("open_id = ? and third_type = ?", openid, op).First(ut)
300
+	global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
407 301
 	return ut
408 302
 }
409 303
 
410
-func (os *OauthService) BindGithubUser(openid, username string, userId uint) error {
411
-	return os.BindOauthUser(model.OauthTypeGithub, openid, username, userId)
412
-}
413
-
414
-func (os *OauthService) BindGoogleUser(email, username string, userId uint) error {
415
-	return os.BindOauthUser(model.OauthTypeGoogle, email, username, userId)
416
-}
417
-
418
-func (os *OauthService) BindOidcUser(sub, username string, userId uint) error {
419
-	return os.BindOauthUser(model.OauthTypeOidc, sub, username, userId)
420
-}
421
-
422
-func (os *OauthService) BindOauthUser(thirdType, openid, username string, userId uint) error {
423
-	utr := &model.UserThird{
424
-		OpenId:    openid,
425
-		ThirdType: thirdType,
426
-		ThirdName: username,
427
-		UserId:    userId,
304
+// BindOauthUser: Bind third party account
305
+func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error {
306
+	utr := &model.UserThird{}
307
+	err, oauthType := os.GetTypeByOp(op)
308
+	if err != nil {
309
+		return err
428 310
 	}
311
+	utr.FromOauthUser(userId, oauthUser, oauthType, op)
429 312
 	return global.DB.Create(utr).Error
430 313
 }
431 314
 
432
-func (os *OauthService) UnBindGithubUser(userid uint) error {
433
-	return os.UnBindThird(model.OauthTypeGithub, userid)
315
+// UnBindOauthUser: Unbind third party account
316
+func (os *OauthService) UnBindOauthUser(userId uint, op string) error {
317
+	return os.UnBindThird(op, userId)
434 318
 }
435
-func (os *OauthService) UnBindGoogleUser(userid uint) error {
436
-	return os.UnBindThird(model.OauthTypeGoogle, userid)
437
-}
438
-func (os *OauthService) UnBindOidcUser(userid uint) error {
439
-	return os.UnBindThird(model.OauthTypeOidc, userid)
440
-}
441
-func (os *OauthService) UnBindThird(thirdType string, userid uint) error {
442
-	return global.DB.Where("user_id = ? and third_type = ?", userid, thirdType).Delete(&model.UserThird{}).Error
319
+
320
+// UnBindThird: Unbind third party account
321
+func (os *OauthService) UnBindThird(op string, userId uint) error {
322
+	return global.DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error
443 323
 }
444 324
 
445 325
 // DeleteUserByUserId: When user is deleted, delete all third party bindings
446
-func (os *OauthService) DeleteUserByUserId(userid uint) error {
447
-	return global.DB.Where("user_id = ?", userid).Delete(&model.UserThird{}).Error
326
+func (os *OauthService) DeleteUserByUserId(userId uint) error {
327
+	return global.DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error
448 328
 }
449 329
 
450
-// InfoById 根据id取用户信息
330
+// InfoById 根据id获取Oauth信息
451 331
 func (os *OauthService) InfoById(id uint) *model.Oauth {
452
-	u := &model.Oauth{}
453
-	global.DB.Where("id = ?", id).First(u)
454
-	return u
332
+	oauthInfo := &model.Oauth{}
333
+	global.DB.Where("id = ?", id).First(oauthInfo)
334
+	return oauthInfo
455 335
 }
456 336
 
457
-// InfoByOp 根据op取用户信息
337
+// InfoByOp 根据op获取Oauth信息
458 338
 func (os *OauthService) InfoByOp(op string) *model.Oauth {
459
-	u := &model.Oauth{}
460
-	global.DB.Where("op = ?", op).First(u)
461
-	return u
339
+	oauthInfo := &model.Oauth{}
340
+	global.DB.Where("op = ?", op).First(oauthInfo)
341
+	return oauthInfo
462 342
 }
343
+
344
+// Helper function to get scopes by operation
345
+func (os *OauthService) getScopesByOp(op string) []string {
346
+    scopes := os.InfoByOp(op).Scopes
347
+	return os.constructScopes(scopes)
348
+}
349
+
350
+// Helper function to construct scopes
351
+func (os *OauthService) constructScopes(scopes string) []string {
352
+    scopes = strings.TrimSpace(scopes)
353
+    if scopes == "" {
354
+        scopes = model.OIDC_DEFAULT_SCOPES
355
+    }
356
+    return strings.Split(scopes, ",")
357
+}
358
+
463 359
 func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
464 360
 	res = &model.OauthList{}
465 361
 	res.Page = int64(page)
@@ -474,16 +370,95 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
474 370
 	return
475 371
 }
476 372
 
373
+// GetTypeByOp 根据op获取OauthType
374
+func (os *OauthService) GetTypeByOp(op string) (error, string) {
375
+	oauthInfo := &model.Oauth{}
376
+	if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil {
377
+		return fmt.Errorf("OAuth provider with op '%s' not found", op), ""
378
+	}
379
+	return nil, oauthInfo.OauthType
380
+}
381
+
382
+// ValidateOauthProvider 验证Oauth提供者是否正确
383
+func (os *OauthService) ValidateOauthProvider(op string) error {
384
+	if !os.IsOauthProviderExist(op) {
385
+		return fmt.Errorf("OAuth provider with op '%s' not found", op)
386
+	}
387
+	return nil
388
+}
389
+
390
+// IsOauthProviderExist 验证Oauth提供者是否存在
391
+func (os *OauthService) IsOauthProviderExist(op string) bool {
392
+	oauthInfo := &model.Oauth{}
393
+	// 使用 Gorm 的 Take 方法查找符合条件的记录
394
+	if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil {
395
+		return false
396
+	}
397
+	return true
398
+}
399
+
477 400
 // Create 创建
478
-func (os *OauthService) Create(u *model.Oauth) error {
479
-	res := global.DB.Create(u).Error
401
+func (os *OauthService) Create(oauthInfo *model.Oauth) error {
402
+	err := oauthInfo.FormatOauthInfo()
403
+	if err != nil {
404
+		return err
405
+	}
406
+	res := global.DB.Create(oauthInfo).Error
480 407
 	return res
481 408
 }
482
-func (os *OauthService) Delete(u *model.Oauth) error {
483
-	return global.DB.Delete(u).Error
409
+func (os *OauthService) Delete(oauthInfo *model.Oauth) error {
410
+	return global.DB.Delete(oauthInfo).Error
484 411
 }
485 412
 
486 413
 // Update 更新
487
-func (os *OauthService) Update(u *model.Oauth) error {
488
-	return global.DB.Model(u).Updates(u).Error
414
+func (os *OauthService) Update(oauthInfo *model.Oauth) error {
415
+	err := oauthInfo.FormatOauthInfo()
416
+	if err != nil {
417
+		return err
418
+	}
419
+	return global.DB.Model(oauthInfo).Updates(oauthInfo).Error
489 420
 }
421
+
422
+// GetOauthProviders 获取所有的provider
423
+func (os *OauthService) GetOauthProviders() []string {
424
+	var res []string
425
+	global.DB.Model(&model.Oauth{}).Pluck("op", &res)
426
+	return res
427
+}
428
+
429
+// getGithubPrimaryEmail: Get the primary email of the user from Github
430
+func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *model.GithubUser) error {
431
+	// the client is already set with the token
432
+	resp, err := client.Get("https://api.github.com/user/emails")
433
+	if err != nil {
434
+		return fmt.Errorf("failed to fetch emails: %w", err)
435
+	}
436
+	defer resp.Body.Close()
437
+
438
+	// check the response status code
439
+	if resp.StatusCode != http.StatusOK {
440
+		return fmt.Errorf("failed to fetch emails: %s", resp.Status)
441
+	}
442
+
443
+	// decode the response
444
+	var emails []struct {
445
+		Email    string `json:"email"`
446
+		Primary  bool   `json:"primary"`
447
+		Verified bool   `json:"verified"`
448
+	}
449
+
450
+	if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
451
+		return fmt.Errorf("failed to decode response: %w", err)
452
+	}
453
+
454
+	// find the primary verified email
455
+	for _, e := range emails {
456
+		if e.Primary && e.Verified {
457
+			githubUser.Email = e.Email
458
+			githubUser.VerifiedEmail = e.Verified
459
+			return nil
460
+		}
461
+	}
462
+
463
+	return fmt.Errorf("no primary verified email found")
464
+}

+ 68 - 4
service/peer.go

@@ -26,15 +26,43 @@ func (ps *PeerService) InfoByRowId(id uint) *model.Peer {
26 26
 	return p
27 27
 }
28 28
 
29
+// FindByUserIdAndUuid 根据用户id和uuid查找peer
30
+func (ps *PeerService) FindByUserIdAndUuid(uuid string,userId uint) *model.Peer {
31
+	p := &model.Peer{}
32
+	global.DB.Where("uuid = ? and user_id = ?", uuid, userId).First(p)
33
+	return p
34
+}
35
+
29 36
 // UuidBindUserId 绑定用户id
30
-func (ps *PeerService) UuidBindUserId(uuid string, userId uint) {
37
+func (ps *PeerService) UuidBindUserId(deviceId string, uuid string, userId uint) {
31 38
 	peer := ps.FindByUuid(uuid)
39
+	// 如果存在则更新
32 40
 	if peer.RowId > 0 {
33 41
 		peer.UserId = userId
34 42
 		ps.Update(peer)
43
+	} else {
44
+		// 不存在则创建
45
+		global.DB.Create(&model.Peer{
46
+			Id: 		deviceId,
47
+			Uuid:     	uuid,
48
+			UserId:   	userId,
49
+		})
50
+	}
51
+}
52
+
53
+// UuidUnbindUserId 解绑用户id, 用于用户注销
54
+func (ps *PeerService) UuidUnbindUserId(uuid string, userId uint) {
55
+	peer := ps.FindByUserIdAndUuid(uuid, userId)
56
+	if peer.RowId > 0 {
57
+		global.DB.Model(peer).Update("user_id", 0)
35 58
 	}
36 59
 }
37 60
 
61
+// EraseUserId 清除用户id, 用于用户删除
62
+func (ps *PeerService) EraseUserId(userId uint) error {
63
+	return global.DB.Model(&model.Peer{}).Where("user_id = ?", userId).Update("user_id", 0).Error
64
+}
65
+
38 66
 // ListByUserIds 根据用户id取列表
39 67
 func (ps *PeerService) ListByUserIds(userIds []uint, page, pageSize uint) (res *model.PeerList) {
40 68
 	res = &model.PeerList{}
@@ -62,21 +90,57 @@ func (ps *PeerService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *
62 90
 	return
63 91
 }
64 92
 
93
+// ListFilterByUserId 根据用户id过滤Peer列表
94
+func (ps *PeerService) ListFilterByUserId(page, pageSize uint, where func(tx *gorm.DB), userId uint) (res *model.PeerList) {
95
+	userWhere := func(tx *gorm.DB) {
96
+		tx.Where("user_id = ?", userId)
97
+		// 如果还有额外的筛选条件,执行它
98
+		if where != nil {
99
+			where(tx)
100
+		}
101
+	}
102
+	return ps.List(page, pageSize, userWhere)
103
+}
104
+
65 105
 // Create 创建
66 106
 func (ps *PeerService) Create(u *model.Peer) error {
67 107
 	res := global.DB.Create(u).Error
68 108
 	return res
69 109
 }
110
+
111
+// Delete 删除, 同时也应该删除token
70 112
 func (ps *PeerService) Delete(u *model.Peer) error {
71
-	return global.DB.Delete(u).Error
113
+	uuid := u.Uuid
114
+	err := global.DB.Delete(u).Error
115
+	if err != nil {
116
+		return err
117
+	}
118
+	// 删除token
119
+	return AllService.UserService.FlushTokenByUuid(uuid)
72 120
 }
73 121
 
74
-// BatchDelete
122
+// GetUuidListByIDs 根据ids获取uuid列表
123
+func (ps *PeerService) GetUuidListByIDs(ids []uint) ([]string, error) {
124
+	var uuids []string
125
+	err := global.DB.Model(&model.Peer{}).
126
+		Where("row_id in (?)", ids).
127
+		Pluck("uuid", &uuids).Error
128
+	return uuids, err
129
+}
130
+
131
+// BatchDelete 批量删除, 同时也应该删除token
75 132
 func (ps *PeerService) BatchDelete(ids []uint) error {
76
-	return global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error
133
+	uuids, err := ps.GetUuidListByIDs(ids)
134
+	err = global.DB.Where("row_id in (?)", ids).Delete(&model.Peer{}).Error
135
+	if err != nil {
136
+		return err
137
+	}
138
+	// 删除token
139
+	return AllService.UserService.FlushTokenByUuids(uuids)
77 140
 }
78 141
 
79 142
 // Update 更新
80 143
 func (ps *PeerService) Update(u *model.Peer) error {
81 144
 	return global.DB.Model(u).Updates(u).Error
82 145
 }
146
+

+ 131 - 66
service/user.go

@@ -10,6 +10,8 @@ import (
10 10
 	"math/rand"
11 11
 	"strconv"
12 12
 	"time"
13
+	"strings"
14
+	"errors"
13 15
 )
14 16
 
15 17
 type UserService struct {
@@ -21,12 +23,20 @@ func (us *UserService) InfoById(id uint) *model.User {
21 23
 	global.DB.Where("id = ?", id).First(u)
22 24
 	return u
23 25
 }
26
+// InfoByUsername 根据用户名取用户信息
24 27
 func (us *UserService) InfoByUsername(un string) *model.User {
25 28
 	u := &model.User{}
26 29
 	global.DB.Where("username = ?", un).First(u)
27 30
 	return u
28 31
 }
29 32
 
33
+// InfoByEmail 根据邮箱取用户信息
34
+func (us *UserService) InfoByEmail(email string) *model.User {
35
+	u := &model.User{}
36
+	global.DB.Where("email = ?", email).First(u)
37
+	return u
38
+}
39
+
30 40
 // InfoByOpenid 根据openid取用户信息
31 41
 func (us *UserService) InfoByOpenid(openid string) *model.User {
32 42
 	u := &model.User{}
@@ -65,15 +75,17 @@ func (us *UserService) GenerateToken(u *model.User) string {
65 75
 func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
66 76
 	token := us.GenerateToken(u)
67 77
 	ut := &model.UserToken{
68
-		UserId:    u.Id,
69
-		Token:     token,
70
-		ExpiredAt: time.Now().Add(time.Hour * 24 * 7).Unix(),
78
+		UserId:    	u.Id,
79
+		Token:     	token,
80
+		DeviceUuid: llog.Uuid,
81
+		DeviceId:   llog.DeviceId,
82
+		ExpiredAt: 	time.Now().Add(time.Hour * 24 * 7).Unix(),
71 83
 	}
72 84
 	global.DB.Create(ut)
73 85
 	llog.UserTokenId = ut.UserId
74 86
 	global.DB.Create(llog)
75 87
 	if llog.Uuid != "" {
76
-		AllService.PeerService.UuidBindUserId(llog.Uuid, u.Id)
88
+		AllService.PeerService.UuidBindUserId(llog.DeviceId, llog.Uuid, u.Id)
77 89
 	}
78 90
 	return ut
79 91
 }
@@ -140,18 +152,42 @@ func (us *UserService) CheckUserEnable(u *model.User) bool {
140 152
 
141 153
 // Create 创建
142 154
 func (us *UserService) Create(u *model.User) error {
155
+	// The initial username should be formatted, and the username should be unique
156
+	u.Username = us.formatUsername(u.Username)
143 157
 	u.Password = us.EncryptPassword(u.Password)
144 158
 	res := global.DB.Create(u).Error
145 159
 	return res
146 160
 }
147 161
 
148
-// Logout 退出登录
162
+// GetUuidByToken 根据token和user取uuid
163
+func (us *UserService) GetUuidByToken(u *model.User, token string) string {
164
+	ut := &model.UserToken{}
165
+	err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
166
+	if err != nil {
167
+		return ""
168
+	}
169
+	return ut.DeviceUuid
170
+}
171
+
172
+// Logout 退出登录 -> 删除token, 解绑uuid
149 173
 func (us *UserService) Logout(u *model.User, token string) error {
150
-	return global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error
174
+	uuid := us.GetUuidByToken(u, token)
175
+	err := global.DB.Where("user_id = ? and token = ?", u.Id, token).Delete(&model.UserToken{}).Error
176
+	if err != nil {
177
+		return err
178
+	}
179
+	if uuid != "" {
180
+		AllService.PeerService.UuidUnbindUserId(uuid, u.Id)
181
+	}
182
+	return nil
151 183
 }
152 184
 
153 185
 // Delete 删除用户和oauth信息
154 186
 func (us *UserService) Delete(u *model.User) error {
187
+	userCount := us.getAdminUserCount()
188
+	if userCount <= 1 && us.IsAdmin(u) {
189
+		return errors.New("The last admin user cannot be deleted")
190
+	}
155 191
 	tx := global.DB.Begin()
156 192
 	// 删除用户
157 193
 	if err := tx.Delete(u).Error; err != nil {
@@ -179,11 +215,25 @@ func (us *UserService) Delete(u *model.User) error {
179 215
 		return err
180 216
 	}
181 217
 	tx.Commit()
218
+	// 删除关联的peer
219
+	if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
220
+		tx.Rollback()
221
+		return err
222
+	}
182 223
 	return nil
183 224
 }
184 225
 
185 226
 // Update 更新
186 227
 func (us *UserService) Update(u *model.User) error {
228
+	currentUser := us.InfoById(u.Id)
229
+	// 如果当前用户是管理员并且 IsAdmin 不为空,进行检查
230
+	if us.IsAdmin(currentUser) {
231
+		adminCount := us.getAdminUserCount()
232
+		// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
233
+		if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
234
+			return errors.New("The last admin user cannot be disabled or demoted")
235
+		}
236
+	}
187 237
 	return global.DB.Model(u).Updates(u).Error
188 238
 }
189 239
 
@@ -192,6 +242,16 @@ func (us *UserService) FlushToken(u *model.User) error {
192 242
 	return global.DB.Where("user_id = ?", u.Id).Delete(&model.UserToken{}).Error
193 243
 }
194 244
 
245
+// FlushTokenByUuid 清空token
246
+func (us *UserService) FlushTokenByUuid(uuid string) error {
247
+	return global.DB.Where("device_uuid = ?", uuid).Delete(&model.UserToken{}).Error
248
+}
249
+
250
+// FlushTokenByUuids 清空token
251
+func (us *UserService) FlushTokenByUuids(uuids []string) error {
252
+	return global.DB.Where("device_uuid in (?)", uuids).Delete(&model.UserToken{}).Error
253
+}
254
+
195 255
 // UpdatePassword 更新密码
196 256
 func (us *UserService) UpdatePassword(u *model.User, password string) error {
197 257
 	u.Password = us.EncryptPassword(password)
@@ -216,24 +276,9 @@ func (us *UserService) RouteNames(u *model.User) []string {
216 276
 	return adResp.UserRouteNames
217 277
 }
218 278
 
219
-// InfoByGithubId 根据githubid取用户信息
220
-func (us *UserService) InfoByGithubId(githubId string) *model.User {
221
-	return us.InfoByOauthId(model.OauthTypeGithub, githubId)
222
-}
223
-
224
-// InfoByGoogleEmail 根据googleid取用户信息
225
-func (us *UserService) InfoByGoogleEmail(email string) *model.User {
226
-	return us.InfoByOauthId(model.OauthTypeGithub, email)
227
-}
228
-
229
-// InfoByOidcSub 根据oidc取用户信息
230
-func (us *UserService) InfoByOidcSub(sub string) *model.User {
231
-	return us.InfoByOauthId(model.OauthTypeOidc, sub)
232
-}
233
-
234
-// InfoByOauthId 根据oauth取用户信息
235
-func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User {
236
-	ut := AllService.OauthService.UserThirdInfo(thirdType, uid)
279
+// InfoByOauthId 根据oauth的name和openId取用户信息
280
+func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
281
+	ut := AllService.OauthService.UserThirdInfo(op, openId)
237 282
 	if ut.Id == 0 {
238 283
 		return nil
239 284
 	}
@@ -244,55 +289,52 @@ func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User {
244 289
 	return u
245 290
 }
246 291
 
247
-// RegisterByGithub 注册
248
-func (us *UserService) RegisterByGithub(githubName string, githubId string) *model.User {
249
-	return us.RegisterByOauth(model.OauthTypeGithub, githubName, githubId)
250
-}
251
-
252
-// RegisterByGoogle 注册
253
-func (us *UserService) RegisterByGoogle(name string, email string) *model.User {
254
-	return us.RegisterByOauth(model.OauthTypeGoogle, name, email)
255
-}
256
-
257
-// RegisterByOidc 注册, use PreferredUsername as username, sub as openid
258
-func (us *UserService) RegisterByOidc(PreferredUsername string, sub string) *model.User {
259
-	return us.RegisterByOauth(model.OauthTypeOidc, PreferredUsername, sub)
260
-}
261
-
262 292
 // RegisterByOauth 注册
263
-func (us *UserService) RegisterByOauth(thirdType, thirdName, uid string) *model.User {
293
+func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) {
264 294
 	global.Lock.Lock("registerByOauth")
265 295
 	defer global.Lock.UnLock("registerByOauth")
266
-	ut := AllService.OauthService.UserThirdInfo(thirdType, uid)
296
+	ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
267 297
 	if ut.Id != 0 {
268
-		u := &model.User{}
269
-		global.DB.Where("id = ?", ut.UserId).First(u)
270
-		return u
271
-	}
272
-
273
-	tx := global.DB.Begin()
274
-	ut = &model.UserThird{
275
-		OpenId:    uid,
276
-		ThirdName: thirdName,
277
-		ThirdType: thirdType,
298
+		return nil, us.InfoById(ut.UserId)
278 299
 	}
279
-
280
-	username := us.GenerateUsernameByOauth(thirdName)
281
-	u := &model.User{
282
-		Username: username,
283
-		GroupId:  1,
300
+	//check if this email has been registered 
301
+	email := oauthUser.Email
302
+	err, oauthType := AllService.OauthService.GetTypeByOp(op)
303
+	if err != nil {
304
+		return err, nil
284 305
 	}
285
-	tx.Create(u)
286
-	if u.Id == 0 {
287
-		tx.Rollback()
288
-		return u
306
+	// if email is empty, use username and op as email
307
+	if email == "" {
308
+		email = oauthUser.Username + "@" + op
309
+	} 
310
+	email = strings.ToLower(email)
311
+	// update email to oauthUser, in case it contain upper case
312
+	oauthUser.Email = email
313
+	user := us.InfoByEmail(email)
314
+	tx := global.DB.Begin()
315
+	if user.Id != 0 {
316
+		ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
317
+	} else {
318
+		ut = &model.UserThird{}
319
+		ut.FromOauthUser(0, oauthUser, oauthType, op)
320
+		// The initial username should be formatted
321
+		username := us.formatUsername(oauthUser.Username)
322
+		usernameUnique := us.GenerateUsernameByOauth(username)
323
+		user = &model.User{
324
+			Username: usernameUnique,
325
+			GroupId:  1,
326
+		}
327
+		oauthUser.ToUser(user, false)
328
+		tx.Create(user)
329
+		if user.Id == 0 {
330
+			tx.Rollback()
331
+			return errors.New("OauthRegisterFailed"), user
332
+		}
333
+		ut.UserId = user.Id
289 334
 	}
290
-
291
-	ut.UserId = u.Id
292 335
 	tx.Create(ut)
293
-
294 336
 	tx.Commit()
295
-	return u
337
+	return nil, user
296 338
 }
297 339
 
298 340
 // GenerateUsernameByOauth 生成用户名
@@ -314,7 +356,7 @@ func (us *UserService) UserThirdsByUserId(userId uint) (res []*model.UserThird)
314 356
 
315 357
 func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird {
316 358
 	ut := &model.UserThird{}
317
-	global.DB.Where("user_id = ? and third_type = ?", userId, op).First(ut)
359
+	global.DB.Where("user_id = ? and op = ?", userId, op).First(ut)
318 360
 	return ut
319 361
 }
320 362
 
@@ -348,9 +390,11 @@ func (us *UserService) IsPasswordEmptyByUser(u *model.User) bool {
348 390
 	return us.IsPasswordEmptyById(u.Id)
349 391
 }
350 392
 
351
-func (us *UserService) Register(username string, password string) *model.User {
393
+// Register 注册
394
+func (us *UserService) Register(username string, email string, password string) *model.User {
352 395
 	u := &model.User{
353 396
 		Username: username,
397
+		Email:    email,
354 398
 		Password: us.EncryptPassword(password),
355 399
 		GroupId:  1,
356 400
 	}
@@ -381,3 +425,24 @@ func (us *UserService) TokenInfoById(id uint) *model.UserToken {
381 425
 func (us *UserService) DeleteToken(l *model.UserToken) error {
382 426
 	return global.DB.Delete(l).Error
383 427
 }
428
+
429
+// Helper functions, used for formatting username
430
+func (us *UserService) formatUsername(username string) string {
431
+	username = strings.ReplaceAll(username, " ", "")
432
+	username = strings.ToLower(username)
433
+	return username
434
+}
435
+
436
+//  Helper functions, getUserCount
437
+func (us *UserService) getUserCount() int64 {
438
+	var count int64
439
+	global.DB.Model(&model.User{}).Count(&count)
440
+	return count
441
+}
442
+
443
+// helper functions, getAdminUserCount
444
+func (us *UserService) getAdminUserCount() int64 {
445
+	var count int64
446
+	global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
447
+	return count
448
+}