| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- package postgres
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "time"
- "cmr-backend/internal/apperr"
- "github.com/jackc/pgx/v5"
- )
- type SMSCodeMeta struct {
- ID string
- CodeHash string
- ExpiresAt time.Time
- CooldownUntil time.Time
- }
- type CreateSMSCodeParams struct {
- Scene string
- CountryCode string
- Mobile string
- ClientType string
- DeviceKey string
- CodeHash string
- ProviderName string
- ProviderDebug map[string]any
- ExpiresAt time.Time
- CooldownUntil time.Time
- }
- type CreateMobileIdentityParams struct {
- UserID string
- IdentityType string
- Provider string
- ProviderSubj string
- CountryCode string
- Mobile string
- }
- type CreateIdentityParams struct {
- UserID string
- IdentityType string
- Provider string
- ProviderSubj string
- CountryCode *string
- Mobile *string
- ProfileJSON string
- }
- type CreateRefreshTokenParams struct {
- UserID string
- ClientType string
- DeviceKey string
- TokenHash string
- ExpiresAt time.Time
- }
- type RefreshTokenRecord struct {
- ID string
- UserID string
- ClientType string
- DeviceKey *string
- ExpiresAt time.Time
- IsRevoked bool
- }
- func (s *Store) GetLatestSMSCodeMeta(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
- row := s.pool.QueryRow(ctx, `
- SELECT id, code_hash, expires_at, cooldown_until
- FROM auth_sms_codes
- WHERE country_code = $1 AND mobile = $2 AND client_type = $3 AND scene = $4
- ORDER BY created_at DESC
- LIMIT 1
- `, countryCode, mobile, clientType, scene)
- var record SMSCodeMeta
- err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
- if errors.Is(err, pgx.ErrNoRows) {
- return nil, nil
- }
- if err != nil {
- return nil, fmt.Errorf("query latest sms code meta: %w", err)
- }
- return &record, nil
- }
- func (s *Store) CreateSMSCode(ctx context.Context, params CreateSMSCodeParams) error {
- payload, err := json.Marshal(map[string]any{
- "provider": params.ProviderName,
- "debug": params.ProviderDebug,
- })
- if err != nil {
- return err
- }
- _, err = s.pool.Exec(ctx, `
- INSERT INTO auth_sms_codes (
- scene, country_code, mobile, client_type, device_key, code_hash,
- provider_payload_jsonb, expires_at, cooldown_until
- )
- VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9)
- `, params.Scene, params.CountryCode, params.Mobile, params.ClientType, params.DeviceKey, params.CodeHash, string(payload), params.ExpiresAt, params.CooldownUntil)
- if err != nil {
- return fmt.Errorf("insert sms code: %w", err)
- }
- return nil
- }
- func (s *Store) GetLatestValidSMSCode(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
- row := s.pool.QueryRow(ctx, `
- SELECT id, code_hash, expires_at, cooldown_until
- FROM auth_sms_codes
- WHERE country_code = $1
- AND mobile = $2
- AND client_type = $3
- AND scene = $4
- AND consumed_at IS NULL
- AND expires_at > NOW()
- ORDER BY created_at DESC
- LIMIT 1
- `, countryCode, mobile, clientType, scene)
- var record SMSCodeMeta
- err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
- if errors.Is(err, pgx.ErrNoRows) {
- return nil, nil
- }
- if err != nil {
- return nil, fmt.Errorf("query latest valid sms code: %w", err)
- }
- return &record, nil
- }
- func (s *Store) ConsumeSMSCode(ctx context.Context, tx Tx, id string) (bool, error) {
- commandTag, err := tx.Exec(ctx, `
- UPDATE auth_sms_codes
- SET consumed_at = NOW()
- WHERE id = $1 AND consumed_at IS NULL
- `, id)
- if err != nil {
- return false, fmt.Errorf("consume sms code: %w", err)
- }
- return commandTag.RowsAffected() == 1, nil
- }
- func (s *Store) CreateMobileIdentity(ctx context.Context, tx Tx, params CreateMobileIdentityParams) error {
- countryCode := params.CountryCode
- mobile := params.Mobile
- return s.CreateIdentity(ctx, tx, CreateIdentityParams{
- UserID: params.UserID,
- IdentityType: params.IdentityType,
- Provider: params.Provider,
- ProviderSubj: params.ProviderSubj,
- CountryCode: &countryCode,
- Mobile: &mobile,
- ProfileJSON: "{}",
- })
- }
- func (s *Store) CreateIdentity(ctx context.Context, tx Tx, params CreateIdentityParams) error {
- _, err := tx.Exec(ctx, `
- INSERT INTO login_identities (
- user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb
- )
- VALUES ($1, $2, $3, $4, $5, $6, 'active', $7::jsonb)
- ON CONFLICT (provider, provider_subject) DO NOTHING
- `, params.UserID, params.IdentityType, params.Provider, params.ProviderSubj, params.CountryCode, params.Mobile, zeroJSON(params.ProfileJSON))
- if err != nil {
- return fmt.Errorf("create identity: %w", err)
- }
- return nil
- }
- func (s *Store) FindUserByProviderSubject(ctx context.Context, tx Tx, provider, providerSubject string) (*User, error) {
- row := tx.QueryRow(ctx, `
- SELECT u.id, u.user_public_id, u.status, u.nickname, u.avatar_url
- FROM users u
- JOIN login_identities li ON li.user_id = u.id
- WHERE li.provider = $1
- AND li.provider_subject = $2
- AND li.status = 'active'
- LIMIT 1
- `, provider, providerSubject)
- return scanUser(row)
- }
- func (s *Store) CreateRefreshToken(ctx context.Context, tx Tx, params CreateRefreshTokenParams) (string, error) {
- row := tx.QueryRow(ctx, `
- INSERT INTO auth_refresh_tokens (user_id, client_type, device_key, token_hash, expires_at)
- VALUES ($1, $2, NULLIF($3, ''), $4, $5)
- RETURNING id
- `, params.UserID, params.ClientType, params.DeviceKey, params.TokenHash, params.ExpiresAt)
- var id string
- if err := row.Scan(&id); err != nil {
- return "", fmt.Errorf("create refresh token: %w", err)
- }
- return id, nil
- }
- func (s *Store) GetRefreshTokenForUpdate(ctx context.Context, tx Tx, tokenHash string) (*RefreshTokenRecord, error) {
- row := tx.QueryRow(ctx, `
- SELECT id, user_id, client_type, device_key, expires_at, revoked_at IS NOT NULL
- FROM auth_refresh_tokens
- WHERE token_hash = $1
- FOR UPDATE
- `, tokenHash)
- var record RefreshTokenRecord
- err := row.Scan(&record.ID, &record.UserID, &record.ClientType, &record.DeviceKey, &record.ExpiresAt, &record.IsRevoked)
- if errors.Is(err, pgx.ErrNoRows) {
- return nil, nil
- }
- if err != nil {
- return nil, fmt.Errorf("query refresh token for update: %w", err)
- }
- return &record, nil
- }
- func (s *Store) RotateRefreshToken(ctx context.Context, tx Tx, oldTokenID, newTokenID string) error {
- _, err := tx.Exec(ctx, `
- UPDATE auth_refresh_tokens
- SET revoked_at = NOW(), replaced_by_token_id = $2
- WHERE id = $1
- `, oldTokenID, newTokenID)
- if err != nil {
- return fmt.Errorf("rotate refresh token: %w", err)
- }
- return nil
- }
- func (s *Store) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
- commandTag, err := s.pool.Exec(ctx, `
- UPDATE auth_refresh_tokens
- SET revoked_at = COALESCE(revoked_at, NOW())
- WHERE token_hash = $1
- `, tokenHash)
- if err != nil {
- return fmt.Errorf("revoke refresh token: %w", err)
- }
- if commandTag.RowsAffected() == 0 {
- return apperr.New(http.StatusNotFound, "refresh_token_not_found", "refresh token not found")
- }
- return nil
- }
- func (s *Store) RevokeRefreshTokensByUserID(ctx context.Context, tx Tx, userID string) error {
- _, err := tx.Exec(ctx, `
- UPDATE auth_refresh_tokens
- SET revoked_at = COALESCE(revoked_at, NOW())
- WHERE user_id = $1
- `, userID)
- if err != nil {
- return fmt.Errorf("revoke refresh tokens by user id: %w", err)
- }
- return nil
- }
- func (s *Store) TransferNonMobileIdentities(ctx context.Context, tx Tx, sourceUserID, targetUserID string) error {
- if sourceUserID == targetUserID {
- return nil
- }
- _, err := tx.Exec(ctx, `
- INSERT INTO login_identities (
- user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb, created_at, updated_at
- )
- SELECT
- $2,
- li.identity_type,
- li.provider,
- li.provider_subject,
- li.country_code,
- li.mobile,
- li.status,
- li.profile_jsonb,
- li.created_at,
- li.updated_at
- FROM login_identities li
- WHERE li.user_id = $1
- AND li.provider <> 'mobile'
- ON CONFLICT (provider, provider_subject) DO NOTHING
- `, sourceUserID, targetUserID)
- if err != nil {
- return fmt.Errorf("copy non-mobile identities: %w", err)
- }
- _, err = tx.Exec(ctx, `
- DELETE FROM login_identities
- WHERE user_id = $1
- AND provider <> 'mobile'
- `, sourceUserID)
- if err != nil {
- return fmt.Errorf("delete source non-mobile identities: %w", err)
- }
- return nil
- }
- func zeroJSON(value string) string {
- if value == "" {
- return "{}"
- }
- return value
- }
|