package gateway import ( "context" "encoding/json" "errors" "log/slog" "net/http" "strings" "sync" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" "realtime-gateway/internal/channel" "realtime-gateway/internal/config" "realtime-gateway/internal/model" "realtime-gateway/internal/plugin" "realtime-gateway/internal/router" "realtime-gateway/internal/session" ) type client struct { conn *websocket.Conn logger *slog.Logger cfg config.GatewayConfig hub *router.Hub channels *channel.Manager plugins *plugin.Bus session *session.Session auth config.AuthConfig writeMu sync.Mutex } func serveClient( w http.ResponseWriter, r *http.Request, logger *slog.Logger, cfg config.Config, hub *router.Hub, channels *channel.Manager, plugins *plugin.Bus, sessions *session.Manager, ) { conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { logger.Error("websocket accept failed", "error", err) return } sess := sessions.Create() c := &client{ conn: conn, logger: logger.With("sessionId", sess.ID), cfg: cfg.Gateway, hub: hub, channels: channels, plugins: plugins, session: sess, auth: cfg.Auth, } hub.Register(c, nil) defer func() { if sess.ChannelID != "" { channels.Unbind(sess.ChannelID, sess.Role) } hub.Unregister(sess.ID) sessions.Delete(sess.ID) _ = conn.Close(websocket.StatusNormalClosure, "session closed") }() if err := c.run(r.Context()); err != nil && !errors.Is(err, context.Canceled) { c.logger.Warn("client closed", "error", err) } } func (c *client) ID() string { return c.session.ID } func (c *client) Send(message model.ServerMessage) error { c.writeMu.Lock() defer c.writeMu.Unlock() ctx, cancel := context.WithTimeout(context.Background(), c.cfg.WriteWait()) defer cancel() return wsjson.Write(ctx, c.conn, message) } func (c *client) run(ctx context.Context) error { if err := c.Send(model.ServerMessage{ Type: "welcome", SessionID: c.session.ID, }); err != nil { return err } pingCtx, cancelPing := context.WithCancel(ctx) defer cancelPing() go c.pingLoop(pingCtx) for { readCtx, cancel := context.WithTimeout(ctx, c.cfg.PongWait()) var message model.ClientMessage err := wsjson.Read(readCtx, c.conn, &message) cancel() if err != nil { return err } if err := c.handleMessage(message); err != nil { _ = c.Send(model.ServerMessage{ Type: "error", Error: err.Error(), }) } } } func (c *client) handleMessage(message model.ClientMessage) error { switch message.Type { case "authenticate": return c.handleAuthenticate(message) case "join_channel": return c.handleJoinChannel(message) case "subscribe": return c.handleSubscribe(message) case "publish": return c.handlePublish(message) case "snapshot": return c.handleSnapshot(message) default: return errors.New("unsupported message type") } } func (c *client) handleJoinChannel(message model.ClientMessage) error { if strings.TrimSpace(message.ChannelID) == "" { return errors.New("channelId is required") } snapshot, err := c.channels.Join(message.ChannelID, message.Token, message.Role) if err != nil { return err } if c.session.ChannelID != "" { c.channels.Unbind(c.session.ChannelID, c.session.Role) } if err := c.channels.Bind(snapshot.ID, message.Role); err != nil { return err } c.session.Role = message.Role c.session.Authenticated = true c.session.ChannelID = snapshot.ID c.session.Subscriptions = nil c.hub.UpdateSubscriptions(c.session.ID, nil) return c.Send(model.ServerMessage{ Type: "joined_channel", SessionID: c.session.ID, State: json.RawMessage([]byte( `{"channelId":"` + snapshot.ID + `","deliveryMode":"` + snapshot.DeliveryMode + `"}`, )), }) } func (c *client) handleAuthenticate(message model.ClientMessage) error { if !authorize(c.auth, message.Role, message.Token) { return errors.New("authentication failed") } c.session.Role = message.Role c.session.Authenticated = true return c.Send(model.ServerMessage{ Type: "authenticated", SessionID: c.session.ID, }) } func (c *client) handleSubscribe(message model.ClientMessage) error { if !c.session.Authenticated && !c.auth.AllowAnonymousConsumers { return errors.New("consumer must authenticate before subscribe") } subscriptions := normalizeSubscriptions(c.session.ChannelID, message.Subscriptions) c.session.Subscriptions = subscriptions c.hub.UpdateSubscriptions(c.session.ID, subscriptions) return c.Send(model.ServerMessage{ Type: "subscribed", SessionID: c.session.ID, }) } func (c *client) handlePublish(message model.ClientMessage) error { if !c.session.Authenticated { return errors.New("authentication required") } if c.session.Role != model.RoleProducer && c.session.Role != model.RoleController { return errors.New("publish is only allowed for producer or controller") } if message.Envelope == nil { return errors.New("envelope is required") } envelope := *message.Envelope if envelope.Source.Kind == "" { envelope.Source.Kind = c.session.Role } if c.session.ChannelID != "" { envelope.Target.ChannelID = c.session.ChannelID } deliveryMode := channel.DeliveryModeCacheLatest if envelope.Target.ChannelID != "" { deliveryMode = c.channels.DeliveryMode(envelope.Target.ChannelID) } result := c.hub.Publish(envelope, deliveryMode) if !result.Dropped { c.plugins.Publish(envelope) } return c.Send(model.ServerMessage{ Type: "published", SessionID: c.session.ID, }) } func (c *client) handleSnapshot(message model.ClientMessage) error { if len(message.Subscriptions) == 0 || message.Subscriptions[0].DeviceID == "" { return errors.New("snapshot requires deviceId in first subscription") } channelID := message.Subscriptions[0].ChannelID if channelID == "" { channelID = c.session.ChannelID } state, ok := c.hub.Snapshot(channelID, message.Subscriptions[0].DeviceID) if !ok { return errors.New("snapshot not found") } return c.Send(model.ServerMessage{ Type: "snapshot", SessionID: c.session.ID, State: json.RawMessage(state), }) } func (c *client) pingLoop(ctx context.Context) { ticker := time.NewTicker(c.cfg.PingInterval()) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: pingCtx, cancel := context.WithTimeout(ctx, c.cfg.WriteWait()) _ = c.conn.Ping(pingCtx) cancel() } } } func normalizeSubscriptions(channelID string, subscriptions []model.Subscription) []model.Subscription { items := make([]model.Subscription, 0, len(subscriptions)) for _, entry := range subscriptions { if channelID != "" && strings.TrimSpace(entry.ChannelID) == "" { entry.ChannelID = channelID } items = append(items, entry) } return items }