session_store.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package postgres
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "time"
  7. "github.com/jackc/pgx/v5"
  8. )
  9. type Session struct {
  10. ID string
  11. SessionPublicID string
  12. UserID string
  13. EventID string
  14. EventReleaseID string
  15. ReleasePublicID *string
  16. ConfigLabel *string
  17. ManifestURL *string
  18. ManifestChecksum *string
  19. DeviceKey string
  20. ClientType string
  21. RouteCode *string
  22. Status string
  23. SessionTokenHash string
  24. SessionTokenExpiresAt time.Time
  25. LaunchedAt time.Time
  26. StartedAt *time.Time
  27. EndedAt *time.Time
  28. EventPublicID *string
  29. EventDisplayName *string
  30. }
  31. type FinishSessionParams struct {
  32. SessionID string
  33. Status string
  34. }
  35. func (s *Store) GetSessionByPublicID(ctx context.Context, sessionPublicID string) (*Session, error) {
  36. row := s.pool.QueryRow(ctx, `
  37. SELECT
  38. gs.id,
  39. gs.session_public_id,
  40. gs.user_id,
  41. gs.event_id,
  42. gs.event_release_id,
  43. er.release_public_id,
  44. er.config_label,
  45. er.manifest_url,
  46. er.manifest_checksum_sha256,
  47. gs.device_key,
  48. gs.client_type,
  49. gs.route_code,
  50. gs.status,
  51. gs.session_token_hash,
  52. gs.session_token_expires_at,
  53. gs.launched_at,
  54. gs.started_at,
  55. gs.ended_at,
  56. e.event_public_id,
  57. e.display_name
  58. FROM game_sessions gs
  59. JOIN events e ON e.id = gs.event_id
  60. JOIN event_releases er ON er.id = gs.event_release_id
  61. WHERE gs.session_public_id = $1
  62. LIMIT 1
  63. `, sessionPublicID)
  64. return scanSession(row)
  65. }
  66. func (s *Store) GetSessionByPublicIDForUpdate(ctx context.Context, tx Tx, sessionPublicID string) (*Session, error) {
  67. row := tx.QueryRow(ctx, `
  68. SELECT
  69. gs.id,
  70. gs.session_public_id,
  71. gs.user_id,
  72. gs.event_id,
  73. gs.event_release_id,
  74. er.release_public_id,
  75. er.config_label,
  76. er.manifest_url,
  77. er.manifest_checksum_sha256,
  78. gs.device_key,
  79. gs.client_type,
  80. gs.route_code,
  81. gs.status,
  82. gs.session_token_hash,
  83. gs.session_token_expires_at,
  84. gs.launched_at,
  85. gs.started_at,
  86. gs.ended_at,
  87. e.event_public_id,
  88. e.display_name
  89. FROM game_sessions gs
  90. JOIN events e ON e.id = gs.event_id
  91. JOIN event_releases er ON er.id = gs.event_release_id
  92. WHERE gs.session_public_id = $1
  93. FOR UPDATE
  94. `, sessionPublicID)
  95. return scanSession(row)
  96. }
  97. func (s *Store) ListSessionsByUserID(ctx context.Context, userID string, limit int) ([]Session, error) {
  98. if limit <= 0 || limit > 100 {
  99. limit = 20
  100. }
  101. rows, err := s.pool.Query(ctx, `
  102. SELECT
  103. gs.id,
  104. gs.session_public_id,
  105. gs.user_id,
  106. gs.event_id,
  107. gs.event_release_id,
  108. er.release_public_id,
  109. er.config_label,
  110. er.manifest_url,
  111. er.manifest_checksum_sha256,
  112. gs.device_key,
  113. gs.client_type,
  114. gs.route_code,
  115. gs.status,
  116. gs.session_token_hash,
  117. gs.session_token_expires_at,
  118. gs.launched_at,
  119. gs.started_at,
  120. gs.ended_at,
  121. e.event_public_id,
  122. e.display_name
  123. FROM game_sessions gs
  124. JOIN events e ON e.id = gs.event_id
  125. JOIN event_releases er ON er.id = gs.event_release_id
  126. WHERE gs.user_id = $1
  127. ORDER BY gs.created_at DESC
  128. LIMIT $2
  129. `, userID, limit)
  130. if err != nil {
  131. return nil, fmt.Errorf("list sessions by user id: %w", err)
  132. }
  133. defer rows.Close()
  134. var sessions []Session
  135. for rows.Next() {
  136. session, err := scanSessionFromRows(rows)
  137. if err != nil {
  138. return nil, err
  139. }
  140. sessions = append(sessions, *session)
  141. }
  142. if err := rows.Err(); err != nil {
  143. return nil, fmt.Errorf("iterate sessions by user id: %w", err)
  144. }
  145. return sessions, nil
  146. }
  147. func (s *Store) ListSessionsByUserAndEvent(ctx context.Context, userID, eventID string, limit int) ([]Session, error) {
  148. if limit <= 0 || limit > 100 {
  149. limit = 20
  150. }
  151. rows, err := s.pool.Query(ctx, `
  152. SELECT
  153. gs.id,
  154. gs.session_public_id,
  155. gs.user_id,
  156. gs.event_id,
  157. gs.event_release_id,
  158. er.release_public_id,
  159. er.config_label,
  160. er.manifest_url,
  161. er.manifest_checksum_sha256,
  162. gs.device_key,
  163. gs.client_type,
  164. gs.route_code,
  165. gs.status,
  166. gs.session_token_hash,
  167. gs.session_token_expires_at,
  168. gs.launched_at,
  169. gs.started_at,
  170. gs.ended_at,
  171. e.event_public_id,
  172. e.display_name
  173. FROM game_sessions gs
  174. JOIN events e ON e.id = gs.event_id
  175. JOIN event_releases er ON er.id = gs.event_release_id
  176. WHERE gs.user_id = $1
  177. AND gs.event_id = $2
  178. ORDER BY gs.created_at DESC
  179. LIMIT $3
  180. `, userID, eventID, limit)
  181. if err != nil {
  182. return nil, fmt.Errorf("list sessions by user and event: %w", err)
  183. }
  184. defer rows.Close()
  185. var sessions []Session
  186. for rows.Next() {
  187. session, err := scanSessionFromRows(rows)
  188. if err != nil {
  189. return nil, err
  190. }
  191. sessions = append(sessions, *session)
  192. }
  193. if err := rows.Err(); err != nil {
  194. return nil, fmt.Errorf("iterate sessions by user and event: %w", err)
  195. }
  196. return sessions, nil
  197. }
  198. func (s *Store) StartSession(ctx context.Context, tx Tx, sessionID string) error {
  199. _, err := tx.Exec(ctx, `
  200. UPDATE game_sessions
  201. SET status = CASE WHEN status = 'launched' THEN 'running' ELSE status END,
  202. started_at = COALESCE(started_at, NOW())
  203. WHERE id = $1
  204. `, sessionID)
  205. if err != nil {
  206. return fmt.Errorf("start session: %w", err)
  207. }
  208. return nil
  209. }
  210. func (s *Store) FinishSession(ctx context.Context, tx Tx, params FinishSessionParams) error {
  211. _, err := tx.Exec(ctx, `
  212. UPDATE game_sessions
  213. SET status = $2,
  214. started_at = COALESCE(started_at, NOW()),
  215. ended_at = COALESCE(ended_at, NOW())
  216. WHERE id = $1
  217. `, params.SessionID, params.Status)
  218. if err != nil {
  219. return fmt.Errorf("finish session: %w", err)
  220. }
  221. return nil
  222. }
  223. func scanSession(row pgx.Row) (*Session, error) {
  224. var session Session
  225. err := row.Scan(
  226. &session.ID,
  227. &session.SessionPublicID,
  228. &session.UserID,
  229. &session.EventID,
  230. &session.EventReleaseID,
  231. &session.ReleasePublicID,
  232. &session.ConfigLabel,
  233. &session.ManifestURL,
  234. &session.ManifestChecksum,
  235. &session.DeviceKey,
  236. &session.ClientType,
  237. &session.RouteCode,
  238. &session.Status,
  239. &session.SessionTokenHash,
  240. &session.SessionTokenExpiresAt,
  241. &session.LaunchedAt,
  242. &session.StartedAt,
  243. &session.EndedAt,
  244. &session.EventPublicID,
  245. &session.EventDisplayName,
  246. )
  247. if errors.Is(err, pgx.ErrNoRows) {
  248. return nil, nil
  249. }
  250. if err != nil {
  251. return nil, fmt.Errorf("scan session: %w", err)
  252. }
  253. return &session, nil
  254. }
  255. func scanSessionFromRows(rows pgx.Rows) (*Session, error) {
  256. var session Session
  257. err := rows.Scan(
  258. &session.ID,
  259. &session.SessionPublicID,
  260. &session.UserID,
  261. &session.EventID,
  262. &session.EventReleaseID,
  263. &session.ReleasePublicID,
  264. &session.ConfigLabel,
  265. &session.ManifestURL,
  266. &session.ManifestChecksum,
  267. &session.DeviceKey,
  268. &session.ClientType,
  269. &session.RouteCode,
  270. &session.Status,
  271. &session.SessionTokenHash,
  272. &session.SessionTokenExpiresAt,
  273. &session.LaunchedAt,
  274. &session.StartedAt,
  275. &session.EndedAt,
  276. &session.EventPublicID,
  277. &session.EventDisplayName,
  278. )
  279. if err != nil {
  280. return nil, fmt.Errorf("scan session row: %w", err)
  281. }
  282. return &session, nil
  283. }