apimain.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. package main
  2. import (
  3. "github.com/go-redis/redis/v8"
  4. "github.com/lejianwen/rustdesk-api/v2/config"
  5. "github.com/lejianwen/rustdesk-api/v2/global"
  6. "github.com/lejianwen/rustdesk-api/v2/http"
  7. "github.com/lejianwen/rustdesk-api/v2/lib/cache"
  8. "github.com/lejianwen/rustdesk-api/v2/lib/jwt"
  9. "github.com/lejianwen/rustdesk-api/v2/lib/lock"
  10. "github.com/lejianwen/rustdesk-api/v2/lib/logger"
  11. "github.com/lejianwen/rustdesk-api/v2/lib/orm"
  12. "github.com/lejianwen/rustdesk-api/v2/lib/upload"
  13. "github.com/lejianwen/rustdesk-api/v2/model"
  14. "github.com/lejianwen/rustdesk-api/v2/service"
  15. "github.com/lejianwen/rustdesk-api/v2/utils"
  16. "github.com/nicksnyder/go-i18n/v2/i18n"
  17. "github.com/spf13/cobra"
  18. "os"
  19. "strconv"
  20. "time"
  21. )
  22. // @title 管理系统API
  23. // @version 1.0
  24. // @description 接口
  25. // @basePath /api
  26. // @securityDefinitions.apikey token
  27. // @in header
  28. // @name api-token
  29. // @securitydefinitions.apikey BearerAuth
  30. // @in header
  31. // @name Authorization
  32. var rootCmd = &cobra.Command{
  33. Use: "apimain",
  34. Short: "RUSTDESK API SERVER",
  35. PersistentPreRun: func(cmd *cobra.Command, args []string) {
  36. InitGlobal()
  37. },
  38. Run: func(cmd *cobra.Command, args []string) {
  39. global.Logger.Info("API SERVER START")
  40. http.ApiInit()
  41. },
  42. }
  43. var resetPwdCmd = &cobra.Command{
  44. Use: "reset-admin-pwd [pwd]",
  45. Example: "reset-admin-pwd 123456",
  46. Short: "Reset Admin Password",
  47. Args: cobra.ExactArgs(1),
  48. Run: func(cmd *cobra.Command, args []string) {
  49. pwd := args[0]
  50. admin := service.AllService.UserService.InfoById(1)
  51. if admin.Id == 0 {
  52. global.Logger.Warn("user not found! ")
  53. return
  54. }
  55. err := service.AllService.UserService.UpdatePassword(admin, pwd)
  56. if err != nil {
  57. global.Logger.Error("reset password fail! ", err)
  58. return
  59. }
  60. global.Logger.Info("reset password success! ")
  61. },
  62. }
  63. var resetUserPwdCmd = &cobra.Command{
  64. Use: "reset-pwd [userId] [pwd]",
  65. Example: "reset-pwd 2 123456",
  66. Short: "Reset User Password",
  67. Args: cobra.ExactArgs(2),
  68. Run: func(cmd *cobra.Command, args []string) {
  69. userId := args[0]
  70. pwd := args[1]
  71. uid, err := strconv.Atoi(userId)
  72. if err != nil {
  73. global.Logger.Warn("userId must be int!")
  74. return
  75. }
  76. if uid <= 0 {
  77. global.Logger.Warn("userId must be greater than 0! ")
  78. return
  79. }
  80. u := service.AllService.UserService.InfoById(uint(uid))
  81. if u.Id == 0 {
  82. global.Logger.Warn("user not found! ")
  83. return
  84. }
  85. err = service.AllService.UserService.UpdatePassword(u, pwd)
  86. if err != nil {
  87. global.Logger.Warn("reset password fail! ", err)
  88. return
  89. }
  90. global.Logger.Info("reset password success!")
  91. },
  92. }
  93. func init() {
  94. rootCmd.PersistentFlags().StringVarP(&global.ConfigPath, "config", "c", "./conf/config.yaml", "choose config file")
  95. rootCmd.AddCommand(resetPwdCmd, resetUserPwdCmd)
  96. }
  97. func main() {
  98. if err := rootCmd.Execute(); err != nil {
  99. global.Logger.Error(err)
  100. os.Exit(1)
  101. }
  102. }
  103. func InitGlobal() {
  104. //配置解析
  105. global.Viper = config.Init(&global.Config, global.ConfigPath)
  106. //日志
  107. global.Logger = logger.New(&logger.Config{
  108. Path: global.Config.Logger.Path,
  109. Level: global.Config.Logger.Level,
  110. ReportCaller: global.Config.Logger.ReportCaller,
  111. })
  112. global.InitI18n()
  113. //redis
  114. global.Redis = redis.NewClient(&redis.Options{
  115. Addr: global.Config.Redis.Addr,
  116. Password: global.Config.Redis.Password,
  117. DB: global.Config.Redis.Db,
  118. })
  119. //cache
  120. if global.Config.Cache.Type == cache.TypeFile {
  121. fc := cache.NewFileCache()
  122. fc.SetDir(global.Config.Cache.FileDir)
  123. global.Cache = fc
  124. } else if global.Config.Cache.Type == cache.TypeRedis {
  125. global.Cache = cache.NewRedis(&redis.Options{
  126. Addr: global.Config.Cache.RedisAddr,
  127. Password: global.Config.Cache.RedisPwd,
  128. DB: global.Config.Cache.RedisDb,
  129. })
  130. }
  131. //gorm
  132. if global.Config.Gorm.Type == config.TypeMysql {
  133. dns := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/" + global.Config.Mysql.Dbname + "?charset=utf8mb4&parseTime=True&loc=Local"
  134. global.DB = orm.NewMysql(&orm.MysqlConfig{
  135. Dns: dns,
  136. MaxIdleConns: global.Config.Gorm.MaxIdleConns,
  137. MaxOpenConns: global.Config.Gorm.MaxOpenConns,
  138. })
  139. } else {
  140. //sqlite
  141. global.DB = orm.NewSqlite(&orm.SqliteConfig{
  142. MaxIdleConns: global.Config.Gorm.MaxIdleConns,
  143. MaxOpenConns: global.Config.Gorm.MaxOpenConns,
  144. })
  145. }
  146. //validator
  147. global.ApiInitValidator()
  148. //oss
  149. global.Oss = &upload.Oss{
  150. AccessKeyId: global.Config.Oss.AccessKeyId,
  151. AccessKeySecret: global.Config.Oss.AccessKeySecret,
  152. Host: global.Config.Oss.Host,
  153. CallbackUrl: global.Config.Oss.CallbackUrl,
  154. ExpireTime: global.Config.Oss.ExpireTime,
  155. MaxByte: global.Config.Oss.MaxByte,
  156. }
  157. //jwt
  158. //fmt.Println(global.Config.Jwt.PrivateKey)
  159. global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration)
  160. //locker
  161. global.Lock = lock.NewLocal()
  162. //service
  163. service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock)
  164. global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{
  165. CaptchaThreshold: global.Config.App.CaptchaThreshold,
  166. BanThreshold: global.Config.App.BanThreshold,
  167. AttemptsWindow: 10 * time.Minute,
  168. BanDuration: 30 * time.Minute,
  169. })
  170. global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{})
  171. DatabaseAutoUpdate()
  172. }
  173. func DatabaseAutoUpdate() {
  174. version := 262
  175. db := global.DB
  176. if global.Config.Gorm.Type == config.TypeMysql {
  177. //检查存不存在数据库,不存在则创建
  178. dbName := db.Migrator().CurrentDatabase()
  179. if dbName == "" {
  180. dbName = global.Config.Mysql.Dbname
  181. // 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
  182. dsnWithoutDB := global.Config.Mysql.Username + ":" + global.Config.Mysql.Password + "@(" + global.Config.Mysql.Addr + ")/?charset=utf8mb4&parseTime=True&loc=Local"
  183. //新链接
  184. dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
  185. Dns: dsnWithoutDB,
  186. })
  187. // 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
  188. sqlDBWithoutDB, err := dbWithoutDB.DB()
  189. if err != nil {
  190. global.Logger.Errorf("获取底层 *sql.DB 对象失败: %v", err)
  191. return
  192. }
  193. defer func() {
  194. if err := sqlDBWithoutDB.Close(); err != nil {
  195. global.Logger.Errorf("关闭连接失败: %v", err)
  196. }
  197. }()
  198. err = dbWithoutDB.Exec("CREATE DATABASE IF NOT EXISTS " + dbName + " DEFAULT CHARSET utf8mb4").Error
  199. if err != nil {
  200. global.Logger.Error(err)
  201. return
  202. }
  203. }
  204. }
  205. if !db.Migrator().HasTable(&model.Version{}) {
  206. Migrate(uint(version))
  207. } else {
  208. //查找最后一个version
  209. var v model.Version
  210. db.Last(&v)
  211. if v.Version < uint(version) {
  212. Migrate(uint(version))
  213. }
  214. // 245迁移
  215. if v.Version < 245 {
  216. //oauths 表的 oauth_type 字段设置为 op同样的值
  217. db.Exec("update oauths set oauth_type = op")
  218. db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google'")
  219. db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
  220. //通过email迁移旧的google授权
  221. uts := make([]model.UserThird, 0)
  222. db.Where("oauth_type = ?", "google").Find(&uts)
  223. for _, ut := range uts {
  224. if ut.UserId > 0 {
  225. db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
  226. }
  227. }
  228. }
  229. if v.Version < 246 {
  230. db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer is null")
  231. }
  232. }
  233. }
  234. func Migrate(version uint) {
  235. global.Logger.Info("Migrating....", version)
  236. err := global.DB.AutoMigrate(
  237. &model.Version{},
  238. &model.User{},
  239. &model.UserToken{},
  240. &model.Tag{},
  241. &model.AddressBook{},
  242. &model.Peer{},
  243. &model.Group{},
  244. &model.UserThird{},
  245. &model.Oauth{},
  246. &model.LoginLog{},
  247. &model.ShareRecord{},
  248. &model.AuditConn{},
  249. &model.AuditFile{},
  250. &model.AddressBookCollection{},
  251. &model.AddressBookCollectionRule{},
  252. &model.ServerCmd{},
  253. &model.DeviceGroup{},
  254. )
  255. if err != nil {
  256. global.Logger.Error("migrate err :=>", err)
  257. }
  258. global.DB.Create(&model.Version{Version: version})
  259. //如果是初次则创建一个默认用户
  260. var vc int64
  261. global.DB.Model(&model.Version{}).Count(&vc)
  262. if vc == 1 {
  263. localizer := global.Localizer("")
  264. defaultGroup, _ := localizer.LocalizeMessage(&i18n.Message{
  265. ID: "DefaultGroup",
  266. })
  267. group := &model.Group{
  268. Name: defaultGroup,
  269. Type: model.GroupTypeDefault,
  270. }
  271. service.AllService.GroupService.Create(group)
  272. shareGroup, _ := localizer.LocalizeMessage(&i18n.Message{
  273. ID: "ShareGroup",
  274. })
  275. groupShare := &model.Group{
  276. Name: shareGroup,
  277. Type: model.GroupTypeShare,
  278. }
  279. service.AllService.GroupService.Create(groupShare)
  280. //是true
  281. is_admin := true
  282. admin := &model.User{
  283. Username: "admin",
  284. Nickname: "Admin",
  285. Status: model.COMMON_STATUS_ENABLE,
  286. IsAdmin: &is_admin,
  287. GroupId: 1,
  288. }
  289. // 生成随机密码
  290. pwd := utils.RandomString(8)
  291. global.Logger.Info("Admin Password Is: ", pwd)
  292. admin.Password = service.AllService.UserService.EncryptPassword(pwd)
  293. global.DB.Create(admin)
  294. }
  295. }