ljw 1 год назад
Родитель
Сommit
daeae19194
9 измененных файлов с 170 добавлено и 175 удалено
  1. 16 1
      cmd/apimain.go
  2. 8 5
      http/controller/admin/oauth.go
  3. 1 1
      http/controller/api/ouath.go
  4. 8 8
      http/request/admin/user.go
  5. 37 54
      model/oauth.go
  6. 9 14
      model/user.go
  7. 11 10
      model/userThird.go
  8. 38 42
      service/oauth.go
  9. 42 40
      service/user.go

+ 16 - 1
cmd/apimain.go

@@ -101,7 +101,7 @@ func main() {
101
 }
101
 }
102
 
102
 
103
 func DatabaseAutoUpdate() {
103
 func DatabaseAutoUpdate() {
104
-	version := 244
104
+	version := 245
105
 
105
 
106
 	db := global.DB
106
 	db := global.DB
107
 
107
 
@@ -146,6 +146,21 @@ func DatabaseAutoUpdate() {
146
 		if v.Version < uint(version) {
146
 		if v.Version < uint(version) {
147
 			Migrate(uint(version))
147
 			Migrate(uint(version))
148
 		}
148
 		}
149
+		// 245迁移
150
+		if v.Version < 245 {
151
+			//oauths 表的 oauth_type 字段设置为 op同样的值
152
+			db.Exec("update oauths set oauth_type = op")
153
+			db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer = ''")
154
+			db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
155
+			//通过email迁移旧的google授权
156
+			uts := make([]model.UserThird, 0)
157
+			db.Where("oauth_type = ?", "google").Find(&uts)
158
+			for _, ut := range uts {
159
+				if ut.UserId > 0 {
160
+					db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
161
+				}
162
+			}
163
+		}
149
 	}
164
 	}
150
 
165
 
151
 }
166
 }

+ 8 - 5
http/controller/admin/oauth.go

@@ -180,15 +180,18 @@ func (o *Oauth) Create(c *gin.Context) {
180
 		response.Fail(c, 101, errList[0])
180
 		response.Fail(c, 101, errList[0])
181
 		return
181
 		return
182
 	}
182
 	}
183
-
184
-	ex := service.AllService.OauthService.InfoByOp(f.Op)
183
+	u := f.ToOauth()
184
+	err := u.FormatOauthInfo()
185
+	if err != nil {
186
+		response.Fail(c, 101, response.TranslateMsg(c, "ParamsError")+err.Error())
187
+		return
188
+	}
189
+	ex := service.AllService.OauthService.InfoByOp(u.Op)
185
 	if ex.Id > 0 {
190
 	if ex.Id > 0 {
186
 		response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
191
 		response.Fail(c, 101, response.TranslateMsg(c, "ItemExists"))
187
 		return
192
 		return
188
 	}
193
 	}
189
-
190
-	u := f.ToOauth()
191
-	err := service.AllService.OauthService.Create(u)
194
+	err = service.AllService.OauthService.Create(u)
192
 	if err != nil {
195
 	if err != nil {
193
 		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
196
 		response.Fail(c, 101, response.TranslateMsg(c, "OperationFailed")+err.Error())
194
 		return
197
 		return

+ 1 - 1
http/controller/api/ouath.go

@@ -217,7 +217,7 @@ func (o *Oauth) OauthCallback(c *gin.Context) {
217
 		oauthCache.UserId = user.Id
217
 		oauthCache.UserId = user.Id
218
 		oauthService.SetOauthCache(cacheKey, oauthCache, 0)
218
 		oauthService.SetOauthCache(cacheKey, oauthCache, 0)
219
 		// 如果是webadmin,登录成功后跳转到webadmin
219
 		// 如果是webadmin,登录成功后跳转到webadmin
220
-		if oauthCache.DeviceType == "webadmin" {
220
+		if oauthCache.DeviceType == model.LoginLogClientWebAdmin {
221
 			/*service.AllService.UserService.Login(u, &model.LoginLog{
221
 			/*service.AllService.UserService.Login(u, &model.LoginLog{
222
 				UserId:   u.Id,
222
 				UserId:   u.Id,
223
 				Client:   "webadmin",
223
 				Client:   "webadmin",

+ 8 - 8
http/request/admin/user.go

@@ -5,15 +5,15 @@ import (
5
 )
5
 )
6
 
6
 
7
 type UserForm struct {
7
 type UserForm struct {
8
-	Id       uint   			`json:"id"`
9
-	Username string 			`json:"username" validate:"required,gte=4,lte=10"`
10
-	Email	 string           	`json:"email" validate:"required,email"`
8
+	Id       uint   `json:"id"`
9
+	Username string `json:"username" validate:"required,gte=4,lte=10"`
10
+	Email    string `json:"email"` //validate:"required,email" email不强制
11
 	//Password string           `json:"password" validate:"required,gte=4,lte=20"`
11
 	//Password string           `json:"password" validate:"required,gte=4,lte=20"`
12
-	Nickname string           	`json:"nickname"`
13
-	Avatar   string           	`json:"avatar"`
14
-	GroupId  uint             	`json:"group_id" validate:"required"`
15
-	IsAdmin  *bool            	`json:"is_admin" `
16
-	Status   model.StatusCode 	`json:"status" validate:"required,gte=0"`
12
+	Nickname string           `json:"nickname"`
13
+	Avatar   string           `json:"avatar"`
14
+	GroupId  uint             `json:"group_id" validate:"required"`
15
+	IsAdmin  *bool            `json:"is_admin" `
16
+	Status   model.StatusCode `json:"status" validate:"required,gte=0"`
17
 }
17
 }
18
 
18
 
19
 func (uf *UserForm) FromUser(user *model.User) *UserForm {
19
 func (uf *UserForm) FromUser(user *model.User) *UserForm {

+ 37 - 54
model/oauth.go

@@ -1,9 +1,9 @@
1
 package model
1
 package model
2
 
2
 
3
 import (
3
 import (
4
+	"errors"
4
 	"strconv"
5
 	"strconv"
5
 	"strings"
6
 	"strings"
6
-	"errors"
7
 )
7
 )
8
 
8
 
9
 const OIDC_DEFAULT_SCOPES = "openid,profile,email"
9
 const OIDC_DEFAULT_SCOPES = "openid,profile,email"
@@ -27,32 +27,23 @@ func ValidateOauthType(oauthType string) error {
27
 }
27
 }
28
 
28
 
29
 const (
29
 const (
30
-	OauthNameGithub  string = "GitHub"
31
-	OauthNameGoogle  string = "Google"
32
-	OauthNameOidc    string = "OIDC"
33
-	OauthNameWebauth string = "WebAuth"
34
-)
35
-
36
-const (
37
-	UserEndpointGithub  string = "https://api.github.com/user"
38
-	IssuerGoogle 		string = "https://accounts.google.com"
30
+	UserEndpointGithub string = "https://api.github.com/user"
31
+	IssuerGoogle       string = "https://accounts.google.com"
39
 )
32
 )
40
 
33
 
41
 type Oauth struct {
34
 type Oauth struct {
42
 	IdModel
35
 	IdModel
43
-	Op           string 	`json:"op"`
44
-	OauthType    string 	`json:"oauth_type"`
45
-	ClientId     string 	`json:"client_id"`
46
-	ClientSecret string 	`json:"client_secret"`
47
-	RedirectUrl  string 	`json:"redirect_url"`
48
-	AutoRegister *bool  	`json:"auto_register"`
49
-	Scopes       string 	`json:"scopes"`
50
-	Issuer	     string 	`json:"issuer"`
36
+	Op           string `json:"op"`
37
+	OauthType    string `json:"oauth_type"`
38
+	ClientId     string `json:"client_id"`
39
+	ClientSecret string `json:"client_secret"`
40
+	RedirectUrl  string `json:"redirect_url"`
41
+	AutoRegister *bool  `json:"auto_register"`
42
+	Scopes       string `json:"scopes"`
43
+	Issuer       string `json:"issuer"`
51
 	TimeModel
44
 	TimeModel
52
 }
45
 }
53
 
46
 
54
-
55
-
56
 // Helper function to format oauth info, it's used in the update and create method
47
 // Helper function to format oauth info, it's used in the update and create method
57
 func (oa *Oauth) FormatOauthInfo() error {
48
 func (oa *Oauth) FormatOauthInfo() error {
58
 	oauthType := strings.TrimSpace(oa.OauthType)
49
 	oauthType := strings.TrimSpace(oa.OauthType)
@@ -60,25 +51,20 @@ func (oa *Oauth) FormatOauthInfo() error {
60
 	if err != nil {
51
 	if err != nil {
61
 		return err
52
 		return err
62
 	}
53
 	}
54
+	switch oauthType {
55
+	case OauthTypeGithub:
56
+		oa.Op = OauthTypeGithub
57
+	case OauthTypeGoogle:
58
+		oa.Op = OauthTypeGoogle
59
+	}
63
 	// check if the op is empty, set the default value
60
 	// check if the op is empty, set the default value
64
 	op := strings.TrimSpace(oa.Op)
61
 	op := strings.TrimSpace(oa.Op)
65
-	if op == "" {
66
-		switch oauthType {
67
-		case OauthTypeGithub:
68
-			oa.Op = OauthNameGithub
69
-		case OauthTypeGoogle:
70
-			oa.Op = OauthNameGoogle
71
-		case OauthTypeOidc:
72
-			oa.Op = OauthNameOidc
73
-		case OauthTypeWebauth:
74
-			oa.Op = OauthNameWebauth
75
-		default:
76
-			oa.Op = oauthType
77
-		}
62
+	if op == "" && oauthType == OauthTypeOidc {
63
+		oa.Op = OauthTypeOidc
78
 	}
64
 	}
79
 	// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
65
 	// check the issuer, if the oauth type is google and the issuer is empty, set the issuer to the default value
80
 	issuer := strings.TrimSpace(oa.Issuer)
66
 	issuer := strings.TrimSpace(oa.Issuer)
81
-	// If the oauth type is google and the issuer is empty, set the issuer to the default value 
67
+	// If the oauth type is google and the issuer is empty, set the issuer to the default value
82
 	if oauthType == OauthTypeGoogle && issuer == "" {
68
 	if oauthType == OauthTypeGoogle && issuer == "" {
83
 		oa.Issuer = IssuerGoogle
69
 		oa.Issuer = IssuerGoogle
84
 	}
70
 	}
@@ -86,12 +72,12 @@ func (oa *Oauth) FormatOauthInfo() error {
86
 }
72
 }
87
 
73
 
88
 type OauthUser struct {
74
 type OauthUser struct {
89
-	OpenId 			string 	`json:"open_id" gorm:"not null;index"`
90
-	Name   			string 	`json:"name"`
91
-	Username 		string 	`json:"username"`
92
-	Email  			string 	`json:"email"`
93
-	VerifiedEmail 	bool 	`json:"verified_email,omitempty"`
94
-	Picture			string 	`json:"picture,omitempty"`
75
+	OpenId        string `json:"open_id" gorm:"not null;index"`
76
+	Name          string `json:"name"`
77
+	Username      string `json:"username"`
78
+	Email         string `json:"email"`
79
+	VerifiedEmail bool   `json:"verified_email,omitempty"`
80
+	Picture       string `json:"picture,omitempty"`
95
 }
81
 }
96
 
82
 
97
 func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
83
 func (ou *OauthUser) ToUser(user *User, overideUsername bool) {
@@ -122,7 +108,7 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
122
 	if ou.PreferredUsername != "" {
108
 	if ou.PreferredUsername != "" {
123
 		username = ou.PreferredUsername
109
 		username = ou.PreferredUsername
124
 	} else {
110
 	} else {
125
-		username = strings.ToLower(strings.Split(ou.Email, "@")[0])
111
+		username = strings.ToLower(ou.Email)
126
 	}
112
 	}
127
 
113
 
128
 	return &OauthUser{
114
 	return &OauthUser{
@@ -135,29 +121,26 @@ func (ou *OidcUser) ToOauthUser() *OauthUser {
135
 	}
121
 	}
136
 }
122
 }
137
 
123
 
138
-
139
 type GithubUser struct {
124
 type GithubUser struct {
140
 	OauthUserBase
125
 	OauthUserBase
141
-	Id                int         `json:"id"`
142
-	Login             string      `json:"login"`
143
-	AvatarUrl         string      `json:"avatar_url"`
144
-	VerifiedEmail	  bool        `json:"verified_email"`
126
+	Id            int    `json:"id"`
127
+	Login         string `json:"login"`
128
+	AvatarUrl     string `json:"avatar_url"`
129
+	VerifiedEmail bool   `json:"verified_email"`
145
 }
130
 }
146
 
131
 
147
 func (gu *GithubUser) ToOauthUser() *OauthUser {
132
 func (gu *GithubUser) ToOauthUser() *OauthUser {
148
 	username := strings.ToLower(gu.Login)
133
 	username := strings.ToLower(gu.Login)
149
 	return &OauthUser{
134
 	return &OauthUser{
150
-		OpenId: 		strconv.Itoa(gu.Id),
151
-		Name:   		gu.Name,
152
-		Username: 		username,
153
-		Email:  		gu.Email,
154
-		VerifiedEmail: 	gu.VerifiedEmail,
155
-		Picture:		gu.AvatarUrl,
135
+		OpenId:        strconv.Itoa(gu.Id),
136
+		Name:          gu.Name,
137
+		Username:      username,
138
+		Email:         gu.Email,
139
+		VerifiedEmail: gu.VerifiedEmail,
140
+		Picture:       gu.AvatarUrl,
156
 	}
141
 	}
157
 }
142
 }
158
 
143
 
159
-
160
-
161
 type OauthList struct {
144
 type OauthList struct {
162
 	Oauths []*Oauth `json:"list"`
145
 	Oauths []*Oauth `json:"list"`
163
 	Pagination
146
 	Pagination

+ 9 - 14
model/user.go

@@ -1,14 +1,9 @@
1
 package model
1
 package model
2
 
2
 
3
-import (
4
-	"fmt"
5
-	"gorm.io/gorm"
6
-)
7
-
8
 type User struct {
3
 type User struct {
9
 	IdModel
4
 	IdModel
10
-	Username string     `json:"username" gorm:"default:'';not null;uniqueIndex"`
11
-	Email	string     	`json:"email" gorm:"default:'';not null;uniqueIndex"`
5
+	Username string `json:"username" gorm:"default:'';not null;uniqueIndex"`
6
+	Email    string `json:"email" gorm:"default:'';not null;index"`
12
 	// Email	string     	`json:"email" `
7
 	// Email	string     	`json:"email" `
13
 	Password string     `json:"-" gorm:"default:'';not null;"`
8
 	Password string     `json:"-" gorm:"default:'';not null;"`
14
 	Nickname string     `json:"nickname" gorm:"default:'';not null;"`
9
 	Nickname string     `json:"nickname" gorm:"default:'';not null;"`
@@ -20,13 +15,13 @@ type User struct {
20
 }
15
 }
21
 
16
 
22
 // BeforeSave 钩子用于确保 email 字段有合理的默认值
17
 // BeforeSave 钩子用于确保 email 字段有合理的默认值
23
-func (u *User) BeforeSave(tx *gorm.DB) (err error) {
24
-    // 如果 email 为空,设置为默认值
25
-    if u.Email == "" {
26
-        u.Email = fmt.Sprintf("%s@example.com", u.Username)
27
-    }
28
-    return nil
29
-}
18
+//func (u *User) BeforeSave(tx *gorm.DB) (err error) {
19
+//	// 如果 email 为空,设置为默认值
20
+//	if u.Email == "" {
21
+//		u.Email = fmt.Sprintf("%s@example.com", u.Username)
22
+//	}
23
+//	return nil
24
+//}
30
 
25
 
31
 type UserList struct {
26
 type UserList struct {
32
 	Users []*User `json:"list,omitempty"`
27
 	Users []*User `json:"list,omitempty"`

+ 11 - 10
model/userThird.go

@@ -6,20 +6,21 @@ import (
6
 
6
 
7
 type UserThird struct {
7
 type UserThird struct {
8
 	IdModel
8
 	IdModel
9
-	UserId     		uint   `	json:"user_id" gorm:"not null;index"`
9
+	UserId uint `json:"user_id" gorm:"not null;index"`
10
 	OauthUser
10
 	OauthUser
11
-	// UnionId    		string `json:"union_id" gorm:"not null;"`
11
+	UnionId string `json:"union_id" gorm:"default:'';not null;"`
12
 	// OauthType  	   	string 		`json:"oauth_type" gorm:"not null;"`
12
 	// OauthType  	   	string 		`json:"oauth_type" gorm:"not null;"`
13
-	OauthType  	   	string 		`json:"oauth_type"`
14
-	Op  			string 		`json:"op" gorm:"not null;"`
13
+	ThirdType string `json:"third_type" gorm:"default:'';not null;"` //deprecated
14
+	OauthType string `json:"oauth_type" gorm:"default:'';not null;"`
15
+	Op        string `json:"op" gorm:"default:'';not null;"`
15
 	TimeModel
16
 	TimeModel
16
 }
17
 }
17
 
18
 
18
 func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
19
 func (u *UserThird) FromOauthUser(userId uint, oauthUser *OauthUser, oauthType string, op string) {
19
-	u.UserId 			= userId
20
-	u.OauthUser 		= *oauthUser
21
-	u.OauthType 		= oauthType
22
-	u.Op 				= op
20
+	u.UserId = userId
21
+	u.OauthUser = *oauthUser
22
+	u.OauthType = oauthType
23
+	u.Op = op
23
 	// make sure email is lower case
24
 	// make sure email is lower case
24
-	u.Email 			= strings.ToLower(u.Email)
25
-}
25
+	u.Email = strings.ToLower(u.Email)
26
+}

+ 38 - 42
service/oauth.go

@@ -12,16 +12,15 @@ import (
12
 	// "golang.org/x/oauth2/google"
12
 	// "golang.org/x/oauth2/google"
13
 	"gorm.io/gorm"
13
 	"gorm.io/gorm"
14
 	// "io"
14
 	// "io"
15
+	"fmt"
15
 	"net/http"
16
 	"net/http"
16
 	"net/url"
17
 	"net/url"
17
 	"strconv"
18
 	"strconv"
18
 	"strings"
19
 	"strings"
19
 	"sync"
20
 	"sync"
20
 	"time"
21
 	"time"
21
-	"fmt"
22
 )
22
 )
23
 
23
 
24
-
25
 type OauthService struct {
24
 type OauthService struct {
26
 }
25
 }
27
 
26
 
@@ -34,26 +33,26 @@ type OidcEndpoint struct {
34
 }
33
 }
35
 
34
 
36
 type OauthCacheItem struct {
35
 type OauthCacheItem struct {
37
-	UserId      uint   `json:"user_id"`
38
-	Id          string `json:"id"` //rustdesk的设备ID
39
-	Op          string `json:"op"`
40
-	Action      string `json:"action"`
41
-	Uuid        string `json:"uuid"`
42
-	DeviceName  string `json:"device_name"`
43
-	DeviceOs    string `json:"device_os"`
44
-	DeviceType  string `json:"device_type"`
45
-	OpenId 		string `json:"open_id"`
46
-	Username	string `json:"username"`
47
-	Name   		string `json:"name"`
48
-	Email  		string `json:"email"`
36
+	UserId     uint   `json:"user_id"`
37
+	Id         string `json:"id"` //rustdesk的设备ID
38
+	Op         string `json:"op"`
39
+	Action     string `json:"action"`
40
+	Uuid       string `json:"uuid"`
41
+	DeviceName string `json:"device_name"`
42
+	DeviceOs   string `json:"device_os"`
43
+	DeviceType string `json:"device_type"`
44
+	OpenId     string `json:"open_id"`
45
+	Username   string `json:"username"`
46
+	Name       string `json:"name"`
47
+	Email      string `json:"email"`
49
 }
48
 }
50
 
49
 
51
 func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
50
 func (oci *OauthCacheItem) ToOauthUser() *model.OauthUser {
52
 	return &model.OauthUser{
51
 	return &model.OauthUser{
53
-		OpenId: oci.OpenId,
52
+		OpenId:   oci.OpenId,
54
 		Username: oci.Username,
53
 		Username: oci.Username,
55
-		Name: oci.Name,
56
-		Email: oci.Email,
54
+		Name:     oci.Name,
55
+		Email:    oci.Email,
57
 	}
56
 	}
58
 }
57
 }
59
 
58
 
@@ -64,14 +63,13 @@ const (
64
 	OauthActionTypeBind  = "bind"
63
 	OauthActionTypeBind  = "bind"
65
 )
64
 )
66
 
65
 
67
-func (oa *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
68
-	oa.OpenId = oauthUser.OpenId
69
-	oa.Username = oauthUser.Username
70
-	oa.Name = oauthUser.Name
71
-	oa.Email = oauthUser.Email
66
+func (oci *OauthCacheItem) UpdateFromOauthUser(oauthUser *model.OauthUser) {
67
+	oci.OpenId = oauthUser.OpenId
68
+	oci.Username = oauthUser.Username
69
+	oci.Name = oauthUser.Name
70
+	oci.Email = oauthUser.Email
72
 }
71
 }
73
 
72
 
74
-
75
 func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
73
 func (os *OauthService) GetOauthCache(key string) *OauthCacheItem {
76
 	v, ok := OauthCache.Load(key)
74
 	v, ok := OauthCache.Load(key)
77
 	if !ok {
75
 	if !ok {
@@ -164,7 +162,7 @@ func (os *OauthService) GetOauthConfig(op string) (err error, oauthInfo *model.O
164
 		if err != nil {
162
 		if err != nil {
165
 			return err, nil, nil
163
 			return err, nil, nil
166
 		}
164
 		}
167
-		oauthConfig.Endpoint = oauth2.Endpoint{AuthURL:  endpoint.AuthURL,TokenURL: endpoint.TokenURL,}
165
+		oauthConfig.Endpoint = oauth2.Endpoint{AuthURL: endpoint.AuthURL, TokenURL: endpoint.TokenURL}
168
 		oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
166
 		oauthConfig.Scopes = os.constructScopes(oauthInfo.Scopes)
169
 	default:
167
 	default:
170
 		return errors.New("unsupported OAuth type"), nil, nil
168
 		return errors.New("unsupported OAuth type"), nil, nil
@@ -259,9 +257,8 @@ func (os *OauthService) githubCallback(oauthConfig *oauth2.Config, code string)
259
 	return nil, user.ToOauthUser()
257
 	return nil, user.ToOauthUser()
260
 }
258
 }
261
 
259
 
262
-
263
 // oidcCallback oidc回调, 通过code获取用户信息
260
 // oidcCallback oidc回调, 通过code获取用户信息
264
-func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser,) {
261
+func (os *OauthService) oidcCallback(oauthConfig *oauth2.Config, code string, userInfoEndpoint string) (error, *model.OauthUser) {
265
 	var user = &model.OidcUser{}
262
 	var user = &model.OidcUser{}
266
 	if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
263
 	if err, _ := os.callbackBase(oauthConfig, code, userInfoEndpoint, user); err != nil {
267
 		return err, nil
264
 		return err, nil
@@ -280,21 +277,20 @@ func (os *OauthService) Callback(code string, op string) (err error, oauthUser *
280
 	}
277
 	}
281
 	oauthType := oauthInfo.OauthType
278
 	oauthType := oauthInfo.OauthType
282
 	switch oauthType {
279
 	switch oauthType {
283
-    case model.OauthTypeGithub:
284
-        err, oauthUser = os.githubCallback(oauthConfig, code)
285
-    case model.OauthTypeOidc, model.OauthTypeGoogle:
280
+	case model.OauthTypeGithub:
281
+		err, oauthUser = os.githubCallback(oauthConfig, code)
282
+	case model.OauthTypeOidc, model.OauthTypeGoogle:
286
 		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
283
 		err, endpoint := os.FetchOidcEndpoint(oauthInfo.Issuer)
287
 		if err != nil {
284
 		if err != nil {
288
 			return err, nil
285
 			return err, nil
289
 		}
286
 		}
290
-        err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
291
-    default:
292
-        return errors.New("unsupported OAuth type"), nil
293
-    }
294
-    return err, oauthUser
287
+		err, oauthUser = os.oidcCallback(oauthConfig, code, endpoint.UserInfo)
288
+	default:
289
+		return errors.New("unsupported OAuth type"), nil
290
+	}
291
+	return err, oauthUser
295
 }
292
 }
296
 
293
 
297
-
298
 func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
294
 func (os *OauthService) UserThirdInfo(op string, openId string) *model.UserThird {
299
 	ut := &model.UserThird{}
295
 	ut := &model.UserThird{}
300
 	global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
296
 	global.DB.Where("open_id = ? and op = ?", openId, op).First(ut)
@@ -343,17 +339,17 @@ func (os *OauthService) InfoByOp(op string) *model.Oauth {
343
 
339
 
344
 // Helper function to get scopes by operation
340
 // Helper function to get scopes by operation
345
 func (os *OauthService) getScopesByOp(op string) []string {
341
 func (os *OauthService) getScopesByOp(op string) []string {
346
-    scopes := os.InfoByOp(op).Scopes
342
+	scopes := os.InfoByOp(op).Scopes
347
 	return os.constructScopes(scopes)
343
 	return os.constructScopes(scopes)
348
 }
344
 }
349
 
345
 
350
 // Helper function to construct scopes
346
 // Helper function to construct scopes
351
 func (os *OauthService) constructScopes(scopes string) []string {
347
 func (os *OauthService) constructScopes(scopes string) []string {
352
-    scopes = strings.TrimSpace(scopes)
353
-    if scopes == "" {
354
-        scopes = model.OIDC_DEFAULT_SCOPES
355
-    }
356
-    return strings.Split(scopes, ",")
348
+	scopes = strings.TrimSpace(scopes)
349
+	if scopes == "" {
350
+		scopes = model.OIDC_DEFAULT_SCOPES
351
+	}
352
+	return strings.Split(scopes, ",")
357
 }
353
 }
358
 
354
 
359
 func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
355
 func (os *OauthService) List(page, pageSize uint, where func(tx *gorm.DB)) (res *model.OauthList) {
@@ -461,4 +457,4 @@ func (os *OauthService) getGithubPrimaryEmail(client *http.Client, githubUser *m
461
 	}
457
 	}
462
 
458
 
463
 	return fmt.Errorf("no primary verified email found")
459
 	return fmt.Errorf("no primary verified email found")
464
-}
460
+}

+ 42 - 40
service/user.go

@@ -5,13 +5,13 @@ import (
5
 	adResp "Gwen/http/response/admin"
5
 	adResp "Gwen/http/response/admin"
6
 	"Gwen/model"
6
 	"Gwen/model"
7
 	"Gwen/utils"
7
 	"Gwen/utils"
8
+	"errors"
8
 	"github.com/gin-gonic/gin"
9
 	"github.com/gin-gonic/gin"
9
 	"gorm.io/gorm"
10
 	"gorm.io/gorm"
10
 	"math/rand"
11
 	"math/rand"
11
 	"strconv"
12
 	"strconv"
12
-	"time"
13
 	"strings"
13
 	"strings"
14
-	"errors"
14
+	"time"
15
 )
15
 )
16
 
16
 
17
 type UserService struct {
17
 type UserService struct {
@@ -23,6 +23,7 @@ func (us *UserService) InfoById(id uint) *model.User {
23
 	global.DB.Where("id = ?", id).First(u)
23
 	global.DB.Where("id = ?", id).First(u)
24
 	return u
24
 	return u
25
 }
25
 }
26
+
26
 // InfoByUsername 根据用户名取用户信息
27
 // InfoByUsername 根据用户名取用户信息
27
 func (us *UserService) InfoByUsername(un string) *model.User {
28
 func (us *UserService) InfoByUsername(un string) *model.User {
28
 	u := &model.User{}
29
 	u := &model.User{}
@@ -75,11 +76,11 @@ func (us *UserService) GenerateToken(u *model.User) string {
75
 func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
76
 func (us *UserService) Login(u *model.User, llog *model.LoginLog) *model.UserToken {
76
 	token := us.GenerateToken(u)
77
 	token := us.GenerateToken(u)
77
 	ut := &model.UserToken{
78
 	ut := &model.UserToken{
78
-		UserId:    	u.Id,
79
-		Token:     	token,
79
+		UserId:     u.Id,
80
+		Token:      token,
80
 		DeviceUuid: llog.Uuid,
81
 		DeviceUuid: llog.Uuid,
81
 		DeviceId:   llog.DeviceId,
82
 		DeviceId:   llog.DeviceId,
82
-		ExpiredAt: 	time.Now().Add(time.Hour * 24 * 7).Unix(),
83
+		ExpiredAt:  time.Now().Add(time.Hour * 24 * 7).Unix(),
83
 	}
84
 	}
84
 	global.DB.Create(ut)
85
 	global.DB.Create(ut)
85
 	llog.UserTokenId = ut.UserId
86
 	llog.UserTokenId = ut.UserId
@@ -162,7 +163,7 @@ func (us *UserService) Create(u *model.User) error {
162
 // GetUuidByToken 根据token和user取uuid
163
 // GetUuidByToken 根据token和user取uuid
163
 func (us *UserService) GetUuidByToken(u *model.User, token string) string {
164
 func (us *UserService) GetUuidByToken(u *model.User, token string) string {
164
 	ut := &model.UserToken{}
165
 	ut := &model.UserToken{}
165
-	err :=global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
166
+	err := global.DB.Where("user_id = ? and token = ?", u.Id, token).First(ut).Error
166
 	if err != nil {
167
 	if err != nil {
167
 		return ""
168
 		return ""
168
 	}
169
 	}
@@ -214,12 +215,12 @@ func (us *UserService) Delete(u *model.User) error {
214
 		tx.Rollback()
215
 		tx.Rollback()
215
 		return err
216
 		return err
216
 	}
217
 	}
217
-	tx.Commit()
218
 	// 删除关联的peer
218
 	// 删除关联的peer
219
 	if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
219
 	if err := AllService.PeerService.EraseUserId(u.Id); err != nil {
220
 		tx.Rollback()
220
 		tx.Rollback()
221
 		return err
221
 		return err
222
 	}
222
 	}
223
+	tx.Commit()
223
 	return nil
224
 	return nil
224
 }
225
 }
225
 
226
 
@@ -230,7 +231,7 @@ func (us *UserService) Update(u *model.User) error {
230
 	if us.IsAdmin(currentUser) {
231
 	if us.IsAdmin(currentUser) {
231
 		adminCount := us.getAdminUserCount()
232
 		adminCount := us.getAdminUserCount()
232
 		// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
233
 		// 如果这是唯一的管理员,确保不能禁用或取消管理员权限
233
-		if adminCount <= 1 && ( !us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
234
+		if adminCount <= 1 && (!us.IsAdmin(u) || u.Status == model.COMMON_STATUS_DISABLED) {
234
 			return errors.New("The last admin user cannot be disabled or demoted")
235
 			return errors.New("The last admin user cannot be disabled or demoted")
235
 		}
236
 		}
236
 	}
237
 	}
@@ -290,48 +291,49 @@ func (us *UserService) InfoByOauthId(op string, openId string) *model.User {
290
 }
291
 }
291
 
292
 
292
 // RegisterByOauth 注册
293
 // RegisterByOauth 注册
293
-func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser , op string) (error, *model.User) {
294
+func (us *UserService) RegisterByOauth(oauthUser *model.OauthUser, op string) (error, *model.User) {
294
 	global.Lock.Lock("registerByOauth")
295
 	global.Lock.Lock("registerByOauth")
295
 	defer global.Lock.UnLock("registerByOauth")
296
 	defer global.Lock.UnLock("registerByOauth")
296
 	ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
297
 	ut := AllService.OauthService.UserThirdInfo(op, oauthUser.OpenId)
297
 	if ut.Id != 0 {
298
 	if ut.Id != 0 {
298
 		return nil, us.InfoById(ut.UserId)
299
 		return nil, us.InfoById(ut.UserId)
299
 	}
300
 	}
300
-	//check if this email has been registered 
301
-	email := oauthUser.Email
302
 	err, oauthType := AllService.OauthService.GetTypeByOp(op)
301
 	err, oauthType := AllService.OauthService.GetTypeByOp(op)
303
 	if err != nil {
302
 	if err != nil {
304
 		return err, nil
303
 		return err, nil
305
 	}
304
 	}
306
-	// if email is empty, use username and op as email
307
-	if email == "" {
308
-		email = oauthUser.Username + "@" + op
309
-	} 
310
-	email = strings.ToLower(email)
311
-	// update email to oauthUser, in case it contain upper case
312
-	oauthUser.Email = email
313
-	user := us.InfoByEmail(email)
314
-	tx := global.DB.Begin()
315
-	if user.Id != 0 {
316
-		ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
317
-	} else {
318
-		ut = &model.UserThird{}
319
-		ut.FromOauthUser(0, oauthUser, oauthType, op)
320
-		// The initial username should be formatted
321
-		username := us.formatUsername(oauthUser.Username)
322
-		usernameUnique := us.GenerateUsernameByOauth(username)
323
-		user = &model.User{
324
-			Username: usernameUnique,
325
-			GroupId:  1,
326
-		}
327
-		oauthUser.ToUser(user, false)
328
-		tx.Create(user)
329
-		if user.Id == 0 {
330
-			tx.Rollback()
331
-			return errors.New("OauthRegisterFailed"), user
305
+	//check if this email has been registered
306
+	email := oauthUser.Email
307
+	// only email is not empty
308
+	if email != "" {
309
+		email = strings.ToLower(email)
310
+		// update email to oauthUser, in case it contain upper case
311
+		oauthUser.Email = email
312
+		user := us.InfoByEmail(email)
313
+		if user.Id != 0 {
314
+			ut.FromOauthUser(user.Id, oauthUser, oauthType, op)
315
+			global.DB.Create(ut)
316
+			return nil, user
332
 		}
317
 		}
333
-		ut.UserId = user.Id
334
 	}
318
 	}
319
+
320
+	tx := global.DB.Begin()
321
+	ut = &model.UserThird{}
322
+	ut.FromOauthUser(0, oauthUser, oauthType, op)
323
+	// The initial username should be formatted
324
+	username := us.formatUsername(oauthUser.Username)
325
+	usernameUnique := us.GenerateUsernameByOauth(username)
326
+	user := &model.User{
327
+		Username: usernameUnique,
328
+		GroupId:  1,
329
+	}
330
+	oauthUser.ToUser(user, false)
331
+	tx.Create(user)
332
+	if user.Id == 0 {
333
+		tx.Rollback()
334
+		return errors.New("OauthRegisterFailed"), user
335
+	}
336
+	ut.UserId = user.Id
335
 	tx.Create(ut)
337
 	tx.Create(ut)
336
 	tx.Commit()
338
 	tx.Commit()
337
 	return nil, user
339
 	return nil, user
@@ -433,7 +435,7 @@ func (us *UserService) formatUsername(username string) string {
433
 	return username
435
 	return username
434
 }
436
 }
435
 
437
 
436
-//  Helper functions, getUserCount
438
+// Helper functions, getUserCount
437
 func (us *UserService) getUserCount() int64 {
439
 func (us *UserService) getUserCount() int64 {
438
 	var count int64
440
 	var count int64
439
 	global.DB.Model(&model.User{}).Count(&count)
441
 	global.DB.Model(&model.User{}).Count(&count)
@@ -445,4 +447,4 @@ func (us *UserService) getAdminUserCount() int64 {
445
 	var count int64
447
 	var count int64
446
 	global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
448
 	global.DB.Model(&model.User{}).Where("is_admin = ?", true).Count(&count)
447
 	return count
449
 	return count
448
-}
450
+}