apimain.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. package main
  2. import (
  3. "fmt"
  4. "os"
  5. "strconv"
  6. "time"
  7. "github.com/go-redis/redis/v8"
  8. "github.com/lejianwen/rustdesk-api/v2/config"
  9. "github.com/lejianwen/rustdesk-api/v2/global"
  10. "github.com/lejianwen/rustdesk-api/v2/http"
  11. "github.com/lejianwen/rustdesk-api/v2/lib/cache"
  12. "github.com/lejianwen/rustdesk-api/v2/lib/jwt"
  13. "github.com/lejianwen/rustdesk-api/v2/lib/lock"
  14. "github.com/lejianwen/rustdesk-api/v2/lib/logger"
  15. "github.com/lejianwen/rustdesk-api/v2/lib/orm"
  16. "github.com/lejianwen/rustdesk-api/v2/lib/upload"
  17. "github.com/lejianwen/rustdesk-api/v2/model"
  18. "github.com/lejianwen/rustdesk-api/v2/service"
  19. "github.com/lejianwen/rustdesk-api/v2/utils"
  20. "github.com/nicksnyder/go-i18n/v2/i18n"
  21. "github.com/spf13/cobra"
  22. )
  23. const DatabaseVersion = 265
  24. // @title 管理系统API
  25. // @version 1.0
  26. // @description 接口
  27. // @basePath /api
  28. // @securityDefinitions.apikey token
  29. // @in header
  30. // @name api-token
  31. // @securitydefinitions.apikey BearerAuth
  32. // @in header
  33. // @name Authorization
  34. var rootCmd = &cobra.Command{
  35. Use: "apimain",
  36. Short: "RUSTDESK API SERVER",
  37. PersistentPreRun: func(cmd *cobra.Command, args []string) {
  38. InitGlobal()
  39. },
  40. Run: func(cmd *cobra.Command, args []string) {
  41. global.Logger.Info("API SERVER START")
  42. http.ApiInit()
  43. },
  44. }
  45. var resetPwdCmd = &cobra.Command{
  46. Use: "reset-admin-pwd [pwd]",
  47. Example: "reset-admin-pwd 123456",
  48. Short: "Reset Admin Password",
  49. Args: cobra.ExactArgs(1),
  50. Run: func(cmd *cobra.Command, args []string) {
  51. pwd := args[0]
  52. admin := service.AllService.UserService.InfoById(1)
  53. if admin.Id == 0 {
  54. global.Logger.Warn("user not found! ")
  55. return
  56. }
  57. err := service.AllService.UserService.UpdatePassword(admin, pwd)
  58. if err != nil {
  59. global.Logger.Error("reset password fail! ", err)
  60. return
  61. }
  62. global.Logger.Info("reset password success! ")
  63. },
  64. }
  65. var resetUserPwdCmd = &cobra.Command{
  66. Use: "reset-pwd [userId] [pwd]",
  67. Example: "reset-pwd 2 123456",
  68. Short: "Reset User Password",
  69. Args: cobra.ExactArgs(2),
  70. Run: func(cmd *cobra.Command, args []string) {
  71. userId := args[0]
  72. pwd := args[1]
  73. uid, err := strconv.Atoi(userId)
  74. if err != nil {
  75. global.Logger.Warn("userId must be int!")
  76. return
  77. }
  78. if uid <= 0 {
  79. global.Logger.Warn("userId must be greater than 0! ")
  80. return
  81. }
  82. u := service.AllService.UserService.InfoById(uint(uid))
  83. if u.Id == 0 {
  84. global.Logger.Warn("user not found! ")
  85. return
  86. }
  87. err = service.AllService.UserService.UpdatePassword(u, pwd)
  88. if err != nil {
  89. global.Logger.Warn("reset password fail! ", err)
  90. return
  91. }
  92. global.Logger.Info("reset password success!")
  93. },
  94. }
  95. func init() {
  96. rootCmd.PersistentFlags().StringVarP(&global.ConfigPath, "config", "c", "./conf/config.yaml", "choose config file")
  97. rootCmd.AddCommand(resetPwdCmd, resetUserPwdCmd)
  98. }
  99. func main() {
  100. if err := rootCmd.Execute(); err != nil {
  101. global.Logger.Error(err)
  102. os.Exit(1)
  103. }
  104. }
  105. func InitGlobal() {
  106. //配置解析
  107. global.Viper = config.Init(&global.Config, global.ConfigPath)
  108. //日志
  109. global.Logger = logger.New(&logger.Config{
  110. Path: global.Config.Logger.Path,
  111. Level: global.Config.Logger.Level,
  112. ReportCaller: global.Config.Logger.ReportCaller,
  113. })
  114. global.InitI18n()
  115. //redis
  116. global.Redis = redis.NewClient(&redis.Options{
  117. Addr: global.Config.Redis.Addr,
  118. Password: global.Config.Redis.Password,
  119. DB: global.Config.Redis.Db,
  120. })
  121. //cache
  122. if global.Config.Cache.Type == cache.TypeFile {
  123. fc := cache.NewFileCache()
  124. fc.SetDir(global.Config.Cache.FileDir)
  125. global.Cache = fc
  126. } else if global.Config.Cache.Type == cache.TypeRedis {
  127. global.Cache = cache.NewRedis(&redis.Options{
  128. Addr: global.Config.Cache.RedisAddr,
  129. Password: global.Config.Cache.RedisPwd,
  130. DB: global.Config.Cache.RedisDb,
  131. })
  132. }
  133. //gorm
  134. if global.Config.Gorm.Type == config.TypeMysql {
  135. dsn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&tls=%s",
  136. global.Config.Mysql.Username,
  137. global.Config.Mysql.Password,
  138. global.Config.Mysql.Addr,
  139. global.Config.Mysql.Dbname,
  140. global.Config.Mysql.Tls,
  141. )
  142. global.DB = orm.NewMysql(&orm.MysqlConfig{
  143. Dsn: dsn,
  144. MaxIdleConns: global.Config.Gorm.MaxIdleConns,
  145. MaxOpenConns: global.Config.Gorm.MaxOpenConns,
  146. }, global.Logger)
  147. } else if global.Config.Gorm.Type == config.TypePostgresql {
  148. dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
  149. global.Config.Postgresql.Host,
  150. global.Config.Postgresql.Port,
  151. global.Config.Postgresql.User,
  152. global.Config.Postgresql.Password,
  153. global.Config.Postgresql.Dbname,
  154. global.Config.Postgresql.Sslmode,
  155. global.Config.Postgresql.TimeZone,
  156. )
  157. global.DB = orm.NewPostgresql(&orm.PostgresqlConfig{
  158. Dsn: dsn,
  159. MaxIdleConns: global.Config.Gorm.MaxIdleConns,
  160. MaxOpenConns: global.Config.Gorm.MaxOpenConns,
  161. }, global.Logger)
  162. } else {
  163. //sqlite
  164. global.DB = orm.NewSqlite(&orm.SqliteConfig{
  165. MaxIdleConns: global.Config.Gorm.MaxIdleConns,
  166. MaxOpenConns: global.Config.Gorm.MaxOpenConns,
  167. }, global.Logger)
  168. }
  169. //validator
  170. global.ApiInitValidator()
  171. //oss
  172. global.Oss = &upload.Oss{
  173. AccessKeyId: global.Config.Oss.AccessKeyId,
  174. AccessKeySecret: global.Config.Oss.AccessKeySecret,
  175. Host: global.Config.Oss.Host,
  176. CallbackUrl: global.Config.Oss.CallbackUrl,
  177. ExpireTime: global.Config.Oss.ExpireTime,
  178. MaxByte: global.Config.Oss.MaxByte,
  179. }
  180. //jwt
  181. //fmt.Println(global.Config.Jwt.PrivateKey)
  182. global.Jwt = jwt.NewJwt(global.Config.Jwt.Key, global.Config.Jwt.ExpireDuration)
  183. //locker
  184. global.Lock = lock.NewLocal()
  185. //service
  186. service.New(&global.Config, global.DB, global.Logger, global.Jwt, global.Lock)
  187. global.LoginLimiter = utils.NewLoginLimiter(utils.SecurityPolicy{
  188. CaptchaThreshold: global.Config.App.CaptchaThreshold,
  189. BanThreshold: global.Config.App.BanThreshold,
  190. AttemptsWindow: 10 * time.Minute,
  191. BanDuration: 30 * time.Minute,
  192. })
  193. global.LoginLimiter.RegisterProvider(utils.B64StringCaptchaProvider{})
  194. DatabaseAutoUpdate()
  195. }
  196. func DatabaseAutoUpdate() {
  197. version := DatabaseVersion
  198. db := global.DB
  199. if global.Config.Gorm.Type == config.TypeMysql {
  200. //检查存不存在数据库,不存在则创建
  201. dbName := db.Migrator().CurrentDatabase()
  202. if dbName == "" {
  203. dbName = global.Config.Mysql.Dbname
  204. // 移除 DSN 中的数据库名称,以便初始连接时不指定数据库
  205. dsnWithoutDB := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
  206. global.Config.Mysql.Username,
  207. global.Config.Mysql.Password,
  208. global.Config.Mysql.Addr,
  209. "",
  210. )
  211. //新链接
  212. dbWithoutDB := orm.NewMysql(&orm.MysqlConfig{
  213. Dsn: dsnWithoutDB,
  214. }, global.Logger)
  215. // 获取底层的 *sql.DB 对象,并确保在程序退出时关闭连接
  216. sqlDBWithoutDB, err := dbWithoutDB.DB()
  217. if err != nil {
  218. global.Logger.Errorf("获取底层 *sql.DB 对象失败: %v", err)
  219. return
  220. }
  221. defer func() {
  222. if err := sqlDBWithoutDB.Close(); err != nil {
  223. global.Logger.Errorf("关闭连接失败: %v", err)
  224. }
  225. }()
  226. err = dbWithoutDB.Exec("CREATE DATABASE IF NOT EXISTS " + dbName + " DEFAULT CHARSET utf8mb4").Error
  227. if err != nil {
  228. global.Logger.Error(err)
  229. return
  230. }
  231. }
  232. }
  233. if !db.Migrator().HasTable(&model.Version{}) {
  234. Migrate(uint(version))
  235. } else {
  236. //查找最后一个version
  237. var v model.Version
  238. db.Last(&v)
  239. if v.Version < uint(version) {
  240. Migrate(uint(version))
  241. }
  242. // 245迁移
  243. if v.Version < 245 {
  244. //oauths 表的 oauth_type 字段设置为 op同样的值
  245. db.Exec("update oauths set oauth_type = op")
  246. db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google'")
  247. db.Exec("update user_thirds set oauth_type = third_type, op = third_type")
  248. //通过email迁移旧的google授权
  249. uts := make([]model.UserThird, 0)
  250. db.Where("oauth_type = ?", "google").Find(&uts)
  251. for _, ut := range uts {
  252. if ut.UserId > 0 {
  253. db.Model(&model.User{}).Where("id = ?", ut.UserId).Update("email", ut.OpenId)
  254. }
  255. }
  256. }
  257. if v.Version < 246 {
  258. db.Exec("update oauths set issuer = 'https://accounts.google.com' where op = 'google' and issuer is null")
  259. }
  260. }
  261. }
  262. func Migrate(version uint) {
  263. global.Logger.Info("Migrating....", version)
  264. err := global.DB.AutoMigrate(
  265. &model.Version{},
  266. &model.User{},
  267. &model.UserToken{},
  268. &model.Tag{},
  269. &model.AddressBook{},
  270. &model.Peer{},
  271. &model.Group{},
  272. &model.UserThird{},
  273. &model.Oauth{},
  274. &model.LoginLog{},
  275. &model.ShareRecord{},
  276. &model.AuditConn{},
  277. &model.AuditFile{},
  278. &model.AddressBookCollection{},
  279. &model.AddressBookCollectionRule{},
  280. &model.ServerCmd{},
  281. &model.DeviceGroup{},
  282. )
  283. if err != nil {
  284. global.Logger.Error("migrate err :=>", err)
  285. }
  286. global.DB.Create(&model.Version{Version: version})
  287. //如果是初次则创建一个默认用户
  288. var vc int64
  289. global.DB.Model(&model.Version{}).Count(&vc)
  290. if vc == 1 {
  291. localizer := global.Localizer("")
  292. defaultGroup, _ := localizer.LocalizeMessage(&i18n.Message{
  293. ID: "DefaultGroup",
  294. })
  295. group := &model.Group{
  296. Name: defaultGroup,
  297. Type: model.GroupTypeDefault,
  298. }
  299. service.AllService.GroupService.Create(group)
  300. shareGroup, _ := localizer.LocalizeMessage(&i18n.Message{
  301. ID: "ShareGroup",
  302. })
  303. groupShare := &model.Group{
  304. Name: shareGroup,
  305. Type: model.GroupTypeShare,
  306. }
  307. service.AllService.GroupService.Create(groupShare)
  308. //是true
  309. is_admin := true
  310. admin := &model.User{
  311. Username: "admin",
  312. Nickname: "Admin",
  313. Status: model.COMMON_STATUS_ENABLE,
  314. IsAdmin: &is_admin,
  315. GroupId: 1,
  316. }
  317. // 生成随机密码
  318. pwd := utils.RandomString(8)
  319. global.Logger.Info("Admin Password Is: ", pwd)
  320. var err error
  321. admin.Password, err = utils.EncryptPassword(pwd)
  322. if err != nil {
  323. global.Logger.Fatalf("failed to generate admin password: %v", err)
  324. }
  325. global.DB.Create(admin)
  326. }
  327. }