-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsocket.go
More file actions
374 lines (329 loc) · 12.8 KB
/
socket.go
File metadata and controls
374 lines (329 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
package velaros
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/coder/websocket"
"github.com/google/uuid"
)
// MessageType represents the type of a WebSocket message (text or binary). This is
// a type alias for websocket.MessageType from the github.com/coder/websocket package.
type MessageType = websocket.MessageType
const (
MessageText MessageType = websocket.MessageText
MessageBinary MessageType = websocket.MessageBinary
)
// SocketMessage represents a WebSocket message at the transport layer, containing
// the message type, raw data, processed data, and metadata. This is used by
// SocketConnection implementations to pass messages between the WebSocket layer
// and the routing layer.
type SocketMessage struct {
Type MessageType
RawData []byte
Data []byte
Meta map[string]any
}
// SocketConnection is an interface for WebSocket connection implementations. This
// allows Velaros to work with different WebSocket libraries or custom connection types.
// The framework provides WebSocketConnection for the standard github.com/coder/websocket
// library.
type SocketConnection interface {
Read(ctx context.Context) (*SocketMessage, error)
Write(ctx context.Context, msg *SocketMessage) error
Close(status Status, reason string) error
}
// receiverEntry holds a registered receiver for a specific handler within a node.
// Receivers are used for multi-message conversations: a handler registers a receiver
// via ReceiveInto, and the next same-ID message is delivered to it at exactly this
// node+handlerIndex rather than spawning a new handler instance.
type receiverEntry struct {
node *HandlerNode
handlerIndex int
ch chan *InboundMessage
}
// Socket represents a WebSocket connection and manages its lifecycle, message
// interception, and connection-level storage. It implements the context.Context
// interface to support cancellation and deadlines.
//
// Socket is primarily used internally by the router but is exported to allow
// advanced use cases and custom integrations. Most users will interact with
// Socket indirectly through Context methods like SetOnSocket and GetFromSocket.
//
// Key responsibilities:
// - Connection lifecycle management (open, close, done signaling)
// - Thread-safe connection-level value storage
// - Message receiver registration for multi-message conversations
// - Message-ID locking to serialize same-ID messages through the handler chain
// - Access to original HTTP upgrade request headers
type Socket struct {
id string
connectionInfo *ConnectionInfo
connection SocketConnection
receiversMx sync.Mutex
receivers map[string]*receiverEntry
messageIDLocksMx sync.Mutex
messageIDLocks map[string]chan struct{}
associatedValuesMx sync.RWMutex
associatedValues map[string]any
closeMu sync.Mutex
closed bool
closeStatus Status
closeStatusSource CloseSource
closeReason string
ctx context.Context
cancelCtx context.CancelFunc
}
var _ context.Context = &Socket{}
// NewSocket creates a new Socket wrapping a WebSocket connection. This is
// primarily for internal use by the router. The socket ID is automatically
// generated and the done channel is initialized.
func NewSocket(info *ConnectionInfo, conn SocketConnection) *Socket {
s := &Socket{
id: uuid.NewString(),
connectionInfo: info,
connection: conn,
receivers: map[string]*receiverEntry{},
messageIDLocks: map[string]chan struct{}{},
associatedValues: map[string]any{},
}
s.ctx, s.cancelCtx = context.WithCancel(context.Background())
return s
}
// ID returns the unique identifier for this socket. The ID is automatically
// generated when the socket is created and remains constant for the connection's lifetime.
func (s *Socket) ID() string {
return s.id
}
// Headers returns the HTTP headers from the initial WebSocket upgrade request.
// These headers persist for the lifetime of the connection and are useful for
// accessing authentication tokens, cookies, or custom headers sent during the handshake.
func (s *Socket) Headers() http.Header {
if s.connectionInfo != nil && s.connectionInfo.Headers != nil {
return s.connectionInfo.Headers
}
return http.Header{}
}
// RemoteAddr returns the remote network address of the client. The format depends
// on the underlying connection but is typically 'IP:port'.
func (s *Socket) RemoteAddr() string {
if s.connectionInfo != nil {
return s.connectionInfo.RemoteAddr
}
return ""
}
// Close marks the socket as closed with the given status code, reason, and source
// (client or server). This is thread-safe and idempotent - subsequent calls have
// no effect. The actual connection close happens after UseClose handlers complete.
func (s *Socket) Close(status Status, reason string, source CloseSource) {
s.closeMu.Lock()
defer s.closeMu.Unlock()
if s.closed {
return
}
s.closed = true
s.closeStatus = status
s.closeReason = reason
s.closeStatusSource = source
s.receiversMx.Lock()
for id := range s.receivers {
close(s.receivers[id].ch)
delete(s.receivers, id)
}
s.receiversMx.Unlock()
s.messageIDLocksMx.Lock()
s.messageIDLocks = map[string]chan struct{}{}
s.messageIDLocksMx.Unlock()
s.cancelCtx()
}
// IsClosed returns true if the socket has been closed. This is thread-safe and
// can be called from any goroutine.
func (s *Socket) IsClosed() bool {
s.closeMu.Lock()
defer s.closeMu.Unlock()
return s.closed
}
// Send writes a message to the WebSocket connection with the specified message type
// (MessageText or MessageBinary). This is a low-level method - most users should use
// Context.Send instead.
func (s *Socket) Send(messageType MessageType, data []byte) error {
return s.SendWithContext(context.Background(), messageType, data)
}
// SendWithContext writes a message to the WebSocket connection with the specified
// message type, using the provided context for cancellation and deadlines.
func (s *Socket) SendWithContext(ctx context.Context, messageType MessageType, data []byte) error {
return s.connection.Write(ctx, &SocketMessage{
Type: messageType,
Data: data,
})
}
// Set stores a value at the socket/connection level. This is thread-safe and values
// persist for the lifetime of the connection. Use Context.SetOnSocket instead of
// calling this directly.
func (s *Socket) Set(key string, value any) {
s.associatedValuesMx.Lock()
s.associatedValues[key] = value
s.associatedValuesMx.Unlock()
}
// Get retrieves a value stored at the socket/connection level. Returns the value
// and true if found, or nil and false otherwise. This is thread-safe. Use
// Context.GetFromSocket instead of calling this directly.
func (s *Socket) Get(key string) (any, bool) {
s.associatedValuesMx.RLock()
v, ok := s.associatedValues[key]
s.associatedValuesMx.RUnlock()
return v, ok
}
// MustGet retrieves a value stored at the socket/connection level. Panics if the
// key is not found. This is thread-safe. Use Context.MustGetFromSocket instead of
// calling this directly.
func (s *Socket) MustGet(key string) any {
s.associatedValuesMx.RLock()
v, ok := s.associatedValues[key]
s.associatedValuesMx.RUnlock()
if !ok {
panic(fmt.Sprintf("key %s not found", key))
}
return v
}
// Delete removes a value stored at the socket/connection level. This is thread-safe.
// Use Context.DeleteFromSocket instead of calling this directly.
func (s *Socket) Delete(key string) {
s.associatedValuesMx.Lock()
delete(s.associatedValues, key)
s.associatedValuesMx.Unlock()
}
// HandleNextMessageWithNode reads the next message from the connection and processes
// it through the handler chain starting at the given node. Returns false if the
// connection is closed or an error occurs. This is an internal method used by the router.
func (s *Socket) HandleNextMessageWithNode(node *HandlerNode) bool {
msg, err := s.connection.Read(s)
if err != nil {
closeStatus := websocket.CloseStatus(err)
if closeStatus != -1 {
s.Close(Status(closeStatus), "", ClientCloseSource)
return false
}
if errors.Is(err, context.Canceled) {
s.Close(StatusGoingAway, "", ServerCloseSource)
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, net.ErrClosed) {
s.Close(StatusGoingAway, "", ClientCloseSource)
return false
}
panic(fmt.Errorf("error reading socket message: %w", err))
}
go func() {
inboundMsg := inboundMessageFromPool()
inboundMsg.RawData = msg.RawData
inboundMsg.Data = msg.Data
inboundMsg.Meta = msg.Meta
ctx := NewContextWithNodeAndMessageType(s, inboundMsg, node, msg.Type)
ctx.Next()
if ctx.message != nil && ctx.message.hasSetID {
s.releaseMessageIDLock(ctx.message.ID)
}
ctx.finalize()
}()
return true
}
// HandleOpen executes the open lifecycle handlers starting at the given node.
// This is an internal method used by the router when a new connection is established.
func (s *Socket) HandleOpen(node *HandlerNode) {
openCtx := NewContextWithNode(s, inboundMessageFromPool(), node)
openCtx.Next()
openCtx.finalize()
}
// HandleClose executes the close lifecycle handlers starting at the given node.
// This is an internal method used by the router when a connection is closing.
func (s *Socket) HandleClose(node *HandlerNode) {
closeCtx := NewContextWithNode(s, inboundMessageFromPool(), node)
closeCtx.Next()
closeCtx.finalize()
}
// GetReceiverForNode retrieves the receiver channel for a given message ID, but only
// if the registered receiver matches the given node pointer and handler index. This
// ensures same-ID continuation messages are only consumed at the exact handler that
// registered the receiver, allowing middleware on earlier nodes to still run.
func (s *Socket) GetReceiverForNode(id string, node *HandlerNode, handlerIndex int) (chan *InboundMessage, bool) {
s.receiversMx.Lock()
defer s.receiversMx.Unlock()
entry, ok := s.receivers[id]
if !ok || entry.node != node || entry.handlerIndex != handlerIndex {
return nil, false
}
return entry.ch, true
}
// AddReceiver registers a receiver for a given message ID at a specific handler node
// and index. The next same-ID message will be delivered to this receiver rather than
// invoking the handler again. Receivers are self-consuming: they are removed from
// the map upon delivery.
func (s *Socket) AddReceiver(id string, node *HandlerNode, handlerIndex int, ch chan *InboundMessage) {
s.receiversMx.Lock()
defer s.receiversMx.Unlock()
s.receivers[id] = &receiverEntry{node: node, handlerIndex: handlerIndex, ch: ch}
}
// RemoveReceiver unregisters the receiver for a given message ID. This is a no-op
// if no receiver is registered. Used both for self-consuming delivery cleanup and
// as a safety net in Context.cancelCtx().
func (s *Socket) RemoveReceiver(id string) {
s.receiversMx.Lock()
defer s.receiversMx.Unlock()
delete(s.receivers, id)
}
// acquireMessageIDLock atomically checks whether a message with the given ID is
// already being processed. If not, it creates a new lock channel, stores it, and
// returns (ch, true) — the caller is first and proceeds without blocking. If a lock
// already exists, it returns (ch, false) — the caller must block on <-ch and retry.
func (s *Socket) acquireMessageIDLock(id string) (chan struct{}, bool) {
s.messageIDLocksMx.Lock()
defer s.messageIDLocksMx.Unlock()
if ch, ok := s.messageIDLocks[id]; ok {
return ch, false
}
ch := make(chan struct{})
s.messageIDLocks[id] = ch
return ch, true
}
// releaseMessageIDLock releases the lock for the given message ID. It deletes the
// entry from the map and closes the channel, which wakes all goroutines waiting in
// SetMessageID. Those goroutines retry acquireMessageIDLock and race to become the
// next active message. If there are no waiters, close is a no-op and the entry is
// simply gone.
func (s *Socket) releaseMessageIDLock(id string) {
s.messageIDLocksMx.Lock()
ch, ok := s.messageIDLocks[id]
if ok {
delete(s.messageIDLocks, id)
}
s.messageIDLocksMx.Unlock()
if ok {
close(ch)
}
}
// Deadline returns the time when work done on behalf of this socket's context should
// be canceled. Returns ok==false when no deadline is set. Part of the context.Context
// interface.
func (s *Socket) Deadline() (time.Time, bool) {
return s.ctx.Deadline()
}
// Done returns a channel that's closed when the socket's context should be canceled.
// This closes when the connection closes. Part of the context.Context interface.
func (s *Socket) Done() <-chan struct{} {
return s.ctx.Done()
}
// Err returns a non-nil error value after Done is closed. Returns Canceled if the
// context was canceled. Part of the context.Context interface.
func (s *Socket) Err() error {
return s.ctx.Err()
}
// Value returns the value associated with this socket's context for key. Part of the
// context.Context interface.
func (s *Socket) Value(key any) any {
return s.ctx.Value(key)
}