lejianwen месяцев назад: 11
Родитель
Сommit
f0b4b0d7c6
2 измененных файлов с 92 добавлено и 79 удалено
  1. 2 0
      go.mod
  2. 90 79
      service/oauth.go

+ 2 - 0
go.mod

@@ -36,9 +36,11 @@ require (
36
 	github.com/bytedance/sonic v1.8.0 // indirect
36
 	github.com/bytedance/sonic v1.8.0 // indirect
37
 	github.com/cespare/xxhash/v2 v2.1.2 // indirect
37
 	github.com/cespare/xxhash/v2 v2.1.2 // indirect
38
 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
38
 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
39
+	github.com/coreos/go-oidc/v3 v3.12.0 // indirect
39
 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
40
 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
40
 	github.com/gin-contrib/sse v0.1.0 // indirect
41
 	github.com/gin-contrib/sse v0.1.0 // indirect
41
 	github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
42
 	github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
43
+	github.com/go-jose/go-jose/v4 v4.0.2 // indirect
42
 	github.com/go-ldap/ldap/v3 v3.4.10 // indirect
44
 	github.com/go-ldap/ldap/v3 v3.4.10 // indirect
43
 	github.com/go-openapi/jsonpointer v0.19.5 // indirect
45
 	github.com/go-openapi/jsonpointer v0.19.5 // indirect
44
 	github.com/go-openapi/jsonreference v0.19.6 // indirect
46
 	github.com/go-openapi/jsonreference v0.19.6 // indirect

+ 90 - 79
service/oauth.go

@@ -4,6 +4,7 @@ import (
4
 	"context"
4
 	"context"
5
 	"encoding/json"
5
 	"encoding/json"
6
 	"errors"
6
 	"errors"
7
+	"github.com/coreos/go-oidc/v3/oidc"
7
 	"github.com/lejianwen/rustdesk-api/v2/global"
8
 	"github.com/lejianwen/rustdesk-api/v2/global"
8
 	"github.com/lejianwen/rustdesk-api/v2/model"
9
 	"github.com/lejianwen/rustdesk-api/v2/model"
9
 	"github.com/lejianwen/rustdesk-api/v2/utils"
10
 	"github.com/lejianwen/rustdesk-api/v2/utils"
@@ -45,7 +46,7 @@ type OauthCacheItem struct {
45
 	Username   string `json:"username"`
46
 	Username   string `json:"username"`
46
 	Name       string `json:"name"`
47
 	Name       string `json:"name"`
47
 	Email      string `json:"email"`
48
 	Email      string `json:"email"`
48
-	Verifier   string `json:"verifier"`  // used for oauth pkce
49
+	Verifier   string `json:"verifier"` // used for oauth pkce
49
 }
50
 }
50
 
51
 
51
 func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
52
 func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
@@ -82,10 +83,9 @@ func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
82
 func (os *OauthService) SetOauthCache(key string, item *OauthCacheItem, expire uint) {
83
 func (os *OauthService) SetOauthCache(key string, item *OauthCacheItem, expire uint) {
83
 	OauthCache.Store(key, item)
84
 	OauthCache.Store(key, item)
84
 	if expire > 0 {
85
 	if expire > 0 {
85
-		go func() {
86
-			time.Sleep(time.Duration(expire) * time.Second)
86
+		time.AfterFunc(time.Duration(expire)*time.Second, func() {
87
 			os.DeleteOauthCache(key)
87
 			os.DeleteOauthCache(key)
88
-		}()
88
+		})
89
 	}
89
 	}
90
 }
90
 }
91
 
91
 
@@ -96,12 +96,12 @@ func (os *OauthService) DeleteOauthCache(key string) {
96
 func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
96
 func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url string) {
97
 	state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
97
 	state = utils.RandomString(10) + strconv.FormatInt(time.Now().Unix(), 10)
98
 	verifier = ""
98
 	verifier = ""
99
-	if op == string(model.OauthTypeWebauth) {
99
+	if op == model.OauthTypeWebauth {
100
 		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
100
 		url = global.Config.Rustdesk.ApiServer + "/_admin/#/oauth/" + state
101
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
101
 		//url = "http://localhost:8888/_admin/#/oauth/" + code
102
 		return nil, state, verifier, url
102
 		return nil, state, verifier, url
103
 	}
103
 	}
104
-	err, oauthInfo, oauthConfig := os.GetOauthConfig(op)
104
+	err, oauthInfo, oauthConfig, _ := os.GetOauthConfig(op)
105
 	if err == nil {
105
 	if err == nil {
106
 		extras := make([]oauth2.AuthCodeOption, 0, 3)
106
 		extras := make([]oauth2.AuthCodeOption, 0, 3)
107
 		if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
107
 		if oauthInfo.PkceEnable != nil && *oauthInfo.PkceEnable {
@@ -121,88 +121,80 @@ func (os *OauthService) BeginAuth(op string) (error error, state, verifier, url
121
 	return err, state, verifier, ""
121
 	return err, state, verifier, ""
122
 }
122
 }
123
 
123
 
124
-// Method to fetch OIDC configuration dynamically
125
-func (os *OauthService) FetchOidcEndpoint(issuer string) (error, OidcEndpoint) {
126
-	configURL := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
124
+func (os *OauthService) FetchOidcProvider(issuer string) (error, *oidc.Provider) {
127
 
125
 
128
 	// Get the HTTP client (with or without proxy based on configuration)
126
 	// Get the HTTP client (with or without proxy based on configuration)
129
 	client := getHTTPClientWithProxy()
127
 	client := getHTTPClientWithProxy()
130
 
128
 
131
-	resp, err := client.Get(configURL)
132
-	if err != nil {
133
-		return errors.New("failed to fetch OIDC configuration"), OidcEndpoint{}
134
-	}
135
-	defer resp.Body.Close()
136
-
137
-	if resp.StatusCode != http.StatusOK {
138
-		return errors.New("OIDC configuration not found, status code: %d"), OidcEndpoint{}
139
-	}
129
+	ctx := oidc.ClientContext(context.Background(), client)
140
 
130
 
141
-	var endpoint OidcEndpoint
142
-	if err := json.NewDecoder(resp.Body).Decode(&endpoint); err != nil {
143
-		return errors.New("failed to parse OIDC configuration"), OidcEndpoint{}
131
+	provider, err := oidc.NewProvider(ctx, issuer)
132
+	if err != nil {
133
+		return err, nil
144
 	}
134
 	}
145
 
135
 
146
-	return nil, endpoint
136
+	return nil, provider
147
 }
137
 }
148
 
138
 
149
-func (os *OauthService) FetchOidcEndpointByOp(op string) (error, OidcEndpoint) {
150
-	oauthInfo := os.InfoByOp(op)
151
-	if oauthInfo.Issuer == "" {
152
-		return errors.New("issuer is empty"), OidcEndpoint{}
153
-	}
154
-	return os.FetchOidcEndpoint(oauthInfo.Issuer)
139
+func (os *OauthService) GithubProvider() *oidc.Provider {
140
+	return (&oidc.ProviderConfig{
141
+		IssuerURL:     "",
142
+		AuthURL:       github.Endpoint.AuthURL,
143
+		TokenURL:      github.Endpoint.TokenURL,
144
+		DeviceAuthURL: github.Endpoint.DeviceAuthURL,
145
+		UserInfoURL:   model.UserEndpointGithub,
146
+		JWKSURL:       "",
147
+		Algorithms:    nil,
148
+	}).NewProvider(context.Background())
155
 }
149
 }
156
 
150
 
157
 // GetOauthConfig retrieves the OAuth2 configuration based on the provider name
151
 // GetOauthConfig retrieves the OAuth2 configuration based on the provider name
158
-func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
159
-	err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
160
-	if err != nil {
161
-		return err, nil, nil
152
+func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config, provider *oidc.Provider) {
153
+	//err, oauthInfo, oauthConfig = os.getOauthConfigGeneral(op)
154
+	oauthInfo = os.InfoByOp(op)
155
+	if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
156
+		return errors.New("ConfigNotFound"), nil, nil, nil
162
 	}
157
 	}
158
+	// If the redirect URL is empty, use the default redirect URL
159
+	if oauthInfo.RedirectUrl == "" {
160
+		oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
161
+	}
162
+	oauthConfig = &oauth2.Config{
163
+		ClientID:     oauthInfo.ClientId,
164
+		ClientSecret: oauthInfo.ClientSecret,
165
+		RedirectURL:  oauthInfo.RedirectUrl,
166
+	}
167
+
163
 	// Maybe should validate the oauthConfig here
168
 	// Maybe should validate the oauthConfig here
164
 	oauthType := oauthInfo.OauthType
169
 	oauthType := oauthInfo.OauthType
165
 	err = model.ValidateOauthType(oauthType)
170
 	err = model.ValidateOauthType(oauthType)
166
 	if err != nil {
171
 	if err != nil {
167
-		return err, nil, nil
172
+		return err, nil, nil, nil
168
 	}
173
 	}
169
 	switch oauthType {
174
 	switch oauthType {
170
 	case model.OauthTypeGithub:
175
 	case model.OauthTypeGithub:
171
 		oauthConfig.Endpoint = github.Endpoint
176
 		oauthConfig.Endpoint = github.Endpoint
172
 		oauthConfig.Scopes = []string{"read:user", "user:email"}
177
 		oauthConfig.Scopes = []string{"read:user", "user:email"}
178
+		provider = os.GithubProvider()
179
+	//case model.OauthTypeGoogle: //google单独出来,可以少一次FetchOidcEndpoint请求
180
+	//	oauthConfig.Endpoint = google.Endpoint
181
+	//	oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
173
 	case model.OauthTypeOidc, model.OauthTypeGoogle:
182
 	case model.OauthTypeOidc, model.OauthTypeGoogle:
174
-		var endpoint OidcEndpoint
175
-		err, endpoint = os.FetchOidcEndpoint(oauthInfo.Issuer)
183
+		err, provider = os.FetchOidcProvider(oauthInfo.Issuer)
176
 		if err != nil {
184
 		if err != nil {
177
-			return err, nil, nil
185
+			return err, nil, nil, nil
178
 		}
186
 		}
179
-		oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
187
+		oauthConfig.Endpoint = provider.Endpoint()
180
 		oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
188
 		oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
181
 	default:
189
 	default:
182
-		return errors.New("unsupported OAuth type"), nil, nil
183
-	}
184
-	return nil, oauthInfo, oauthConfig
185
-}
186
-
187
-// GetOauthConfig retrieves the OAuth2 configuration based on the provider name
188
-func (os *OauthService) getOauthConfigGeneral(op string) (err error, oauthInfo *model.Oauth, oauthConfig *oauth2.Config) {
189
-	oauthInfo = os.InfoByOp(op)
190
-	if oauthInfo.Id == 0 || oauthInfo.ClientId == "" || oauthInfo.ClientSecret == "" {
191
-		return errors.New("ConfigNotFound"), nil, nil
192
-	}
193
-	// If the redirect URL is empty, use the default redirect URL
194
-	if oauthInfo.RedirectUrl == "" {
195
-		oauthInfo.RedirectUrl = global.Config.Rustdesk.ApiServer + "/api/oidc/callback"
196
-	}
197
-	return nil, oauthInfo, &oauth2.Config{
198
-		ClientID:     oauthInfo.ClientId,
199
-		ClientSecret: oauthInfo.ClientSecret,
200
-		RedirectURL:  oauthInfo.RedirectUrl,
190
+		return errors.New("unsupported OAuth type"), nil, nil, nil
201
 	}
191
 	}
192
+	return nil, oauthInfo, oauthConfig, provider
202
 }
193
 }
203
 
194
 
204
 func getHTTPClientWithProxy() *http.Client {
195
 func getHTTPClientWithProxy() *http.Client {
205
-	//todo add timeout
196
+	//add timeout 30s
197
+	timeout := time.Duration(60) * time.Second
206
 	if global.Config.Proxy.Enable {
198
 	if global.Config.Proxy.Enable {
207
 		if global.Config.Proxy.Host == "" {
199
 		if global.Config.Proxy.Host == "" {
208
 			global.Logger.Warn("Proxy is enabled but proxy host is empty.")
200
 			global.Logger.Warn("Proxy is enabled but proxy host is empty.")
@@ -216,33 +208,58 @@ func getHTTPClientWithProxy() *http.Client {
216
 		transport := &http.Transport{
208
 		transport := &http.Transport{
217
 			Proxy: http.ProxyURL(proxyURL),
209
 			Proxy: http.ProxyURL(proxyURL),
218
 		}
210
 		}
219
-		return &http.Client{Transport: transport}
211
+		return &http.Client{Transport: transport, Timeout: timeout}
220
 	}
212
 	}
221
 	return http.DefaultClient
213
 	return http.DefaultClient
222
 }
214
 }
223
-
224
-func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, verifier string, userEndpoint string, userData interface{}) (err error, client *http.Client) {
215
+func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string, nonce string, userData interface{}) (err error, client *http.Client) {
225
 
216
 
226
 	// 设置代理客户端
217
 	// 设置代理客户端
227
 	httpClient := getHTTPClientWithProxy()
218
 	httpClient := getHTTPClientWithProxy()
228
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
219
 	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
229
 
220
 
230
-	var exchangeOpts []oauth2.AuthCodeOption
221
+	exchangeOpts := make([]oauth2.AuthCodeOption, 0, 1)
231
 	if verifier != "" {
222
 	if verifier != "" {
232
-		exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
223
+		exchangeOpts = append(exchangeOpts, oauth2.VerifierOption(verifier))
233
 	}
224
 	}
234
 
225
 
235
-	// 使用 code 换取 token
236
-	var token *oauth2.Token
237
-	token, err = oauthConfig.Exchange(ctx, code, exchangeOpts...)
226
+	token, err := oauthConfig.Exchange(ctx, code, exchangeOpts...)
227
+
238
 	if err != nil {
228
 	if err != nil {
239
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
229
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
240
 		return errors.New("GetOauthTokenError"), nil
230
 		return errors.New("GetOauthTokenError"), nil
241
 	}
231
 	}
242
 
232
 
233
+	// 获取 ID Token, github没有id_token
234
+	rawIDToken, ok := token.Extra("id_token").(string)
235
+	if ok && rawIDToken != "" {
236
+		// 验证 ID Token
237
+		v := provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID})
238
+		idToken, err2 := v.Verify(ctx, rawIDToken)
239
+		if err2 != nil {
240
+			global.Logger.Warn("IdTokenVerifyError: ", err2)
241
+			return errors.New("IdTokenVerifyError"), nil
242
+		}
243
+		if nonce != "" {
244
+			// 验证 nonce
245
+			var claims struct {
246
+				Nonce string `json:"nonce"`
247
+			}
248
+			if err2 = idToken.Claims(&claims); err2 != nil {
249
+				global.Logger.Warn("Failed to parse ID Token claims: ", err)
250
+				return errors.New("IDTokenClaimsError"), nil
251
+			}
252
+
253
+			if claims.Nonce != nonce {
254
+				global.Logger.Warn("Nonce does not match")
255
+				return errors.New("NonceDoesNotMatch"), nil
256
+			}
257
+		}
258
+	}
259
+
243
 	// 获取用户信息
260
 	// 获取用户信息
244
 	client = oauthConfig.Client(ctx, token)
261
 	client = oauthConfig.Client(ctx, token)
245
-	resp, err := client.Get(userEndpoint)
262
+	resp, err := client.Get(provider.UserInfoEndpoint())
246
 	if err != nil {
263
 	if err != nil {
247
 		global.Logger.Warn("failed getting user info: ", err)
264
 		global.Logger.Warn("failed getting user info: ", err)
248
 		return errors.New("GetOauthUserInfoError"), nil
265
 		return errors.New("GetOauthUserInfoError"), nil
@@ -263,9 +280,9 @@ func (os *OauthService) callbackBase(oauthConfig *oauth2.Config, code string, ve
263
 }
280
 }
264
 
281
 
265
 // githubCallback github回调
282
 // githubCallback github回调
266
-func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string, verifier string) (error, *model.OauthUser) {
283
+func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) {
267
 	var user = &model.GithubUser{}
284
 	var user = &model.GithubUser{}
268
-	err, client := os.callbackBase(oauthConfig, code, verifier, model.UserEndpointGithub, user)
285
+	err, client := os.callbackBase(oauthConfig, provider, code, verifier, "", user)
269
 	if err != nil {
286
 	if err != nil {
270
 		return err, nil
287
 		return err, nil
271
 	}
288
 	}
@@ -277,9 +294,9 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string,
277
 }
294
 }
278
 
295
 
279
 // oidcCallback oidc回调, 通过code获取用户信息
296
 // oidcCallback oidc回调, 通过code获取用户信息
280
-func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, verifier string, userInfoEndpoint string) (error, *model.OauthUser) {
297
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, provider *oidc.Provider, code string, verifier string) (error, *model.OauthUser) {
281
 	var user = &model.OidcUser{}
298
 	var user = &model.OidcUser{}
282
-	if err, _ := os.callbackBase(oauthConfig, code, verifier, userInfoEndpoint, user); err != nil {
299
+	if err, _ := os.callbackBase(oauthConfig, provider, code, verifier, "", user); err != nil {
283
 		return err, nil
300
 		return err, nil
284
 	}
301
 	}
285
 	return nil, user.ToOauthUser()
302
 	return nil, user.ToOauthUser()
@@ -287,9 +304,7 @@ func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, ve
287
 
304
 
288
 // Callback: Get user information by code and op(Oauth provider)
305
 // Callback: Get user information by code and op(Oauth provider)
289
 func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
306
 func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUser *model.OauthUser) {
290
-	var oauthInfo *model.Oauth
291
-	var oauthConfig *oauth2.Config
292
-	err, oauthInfo, oauthConfig = os.GetOauthConfig(op)
307
+	err, oauthInfo, oauthConfig, provider := os.GetOauthConfig(op)
293
 	// oauthType is already validated in GetOauthConfig
308
 	// oauthType is already validated in GetOauthConfig
294
 	if err != nil {
309
 	if err != nil {
295
 		return err, nil
310
 		return err, nil
@@ -297,13 +312,9 @@ func (os *OauthService) Callback(code, verifier, op string) (err error, oauthUse
297
 	oauthType := oauthInfo.OauthType
312
 	oauthType := oauthInfo.OauthType
298
 	switch oauthType {
313
 	switch oauthType {
299
 	case model.OauthTypeGithub:
314
 	case model.OauthTypeGithub:
300
-		err, oauthUser = os.githubCallback(oauthConfig, code, verifier)
315
+		err, oauthUser = os.githubCallback(oauthConfig, provider, code, verifier)
301
 	case model.OauthTypeOidc, model.OauthTypeGoogle:
316
 	case model.OauthTypeOidc, model.OauthTypeGoogle:
302
-		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
303
-		if err != nil {
304
-			return err, nil
305
-		}
306
-		err, oauthUser = os.oidcCallback(oauthConfig, code, verifier, endpoint.UserInfo)
317
+		err, oauthUser = os.oidcCallback(oauthConfig, provider, code, verifier)
307
 	default:
318
 	default:
308
 		return errors.New("unsupported OAuth type"), nil
319
 		return errors.New("unsupported OAuth type"), nil
309
 	}
320
 	}