Browse Source

Update WebSocket implementation to use Gorilla WebSocket library

- Replaced the existing `github.com/coder/websocket` package with `github.com/gorilla/websocket` for improved functionality and compatibility.
- Adjusted WebSocket connection handling, including message reading and writing, to align with the new library's API.
- Enhanced error handling and logging for WebSocket operations.
- Bumped version to v0.20.0 to reflect the changes made.
main
mleku 2 months ago
parent
commit
88ebf6eccc
No known key found for this signature in database
  1. 119
      app/handle-websocket.go
  2. 13
      app/listener.go
  3. 17
      app/publisher.go
  4. 2
      go.mod
  5. 4
      go.sum
  6. 47
      pkg/acl/follows.go
  7. 67
      pkg/protocol/ws/connection.go
  8. 33
      pkg/protocol/ws/connection_options.go
  9. 2
      pkg/version/version

119
app/handle-websocket.go

@ -7,7 +7,7 @@ import ( @@ -7,7 +7,7 @@ import (
"strings"
"time"
"github.com/coder/websocket"
"github.com/gorilla/websocket"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
"next.orly.dev/pkg/encoders/envelopes/authenvelope"
@ -24,21 +24,16 @@ const ( @@ -24,21 +24,16 @@ const (
// ClientMessageSizeLimit is the maximum message size that clients can handle
// This is set to 100MB to allow large messages
ClientMessageSizeLimit = 100 * 1024 * 1024 // 100MB
// CloseMessage denotes a close control message. The optional message
// payload contains a numeric code and text. Use the FormatCloseMessage
// function to format a close message payload.
CloseMessage = 8
// PingMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PingMessage = 9
// PongMessage denotes a pong control message. The optional message payload
// is UTF-8 encoded text.
PongMessage = 10
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for proxy compatibility
},
}
func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
remote := GetRemoteFromReq(r)
@ -62,16 +57,12 @@ whitelist: @@ -62,16 +57,12 @@ whitelist:
defer cancel()
var err error
var conn *websocket.Conn
// Configure WebSocket accept options for proxy compatibility
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, // Allow all origins for proxy compatibility
// Don't check origin when behind a proxy - let the proxy handle it
InsecureSkipVerify: true,
// Try to set a higher compression threshold to allow larger messages
CompressionMode: websocket.CompressionDisabled,
}
if conn, err = websocket.Accept(w, r, acceptOptions); chk.E(err) {
// Configure upgrader for this connection
upgrader.ReadBufferSize = int(DefaultMaxMessageSize)
upgrader.WriteBufferSize = int(DefaultMaxMessageSize)
if conn, err = upgrader.Upgrade(w, r, nil); chk.E(err) {
log.E.F("websocket accept failed from %s: %v", remote, err)
return
}
@ -80,7 +71,7 @@ whitelist: @@ -80,7 +71,7 @@ whitelist:
// Set read limit immediately after connection is established
conn.SetReadLimit(DefaultMaxMessageSize)
log.D.F("set read limit to %d bytes (%d MB) for %s", DefaultMaxMessageSize, DefaultMaxMessageSize/units.Mb, remote)
defer conn.CloseNow()
defer conn.Close()
listener := &Listener{
ctx: ctx,
Server: s,
@ -109,6 +100,16 @@ whitelist: @@ -109,6 +100,16 @@ whitelist:
log.D.F("AUTH challenge sent successfully to %s", remote)
}
ticker := time.NewTicker(DefaultPingWait)
// Set pong handler
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
return nil
})
// Set ping handler
conn.SetPingHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(DefaultWriteTimeout))
})
// Don't pass cancel to Pinger - it should not be able to cancel the connection context
go s.Pinger(ctx, conn, ticker)
defer func() {
@ -154,12 +155,19 @@ whitelist: @@ -154,12 +155,19 @@ whitelist:
return
}
var typ websocket.MessageType
var typ int
var msg []byte
log.T.F("waiting for message from %s", remote)
// Set read deadline for context cancellation
deadline := time.Now().Add(DefaultPongWait)
if ctx.Err() != nil {
return
}
conn.SetReadDeadline(deadline)
// Block waiting for message; rely on pings and context cancellation to detect dead peers
typ, msg, err = conn.Read(ctx)
typ, msg, err = conn.ReadMessage()
if err != nil {
// Check if the error is due to context cancellation
@ -180,50 +188,40 @@ whitelist: @@ -180,50 +188,40 @@ whitelist:
return
}
// Handle message too big errors specifically
if strings.Contains(err.Error(), "MessageTooBig") ||
if strings.Contains(err.Error(), "message too large") ||
strings.Contains(err.Error(), "read limited at") {
log.D.F("client %s hit message size limit: %v", remote, err)
// Don't log this as an error since it's a client-side limit
// Just close the connection gracefully
return
}
status := websocket.CloseStatus(err)
switch status {
case websocket.StatusNormalClosure,
websocket.StatusGoingAway,
websocket.StatusNoStatusRcvd,
websocket.StatusAbnormalClosure,
websocket.StatusProtocolError:
log.T.F(
"connection from %s closed with status: %v", remote, status,
)
case websocket.StatusMessageTooBig:
// Check for websocket close errors
if websocket.IsCloseError(err, websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
websocket.CloseUnsupportedData,
websocket.CloseInvalidFramePayloadData) {
log.T.F("connection from %s closed: %v", remote, err)
} else if websocket.IsCloseError(err, websocket.CloseMessageTooBig) {
log.D.F("client %s sent message too big: %v", remote, err)
default:
} else {
log.E.F("unexpected close error from %s: %v", remote, err)
}
return
}
if typ == PingMessage {
if typ == websocket.PingMessage {
log.D.F("received PING from %s, sending PONG", remote)
// Create a write context with timeout for pong response
writeCtx, writeCancel := context.WithTimeout(
ctx, DefaultWriteTimeout,
)
deadline := time.Now().Add(DefaultWriteTimeout)
conn.SetWriteDeadline(deadline)
pongStart := time.Now()
if err = conn.Write(writeCtx, PongMessage, msg); chk.E(err) {
if err = conn.WriteControl(websocket.PongMessage, msg, deadline); chk.E(err) {
pongDuration := time.Since(pongStart)
log.E.F(
"failed to send PONG to %s after %v: %v", remote,
pongDuration, err,
)
if writeCtx.Err() != nil {
log.E.F(
"PONG write timeout to %s after %v (limit=%v)", remote,
pongDuration, DefaultWriteTimeout,
)
}
writeCancel()
return
}
pongDuration := time.Since(pongStart)
@ -231,7 +229,6 @@ whitelist: @@ -231,7 +229,6 @@ whitelist:
if pongDuration > time.Millisecond*50 {
log.D.F("SLOW PONG to %s: %v (>50ms)", remote, pongDuration)
}
writeCancel()
continue
}
// Log message size for debugging
@ -260,26 +257,18 @@ func (s *Server) Pinger( @@ -260,26 +257,18 @@ func (s *Server) Pinger(
pingCount++
log.D.F("sending PING #%d", pingCount)
// Create a write context with timeout for ping operation
pingCtx, pingCancel := context.WithTimeout(ctx, DefaultWriteTimeout)
// Set write deadline for ping operation
deadline := time.Now().Add(DefaultWriteTimeout)
conn.SetWriteDeadline(deadline)
pingStart := time.Now()
if err = conn.Ping(pingCtx); err != nil {
if err = conn.WriteControl(websocket.PingMessage, []byte{}, deadline); err != nil {
pingDuration := time.Since(pingStart)
log.E.F(
"PING #%d FAILED after %v: %v", pingCount, pingDuration,
err,
)
if pingCtx.Err() != nil {
log.E.F(
"PING #%d timeout after %v (limit=%v)", pingCount,
pingDuration, DefaultWriteTimeout,
)
}
chk.E(err)
pingCancel()
return
}
@ -289,8 +278,6 @@ func (s *Server) Pinger( @@ -289,8 +278,6 @@ func (s *Server) Pinger(
if pingDuration > time.Millisecond*100 {
log.D.F("SLOW PING #%d: %v (>100ms)", pingCount, pingDuration)
}
pingCancel()
case <-ctx.Done():
log.T.F("pinger context cancelled after %d pings", pingCount)
return

13
app/listener.go

@ -3,9 +3,10 @@ package app @@ -3,9 +3,10 @@ package app
import (
"context"
"net/http"
"strings"
"time"
"github.com/coder/websocket"
"github.com/gorilla/websocket"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
"next.orly.dev/pkg/acl"
@ -54,14 +55,12 @@ func (l *Listener) Write(p []byte) (n int, err error) { @@ -54,14 +55,12 @@ func (l *Listener) Write(p []byte) (n int, err error) {
// Use a separate context with timeout for writes to prevent race conditions
// where the main connection context gets cancelled while writing events
writeCtx, cancel := context.WithTimeout(
context.Background(), DefaultWriteTimeout,
)
defer cancel()
deadline := time.Now().Add(DefaultWriteTimeout)
l.conn.SetWriteDeadline(deadline)
// Attempt the write operation
writeStart := time.Now()
if err = l.conn.Write(writeCtx, websocket.MessageText, p); err != nil {
if err = l.conn.WriteMessage(websocket.TextMessage, p); err != nil {
writeDuration := time.Since(writeStart)
totalDuration := time.Since(start)
@ -72,7 +71,7 @@ func (l *Listener) Write(p []byte) (n int, err error) { @@ -72,7 +71,7 @@ func (l *Listener) Write(p []byte) (n int, err error) {
)
// Check if this is a context timeout
if writeCtx.Err() != nil {
if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline") {
log.E.F(
"ws->%s write timeout after %v (limit=%v)", l.remote,
writeDuration, DefaultWriteTimeout,

17
app/publisher.go

@ -3,10 +3,11 @@ package app @@ -3,10 +3,11 @@ package app
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/coder/websocket"
"github.com/gorilla/websocket"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
"next.orly.dev/pkg/acl"
@ -270,15 +271,11 @@ func (p *P) Deliver(ev *event.E) { @@ -270,15 +271,11 @@ func (p *P) Deliver(ev *event.E) {
// Use a separate context with timeout for writes to prevent race conditions
// where the publisher context gets cancelled while writing events
writeCtx, cancel := context.WithTimeout(
context.Background(), DefaultWriteTimeout,
)
defer cancel()
deadline := time.Now().Add(DefaultWriteTimeout)
d.w.SetWriteDeadline(deadline)
deliveryStart := time.Now()
if err = d.w.Write(
writeCtx, websocket.MessageText, msgData,
); err != nil {
if err = d.w.WriteMessage(websocket.TextMessage, msgData); err != nil {
deliveryDuration := time.Since(deliveryStart)
// Log detailed failure information
@ -286,7 +283,7 @@ func (p *P) Deliver(ev *event.E) { @@ -286,7 +283,7 @@ func (p *P) Deliver(ev *event.E) {
hex.Enc(ev.ID), d.sub.remote, d.id, deliveryDuration, err)
// Check for timeout specifically
if writeCtx.Err() != nil {
if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline") {
log.E.F("subscription delivery TIMEOUT: event=%s to=%s after %v (limit=%v)",
hex.Enc(ev.ID), d.sub.remote, deliveryDuration, DefaultWriteTimeout)
}
@ -296,7 +293,7 @@ func (p *P) Deliver(ev *event.E) { @@ -296,7 +293,7 @@ func (p *P) Deliver(ev *event.E) {
// On error, remove the subscriber connection safely
p.removeSubscriber(d.w)
_ = d.w.CloseNow()
_ = d.w.Close()
continue
}

2
go.mod

@ -4,9 +4,9 @@ go 1.25.0 @@ -4,9 +4,9 @@ go 1.25.0
require (
github.com/adrg/xdg v0.5.3
github.com/coder/websocket v1.8.14
github.com/davecgh/go-spew v1.1.1
github.com/dgraph-io/badger/v4 v4.8.0
github.com/gorilla/websocket v1.5.3
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0
github.com/klauspost/cpuid/v2 v2.3.0
github.com/pkg/profile v1.7.0

4
go.sum

@ -13,8 +13,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P @@ -13,8 +13,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -45,6 +43,8 @@ github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8I @@ -45,6 +43,8 @@ github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8I
github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d h1:KJIErDwbSHjnp/SGzE5ed8Aol7JsKiI5X7yWKAtzhM0=
github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/ianlancetaylor/demangle v0.0.0-20230524184225-eabc099b10ab/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=

47
pkg/acl/follows.go

@ -10,7 +10,7 @@ import ( @@ -10,7 +10,7 @@ import (
"sync"
"time"
"github.com/coder/websocket"
"github.com/gorilla/websocket"
"lol.mleku.dev/chk"
"lol.mleku.dev/errorf"
"lol.mleku.dev/log"
@ -396,12 +396,15 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) { @@ -396,12 +396,15 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
headers.Set("Origin", "https://orly.dev")
// Use proper WebSocket dial options
dialOptions := &websocket.DialOptions{
HTTPHeader: headers,
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
c, _, err := websocket.Dial(connCtx, u, dialOptions)
c, resp, err := dialer.DialContext(connCtx, u, headers)
cancel()
if resp != nil {
resp.Body.Close()
}
if err != nil {
log.W.F("follows syncer: dial %s failed: %v", u, err)
@ -480,13 +483,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) { @@ -480,13 +483,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
req := reqenvelope.NewFrom([]byte(subID), ff)
reqBytes := req.Marshal(nil)
log.T.F("follows syncer: outbound REQ to %s: %s", u, string(reqBytes))
if err = c.Write(
ctx, websocket.MessageText, reqBytes,
); chk.E(err) {
c.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err = c.WriteMessage(websocket.TextMessage, reqBytes); chk.E(err) {
log.W.F(
"follows syncer: failed to send event REQ to %s: %v", u, err,
)
_ = c.Close(websocket.StatusInternalError, "write failed")
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "write failed"), time.Now().Add(time.Second))
continue
}
log.T.F(
@ -501,11 +503,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) { @@ -501,11 +503,12 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
for {
select {
case <-ctx.Done():
_ = c.Close(websocket.StatusNormalClosure, "ctx done")
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "ctx done"), time.Now().Add(time.Second))
return
case <-keepaliveTicker.C:
// Send ping to keep connection alive
if err := c.Ping(ctx); err != nil {
c.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
log.T.F("follows syncer: ping failed for %s: %v", u, err)
break readLoop
}
@ -513,11 +516,10 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) { @@ -513,11 +516,10 @@ func (f *Follows) startEventSubscriptions(ctx context.Context) {
continue
default:
// Set a read timeout to avoid hanging
readCtx, readCancel := context.WithTimeout(ctx, 60*time.Second)
_, data, err := c.Read(readCtx)
readCancel()
c.SetReadDeadline(time.Now().Add(60 * time.Second))
_, data, err := c.ReadMessage()
if err != nil {
_ = c.Close(websocket.StatusNormalClosure, "read err")
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "read err"), time.Now().Add(time.Second))
break readLoop
}
label, rem, err := envelopes.Identify(data)
@ -714,16 +716,19 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) { @@ -714,16 +716,19 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
headers.Set("Origin", "https://orly.dev")
// Use proper WebSocket dial options
dialOptions := &websocket.DialOptions{
HTTPHeader: headers,
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
c, _, err := websocket.Dial(ctx, relayURL, dialOptions)
c, resp, err := dialer.DialContext(ctx, relayURL, headers)
if resp != nil {
resp.Body.Close()
}
if err != nil {
log.W.F("follows syncer: failed to connect to %s for follow list fetch: %v", relayURL, err)
return
}
defer c.Close(websocket.StatusNormalClosure, "follow list fetch complete")
defer c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "follow list fetch complete"), time.Now().Add(time.Second))
log.I.F("follows syncer: fetching follow lists from relay %s", relayURL)
@ -746,7 +751,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) { @@ -746,7 +751,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
req := reqenvelope.NewFrom([]byte(subID), ff)
reqBytes := req.Marshal(nil)
log.T.F("follows syncer: outbound REQ to %s: %s", relayURL, string(reqBytes))
if err = c.Write(ctx, websocket.MessageText, reqBytes); chk.E(err) {
c.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err = c.WriteMessage(websocket.TextMessage, reqBytes); chk.E(err) {
log.W.F("follows syncer: failed to send follow list REQ to %s: %v", relayURL, err)
return
}
@ -769,7 +775,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) { @@ -769,7 +775,8 @@ func (f *Follows) fetchFollowListsFromRelay(relayURL string, authors [][]byte) {
default:
}
_, data, err := c.Read(ctx)
c.SetReadDeadline(time.Now().Add(10 * time.Second))
_, data, err := c.ReadMessage()
if err != nil {
log.T.F("follows syncer: error reading events from %s: %v", relayURL, err)
goto processEvents

67
pkg/protocol/ws/connection.go

@ -3,21 +3,19 @@ package ws @@ -3,21 +3,19 @@ package ws
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/gorilla/websocket"
"lol.mleku.dev/errorf"
"next.orly.dev/pkg/utils/units"
ws "github.com/coder/websocket"
)
// Connection represents a websocket connection to a Nostr relay.
type Connection struct {
conn *ws.Conn
conn *websocket.Conn
}
// NewConnection creates a new websocket connection to a Nostr relay.
@ -25,10 +23,23 @@ func NewConnection( @@ -25,10 +23,23 @@ func NewConnection(
ctx context.Context, url string, reqHeader http.Header,
tlsConfig *tls.Config,
) (c *Connection, err error) {
var conn *ws.Conn
if conn, _, err = ws.Dial(
ctx, url, getConnectionOptions(reqHeader, tlsConfig),
); err != nil {
var conn *websocket.Conn
var resp *http.Response
dialer := getConnectionOptions(reqHeader, tlsConfig)
// Prepare headers with default User-Agent if not present
headers := reqHeader
if headers == nil {
headers = make(http.Header)
}
if headers.Get("User-Agent") == "" {
headers.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
}
if conn, resp, err = dialer.DialContext(ctx, url, headers); err != nil {
if resp != nil {
resp.Body.Close()
}
return
}
conn.SetReadLimit(33 * units.Mb)
@ -41,7 +52,14 @@ func NewConnection( @@ -41,7 +52,14 @@ func NewConnection(
func (c *Connection) WriteMessage(
ctx context.Context, data []byte,
) (err error) {
if err = c.conn.Write(ctx, ws.MessageText, data); err != nil {
deadline := time.Now().Add(10 * time.Second)
if ctx != nil {
if d, ok := ctx.Deadline(); ok {
deadline = d
}
}
c.conn.SetWriteDeadline(deadline)
if err = c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
err = errorf.E("failed to write message: %w", err)
return
}
@ -52,11 +70,22 @@ func (c *Connection) WriteMessage( @@ -52,11 +70,22 @@ func (c *Connection) WriteMessage(
func (c *Connection) ReadMessage(
ctx context.Context, buf io.Writer,
) (err error) {
var reader io.Reader
if _, reader, err = c.conn.Reader(ctx); err != nil {
deadline := time.Now().Add(60 * time.Second)
if ctx != nil {
if d, ok := ctx.Deadline(); ok {
deadline = d
}
}
c.conn.SetReadDeadline(deadline)
messageType, reader, err := c.conn.NextReader()
if err != nil {
err = fmt.Errorf("failed to get reader: %w", err)
return
}
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
err = fmt.Errorf("unexpected message type: %d", messageType)
return
}
if _, err = io.Copy(buf, reader); err != nil {
err = fmt.Errorf("failed to read message: %w", err)
return
@ -66,14 +95,18 @@ func (c *Connection) ReadMessage( @@ -66,14 +95,18 @@ func (c *Connection) ReadMessage(
// Close closes the websocket connection.
func (c *Connection) Close() error {
return c.conn.Close(ws.StatusNormalClosure, "")
c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second))
return c.conn.Close()
}
// Ping sends a ping message to the websocket connection.
func (c *Connection) Ping(ctx context.Context) error {
ctx, cancel := context.WithTimeoutCause(
ctx, time.Millisecond*800, errors.New("ping took too long"),
)
defer cancel()
return c.conn.Ping(ctx)
deadline := time.Now().Add(800 * time.Millisecond)
if ctx != nil {
if d, ok := ctx.Deadline(); ok {
deadline = d
}
}
c.conn.SetWriteDeadline(deadline)
return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline)
}

33
pkg/protocol/ws/connection_options.go

@ -5,32 +5,21 @@ package ws @@ -5,32 +5,21 @@ package ws
import (
"crypto/tls"
"net/http"
"net/textproto"
"time"
ws "github.com/coder/websocket"
"github.com/gorilla/websocket"
)
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}
func getConnectionOptions(
requestHeader http.Header, tlsConfig *tls.Config,
) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
) *websocket.Dialer {
dialer := &websocket.Dialer{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
TLSClientConfig: tlsConfig,
HandshakeTimeout: 10 * time.Second,
}
// Headers are passed directly to DialContext, not set on Dialer
// The User-Agent header will be set when calling DialContext if not present
return dialer
}

2
pkg/version/version

@ -1 +1 @@ @@ -1 +1 @@
v0.19.9
v0.20.0
Loading…
Cancel
Save