Browse Source

Update WebSocket implementation to use Gorilla WebSocket library

- Replaced the existing `github.com/coder/websocket` package with `github.com/gorilla/websocket` for improved functionality and compatibility.
- Adjusted WebSocket connection handling, including message reading and writing, to align with the new library's API.
- Enhanced error handling and logging for WebSocket operations.
- Bumped version to v0.20.0 to reflect the changes made.
main
mleku 2 months ago
parent
commit
88ebf6eccc
No known key found for this signature in database
  1. 119
      app/handle-websocket.go
  2. 13
      app/listener.go
  3. 17
      app/publisher.go
  4. 2
      go.mod
  5. 4
      go.sum
  6. 47
      pkg/acl/follows.go
  7. 67
      pkg/protocol/ws/connection.go
  8. 31
      pkg/protocol/ws/connection_options.go
  9. 2
      pkg/version/version

119
app/handle-websocket.go

@ -7,7 +7,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/coder/websocket" "github.com/gorilla/websocket"
"lol.mleku.dev/chk" "lol.mleku.dev/chk"
"lol.mleku.dev/log" "lol.mleku.dev/log"
"next.orly.dev/pkg/encoders/envelopes/authenvelope" "next.orly.dev/pkg/encoders/envelopes/authenvelope"
@ -24,21 +24,16 @@ const (
// ClientMessageSizeLimit is the maximum message size that clients can handle // ClientMessageSizeLimit is the maximum message size that clients can handle
// This is set to 100MB to allow large messages // This is set to 100MB to allow large messages
ClientMessageSizeLimit = 100 * 1024 * 1024 // 100MB ClientMessageSizeLimit = 100 * 1024 * 1024 // 100MB
// CloseMessage denotes a close control message. The optional message
// payload contains a numeric code and text. Use the FormatCloseMessage
// function to format a close message payload.
CloseMessage = 8
// PingMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PingMessage = 9
// PongMessage denotes a pong control message. The optional message payload
// is UTF-8 encoded text.
PongMessage = 10
) )
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for proxy compatibility
},
}
func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
remote := GetRemoteFromReq(r) remote := GetRemoteFromReq(r)
@ -62,16 +57,12 @@ whitelist:
defer cancel() defer cancel()
var err error var err error
var conn *websocket.Conn var conn *websocket.Conn
// Configure WebSocket accept options for proxy compatibility
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, // Allow all origins for proxy compatibility
// Don't check origin when behind a proxy - let the proxy handle it
InsecureSkipVerify: true,
// Try to set a higher compression threshold to allow larger messages
CompressionMode: websocket.CompressionDisabled,
}
if conn, err = websocket.Accept(w, r, acceptOptions); chk.E(err) { // Configure upgrader for this connection
upgrader.ReadBufferSize = int(DefaultMaxMessageSize)
upgrader.WriteBufferSize = int(DefaultMaxMessageSize)
if conn, err = upgrader.Upgrade(w, r, nil); chk.E(err) {
log.E.F("websocket accept failed from %s: %v", remote, err) log.E.F("websocket accept failed from %s: %v", remote, err)
return return
} }
@ -80,7 +71,7 @@ whitelist:
// Set read limit immediately after connection is established // Set read limit immediately after connection is established
conn.SetReadLimit(DefaultMaxMessageSize) conn.SetReadLimit(DefaultMaxMessageSize)
log.D.F("set read limit to %d bytes (%d MB) for %s", DefaultMaxMessageSize, DefaultMaxMessageSize/units.Mb, remote) log.D.F("set read limit to %d bytes (%d MB) for %s", DefaultMaxMessageSize, DefaultMaxMessageSize/units.Mb, remote)
defer conn.CloseNow() defer conn.Close()
listener := &Listener{ listener := &Listener{
ctx: ctx, ctx: ctx,
Server: s, Server: s,
@ -109,6 +100,16 @@ whitelist:
log.D.F("AUTH challenge sent successfully to %s", remote) log.D.F("AUTH challenge sent successfully to %s", remote)
} }
ticker := time.NewTicker(DefaultPingWait) ticker := time.NewTicker(DefaultPingWait)
// Set pong handler
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
return nil
})
// Set ping handler
conn.SetPingHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(DefaultWriteTimeout))
})
// Don't pass cancel to Pinger - it should not be able to cancel the connection context // Don't pass cancel to Pinger - it should not be able to cancel the connection context
go s.Pinger(ctx, conn, ticker) go s.Pinger(ctx, conn, ticker)
defer func() { defer func() {
@ -154,12 +155,19 @@ whitelist:
return return
} }
var typ websocket.MessageType var typ int
var msg []byte var msg []byte
log.T.F("waiting for message from %s", remote) log.T.F("waiting for message from %s", remote)
// Set read deadline for context cancellation
deadline := time.Now().Add(DefaultPongWait)
if ctx.Err() != nil {
return
}
conn.SetReadDeadline(deadline)
// Block waiting for message; rely on pings and context cancellation to detect dead peers // Block waiting for message; rely on pings and context cancellation to detect dead peers
typ, msg, err = conn.Read(ctx) typ, msg, err = conn.ReadMessage()
if err != nil { if err != nil {
// Check if the error is due to context cancellation // Check if the error is due to context cancellation
@ -180,50 +188,40 @@ whitelist:
return return
} }
// Handle message too big errors specifically // Handle message too big errors specifically
if strings.Contains(err.Error(), "MessageTooBig") || if strings.Contains(err.Error(), "message too large") ||
strings.Contains(err.Error(), "read limited at") { strings.Contains(err.Error(), "read limited at") {
log.D.F("client %s hit message size limit: %v", remote, err) log.D.F("client %s hit message size limit: %v", remote, err)
// Don't log this as an error since it's a client-side limit // Don't log this as an error since it's a client-side limit
// Just close the connection gracefully // Just close the connection gracefully
return return
} }
status := websocket.CloseStatus(err) // Check for websocket close errors
switch status { if websocket.IsCloseError(err, websocket.CloseNormalClosure,
case websocket.StatusNormalClosure, websocket.CloseGoingAway,
websocket.StatusGoingAway, websocket.CloseNoStatusReceived,
websocket.StatusNoStatusRcvd, websocket.CloseAbnormalClosure,
websocket.StatusAbnormalClosure, websocket.CloseUnsupportedData,
websocket.StatusProtocolError: websocket.CloseInvalidFramePayloadData) {
log.T.F( log.T.F("connection from %s closed: %v", remote, err)
"connection from %s closed with status: %v", remote, status, } else if websocket.IsCloseError(err, websocket.CloseMessageTooBig) {
)
case websocket.StatusMessageTooBig:
log.D.F("client %s sent message too big: %v", remote, err) log.D.F("client %s sent message too big: %v", remote, err)
default: } else {
log.E.F("unexpected close error from %s: %v", remote, err) log.E.F("unexpected close error from %s: %v", remote, err)
} }
return return
} }
if typ == PingMessage { if typ == websocket.PingMessage {
log.D.F("received PING from %s, sending PONG", remote) log.D.F("received PING from %s, sending PONG", remote)
// Create a write context with timeout for pong response // Create a write context with timeout for pong response
writeCtx, writeCancel := context.WithTimeout( deadline := time.Now().Add(DefaultWriteTimeout)
ctx, DefaultWriteTimeout, conn.SetWriteDeadline(deadline)
)
pongStart := time.Now() pongStart := time.Now()
if err = conn.Write(writeCtx, PongMessage, msg); chk.E(err) { if err = conn.WriteControl(websocket.PongMessage, msg, deadline); chk.E(err) {
pongDuration := time.Since(pongStart) pongDuration := time.Since(pongStart)
log.E.F( log.E.F(
"failed to send PONG to %s after %v: %v", remote, "failed to send PONG to %s after %v: %v", remote,
pongDuration, err, pongDuration, err,
) )
if writeCtx.Err() != nil {
log.E.F(
"PONG write timeout to %s after %v (limit=%v)", remote,
pongDuration, DefaultWriteTimeout,
)
}
writeCancel()
return return
} }
pongDuration := time.Since(pongStart) pongDuration := time.Since(pongStart)
@ -231,7 +229,6 @@ whitelist:
if pongDuration > time.Millisecond*50 { if pongDuration > time.Millisecond*50 {
log.D.F("SLOW PONG to %s: %v (>50ms)", remote, pongDuration) log.D.F("SLOW PONG to %s: %v (>50ms)", remote, pongDuration)
} }
writeCancel()
continue continue
} }
// Log message size for debugging // Log message size for debugging
@ -260,26 +257,18 @@ func (s *Server) Pinger(
pingCount++ pingCount++
log.D.F("sending PING #%d", pingCount) log.D.F("sending PING #%d", pingCount)
// Create a write context with timeout for ping operation // Set write deadline for ping operation
pingCtx, pingCancel := context.WithTimeout(ctx, DefaultWriteTimeout) deadline := time.Now().Add(DefaultWriteTimeout)
conn.SetWriteDeadline(deadline)
pingStart := time.Now() pingStart := time.Now()
if err = conn.Ping(pingCtx); err != nil { if err = conn.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil {
pingDuration := time.Since(pingStart) pingDuration := time.Since(pingStart)
log.E.F( log.E.F(
"PING #%d FAILED after %v: %v", pingCount, pingDuration, "PING #%d FAILED after %v: %v", pingCount, pingDuration,
err, err,
) )
if pingCtx.Err() != nil {
log.E.F(
"PING #%d timeout after %v (limit=%v)", pingCount,
pingDuration, DefaultWriteTimeout,
)
}
chk.E(err) chk.E(err)
pingCancel()
return return
} }
@ -289,8 +278,6 @@ func (s *Server) Pinger(
if pingDuration > time.Millisecond*100 { if pingDuration > time.Millisecond*100 {
log.D.F("SLOW PING #%d: %v (>100ms)", pingCount, pingDuration) log.D.F("SLOW PING #%d: %v (>100ms)", pingCount, pingDuration)
} }
pingCancel()
case <-ctx.Done(): case <-ctx.Done():
log.T.F("pinger context cancelled after %d pings", pingCount) log.T.F("pinger context cancelled after %d pings", pingCount)
return return

13
app/listener.go

@ -3,9 +3,10 @@ package app
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/coder/websocket" "github.com/gorilla/websocket"
"lol.mleku.dev/chk" "lol.mleku.dev/chk"
"lol.mleku.dev/log" "lol.mleku.dev/log"
"next.orly.dev/pkg/acl" "next.orly.dev/pkg/acl"
@ -54,14 +55,12 @@ func (l *Listener) Write(p []byte) (n int, err error) {
// Use a separate context with timeout for writes to prevent race conditions // Use a separate context with timeout for writes to prevent race conditions
// where the main connection context gets cancelled while writing events // where the main connection context gets cancelled while writing events
writeCtx, cancel := context.WithTimeout( deadline := time.Now().Add(DefaultWriteTimeout)
context.Background(), DefaultWriteTimeout, l.conn.SetWriteDeadline(deadline)
)
defer cancel()
// Attempt the write operation // Attempt the write operation
writeStart := time.Now() writeStart := time.Now()
if err = l.conn.Write(writeCtx, websocket.MessageText, p); err != nil { if err = l.conn.WriteMessage(websocket.TextMessage, p); err != nil {
writeDuration := time.Since(writeStart) writeDuration := time.Since(writeStart)
totalDuration := time.Since(start) totalDuration := time.Since(start)
@ -72,7 +71,7 @@ func (l *Listener) Write(p []byte) (n int, err error) {
) )
// Check if this is a context timeout // Check if this is a context timeout
if writeCtx.Err() != nil { if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline") {
log.E.F( log.E.F(
"ws->%s write timeout after %v (limit=%v)", l.remote, "ws->%s write timeout after %v (limit=%v)", l.remote,
writeDuration, DefaultWriteTimeout, writeDuration, DefaultWriteTimeout,

17
app/publisher.go

@ -3,10 +3,11 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"time" "time"
"github.com/coder/websocket" "github.com/gorilla/websocket"
"lol.mleku.dev/chk" "lol.mleku.dev/chk"
"lol.mleku.dev/log" "lol.mleku.dev/log"
"next.orly.dev/pkg/acl" "next.orly.dev/pkg/acl"
@ -270,15 +271,11 @@ func (p *P) Deliver(ev *event.E) {
// Use a separate context with timeout for writes to prevent race conditions // Use a separate context with timeout for writes to prevent race conditions
// where the publisher context gets cancelled while writing events // where the publisher context gets cancelled while writing events
writeCtx, cancel := context.WithTimeout( deadline := time.Now().Add(DefaultWriteTimeout)
context.Background(), DefaultWriteTimeout, d.w.SetWriteDeadline(deadline)
)
defer cancel()
deliveryStart := time.Now() deliveryStart := time.Now()
if err = d.w.Write( if err = d.w.WriteMessage(websocket.TextMessage, msgData); err != nil {
writeCtx, websocket.MessageText, msgData,
); err != nil {
deliveryDuration := time.Since(deliveryStart) deliveryDuration := time.Since(deliveryStart)
// Log detailed failure information // Log detailed failure information
@ -286,7 +283,7 @@ func (p *P) Deliver(ev *event.E) {
hex.Enc(ev.ID), d.sub.remote, d.id, deliveryDuration, err) hex.Enc(ev.ID), d.sub.remote, d.id, deliveryDuration, err)
// Check for timeout specifically // Check for timeout specifically
if writeCtx.Err() != nil { if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline") {
log.E.F("subscription delivery TIMEOUT: event=%s to=%s after %v (limit=%v)", log.E.F("subscription delivery TIMEOUT: event=%s to=%s after %v (limit=%v)",
hex.Enc(ev.ID), d.sub.remote, deliveryDuration, DefaultWriteTimeout) hex.Enc(ev.ID), d.sub.remote, deliveryDuration, DefaultWriteTimeout)
} }
@ -296,7 +293,7 @@ func (p *P) Deliver(ev *event.E) {
// On error, remove the subscriber connection safely // On error, remove the subscriber connection safely
p.removeSubscriber(d.w) p.removeSubscriber(d.w)
_ = d.w.CloseNow() _ = d.w.Close()
continue continue
} }

2
go.mod

@ -4,9 +4,9 @@ go 1.25.0
require ( require (
github.com/adrg/xdg v0.5.3 github.com/adrg/xdg v0.5.3
github.com/coder/websocket v1.8.14
github.com/davecgh/go-spew v1.1.1 github.com/davecgh/go-spew v1.1.1
github.com/dgraph-io/badger/v4 v4.8.0 github.com/dgraph-io/badger/v4 v4.8.0
github.com/gorilla/websocket v1.5.3
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0
github.com/klauspost/cpuid/v2 v2.3.0 github.com/klauspost/cpuid/v2 v2.3.0
github.com/pkg/profile v1.7.0 github.com/pkg/profile v1.7.0

4
go.sum

@ -13,8 +13,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -45,6 +43,8 @@ github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8I
github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d h1:KJIErDwbSHjnp/SGzE5ed8Aol7JsKiI5X7yWKAtzhM0= github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d h1:KJIErDwbSHjnp/SGzE5ed8Aol7JsKiI5X7yWKAtzhM0=
github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/ianlancetaylor/demangle v0.0.0-20230524184225-eabc099b10ab/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw= github.com/ianlancetaylor/demangle v0.0.0-20230524184225-eabc099b10ab/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=

47
pkg/acl/follows.go

@ -10,7 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/coder/websocket" "github.com/gorilla/websocket"
"lol.mleku.dev/chk" "lol.mleku.dev/chk"
"lol.mleku.dev/errorf" "lol.mleku.dev/errorf"
"lol.mleku.dev/log" "lol.mleku.dev/log"
@ -396,12 +396,15 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
headers.Set("Origin", "https://orly.dev") headers.Set("Origin", "https://orly.dev")
// Use proper WebSocket dial options // Use proper WebSocket dial options
dialOptions := &websocket.DialOptions{ dialer := websocket.Dialer{
HTTPHeader: headers, HandshakeTimeout: 10 * time.Second,
} }
c, _, err := websocket.Dial(connCtx, u, dialOptions) c, resp, err := dialer.DialContext(connCtx, u, headers)
cancel() cancel()
if resp != nil {
resp.Body.Close()
}
if err != nil { if err != nil {
log.W.F("follows syncer: dial %s failed: %v", u, err) log.W.F("follows syncer: dial %s failed: %v", u, err)
@ -480,13 +483,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
req := reqenvelope.NewFrom([]byte(subID), ff) req := reqenvelope.NewFrom([]byte(subID), ff)
reqBytes := req.Marshal(nil) reqBytes := req.Marshal(nil)
log.T.F("follows syncer: outbound REQ to %s: %s", u, string(reqBytes)) log.T.F("follows syncer: outbound REQ to %s: %s", u, string(reqBytes))
if err = c.Write( c.SetWriteDeadline(time.Now().Add(10 * time.Second))
ctx, websocket.MessageText, reqBytes, if err = c.WriteMessage(websocket.TextMessage, reqBytes); chk.E(err) {
); chk.E(err) {
log.W.F( log.W.F(
"follows syncer: failed to send event REQ to %s: %v", u, err, "follows syncer: failed to send event REQ to %s: %v", u, err,
) )
_ = c.Close(websocket.StatusInternalError, "write failed") _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "write failed"), time.Now().Add(time.Second))
continue continue
} }
log.T.F( log.T.F(
@ -501,11 +503,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
_ = c.Close(websocket.StatusNormalClosure, "ctx done") _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "ctx done"), time.Now().Add(time.Second))
return return
case <-keepaliveTicker.C: case <-keepaliveTicker.C:
// Send ping to keep connection alive // Send ping to keep connection alive
if err := c.Ping(ctx); err != nil { c.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
log.T.F("follows syncer: ping failed for %s: %v", u, err) log.T.F("follows syncer: ping failed for %s: %v", u, err)
break readLoop break readLoop
} }
@ -513,11 +516,10 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
continue continue
default: default:
// Set a read timeout to avoid hanging // Set a read timeout to avoid hanging
readCtx, readCancel := context.WithTimeout(ctx, 60*time.Second) c.SetReadDeadline(time.Now().Add(60 * time.Second))
_, data, err := c.Read(readCtx) _, data, err := c.ReadMessage()
readCancel()
if err != nil { if err != nil {
_ = c.Close(websocket.StatusNormalClosure, "read err") _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "read err"), time.Now().Add(time.Second))
break readLoop break readLoop
} }
label, rem, err := envelopes.Identify(data) label, rem, err := envelopes.Identify(data)
@ -714,16 +716,19 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
headers.Set("Origin", "https://orly.dev") headers.Set("Origin", "https://orly.dev")
// Use proper WebSocket dial options // Use proper WebSocket dial options
dialOptions := &websocket.DialOptions{ dialer := websocket.Dialer{
HTTPHeader: headers, HandshakeTimeout: 10 * time.Second,
} }
c, _, err := websocket.Dial(ctx, relayURL, dialOptions) c, resp, err := dialer.DialContext(ctx, relayURL, headers)
if resp != nil {
resp.Body.Close()
}
if err != nil { if err != nil {
log.W.F("follows syncer: failed to connect to %s for follow list fetch: %v", relayURL, err) log.W.F("follows syncer: failed to connect to %s for follow list fetch: %v", relayURL, err)
return return
} }
defer c.Close(websocket.StatusNormalClosure, "follow list fetch complete") defer c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "follow list fetch complete"), time.Now().Add(time.Second))
log.I.F("follows syncer: fetching follow lists from relay %s", relayURL) log.I.F("follows syncer: fetching follow lists from relay %s", relayURL)
@ -746,7 +751,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
req := reqenvelope.NewFrom([]byte(subID), ff) req := reqenvelope.NewFrom([]byte(subID), ff)
reqBytes := req.Marshal(nil) reqBytes := req.Marshal(nil)
log.T.F("follows syncer: outbound REQ to %s: %s", relayURL, string(reqBytes)) log.T.F("follows syncer: outbound REQ to %s: %s", relayURL, string(reqBytes))
if err = c.Write(ctx, websocket.MessageText, reqBytes); chk.E(err) { c.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err = c.WriteMessage(websocket.TextMessage, reqBytes); chk.E(err) {
log.W.F("follows syncer: failed to send follow list REQ to %s: %v", relayURL, err) log.W.F("follows syncer: failed to send follow list REQ to %s: %v", relayURL, err)
return return
} }
@ -769,7 +775,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
default: default:
} }
_, data, err := c.Read(ctx) c.SetReadDeadline(time.Now().Add(10 * time.Second))
_, data, err := c.ReadMessage()
if err != nil { if err != nil {
log.T.F("follows syncer: error reading events from %s: %v", relayURL, err) log.T.F("follows syncer: error reading events from %s: %v", relayURL, err)
goto processEvents goto processEvents

67
pkg/protocol/ws/connection.go

@ -3,21 +3,19 @@ package ws
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"time" "time"
"github.com/gorilla/websocket"
"lol.mleku.dev/errorf" "lol.mleku.dev/errorf"
"next.orly.dev/pkg/utils/units" "next.orly.dev/pkg/utils/units"
ws "github.com/coder/websocket"
) )
// Connection represents a websocket connection to a Nostr relay. // Connection represents a websocket connection to a Nostr relay.
type Connection struct { type Connection struct {
conn *ws.Conn conn *websocket.Conn
} }
// NewConnection creates a new websocket connection to a Nostr relay. // NewConnection creates a new websocket connection to a Nostr relay.
@ -25,10 +23,23 @@ func NewConnection(
ctx context.Context, url string, reqHeader http.Header, ctx context.Context, url string, reqHeader http.Header,
tlsConfig *tls.Config, tlsConfig *tls.Config,
) (c *Connection, err error) { ) (c *Connection, err error) {
var conn *ws.Conn var conn *websocket.Conn
if conn, _, err = ws.Dial( var resp *http.Response
ctx, url, getConnectionOptions(reqHeader, tlsConfig), dialer := getConnectionOptions(reqHeader, tlsConfig)
); err != nil {
// Prepare headers with default User-Agent if not present
headers := reqHeader
if headers == nil {
headers = make(http.Header)
}
if headers.Get("User-Agent") == "" {
headers.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
}
if conn, resp, err = dialer.DialContext(ctx, url, headers); err != nil {
if resp != nil {
resp.Body.Close()
}
return return
} }
conn.SetReadLimit(33 * units.Mb) conn.SetReadLimit(33 * units.Mb)
@ -41,7 +52,14 @@ func NewConnection(
func (c *Connection) WriteMessage( func (c *Connection) WriteMessage(
ctx context.Context, data []byte, ctx context.Context, data []byte,
) (err error) { ) (err error) {
if err = c.conn.Write(ctx, ws.MessageText, data); err != nil { deadline := time.Now().Add(10 * time.Second)
if ctx != nil {
if d, ok := ctx.Deadline(); ok {
deadline = d
}
}
c.conn.SetWriteDeadline(deadline)
if err = c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
err = errorf.E("failed to write message: %w", err) err = errorf.E("failed to write message: %w", err)
return return
} }
@ -52,11 +70,22 @@ func (c *Connection) WriteMessage(
func (c *Connection) ReadMessage( func (c *Connection) ReadMessage(
ctx context.Context, buf io.Writer, ctx context.Context, buf io.Writer,
) (err error) { ) (err error) {
var reader io.Reader deadline := time.Now().Add(60 * time.Second)
if _, reader, err = c.conn.Reader(ctx); err != nil { if ctx != nil {
if d, ok := ctx.Deadline(); ok {
deadline = d
}
}
c.conn.SetReadDeadline(deadline)
messageType, reader, err := c.conn.NextReader()
if err != nil {
err = fmt.Errorf("failed to get reader: %w", err) err = fmt.Errorf("failed to get reader: %w", err)
return return
} }
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
err = fmt.Errorf("unexpected message type: %d", messageType)
return
}
if _, err = io.Copy(buf, reader); err != nil { if _, err = io.Copy(buf, reader); err != nil {
err = fmt.Errorf("failed to read message: %w", err) err = fmt.Errorf("failed to read message: %w", err)
return return
@ -66,14 +95,18 @@ func (c *Connection) ReadMessage(
// Close closes the websocket connection. // Close closes the websocket connection.
func (c *Connection) Close() error { func (c *Connection) Close() error {
return c.conn.Close(ws.StatusNormalClosure, "") c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second))
return c.conn.Close()
} }
// Ping sends a ping message to the websocket connection. // Ping sends a ping message to the websocket connection.
func (c *Connection) Ping(ctx context.Context) error { func (c *Connection) Ping(ctx context.Context) error {
ctx, cancel := context.WithTimeoutCause( deadline := time.Now().Add(800 * time.Millisecond)
ctx, time.Millisecond*800, errors.New("ping took too long"), if ctx != nil {
) if d, ok := ctx.Deadline(); ok {
defer cancel() deadline = d
return c.conn.Ping(ctx) }
}
c.conn.SetWriteDeadline(deadline)
return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline)
} }

31
pkg/protocol/ws/connection_options.go

@ -5,32 +5,21 @@ package ws
import ( import (
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"net/textproto" "time"
ws "github.com/coder/websocket" "github.com/gorilla/websocket"
) )
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}
func getConnectionOptions( func getConnectionOptions(
requestHeader http.Header, tlsConfig *tls.Config, requestHeader http.Header, tlsConfig *tls.Config,
) *ws.DialOptions { ) *websocket.Dialer {
if requestHeader == nil && tlsConfig == nil { dialer := &websocket.Dialer{
return defaultConnectionOptions ReadBufferSize: 1024,
} WriteBufferSize: 1024,
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
}, HandshakeTimeout: 10 * time.Second,
},
} }
// Headers are passed directly to DialContext, not set on Dialer
// The User-Agent header will be set when calling DialContext if not present
return dialer
} }

2
pkg/version/version

@ -1 +1 @@
v0.19.9 v0.20.0
Loading…
Cancel
Save