|
|
@@ -4,6 +4,7 @@ import (
|
|
4
|
4
|
"context"
|
|
5
|
5
|
"encoding/json"
|
|
6
|
6
|
"errors"
|
|
|
7
|
+ "github.com/coreos/go-oidc/v3/oidc"
|
|
7
|
8
|
"github.com/lejianwen/rustdesk-api/v2/global"
|
|
8
|
9
|
"github.com/lejianwen/rustdesk-api/v2/model"
|
|
9
|
10
|
"github.com/lejianwen/rustdesk-api/v2/utils"
|
|
|
@@ -45,7 +46,7 @@ type OauthCacheItem struct {
|
|
45
|
46
|
Username string `json:"username"`
|
|
46
|
47
|
Name string `json:"name"`
|
|
47
|
48
|
Email string `json:"email"`
|
|
48
|
|
- Verifier string `json:"verifier"` // used for oauth pkce
|
|
|
49
|
+ Verifier string `json:"verifier"` // used for oauth pkce
|
|
49
|
50
|
}
|
|
50
|
51
|
|
|
51
|
52
|
func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
|
|
|
@@ -82,10 +83,9 @@ func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
|
|
82
|
83
|
func (os *OauthService) SetOauthCache(key string, item *OauthCacheItem, expire uint) {
|
|
83
|
84
|
OauthCache.Store(key, item)
|
|
84
|
85
|
if expire > 0 {
|
|
85
|
|
- go func() {
|
|
86
|
|
- time.Sleep(time.Duration(expire) * time.Second)
|
|
|
86
|
+ time.AfterFunc(time.Duration(expire)*time.Second, func() {
|
|
87
|
87
|
os.DeleteOauthCache(key)
|
|
88
|
|
- }()
|
|
|
88
|
+ })
|
|
89
|
89
|
}
|
|
90
|
90
|
}
|
|
91
|
91
|
|
|
|
@@ -96,12 +96,12 @@ func (os *OauthService) DeleteOauthCache(key string) {
|
|
96
|
96
|
func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
|
|
97
|
97
|
state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
|
|
98
|
98
|
verifier = ""
|
|
99
|
|
- if op == string(model.OauthTypeWebauth) {
|
|
|
99
|
+ if op == model.OauthTypeWebauth {
|
|
100
|
100
|
url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
|
|
101
|
101
|
//url = "http://localhost:8888/_admin/#/oauth/" + code
|
|
102
|
102
|
return nil, state, verifier, url
|
|
103
|
103
|
}
|
|
104
|
|
- err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
|
|
|
104
|
+ err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op)
|
|
105
|
105
|
if err == nil {
|
|
106
|
106
|
extras := make([]oauth2.AuthCodeOption, 0, 3)
|
|
107
|
107
|
if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
|
|
|
@@ -121,88 +121,80 @@ func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url
|
|
121
|
121
|
return err, state, verifier, ""
|
|
122
|
122
|
}
|
|
123
|
123
|
|
|
124
|
|
-// Method to fetch OIDC configuration dynamically
|
|
125
|
|
-func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) {
|
|
126
|
|
- configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
|
|
124
|
+func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) {
|
|
127
|
125
|
|
|
128
|
126
|
// Get the HTTP client (with or without proxy based on configuration)
|
|
129
|
127
|
client := getHTTPClientWithProxy()
|
|
130
|
128
|
|
|
131
|
|
- resp, err := client.Get(configURL)
|
|
132
|
|
- if err != nil {
|
|
133
|
|
- return errors.New("failed to fetch OIDC configuration"), OidcEndpoint{}
|
|
134
|
|
- }
|
|
135
|
|
- defer resp.Body.Close()
|
|
136
|
|
-
|
|
137
|
|
- if resp.StatusCode != http.StatusOK {
|
|
138
|
|
- return errors.New("OIDC configuration not found, status code: %d"), OidcEndpoint{}
|
|
139
|
|
- }
|
|
|
129
|
+ ctx := oidc.ClientContext(context.Background(), client)
|
|
140
|
130
|
|
|
141
|
|
- var endpoint OidcEndpoint
|
|
142
|
|
- if err := json.NewDecoder(resp.Body).Decode(&endpoint); err != nil {
|
|
143
|
|
- return errors.New("failed to parse OIDC configuration"), OidcEndpoint{}
|
|
|
131
|
+ provider, err := oidc.NewProvider(ctx, issuer)
|
|
|
132
|
+ if err != nil {
|
|
|
133
|
+ return err, nil
|
|
144
|
134
|
}
|
|
145
|
135
|
|
|
146
|
|
- return nil, endpoint
|
|
|
136
|
+ return nil, provider
|
|
147
|
137
|
}
|
|
148
|
138
|
|
|
149
|
|
-func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
|
|
150
|
|
- oauthInfo := os.InfoByOp(op)
|
|
151
|
|
- if oauthInfo.Issuer == "" {
|
|
152
|
|
- return errors.New("issuer is empty"), OidcEndpoint{}
|
|
153
|
|
- }
|
|
154
|
|
- return os.FetchOidcEndpoint(oauthInfo.Issuer)
|
|
|
139
|
+func (os *OauthService) GithubProvider() *oidc.Provider {
|
|
|
140
|
+ return (&oidc.ProviderConfig{
|
|
|
141
|
+ IssuerURL: "",
|
|
|
142
|
+ AuthURL: github.Endpoint.AuthURL,
|
|
|
143
|
+ TokenURL: github.Endpoint.TokenURL,
|
|
|
144
|
+ DeviceAuthURL: github.Endpoint.DeviceAuthURL,
|
|
|
145
|
+ UserInfoURL: model.UserEndpointGithub,
|
|
|
146
|
+ JWKSURL: "",
|
|
|
147
|
+ Algorithms: nil,
|
|
|
148
|
+ }).NewProvider(context.Background())
|
|
155
|
149
|
}
|
|
156
|
150
|
|
|
157
|
151
|
// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
|
|
158
|
|
-func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
|
|
159
|
|
- err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
|
|
160
|
|
- if err != nil {
|
|
161
|
|
- return err, nil, nil
|
|
|
152
|
+func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
|
|
|
153
|
+ //err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
|
|
|
154
|
+ oauthInfo = os.InfoByOp(op)
|
|
|
155
|
+ if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
|
|
|
156
|
+ return errors.New("ConfigNotFound"), nil, nil, nil
|
|
162
|
157
|
}
|
|
|
158
|
+ // If the redirect URL is empty, use the default redirect URL
|
|
|
159
|
+ if oauthInfo.RedirectUrl == "" {
|
|
|
160
|
+ oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
|
|
|
161
|
+ }
|
|
|
162
|
+ oauthConfig = &oauth2.Config{
|
|
|
163
|
+ ClientID: oauthInfo.ClientId,
|
|
|
164
|
+ ClientSecret: oauthInfo.ClientSecret,
|
|
|
165
|
+ RedirectURL: oauthInfo.RedirectUrl,
|
|
|
166
|
+ }
|
|
|
167
|
+
|
|
163
|
168
|
// Maybe should validate the oauthConfig here
|
|
164
|
169
|
oauthType := oauthInfo.OauthType
|
|
165
|
170
|
err = model.ValidateOauthType(oauthType)
|
|
166
|
171
|
if err != nil {
|
|
167
|
|
- return err, nil, nil
|
|
|
172
|
+ return err, nil, nil, nil
|
|
168
|
173
|
}
|
|
169
|
174
|
switch oauthType {
|
|
170
|
175
|
case model.OauthTypeGithub:
|
|
171
|
176
|
oauthConfig.Endpoint = github.Endpoint
|
|
172
|
177
|
oauthConfig.Scopes = []string{"read:user", "user:email"}
|
|
|
178
|
+ provider = os.GithubProvider()
|
|
|
179
|
+ //case model.OauthTypeGoogle: //google单独出来,可以少一次FetchOidcEndpoint请求
|
|
|
180
|
+ // oauthConfig.Endpoint = google.Endpoint
|
|
|
181
|
+ // oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
|
173
|
182
|
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
|
174
|
|
- var endpoint OidcEndpoint
|
|
175
|
|
- err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
|
|
|
183
|
+ err, provider = os.FetchOidcProvider(oauthInfo.Issuer)
|
|
176
|
184
|
if err != nil {
|
|
177
|
|
- return err, nil, nil
|
|
|
185
|
+ return err, nil, nil, nil
|
|
178
|
186
|
}
|
|
179
|
|
- oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
|
|
|
187
|
+ oauthConfig.Endpoint = provider.Endpoint()
|
|
180
|
188
|
oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
|
|
181
|
189
|
default:
|
|
182
|
|
- return errors.New("unsupported OAuth type"), nil, nil
|
|
183
|
|
- }
|
|
184
|
|
- return nil, oauthInfo, oauthConfig
|
|
185
|
|
-}
|
|
186
|
|
-
|
|
187
|
|
-// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
|
|
188
|
|
-func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
|
|
189
|
|
- oauthInfo = os.InfoByOp(op)
|
|
190
|
|
- if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
|
|
191
|
|
- return errors.New("ConfigNotFound"), nil, nil
|
|
192
|
|
- }
|
|
193
|
|
- // If the redirect URL is empty, use the default redirect URL
|
|
194
|
|
- if oauthInfo.RedirectUrl == "" {
|
|
195
|
|
- oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
|
|
196
|
|
- }
|
|
197
|
|
- return nil, oauthInfo, &oauth2.Config{
|
|
198
|
|
- ClientID: oauthInfo.ClientId,
|
|
199
|
|
- ClientSecret: oauthInfo.ClientSecret,
|
|
200
|
|
- RedirectURL: oauthInfo.RedirectUrl,
|
|
|
190
|
+ return errors.New("unsupported OAuth type"), nil, nil, nil
|
|
201
|
191
|
}
|
|
|
192
|
+ return nil, oauthInfo, oauthConfig, provider
|
|
202
|
193
|
}
|
|
203
|
194
|
|
|
204
|
195
|
func getHTTPClientWithProxy() *http.Client {
|
|
205
|
|
- //todo add timeout
|
|
|
196
|
+ //add timeout 30s
|
|
|
197
|
+ timeout := time.Duration(60) * time.Second
|
|
206
|
198
|
if global.Config.Proxy.Enable {
|
|
207
|
199
|
if global.Config.Proxy.Host == "" {
|
|
208
|
200
|
global.Logger.Warn("Proxy is enabled but proxy host is empty.")
|
|
|
@@ -216,33 +208,58 @@ func getHTTPClientWithProxy() *http.Client {
|
|
216
|
208
|
transport := &http.Transport{
|
|
217
|
209
|
Proxy: http.ProxyURL(proxyURL),
|
|
218
|
210
|
}
|
|
219
|
|
- return &http.Client{Transport: transport}
|
|
|
211
|
+ return &http.Client{Transport: transport, Timeout: timeout}
|
|
220
|
212
|
}
|
|
221
|
213
|
return http.DefaultClient
|
|
222
|
214
|
}
|
|
223
|
|
-
|
|
224
|
|
-func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
|
|
|
215
|
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string, nonce string, userData interface{}) (err error, client *http.Client) {
|
|
225
|
216
|
|
|
226
|
217
|
// 设置代理客户端
|
|
227
|
218
|
httpClient := getHTTPClientWithProxy()
|
|
228
|
219
|
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
|
229
|
220
|
|
|
230
|
|
- var exchangeOpts []oauth2.AuthCodeOption
|
|
|
221
|
+ exchangeOpts := make([]oauth2.AuthCodeOption, 0, 1)
|
|
231
|
222
|
if verifier != "" {
|
|
232
|
|
- exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
|
|
|
223
|
+ exchangeOpts = append(exchangeOpts, oauth2.VerifierOption(verifier))
|
|
233
|
224
|
}
|
|
234
|
225
|
|
|
235
|
|
- // 使用 code 换取 token
|
|
236
|
|
- var token *oauth2.Token
|
|
237
|
|
- token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
|
|
|
226
|
+ token, err := oauthConfig.Exchange(ctx, code, exchangeOpts...)
|
|
|
227
|
+
|
|
238
|
228
|
if err != nil {
|
|
239
|
229
|
global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
|
|
240
|
230
|
return errors.New("GetOauthTokenError"), nil
|
|
241
|
231
|
}
|
|
242
|
232
|
|
|
|
233
|
+ // 获取 ID Token, github没有id_token
|
|
|
234
|
+ rawIDToken, ok := token.Extra("id_token").(string)
|
|
|
235
|
+ if ok && rawIDToken != "" {
|
|
|
236
|
+ // 验证 ID Token
|
|
|
237
|
+ v := provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID})
|
|
|
238
|
+ idToken, err2 := v.Verify(ctx, rawIDToken)
|
|
|
239
|
+ if err2 != nil {
|
|
|
240
|
+ global.Logger.Warn("IdTokenVerifyError: ", err2)
|
|
|
241
|
+ return errors.New("IdTokenVerifyError"), nil
|
|
|
242
|
+ }
|
|
|
243
|
+ if nonce != "" {
|
|
|
244
|
+ // 验证 nonce
|
|
|
245
|
+ var claims struct {
|
|
|
246
|
+ Nonce string `json:"nonce"`
|
|
|
247
|
+ }
|
|
|
248
|
+ if err2 = idToken.Claims(&claims); err2 != nil {
|
|
|
249
|
+ global.Logger.Warn("Failed to parse ID Token claims: ", err)
|
|
|
250
|
+ return errors.New("IDTokenClaimsError"), nil
|
|
|
251
|
+ }
|
|
|
252
|
+
|
|
|
253
|
+ if claims.Nonce != nonce {
|
|
|
254
|
+ global.Logger.Warn("Nonce does not match")
|
|
|
255
|
+ return errors.New("NonceDoesNotMatch"), nil
|
|
|
256
|
+ }
|
|
|
257
|
+ }
|
|
|
258
|
+ }
|
|
|
259
|
+
|
|
243
|
260
|
// 获取用户信息
|
|
244
|
261
|
client = oauthConfig.Client(ctx, token)
|
|
245
|
|
- resp, err := client.Get(userEndpoint)
|
|
|
262
|
+ resp, err := client.Get(provider.UserInfoEndpoint())
|
|
246
|
263
|
if err != nil {
|
|
247
|
264
|
global.Logger.Warn("failed getting user info: ", err)
|
|
248
|
265
|
return errors.New("GetOauthUserInfoError"), nil
|
|
|
@@ -263,9 +280,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, ve
|
|
263
|
280
|
}
|
|
264
|
281
|
|
|
265
|
282
|
// githubCallback github回调
|
|
266
|
|
-func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
|
|
|
283
|
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) {
|
|
267
|
284
|
var user = &model.GithubUser{}
|
|
268
|
|
- err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
|
|
|
285
|
+ err, client := os.callbackBase(oauthConfig, provider, code, verifier, "", user)
|
|
269
|
286
|
if err != nil {
|
|
270
|
287
|
return err, nil
|
|
271
|
288
|
}
|
|
|
@@ -277,9 +294,9 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string,
|
|
277
|
294
|
}
|
|
278
|
295
|
|
|
279
|
296
|
// oidcCallback oidc回调, 通过code获取用户信息
|
|
280
|
|
-func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
|
|
|
297
|
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) {
|
|
281
|
298
|
var user = &model.OidcUser{}
|
|
282
|
|
- if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
|
|
|
299
|
+ if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, "", user); err != nil {
|
|
283
|
300
|
return err, nil
|
|
284
|
301
|
}
|
|
285
|
302
|
return nil, user.ToOauthUser()
|
|
|
@@ -287,9 +304,7 @@ func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, ve
|
|
287
|
304
|
|
|
288
|
305
|
// Callback: Get user information by code and op(Oauth provider)
|
|
289
|
306
|
func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
|
|
290
|
|
- var oauthInfo *model.Oauth
|
|
291
|
|
- var oauthConfig *oauth2.Config
|
|
292
|
|
- err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
|
|
|
307
|
+ err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op)
|
|
293
|
308
|
// oauthType is already validated in GetOauthConfig
|
|
294
|
309
|
if err != nil {
|
|
295
|
310
|
return err, nil
|
|
|
@@ -297,13 +312,9 @@ func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUse
|
|
297
|
312
|
oauthType := oauthInfo.OauthType
|
|
298
|
313
|
switch oauthType {
|
|
299
|
314
|
case model.OauthTypeGithub:
|
|
300
|
|
- err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
|
|
|
315
|
+ err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier)
|
|
301
|
316
|
case model.OauthTypeOidc, model.OauthTypeGoogle:
|
|
302
|
|
- err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
|
|
303
|
|
- if err != nil {
|
|
304
|
|
- return err, nil
|
|
305
|
|
- }
|
|
306
|
|
- err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
|
|
|
317
|
+ err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier)
|
|
307
|
318
|
default:
|
|
308
|
319
|
return errors.New("unsupported OAuth type"), nil
|
|
309
|
320
|
}
|