auth_store.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. package postgres
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "time"
  9. "cmr-backend/internal/apperr"
  10. "github.com/jackc/pgx/v5"
  11. )
  12. type SMSCodeMeta struct {
  13. ID string
  14. CodeHash string
  15. ExpiresAt time.Time
  16. CooldownUntil time.Time
  17. }
  18. type CreateSMSCodeParams struct {
  19. Scene string
  20. CountryCode string
  21. Mobile string
  22. ClientType string
  23. DeviceKey string
  24. CodeHash string
  25. ProviderName string
  26. ProviderDebug map[string]any
  27. ExpiresAt time.Time
  28. CooldownUntil time.Time
  29. }
  30. type CreateMobileIdentityParams struct {
  31. UserID string
  32. IdentityType string
  33. Provider string
  34. ProviderSubj string
  35. CountryCode string
  36. Mobile string
  37. }
  38. type CreateIdentityParams struct {
  39. UserID string
  40. IdentityType string
  41. Provider string
  42. ProviderSubj string
  43. CountryCode *string
  44. Mobile *string
  45. ProfileJSON string
  46. }
  47. type CreateRefreshTokenParams struct {
  48. UserID string
  49. ClientType string
  50. DeviceKey string
  51. TokenHash string
  52. ExpiresAt time.Time
  53. }
  54. type RefreshTokenRecord struct {
  55. ID string
  56. UserID string
  57. ClientType string
  58. DeviceKey *string
  59. ExpiresAt time.Time
  60. IsRevoked bool
  61. }
  62. func (s *Store) GetLatestSMSCodeMeta(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
  63. row := s.pool.QueryRow(ctx, `
  64. SELECT id, code_hash, expires_at, cooldown_until
  65. FROM auth_sms_codes
  66. WHERE country_code = $1 AND mobile = $2 AND client_type = $3 AND scene = $4
  67. ORDER BY created_at DESC
  68. LIMIT 1
  69. `, countryCode, mobile, clientType, scene)
  70. var record SMSCodeMeta
  71. err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
  72. if errors.Is(err, pgx.ErrNoRows) {
  73. return nil, nil
  74. }
  75. if err != nil {
  76. return nil, fmt.Errorf("query latest sms code meta: %w", err)
  77. }
  78. return &record, nil
  79. }
  80. func (s *Store) CreateSMSCode(ctx context.Context, params CreateSMSCodeParams) error {
  81. payload, err := json.Marshal(map[string]any{
  82. "provider": params.ProviderName,
  83. "debug": params.ProviderDebug,
  84. })
  85. if err != nil {
  86. return err
  87. }
  88. _, err = s.pool.Exec(ctx, `
  89. INSERT INTO auth_sms_codes (
  90. scene, country_code, mobile, client_type, device_key, code_hash,
  91. provider_payload_jsonb, expires_at, cooldown_until
  92. )
  93. VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9)
  94. `, params.Scene, params.CountryCode, params.Mobile, params.ClientType, params.DeviceKey, params.CodeHash, string(payload), params.ExpiresAt, params.CooldownUntil)
  95. if err != nil {
  96. return fmt.Errorf("insert sms code: %w", err)
  97. }
  98. return nil
  99. }
  100. func (s *Store) GetLatestValidSMSCode(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
  101. row := s.pool.QueryRow(ctx, `
  102. SELECT id, code_hash, expires_at, cooldown_until
  103. FROM auth_sms_codes
  104. WHERE country_code = $1
  105. AND mobile = $2
  106. AND client_type = $3
  107. AND scene = $4
  108. AND consumed_at IS NULL
  109. AND expires_at > NOW()
  110. ORDER BY created_at DESC
  111. LIMIT 1
  112. `, countryCode, mobile, clientType, scene)
  113. var record SMSCodeMeta
  114. err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
  115. if errors.Is(err, pgx.ErrNoRows) {
  116. return nil, nil
  117. }
  118. if err != nil {
  119. return nil, fmt.Errorf("query latest valid sms code: %w", err)
  120. }
  121. return &record, nil
  122. }
  123. func (s *Store) ConsumeSMSCode(ctx context.Context, tx Tx, id string) (bool, error) {
  124. commandTag, err := tx.Exec(ctx, `
  125. UPDATE auth_sms_codes
  126. SET consumed_at = NOW()
  127. WHERE id = $1 AND consumed_at IS NULL
  128. `, id)
  129. if err != nil {
  130. return false, fmt.Errorf("consume sms code: %w", err)
  131. }
  132. return commandTag.RowsAffected() == 1, nil
  133. }
  134. func (s *Store) CreateMobileIdentity(ctx context.Context, tx Tx, params CreateMobileIdentityParams) error {
  135. countryCode := params.CountryCode
  136. mobile := params.Mobile
  137. return s.CreateIdentity(ctx, tx, CreateIdentityParams{
  138. UserID: params.UserID,
  139. IdentityType: params.IdentityType,
  140. Provider: params.Provider,
  141. ProviderSubj: params.ProviderSubj,
  142. CountryCode: &countryCode,
  143. Mobile: &mobile,
  144. ProfileJSON: "{}",
  145. })
  146. }
  147. func (s *Store) CreateIdentity(ctx context.Context, tx Tx, params CreateIdentityParams) error {
  148. _, err := tx.Exec(ctx, `
  149. INSERT INTO login_identities (
  150. user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb
  151. )
  152. VALUES ($1, $2, $3, $4, $5, $6, 'active', $7::jsonb)
  153. ON CONFLICT (provider, provider_subject) DO NOTHING
  154. `, params.UserID, params.IdentityType, params.Provider, params.ProviderSubj, params.CountryCode, params.Mobile, zeroJSON(params.ProfileJSON))
  155. if err != nil {
  156. return fmt.Errorf("create identity: %w", err)
  157. }
  158. return nil
  159. }
  160. func (s *Store) FindUserByProviderSubject(ctx context.Context, tx Tx, provider, providerSubject string) (*User, error) {
  161. row := tx.QueryRow(ctx, `
  162. SELECT u.id, u.user_public_id, u.status, u.nickname, u.avatar_url
  163. FROM users u
  164. JOIN login_identities li ON li.user_id = u.id
  165. WHERE li.provider = $1
  166. AND li.provider_subject = $2
  167. AND li.status = 'active'
  168. LIMIT 1
  169. `, provider, providerSubject)
  170. return scanUser(row)
  171. }
  172. func (s *Store) CreateRefreshToken(ctx context.Context, tx Tx, params CreateRefreshTokenParams) (string, error) {
  173. row := tx.QueryRow(ctx, `
  174. INSERT INTO auth_refresh_tokens (user_id, client_type, device_key, token_hash, expires_at)
  175. VALUES ($1, $2, NULLIF($3, ''), $4, $5)
  176. RETURNING id
  177. `, params.UserID, params.ClientType, params.DeviceKey, params.TokenHash, params.ExpiresAt)
  178. var id string
  179. if err := row.Scan(&id); err != nil {
  180. return "", fmt.Errorf("create refresh token: %w", err)
  181. }
  182. return id, nil
  183. }
  184. func (s *Store) GetRefreshTokenForUpdate(ctx context.Context, tx Tx, tokenHash string) (*RefreshTokenRecord, error) {
  185. row := tx.QueryRow(ctx, `
  186. SELECT id, user_id, client_type, device_key, expires_at, revoked_at IS NOT NULL
  187. FROM auth_refresh_tokens
  188. WHERE token_hash = $1
  189. FOR UPDATE
  190. `, tokenHash)
  191. var record RefreshTokenRecord
  192. err := row.Scan(&record.ID, &record.UserID, &record.ClientType, &record.DeviceKey, &record.ExpiresAt, &record.IsRevoked)
  193. if errors.Is(err, pgx.ErrNoRows) {
  194. return nil, nil
  195. }
  196. if err != nil {
  197. return nil, fmt.Errorf("query refresh token for update: %w", err)
  198. }
  199. return &record, nil
  200. }
  201. func (s *Store) RotateRefreshToken(ctx context.Context, tx Tx, oldTokenID, newTokenID string) error {
  202. _, err := tx.Exec(ctx, `
  203. UPDATE auth_refresh_tokens
  204. SET revoked_at = NOW(), replaced_by_token_id = $2
  205. WHERE id = $1
  206. `, oldTokenID, newTokenID)
  207. if err != nil {
  208. return fmt.Errorf("rotate refresh token: %w", err)
  209. }
  210. return nil
  211. }
  212. func (s *Store) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
  213. commandTag, err := s.pool.Exec(ctx, `
  214. UPDATE auth_refresh_tokens
  215. SET revoked_at = COALESCE(revoked_at, NOW())
  216. WHERE token_hash = $1
  217. `, tokenHash)
  218. if err != nil {
  219. return fmt.Errorf("revoke refresh token: %w", err)
  220. }
  221. if commandTag.RowsAffected() == 0 {
  222. return apperr.New(http.StatusNotFound, "refresh_token_not_found", "refresh token not found")
  223. }
  224. return nil
  225. }
  226. func (s *Store) RevokeRefreshTokensByUserID(ctx context.Context, tx Tx, userID string) error {
  227. _, err := tx.Exec(ctx, `
  228. UPDATE auth_refresh_tokens
  229. SET revoked_at = COALESCE(revoked_at, NOW())
  230. WHERE user_id = $1
  231. `, userID)
  232. if err != nil {
  233. return fmt.Errorf("revoke refresh tokens by user id: %w", err)
  234. }
  235. return nil
  236. }
  237. func (s *Store) TransferNonMobileIdentities(ctx context.Context, tx Tx, sourceUserID, targetUserID string) error {
  238. if sourceUserID == targetUserID {
  239. return nil
  240. }
  241. _, err := tx.Exec(ctx, `
  242. INSERT INTO login_identities (
  243. user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb, created_at, updated_at
  244. )
  245. SELECT
  246. $2,
  247. li.identity_type,
  248. li.provider,
  249. li.provider_subject,
  250. li.country_code,
  251. li.mobile,
  252. li.status,
  253. li.profile_jsonb,
  254. li.created_at,
  255. li.updated_at
  256. FROM login_identities li
  257. WHERE li.user_id = $1
  258. AND li.provider <> 'mobile'
  259. ON CONFLICT (provider, provider_subject) DO NOTHING
  260. `, sourceUserID, targetUserID)
  261. if err != nil {
  262. return fmt.Errorf("copy non-mobile identities: %w", err)
  263. }
  264. _, err = tx.Exec(ctx, `
  265. DELETE FROM login_identities
  266. WHERE user_id = $1
  267. AND provider <> 'mobile'
  268. `, sourceUserID)
  269. if err != nil {
  270. return fmt.Errorf("delete source non-mobile identities: %w", err)
  271. }
  272. return nil
  273. }
  274. func zeroJSON(value string) string {
  275. if value == "" {
  276. return "{}"
  277. }
  278. return value
  279. }