Просмотр исходного кода

feat!: Add JWT

- `RUSTDESK_API_JWT_KEY`如果设置,将会启用JWT,token自动续期功能将失效
- 此功能是为了server端校验token的合法性
lejianwen 1 год назад
Родитель
Сommit
edb095ab0b
7 измененных файлов с 39 добавлено и 22 удалено
  1. 3 1
      cmd/apimain.go
  2. 4 3
      conf/config.yaml
  3. 0 0
      conf/jwt_pri.pem
  4. 1 1
      config/jwt.go
  5. 16 1
      http/middleware/rustauth.go
  6. 8 16
      lib/jwt/jwt.go
  7. 7 0
      service/user.go

+ 3 - 1
cmd/apimain.go

@@ -5,6 +5,7 @@ import (
5
 	"Gwen/global"
5
 	"Gwen/global"
6
 	"Gwen/http"
6
 	"Gwen/http"
7
 	"Gwen/lib/cache"
7
 	"Gwen/lib/cache"
8
+	"Gwen/lib/jwt"
8
 	"Gwen/lib/lock"
9
 	"Gwen/lib/lock"
9
 	"Gwen/lib/logger"
10
 	"Gwen/lib/logger"
10
 	"Gwen/lib/orm"
11
 	"Gwen/lib/orm"
@@ -17,6 +18,7 @@ import (
17
 	"github.com/spf13/cobra"
18
 	"github.com/spf13/cobra"
18
 	"os"
19
 	"os"
19
 	"strconv"
20
 	"strconv"
21
+	"time"
20
 )
22
 )
21
 
23
 
22
 // @title 管理系统API
24
 // @title 管理系统API
@@ -163,7 +165,7 @@ func InitGlobal() {
163
 
165
 
164
 	//jwt
166
 	//jwt
165
 	//fmt.Println(global.Config.Jwt.PrivateKey)
167
 	//fmt.Println(global.Config.Jwt.PrivateKey)
166
-	//global.Jwt = jwt.NewJwt(global.Config.Jwt.PrivateKey, global.Config.Jwt.ExpireDuration*time.Second)
168
+	global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration*time.Second)
167
 
169
 
168
 	//locker
170
 	//locker
169
 	global.Lock = lock.NewLocal()
171
 	global.Lock = lock.NewLocal()

+ 4 - 3
conf/config.yaml

@@ -36,6 +36,9 @@ logger:
36
 proxy:
36
 proxy:
37
   enable: false
37
   enable: false
38
   host: "http://127.0.0.1:1080"
38
   host: "http://127.0.0.1:1080"
39
+jwt:
40
+  key: ""
41
+  expire-duration: 360000
39
 redis:
42
 redis:
40
   addr: "127.0.0.1:6379"
43
   addr: "127.0.0.1:6379"
41
   password: ""
44
   password: ""
@@ -53,6 +56,4 @@ oss:
53
   callback-url: ""
56
   callback-url: ""
54
   expire-time: 30
57
   expire-time: 30
55
   max-byte: 10240
58
   max-byte: 10240
56
-jwt:
57
-  private-key: "./conf/jwt_pri.pem"
58
-  expire-duration: 360000
59
+

+ 0 - 0
conf/jwt_pri.pem


+ 1 - 1
config/jwt.go

@@ -3,6 +3,6 @@ package config
3
 import "time"
3
 import "time"
4
 
4
 
5
 type Jwt struct {
5
 type Jwt struct {
6
-	PrivateKey     string        `mapstructure:"private-key"`
6
+	Key            string        `mapstructure:"key"`
7
 	ExpireDuration time.Duration `mapstructure:"expire-duration"`
7
 	ExpireDuration time.Duration `mapstructure:"expire-duration"`
8
 }
8
 }

+ 16 - 1
http/middleware/rustauth.go

@@ -1,6 +1,7 @@
1
 package middleware
1
 package middleware
2
 
2
 
3
 import (
3
 import (
4
+	"Gwen/global"
4
 	"Gwen/service"
5
 	"Gwen/service"
5
 	"github.com/gin-gonic/gin"
6
 	"github.com/gin-gonic/gin"
6
 )
7
 )
@@ -27,7 +28,21 @@ func RustAuth() gin.HandlerFunc {
27
 		//提取token,格式是Bearer {token}
28
 		//提取token,格式是Bearer {token}
28
 		//这里只是简单的提取
29
 		//这里只是简单的提取
29
 		token = token[7:]
30
 		token = token[7:]
31
+
30
 		//验证token
32
 		//验证token
33
+
34
+		//检查是否设置了jwt key
35
+		if global.Config.Jwt.Key != "" {
36
+			uid, _ := service.AllService.UserService.VerifyJWT(token)
37
+			if uid == 0 {
38
+				c.JSON(401, gin.H{
39
+					"error": "Unauthorized",
40
+				})
41
+				c.Abort()
42
+				return
43
+			}
44
+		}
45
+
31
 		user, ut := service.AllService.UserService.InfoByAccessToken(token)
46
 		user, ut := service.AllService.UserService.InfoByAccessToken(token)
32
 		if user.Id == 0 {
47
 		if user.Id == 0 {
33
 			c.JSON(401, gin.H{
48
 			c.JSON(401, gin.H{
@@ -38,7 +53,7 @@ func RustAuth() gin.HandlerFunc {
38
 		}
53
 		}
39
 		if !service.AllService.UserService.CheckUserEnable(user) {
54
 		if !service.AllService.UserService.CheckUserEnable(user) {
40
 			c.JSON(401, gin.H{
55
 			c.JSON(401, gin.H{
41
-				"error": "账号已被禁用",
56
+				"error": "Unauthorized",
42
 			})
57
 			})
43
 			c.Abort()
58
 			c.Abort()
44
 			return
59
 			return

+ 8 - 16
lib/jwt/jwt.go

@@ -1,14 +1,13 @@
1
 package jwt
1
 package jwt
2
 
2
 
3
 import (
3
 import (
4
-	"crypto/rsa"
4
+	"fmt"
5
 	"github.com/golang-jwt/jwt/v5"
5
 	"github.com/golang-jwt/jwt/v5"
6
-	"os"
7
 	"time"
6
 	"time"
8
 )
7
 )
9
 
8
 
10
 type Jwt struct {
9
 type Jwt struct {
11
-	privateKey          *rsa.PrivateKey
10
+	Key                 []byte
12
 	TokenExpireDuration time.Duration
11
 	TokenExpireDuration time.Duration
13
 }
12
 }
14
 
13
 
@@ -17,31 +16,24 @@ type UserClaims struct {
17
 	jwt.RegisteredClaims
16
 	jwt.RegisteredClaims
18
 }
17
 }
19
 
18
 
20
-func NewJwt(privateKeyFile string, tokenExpireDuration time.Duration) *Jwt {
21
-	privateKeyContent, err := os.ReadFile(privateKeyFile)
22
-	if err != nil {
23
-		panic(err)
24
-	}
25
-	privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyContent)
26
-	if err != nil {
27
-		panic(err)
28
-	}
19
+func NewJwt(key string, tokenExpireDuration time.Duration) *Jwt {
29
 	return &Jwt{
20
 	return &Jwt{
30
-		privateKey:          privateKey,
21
+		Key:                 []byte(key),
31
 		TokenExpireDuration: tokenExpireDuration,
22
 		TokenExpireDuration: tokenExpireDuration,
32
 	}
23
 	}
33
 }
24
 }
34
 
25
 
35
 func (s *Jwt) GenerateToken(userId uint) string {
26
 func (s *Jwt) GenerateToken(userId uint) string {
36
-	t := jwt.NewWithClaims(jwt.SigningMethodRS256,
27
+	t := jwt.NewWithClaims(jwt.SigningMethodHS256,
37
 		UserClaims{
28
 		UserClaims{
38
 			UserId: userId,
29
 			UserId: userId,
39
 			RegisteredClaims: jwt.RegisteredClaims{
30
 			RegisteredClaims: jwt.RegisteredClaims{
40
 				ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.TokenExpireDuration)),
31
 				ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.TokenExpireDuration)),
41
 			},
32
 			},
42
 		})
33
 		})
43
-	token, err := t.SignedString(s.privateKey)
34
+	token, err := t.SignedString(s.Key)
44
 	if err != nil {
35
 	if err != nil {
36
+		fmt.Println(err)
45
 		return ""
37
 		return ""
46
 	}
38
 	}
47
 	return token
39
 	return token
@@ -49,7 +41,7 @@ func (s *Jwt) GenerateToken(userId uint) string {
49
 
41
 
50
 func (s *Jwt) ParseToken(tokenString string) (uint, error) {
42
 func (s *Jwt) ParseToken(tokenString string) (uint, error) {
51
 	token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
43
 	token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
52
-		return s.privateKey.Public(), nil
44
+		return s.Key, nil
53
 	})
45
 	})
54
 	if err != nil {
46
 	if err != nil {
55
 		return 0, err
47
 		return 0, err

+ 7 - 0
service/user.go

@@ -68,6 +68,9 @@ func (us *UserService) InfoByAccessToken(token string) (*model.User, *model.User
68
 
68
 
69
 // GenerateToken 生成token
69
 // GenerateToken 生成token
70
 func (us *UserService) GenerateToken(u *model.User) string {
70
 func (us *UserService) GenerateToken(u *model.User) string {
71
+	if global.Config.Jwt.Key != "" {
72
+		return global.Jwt.GenerateToken(u.Id)
73
+	}
71
 	return utils.Md5(u.Username + time.Now().String())
74
 	return utils.Md5(u.Username + time.Now().String())
72
 }
75
 }
73
 
76
 
@@ -461,3 +464,7 @@ func (us *UserService) AutoRefreshAccessToken(ut *model.UserToken) {
461
 func (us *UserService) BatchDeleteUserToken(ids []uint) error {
464
 func (us *UserService) BatchDeleteUserToken(ids []uint) error {
462
 	return global.DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error
465
 	return global.DB.Where("id in ?", ids).Delete(&model.UserToken{}).Error
463
 }
466
 }
467
+
468
+func (us *UserService) VerifyJWT(token string) (uint, error) {
469
+	return global.Jwt.ParseToken(token)
470
+}