|
|
@@ -45,6 +45,7 @@ type OauthCacheItem struct {
|
|
45
|
45
|
Username string `json:"username"`
|
|
46
|
46
|
Name string `json:"name"`
|
|
47
|
47
|
Email string `json:"email"`
|
|
|
48
|
+ Verifier string `json:"verifier"` // used for oauth pkce
|
|
48
|
49
|
}
|
|
49
|
50
|
|
|
50
|
51
|
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
|
|
|
@@ -92,19 +93,32 @@ func (os *OauthService) DeleteOauthCache(key string) {
|
|
92
|
93
|
OauthCache.Delete(key)
|
|
93
|
94
|
}
|
|
94
|
95
|
|
|
95
|
|
-func (os *OauthService) BeginAuth(op string) (error error, code, url string) {
|
|
96
|
|
- code = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
|
|
96
|
+func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
|
|
|
97
|
+ state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
|
|
98
|
+ verifier = ""
|
|
97
|
99
|
if op == string(model.OauthTypeWebauth) {
|
|
98
|
|
- url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + code
|
|
|
100
|
+ url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
|
|
99
|
101
|
//url = "http://localhost:8888/_admin/#/oauth/" + code
|
|
100
|
|
- return nil, code, url
|
|
|
102
|
+ return nil, state, verifier, url
|
|
101
|
103
|
}
|
|
102
|
|
- err, _, oauthConfig := os.GetOauthConfig(op)
|
|
|
104
|
+ err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
|
|
103
|
105
|
if err == nil {
|
|
104
|
|
- return err, code, oauthConfig.AuthCodeURL(code)
|
|
|
106
|
+ extras := make([]oauth2.AuthCodeOption, 0, 3)
|
|
|
107
|
+ if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
|
|
|
108
|
+ extras = append(extras, oauth2.AccessTypeOffline)
|
|
|
109
|
+ verifier = oauth2.GenerateVerifier()
|
|
|
110
|
+ switch oauthInfo.PkceMethod {
|
|
|
111
|
+ case model.PKCEMethodS256:
|
|
|
112
|
+ extras = append(extras, oauth2.S256ChallengeOption(verifier))
|
|
|
113
|
+ case model.PKCEMethodPlain:
|
|
|
114
|
+ // oauth2 does not have a plain challenge option, so we add it manually
|
|
|
115
|
+ extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
|
|
|
116
|
+ }
|
|
|
117
|
+ }
|
|
|
118
|
+ return err, state, verifier, oauthConfig.AuthCodeURL(state, extras...)
|
|
105
|
119
|
}
|
|
106
|
120
|
|
|
107
|
|
- return err, code, ""
|
|
|
121
|
+ return err, state, verifier, ""
|
|
108
|
122
|
}
|
|
109
|
123
|
|
|
110
|
124
|
// Method to fetch OIDC configuration dynamically
|
|
|
@@ -207,15 +221,20 @@ func getHTTPClientWithProxy() *http.Client {
|
|
207
|
221
|
return http.DefaultClient
|
|
208
|
222
|
}
|
|
209
|
223
|
|
|
210
|
|
-func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
|
|
|
224
|
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
|
|
211
|
225
|
|
|
212
|
226
|
// 设置代理客户端
|
|
213
|
227
|
httpClient := getHTTPClientWithProxy()
|
|
214
|
228
|
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
|
215
|
229
|
|
|
|
230
|
+ var exchangeOpts []oauth2.AuthCodeOption
|
|
|
231
|
+ if verifier != "" {
|
|
|
232
|
+ exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
|
|
|
233
|
+ }
|
|
|
234
|
+
|
|
216
|
235
|
// 使用 code 换取 token
|
|
217
|
236
|
var token *oauth2.Token
|
|
218
|
|
- token, err = oauthConfig.Exchange(ctx, code)
|
|
|
237
|
+ token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
|
|
219
|
238
|
if err != nil {
|
|
220
|
239
|
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
|
221
|
240
|
return errors.New("GetOauthTokenError"), nil
|
|
|
@@ -244,9 +263,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, us
|
|
244
|
263
|
}
|
|
245
|
264
|
|
|
246
|
265
|
// githubCallback github回调
|
|
247
|
|
-func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string) (error, *model.OauthUser) {
|
|
|
266
|
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
|
|
248
|
267
|
var user = &model.GithubUser{}
|
|
249
|
|
- err, client := os.callbackBase(oauthConfig, code, model.UserEndpointGithub, user)
|
|
|
268
|
+ err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
|
|
250
|
269
|
if err != nil {
|
|
251
|
270
|
return err, nil
|
|
252
|
271
|
}
|
|
|
@@ -258,16 +277,16 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
|
|
258
|
277
|
}
|
|
259
|
278
|
|
|
260
|
279
|
// oidcCallback oidc回调, 通过code获取用户信息
|
|
261
|
|
-func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
|
|
|
280
|
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
|
|
262
|
281
|
var user = &model.OidcUser{}
|
|
263
|
|
- if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
|
|
|
282
|
+ if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
|
|
264
|
283
|
return err, nil
|
|
265
|
284
|
}
|
|
266
|
285
|
return nil, user.ToOauthUser()
|
|
267
|
286
|
}
|
|
268
|
287
|
|
|
269
|
288
|
// Callback: Get user information by code and op(Oauth provider)
|
|
270
|
|
-func (os *OauthService) Callback(code string, op string) (err error, oauthUser *model.OauthUser) {
|
|
|
289
|
+func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
|
|
271
|
290
|
var oauthInfo *model.Oauth
|
|
272
|
291
|
var oauthConfig *oauth2.Config
|
|
273
|
292
|
err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
|
|
|
@@ -278,13 +297,13 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
|
|
278
|
297
|
oauthType := oauthInfo.OauthType
|
|
279
|
298
|
switch oauthType {
|
|
280
|
299
|
case model.OauthTypeGithub:
|
|
281
|
|
- err, oauthUser = os.githubCallback(oauthConfig, code)
|
|
|
300
|
+ err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
|
|
282
|
301
|
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
|
283
|
302
|
err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
|
|
284
|
303
|
if err != nil {
|
|
285
|
304
|
return err, nil
|
|
286
|
305
|
}
|
|
287
|
|
- err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
|
|
|
306
|
+ err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
|
|
288
|
307
|
default:
|
|
289
|
308
|
return errors.New("unsupported OAuth type"), nil
|
|
290
|
309
|
}
|