client.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package gateway
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "log/slog"
  7. "net/http"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/coder/websocket"
  12. "github.com/coder/websocket/wsjson"
  13. "realtime-gateway/internal/channel"
  14. "realtime-gateway/internal/config"
  15. "realtime-gateway/internal/model"
  16. "realtime-gateway/internal/plugin"
  17. "realtime-gateway/internal/router"
  18. "realtime-gateway/internal/session"
  19. )
  20. type client struct {
  21. conn *websocket.Conn
  22. logger *slog.Logger
  23. cfg config.GatewayConfig
  24. hub *router.Hub
  25. channels *channel.Manager
  26. plugins *plugin.Bus
  27. session *session.Session
  28. auth config.AuthConfig
  29. writeMu sync.Mutex
  30. }
  31. func serveClient(
  32. w http.ResponseWriter,
  33. r *http.Request,
  34. logger *slog.Logger,
  35. cfg config.Config,
  36. hub *router.Hub,
  37. channels *channel.Manager,
  38. plugins *plugin.Bus,
  39. sessions *session.Manager,
  40. ) {
  41. conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
  42. InsecureSkipVerify: true,
  43. })
  44. if err != nil {
  45. logger.Error("websocket accept failed", "error", err)
  46. return
  47. }
  48. sess := sessions.Create()
  49. c := &client{
  50. conn: conn,
  51. logger: logger.With("sessionId", sess.ID),
  52. cfg: cfg.Gateway,
  53. hub: hub,
  54. channels: channels,
  55. plugins: plugins,
  56. session: sess,
  57. auth: cfg.Auth,
  58. }
  59. hub.Register(c, nil)
  60. defer func() {
  61. if sess.ChannelID != "" {
  62. channels.Unbind(sess.ChannelID, sess.Role)
  63. }
  64. hub.Unregister(sess.ID)
  65. sessions.Delete(sess.ID)
  66. _ = conn.Close(websocket.StatusNormalClosure, "session closed")
  67. }()
  68. if err := c.run(r.Context()); err != nil && !errors.Is(err, context.Canceled) {
  69. c.logger.Warn("client closed", "error", err)
  70. }
  71. }
  72. func (c *client) ID() string {
  73. return c.session.ID
  74. }
  75. func (c *client) Send(message model.ServerMessage) error {
  76. c.writeMu.Lock()
  77. defer c.writeMu.Unlock()
  78. ctx, cancel := context.WithTimeout(context.Background(), c.cfg.WriteWait())
  79. defer cancel()
  80. return wsjson.Write(ctx, c.conn, message)
  81. }
  82. func (c *client) run(ctx context.Context) error {
  83. if err := c.Send(model.ServerMessage{
  84. Type: "welcome",
  85. SessionID: c.session.ID,
  86. }); err != nil {
  87. return err
  88. }
  89. pingCtx, cancelPing := context.WithCancel(ctx)
  90. defer cancelPing()
  91. go c.pingLoop(pingCtx)
  92. for {
  93. readCtx, cancel := context.WithTimeout(ctx, c.cfg.PongWait())
  94. var message model.ClientMessage
  95. err := wsjson.Read(readCtx, c.conn, &message)
  96. cancel()
  97. if err != nil {
  98. return err
  99. }
  100. if err := c.handleMessage(message); err != nil {
  101. _ = c.Send(model.ServerMessage{
  102. Type: "error",
  103. Error: err.Error(),
  104. })
  105. }
  106. }
  107. }
  108. func (c *client) handleMessage(message model.ClientMessage) error {
  109. switch message.Type {
  110. case "authenticate":
  111. return c.handleAuthenticate(message)
  112. case "join_channel":
  113. return c.handleJoinChannel(message)
  114. case "subscribe":
  115. return c.handleSubscribe(message)
  116. case "publish":
  117. return c.handlePublish(message)
  118. case "snapshot":
  119. return c.handleSnapshot(message)
  120. default:
  121. return errors.New("unsupported message type")
  122. }
  123. }
  124. func (c *client) handleJoinChannel(message model.ClientMessage) error {
  125. if strings.TrimSpace(message.ChannelID) == "" {
  126. return errors.New("channelId is required")
  127. }
  128. snapshot, err := c.channels.Join(message.ChannelID, message.Token, message.Role)
  129. if err != nil {
  130. return err
  131. }
  132. if c.session.ChannelID != "" {
  133. c.channels.Unbind(c.session.ChannelID, c.session.Role)
  134. }
  135. if err := c.channels.Bind(snapshot.ID, message.Role); err != nil {
  136. return err
  137. }
  138. c.session.Role = message.Role
  139. c.session.Authenticated = true
  140. c.session.ChannelID = snapshot.ID
  141. c.session.Subscriptions = nil
  142. c.hub.UpdateSubscriptions(c.session.ID, nil)
  143. return c.Send(model.ServerMessage{
  144. Type: "joined_channel",
  145. SessionID: c.session.ID,
  146. State: json.RawMessage([]byte(
  147. `{"channelId":"` + snapshot.ID + `","deliveryMode":"` + snapshot.DeliveryMode + `"}`,
  148. )),
  149. })
  150. }
  151. func (c *client) handleAuthenticate(message model.ClientMessage) error {
  152. if !authorize(c.auth, message.Role, message.Token) {
  153. return errors.New("authentication failed")
  154. }
  155. c.session.Role = message.Role
  156. c.session.Authenticated = true
  157. return c.Send(model.ServerMessage{
  158. Type: "authenticated",
  159. SessionID: c.session.ID,
  160. })
  161. }
  162. func (c *client) handleSubscribe(message model.ClientMessage) error {
  163. if !c.session.Authenticated && !c.auth.AllowAnonymousConsumers {
  164. return errors.New("consumer must authenticate before subscribe")
  165. }
  166. subscriptions := normalizeSubscriptions(c.session.ChannelID, message.Subscriptions)
  167. c.session.Subscriptions = subscriptions
  168. c.hub.UpdateSubscriptions(c.session.ID, subscriptions)
  169. return c.Send(model.ServerMessage{
  170. Type: "subscribed",
  171. SessionID: c.session.ID,
  172. })
  173. }
  174. func (c *client) handlePublish(message model.ClientMessage) error {
  175. if !c.session.Authenticated {
  176. return errors.New("authentication required")
  177. }
  178. if c.session.Role != model.RoleProducer && c.session.Role != model.RoleController {
  179. return errors.New("publish is only allowed for producer or controller")
  180. }
  181. if message.Envelope == nil {
  182. return errors.New("envelope is required")
  183. }
  184. envelope := *message.Envelope
  185. if envelope.Source.Kind == "" {
  186. envelope.Source.Kind = c.session.Role
  187. }
  188. if c.session.ChannelID != "" {
  189. envelope.Target.ChannelID = c.session.ChannelID
  190. }
  191. deliveryMode := channel.DeliveryModeCacheLatest
  192. if envelope.Target.ChannelID != "" {
  193. deliveryMode = c.channels.DeliveryMode(envelope.Target.ChannelID)
  194. }
  195. result := c.hub.Publish(envelope, deliveryMode)
  196. if !result.Dropped {
  197. c.plugins.Publish(envelope)
  198. }
  199. return c.Send(model.ServerMessage{
  200. Type: "published",
  201. SessionID: c.session.ID,
  202. })
  203. }
  204. func (c *client) handleSnapshot(message model.ClientMessage) error {
  205. if len(message.Subscriptions) == 0 || message.Subscriptions[0].DeviceID == "" {
  206. return errors.New("snapshot requires deviceId in first subscription")
  207. }
  208. channelID := message.Subscriptions[0].ChannelID
  209. if channelID == "" {
  210. channelID = c.session.ChannelID
  211. }
  212. state, ok := c.hub.Snapshot(channelID, message.Subscriptions[0].DeviceID)
  213. if !ok {
  214. return errors.New("snapshot not found")
  215. }
  216. return c.Send(model.ServerMessage{
  217. Type: "snapshot",
  218. SessionID: c.session.ID,
  219. State: json.RawMessage(state),
  220. })
  221. }
  222. func (c *client) pingLoop(ctx context.Context) {
  223. ticker := time.NewTicker(c.cfg.PingInterval())
  224. defer ticker.Stop()
  225. for {
  226. select {
  227. case <-ctx.Done():
  228. return
  229. case <-ticker.C:
  230. pingCtx, cancel := context.WithTimeout(ctx, c.cfg.WriteWait())
  231. _ = c.conn.Ping(pingCtx)
  232. cancel()
  233. }
  234. }
  235. }
  236. func normalizeSubscriptions(channelID string, subscriptions []model.Subscription) []model.Subscription {
  237. items := make([]model.Subscription, 0, len(subscriptions))
  238. for _, entry := range subscriptions {
  239. if channelID != "" && strings.TrimSpace(entry.ChannelID) == "" {
  240. entry.ChannelID = channelID
  241. }
  242. items = append(items, entry)
  243. }
  244. return items
  245. }