apimain.go 6.5 KB


  1. package main
  2. import (
  3. "Gwen/config"
  4. "Gwen/global"
  5. "Gwen/http"
  6. "Gwen/lib/cache"
  7. "Gwen/lib/lock"
  8. "Gwen/lib/logger"
  9. "Gwen/lib/orm"
  10. "Gwen/lib/upload"
  11. "Gwen/model"
  12. "Gwen/service"
  13. "fmt"
  14. "github.com/go-playground/locales/en"
  15. "github.com/go-playground/locales/zh_Hans_CN"
  16. ut "github.com/go-playground/universal-translator"
  17. "github.com/go-playground/validator/v10"
  18. zh_translations "github.com/go-playground/validator/v10/translations/zh"
  19. "github.com/go-redis/redis/v8"
  20. "reflect"
  21. )
  22. // @title 管理系统API
  23. // @version 1.0
  24. // @description 接口
  25. // @basePath /api
  26. // @securityDefinitions.apikey token
  27. // @in header
  28. // @name api-token
  29. // @securitydefinitions.apikey BearerAuth
  30. // @in header
  31. // @name Authorization
  32. func main() {
  33. //配置解析
  34. global.Viper = config.Init(&global.Config, func() {
  35. fmt.Println(global.Config)
  36. })
  37. //日志
  38. global.Logger = logger.New(&logger.Config{
  39. Path: global.Config.Logger.Path,
  40. Level: global.Config.Logger.Level,
  41. ReportCaller: global.Config.Logger.ReportCaller,
  42. })
  43. //redis
  44. global.Redis = redis.NewClient(&redis.Options{
  45. Addr: global.Config.Redis.Addr,
  46. Password: global.Config.Redis.Password,
  47. DB: global.Config.Redis.Db,
  48. })
  49. //cache
  50. if global.Config.Cache.Type == cache.TypeFile {
  51. fc := cache.NewFileCache()
  52. fc.SetDir(global.Config.Cache.FileDir)
  53. global.Cache = fc
  54. } else if global.Config.Cache.Type == cache.TypeRedis {
  55. global.Cache = cache.NewRedis(&redis.Options{
  56. Addr: global.Config.Cache.RedisAddr,
  57. Password: global.Config.Cache.RedisPwd,
  58. DB: global.Config.Cache.RedisDb,
  59. })
  60. }
  61. //gorm
  62. if global.Config.Gorm.Type == config.TypeMysql {
  63. dns := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/" + global.Config.Mysql.Dbname + "?charset=utf8mb4&parseTime=True&loc=Local"
  64. global.DB = orm.NewMysql(&orm.MysqlConfig{
  65. Dns: dns,
  66. MaxIdleConns: global.Config.Gorm.MaxIdleConns,
  67. MaxOpenConns: global.Config.Gorm.MaxOpenConns,
  68. })
  69. } else {
  70. //sqlite
  71. global.DB = orm.NewSqlite(&orm.SqliteConfig{
  72. MaxIdleConns: global.Config.Gorm.MaxIdleConns,
  73. MaxOpenConns: global.Config.Gorm.MaxOpenConns,
  74. })
  75. }
  76. DatabaseAutoUpdate()
  77. //validator
  78. ApiInitValidator()
  79. //oss
  80. global.Oss = &upload.Oss{
  81. AccessKeyId: global.Config.Oss.AccessKeyId,
  82. AccessKeySecret: global.Config.Oss.AccessKeySecret,
  83. Host: global.Config.Oss.Host,
  84. CallbackUrl: global.Config.Oss.CallbackUrl,
  85. ExpireTime: global.Config.Oss.ExpireTime,
  86. MaxByte: global.Config.Oss.MaxByte,
  87. }
  88. //jwt
  89. //fmt.Println(global.Config.Jwt.PrivateKey)
  90. //global.Jwt = jwt.NewJwt(global.Config.Jwt.PrivateKey, global.Config.Jwt.ExpireDuration*time.Second)
  91. //locker
  92. global.Lock = lock.NewLocal()
  93. //gin
  94. http.ApiInit()
  95. }
  96. func ApiInitValidator() {
  97. validate := validator.New()
  98. enT := en.New()
  99. cn := zh_Hans_CN.New()
  100. uni := ut.New(enT, cn)
  101. trans, _ := uni.GetTranslator("cn")
  102. err := zh_translations.RegisterDefaultTranslations(validate, trans)
  103. if err != nil {
  104. //退出
  105. panic(err)
  106. }
  107. validate.RegisterTagNameFunc(func(field reflect.StructField) string {
  108. label := field.Tag.Get("label")
  109. if label == "" {
  110. return field.Name
  111. }
  112. return label
  113. })
  114. global.Validator.Validate = validate
  115. global.Validator.VTrans = trans
  116. global.Validator.ValidStruct = func(i interface{}) []string {
  117. err := global.Validator.Validate.Struct(i)
  118. errList := make([]string, 0, 10)
  119. if err != nil {
  120. if _, ok := err.(*validator.InvalidValidationError); ok {
  121. errList = append(errList, err.Error())
  122. return errList
  123. }
  124. for _, err2 := range err.(validator.ValidationErrors) {
  125. errList = append(errList, err2.Translate(global.Validator.VTrans))
  126. }
  127. }
  128. return errList
  129. }
  130. global.Validator.ValidVar = func(field interface{}, tag string) []string {
  131. err := global.Validator.Validate.Var(field, tag)
  132. fmt.Println(err)
  133. errList := make([]string, 0, 10)
  134. if err != nil {
  135. if _, ok := err.(*validator.InvalidValidationError); ok {
  136. errList = append(errList, err.Error())
  137. return errList
  138. }
  139. for _, err2 := range err.(validator.ValidationErrors) {
  140. errList = append(errList, err2.Translate(global.Validator.VTrans))
  141. }
  142. }
  143. return errList
  144. }
  145. }
  146. func DatabaseAutoUpdate() {
  147. version := 100
  148. db := global.DB
  149. if global.Config.Gorm.Type == config.TypeMysql {
  150. //检查存不存在数据库,不存在则创建
  151. dbName := db.Migrator().CurrentDatabase()
  152. fmt.Println("dbName", dbName)
  153. if dbName == "" {
  154. dbName = global.Config.Mysql.Dbname
  155. // 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
  156. dsnWithoutDB := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/?charset=utf8mb4&parseTime=True&loc=Local"
  157. //新链接
  158. dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
  159. Dns: dsnWithoutDB,
  160. })
  161. // 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
  162. sqlDBWithoutDB, err := dbWithoutDB.DB()
  163. if err != nil {
  164. fmt.Printf("获取底层 *sql.DB 对象失败: %v\n", err)
  165. return
  166. }
  167. defer func() {
  168. if err := sqlDBWithoutDB.Close(); err != nil {
  169. fmt.Printf("关闭连接失败: %v\n", err)
  170. }
  171. }()
  172. err = dbWithoutDB.Exec("CREATE DATABASE IF NOT EXISTS " + dbName + " DEFAULT CHARSET utf8mb4").Error
  173. if err != nil {
  174. fmt.Println(err)
  175. return
  176. }
  177. }
  178. }
  179. if !db.Migrator().HasTable(&model.Version{}) {
  180. Migrate(uint(version))
  181. } else {
  182. //查找最后一个version
  183. var v model.Version
  184. db.Last(&v)
  185. if v.Version < uint(version) {
  186. Migrate(uint(version))
  187. }
  188. }
  189. }
  190. func Migrate(version uint) {
  191. fmt.Println("migrating....", version)
  192. err := global.DB.AutoMigrate(
  193. &model.Version{},
  194. &model.User{},
  195. &model.UserToken{},
  196. &model.Tag{},
  197. &model.AddressBook{},
  198. &model.Peer{},
  199. &model.Group{},
  200. )
  201. if err != nil {
  202. fmt.Println("migrate err :=>", err)
  203. }
  204. global.DB.Create(&model.Version{Version: version})
  205. //如果是初次则创建一个默认用户
  206. var vc int64
  207. global.DB.Model(&model.Version{}).Count(&vc)
  208. if vc == 1 {
  209. group := &model.Group{
  210. Name: "默认组",
  211. Type: model.GroupTypeDefault,
  212. }
  213. service.AllService.GroupService.Create(group)
  214. groupShare := &model.Group{
  215. Name: "共享组",
  216. Type: model.GroupTypeShare,
  217. }
  218. service.AllService.GroupService.Create(groupShare)
  219. //是true
  220. is_admin := true
  221. admin := &model.User{
  222. Username: "admin",
  223. Nickname: "管理员",
  224. Status: model.COMMON_STATUS_ENABLE,
  225. IsAdmin: &is_admin,
  226. GroupId: 1,
  227. }
  228. admin.Password = service.AllService.UserService.EncryptPassword("admin")
  229. global.DB.Create(admin)
  230. }
  231. }