Просмотр исходного кода

Merge pull request #17 from Ogannesson/master

Add proxy option for Google Oauthon
1 год назад
Родитель
Сommit
dc8fcdf214
4 измененных файлов с 55 добавлено и 9 удалено
  1. 4 1
      conf/config.yaml
  2. 1 0
      config/config.go
  3. 6 0
      config/proxy.go
  4. 44 8
      service/oauth.go

+ 4 - 1
conf/config.yaml

@@ -44,4 +44,7 @@ oss:
44
   max-byte: 10240
44
   max-byte: 10240
45
 jwt:
45
 jwt:
46
   private-key: "./conf/jwt_pri.pem"
46
   private-key: "./conf/jwt_pri.pem"
47
-  expire-duration: 360000
47
+  expire-duration: 360000
48
+proxy:
49
+  enable: false
50
+  host: ""

+ 1 - 0
config/config.go

@@ -30,6 +30,7 @@ type Config struct {
30
 	Oss      Oss
30
 	Oss      Oss
31
 	Jwt      Jwt
31
 	Jwt      Jwt
32
 	Rustdesk Rustdesk
32
 	Rustdesk Rustdesk
33
+	Proxy    Proxy
33
 }
34
 }
34
 
35
 
35
 // Init 初始化配置
36
 // Init 初始化配置

+ 6 - 0
config/proxy.go

@@ -0,0 +1,6 @@
1
+package config
2
+
3
+type Proxy struct {
4
+	Enable bool   `mapstructure:"enable"`
5
+	Host   string `mapstructure:"host"`
6
+}

+ 44 - 8
service/oauth.go

@@ -12,6 +12,8 @@ import (
12
 	"golang.org/x/oauth2/google"
12
 	"golang.org/x/oauth2/google"
13
 	"gorm.io/gorm"
13
 	"gorm.io/gorm"
14
 	"io"
14
 	"io"
15
+	"net/http"
16
+	"net/url"
15
 	"strconv"
17
 	"strconv"
16
 	"sync"
18
 	"sync"
17
 	"time"
19
 	"time"
@@ -166,20 +168,44 @@ func (os *OauthService) GetOauthConfig(op string) (error, *oauth2.Config) {
166
 	return errors.New("ConfigNotFound"), nil
168
 	return errors.New("ConfigNotFound"), nil
167
 }
169
 }
168
 
170
 
171
+func getHTTPClientWithProxy() *http.Client {
172
+	if global.Config.Proxy.Enable {
173
+		if global.Config.Proxy.Host == "" {
174
+			global.Logger.Warn("Proxy is enabled but proxy host is empty.")
175
+			return http.DefaultClient
176
+		}
177
+		proxyURL, err := url.Parse(global.Config.Proxy.Host)
178
+		if err != nil {
179
+			global.Logger.Warn("Invalid proxy URL: ", err)
180
+			return http.DefaultClient
181
+		}
182
+		transport := &http.Transport{
183
+			Proxy: http.ProxyURL(proxyURL),
184
+		}
185
+		return &http.Client{Transport: transport}
186
+	}
187
+	return http.DefaultClient
188
+}
189
+
169
 func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
190
 func (os *OauthService) GithubCallback(code string) (error error, userData *GithubUserdata) {
170
 	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
191
 	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGithub)
171
 	if err != nil {
192
 	if err != nil {
172
 		return err, nil
193
 		return err, nil
173
 	}
194
 	}
174
-	token, err := oauthConfig.Exchange(context.Background(), code)
195
+
196
+	// 使用代理配置创建 HTTP 客户端
197
+	httpClient := getHTTPClientWithProxy()
198
+	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
199
+
200
+	token, err := oauthConfig.Exchange(ctx, code)
175
 	if err != nil {
201
 	if err != nil {
176
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
202
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
177
 		error = errors.New("GetOauthTokenError")
203
 		error = errors.New("GetOauthTokenError")
178
 		return
204
 		return
179
 	}
205
 	}
180
 
206
 
181
-	// 创建一个 HTTP 客户端,并将 access_token 添加到 Authorization 头中
182
-	client := oauthConfig.Client(context.Background(), token)
207
+	// 使用带有代理的 HTTP 客户端获取用户信息
208
+	client := oauthConfig.Client(ctx, token)
183
 	resp, err := client.Get("https://api.github.com/user")
209
 	resp, err := client.Get("https://api.github.com/user")
184
 	if err != nil {
210
 	if err != nil {
185
 		global.Logger.Warn("failed getting user info: ", err)
211
 		global.Logger.Warn("failed getting user info: ", err)
@@ -193,7 +219,7 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith
193
 		}
219
 		}
194
 	}(resp.Body)
220
 	}(resp.Body)
195
 
221
 
196
-	// 在这里处理 GitHub 用户信息
222
+	// 解析用户信息
197
 	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
223
 	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
198
 		global.Logger.Warn("failed decoding user info: ", err)
224
 		global.Logger.Warn("failed decoding user info: ", err)
199
 		error = errors.New("DecodeOauthUserInfoError")
225
 		error = errors.New("DecodeOauthUserInfoError")
@@ -204,14 +230,23 @@ func (os *OauthService) GithubCallback(code string) (error error, userData *Gith
204
 
230
 
205
 func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
231
 func (os *OauthService) GoogleCallback(code string) (error error, userData *GoogleUserdata) {
206
 	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
232
 	err, oauthConfig := os.GetOauthConfig(model.OauthTypeGoogle)
207
-	token, err := oauthConfig.Exchange(context.Background(), code)
233
+	if err != nil {
234
+		return err, nil
235
+	}
236
+
237
+	// 使用代理配置创建 HTTP 客户端
238
+	httpClient := getHTTPClientWithProxy()
239
+	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
240
+
241
+	token, err := oauthConfig.Exchange(ctx, code)
208
 	if err != nil {
242
 	if err != nil {
209
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
243
 		global.Logger.Warn("oauthConfig.Exchange() failed: ", err)
210
 		error = errors.New("GetOauthTokenError")
244
 		error = errors.New("GetOauthTokenError")
211
 		return
245
 		return
212
 	}
246
 	}
213
-	// 创建 HTTP 客户端,并将 access_token 添加到 Authorization 头中
214
-	client := oauthConfig.Client(context.Background(), token)
247
+
248
+	// 使用带有代理的 HTTP 客户端获取用户信息
249
+	client := oauthConfig.Client(ctx, token)
215
 	resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
250
 	resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
216
 	if err != nil {
251
 	if err != nil {
217
 		global.Logger.Warn("failed getting user info: ", err)
252
 		global.Logger.Warn("failed getting user info: ", err)
@@ -225,8 +260,9 @@ func (os *OauthService) GoogleCallback(code string) (error error, userData *Goog
225
 		}
260
 		}
226
 	}(resp.Body)
261
 	}(resp.Body)
227
 
262
 
263
+	// 解析用户信息
228
 	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
264
 	if err = json.NewDecoder(resp.Body).Decode(&userData); err != nil {
229
-		global.Logger.Warn("failed decoding user info: %s\n", err)
265
+		global.Logger.Warn("failed decoding user info: ", err)
230
 		error = errors.New("DecodeOauthUserInfoError")
266
 		error = errors.New("DecodeOauthUserInfoError")
231
 		return
267
 		return
232
 	}
268
 	}