package postgres import ( "context" "encoding/json" "errors" "fmt" "net/http" "time" "cmr-backend/internal/apperr" "github.com/jackc/pgx/v5" ) type SMSCodeMeta struct { ID string CodeHash string ExpiresAt time.Time CooldownUntil time.Time } type CreateSMSCodeParams struct { Scene string CountryCode string Mobile string ClientType string DeviceKey string CodeHash string ProviderName string ProviderDebug map[string]any ExpiresAt time.Time CooldownUntil time.Time } type CreateMobileIdentityParams struct { UserID string IdentityType string Provider string ProviderSubj string CountryCode string Mobile string } type CreateIdentityParams struct { UserID string IdentityType string Provider string ProviderSubj string CountryCode *string Mobile *string ProfileJSON string } type CreateRefreshTokenParams struct { UserID string ClientType string DeviceKey string TokenHash string ExpiresAt time.Time } type RefreshTokenRecord struct { ID string UserID string ClientType string DeviceKey *string ExpiresAt time.Time IsRevoked bool } func (s *Store) GetLatestSMSCodeMeta(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) { row := s.pool.QueryRow(ctx, ` SELECT id, code_hash, expires_at, cooldown_until FROM auth_sms_codes WHERE country_code = $1 AND mobile = $2 AND client_type = $3 AND scene = $4 ORDER BY created_at DESC LIMIT 1 `, countryCode, mobile, clientType, scene) var record SMSCodeMeta err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } if err != nil { return nil, fmt.Errorf("query latest sms code meta: %w", err) } return &record, nil } func (s *Store) CreateSMSCode(ctx context.Context, params CreateSMSCodeParams) error { payload, err := json.Marshal(map[string]any{ "provider": params.ProviderName, "debug": params.ProviderDebug, }) if err != nil { return err } _, err = s.pool.Exec(ctx, ` INSERT INTO auth_sms_codes ( scene, country_code, mobile, client_type, device_key, code_hash, provider_payload_jsonb, expires_at, cooldown_until ) VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9) `, params.Scene, params.CountryCode, params.Mobile, params.ClientType, params.DeviceKey, params.CodeHash, string(payload), params.ExpiresAt, params.CooldownUntil) if err != nil { return fmt.Errorf("insert sms code: %w", err) } return nil } func (s *Store) GetLatestValidSMSCode(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) { row := s.pool.QueryRow(ctx, ` SELECT id, code_hash, expires_at, cooldown_until FROM auth_sms_codes WHERE country_code = $1 AND mobile = $2 AND client_type = $3 AND scene = $4 AND consumed_at IS NULL AND expires_at > NOW() ORDER BY created_at DESC LIMIT 1 `, countryCode, mobile, clientType, scene) var record SMSCodeMeta err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } if err != nil { return nil, fmt.Errorf("query latest valid sms code: %w", err) } return &record, nil } func (s *Store) ConsumeSMSCode(ctx context.Context, tx Tx, id string) (bool, error) { commandTag, err := tx.Exec(ctx, ` UPDATE auth_sms_codes SET consumed_at = NOW() WHERE id = $1 AND consumed_at IS NULL `, id) if err != nil { return false, fmt.Errorf("consume sms code: %w", err) } return commandTag.RowsAffected() == 1, nil } func (s *Store) CreateMobileIdentity(ctx context.Context, tx Tx, params CreateMobileIdentityParams) error { countryCode := params.CountryCode mobile := params.Mobile return s.CreateIdentity(ctx, tx, CreateIdentityParams{ UserID: params.UserID, IdentityType: params.IdentityType, Provider: params.Provider, ProviderSubj: params.ProviderSubj, CountryCode: &countryCode, Mobile: &mobile, ProfileJSON: "{}", }) } func (s *Store) CreateIdentity(ctx context.Context, tx Tx, params CreateIdentityParams) error { _, err := tx.Exec(ctx, ` INSERT INTO login_identities ( user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb ) VALUES ($1, $2, $3, $4, $5, $6, 'active', $7::jsonb) ON CONFLICT (provider, provider_subject) DO NOTHING `, params.UserID, params.IdentityType, params.Provider, params.ProviderSubj, params.CountryCode, params.Mobile, zeroJSON(params.ProfileJSON)) if err != nil { return fmt.Errorf("create identity: %w", err) } return nil } func (s *Store) FindUserByProviderSubject(ctx context.Context, tx Tx, provider, providerSubject string) (*User, error) { row := tx.QueryRow(ctx, ` SELECT u.id, u.user_public_id, u.status, u.nickname, u.avatar_url FROM users u JOIN login_identities li ON li.user_id = u.id WHERE li.provider = $1 AND li.provider_subject = $2 AND li.status = 'active' LIMIT 1 `, provider, providerSubject) return scanUser(row) } func (s *Store) CreateRefreshToken(ctx context.Context, tx Tx, params CreateRefreshTokenParams) (string, error) { row := tx.QueryRow(ctx, ` INSERT INTO auth_refresh_tokens (user_id, client_type, device_key, token_hash, expires_at) VALUES ($1, $2, NULLIF($3, ''), $4, $5) RETURNING id `, params.UserID, params.ClientType, params.DeviceKey, params.TokenHash, params.ExpiresAt) var id string if err := row.Scan(&id); err != nil { return "", fmt.Errorf("create refresh token: %w", err) } return id, nil } func (s *Store) GetRefreshTokenForUpdate(ctx context.Context, tx Tx, tokenHash string) (*RefreshTokenRecord, error) { row := tx.QueryRow(ctx, ` SELECT id, user_id, client_type, device_key, expires_at, revoked_at IS NOT NULL FROM auth_refresh_tokens WHERE token_hash = $1 FOR UPDATE `, tokenHash) var record RefreshTokenRecord err := row.Scan(&record.ID, &record.UserID, &record.ClientType, &record.DeviceKey, &record.ExpiresAt, &record.IsRevoked) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } if err != nil { return nil, fmt.Errorf("query refresh token for update: %w", err) } return &record, nil } func (s *Store) RotateRefreshToken(ctx context.Context, tx Tx, oldTokenID, newTokenID string) error { _, err := tx.Exec(ctx, ` UPDATE auth_refresh_tokens SET revoked_at = NOW(), replaced_by_token_id = $2 WHERE id = $1 `, oldTokenID, newTokenID) if err != nil { return fmt.Errorf("rotate refresh token: %w", err) } return nil } func (s *Store) RevokeRefreshToken(ctx context.Context, tokenHash string) error { commandTag, err := s.pool.Exec(ctx, ` UPDATE auth_refresh_tokens SET revoked_at = COALESCE(revoked_at, NOW()) WHERE token_hash = $1 `, tokenHash) if err != nil { return fmt.Errorf("revoke refresh token: %w", err) } if commandTag.RowsAffected() == 0 { return apperr.New(http.StatusNotFound, "refresh_token_not_found", "refresh token not found") } return nil } func (s *Store) RevokeRefreshTokensByUserID(ctx context.Context, tx Tx, userID string) error { _, err := tx.Exec(ctx, ` UPDATE auth_refresh_tokens SET revoked_at = COALESCE(revoked_at, NOW()) WHERE user_id = $1 `, userID) if err != nil { return fmt.Errorf("revoke refresh tokens by user id: %w", err) } return nil } func (s *Store) TransferNonMobileIdentities(ctx context.Context, tx Tx, sourceUserID, targetUserID string) error { if sourceUserID == targetUserID { return nil } _, err := tx.Exec(ctx, ` INSERT INTO login_identities ( user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb, created_at, updated_at ) SELECT $2, li.identity_type, li.provider, li.provider_subject, li.country_code, li.mobile, li.status, li.profile_jsonb, li.created_at, li.updated_at FROM login_identities li WHERE li.user_id = $1 AND li.provider <> 'mobile' ON CONFLICT (provider, provider_subject) DO NOTHING `, sourceUserID, targetUserID) if err != nil { return fmt.Errorf("copy non-mobile identities: %w", err) } _, err = tx.Exec(ctx, ` DELETE FROM login_identities WHERE user_id = $1 AND provider <> 'mobile' `, sourceUserID) if err != nil { return fmt.Errorf("delete source non-mobile identities: %w", err) } return nil } func zeroJSON(value string) string { if value == "" { return "{}" } return value }