diff --git a/pkg/database/compact_event.go b/pkg/database/compact_event.go index f5949f0..d2eae63 100644 --- a/pkg/database/compact_event.go +++ b/pkg/database/compact_event.go @@ -52,6 +52,20 @@ const ( TagElementPubkeySerial = 0x01 // Pubkey serial reference (5 bytes) TagElementEventSerial = 0x02 // Event ID serial reference (5 bytes) TagElementEventIdFull = 0x03 // Full event ID (32 bytes) - for unknown refs + + // Sanity limits to prevent OOM from corrupt data + MaxTagsPerEvent = 10000 // Maximum number of tags in an event + MaxTagElements = 100 // Maximum elements in a single tag + MaxContentLength = 10 << 20 // 10MB max content + MaxTagElementLength = 1 << 20 // 1MB max for a single tag element +) + +var ( + ErrTooManyTags = errors.New("corrupt data: too many tags") + ErrTooManyTagElems = errors.New("corrupt data: too many tag elements") + ErrContentTooLarge = errors.New("corrupt data: content too large") + ErrTagElementTooLong = errors.New("corrupt data: tag element too long") + ErrUnknownTagElemType = errors.New("corrupt data: unknown tag element type") ) // SerialResolver is an interface for resolving serials during compact encoding/decoding. @@ -287,12 +301,15 @@ func UnmarshalCompactEvent(data []byte, eventId []byte, resolver SerialResolver) if nTags, err = varint.Decode(r); chk.E(err) { return nil, err } + if nTags > MaxTagsPerEvent { + return nil, ErrTooManyTags // Don't log - caller handles gracefully + } if nTags > 0 { ev.Tags = tag.NewSWithCap(int(nTags)) for i := uint64(0); i < nTags; i++ { var t *tag.T - if t, err = decodeCompactTag(r, resolver); chk.E(err) { - return nil, err + if t, err = decodeCompactTag(r, resolver); err != nil { + return nil, err // Don't log corruption errors } *ev.Tags = append(*ev.Tags, t) } @@ -303,6 +320,9 @@ func UnmarshalCompactEvent(data []byte, eventId []byte, resolver SerialResolver) if contentLen, err = varint.Decode(r); chk.E(err) { return nil, err } + if contentLen > MaxContentLength { + return nil, ErrContentTooLarge + } ev.Content = make([]byte, contentLen) if _, err = io.ReadFull(r, ev.Content); chk.E(err) { return nil, err @@ -320,16 +340,19 @@ func UnmarshalCompactEvent(data []byte, eventId []byte, resolver SerialResolver) // decodeCompactTag decodes a single tag from compact format. func decodeCompactTag(r io.Reader, resolver SerialResolver) (t *tag.T, err error) { var nElems uint64 - if nElems, err = varint.Decode(r); chk.E(err) { + if nElems, err = varint.Decode(r); err != nil { return nil, err } + if nElems > MaxTagElements { + return nil, ErrTooManyTagElems + } t = tag.NewWithCap(int(nElems)) for i := uint64(0); i < nElems; i++ { var elem []byte - if elem, err = decodeTagElement(r, resolver); chk.E(err) { - return nil, err + if elem, err = decodeTagElement(r, resolver); err != nil { + return nil, err // Don't log corruption errors } t.T = append(t.T, elem) } @@ -350,9 +373,12 @@ func decodeTagElement(r io.Reader, resolver SerialResolver) (elem []byte, err er case TagElementRaw: // Raw bytes: varint length + data var length uint64 - if length, err = varint.Decode(r); chk.E(err) { + if length, err = varint.Decode(r); err != nil { return nil, err } + if length > MaxTagElementLength { + return nil, ErrTagElementTooLong + } elem = make([]byte, length) if _, err = io.ReadFull(r, elem); err != nil { return nil, err @@ -402,7 +428,7 @@ func decodeTagElement(r io.Reader, resolver SerialResolver) (elem []byte, err er return elem, nil default: - return nil, errors.New("unknown tag element type flag") + return nil, ErrUnknownTagElemType } }