| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- package gateway
- import (
- "context"
- "encoding/json"
- "errors"
- "log/slog"
- "net/http"
- "strings"
- "sync"
- "time"
- "github.com/coder/websocket"
- "github.com/coder/websocket/wsjson"
- "realtime-gateway/internal/channel"
- "realtime-gateway/internal/config"
- "realtime-gateway/internal/model"
- "realtime-gateway/internal/plugin"
- "realtime-gateway/internal/router"
- "realtime-gateway/internal/session"
- )
- type client struct {
- conn *websocket.Conn
- logger *slog.Logger
- cfg config.GatewayConfig
- hub *router.Hub
- channels *channel.Manager
- plugins *plugin.Bus
- session *session.Session
- auth config.AuthConfig
- writeMu sync.Mutex
- }
- func serveClient(
- w http.ResponseWriter,
- r *http.Request,
- logger *slog.Logger,
- cfg config.Config,
- hub *router.Hub,
- channels *channel.Manager,
- plugins *plugin.Bus,
- sessions *session.Manager,
- ) {
- conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
- InsecureSkipVerify: true,
- })
- if err != nil {
- logger.Error("websocket accept failed", "error", err)
- return
- }
- sess := sessions.Create()
- c := &client{
- conn: conn,
- logger: logger.With("sessionId", sess.ID),
- cfg: cfg.Gateway,
- hub: hub,
- channels: channels,
- plugins: plugins,
- session: sess,
- auth: cfg.Auth,
- }
- hub.Register(c, nil)
- defer func() {
- if sess.ChannelID != "" {
- channels.Unbind(sess.ChannelID, sess.Role)
- }
- hub.Unregister(sess.ID)
- sessions.Delete(sess.ID)
- _ = conn.Close(websocket.StatusNormalClosure, "session closed")
- }()
- if err := c.run(r.Context()); err != nil && !errors.Is(err, context.Canceled) {
- c.logger.Warn("client closed", "error", err)
- }
- }
- func (c *client) ID() string {
- return c.session.ID
- }
- func (c *client) Send(message model.ServerMessage) error {
- c.writeMu.Lock()
- defer c.writeMu.Unlock()
- ctx, cancel := context.WithTimeout(context.Background(), c.cfg.WriteWait())
- defer cancel()
- return wsjson.Write(ctx, c.conn, message)
- }
- func (c *client) run(ctx context.Context) error {
- if err := c.Send(model.ServerMessage{
- Type: "welcome",
- SessionID: c.session.ID,
- }); err != nil {
- return err
- }
- pingCtx, cancelPing := context.WithCancel(ctx)
- defer cancelPing()
- go c.pingLoop(pingCtx)
- for {
- readCtx, cancel := context.WithTimeout(ctx, c.cfg.PongWait())
- var message model.ClientMessage
- err := wsjson.Read(readCtx, c.conn, &message)
- cancel()
- if err != nil {
- return err
- }
- if err := c.handleMessage(message); err != nil {
- _ = c.Send(model.ServerMessage{
- Type: "error",
- Error: err.Error(),
- })
- }
- }
- }
- func (c *client) handleMessage(message model.ClientMessage) error {
- switch message.Type {
- case "authenticate":
- return c.handleAuthenticate(message)
- case "join_channel":
- return c.handleJoinChannel(message)
- case "subscribe":
- return c.handleSubscribe(message)
- case "publish":
- return c.handlePublish(message)
- case "snapshot":
- return c.handleSnapshot(message)
- default:
- return errors.New("unsupported message type")
- }
- }
- func (c *client) handleJoinChannel(message model.ClientMessage) error {
- if strings.TrimSpace(message.ChannelID) == "" {
- return errors.New("channelId is required")
- }
- snapshot, err := c.channels.Join(message.ChannelID, message.Token, message.Role)
- if err != nil {
- return err
- }
- if c.session.ChannelID != "" {
- c.channels.Unbind(c.session.ChannelID, c.session.Role)
- }
- if err := c.channels.Bind(snapshot.ID, message.Role); err != nil {
- return err
- }
- c.session.Role = message.Role
- c.session.Authenticated = true
- c.session.ChannelID = snapshot.ID
- c.session.Subscriptions = nil
- c.hub.UpdateSubscriptions(c.session.ID, nil)
- return c.Send(model.ServerMessage{
- Type: "joined_channel",
- SessionID: c.session.ID,
- State: json.RawMessage([]byte(
- `{"channelId":"` + snapshot.ID + `","deliveryMode":"` + snapshot.DeliveryMode + `"}`,
- )),
- })
- }
- func (c *client) handleAuthenticate(message model.ClientMessage) error {
- if !authorize(c.auth, message.Role, message.Token) {
- return errors.New("authentication failed")
- }
- c.session.Role = message.Role
- c.session.Authenticated = true
- return c.Send(model.ServerMessage{
- Type: "authenticated",
- SessionID: c.session.ID,
- })
- }
- func (c *client) handleSubscribe(message model.ClientMessage) error {
- if !c.session.Authenticated && !c.auth.AllowAnonymousConsumers {
- return errors.New("consumer must authenticate before subscribe")
- }
- subscriptions := normalizeSubscriptions(c.session.ChannelID, message.Subscriptions)
- c.session.Subscriptions = subscriptions
- c.hub.UpdateSubscriptions(c.session.ID, subscriptions)
- return c.Send(model.ServerMessage{
- Type: "subscribed",
- SessionID: c.session.ID,
- })
- }
- func (c *client) handlePublish(message model.ClientMessage) error {
- if !c.session.Authenticated {
- return errors.New("authentication required")
- }
- if c.session.Role != model.RoleProducer && c.session.Role != model.RoleController {
- return errors.New("publish is only allowed for producer or controller")
- }
- if message.Envelope == nil {
- return errors.New("envelope is required")
- }
- envelope := *message.Envelope
- if envelope.Source.Kind == "" {
- envelope.Source.Kind = c.session.Role
- }
- if c.session.ChannelID != "" {
- envelope.Target.ChannelID = c.session.ChannelID
- }
- deliveryMode := channel.DeliveryModeCacheLatest
- if envelope.Target.ChannelID != "" {
- deliveryMode = c.channels.DeliveryMode(envelope.Target.ChannelID)
- }
- result := c.hub.Publish(envelope, deliveryMode)
- if !result.Dropped {
- c.plugins.Publish(envelope)
- }
- return c.Send(model.ServerMessage{
- Type: "published",
- SessionID: c.session.ID,
- })
- }
- func (c *client) handleSnapshot(message model.ClientMessage) error {
- if len(message.Subscriptions) == 0 || message.Subscriptions[0].DeviceID == "" {
- return errors.New("snapshot requires deviceId in first subscription")
- }
- channelID := message.Subscriptions[0].ChannelID
- if channelID == "" {
- channelID = c.session.ChannelID
- }
- state, ok := c.hub.Snapshot(channelID, message.Subscriptions[0].DeviceID)
- if !ok {
- return errors.New("snapshot not found")
- }
- return c.Send(model.ServerMessage{
- Type: "snapshot",
- SessionID: c.session.ID,
- State: json.RawMessage(state),
- })
- }
- func (c *client) pingLoop(ctx context.Context) {
- ticker := time.NewTicker(c.cfg.PingInterval())
- defer ticker.Stop()
- for {
- select {
- case <-ctx.Done():
- return
- case <-ticker.C:
- pingCtx, cancel := context.WithTimeout(ctx, c.cfg.WriteWait())
- _ = c.conn.Ping(pingCtx)
- cancel()
- }
- }
- }
- func normalizeSubscriptions(channelID string, subscriptions []model.Subscription) []model.Subscription {
- items := make([]model.Subscription, 0, len(subscriptions))
- for _, entry := range subscriptions {
- if channelID != "" && strings.TrimSpace(entry.ChannelID) == "" {
- entry.ChannelID = channelID
- }
- items = append(items, entry)
- }
- return items
- }
|