Tao Chen месяцев назад: 11
Родитель
Сommit
6f55c5b642

+ 7 - 1
cmd/apimain.go

@@ -166,7 +166,7 @@ func InitGlobal() {
166 166
 	global.Lock = lock.NewLocal()
167 167
 }
168 168
 func DatabaseAutoUpdate() {
169
-	version := 260
169
+	version := 261
170 170
 
171 171
 	db := global.DB
172 172
 
@@ -210,6 +210,12 @@ func DatabaseAutoUpdate() {
210 210
 		if v.Version < uint(version) {
211 211
 			Migrate(uint(version))
212 212
 		}
213
+		// 261迁移
214
+		if v.Version < 261 {
215
+			// 在oauths表中添加pkce_enable 和 pkce_method 字段
216
+			db.Exec("ALTER TABLE oauths ADD COLUMN pkce_enable TINYINT(1) NOT NULL DEFAULT 0")
217
+			db.Exec("ALTER TABLE oauths ADD COLUMN pkce_method VARCHAR(20) NOT NULL DEFAULT 'S256'")
218
+		}
213 219
 		// 245迁移
214 220
 		if v.Version < 245 {
215 221
 			//oauths 表的 oauth_type 字段设置为 op同样的值

+ 4 - 3
http/controller/admin/login.go

@@ -283,13 +283,13 @@ func (ct *Login) OidcAuth(c *gin.Context) {
283 283
 		return
284 284
 	}
285 285
 
286
-	err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
286
+	err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
287 287
 	if err != nil {
288 288
 		response.Error(c, response.TranslateMsg(c, err.Error()))
289 289
 		return
290 290
 	}
291 291
 
292
-	service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
292
+	service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
293 293
 		Action:     service.OauthActionTypeLogin,
294 294
 		Op:         f.Op,
295 295
 		Id:         f.Id,
@@ -297,10 +297,11 @@ func (ct *Login) OidcAuth(c *gin.Context) {
297 297
 		// DeviceOs: ct.Platform(c),
298 298
 		DeviceOs: f.DeviceInfo.Os,
299 299
 		Uuid:     f.Uuid,
300
+		Verifier: verifier,
300 301
 	}, 5*60)
301 302
 
302 303
 	response.Success(c, gin.H{
303
-		"code": code,
304
+		"code": state,
304 305
 		"url":  url,
305 306
 	})
306 307
 }

+ 6 - 5
http/controller/admin/oauth.go

@@ -43,20 +43,21 @@ func (o *Oauth) ToBind(c *gin.Context) {
43 43
 		return
44 44
 	}
45 45
 
46
-	err, code, url := service.AllService.OauthService.BeginAuth(f.Op)
46
+	err, state, verifier, url := service.AllService.OauthService.BeginAuth(f.Op)
47 47
 	if err != nil {
48 48
 		response.Error(c, response.TranslateMsg(c, err.Error()))
49 49
 		return
50 50
 	}
51 51
 
52
-	service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
52
+	service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
53 53
 		Action: service.OauthActionTypeBind,
54
-		Op:     f.Op,
55
-		UserId: u.Id,
54
+		Op:     	f.Op,
55
+		UserId: 	u.Id,
56
+		Verifier: 	verifier,
56 57
 	}, 5*60)
57 58
 
58 59
 	response.Success(c, gin.H{
59
-		"code": code,
60
+		"code": state,
60 61
 		"url":  url,
61 62
 	})
62 63
 }

+ 8 - 5
http/controller/api/ouath.go

@@ -32,15 +32,16 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
32 32
 	}
33 33
 
34 34
 	oauthService := service.AllService.OauthService
35
-	var code string
35
+	var state string
36 36
 	var url string
37
-	err, code, url = oauthService.BeginAuth(f.Op)
37
+	var verifier string
38
+	err, state, verifier, url = oauthService.BeginAuth(f.Op)
38 39
 	if err != nil {
39 40
 		response.Error(c, response.TranslateMsg(c, err.Error()))
40 41
 		return
41 42
 	}
42 43
 
43
-	service.AllService.OauthService.SetOauthCache(code, &service.OauthCacheItem{
44
+	service.AllService.OauthService.SetOauthCache(state, &service.OauthCacheItem{
44 45
 		Action:     service.OauthActionTypeLogin,
45 46
 		Id:         f.Id,
46 47
 		Op:         f.Op,
@@ -48,10 +49,11 @@ func (o *Oauth) OidcAuth(c *gin.Context) {
48 49
 		DeviceName: f.DeviceInfo.Name,
49 50
 		DeviceOs:   f.DeviceInfo.Os,
50 51
 		DeviceType: f.DeviceInfo.Type,
52
+		Verifier:   verifier,
51 53
 	}, 5*60)
52 54
 	//fmt.Println("code url", code, url)
53 55
 	c.JSON(http.StatusOK, gin.H{
54
-		"code": code,
56
+		"code": state,
55 57
 		"url":  url,
56 58
 	})
57 59
 }
@@ -156,10 +158,11 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
156 158
 	}
157 159
 	op := oauthCache.Op
158 160
 	action := oauthCache.Action
161
+	verifier := oauthCache.Verifier
159 162
 	var user *model.User
160 163
 	// 获取用户信息
161 164
 	code := c.Query("code")
162
-	err, oauthUser := oauthService.Callback(code, op)
165
+	err, oauthUser := oauthService.Callback(code, verifier, op)
163 166
 	if err != nil {
164 167
 		c.String(http.StatusInternalServerError, response.TranslateMsg(c, "OauthFailed")+response.TranslateMsg(c, err.Error()))
165 168
 		return

+ 4 - 0
http/request/admin/oauth.go

@@ -24,6 +24,8 @@ type OauthForm struct {
24 24
 	ClientSecret string `json:"client_secret" validate:"required"`
25 25
 	RedirectUrl  string `json:"redirect_url" validate:"required"`
26 26
 	AutoRegister *bool  `json:"auto_register"`
27
+	PkceEnable   *bool  `json:"pkce_enable"`
28
+	PkceMethod   string `json:"pkce_method"`
27 29
 }
28 30
 
29 31
 func (of *OauthForm) ToOauth() *model.Oauth {
@@ -36,6 +38,8 @@ func (of *OauthForm) ToOauth() *model.Oauth {
36 38
 		AutoRegister: of.AutoRegister,
37 39
 		Issuer:       of.Issuer,
38 40
 		Scopes:       of.Scopes,
41
+		PkceEnable:   of.PkceEnable,
42
+		PkceMethod:   of.PkceMethod,
39 43
 	}
40 44
 	oa.Id = of.Id
41 45
 	return oa

+ 11 - 0
model/oauth.go

@@ -14,6 +14,8 @@ const (
14 14
 	OauthTypeGoogle  string = "google"
15 15
 	OauthTypeOidc    string = "oidc"
16 16
 	OauthTypeWebauth string = "webauth"
17
+	PKCEMethodS256   string = "S256"
18
+	PKCEMethodPlain  string = "plain"
17 19
 )
18 20
 
19 21
 // Validate the oauth type
@@ -41,6 +43,8 @@ type Oauth struct {
41 43
 	AutoRegister *bool  `json:"auto_register"`
42 44
 	Scopes       string `json:"scopes"`
43 45
 	Issuer       string `json:"issuer"`
46
+	PkceEnable	 *bool  `json:"pkce_enable"`
47
+	PkceMethod	 string `json:"pkce_method"`
44 48
 	TimeModel
45 49
 }
46 50
 
@@ -68,6 +72,13 @@ func (oa *Oauth) FormatOauthInfo() error {
68 72
 	if oauthType == OauthTypeGoogle && issuer == "" {
69 73
 		oa.Issuer = IssuerGoogle
70 74
 	}
75
+	if oa.PkceEnable == nil {
76
+		oa.PkceEnable = new(bool)
77
+		*oa.PkceEnable = false
78
+	}
79
+	if oa.PkceMethod == "" {
80
+		oa.PkceMethod = PKCEMethodS256
81
+	}
71 82
 	return nil
72 83
 }
73 84
 

+ 35 - 16
service/oauth.go

@@ -45,6 +45,7 @@ type OauthCacheItem struct {
45 45
 	Username   string `json:"username"`
46 46
 	Name       string `json:"name"`
47 47
 	Email      string `json:"email"`
48
+	Verifier   string `json:"verifier"`  // used for oauth pkce
48 49
 }
49 50
 
50 51
 func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
@@ -92,19 +93,32 @@ func (os *OauthService) DeleteOauthCache(key string) {
92 93
 	OauthCache.Delete(key)
93 94
 }
94 95
 
95
-func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
96
-	code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
96
+func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
97
+	state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
98
+	verifier = ""
97 99
 	if op == string(model.OauthTypeWebauth) {
98
-		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
100
+		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
99 101
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
100
-		return nil, code, url
102
+		return nil, state, verifier, url
101 103
 	}
102
-	err, _, oauthConfig := os.GetOauthConfig(op)
104
+	err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
103 105
 	if err == nil {
104
-		return err, code, oauthConfig.AuthCodeURL(code)
106
+		extras := make([]oauth2.AuthCodeOption, 0, 3)
107
+		if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
108
+			extras = append(extras, oauth2.AccessTypeOffline)
109
+			verifier = oauth2.GenerateVerifier()
110
+			switch oauthInfo.PkceMethod {
111
+			case model.PKCEMethodS256:
112
+				extras = append(extras, oauth2.S256ChallengeOption(verifier))
113
+			case model.PKCEMethodPlain:
114
+				// oauth2 does not have a plain challenge option, so we add it manually
115
+				extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
116
+			}
117
+		}
118
+		return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...)
105 119
 	}
106 120
 
107
-	return err, code, ""
121
+	return err, state, verifier, ""
108 122
 }
109 123
 
110 124
 // Method to fetch OIDC configuration dynamically
@@ -207,15 +221,20 @@ func getHTTPClientWithProxy() *http.Client {
207 221
 	return http.DefaultClient
208 222
 }
209 223
 
210
-func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
224
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
211 225
 
212 226
 	// 设置代理客户端
213 227
 	httpClient := getHTTPClientWithProxy()
214 228
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
215 229
 
230
+	var exchangeOpts []oauth2.AuthCodeOption
231
+	if verifier != "" {
232
+		exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
233
+	}
234
+
216 235
 	// 使用 code 换取 token
217 236
 	var token *oauth2.Token
218
-	token, err = oauthConfig.Exchange(ctx, code)
237
+	token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
219 238
 	if err != nil {
220 239
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
221 240
 		return errors.New("GetOauthTokenError"), nil
@@ -244,9 +263,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, us
244 263
 }
245 264
 
246 265
 // githubCallback github回调
247
-func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
266
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
248 267
 	var user = &model.GithubUser{}
249
-	err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
268
+	err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
250 269
 	if err != nil {
251 270
 		return err, nil
252 271
 	}
@@ -258,16 +277,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
258 277
 }
259 278
 
260 279
 // oidcCallback oidc回调, 通过code获取用户信息
261
-func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
280
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
262 281
 	var user = &model.OidcUser{}
263
-	if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
282
+	if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
264 283
 		return err, nil
265 284
 	}
266 285
 	return nil, user.ToOauthUser()
267 286
 }
268 287
 
269 288
 // Callback: Get user information by code and op(Oauth provider)
270
-func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
289
+func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
271 290
 	var oauthInfo *model.Oauth
272 291
 	var oauthConfig *oauth2.Config
273 292
 	err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
@@ -278,13 +297,13 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
278 297
 	oauthType := oauthInfo.OauthType
279 298
 	switch oauthType {
280 299
 	case model.OauthTypeGithub:
281
-		err, oauthUser = os.githubCallback(oauthConfig, code)
300
+		err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
282 301
 	case model.OauthTypeOidc, model.OauthTypeGoogle:
283 302
 		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
284 303
 		if err != nil {
285 304
 			return err, nil
286 305
 		}
287
-		err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
306
+		err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
288 307
 	default:
289 308
 		return errors.New("unsupported OAuth type"), nil
290 309
 	}