Tao Chen 1 год назад
Родитель
Сommit
6698877761
2 измененных файлов с 143 добавлено и 76 удалено
  1. 19 1
      model/oauth.go
  2. 124 75
      service/oauth.go

+ 19 - 1
model/oauth.go

@@ -13,6 +13,18 @@ const (
13 13
 	OauthTypeWebauth string = "webauth"
14 14
 )
15 15
 
16
+const (
17
+	OauthNameGithub  string = "GitHub"
18
+	OauthNameGoogle  string = "Google"
19
+	OauthNameOidc    string = "OIDC"
20
+	OauthNameWebauth string = "WebAuth"
21
+)
22
+
23
+const (
24
+	UserEndpointGithub  string = "https://api.github.com/user"
25
+	UserEndpointGoogle  string = "https://www.googleapis.com/oauth2/v3/userinfo"
26
+	UserEndpointOidc    string = ""
27
+)
16 28
 
17 29
 type Oauth struct {
18 30
 	IdModel
@@ -33,6 +45,7 @@ type OauthUser struct {
33 45
 	Username 		string 	`json:"username"`
34 46
 	Email  			string 	`json:"email"`
35 47
 	VerifiedEmail 	bool 	`json:"verified_email,omitempty"`
48
+	Picture			string 	`json:"picture,omitempty"`
36 49
 }
37 50
 
38 51
 func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
@@ -56,6 +69,7 @@ type OidcUser struct {
56 69
 	Sub               string `json:"sub"`
57 70
 	VerifiedEmail     bool   `json:"email_verified"`
58 71
 	PreferredUsername string `json:"preferred_username"`
72
+	Picture		   	  string `json:"picture"`
59 73
 }
60 74
 
61 75
 func (ou *OidcUser) ToOauthUser() *OauthUser {
@@ -65,6 +79,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
65 79
 		Username: 		ou.PreferredUsername,
66 80
 		Email:  		ou.Email,
67 81
 		VerifiedEmail: 	ou.VerifiedEmail,
82
+		Picture:		ou.Picture,
68 83
 	}
69 84
 }
70 85
 
@@ -84,6 +99,7 @@ func (gu *GoogleUser) ToOauthUser() *OauthUser {
84 99
 		Username: 		gu.GivenName,
85 100
 		Email:  		gu.Email,
86 101
 		VerifiedEmail: 	gu.VerifiedEmail,
102
+		Picture:		gu.Picture,
87 103
 	}	
88 104
 }
89 105
 
@@ -92,6 +108,8 @@ type GithubUser struct {
92 108
 	OauthUserBase
93 109
 	Id                int         `json:"id"`
94 110
 	Login             string      `json:"login"`
111
+	AvatarUrl         string      `json:"avatar_url"`
112
+	VerifiedEmail	  bool        `json:"verified_email"`
95 113
 }
96 114
 
97 115
 func (gu *GithubUser) ToOauthUser() *OauthUser {
@@ -100,7 +118,7 @@ func (gu *GithubUser) ToOauthUser() *OauthUser {
100 118
 		Name:   		gu.Name,
101 119
 		Username: 		gu.Login,
102 120
 		Email:  		gu.Email,
103
-		VerifiedEmail: 	true,
121
+		VerifiedEmail: 	gu.VerifiedEmail,
104 122
 	}
105 123
 }
106 124
 

+ 124 - 75
service/oauth.go

@@ -106,15 +106,14 @@ func (os *OauthService) DeleteOauthCache(key string) {
106 106
 
107 107
 func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
108 108
 	code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
109
-
110 109
 	if op == string(model.OauthTypeWebauth) {
111 110
 		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
112 111
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
113 112
 		return nil, code, url
114 113
 	}
115
-	err, _, conf := os.GetOauthConfig(op)
114
+	err, _, oauthConfig := os.GetOauthConfig(op)
116 115
 	if err == nil {
117
-		return err, code, conf.AuthCodeURL(code)
116
+		return err, code, oauthConfig.AuthCodeURL(code)
118 117
 	}
119 118
 
120 119
 	return err, code, ""
@@ -154,16 +153,17 @@ func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
154 153
 }
155 154
 
156 155
 // GetOauthConfig retrieves the OAuth2 configuration based on the provider name
157
-func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string, oauthConfig *oauth2.Config) {
158
-	err = os.ValidateOauthProvider(op)
156
+func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
157
+	err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
159 158
 	if err != nil {
160
-		return err, "", nil
159
+		return err, nil, nil
161 160
 	}
162
-	err, oauthType, oauthConfig = os.getOauthConfigGeneral(op)
161
+	// Maybe should validate the oauthConfig here
162
+	oauthType := oauthInfo.OauthType
163
+	err = os.ValidateOauthType(oauthType)
163 164
 	if err != nil {
164
-		return err, oauthType, nil
165
+		return err, nil, nil
165 166
 	}
166
-	// Maybe should validate the oauthConfig here
167 167
 	switch oauthType {
168 168
 	case model.OauthTypeGithub:
169 169
 		oauthConfig.Endpoint = github.Endpoint
@@ -172,32 +172,33 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthType string,
172 172
 		oauthConfig.Endpoint = google.Endpoint
173 173
 		oauthConfig.Scopes = []string{"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"}
174 174
 	case model.OauthTypeOidc:
175
-		err, endpoint := os.FetchOidcEndpointByOp(op)
175
+		var endpoint OidcEndpoint
176
+		err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
176 177
 		if err != nil {
177
-			return err,oauthType, nil
178
+			return err, nil, nil
178 179
 		}
179 180
 		oauthConfig.Endpoint = oauth2.Endpoint{AuthURL:  endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
180
-		oauthConfig.Scopes = os.getScopesByOp(op)
181
+		oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
181 182
 	default:
182
-		return errors.New("unsupported OAuth type"), oauthType, nil
183
+		return errors.New("unsupported OAuth type"), nil, nil
183 184
 	}
184
-	return nil, oauthType, oauthConfig
185
+	return nil, oauthInfo, oauthConfig
185 186
 }
186 187
 
187 188
 // GetOauthConfig retrieves the OAuth2 configuration based on the provider name
188
-func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthType string, oauthConfig *oauth2.Config) {
189
-	g := os.InfoByOp(op)
190
-	if g.Id == 0 || g.ClientId == "" || g.ClientSecret == "" {
191
-		return errors.New("ConfigNotFound"), "", nil
189
+func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
190
+	oauthInfo = os.InfoByOp(op)
191
+	if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
192
+		return errors.New("ConfigNotFound"), nil, nil
192 193
 	}
193 194
 	// If the redirect URL is empty, use the default redirect URL
194
-	if g.RedirectUrl == "" {
195
-		g.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
195
+	if oauthInfo.RedirectUrl == "" {
196
+		oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
196 197
 	}
197
-	return nil, g.OauthType, &oauth2.Config{
198
-		ClientID:     g.ClientId,
199
-		ClientSecret: g.ClientSecret,
200
-		RedirectURL:  g.RedirectUrl,
198
+	return nil, oauthInfo, &oauth2.Config{
199
+		ClientID:     oauthInfo.ClientId,
200
+		ClientSecret: oauthInfo.ClientSecret,
201
+		RedirectURL:  oauthInfo.RedirectUrl,
201 202
 	}
202 203
 }
203 204
 
@@ -221,40 +222,26 @@ func getHTTPClientWithProxy() *http.Client {
221 222
 	return http.DefaultClient
222 223
 }
223 224
 
224
-func (os *OauthService) callbackBase(op string, code string, userEndpoint string, userData interface{}) error {
225
-	err, oauthType, oauthConfig := os.GetOauthConfig(op)
226
-	if err != nil {
227
-		return err
228
-	}
229
-	
230
-	// If the OAuth type is OIDC and the user endpoint is empty
231
-	// Fetch the OIDC configuration and get the user endpoint
232
-	if oauthType == model.OauthTypeOidc && userEndpoint == "" {
233
-		err, endpoint := os.FetchOidcEndpointByOp(op)
234
-		if err != nil {
235
-			global.Logger.Warn("failed fetching OIDC configuration: ", err)
236
-			return errors.New("FetchOidcEndpointError")
237
-		}
238
-		userEndpoint = endpoint.UserInfo
239
-	}
225
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
240 226
 
241 227
 	// 设置代理客户端
242 228
 	httpClient := getHTTPClientWithProxy()
243 229
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
244 230
 
245 231
 	// 使用 code 换取 token
246
-	token, err := oauthConfig.Exchange(ctx, code)
232
+	var token *oauth2.Token
233
+	token, err = oauthConfig.Exchange(ctx, code)
247 234
 	if err != nil {
248 235
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
249
-		return errors.New("GetOauthTokenError")
236
+		return errors.New("GetOauthTokenError"), nil
250 237
 	}
251 238
 
252 239
 	// 获取用户信息
253
-	client := oauthConfig.Client(ctx, token)
240
+	client = oauthConfig.Client(ctx, token)
254 241
 	resp, err := client.Get(userEndpoint)
255 242
 	if err != nil {
256 243
 		global.Logger.Warn("failed getting user info: ", err)
257
-		return errors.New("GetOauthUserInfoError")
244
+		return errors.New("GetOauthUserInfoError"), nil
258 245
 	}
259 246
 	defer func() {
260 247
 		if closeErr := resp.Body.Close(); closeErr != nil {
@@ -265,36 +252,39 @@ func (os *OauthService) callbackBase(op string, code string, userEndpoint string
265 252
 	// 解析用户信息
266 253
 	if err = json.NewDecoder(resp.Body).Decode(userData); err != nil {
267 254
 		global.Logger.Warn("failed decoding user info: ", err)
268
-		return errors.New("DecodeOauthUserInfoError")
255
+		return errors.New("DecodeOauthUserInfoError"), nil
269 256
 	}
270 257
 
271
-	return nil
258
+	return nil, client
272 259
 }
273 260
 
274 261
 // githubCallback github回调
275
-func (os *OauthService) githubCallback(code string) (error, *model.OauthUser) {
262
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
276 263
 	var user = &model.GithubUser{}
277
-	const userEndpoint = "https://api.github.com/user"
278
-	if err := os.callbackBase(model.OauthTypeGithub, code, userEndpoint, user); err != nil {
264
+	err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
265
+	if err != nil {
266
+		return err, nil
267
+	}
268
+	err = os.getGithubPrimaryEmail(client, user)
269
+	if err != nil {
279 270
 		return err, nil
280 271
 	}
281 272
 	return nil, user.ToOauthUser()
282 273
 }
283 274
 
284 275
 // googleCallback google回调
285
-func (os *OauthService) googleCallback(code string) (error, *model.OauthUser) {
276
+func (os *OauthService) googleCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
286 277
 	var user = &model.GoogleUser{}
287
-	const userEndpoint = "https://www.googleapis.com/oauth2/v2/userinfo"
288
-	if err := os.callbackBase(model.OauthTypeGoogle, code, userEndpoint, user); err != nil {
278
+	if err, _ := os.callbackBase(oauthConfig, code, model.UserEndpointGoogle, user); err != nil {
289 279
 		return err, nil
290 280
 	}
291 281
 	return nil, user.ToOauthUser()
292 282
 }
293 283
 
294 284
 // oidcCallback oidc回调, 通过code获取用户信息
295
-func (os *OauthService) oidcCallback(code string, op string) (error, *model.OauthUser,) {
285
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
296 286
 	var user = &model.OidcUser{}
297
-	if err := os.callbackBase(op, code, "", user); err != nil {
287
+	if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
298 288
 		return err, nil
299 289
 	}
300 290
 	return nil, user.ToOauthUser()
@@ -302,22 +292,28 @@ func (os *OauthService) oidcCallback(code string, op string) (error, *model.Oaut
302 292
 
303 293
 // Callback: Get user information by code and op(Oauth provider)
304 294
 func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
305
-    oauthType := os.GetTypeByOp(op)
306
-    if err = os.ValidateOauthType(oauthType); err != nil {
307
-        return err, nil
308
-    }
309
-    
310
-    switch oauthType {
295
+	var oauthInfo *model.Oauth
296
+	var oauthConfig *oauth2.Config
297
+	err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
298
+	// oauthType is already validated in GetOauthConfig
299
+	if err != nil {
300
+		return err, nil
301
+	}
302
+	oauthType := oauthInfo.OauthType
303
+	switch oauthType {
311 304
     case model.OauthTypeGithub:
312
-        err, oauthUser = os.githubCallback(code)
305
+        err, oauthUser = os.githubCallback(oauthConfig, code)
313 306
     case model.OauthTypeGoogle:
314
-        err, oauthUser = os.googleCallback(code)
307
+        err, oauthUser = os.googleCallback(oauthConfig, code)
315 308
     case model.OauthTypeOidc:
316
-        err, oauthUser = os.oidcCallback(code, op)
309
+		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
310
+		if err != nil {
311
+			return err, nil
312
+		}
313
+        err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
317 314
     default:
318 315
         return errors.New("unsupported OAuth type"), nil
319 316
     }
320
-    
321 317
     return err, oauthUser
322 318
 }
323 319
 
@@ -331,7 +327,10 @@ func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird
331 327
 // BindOauthUser: Bind third party account
332 328
 func (os *OauthService) BindOauthUser(userId uint, oauthUser *model.OauthUser, op string) error {
333 329
 	utr := &model.UserThird{}
334
-	oauthType := os.GetTypeByOp(op)
330
+	err, oauthType := os.GetTypeByOp(op)
331
+	if err != nil {
332
+		return err
333
+	}
335 334
 	utr.FromOauthUser(userId, oauthUser, oauthType, op)
336 335
 	return global.DB.Create(utr).Error
337 336
 }
@@ -368,14 +367,18 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth {
368 367
 // Helper function to get scopes by operation
369 368
 func (os *OauthService) getScopesByOp(op string) []string {
370 369
     scopes := os.InfoByOp(op).Scopes
371
-    scopes = strings.TrimSpace(scopes) // 这里使用 `=` 而不是 `:=`,避免重新声明变量
370
+	return os.constructScopes(scopes)
371
+}
372
+
373
+// Helper function to construct scopes
374
+func (os *OauthService) constructScopes(scopes string) []string {
375
+    scopes = strings.TrimSpace(scopes)
372 376
     if scopes == "" {
373 377
         scopes = "openid,profile,email"
374 378
     }
375 379
     return strings.Split(scopes, ",")
376 380
 }
377 381
 
378
-
379 382
 func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
380 383
 	res = &model.OauthList{}
381 384
 	res.Page = int64(page)
@@ -391,21 +394,30 @@ func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res
391 394
 }
392 395
 
393 396
 // GetTypeByOp 根据op获取OauthType
394
-func (os *OauthService) GetTypeByOp(op string) string {
397
+func (os *OauthService) GetTypeByOp(op string) (error, string) {
395 398
 	oauthInfo := &model.Oauth{}
396 399
 	if global.DB.Where("op = ?", op).First(oauthInfo).Error != nil {
397
-		return ""
400
+		return fmt.Errorf("OAuth provider with op '%s' not found", op), ""
398 401
 	}
399
-	return oauthInfo.OauthType
402
+	return nil, oauthInfo.OauthType
400 403
 }
401 404
 
405
+// ValidateOauthProvider 验证Oauth提供者是否正确
402 406
 func (os *OauthService) ValidateOauthProvider(op string) error {
407
+	if !os.IsOauthProviderExist(op) {
408
+		return fmt.Errorf("OAuth provider with op '%s' not found", op)
409
+	}
410
+	return nil
411
+}
412
+
413
+// IsOauthProviderExist 验证Oauth提供者是否存在
414
+func (os *OauthService) IsOauthProviderExist(op string) bool {
403 415
 	oauthInfo := &model.Oauth{}
404
-    // 使用 Gorm 的 Take 方法查找符合条件的记录
405
-    if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil {
406
-        return fmt.Errorf("OAuth provider with op '%s' not found: %w", op, err)
407
-    }
408
-    return nil
416
+	// 使用 Gorm 的 Take 方法查找符合条件的记录
417
+	if err := global.DB.Where("op = ?", op).Take(oauthInfo).Error; err != nil {
418
+		return false
419
+	}
420
+	return true
409 421
 }
410 422
 
411 423
 // Create 创建
@@ -427,4 +439,41 @@ func (os *OauthService) GetOauthProviders() []string {
427 439
 	var res []string
428 440
 	global.DB.Model(&model.Oauth{}).Pluck("op", &res)
429 441
 	return res
442
+}
443
+
444
+// getGithubPrimaryEmail: Get the primary email of the user from Github
445
+func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *model.GithubUser) error {
446
+	// the client is already set with the token
447
+	resp, err := client.Get("https://api.github.com/user/emails")
448
+	if err != nil {
449
+		return fmt.Errorf("failed to fetch emails: %w", err)
450
+	}
451
+	defer resp.Body.Close()
452
+
453
+	// check the response status code
454
+	if resp.StatusCode != http.StatusOK {
455
+		return fmt.Errorf("failed to fetch emails: %s", resp.Status)
456
+	}
457
+
458
+	// decode the response
459
+	var emails []struct {
460
+		Email    string `json:"email"`
461
+		Primary  bool   `json:"primary"`
462
+		Verified bool   `json:"verified"`
463
+	}
464
+
465
+	if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
466
+		return fmt.Errorf("failed to decode response: %w", err)
467
+	}
468
+
469
+	// find the primary verified email
470
+	for _, e := range emails {
471
+		if e.Primary && e.Verified {
472
+			githubUser.Email = e.Email
473
+			githubUser.VerifiedEmail = e.Verified
474
+			return nil
475
+		}
476
+	}
477
+
478
+	return fmt.Errorf("no primary verified email found")
430 479
 }