auth_service.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. package service
  2. import (
  3. "context"
  4. "encoding/json"
  5. "net/http"
  6. "strings"
  7. "time"
  8. "cmr-backend/internal/apperr"
  9. "cmr-backend/internal/platform/jwtx"
  10. "cmr-backend/internal/platform/security"
  11. "cmr-backend/internal/platform/wechatmini"
  12. "cmr-backend/internal/store/postgres"
  13. )
  14. type AuthSettings struct {
  15. AppEnv string
  16. RefreshTTL time.Duration
  17. SMSCodeTTL time.Duration
  18. SMSCodeCooldown time.Duration
  19. SMSProvider string
  20. DevSMSCode string
  21. WechatMini *wechatmini.Client
  22. }
  23. type AuthService struct {
  24. cfg AuthSettings
  25. store *postgres.Store
  26. jwtManager *jwtx.Manager
  27. }
  28. type SendSMSCodeInput struct {
  29. CountryCode string `json:"countryCode"`
  30. Mobile string `json:"mobile"`
  31. ClientType string `json:"clientType"`
  32. DeviceKey string `json:"deviceKey"`
  33. Scene string `json:"scene"`
  34. }
  35. type SendSMSCodeResult struct {
  36. TTLSeconds int64 `json:"ttlSeconds"`
  37. CooldownSeconds int64 `json:"cooldownSeconds"`
  38. DevCode *string `json:"devCode,omitempty"`
  39. }
  40. type LoginSMSInput struct {
  41. CountryCode string `json:"countryCode"`
  42. Mobile string `json:"mobile"`
  43. Code string `json:"code"`
  44. ClientType string `json:"clientType"`
  45. DeviceKey string `json:"deviceKey"`
  46. }
  47. type LoginWechatMiniInput struct {
  48. Code string `json:"code"`
  49. ClientType string `json:"clientType"`
  50. DeviceKey string `json:"deviceKey"`
  51. }
  52. type BindMobileInput struct {
  53. UserID string `json:"-"`
  54. CountryCode string `json:"countryCode"`
  55. Mobile string `json:"mobile"`
  56. Code string `json:"code"`
  57. ClientType string `json:"clientType"`
  58. DeviceKey string `json:"deviceKey"`
  59. }
  60. type RefreshTokenInput struct {
  61. RefreshToken string `json:"refreshToken"`
  62. ClientType string `json:"clientType"`
  63. DeviceKey string `json:"deviceKey"`
  64. }
  65. type LogoutInput struct {
  66. RefreshToken string `json:"refreshToken"`
  67. UserID string `json:"-"`
  68. }
  69. type AuthUser struct {
  70. ID string `json:"id"`
  71. PublicID string `json:"publicId"`
  72. Status string `json:"status"`
  73. Nickname *string `json:"nickname,omitempty"`
  74. AvatarURL *string `json:"avatarUrl,omitempty"`
  75. }
  76. type AuthTokens struct {
  77. AccessToken string `json:"accessToken"`
  78. AccessTokenExpiresAt string `json:"accessTokenExpiresAt"`
  79. RefreshToken string `json:"refreshToken"`
  80. RefreshTokenExpiresAt string `json:"refreshTokenExpiresAt"`
  81. }
  82. type AuthResult struct {
  83. User AuthUser `json:"user"`
  84. Tokens AuthTokens `json:"tokens"`
  85. NewUser bool `json:"newUser"`
  86. }
  87. func NewAuthService(cfg AuthSettings, store *postgres.Store, jwtManager *jwtx.Manager) *AuthService {
  88. return &AuthService{
  89. cfg: cfg,
  90. store: store,
  91. jwtManager: jwtManager,
  92. }
  93. }
  94. func (s *AuthService) SendSMSCode(ctx context.Context, input SendSMSCodeInput) (*SendSMSCodeResult, error) {
  95. input.CountryCode = normalizeCountryCode(input.CountryCode)
  96. input.Mobile = normalizeMobile(input.Mobile)
  97. input.Scene = normalizeScene(input.Scene)
  98. if err := validateClientType(input.ClientType); err != nil {
  99. return nil, err
  100. }
  101. if input.Mobile == "" || input.DeviceKey == "" {
  102. return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile and deviceKey are required")
  103. }
  104. latest, err := s.store.GetLatestSMSCodeMeta(ctx, input.CountryCode, input.Mobile, input.ClientType, input.Scene)
  105. if err != nil {
  106. return nil, err
  107. }
  108. now := time.Now().UTC()
  109. if latest != nil && latest.CooldownUntil.After(now) {
  110. return nil, apperr.New(http.StatusTooManyRequests, "sms_cooldown", "sms code sent too frequently")
  111. }
  112. code := s.cfg.DevSMSCode
  113. if code == "" {
  114. code, err = security.GenerateNumericCode(6)
  115. if err != nil {
  116. return nil, err
  117. }
  118. }
  119. expiresAt := now.Add(s.cfg.SMSCodeTTL)
  120. cooldownUntil := now.Add(s.cfg.SMSCodeCooldown)
  121. if err := s.store.CreateSMSCode(ctx, postgres.CreateSMSCodeParams{
  122. Scene: input.Scene,
  123. CountryCode: input.CountryCode,
  124. Mobile: input.Mobile,
  125. ClientType: input.ClientType,
  126. DeviceKey: input.DeviceKey,
  127. CodeHash: security.HashText(code),
  128. ProviderName: s.cfg.SMSProvider,
  129. ProviderDebug: map[string]any{"mode": s.cfg.SMSProvider},
  130. ExpiresAt: expiresAt,
  131. CooldownUntil: cooldownUntil,
  132. }); err != nil {
  133. return nil, err
  134. }
  135. result := &SendSMSCodeResult{
  136. TTLSeconds: int64(s.cfg.SMSCodeTTL.Seconds()),
  137. CooldownSeconds: int64(s.cfg.SMSCodeCooldown.Seconds()),
  138. }
  139. if strings.EqualFold(s.cfg.SMSProvider, "console") || strings.EqualFold(s.cfg.AppEnv, "development") {
  140. result.DevCode = &code
  141. }
  142. return result, nil
  143. }
  144. func (s *AuthService) LoginSMS(ctx context.Context, input LoginSMSInput) (*AuthResult, error) {
  145. input.CountryCode = normalizeCountryCode(input.CountryCode)
  146. input.Mobile = normalizeMobile(input.Mobile)
  147. input.Code = strings.TrimSpace(input.Code)
  148. if err := validateClientType(input.ClientType); err != nil {
  149. return nil, err
  150. }
  151. if input.Mobile == "" || input.DeviceKey == "" || input.Code == "" {
  152. return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile, code and deviceKey are required")
  153. }
  154. codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, input.ClientType, "login")
  155. if err != nil {
  156. return nil, err
  157. }
  158. if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
  159. return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
  160. }
  161. tx, err := s.store.Begin(ctx)
  162. if err != nil {
  163. return nil, err
  164. }
  165. defer tx.Rollback(ctx)
  166. consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
  167. if err != nil {
  168. return nil, err
  169. }
  170. if !consumed {
  171. return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
  172. }
  173. user, err := s.store.FindUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
  174. if err != nil {
  175. return nil, err
  176. }
  177. newUser := false
  178. if user == nil {
  179. userPublicID, err := security.GeneratePublicID("usr")
  180. if err != nil {
  181. return nil, err
  182. }
  183. user, err = s.store.CreateUser(ctx, tx, postgres.CreateUserParams{
  184. PublicID: userPublicID,
  185. Status: "active",
  186. })
  187. if err != nil {
  188. return nil, err
  189. }
  190. if err := s.store.CreateMobileIdentity(ctx, tx, postgres.CreateMobileIdentityParams{
  191. UserID: user.ID,
  192. CountryCode: input.CountryCode,
  193. Mobile: input.Mobile,
  194. Provider: "mobile",
  195. ProviderSubj: input.CountryCode + ":" + input.Mobile,
  196. IdentityType: "mobile",
  197. }); err != nil {
  198. return nil, err
  199. }
  200. newUser = true
  201. }
  202. if err := s.store.TouchUserLogin(ctx, tx, user.ID); err != nil {
  203. return nil, err
  204. }
  205. result, err := s.issueAuthResult(ctx, tx, *user, input.ClientType, input.DeviceKey, newUser)
  206. if err != nil {
  207. return nil, err
  208. }
  209. if err := tx.Commit(ctx); err != nil {
  210. return nil, err
  211. }
  212. return result, nil
  213. }
  214. func (s *AuthService) Refresh(ctx context.Context, input RefreshTokenInput) (*AuthResult, error) {
  215. input.RefreshToken = strings.TrimSpace(input.RefreshToken)
  216. if err := validateClientType(input.ClientType); err != nil {
  217. return nil, err
  218. }
  219. if input.RefreshToken == "" {
  220. return nil, apperr.New(http.StatusBadRequest, "invalid_params", "refreshToken is required")
  221. }
  222. tx, err := s.store.Begin(ctx)
  223. if err != nil {
  224. return nil, err
  225. }
  226. defer tx.Rollback(ctx)
  227. record, err := s.store.GetRefreshTokenForUpdate(ctx, tx, security.HashText(input.RefreshToken))
  228. if err != nil {
  229. return nil, err
  230. }
  231. if record == nil || record.IsRevoked || record.ExpiresAt.Before(time.Now().UTC()) {
  232. return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token is invalid or expired")
  233. }
  234. if input.ClientType != "" && input.ClientType != record.ClientType {
  235. return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token client mismatch")
  236. }
  237. if input.DeviceKey != "" && record.DeviceKey != nil && input.DeviceKey != *record.DeviceKey {
  238. return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token device mismatch")
  239. }
  240. user, err := s.store.GetUserByID(ctx, tx, record.UserID)
  241. if err != nil {
  242. return nil, err
  243. }
  244. if user == nil {
  245. return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token user not found")
  246. }
  247. result, refreshTokenID, err := s.issueAuthResultWithRefreshID(ctx, tx, *user, record.ClientType, nullableStringValue(record.DeviceKey), false)
  248. if err != nil {
  249. return nil, err
  250. }
  251. if err := s.store.RotateRefreshToken(ctx, tx, record.ID, refreshTokenID); err != nil {
  252. return nil, err
  253. }
  254. if err := tx.Commit(ctx); err != nil {
  255. return nil, err
  256. }
  257. return result, nil
  258. }
  259. func (s *AuthService) LoginWechatMini(ctx context.Context, input LoginWechatMiniInput) (*AuthResult, error) {
  260. input.Code = strings.TrimSpace(input.Code)
  261. if err := validateClientType(input.ClientType); err != nil {
  262. return nil, err
  263. }
  264. if input.ClientType != "wechat" {
  265. return nil, apperr.New(http.StatusBadRequest, "invalid_client_type", "wechat mini login requires clientType=wechat")
  266. }
  267. if input.Code == "" || strings.TrimSpace(input.DeviceKey) == "" {
  268. return nil, apperr.New(http.StatusBadRequest, "invalid_params", "code and deviceKey are required")
  269. }
  270. if s.cfg.WechatMini == nil {
  271. return nil, apperr.New(http.StatusNotImplemented, "wechat_not_configured", "wechat mini provider is not configured")
  272. }
  273. session, err := s.cfg.WechatMini.ExchangeCode(ctx, input.Code)
  274. if err != nil {
  275. return nil, apperr.New(http.StatusUnauthorized, "wechat_login_failed", err.Error())
  276. }
  277. openIDSubject := session.AppID + ":" + session.OpenID
  278. unionIDSubject := strings.TrimSpace(session.UnionID)
  279. tx, err := s.store.Begin(ctx)
  280. if err != nil {
  281. return nil, err
  282. }
  283. defer tx.Rollback(ctx)
  284. user, err := s.store.FindUserByProviderSubject(ctx, tx, "wechat_mini", openIDSubject)
  285. if err != nil {
  286. return nil, err
  287. }
  288. if user == nil && unionIDSubject != "" {
  289. user, err = s.store.FindUserByProviderSubject(ctx, tx, "wechat_unionid", unionIDSubject)
  290. if err != nil {
  291. return nil, err
  292. }
  293. }
  294. newUser := false
  295. if user == nil {
  296. userPublicID, err := security.GeneratePublicID("usr")
  297. if err != nil {
  298. return nil, err
  299. }
  300. user, err = s.store.CreateUser(ctx, tx, postgres.CreateUserParams{
  301. PublicID: userPublicID,
  302. Status: "active",
  303. })
  304. if err != nil {
  305. return nil, err
  306. }
  307. newUser = true
  308. }
  309. profileJSON, err := json.Marshal(map[string]any{
  310. "appId": session.AppID,
  311. })
  312. if err != nil {
  313. return nil, err
  314. }
  315. if err := s.store.CreateIdentity(ctx, tx, postgres.CreateIdentityParams{
  316. UserID: user.ID,
  317. IdentityType: "wechat_mini_openid",
  318. Provider: "wechat_mini",
  319. ProviderSubj: openIDSubject,
  320. ProfileJSON: string(profileJSON),
  321. }); err != nil {
  322. return nil, err
  323. }
  324. if unionIDSubject != "" {
  325. if err := s.store.CreateIdentity(ctx, tx, postgres.CreateIdentityParams{
  326. UserID: user.ID,
  327. IdentityType: "wechat_unionid",
  328. Provider: "wechat_unionid",
  329. ProviderSubj: unionIDSubject,
  330. ProfileJSON: "{}",
  331. }); err != nil {
  332. return nil, err
  333. }
  334. }
  335. if err := s.store.TouchUserLogin(ctx, tx, user.ID); err != nil {
  336. return nil, err
  337. }
  338. result, err := s.issueAuthResult(ctx, tx, *user, input.ClientType, input.DeviceKey, newUser)
  339. if err != nil {
  340. return nil, err
  341. }
  342. if err := tx.Commit(ctx); err != nil {
  343. return nil, err
  344. }
  345. return result, nil
  346. }
  347. func (s *AuthService) BindMobile(ctx context.Context, input BindMobileInput) (*AuthResult, error) {
  348. input.CountryCode = normalizeCountryCode(input.CountryCode)
  349. input.Mobile = normalizeMobile(input.Mobile)
  350. input.Code = strings.TrimSpace(input.Code)
  351. if err := validateClientType(input.ClientType); err != nil {
  352. return nil, err
  353. }
  354. if input.UserID == "" || input.Mobile == "" || input.Code == "" || strings.TrimSpace(input.DeviceKey) == "" {
  355. return nil, apperr.New(http.StatusBadRequest, "invalid_params", "user, mobile, code and deviceKey are required")
  356. }
  357. codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, input.ClientType, "bind_mobile")
  358. if err != nil {
  359. return nil, err
  360. }
  361. if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
  362. return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
  363. }
  364. tx, err := s.store.Begin(ctx)
  365. if err != nil {
  366. return nil, err
  367. }
  368. defer tx.Rollback(ctx)
  369. consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
  370. if err != nil {
  371. return nil, err
  372. }
  373. if !consumed {
  374. return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
  375. }
  376. currentUser, err := s.store.GetUserByID(ctx, tx, input.UserID)
  377. if err != nil {
  378. return nil, err
  379. }
  380. if currentUser == nil {
  381. return nil, apperr.New(http.StatusNotFound, "user_not_found", "current user not found")
  382. }
  383. mobileUser, err := s.store.FindUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
  384. if err != nil {
  385. return nil, err
  386. }
  387. finalUser := currentUser
  388. newlyBound := false
  389. if mobileUser == nil {
  390. if err := s.store.CreateMobileIdentity(ctx, tx, postgres.CreateMobileIdentityParams{
  391. UserID: currentUser.ID,
  392. CountryCode: input.CountryCode,
  393. Mobile: input.Mobile,
  394. Provider: "mobile",
  395. ProviderSubj: input.CountryCode + ":" + input.Mobile,
  396. IdentityType: "mobile",
  397. }); err != nil {
  398. return nil, err
  399. }
  400. newlyBound = true
  401. } else if mobileUser.ID != currentUser.ID {
  402. if err := s.store.TransferNonMobileIdentities(ctx, tx, currentUser.ID, mobileUser.ID); err != nil {
  403. return nil, err
  404. }
  405. if err := s.store.RevokeRefreshTokensByUserID(ctx, tx, currentUser.ID); err != nil {
  406. return nil, err
  407. }
  408. if err := s.store.DeactivateUser(ctx, tx, currentUser.ID); err != nil {
  409. return nil, err
  410. }
  411. finalUser = mobileUser
  412. }
  413. if err := s.store.TouchUserLogin(ctx, tx, finalUser.ID); err != nil {
  414. return nil, err
  415. }
  416. result, err := s.issueAuthResult(ctx, tx, *finalUser, input.ClientType, input.DeviceKey, newlyBound)
  417. if err != nil {
  418. return nil, err
  419. }
  420. if err := tx.Commit(ctx); err != nil {
  421. return nil, err
  422. }
  423. return result, nil
  424. }
  425. func (s *AuthService) Logout(ctx context.Context, input LogoutInput) error {
  426. if strings.TrimSpace(input.RefreshToken) == "" {
  427. return nil
  428. }
  429. return s.store.RevokeRefreshToken(ctx, security.HashText(strings.TrimSpace(input.RefreshToken)))
  430. }
  431. func (s *AuthService) issueAuthResult(
  432. ctx context.Context,
  433. tx postgres.Tx,
  434. user postgres.User,
  435. clientType string,
  436. deviceKey string,
  437. newUser bool,
  438. ) (*AuthResult, error) {
  439. result, _, err := s.issueAuthResultWithRefreshID(ctx, tx, user, clientType, deviceKey, newUser)
  440. return result, err
  441. }
  442. func (s *AuthService) issueAuthResultWithRefreshID(
  443. ctx context.Context,
  444. tx postgres.Tx,
  445. user postgres.User,
  446. clientType string,
  447. deviceKey string,
  448. newUser bool,
  449. ) (*AuthResult, string, error) {
  450. accessToken, accessExpiresAt, err := s.jwtManager.IssueAccessToken(user.ID, user.PublicID)
  451. if err != nil {
  452. return nil, "", err
  453. }
  454. refreshToken, err := security.GenerateToken(32)
  455. if err != nil {
  456. return nil, "", err
  457. }
  458. refreshTokenHash := security.HashText(refreshToken)
  459. refreshExpiresAt := time.Now().UTC().Add(s.cfg.RefreshTTL)
  460. refreshID, err := s.store.CreateRefreshToken(ctx, tx, postgres.CreateRefreshTokenParams{
  461. UserID: user.ID,
  462. ClientType: clientType,
  463. DeviceKey: deviceKey,
  464. TokenHash: refreshTokenHash,
  465. ExpiresAt: refreshExpiresAt,
  466. })
  467. if err != nil {
  468. return nil, "", err
  469. }
  470. return &AuthResult{
  471. User: AuthUser{
  472. ID: user.ID,
  473. PublicID: user.PublicID,
  474. Status: user.Status,
  475. Nickname: user.Nickname,
  476. AvatarURL: user.AvatarURL,
  477. },
  478. Tokens: AuthTokens{
  479. AccessToken: accessToken,
  480. AccessTokenExpiresAt: accessExpiresAt.Format(time.RFC3339),
  481. RefreshToken: refreshToken,
  482. RefreshTokenExpiresAt: refreshExpiresAt.Format(time.RFC3339),
  483. },
  484. NewUser: newUser,
  485. }, refreshID, nil
  486. }
  487. func validateClientType(clientType string) error {
  488. switch clientType {
  489. case "app", "wechat":
  490. return nil
  491. default:
  492. return apperr.New(http.StatusBadRequest, "invalid_client_type", "clientType must be app or wechat")
  493. }
  494. }
  495. func normalizeCountryCode(value string) string {
  496. value = strings.TrimSpace(value)
  497. if value == "" {
  498. return "86"
  499. }
  500. return strings.TrimPrefix(value, "+")
  501. }
  502. func normalizeMobile(value string) string {
  503. value = strings.TrimSpace(value)
  504. value = strings.ReplaceAll(value, " ", "")
  505. value = strings.ReplaceAll(value, "-", "")
  506. return value
  507. }
  508. func normalizeScene(value string) string {
  509. value = strings.TrimSpace(value)
  510. if value == "" {
  511. return "login"
  512. }
  513. return value
  514. }
  515. func nullableStringValue(value *string) string {
  516. if value == nil {
  517. return ""
  518. }
  519. return *value
  520. }