8 changed files with 292 additions and 164 deletions
@ -0,0 +1,78 @@ |
|||||||
|
package app |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"time" |
||||||
|
|
||||||
|
"lol.mleku.dev/chk" |
||||||
|
"lol.mleku.dev/log" |
||||||
|
"next.orly.dev/pkg/acl" |
||||||
|
"next.orly.dev/pkg/encoders/envelopes/authenvelope" |
||||||
|
"next.orly.dev/pkg/encoders/envelopes/countenvelope" |
||||||
|
"next.orly.dev/pkg/utils/normalize" |
||||||
|
) |
||||||
|
|
||||||
|
// HandleCount processes a COUNT envelope by parsing the request, verifying
|
||||||
|
// permissions, invoking the database CountEvents for each provided filter, and
|
||||||
|
// responding with a COUNT response containing the aggregate count.
|
||||||
|
func (l *Listener) HandleCount(msg []byte) (err error) { |
||||||
|
log.D.F("HandleCount: START processing from %s", l.remote) |
||||||
|
|
||||||
|
// Parse the COUNT request
|
||||||
|
env := countenvelope.New() |
||||||
|
if _, err = env.Unmarshal(msg); chk.E(err) { |
||||||
|
return normalize.Error.Errorf(err.Error()) |
||||||
|
} |
||||||
|
log.D.C(func() string { return fmt.Sprintf("COUNT sub=%s filters=%d", env.Subscription, len(env.Filters)) }) |
||||||
|
|
||||||
|
// If ACL is active, send a challenge (same as REQ path)
|
||||||
|
if acl.Registry.Active.Load() != "none" { |
||||||
|
if err = authenvelope.NewChallengeWith(l.challenge.Load()).Write(l); chk.E(err) { |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Check read permissions
|
||||||
|
accessLevel := acl.Registry.GetAccessLevel(l.authedPubkey.Load(), l.remote) |
||||||
|
switch accessLevel { |
||||||
|
case "none": |
||||||
|
return errors.New("auth required: user not authed or has no read access") |
||||||
|
default: |
||||||
|
// allowed to read
|
||||||
|
} |
||||||
|
|
||||||
|
// Use a bounded context for counting
|
||||||
|
ctx, cancel := context.WithTimeout(l.ctx, 30*time.Second) |
||||||
|
defer cancel() |
||||||
|
|
||||||
|
// Aggregate count across all provided filters
|
||||||
|
var total int |
||||||
|
var approx bool // database returns false per implementation
|
||||||
|
for _, f := range env.Filters { |
||||||
|
if f == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
var cnt int |
||||||
|
var a bool |
||||||
|
cnt, a, err = l.D.CountEvents(ctx, f) |
||||||
|
if chk.E(err) { |
||||||
|
return |
||||||
|
} |
||||||
|
total += cnt |
||||||
|
approx = approx || a |
||||||
|
} |
||||||
|
|
||||||
|
// Build and send COUNT response
|
||||||
|
var res *countenvelope.Response |
||||||
|
if res, err = countenvelope.NewResponseFrom(env.Subscription, total, approx); chk.E(err) { |
||||||
|
return |
||||||
|
} |
||||||
|
if err = res.Write(l); chk.E(err) { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
log.D.F("HandleCount: COMPLETED processing from %s count=%d approx=%v", l.remote, total, approx) |
||||||
|
return nil |
||||||
|
} |
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -0,0 +1,44 @@ |
|||||||
|
package database |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
|
||||||
|
"next.orly.dev/pkg/encoders/filter" |
||||||
|
) |
||||||
|
|
||||||
|
// CountEvents mirrors the initial selection logic of QueryEvents but stops
|
||||||
|
// once we have identified candidate event serials (id/pk/ts). It returns the
|
||||||
|
// count of those serials. The `approx` flag is always false as requested.
|
||||||
|
func (d *D) CountEvents(c context.Context, f *filter.F) ( |
||||||
|
count int, approx bool, err error, |
||||||
|
) { |
||||||
|
approx = false |
||||||
|
if f == nil { |
||||||
|
return 0, false, nil |
||||||
|
} |
||||||
|
|
||||||
|
// If explicit Ids are provided, count how many of them resolve to serials.
|
||||||
|
if f.Ids != nil && f.Ids.Len() > 0 { |
||||||
|
var serials map[string]interface{} |
||||||
|
// Use type inference without importing extra packages by discarding the
|
||||||
|
// concrete value type via a two-step assignment.
|
||||||
|
if tmp, idErr := d.GetSerialsByIds(f.Ids); idErr != nil { |
||||||
|
return 0, false, idErr |
||||||
|
} else { |
||||||
|
// Reassign to a map with empty interface values to avoid referencing
|
||||||
|
// the concrete Uint40 type here.
|
||||||
|
serials = make(map[string]interface{}, len(tmp)) |
||||||
|
for k := range tmp { |
||||||
|
serials[k] = struct{}{} |
||||||
|
} |
||||||
|
} |
||||||
|
return len(serials), false, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Otherwise, query for candidate Id/Pubkey/Timestamp triplets and count them.
|
||||||
|
if idPkTs, qErr := d.QueryForIds(c, f); qErr != nil { |
||||||
|
return 0, false, qErr |
||||||
|
} else { |
||||||
|
return len(idPkTs), false, nil |
||||||
|
} |
||||||
|
} |
||||||
Loading…
Reference in new issue