Tao Chen месяцев назад: 11
Родитель
Сommit
c75320f4f4

+ 7 - 1
cmd/apimain.go

@@ -166,7 +166,7 @@ func InitGlobal() {
166
 	global.Lock = lock.NewLocal()
166
 	global.Lock = lock.NewLocal()
167
 }
167
 }
168
 func DatabaseAutoUpdate() {
168
 func DatabaseAutoUpdate() {
169
-	version := 260
169
+	version := 261
170
 
170
 
171
 	db := global.DB
171
 	db := global.DB
172
 
172
 
@@ -210,6 +210,12 @@ func DatabaseAutoUpdate() {
210
 		if v.Version < uint(version) {
210
 		if v.Version < uint(version) {
211
 			Migrate(uint(version))
211
 			Migrate(uint(version))
212
 		}
212
 		}
213
+		// 261迁移
214
+		if v.Version < 261 {
215
+			// 在oauths表中添加pkce_enable 和 pkce_method 字段
216
+			db.Exec("ALTER TABLE oauths ADD COLUMN pkce_enable TINYINT(1) NOT NULL DEFAULT 0")
217
+			db.Exec("ALTER TABLE oauths ADD COLUMN pkce_method VARCHAR(20) NOT NULL DEFAULT 'S256'")
218
+		}
213
 		// 245迁移
219
 		// 245迁移
214
 		if v.Version < 245 {
220
 		if v.Version < 245 {
215
 			//oauths 表的 oauth_type 字段设置为 op同样的值
221
 			//oauths 表的 oauth_type 字段设置为 op同样的值

+ 4 - 3
http/controller/admin/login.go

@@ -283,13 +283,13 @@ func (ct *Login) OidcAuth(c *gin.Context) {
283
 		return
283
 		return
284
 	}
284
 	}
285
 
285
 
286
-	err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
286
+	err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
287
 	if err != nil {
287
 	if err != nil {
288
 		response.Error(c, response.TranslateMsg(c, err.Error()))
288
 		response.Error(c, response.TranslateMsg(c, err.Error()))
289
 		return
289
 		return
290
 	}
290
 	}
291
 
291
 
292
-	service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
292
+	service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
293
 		Action:     service.OauthActionTypeLogin,
293
 		Action:     service.OauthActionTypeLogin,
294
 		Op:         f.Op,
294
 		Op:         f.Op,
295
 		Id:         f.Id,
295
 		Id:         f.Id,
@@ -297,10 +297,11 @@ func (ct *Login) OidcAuth(c *gin.Context) {
297
 		// DeviceOs: ct.Platform(c),
297
 		// DeviceOs: ct.Platform(c),
298
 		DeviceOs: f.DeviceInfo.Os,
298
 		DeviceOs: f.DeviceInfo.Os,
299
 		Uuid:     f.Uuid,
299
 		Uuid:     f.Uuid,
300
+		Verifier: verifier,
300
 	}, 5*60)
301
 	}, 5*60)
301
 
302
 
302
 	response.Success(c, gin.H{
303
 	response.Success(c, gin.H{
303
-		"code": code,
304
+		"code": state,
304
 		"url":  url,
305
 		"url":  url,
305
 	})
306
 	})
306
 }
307
 }

+ 6 - 5
http/controller/admin/oauth.go

@@ -43,20 +43,21 @@ func (o *Oauth) ToBind(c *gin.Context) {
43
 		return
43
 		return
44
 	}
44
 	}
45
 
45
 
46
-	err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
46
+	err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
47
 	if err != nil {
47
 	if err != nil {
48
 		response.Error(c, response.TranslateMsg(c, err.Error()))
48
 		response.Error(c, response.TranslateMsg(c, err.Error()))
49
 		return
49
 		return
50
 	}
50
 	}
51
 
51
 
52
-	service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
52
+	service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
53
 		Action: service.OauthActionTypeBind,
53
 		Action: service.OauthActionTypeBind,
54
-		Op:     f.Op,
55
-		UserId: u.Id,
54
+		Op:     	f.Op,
55
+		UserId: 	u.Id,
56
+		Verifier: 	verifier,
56
 	}, 5*60)
57
 	}, 5*60)
57
 
58
 
58
 	response.Success(c, gin.H{
59
 	response.Success(c, gin.H{
59
-		"code": code,
60
+		"code": state,
60
 		"url":  url,
61
 		"url":  url,
61
 	})
62
 	})
62
 }
63
 }

+ 8 - 5
http/controller/api/ouath.go

@@ -32,15 +32,16 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
32
 	}
32
 	}
33
 
33
 
34
 	oauthService := service.AllService.OauthService
34
 	oauthService := service.AllService.OauthService
35
-	var code string
35
+	var state string
36
 	var url string
36
 	var url string
37
-	err, code, url = oauthService.BeginAuth(f.Op)
37
+	var verifier string
38
+	err, state, verifier, url = oauthService.BeginAuth(f.Op)
38
 	if err != nil {
39
 	if err != nil {
39
 		response.Error(c, response.TranslateMsg(c, err.Error()))
40
 		response.Error(c, response.TranslateMsg(c, err.Error()))
40
 		return
41
 		return
41
 	}
42
 	}
42
 
43
 
43
-	service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
44
+	service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
44
 		Action:     service.OauthActionTypeLogin,
45
 		Action:     service.OauthActionTypeLogin,
45
 		Id:         f.Id,
46
 		Id:         f.Id,
46
 		Op:         f.Op,
47
 		Op:         f.Op,
@@ -48,10 +49,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
48
 		DeviceName: f.DeviceInfo.Name,
49
 		DeviceName: f.DeviceInfo.Name,
49
 		DeviceOs:   f.DeviceInfo.Os,
50
 		DeviceOs:   f.DeviceInfo.Os,
50
 		DeviceType: f.DeviceInfo.Type,
51
 		DeviceType: f.DeviceInfo.Type,
52
+		Verifier:   verifier,
51
 	}, 5*60)
53
 	}, 5*60)
52
 	//fmt.Println("code url", code, url)
54
 	//fmt.Println("code url", code, url)
53
 	c.JSON(http.StatusOK, gin.H{
55
 	c.JSON(http.StatusOK, gin.H{
54
-		"code": code,
56
+		"code": state,
55
 		"url":  url,
57
 		"url":  url,
56
 	})
58
 	})
57
 }
59
 }
@@ -156,10 +158,11 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
156
 	}
158
 	}
157
 	op := oauthCache.Op
159
 	op := oauthCache.Op
158
 	action := oauthCache.Action
160
 	action := oauthCache.Action
161
+	verifier := oauthCache.Verifier
159
 	var user *model.User
162
 	var user *model.User
160
 	// 获取用户信息
163
 	// 获取用户信息
161
 	code := c.Query("code")
164
 	code := c.Query("code")
162
-	err, oauthUser := oauthService.Callback(code, op)
165
+	err, oauthUser := oauthService.Callback(code, verifier, op)
163
 	if err != nil {
166
 	if err != nil {
164
 		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
167
 		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
165
 		return
168
 		return

+ 4 - 0
http/request/admin/oauth.go

@@ -24,6 +24,8 @@ type OauthForm struct {
24
 	ClientSecret string `json:"client_secret" validate:"required"`
24
 	ClientSecret string `json:"client_secret" validate:"required"`
25
 	RedirectUrl  string `json:"redirect_url" validate:"required"`
25
 	RedirectUrl  string `json:"redirect_url" validate:"required"`
26
 	AutoRegister *bool  `json:"auto_register"`
26
 	AutoRegister *bool  `json:"auto_register"`
27
+	PkceEnable   *bool  `json:"pkce_enable"`
28
+	PkceMethod   string `json:"pkce_method"`
27
 }
29
 }
28
 
30
 
29
 func (of *OauthForm) ToOauth() *model.Oauth {
31
 func (of *OauthForm) ToOauth() *model.Oauth {
@@ -36,6 +38,8 @@ func (of *OauthForm) ToOauth() *model.Oauth {
36
 		AutoRegister: of.AutoRegister,
38
 		AutoRegister: of.AutoRegister,
37
 		Issuer:       of.Issuer,
39
 		Issuer:       of.Issuer,
38
 		Scopes:       of.Scopes,
40
 		Scopes:       of.Scopes,
41
+		PkceEnable:   of.PkceEnable,
42
+		PkceMethod:   of.PkceMethod,
39
 	}
43
 	}
40
 	oa.Id = of.Id
44
 	oa.Id = of.Id
41
 	return oa
45
 	return oa

+ 11 - 0
model/oauth.go

@@ -14,6 +14,8 @@ const (
14
 	OauthTypeGoogle  string = "google"
14
 	OauthTypeGoogle  string = "google"
15
 	OauthTypeOidc    string = "oidc"
15
 	OauthTypeOidc    string = "oidc"
16
 	OauthTypeWebauth string = "webauth"
16
 	OauthTypeWebauth string = "webauth"
17
+	PKCEMethodS256   string = "S256"
18
+	PKCEMethodPlain  string = "plain"
17
 )
19
 )
18
 
20
 
19
 // Validate the oauth type
21
 // Validate the oauth type
@@ -41,6 +43,8 @@ type Oauth struct {
41
 	AutoRegister *bool  `json:"auto_register"`
43
 	AutoRegister *bool  `json:"auto_register"`
42
 	Scopes       string `json:"scopes"`
44
 	Scopes       string `json:"scopes"`
43
 	Issuer       string `json:"issuer"`
45
 	Issuer       string `json:"issuer"`
46
+	PkceEnable	 *bool  `json:"pkce_enable"`
47
+	PkceMethod	 string `json:"pkce_method"`
44
 	TimeModel
48
 	TimeModel
45
 }
49
 }
46
 
50
 
@@ -68,6 +72,13 @@ func (oa *Oauth) FormatOauthInfo() error {
68
 	if oauthType == OauthTypeGoogle && issuer == "" {
72
 	if oauthType == OauthTypeGoogle && issuer == "" {
69
 		oa.Issuer = IssuerGoogle
73
 		oa.Issuer = IssuerGoogle
70
 	}
74
 	}
75
+	if oa.PkceEnable == nil {
76
+		oa.PkceEnable = new(bool)
77
+		*oa.PkceEnable = false
78
+	}
79
+	if oa.PkceMethod == "" {
80
+		oa.PkceMethod = PKCEMethodS256
81
+	}
71
 	return nil
82
 	return nil
72
 }
83
 }
73
 
84
 

+ 35 - 16
service/oauth.go

@@ -45,6 +45,7 @@ type OauthCacheItem struct {
45
 	Username   string `json:"username"`
45
 	Username   string `json:"username"`
46
 	Name       string `json:"name"`
46
 	Name       string `json:"name"`
47
 	Email      string `json:"email"`
47
 	Email      string `json:"email"`
48
+	Verifier   string `json:"verifier"`  // used for oauth pkce
48
 }
49
 }
49
 
50
 
50
 func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
51
 func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
@@ -92,19 +93,32 @@ func (os *OauthService) DeleteOauthCache(key string) {
92
 	OauthCache.Delete(key)
93
 	OauthCache.Delete(key)
93
 }
94
 }
94
 
95
 
95
-func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
96
-	code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
96
+func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
97
+	state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
98
+	verifier = ""
97
 	if op == string(model.OauthTypeWebauth) {
99
 	if op == string(model.OauthTypeWebauth) {
98
-		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
100
+		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
99
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
101
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
100
-		return nil, code, url
102
+		return nil, state, verifier, url
101
 	}
103
 	}
102
-	err, _, oauthConfig := os.GetOauthConfig(op)
104
+	err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
103
 	if err == nil {
105
 	if err == nil {
104
-		return err, code, oauthConfig.AuthCodeURL(code)
106
+		extras := make([]oauth2.AuthCodeOption, 0, 3)
107
+		if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
108
+			extras = append(extras, oauth2.AccessTypeOffline)
109
+			verifier = oauth2.GenerateVerifier()
110
+			switch oauthInfo.PkceMethod {
111
+			case model.PKCEMethodS256:
112
+				extras = append(extras, oauth2.S256ChallengeOption(verifier))
113
+			case model.PKCEMethodPlain:
114
+				// oauth2 does not have a plain challenge option, so we add it manually
115
+				extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
116
+			}
117
+		}
118
+		return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...)
105
 	}
119
 	}
106
 
120
 
107
-	return err, code, ""
121
+	return err, state, verifier, ""
108
 }
122
 }
109
 
123
 
110
 // Method to fetch OIDC configuration dynamically
124
 // Method to fetch OIDC configuration dynamically
@@ -207,15 +221,20 @@ func getHTTPClientWithProxy() *http.Client {
207
 	return http.DefaultClient
221
 	return http.DefaultClient
208
 }
222
 }
209
 
223
 
210
-func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
224
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
211
 
225
 
212
 	// 设置代理客户端
226
 	// 设置代理客户端
213
 	httpClient := getHTTPClientWithProxy()
227
 	httpClient := getHTTPClientWithProxy()
214
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
228
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
215
 
229
 
230
+	var exchangeOpts []oauth2.AuthCodeOption
231
+	if verifier != "" {
232
+		exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
233
+	}
234
+
216
 	// 使用 code 换取 token
235
 	// 使用 code 换取 token
217
 	var token *oauth2.Token
236
 	var token *oauth2.Token
218
-	token, err = oauthConfig.Exchange(ctx, code)
237
+	token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
219
 	if err != nil {
238
 	if err != nil {
220
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
239
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
221
 		return errors.New("GetOauthTokenError"), nil
240
 		return errors.New("GetOauthTokenError"), nil
@@ -244,9 +263,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, us
244
 }
263
 }
245
 
264
 
246
 // githubCallback github回调
265
 // githubCallback github回调
247
-func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
266
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
248
 	var user = &model.GithubUser{}
267
 	var user = &model.GithubUser{}
249
-	err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
268
+	err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
250
 	if err != nil {
269
 	if err != nil {
251
 		return err, nil
270
 		return err, nil
252
 	}
271
 	}
@@ -258,16 +277,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
258
 }
277
 }
259
 
278
 
260
 // oidcCallback oidc回调, 通过code获取用户信息
279
 // oidcCallback oidc回调, 通过code获取用户信息
261
-func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
280
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
262
 	var user = &model.OidcUser{}
281
 	var user = &model.OidcUser{}
263
-	if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
282
+	if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
264
 		return err, nil
283
 		return err, nil
265
 	}
284
 	}
266
 	return nil, user.ToOauthUser()
285
 	return nil, user.ToOauthUser()
267
 }
286
 }
268
 
287
 
269
 // Callback: Get user information by code and op(Oauth provider)
288
 // Callback: Get user information by code and op(Oauth provider)
270
-func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
289
+func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
271
 	var oauthInfo *model.Oauth
290
 	var oauthInfo *model.Oauth
272
 	var oauthConfig *oauth2.Config
291
 	var oauthConfig *oauth2.Config
273
 	err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
292
 	err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
@@ -278,13 +297,13 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
278
 	oauthType := oauthInfo.OauthType
297
 	oauthType := oauthInfo.OauthType
279
 	switch oauthType {
298
 	switch oauthType {
280
 	case model.OauthTypeGithub:
299
 	case model.OauthTypeGithub:
281
-		err, oauthUser = os.githubCallback(oauthConfig, code)
300
+		err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
282
 	case model.OauthTypeOidc, model.OauthTypeGoogle:
301
 	case model.OauthTypeOidc, model.OauthTypeGoogle:
283
 		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
302
 		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
284
 		if err != nil {
303
 		if err != nil {
285
 			return err, nil
304
 			return err, nil
286
 		}
305
 		}
287
-		err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
306
+		err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
288
 	default:
307
 	default:
289
 		return errors.New("unsupported OAuth type"), nil
308
 		return errors.New("unsupported OAuth type"), nil
290
 	}
309
 	}