Browse Source

re-construct oauth

Tao Chen 1 year ago
parent
commit
7707cc116f

+ 2 - 0
http/controller/admin/login.go

@@ -63,6 +63,8 @@ func (ct *Login) Login(c *gin.Context) {
63 63
 	response.Success(c, &adResp.LoginPayload{
64 64
 		Token:      ut.Token,
65 65
 		Username:   u.Username,
66
+		Email:      u.Email,
67
+		Avatar:     u.Avatar,
66 68
 		RouteNames: service.AllService.UserService.RouteNames(u),
67 69
 		Nickname:   u.Nickname,
68 70
 	})

+ 13 - 29
http/controller/admin/oauth.go

@@ -5,7 +5,6 @@ import (
5 5
 	"Gwen/http/request/admin"
6 6
 	adminReq "Gwen/http/request/admin"
7 7
 	"Gwen/http/response"
8
-	"Gwen/model"
9 8
 	"Gwen/service"
10 9
 	"github.com/gin-gonic/gin"
11 10
 	"strconv"
@@ -96,21 +95,23 @@ func (o *Oauth) BindConfirm(c *gin.Context) {
96 95
 		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError"))
97 96
 		return
98 97
 	}
99
-	v := service.AllService.OauthService.GetOauthCache(j.Code)
100
-	if v == nil {
98
+	oauthService := service.AllService.OauthService
99
+	oauthCache := oauthService.GetOauthCache(j.Code)
100
+	if oauthCache == nil {
101 101
 		response.Fail(c, 101, response.TranslateMsg(c, "OauthExpired"))
102 102
 		return
103 103
 	}
104
-	u := service.AllService.UserService.CurUser(c)
105
-	err = service.AllService.OauthService.BindOauthUser(v.Op, v.ThirdOpenId, v.ThirdName, u.Id)
104
+	oauthUser := oauthCache.ToOauthUser()
105
+	user := service.AllService.UserService.CurUser(c)
106
+	err = oauthService.BindOauthUser(user.Id, oauthUser, oauthCache.Op)
106 107
 	if err != nil {
107 108
 		response.Fail(c, 101, response.TranslateMsg(c, "BindFail"))
108 109
 		return
109 110
 	}
110 111
 
111
-	v.UserId = u.Id
112
-	service.AllService.OauthService.SetOauthCache(j.Code, v, 0)
113
-	response.Success(c, v)
112
+	oauthCache.UserId = user.Id
113
+	oauthService.SetOauthCache(j.Code, oauthCache, 0)
114
+	response.Success(c, oauthCache)
114 115
 }
115 116
 
116 117
 func (o *Oauth) Unbind(c *gin.Context) {
@@ -126,28 +127,11 @@ func (o *Oauth) Unbind(c *gin.Context) {
126 127
 		response.Fail(c, 101, response.TranslateMsg(c, "ItemNotFound"))
127 128
 		return
128 129
 	}
129
-	if f.Op == model.OauthTypeGithub {
130
-		err = service.AllService.OauthService.UnBindGithubUser(u.Id)
131
-		if err != nil {
132
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
133
-			return
134
-		}
135
-	}
136
-	if f.Op == model.OauthTypeGoogle {
137
-		err = service.AllService.OauthService.UnBindGoogleUser(u.Id)
138
-		if err != nil {
139
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
140
-			return
141
-		}
142
-	}
143
-	if f.Op == model.OauthTypeOidc {
144
-		err = service.AllService.OauthService.UnBindOidcUser(u.Id)
145
-		if err != nil {
146
-			response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
147
-			return
148
-		}
130
+	err = service.AllService.OauthService.UnBindOauthUser(u.Id, f.Op)
131
+	if err != nil {
132
+		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
133
+		return
149 134
 	}
150
-
151 135
 	response.Success(c, nil)
152 136
 }
153 137
 

+ 2 - 2
http/controller/admin/user.go

@@ -286,10 +286,10 @@ func (ct *User) MyOauth(c *gin.Context) {
286 286
 	var res []*adResp.UserOauthItem
287 287
 	for _, oa := range oal.Oauths {
288 288
 		item := &adResp.UserOauthItem{
289
-			ThirdType: oa.Op,
289
+			Op: oa.Op,
290 290
 		}
291 291
 		for _, ut := range uts {
292
-			if ut.ThirdType == oa.Op {
292
+			if ut.Op == oa.Op {
293 293
 				item.Status = 1
294 294
 				break
295 295
 			}

+ 4 - 16
http/controller/api/login.go

@@ -83,22 +83,10 @@ func (l *Login) Login(c *gin.Context) {
83 83
 // @Failure 500 {object} response.ErrorResponse
84 84
 // @Router /login-options [get]
85 85
 func (l *Login) LoginOptions(c *gin.Context) {
86
-	oauthOks := []string{}
87
-	err, _ := service.AllService.OauthService.GetOauthConfig(model.OauthTypeGithub)
88
-	if err == nil {
89
-		oauthOks = append(oauthOks, model.OauthTypeGithub)
90
-	}
91
-	err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeGoogle)
92
-	if err == nil {
93
-		oauthOks = append(oauthOks, model.OauthTypeGoogle)
94
-	}
95
-	err, _ = service.AllService.OauthService.GetOauthConfig(model.OauthTypeOidc)
96
-	if err == nil {
97
-		oauthOks = append(oauthOks, model.OauthTypeOidc)
98
-	}
99
-	oauthOks = append(oauthOks, model.OauthTypeWebauth)
86
+	ops := service.AllService.OauthService.GetOauthProviders()
87
+	ops = append(ops, model.OauthTypeWebauth)
100 88
 	var oidcItems []map[string]string
101
-	for _, v := range oauthOks {
89
+	for _, v := range ops {
102 90
 		oidcItems = append(oidcItems, map[string]string{"name": v})
103 91
 	}
104 92
 	common, err := json.Marshal(oidcItems)
@@ -108,7 +96,7 @@ func (l *Login) LoginOptions(c *gin.Context) {
108 96
 	}
109 97
 	var res []string
110 98
 	res = append(res, "common-oidc/"+string(common))
111
-	for _, v := range oauthOks {
99
+	for _, v := range ops {
112 100
 		res = append(res, "oidc/"+v)
113 101
 	}
114 102
 	c.JSON(http.StatusOK, res)

+ 40 - 66
http/controller/api/ouath.go

@@ -9,8 +9,6 @@ import (
9 9
 	"Gwen/service"
10 10
 	"github.com/gin-gonic/gin"
11 11
 	"net/http"
12
-	"strconv"
13
-	"strings"
14 12
 )
15 13
 
16 14
 type Oauth struct {
@@ -32,13 +30,17 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
32 30
 		response.Error(c, response.TranslateMsg(c, "ParamsError")+err.Error())
33 31
 		return
34 32
 	}
35
-	//fmt.Println(f)
36
-	if f.Op != model.OauthTypeWebauth && f.Op != model.OauthTypeGoogle && f.Op != model.OauthTypeGithub && f.Op != model.OauthTypeOidc {
37
-		response.Error(c, response.TranslateMsg(c, "ParamsError"))
33
+
34
+	oauthService := service.AllService.OauthService
35
+	err = oauthService.ValidateOauthProvider(f.Op)
36
+	if err != nil {
37
+		response.Error(c, response.TranslateMsg(c, err.Error()))
38 38
 		return
39 39
 	}
40 40
 
41
-	err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
41
+	var code string
42
+	var url string
43
+	err, code, url = oauthService.BeginAuth(f.Op)
42 44
 	if err != nil {
43 45
 		response.Error(c, response.TranslateMsg(c, err.Error()))
44 46
 		return
@@ -149,70 +151,43 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
149 151
 		c.String(http.StatusInternalServerError, response.TranslateParamMsg(c, "ParamIsEmpty", "state"))
150 152
 		return
151 153
 	}
152
-
153 154
 	cacheKey := state
155
+	oauthService := service.AllService.OauthService
154 156
 	//从缓存中获取
155
-	v := service.AllService.OauthService.GetOauthCache(cacheKey)
156
-	if v == nil {
157
+	oauthCache := oauthService.GetOauthCache(cacheKey)
158
+	if oauthCache == nil {
157 159
 		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthExpired"))
158 160
 		return
159 161
 	}
160
-
161
-	ty := v.Op
162
-	ac := v.Action
163
-	var u *model.User
164
-	openid := ""
165
-	thirdName := ""
166
-	//fmt.Println("ty ac ", ty, ac)
167
-
168
-	if ty == model.OauthTypeGithub {
169
-		code := c.Query("code")
170
-		err, userData := service.AllService.OauthService.GithubCallback(code)
171
-		if err != nil {
172
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
173
-			return
174
-		}
175
-		openid = strconv.Itoa(userData.Id)
176
-		thirdName = userData.Login
177
-	} else if ty == model.OauthTypeGoogle {
178
-		code := c.Query("code")
179
-		err, userData := service.AllService.OauthService.GoogleCallback(code)
180
-		if err != nil {
181
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
182
-			return
183
-		}
184
-		openid = userData.Email
185
-		//将空格替换成_
186
-		thirdName = strings.Replace(userData.Name, " ", "_", -1)
187
-	} else if ty == model.OauthTypeOidc {
188
-		code := c.Query("code")
189
-		err, userData := service.AllService.OauthService.OidcCallback(code)
190
-		if err != nil {
191
-			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
192
-			return
193
-		}
194
-		openid = userData.Sub
195
-		thirdName = userData.PreferredUsername
196
-	} else {
197
-		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ParamsError"))
162
+	op := oauthCache.Op
163
+	action := oauthCache.Action
164
+	var user *model.User
165
+	// 获取用户信息
166
+	code := c.Query("code")
167
+	err, oauthUser := oauthService.Callback(code, op)
168
+	if err != nil {
169
+		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
198 170
 		return
199 171
 	}
200
-	if ac == service.OauthActionTypeBind {
172
+	userId := oauthCache.UserId
173
+	openid := oauthUser.OpenId
174
+	if action == service.OauthActionTypeBind {
201 175
 
202 176
 		//fmt.Println("bind", ty, userData)
203
-		utr := service.AllService.OauthService.UserThirdInfo(ty, openid)
177
+		// 检查此openid是否已经绑定过
178
+		utr := oauthService.UserThirdInfo(op, openid)
204 179
 		if utr.UserId > 0 {
205 180
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBindOtherUser"))
206 181
 			return
207 182
 		}
208 183
 		//绑定
209
-		u = service.AllService.UserService.InfoById(v.UserId)
210
-		if u == nil {
184
+		user = service.AllService.UserService.InfoById(userId)
185
+		if user == nil {
211 186
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "ItemNotFound"))
212 187
 			return
213 188
 		}
214 189
 		//绑定
215
-		err := service.AllService.OauthService.BindOauthUser(ty, openid, thirdName, v.UserId)
190
+		err := oauthService.BindOauthUser(userId, oauthUser, op)
216 191
 		if err != nil {
217 192
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "BindFail"))
218 193
 			return
@@ -220,42 +195,41 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
220 195
 		c.String(http.StatusOK, response.TranslateMsg(c, "BindSuccess"))
221 196
 		return
222 197
 
223
-	} else if ac == service.OauthActionTypeLogin {
198
+	} else if action == service.OauthActionTypeLogin {
224 199
 		//登录
225
-		if v.UserId != 0 {
200
+		if userId != 0 {
226 201
 			c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthHasBeenSuccess"))
227 202
 			return
228 203
 		}
229
-		u = service.AllService.UserService.InfoByGithubId(openid)
230
-		if u == nil {
231
-			oa := service.AllService.OauthService.InfoByOp(ty)
232
-			if !*oa.AutoRegister {
204
+		user = service.AllService.UserService.InfoByOauthId(op, openid)
205
+		if user == nil {
206
+			oauthConfig := oauthService.InfoByOp(op)
207
+			if !*oauthConfig.AutoRegister {
233 208
 				//c.String(http.StatusInternalServerError, "还未绑定用户,请先绑定")
234
-				v.ThirdName = thirdName
235
-				v.ThirdOpenId = openid
209
+				oauthCache.UpdateFromOauthUser(oauthUser)
236 210
 				url := global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/bind/" + cacheKey
237 211
 				c.Redirect(http.StatusFound, url)
238 212
 				return
239 213
 			}
240 214
 
241 215
 			//自动注册
242
-			u = service.AllService.UserService.RegisterByOauth(ty, thirdName, openid)
243
-			if u.Id == 0 {
216
+			user = service.AllService.UserService.RegisterByOauth(oauthUser, op)
217
+			if user.Id == 0 {
244 218
 				c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthRegisterFailed"))
245 219
 				return
246 220
 			}
247 221
 		}
248
-		v.UserId = u.Id
249
-		service.AllService.OauthService.SetOauthCache(cacheKey, v, 0)
222
+		oauthCache.UserId = user.Id
223
+		oauthService.SetOauthCache(cacheKey, oauthCache, 0)
250 224
 		// 如果是webadmin,登录成功后跳转到webadmin
251
-		if v.DeviceType == "webadmin" {
225
+		if oauthCache.DeviceType == "webadmin" {
252 226
 			/*service.AllService.UserService.Login(u, &model.LoginLog{
253 227
 				UserId:   u.Id,
254 228
 				Client:   "webadmin",
255 229
 				Uuid:     "", //must be empty
256 230
 				Ip:       c.ClientIP(),
257 231
 				Type:     model.LoginLogTypeOauth,
258
-				Platform: v.DeviceOs,
232
+				Platform: oauthService.DeviceOs,
259 233
 			})*/
260 234
 			url := global.Config.Rustdesk.ApiServer + "/_admin/#/"
261 235
 			c.Redirect(http.StatusFound, url)

+ 30 - 9
http/request/admin/oauth.go

@@ -1,6 +1,9 @@
1 1
 package admin
2 2
 
3
-import "Gwen/model"
3
+import (
4
+	"Gwen/model"
5
+	"strings"
6
+)
4 7
 
5 8
 type BindOauthForm struct {
6 9
 	Op string `json:"op" binding:"required"`
@@ -13,19 +16,37 @@ type UnBindOauthForm struct {
13 16
 	Op string `json:"op" binding:"required"`
14 17
 }
15 18
 type OauthForm struct {
16
-	Id           uint   `json:"id"`
17
-	Op           string `json:"op" validate:"required"`
18
-	Issuer	     string `json:"issuer" validate:"omitempty,url"`
19
-	Scopes	   	 string `json:"scopes" validate:"omitempty"`
20
-	ClientId     string `json:"client_id" validate:"required"`
21
-	ClientSecret string `json:"client_secret" validate:"required"`
22
-	RedirectUrl  string `json:"redirect_url" validate:"required"`
23
-	AutoRegister *bool  `json:"auto_register"`
19
+	Id           uint   			`json:"id"`
20
+	Op           string 			`json:"op" validate:"omitempty"`
21
+	OauthType    string 			`json:"oauth_type" validate:"required"`
22
+	Issuer	     string 			`json:"issuer" validate:"omitempty,url"`
23
+	Scopes	   	 string 			`json:"scopes" validate:"omitempty"`
24
+	ClientId     string 			`json:"client_id" validate:"required"`
25
+	ClientSecret string 			`json:"client_secret" validate:"required"`
26
+	RedirectUrl  string 			`json:"redirect_url" validate:"required"`
27
+	AutoRegister *bool  			`json:"auto_register"`
24 28
 }
25 29
 
26 30
 func (of *OauthForm) ToOauth() *model.Oauth {
31
+	op := strings.ToLower(of.Op)
32
+	op = strings.TrimSpace(op)
33
+	if op == "" {
34
+		switch of.OauthType {
35
+		case model.OauthTypeGithub:
36
+			of.Op = "GitHub"
37
+		case model.OauthTypeGoogle:
38
+			of.Op = "Google"
39
+		case model.OauthTypeOidc:
40
+			of.Op = "OIDC"
41
+		case model.OauthTypeWebauth:
42
+			of.Op = "WebAuth"
43
+		default:
44
+			of.Op = of.OauthType
45
+		}
46
+	}
27 47
 	oa := &model.Oauth{
28 48
 		Op:           of.Op,
49
+		OauthType:	  of.OauthType,
29 50
 		ClientId:     of.ClientId,
30 51
 		ClientSecret: of.ClientSecret,
31 52
 		RedirectUrl:  of.RedirectUrl,

+ 10 - 7
http/request/admin/user.go

@@ -5,20 +5,22 @@ import (
5 5
 )
6 6
 
7 7
 type UserForm struct {
8
-	Id       uint   `json:"id"`
9
-	Username string `json:"username" validate:"required,gte=4,lte=10"`
8
+	Id       uint   			`json:"id"`
9
+	Username string 			`json:"username" validate:"required,gte=4,lte=10"`
10
+	Email	 string           	`json:"email" validate:"required,email"`
10 11
 	//Password string           `json:"password" validate:"required,gte=4,lte=20"`
11
-	Nickname string           `json:"nickname"`
12
-	Avatar   string           `json:"avatar"`
13
-	GroupId  uint             `json:"group_id" validate:"required"`
14
-	IsAdmin  *bool            `json:"is_admin" `
15
-	Status   model.StatusCode `json:"status" validate:"required,gte=0"`
12
+	Nickname string           	`json:"nickname"`
13
+	Avatar   string           	`json:"avatar"`
14
+	GroupId  uint             	`json:"group_id" validate:"required"`
15
+	IsAdmin  *bool            	`json:"is_admin" `
16
+	Status   model.StatusCode 	`json:"status" validate:"required,gte=0"`
16 17
 }
17 18
 
18 19
 func (uf *UserForm) FromUser(user *model.User) *UserForm {
19 20
 	uf.Id = user.Id
20 21
 	uf.Username = user.Username
21 22
 	uf.Nickname = user.Nickname
23
+	uf.Email = user.Email
22 24
 	uf.Avatar = user.Avatar
23 25
 	uf.GroupId = user.GroupId
24 26
 	uf.IsAdmin = user.IsAdmin
@@ -30,6 +32,7 @@ func (uf *UserForm) ToUser() *model.User {
30 32
 	user.Id = uf.Id
31 33
 	user.Username = uf.Username
32 34
 	user.Nickname = uf.Nickname
35
+	user.Email = uf.Email
33 36
 	user.Avatar = uf.Avatar
34 37
 	user.GroupId = uf.GroupId
35 38
 	user.IsAdmin = uf.IsAdmin

+ 4 - 2
http/response/admin/user.go

@@ -4,6 +4,8 @@ import "Gwen/model"
4 4
 
5 5
 type LoginPayload struct {
6 6
 	Username   string   `json:"username"`
7
+	Email	   string   `json:"email"`
8
+	Avatar	   string   `json:"avatar"`
7 9
 	Token      string   `json:"token"`
8 10
 	RouteNames []string `json:"route_names"`
9 11
 	Nickname   string   `json:"nickname"`
@@ -15,8 +17,8 @@ var UserRouteNames = []string{
15 17
 var AdminRouteNames = []string{"*"}
16 18
 
17 19
 type UserOauthItem struct {
18
-	ThirdType string `json:"third_type"`
19
-	Status    int    `json:"status"`
20
+	Op 			string `json:"op"`
21
+	Status    	int    `json:"status"`
20 22
 }
21 23
 
22 24
 type GroupUsersPayload struct {

+ 1 - 0
http/response/api/user.go

@@ -29,6 +29,7 @@ type UserPayload struct {
29 29
 
30 30
 func (up *UserPayload) FromUser(user *model.User) *UserPayload {
31 31
 	up.Name = user.Username
32
+	up.Email = user.Email
32 33
 	up.IsAdmin = user.IsAdmin
33 34
 	up.Status = int(user.Status)
34 35
 	up.Info = map[string]interface{}{}

+ 100 - 13
model/oauth.go

@@ -1,23 +1,110 @@
1 1
 package model
2 2
 
3
+import (
4
+	"strconv"
5
+	"fmt"
6
+)
7
+
8
+
9
+const (
10
+	OauthTypeGithub  string = "github"
11
+	OauthTypeGoogle  string = "google"
12
+	OauthTypeOidc    string = "oidc"
13
+	OauthTypeWebauth string = "webauth"
14
+)
15
+
16
+
3 17
 type Oauth struct {
4 18
 	IdModel
5
-	Op           string `json:"op"`
6
-	ClientId     string `json:"client_id"`
7
-	ClientSecret string `json:"client_secret"`
8
-	RedirectUrl  string `json:"redirect_url"`
9
-	AutoRegister *bool  `json:"auto_register"`
10
-	Scopes       string `json:"scopes"`
11
-	Issuer	     string `json:"issuer"`
19
+	Op           string 	`json:"op"`
20
+	OauthType    string 	`json:"oauth_type"`
21
+	ClientId     string 	`json:"client_id"`
22
+	ClientSecret string 	`json:"client_secret"`
23
+	RedirectUrl  string 	`json:"redirect_url"`
24
+	AutoRegister *bool  	`json:"auto_register"`
25
+	Scopes       string 	`json:"scopes"`
26
+	Issuer	     string 	`json:"issuer"`
12 27
 	TimeModel
13 28
 }
14 29
 
15
-const (
16
-	OauthTypeGithub  = "github"
17
-	OauthTypeGoogle  = "google"
18
-	OauthTypeOidc    = "oidc"
19
-	OauthTypeWebauth = "webauth"
20
-)
30
+type OauthUser struct {
31
+	OpenId 			string 	`json:"open_id" gorm:"not null;index"`
32
+	Name   			string 	`json:"name"`
33
+	Username 		string 	`json:"username"`
34
+	Email  			string 	`json:"email"`
35
+	VerifiedEmail 	bool 	`json:"verified_email,omitempty"`
36
+}
37
+
38
+func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
39
+	if overideUsername {
40
+		user.Username = ou.Username
41
+	}
42
+	user.Email = ou.Email
43
+	user.Nickname = ou.Name
44
+
45
+}
46
+
47
+
48
+type OauthUserBase struct {
49
+	Name  string `json:"name"`
50
+	Email string `json:"email"`
51
+}
52
+
53
+
54
+type OidcUser struct {
55
+	OauthUserBase
56
+	Sub               string `json:"sub"`
57
+	VerifiedEmail     bool   `json:"email_verified"`
58
+	PreferredUsername string `json:"preferred_username"`
59
+}
60
+
61
+func (ou *OidcUser) ToOauthUser() *OauthUser {
62
+	return &OauthUser{
63
+		OpenId: 		ou.Sub,
64
+		Name:   		ou.Name,
65
+		Username: 		ou.PreferredUsername,
66
+		Email:  		ou.Email,
67
+		VerifiedEmail: 	ou.VerifiedEmail,
68
+	}
69
+}
70
+
71
+type GoogleUser struct {
72
+	OauthUserBase
73
+	FamilyName    string `json:"family_name"`
74
+	GivenName     string `json:"given_name"`
75
+	Id            string `json:"id"`
76
+	Picture       string `json:"picture"`
77
+	VerifiedEmail bool   `json:"verified_email"`
78
+}
79
+
80
+func (gu *GoogleUser) ToOauthUser() *OauthUser {
81
+	return &OauthUser{
82
+		OpenId: 		gu.Id,
83
+		Name:   		fmt.Sprintf("%s %s", gu.GivenName, gu.FamilyName),
84
+		Username: 		gu.GivenName,
85
+		Email:  		gu.Email,
86
+		VerifiedEmail: 	gu.VerifiedEmail,
87
+	}	
88
+}
89
+
90
+
91
+type GithubUser struct {
92
+	OauthUserBase
93
+	Id                int         `json:"id"`
94
+	Login             string      `json:"login"`
95
+}
96
+
97
+func (gu *GithubUser) ToOauthUser() *OauthUser {
98
+	return &OauthUser{
99
+		OpenId: 		strconv.Itoa(gu.Id),
100
+		Name:   		gu.Name,
101
+		Username: 		gu.Login,
102
+		Email:  		gu.Email,
103
+		VerifiedEmail: 	true,
104
+	}
105
+}
106
+
107
+
21 108
 
22 109
 type OauthList struct {
23 110
 	Oauths []*Oauth `json:"list"`

+ 16 - 0
model/user.go

@@ -1,8 +1,15 @@
1 1
 package model
2 2
 
3
+import (
4
+	"fmt"
5
+	"gorm.io/gorm"
6
+)
7
+
3 8
 type User struct {
4 9
 	IdModel
5 10
 	Username string     `json:"username" gorm:"default:'';not null;uniqueIndex"`
11
+	Email	string     	`json:"email" gorm:"default:'';not null;uniqueIndex"`
12
+	// Email	string     	`json:"email" `
6 13
 	Password string     `json:"-" gorm:"default:'';not null;"`
7 14
 	Nickname string     `json:"nickname" gorm:"default:'';not null;"`
8 15
 	Avatar   string     `json:"avatar" gorm:"default:'';not null;"`
@@ -12,6 +19,15 @@ type User struct {
12 19
 	TimeModel
13 20
 }
14 21
 
22
+// BeforeSave 钩子用于确保 email 字段有合理的默认值
23
+func (u *User) BeforeSave(tx *gorm.DB) (err error) {
24
+    // 如果 email 为空,设置为默认值
25
+    if u.Email == "" {
26
+        u.Email = fmt.Sprintf("%s@example.com", u.Username)
27
+    }
28
+    return nil
29
+}
30
+
15 31
 type UserList struct {
16 32
 	Users []*User `json:"list,omitempty"`
17 33
 	Pagination

+ 13 - 6
model/userThird.go

@@ -2,11 +2,18 @@ package model
2 2
 
3 3
 type UserThird struct {
4 4
 	IdModel
5
-	UserId     uint   `json:"user_id" gorm:"not null;index"`
6
-	OpenId     string `json:"open_id" gorm:"not null;index"`
7
-	UnionId    string `json:"union_id" gorm:"not null;"`
8
-	ThirdType  string `json:"third_type" gorm:"not null;"`
9
-	ThirdEmail string `json:"third_email"`
10
-	ThirdName  string `json:"third_name"`
5
+	UserId     		uint   `	json:"user_id" gorm:"not null;index"`
6
+	OauthUser
7
+	// UnionId    		string `json:"union_id" gorm:"not null;"`
8
+	// OauthType  	   	string 		`json:"oauth_type" gorm:"not null;"`
9
+	OauthType  	   	string 		`json:"oauth_type"`
10
+	Op  			string 		`json:"op" gorm:"not null;"`
11 11
 	TimeModel
12 12
 }
13
+
14
+func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
15
+	u.UserId 			= userId
16
+	u.OauthUser 		= *oauthUser
17
+	u.OauthType 		= oauthType
18
+	u.Op 				= op
19
+}

+ 211 - 274
service/oauth.go

@@ -11,15 +11,20 @@ import (
11 11
 	"golang.org/x/oauth2/github"
12 12
 	"golang.org/x/oauth2/google"
13 13
 	"gorm.io/gorm"
14
-	"io"
14
+	// "io"
15 15
 	"net/http"
16 16
 	"net/url"
17 17
 	"strconv"
18 18
 	"strings"
19 19
 	"sync"
20 20
 	"time"
21
+	"fmt"
21 22
 )
22 23
 
24
+
25
+type OauthService struct {
26
+}
27
+
23 28
 // Define a struct to parse the .well-known/openid-configuration response
24 29
 type OidcEndpoint struct {
25 30
 	Issuer   string `json:"issuer"`
@@ -28,73 +33,6 @@ type OidcEndpoint struct {
28 33
 	UserInfo string `json:"userinfo_endpoint"`
29 34
 }
30 35
 
31
-type OauthService struct {
32
-}
33
-
34
-type GithubUserdata struct {
35
-	AvatarUrl         string      `json:"avatar_url"`
36
-	Bio               string      `json:"bio"`
37
-	Blog              string      `json:"blog"`
38
-	Collaborators     int         `json:"collaborators"`
39
-	Company           interface{} `json:"company"`
40
-	CreatedAt         time.Time   `json:"created_at"`
41
-	DiskUsage         int         `json:"disk_usage"`
42
-	Email             interface{} `json:"email"`
43
-	EventsUrl         string      `json:"events_url"`
44
-	Followers         int         `json:"followers"`
45
-	FollowersUrl      string      `json:"followers_url"`
46
-	Following         int         `json:"following"`
47
-	FollowingUrl      string      `json:"following_url"`
48
-	GistsUrl          string      `json:"gists_url"`
49
-	GravatarId        string      `json:"gravatar_id"`
50
-	Hireable          interface{} `json:"hireable"`
51
-	HtmlUrl           string      `json:"html_url"`
52
-	Id                int         `json:"id"`
53
-	Location          interface{} `json:"location"`
54
-	Login             string      `json:"login"`
55
-	Name              string      `json:"name"`
56
-	NodeId            string      `json:"node_id"`
57
-	NotificationEmail interface{} `json:"notification_email"`
58
-	OrganizationsUrl  string      `json:"organizations_url"`
59
-	OwnedPrivateRepos int         `json:"owned_private_repos"`
60
-	Plan              struct {
61
-		Collaborators int    `json:"collaborators"`
62
-		Name          string `json:"name"`
63
-		PrivateRepos  int    `json:"private_repos"`
64
-		Space         int    `json:"space"`
65
-	} `json:"plan"`
66
-	PrivateGists      int    `json:"private_gists"`
67
-	PublicGists       int    `json:"public_gists"`
68
-	PublicRepos       int    `json:"public_repos"`
69
-	ReceivedEventsUrl string `json:"received_events_url"`
70
-	ReposUrl          string `json:"repos_url"`
71
-	SiteAdmin         bool   `json:"site_admin"`
72
-	StarredUrl        string `json:"starred_url"`
73
-	SubscriptionsUrl  string `json:"subscriptions_url"`
74
-	TotalPrivateRepos int    `json:"total_private_repos"`
75
-	//TwitterUsername         interface{} `json:"twitter_username"`
76
-	TwoFactorAuthentication bool      `json:"two_factor_authentication"`
77
-	Type                    string    `json:"type"`
78
-	UpdatedAt               time.Time `json:"updated_at"`
79
-	Url                     string    `json:"url"`
80
-}
81
-type GoogleUserdata struct {
82
-	Email         string `json:"email"`
83
-	FamilyName    string `json:"family_name"`
84
-	GivenName     string `json:"given_name"`
85
-	Id            string `json:"id"`
86
-	Name          string `json:"name"`
87
-	Picture       string `json:"picture"`
88
-	VerifiedEmail bool   `json:"verified_email"`
89
-}
90
-type OidcUserdata struct {
91
-	Sub               string `json:"sub"`
92
-	Email             string `json:"email"`
93
-	VerifiedEmail     bool   `json:"email_verified"`
94
-	Name              string `json:"name"`
95
-	PreferredUsername string `json:"preferred_username"`
96
-}
97
-
98 36
 type OauthCacheItem struct {
99 37
 	UserId      uint   `json:"user_id"`
100 38
 	Id          string `json:"id"` //rustdesk的设备ID
@@ -104,9 +42,19 @@ type OauthCacheItem struct {
104 42
 	DeviceName  string `json:"device_name"`
105 43
 	DeviceOs    string `json:"device_os"`
106 44
 	DeviceType  string `json:"device_type"`
107
-	ThirdOpenId string `json:"third_open_id"`
108
-	ThirdName   string `json:"third_name"`
109
-	ThirdEmail  string `json:"third_email"`
45
+	OpenId 		string `json:"open_id"`
46
+	Username	string `json:"username"`
47
+	Name   		string `json:"name"`
48
+	Email  		string `json:"email"`
49
+}
50
+
51
+func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
52
+	return &model.OauthUser{
53
+		OpenId: oci.OpenId,
54
+		Username: oci.Username,
55
+		Name: oci.Name,
56
+		Email: oci.Email,
57
+	}
110 58
 }
111 59
 
112 60
 var OauthCache = &sync.Map{}
@@ -116,6 +64,24 @@ const (
116 64
 	OauthActionTypeBind  = "bind"
117 65
 )
118 66
 
67
+func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
68
+	oa.OpenId = oauthUser.OpenId
69
+	oa.Username = oauthUser.Username
70
+	oa.Name = oauthUser.Name
71
+	oa.Email = oauthUser.Email
72
+}
73
+
74
+// Validate the oauth type
75
+func (os *OauthService) ValidateOauthType(oauthType string) error {
76
+	switch oauthType {
77
+	case model.OauthTypeGithub, model.OauthTypeGoogle, model.OauthTypeOidc, model.OauthTypeWebauth:
78
+		return nil
79
+	default:
80
+		return errors.New("invalid Oauth type")
81
+	}
82
+}
83
+
84
+
119 85
 func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
120 86
 	v, ok := OauthCache.Load(key)
121 87
 	if !ok {
@@ -141,12 +107,12 @@ func (os *OauthService) DeleteOauthCache(key string) {
141 107
 func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
142 108
 	code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
143 109
 
144
-	if op == model.OauthTypeWebauth {
110
+	if op == string(model.OauthTypeWebauth) {
145 111
 		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
146 112
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
147 113
 		return nil, code, url
148 114
 	}
149
-	err, conf := os.GetOauthConfig(op)
115
+	err, _, conf := os.GetOauthConfig(op)
150 116
 	if err == nil {
151 117
 		return err, code, conf.AuthCodeURL(code)
152 118
 	}
@@ -155,7 +121,7 @@ func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
155 121
 }
156 122
 
157 123
 // Method to fetch OIDC configuration dynamically
158
-func FetchOidcConfig(issuer string) (error, OidcEndpoint) {
124
+func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) {
159 125
 	configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
160 126
 
161 127
 	// Get the HTTP client (with or without proxy based on configuration)
@@ -179,76 +145,55 @@ func FetchOidcConfig(issuer string) (error, OidcEndpoint) {
179 145
 	return nil, endpoint
180 146
 }
181 147
 
182
-// GetOauthConfig retrieves the OAuth2 configuration based on the provider type
183
-func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) {
184
-	switch op {
185
-	case model.OauthTypeGithub:
186
-		return os.getGithubConfig()
187
-	case model.OauthTypeGoogle:
188
-		return os.getGoogleConfig()
189
-	case model.OauthTypeOidc:
190
-		return os.getOidcConfig()
191
-	default:
192
-		return errors.New("unsupported OAuth type"), nil
193
-	}
194
-}
195
-
196
-// Helper function to get GitHub OAuth2 configuration
197
-func (os *OauthService) getGithubConfig() (error, *oauth2.Config) {
198
-	g := os.InfoByOp(model.OauthTypeGithub)
199
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" {
200
-		return errors.New("ConfigNotFound"), nil
201
-	}
202
-	return nil, &oauth2.Config{
203
-		ClientID:     g.ClientId,
204
-		ClientSecret: g.ClientSecret,
205
-		RedirectURL:  g.RedirectUrl,
206
-		Endpoint:     github.Endpoint,
207
-		Scopes:       []string{"read:user", "user:email"},
148
+func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
149
+	oauthInfo := os.InfoByOp(op)
150
+	if oauthInfo.Issuer == "" {
151
+		return errors.New("issuer is empty"), OidcEndpoint{}
208 152
 	}
153
+	return os.FetchOidcEndpoint(oauthInfo.Issuer)
209 154
 }
210 155
 
211
-// Helper function to get Google OAuth2 configuration
212
-func (os *OauthService) getGoogleConfig() (error, *oauth2.Config) {
213
-	g := os.InfoByOp(model.OauthTypeGoogle)
214
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" {
215
-		return errors.New("ConfigNotFound"), nil
156
+// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
157
+func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig *oauth2.Config) {
158
+	err, oauthType, oauthConfig = os.getOauthConfigGeneral(op)
159
+	if err != nil {
160
+		return err, oauthType, nil
216 161
 	}
217
-	return nil, &oauth2.Config{
218
-		ClientID:     g.ClientId,
219
-		ClientSecret: g.ClientSecret,
220
-		RedirectURL:  g.RedirectUrl,
221
-		Endpoint:     google.Endpoint,
222
-		Scopes:       []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"},
162
+	// Maybe should validate the oauthConfig here
163
+	switch oauthType {
164
+	case model.OauthTypeGithub:
165
+		oauthConfig.Endpoint = github.Endpoint
166
+		oauthConfig.Scopes = []string{"read:user", "user:email"}
167
+	case model.OauthTypeGoogle:
168
+		oauthConfig.Endpoint = google.Endpoint
169
+		oauthConfig.Scopes = []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"}
170
+	case model.OauthTypeOidc:
171
+		err, endpoint := os.FetchOidcEndpointByOp(op)
172
+		if err != nil {
173
+			return err,oauthType, nil
174
+		}
175
+		oauthConfig.Endpoint = oauth2.Endpoint{AuthURL:  endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
176
+		oauthConfig.Scopes = os.getScopesByOp(op)
177
+	default:
178
+		return errors.New("unsupported OAuth type"), oauthType, nil
223 179
 	}
180
+	return nil, oauthType, oauthConfig
224 181
 }
225 182
 
226
-// Helper function to get OIDC OAuth2 configuration
227
-func (os *OauthService) getOidcConfig() (error, *oauth2.Config) {
228
-	g := os.InfoByOp(model.OauthTypeOidc)
229
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" || g.RedirectUrl == "" || g.Issuer == "" {
230
-		return errors.New("ConfigNotFound"), nil
183
+// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
184
+func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthType string, oauthConfig *oauth2.Config) {
185
+	g := os.InfoByOp(op)
186
+	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" {
187
+		return errors.New("ConfigNotFound"), "", nil
231 188
 	}
232
-
233
-	// Set scopes
234
-	scopes := strings.TrimSpace(g.Scopes)
235
-	if scopes == "" {
236
-		scopes = "openid,profile,email"
189
+	// If the redirect URL is empty, use the default redirect URL
190
+	if g.RedirectUrl == "" {
191
+		g.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
237 192
 	}
238
-	scopeList := strings.Split(scopes, ",")
239
-	err, endpoint := FetchOidcConfig(g.Issuer)
240
-	if err != nil {
241
-		return err, nil
242
-	}
243
-	return nil, &oauth2.Config{
193
+	return nil, g.OauthType, &oauth2.Config{
244 194
 		ClientID:     g.ClientId,
245 195
 		ClientSecret: g.ClientSecret,
246 196
 		RedirectURL:  g.RedirectUrl,
247
-		Endpoint: oauth2.Endpoint{
248
-			AuthURL:  endpoint.AuthURL,
249
-			TokenURL: endpoint.TokenURL,
250
-		},
251
-		Scopes: scopeList,
252 197
 	}
253 198
 }
254 199
 
@@ -272,194 +217,161 @@ func getHTTPClientWithProxy() *http.Client {
272 217
 	return http.DefaultClient
273 218
 }
274 219
 
275
-func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
276
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
277
-	if err != nil {
278
-		return err, nil
279
-	}
280
-
281
-	// 使用代理配置创建 HTTP 客户端
282
-	httpClient := getHTTPClientWithProxy()
283
-	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
284
-
285
-	token, err := oauthConfig.Exchange(ctx, code)
220
+func (os *OauthService) callbackBase(op string, code string, userEndpoint string, userData interface{}) error {
221
+	err, oauthType, oauthConfig := os.GetOauthConfig(op)
286 222
 	if err != nil {
287
-		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
288
-		error = errors.New("GetOauthTokenError")
289
-		return
223
+		return err
290 224
 	}
291
-
292
-	// 使用带有代理的 HTTP 客户端获取用户信息
293
-	client := oauthConfig.Client(ctx, token)
294
-	resp, err := client.Get("https://api.github.com/user")
295
-	if err != nil {
296
-		global.Logger.Warn("failed getting user info: ", err)
297
-		error = errors.New("GetOauthUserInfoError")
298
-		return
299
-	}
300
-	defer func(Body io.ReadCloser) {
301
-		err := Body.Close()
225
+	
226
+	// If the OAuth type is OIDC and the user endpoint is empty
227
+	// Fetch the OIDC configuration and get the user endpoint
228
+	if oauthType == model.OauthTypeOidc && userEndpoint == "" {
229
+		err, endpoint := os.FetchOidcEndpointByOp(op)
302 230
 		if err != nil {
303
-			global.Logger.Warn("failed closing response body: ", err)
231
+			global.Logger.Warn("failed fetching OIDC configuration: ", err)
232
+			return errors.New("FetchOidcEndpointError")
304 233
 		}
305
-	}(resp.Body)
306
-
307
-	// 解析用户信息
308
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
309
-		global.Logger.Warn("failed decoding user info: ", err)
310
-		error = errors.New("DecodeOauthUserInfoError")
311
-		return
234
+		userEndpoint = endpoint.UserInfo
312 235
 	}
313
-	return
314
-}
315 236
 
316
-func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
317
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
318
-	if err != nil {
319
-		return err, nil
320
-	}
321
-
322
-	// 使用代理配置创建 HTTP 客户端
237
+	// 设置代理客户端
323 238
 	httpClient := getHTTPClientWithProxy()
324 239
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
325 240
 
241
+	// 使用 code 换取 token
326 242
 	token, err := oauthConfig.Exchange(ctx, code)
327 243
 	if err != nil {
328 244
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
329
-		error = errors.New("GetOauthTokenError")
330
-		return
245
+		return errors.New("GetOauthTokenError")
331 246
 	}
332 247
 
333
-	// 使用带有代理的 HTTP 客户端获取用户信息
248
+	// 获取用户信息
334 249
 	client := oauthConfig.Client(ctx, token)
335
-	resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
250
+	resp, err := client.Get(userEndpoint)
336 251
 	if err != nil {
337 252
 		global.Logger.Warn("failed getting user info: ", err)
338
-		error = errors.New("GetOauthUserInfoError")
339
-		return
253
+		return errors.New("GetOauthUserInfoError")
340 254
 	}
341
-	defer func(Body io.ReadCloser) {
342
-		err := Body.Close()
343
-		if err != nil {
344
-			global.Logger.Warn("failed closing response body: ", err)
255
+	defer func() {
256
+		if closeErr := resp.Body.Close(); closeErr != nil {
257
+			global.Logger.Warn("failed closing response body: ", closeErr)
345 258
 		}
346
-	}(resp.Body)
259
+	}()
347 260
 
348 261
 	// 解析用户信息
349
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
262
+	if err = json.NewDecoder(resp.Body).Decode(userData); err != nil {
350 263
 		global.Logger.Warn("failed decoding user info: ", err)
351
-		error = errors.New("DecodeOauthUserInfoError")
352
-		return
264
+		return errors.New("DecodeOauthUserInfoError")
353 265
 	}
354
-	return
266
+
267
+	return nil
355 268
 }
356 269
 
357
-func (os *OauthService) OidcCallback(code string) (error error, userData *OidcUserdata) {
358
-	err, oauthConfig := os.GetOauthConfig(model.OauthTypeOidc)
359
-	if err != nil {
270
+// githubCallback github回调
271
+func (os *OauthService) githubCallback(code string) (error, *model.OauthUser) {
272
+	var user = &model.GithubUser{}
273
+	const userEndpoint = "https://api.github.com/user"
274
+	if err := os.callbackBase(model.OauthTypeGithub, code, userEndpoint, user); err != nil {
360 275
 		return err, nil
361 276
 	}
362
-	// 使用代理配置创建 HTTP 客户端
363
-	httpClient := getHTTPClientWithProxy()
364
-	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
365
-
366
-	token, err := oauthConfig.Exchange(ctx, code)
367
-	if err != nil {
368
-		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
369
-		error = errors.New("GetOauthTokenError")
370
-		return
371
-	}
372
-
373
-	// 使用带有代理的 HTTP 客户端获取用户信息
374
-	client := oauthConfig.Client(ctx, token)
375
-	g := os.InfoByOp(model.OauthTypeOidc)
376
-	err, endpoint := FetchOidcConfig(g.Issuer)
377
-	if err != nil {
378
-		global.Logger.Warn("failed fetching OIDC configuration: ", err)
379
-		error = errors.New("FetchOidcConfigError")
380
-		return
381
-	}
382
-	resp, err := client.Get(endpoint.UserInfo)
383
-	if err != nil {
384
-		global.Logger.Warn("failed getting user info: ", err)
385
-		error = errors.New("GetOauthUserInfoError")
386
-		return
387
-	}
388
-	defer func(Body io.ReadCloser) {
389
-		err := Body.Close()
390
-		if err != nil {
391
-			global.Logger.Warn("failed closing response body: ", err)
392
-		}
393
-	}(resp.Body)
277
+	return nil, user.ToOauthUser()
278
+}
394 279
 
395
-	// 解析用户信息
396
-	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
397
-		global.Logger.Warn("failed decoding user info: ", err)
398
-		error = errors.New("DecodeOauthUserInfoError")
399
-		return
280
+// googleCallback google回调
281
+func (os *OauthService) googleCallback(code string) (error, *model.OauthUser) {
282
+	var user = &model.GoogleUser{}
283
+	const userEndpoint = "https://www.googleapis.com/oauth2/v2/userinfo"
284
+	if err := os.callbackBase(model.OauthTypeGoogle, code, userEndpoint, user); err != nil {
285
+		return err, nil
400 286
 	}
401
-	return
287
+	return nil, user.ToOauthUser()
402 288
 }
403 289
 
404
-func (os *OauthService) UserThirdInfo(op, openid string) *model.UserThird {
405
-	ut := &model.UserThird{}
406
-	global.DB.Where("open_id = ? and third_type = ?", openid, op).First(ut)
407
-	return ut
290
+// oidcCallback oidc回调, 通过code获取用户信息
291
+func (os *OauthService) oidcCallback(code string, op string) (error, *model.OauthUser,) {
292
+	var user = &model.OidcUser{}
293
+	if err := os.callbackBase(op, code, "", user); err != nil {
294
+		return err, nil
295
+	}
296
+	return nil, user.ToOauthUser()
408 297
 }
409 298
 
410
-func (os *OauthService) BindGithubUser(openid, username string, userId uint) error {
411
-	return os.BindOauthUser(model.OauthTypeGithub, openid, username, userId)
299
+// Callback: Get user information by code and op(Oauth provider)
300
+func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
301
+    oauthType := os.GetTypeByOp(op)
302
+    if err = os.ValidateOauthType(oauthType); err != nil {
303
+        return err, nil
304
+    }
305
+    
306
+    switch oauthType {
307
+    case model.OauthTypeGithub:
308
+        err, oauthUser = os.githubCallback(code)
309
+    case model.OauthTypeGoogle:
310
+        err, oauthUser = os.googleCallback(code)
311
+    case model.OauthTypeOidc:
312
+        err, oauthUser = os.oidcCallback(code, op)
313
+    default:
314
+        return errors.New("unsupported OAuth type"), nil
315
+    }
316
+    
317
+    return err, oauthUser
412 318
 }
413 319
 
414
-func (os *OauthService) BindGoogleUser(email, username string, userId uint) error {
415
-	return os.BindOauthUser(model.OauthTypeGoogle, email, username, userId)
416
-}
417 320
 
418
-func (os *OauthService) BindOidcUser(sub, username string, userId uint) error {
419
-	return os.BindOauthUser(model.OauthTypeOidc, sub, username, userId)
321
+func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
322
+	ut := &model.UserThird{}
323
+	global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
324
+	return ut
420 325
 }
421 326
 
422
-func (os *OauthService) BindOauthUser(thirdType, openid, username string, userId uint) error {
423
-	utr := &model.UserThird{
424
-		OpenId:    openid,
425
-		ThirdType: thirdType,
426
-		ThirdName: username,
427
-		UserId:    userId,
428
-	}
327
+// BindOauthUser: Bind third party account
328
+func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error {
329
+	utr := &model.UserThird{}
330
+	oauthType := os.GetTypeByOp(op)
331
+	utr.FromOauthUser(userId, oauthUser, oauthType, op)
429 332
 	return global.DB.Create(utr).Error
430 333
 }
431 334
 
432
-func (os *OauthService) UnBindGithubUser(userid uint) error {
433
-	return os.UnBindThird(model.OauthTypeGithub, userid)
335
+// UnBindOauthUser: Unbind third party account
336
+func (os *OauthService) UnBindOauthUser(userId uint, op string) error {
337
+	return os.UnBindThird(op, userId)
434 338
 }
435
-func (os *OauthService) UnBindGoogleUser(userid uint) error {
436
-	return os.UnBindThird(model.OauthTypeGoogle, userid)
437
-}
438
-func (os *OauthService) UnBindOidcUser(userid uint) error {
439
-	return os.UnBindThird(model.OauthTypeOidc, userid)
440
-}
441
-func (os *OauthService) UnBindThird(thirdType string, userid uint) error {
442
-	return global.DB.Where("user_id = ? and third_type = ?", userid, thirdType).Delete(&model.UserThird{}).Error
339
+
340
+// UnBindThird: Unbind third party account
341
+func (os *OauthService) UnBindThird(op string, userId uint) error {
342
+	return global.DB.Where("user_id = ? and op = ?", userId, op).Delete(&model.UserThird{}).Error
443 343
 }
444 344
 
445 345
 // DeleteUserByUserId: When user is deleted, delete all third party bindings
446
-func (os *OauthService) DeleteUserByUserId(userid uint) error {
447
-	return global.DB.Where("user_id = ?", userid).Delete(&model.UserThird{}).Error
346
+func (os *OauthService) DeleteUserByUserId(userId uint) error {
347
+	return global.DB.Where("user_id = ?", userId).Delete(&model.UserThird{}).Error
448 348
 }
449 349
 
450
-// InfoById 根据id取用户信息
350
+// InfoById 根据id获取Oauth信息
451 351
 func (os *OauthService) InfoById(id uint) *model.Oauth {
452
-	u := &model.Oauth{}
453
-	global.DB.Where("id = ?", id).First(u)
454
-	return u
352
+	oauthInfo := &model.Oauth{}
353
+	global.DB.Where("id = ?", id).First(oauthInfo)
354
+	return oauthInfo
455 355
 }
456 356
 
457
-// InfoByOp 根据op取用户信息
357
+// InfoByOp 根据op获取Oauth信息
458 358
 func (os *OauthService) InfoByOp(op string) *model.Oauth {
459
-	u := &model.Oauth{}
460
-	global.DB.Where("op = ?", op).First(u)
461
-	return u
359
+	oauthInfo := &model.Oauth{}
360
+	global.DB.Where("op = ?", op).First(oauthInfo)
361
+	return oauthInfo
462 362
 }
363
+
364
+// Helper function to get scopes by operation
365
+func (os *OauthService) getScopesByOp(op string) []string {
366
+    scopes := os.InfoByOp(op).Scopes
367
+    scopes = strings.TrimSpace(scopes) // 这里使用 `=` 而不是 `:=`,避免重新声明变量
368
+    if scopes == "" {
369
+        scopes = "openid,profile,email"
370
+    }
371
+    return strings.Split(scopes, ",")
372
+}
373
+
374
+
463 375
 func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
464 376
 	res = &model.OauthList{}
465 377
 	res.Page = int64(page)
@@ -474,16 +386,41 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
474 386
 	return
475 387
 }
476 388
 
389
+// GetTypeByOp 根据op获取OauthType
390
+func (os *OauthService) GetTypeByOp(op string) string {
391
+	oauthInfo := &model.Oauth{}
392
+	if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil {
393
+		return ""
394
+	}
395
+	return oauthInfo.OauthType
396
+}
397
+
398
+func (os *OauthService) ValidateOauthProvider(op string) error {
399
+	oauthInfo := &model.Oauth{}
400
+    // 使用 Gorm 的 Take 方法查找符合条件的记录
401
+    if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil {
402
+        return fmt.Errorf("OAuth provider with op '%s' not found: %w", op, err)
403
+    }
404
+    return nil
405
+}
406
+
477 407
 // Create 创建
478
-func (os *OauthService) Create(u *model.Oauth) error {
479
-	res := global.DB.Create(u).Error
408
+func (os *OauthService) Create(oauthInfo *model.Oauth) error {
409
+	res := global.DB.Create(oauthInfo).Error
480 410
 	return res
481 411
 }
482
-func (os *OauthService) Delete(u *model.Oauth) error {
483
-	return global.DB.Delete(u).Error
412
+func (os *OauthService) Delete(oauthInfo *model.Oauth) error {
413
+	return global.DB.Delete(oauthInfo).Error
484 414
 }
485 415
 
486 416
 // Update 更新
487
-func (os *OauthService) Update(u *model.Oauth) error {
488
-	return global.DB.Model(u).Updates(u).Error
417
+func (os *OauthService) Update(oauthInfo *model.Oauth) error {
418
+	return global.DB.Model(oauthInfo).Updates(oauthInfo).Error
489 419
 }
420
+
421
+// GetOauthProviders 获取所有的provider
422
+func (os *OauthService) GetOauthProviders() []string {
423
+	var res []string
424
+	global.DB.Model(&model.Oauth{}).Pluck("op", &res)
425
+	return res
426
+}

+ 37 - 59
service/user.go

@@ -21,12 +21,20 @@ func (us *UserService) InfoById(id uint) *model.User {
21 21
 	global.DB.Where("id = ?", id).First(u)
22 22
 	return u
23 23
 }
24
+// InfoByUsername 根据用户名取用户信息
24 25
 func (us *UserService) InfoByUsername(un string) *model.User {
25 26
 	u := &model.User{}
26 27
 	global.DB.Where("username = ?", un).First(u)
27 28
 	return u
28 29
 }
29 30
 
31
+// InfoByEmail 根据邮箱取用户信息
32
+func (us *UserService) InfoByEmail(email string) *model.User {
33
+	u := &model.User{}
34
+	global.DB.Where("email = ?", email).First(u)
35
+	return u
36
+}
37
+
30 38
 // InfoByOpenid 根据openid取用户信息
31 39
 func (us *UserService) InfoByOpenid(openid string) *model.User {
32 40
 	u := &model.User{}
@@ -216,24 +224,9 @@ func (us *UserService) RouteNames(u *model.User) []string {
216 224
 	return adResp.UserRouteNames
217 225
 }
218 226
 
219
-// InfoByGithubId 根据githubid取用户信息
220
-func (us *UserService) InfoByGithubId(githubId string) *model.User {
221
-	return us.InfoByOauthId(model.OauthTypeGithub, githubId)
222
-}
223
-
224
-// InfoByGoogleEmail 根据googleid取用户信息
225
-func (us *UserService) InfoByGoogleEmail(email string) *model.User {
226
-	return us.InfoByOauthId(model.OauthTypeGithub, email)
227
-}
228
-
229
-// InfoByOidcSub 根据oidc取用户信息
230
-func (us *UserService) InfoByOidcSub(sub string) *model.User {
231
-	return us.InfoByOauthId(model.OauthTypeOidc, sub)
232
-}
233
-
234
-// InfoByOauthId 根据oauth取用户信息
235
-func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User {
236
-	ut := AllService.OauthService.UserThirdInfo(thirdType, uid)
227
+// InfoByOauthId 根据oauth的name和openId取用户信息
228
+func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
229
+	ut := AllService.OauthService.UserThirdInfo(op, openId)
237 230
 	if ut.Id == 0 {
238 231
 		return nil
239 232
 	}
@@ -244,55 +237,40 @@ func (us *UserService) InfoByOauthId(thirdType, uid string) *model.User {
244 237
 	return u
245 238
 }
246 239
 
247
-// RegisterByGithub 注册
248
-func (us *UserService) RegisterByGithub(githubName string, githubId string) *model.User {
249
-	return us.RegisterByOauth(model.OauthTypeGithub, githubName, githubId)
250
-}
251
-
252
-// RegisterByGoogle 注册
253
-func (us *UserService) RegisterByGoogle(name string, email string) *model.User {
254
-	return us.RegisterByOauth(model.OauthTypeGoogle, name, email)
255
-}
256
-
257
-// RegisterByOidc 注册, use PreferredUsername as username, sub as openid
258
-func (us *UserService) RegisterByOidc(PreferredUsername string, sub string) *model.User {
259
-	return us.RegisterByOauth(model.OauthTypeOidc, PreferredUsername, sub)
260
-}
261
-
262 240
 // RegisterByOauth 注册
263
-func (us *UserService) RegisterByOauth(thirdType, thirdName, uid string) *model.User {
241
+func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) *model.User {
264 242
 	global.Lock.Lock("registerByOauth")
265 243
 	defer global.Lock.UnLock("registerByOauth")
266
-	ut := AllService.OauthService.UserThirdInfo(thirdType, uid)
244
+	ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
267 245
 	if ut.Id != 0 {
268
-		u := &model.User{}
269
-		global.DB.Where("id = ?", ut.UserId).First(u)
270
-		return u
246
+		return us.InfoById(ut.UserId)
271 247
 	}
272
-
248
+	//check if this email has been registered 
249
+	email := oauthUser.Email
250
+	oauthType := AllService.OauthService.GetTypeByOp(op)
251
+	user := us.InfoByEmail(email)
273 252
 	tx := global.DB.Begin()
274
-	ut = &model.UserThird{
275
-		OpenId:    uid,
276
-		ThirdName: thirdName,
277
-		ThirdType: thirdType,
278
-	}
279
-
280
-	username := us.GenerateUsernameByOauth(thirdName)
281
-	u := &model.User{
282
-		Username: username,
283
-		GroupId:  1,
253
+	if user.Id != 0 {
254
+		ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
255
+	} else {
256
+		ut = &model.UserThird{}
257
+		ut.FromOauthUser(0, oauthUser, oauthType, op)
258
+		usernameUnique := us.GenerateUsernameByOauth(oauthUser.Username)
259
+		user := &model.User{
260
+			Username: usernameUnique,
261
+			GroupId:  1,
262
+		}
263
+		oauthUser.ToUser(user, false)
264
+		tx.Create(user)
265
+		if user.Id == 0 {
266
+			tx.Rollback()
267
+			return user
268
+		}
269
+		ut.UserId = user.Id
284 270
 	}
285
-	tx.Create(u)
286
-	if u.Id == 0 {
287
-		tx.Rollback()
288
-		return u
289
-	}
290
-
291
-	ut.UserId = u.Id
292 271
 	tx.Create(ut)
293
-
294 272
 	tx.Commit()
295
-	return u
273
+	return user
296 274
 }
297 275
 
298 276
 // GenerateUsernameByOauth 生成用户名
@@ -314,7 +292,7 @@ func (us *UserService) UserThirdsByUserId(userId uint) (res []*model.UserThird)
314 292
 
315 293
 func (us *UserService) UserThirdInfo(userId uint, op string) *model.UserThird {
316 294
 	ut := &model.UserThird{}
317
-	global.DB.Where("user_id = ? and third_type = ?", userId, op).First(ut)
295
+	global.DB.Where("user_id = ? and op = ?", userId, op).First(ut)
318 296
 	return ut
319 297
 }
320 298