40 changed files with 9078 additions and 46 deletions
@ -0,0 +1,520 @@
@@ -0,0 +1,520 @@
|
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"sort" |
||||
"sync" |
||||
"time" |
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event" |
||||
"git.mleku.dev/mleku/nostr/encoders/filter" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
"git.mleku.dev/mleku/nostr/encoders/kind" |
||||
"git.mleku.dev/mleku/nostr/encoders/tag" |
||||
"git.mleku.dev/mleku/nostr/interfaces/signer/p8k" |
||||
"lukechampine.com/frand" |
||||
"next.orly.dev/pkg/database" |
||||
) |
||||
|
||||
const ( |
||||
// GraphBenchNumPubkeys is the number of pubkeys to generate for graph benchmark
|
||||
GraphBenchNumPubkeys = 100000 |
||||
// GraphBenchMinFollows is the minimum number of follows per pubkey
|
||||
GraphBenchMinFollows = 1 |
||||
// GraphBenchMaxFollows is the maximum number of follows per pubkey
|
||||
GraphBenchMaxFollows = 1000 |
||||
// GraphBenchSeed is the deterministic seed for frand PRNG (fits in uint64)
|
||||
GraphBenchSeed uint64 = 0x4E6F737472 // "Nostr" in hex
|
||||
// GraphBenchTraversalDepth is the depth of graph traversal (3 = third degree)
|
||||
GraphBenchTraversalDepth = 3 |
||||
) |
||||
|
||||
// GraphTraversalBenchmark benchmarks graph traversal using NIP-01 style queries
|
||||
type GraphTraversalBenchmark struct { |
||||
config *BenchmarkConfig |
||||
db *database.D |
||||
results []*BenchmarkResult |
||||
mu sync.RWMutex |
||||
|
||||
// Cached data for the benchmark
|
||||
pubkeys [][]byte // 100k pubkeys as 32-byte arrays
|
||||
signers []*p8k.Signer // signers for each pubkey
|
||||
follows [][]int // follows[i] = list of indices that pubkey[i] follows
|
||||
rng *frand.RNG // deterministic PRNG
|
||||
} |
||||
|
||||
// NewGraphTraversalBenchmark creates a new graph traversal benchmark
|
||||
func NewGraphTraversalBenchmark(config *BenchmarkConfig, db *database.D) *GraphTraversalBenchmark { |
||||
return &GraphTraversalBenchmark{ |
||||
config: config, |
||||
db: db, |
||||
results: make([]*BenchmarkResult, 0), |
||||
rng: frand.NewCustom(make([]byte, 32), 1024, 12), // ChaCha12 with seed buffer
|
||||
} |
||||
} |
||||
|
||||
// initializeDeterministicRNG initializes the PRNG with deterministic seed
|
||||
func (g *GraphTraversalBenchmark) initializeDeterministicRNG() { |
||||
// Create seed buffer from GraphBenchSeed (uint64 spread across 8 bytes)
|
||||
seedBuf := make([]byte, 32) |
||||
seed := GraphBenchSeed |
||||
seedBuf[0] = byte(seed >> 56) |
||||
seedBuf[1] = byte(seed >> 48) |
||||
seedBuf[2] = byte(seed >> 40) |
||||
seedBuf[3] = byte(seed >> 32) |
||||
seedBuf[4] = byte(seed >> 24) |
||||
seedBuf[5] = byte(seed >> 16) |
||||
seedBuf[6] = byte(seed >> 8) |
||||
seedBuf[7] = byte(seed) |
||||
g.rng = frand.NewCustom(seedBuf, 1024, 12) |
||||
} |
||||
|
||||
// generatePubkeys generates deterministic pubkeys using frand
|
||||
func (g *GraphTraversalBenchmark) generatePubkeys() { |
||||
fmt.Printf("Generating %d deterministic pubkeys...\n", GraphBenchNumPubkeys) |
||||
start := time.Now() |
||||
|
||||
g.initializeDeterministicRNG() |
||||
g.pubkeys = make([][]byte, GraphBenchNumPubkeys) |
||||
g.signers = make([]*p8k.Signer, GraphBenchNumPubkeys) |
||||
|
||||
for i := 0; i < GraphBenchNumPubkeys; i++ { |
||||
// Generate deterministic 32-byte secret key from PRNG
|
||||
secretKey := make([]byte, 32) |
||||
g.rng.Read(secretKey) |
||||
|
||||
// Create signer from secret key
|
||||
signer := p8k.MustNew() |
||||
if err := signer.InitSec(secretKey); err != nil { |
||||
panic(fmt.Sprintf("failed to init signer %d: %v", i, err)) |
||||
} |
||||
|
||||
g.signers[i] = signer |
||||
g.pubkeys[i] = make([]byte, 32) |
||||
copy(g.pubkeys[i], signer.Pub()) |
||||
|
||||
if (i+1)%10000 == 0 { |
||||
fmt.Printf(" Generated %d/%d pubkeys...\n", i+1, GraphBenchNumPubkeys) |
||||
} |
||||
} |
||||
|
||||
fmt.Printf("Generated %d pubkeys in %v\n", GraphBenchNumPubkeys, time.Since(start)) |
||||
} |
||||
|
||||
// generateFollowGraph generates the random follow graph with deterministic PRNG
|
||||
func (g *GraphTraversalBenchmark) generateFollowGraph() { |
||||
fmt.Printf("Generating follow graph (1-%d follows per pubkey)...\n", GraphBenchMaxFollows) |
||||
start := time.Now() |
||||
|
||||
// Reset RNG to ensure deterministic follow graph
|
||||
g.initializeDeterministicRNG() |
||||
// Skip the bytes used for pubkey generation
|
||||
skipBuf := make([]byte, 32*GraphBenchNumPubkeys) |
||||
g.rng.Read(skipBuf) |
||||
|
||||
g.follows = make([][]int, GraphBenchNumPubkeys) |
||||
|
||||
totalFollows := 0 |
||||
for i := 0; i < GraphBenchNumPubkeys; i++ { |
||||
// Determine number of follows for this pubkey (1 to 1000)
|
||||
numFollows := int(g.rng.Uint64n(uint64(GraphBenchMaxFollows-GraphBenchMinFollows+1))) + GraphBenchMinFollows |
||||
|
||||
// Generate random follow indices (excluding self)
|
||||
followSet := make(map[int]struct{}) |
||||
for len(followSet) < numFollows { |
||||
followIdx := int(g.rng.Uint64n(uint64(GraphBenchNumPubkeys))) |
||||
if followIdx != i { |
||||
followSet[followIdx] = struct{}{} |
||||
} |
||||
} |
||||
|
||||
// Convert to slice
|
||||
g.follows[i] = make([]int, 0, numFollows) |
||||
for idx := range followSet { |
||||
g.follows[i] = append(g.follows[i], idx) |
||||
} |
||||
totalFollows += numFollows |
||||
|
||||
if (i+1)%10000 == 0 { |
||||
fmt.Printf(" Generated follow lists for %d/%d pubkeys...\n", i+1, GraphBenchNumPubkeys) |
||||
} |
||||
} |
||||
|
||||
avgFollows := float64(totalFollows) / float64(GraphBenchNumPubkeys) |
||||
fmt.Printf("Generated follow graph in %v (avg %.1f follows/pubkey, total %d follows)\n", |
||||
time.Since(start), avgFollows, totalFollows) |
||||
} |
||||
|
||||
// createFollowListEvents creates kind 3 follow list events in the database
|
||||
func (g *GraphTraversalBenchmark) createFollowListEvents() { |
||||
fmt.Println("Creating follow list events in database...") |
||||
start := time.Now() |
||||
|
||||
ctx := context.Background() |
||||
baseTime := time.Now().Unix() |
||||
|
||||
var mu sync.Mutex |
||||
var wg sync.WaitGroup |
||||
var successCount, errorCount int64 |
||||
latencies := make([]time.Duration, 0, GraphBenchNumPubkeys) |
||||
|
||||
// Use worker pool for parallel event creation
|
||||
numWorkers := g.config.ConcurrentWorkers |
||||
if numWorkers < 1 { |
||||
numWorkers = 4 |
||||
} |
||||
|
||||
workChan := make(chan int, numWorkers*2) |
||||
|
||||
// Rate limiter: cap at 20,000 events/second
|
||||
perWorkerRate := 20000.0 / float64(numWorkers) |
||||
|
||||
for w := 0; w < numWorkers; w++ { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
|
||||
workerLimiter := NewRateLimiter(perWorkerRate) |
||||
|
||||
for i := range workChan { |
||||
workerLimiter.Wait() |
||||
|
||||
ev := event.New() |
||||
ev.Kind = kind.FollowList.K |
||||
ev.CreatedAt = baseTime + int64(i) |
||||
ev.Content = []byte("") |
||||
ev.Tags = tag.NewS() |
||||
|
||||
// Add p tags for all follows
|
||||
for _, followIdx := range g.follows[i] { |
||||
pubkeyHex := hex.Enc(g.pubkeys[followIdx]) |
||||
ev.Tags.Append(tag.NewFromAny("p", pubkeyHex)) |
||||
} |
||||
|
||||
// Sign the event
|
||||
if err := ev.Sign(g.signers[i]); err != nil { |
||||
mu.Lock() |
||||
errorCount++ |
||||
mu.Unlock() |
||||
ev.Free() |
||||
continue |
||||
} |
||||
|
||||
// Save to database
|
||||
eventStart := time.Now() |
||||
_, err := g.db.SaveEvent(ctx, ev) |
||||
latency := time.Since(eventStart) |
||||
|
||||
mu.Lock() |
||||
if err != nil { |
||||
errorCount++ |
||||
} else { |
||||
successCount++ |
||||
latencies = append(latencies, latency) |
||||
} |
||||
mu.Unlock() |
||||
|
||||
ev.Free() |
||||
} |
||||
}() |
||||
} |
||||
|
||||
// Send work
|
||||
for i := 0; i < GraphBenchNumPubkeys; i++ { |
||||
workChan <- i |
||||
if (i+1)%10000 == 0 { |
||||
fmt.Printf(" Queued %d/%d follow list events...\n", i+1, GraphBenchNumPubkeys) |
||||
} |
||||
} |
||||
close(workChan) |
||||
wg.Wait() |
||||
|
||||
duration := time.Since(start) |
||||
eventsPerSec := float64(successCount) / duration.Seconds() |
||||
|
||||
// Calculate latency stats
|
||||
var avgLatency, p95Latency, p99Latency time.Duration |
||||
if len(latencies) > 0 { |
||||
sort.Slice(latencies, func(i, j int) bool { return latencies[i] < latencies[j] }) |
||||
avgLatency = calculateAvgLatency(latencies) |
||||
p95Latency = calculatePercentileLatency(latencies, 0.95) |
||||
p99Latency = calculatePercentileLatency(latencies, 0.99) |
||||
} |
||||
|
||||
fmt.Printf("Created %d follow list events in %v (%.2f events/sec, errors: %d)\n", |
||||
successCount, duration, eventsPerSec, errorCount) |
||||
fmt.Printf(" Avg latency: %v, P95: %v, P99: %v\n", avgLatency, p95Latency, p99Latency) |
||||
|
||||
// Record result for event creation phase
|
||||
result := &BenchmarkResult{ |
||||
TestName: "Graph Setup (Follow Lists)", |
||||
Duration: duration, |
||||
TotalEvents: int(successCount), |
||||
EventsPerSecond: eventsPerSec, |
||||
AvgLatency: avgLatency, |
||||
P95Latency: p95Latency, |
||||
P99Latency: p99Latency, |
||||
ConcurrentWorkers: numWorkers, |
||||
MemoryUsed: getMemUsage(), |
||||
SuccessRate: float64(successCount) / float64(GraphBenchNumPubkeys) * 100, |
||||
} |
||||
|
||||
g.mu.Lock() |
||||
g.results = append(g.results, result) |
||||
g.mu.Unlock() |
||||
} |
||||
|
||||
// runThirdDegreeTraversal runs the third-degree graph traversal benchmark
|
||||
func (g *GraphTraversalBenchmark) runThirdDegreeTraversal() { |
||||
fmt.Printf("\n=== Third-Degree Graph Traversal Benchmark ===\n") |
||||
fmt.Printf("Traversing 3 degrees of follows for each of %d pubkeys...\n", GraphBenchNumPubkeys) |
||||
|
||||
start := time.Now() |
||||
ctx := context.Background() |
||||
|
||||
var mu sync.Mutex |
||||
var wg sync.WaitGroup |
||||
var totalQueries int64 |
||||
var totalPubkeysFound int64 |
||||
queryLatencies := make([]time.Duration, 0, GraphBenchNumPubkeys*3) |
||||
traversalLatencies := make([]time.Duration, 0, GraphBenchNumPubkeys) |
||||
|
||||
// Sample a subset for detailed traversal (full 100k would take too long)
|
||||
sampleSize := 1000 |
||||
if sampleSize > GraphBenchNumPubkeys { |
||||
sampleSize = GraphBenchNumPubkeys |
||||
} |
||||
|
||||
// Deterministic sampling
|
||||
g.initializeDeterministicRNG() |
||||
sampleIndices := make([]int, sampleSize) |
||||
for i := 0; i < sampleSize; i++ { |
||||
sampleIndices[i] = int(g.rng.Uint64n(uint64(GraphBenchNumPubkeys))) |
||||
} |
||||
|
||||
fmt.Printf("Sampling %d pubkeys for traversal...\n", sampleSize) |
||||
|
||||
numWorkers := g.config.ConcurrentWorkers |
||||
if numWorkers < 1 { |
||||
numWorkers = 4 |
||||
} |
||||
|
||||
workChan := make(chan int, numWorkers*2) |
||||
|
||||
for w := 0; w < numWorkers; w++ { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
|
||||
for startIdx := range workChan { |
||||
traversalStart := time.Now() |
||||
foundPubkeys := make(map[string]struct{}) |
||||
|
||||
// Start with the initial pubkey
|
||||
currentLevel := [][]byte{g.pubkeys[startIdx]} |
||||
startPubkeyHex := hex.Enc(g.pubkeys[startIdx]) |
||||
foundPubkeys[startPubkeyHex] = struct{}{} |
||||
|
||||
// Traverse 3 degrees
|
||||
for depth := 0; depth < GraphBenchTraversalDepth; depth++ { |
||||
if len(currentLevel) == 0 { |
||||
break |
||||
} |
||||
|
||||
nextLevel := make([][]byte, 0) |
||||
|
||||
// Query follow lists for all pubkeys at current level
|
||||
// Batch queries for efficiency
|
||||
batchSize := 100 |
||||
for batchStart := 0; batchStart < len(currentLevel); batchStart += batchSize { |
||||
batchEnd := batchStart + batchSize |
||||
if batchEnd > len(currentLevel) { |
||||
batchEnd = len(currentLevel) |
||||
} |
||||
|
||||
batch := currentLevel[batchStart:batchEnd] |
||||
|
||||
// Build filter for kind 3 events from these pubkeys
|
||||
f := filter.New() |
||||
f.Kinds = kind.NewS(kind.FollowList) |
||||
f.Authors = tag.NewWithCap(len(batch)) |
||||
for _, pk := range batch { |
||||
// Authors.T expects raw byte slices (pubkeys)
|
||||
f.Authors.T = append(f.Authors.T, pk) |
||||
} |
||||
|
||||
queryStart := time.Now() |
||||
events, err := g.db.QueryEvents(ctx, f) |
||||
queryLatency := time.Since(queryStart) |
||||
|
||||
mu.Lock() |
||||
totalQueries++ |
||||
queryLatencies = append(queryLatencies, queryLatency) |
||||
mu.Unlock() |
||||
|
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
// Extract followed pubkeys from p tags
|
||||
for _, ev := range events { |
||||
for _, t := range *ev.Tags { |
||||
if len(t.T) >= 2 && string(t.T[0]) == "p" { |
||||
pubkeyHex := string(t.ValueHex()) |
||||
if _, exists := foundPubkeys[pubkeyHex]; !exists { |
||||
foundPubkeys[pubkeyHex] = struct{}{} |
||||
// Decode hex to bytes for next level
|
||||
if pkBytes, err := hex.Dec(pubkeyHex); err == nil { |
||||
nextLevel = append(nextLevel, pkBytes) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
ev.Free() |
||||
} |
||||
} |
||||
|
||||
currentLevel = nextLevel |
||||
} |
||||
|
||||
traversalLatency := time.Since(traversalStart) |
||||
|
||||
mu.Lock() |
||||
totalPubkeysFound += int64(len(foundPubkeys)) |
||||
traversalLatencies = append(traversalLatencies, traversalLatency) |
||||
mu.Unlock() |
||||
} |
||||
}() |
||||
} |
||||
|
||||
// Send work
|
||||
for _, idx := range sampleIndices { |
||||
workChan <- idx |
||||
} |
||||
close(workChan) |
||||
wg.Wait() |
||||
|
||||
duration := time.Since(start) |
||||
|
||||
// Calculate statistics
|
||||
var avgQueryLatency, p95QueryLatency, p99QueryLatency time.Duration |
||||
if len(queryLatencies) > 0 { |
||||
sort.Slice(queryLatencies, func(i, j int) bool { return queryLatencies[i] < queryLatencies[j] }) |
||||
avgQueryLatency = calculateAvgLatency(queryLatencies) |
||||
p95QueryLatency = calculatePercentileLatency(queryLatencies, 0.95) |
||||
p99QueryLatency = calculatePercentileLatency(queryLatencies, 0.99) |
||||
} |
||||
|
||||
var avgTraversalLatency, p95TraversalLatency, p99TraversalLatency time.Duration |
||||
if len(traversalLatencies) > 0 { |
||||
sort.Slice(traversalLatencies, func(i, j int) bool { return traversalLatencies[i] < traversalLatencies[j] }) |
||||
avgTraversalLatency = calculateAvgLatency(traversalLatencies) |
||||
p95TraversalLatency = calculatePercentileLatency(traversalLatencies, 0.95) |
||||
p99TraversalLatency = calculatePercentileLatency(traversalLatencies, 0.99) |
||||
} |
||||
|
||||
avgPubkeysPerTraversal := float64(totalPubkeysFound) / float64(sampleSize) |
||||
traversalsPerSec := float64(sampleSize) / duration.Seconds() |
||||
queriesPerSec := float64(totalQueries) / duration.Seconds() |
||||
|
||||
fmt.Printf("\n=== Graph Traversal Results ===\n") |
||||
fmt.Printf("Traversals completed: %d\n", sampleSize) |
||||
fmt.Printf("Total queries: %d (%.2f queries/sec)\n", totalQueries, queriesPerSec) |
||||
fmt.Printf("Avg pubkeys found per traversal: %.1f\n", avgPubkeysPerTraversal) |
||||
fmt.Printf("Total duration: %v\n", duration) |
||||
fmt.Printf("\nQuery Latencies:\n") |
||||
fmt.Printf(" Avg: %v, P95: %v, P99: %v\n", avgQueryLatency, p95QueryLatency, p99QueryLatency) |
||||
fmt.Printf("\nFull Traversal Latencies (3 degrees):\n") |
||||
fmt.Printf(" Avg: %v, P95: %v, P99: %v\n", avgTraversalLatency, p95TraversalLatency, p99TraversalLatency) |
||||
fmt.Printf("Traversals/sec: %.2f\n", traversalsPerSec) |
||||
|
||||
// Record result for traversal phase
|
||||
result := &BenchmarkResult{ |
||||
TestName: "Graph Traversal (3 Degrees)", |
||||
Duration: duration, |
||||
TotalEvents: int(totalQueries), |
||||
EventsPerSecond: traversalsPerSec, |
||||
AvgLatency: avgTraversalLatency, |
||||
P90Latency: calculatePercentileLatency(traversalLatencies, 0.90), |
||||
P95Latency: p95TraversalLatency, |
||||
P99Latency: p99TraversalLatency, |
||||
Bottom10Avg: calculateBottom10Avg(traversalLatencies), |
||||
ConcurrentWorkers: numWorkers, |
||||
MemoryUsed: getMemUsage(), |
||||
SuccessRate: 100.0, |
||||
} |
||||
|
||||
g.mu.Lock() |
||||
g.results = append(g.results, result) |
||||
g.mu.Unlock() |
||||
|
||||
// Also record query performance separately
|
||||
queryResult := &BenchmarkResult{ |
||||
TestName: "Graph Queries (Follow Lists)", |
||||
Duration: duration, |
||||
TotalEvents: int(totalQueries), |
||||
EventsPerSecond: queriesPerSec, |
||||
AvgLatency: avgQueryLatency, |
||||
P90Latency: calculatePercentileLatency(queryLatencies, 0.90), |
||||
P95Latency: p95QueryLatency, |
||||
P99Latency: p99QueryLatency, |
||||
Bottom10Avg: calculateBottom10Avg(queryLatencies), |
||||
ConcurrentWorkers: numWorkers, |
||||
MemoryUsed: getMemUsage(), |
||||
SuccessRate: 100.0, |
||||
} |
||||
|
||||
g.mu.Lock() |
||||
g.results = append(g.results, queryResult) |
||||
g.mu.Unlock() |
||||
} |
||||
|
||||
// RunSuite runs the complete graph traversal benchmark suite
|
||||
func (g *GraphTraversalBenchmark) RunSuite() { |
||||
fmt.Println("\n╔════════════════════════════════════════════════════════╗") |
||||
fmt.Println("║ GRAPH TRAVERSAL BENCHMARK (100k Pubkeys) ║") |
||||
fmt.Println("╚════════════════════════════════════════════════════════╝") |
||||
|
||||
// Step 1: Generate pubkeys
|
||||
g.generatePubkeys() |
||||
|
||||
// Step 2: Generate follow graph
|
||||
g.generateFollowGraph() |
||||
|
||||
// Step 3: Create follow list events in database
|
||||
g.createFollowListEvents() |
||||
|
||||
// Step 4: Run third-degree traversal benchmark
|
||||
g.runThirdDegreeTraversal() |
||||
|
||||
fmt.Printf("\n=== Graph Traversal Benchmark Complete ===\n\n") |
||||
} |
||||
|
||||
// GetResults returns the benchmark results
|
||||
func (g *GraphTraversalBenchmark) GetResults() []*BenchmarkResult { |
||||
g.mu.RLock() |
||||
defer g.mu.RUnlock() |
||||
return g.results |
||||
} |
||||
|
||||
// PrintResults prints the benchmark results
|
||||
func (g *GraphTraversalBenchmark) PrintResults() { |
||||
g.mu.RLock() |
||||
defer g.mu.RUnlock() |
||||
|
||||
for _, result := range g.results { |
||||
fmt.Printf("\nTest: %s\n", result.TestName) |
||||
fmt.Printf("Duration: %v\n", result.Duration) |
||||
fmt.Printf("Total Events/Queries: %d\n", result.TotalEvents) |
||||
fmt.Printf("Events/sec: %.2f\n", result.EventsPerSecond) |
||||
fmt.Printf("Success Rate: %.1f%%\n", result.SuccessRate) |
||||
fmt.Printf("Concurrent Workers: %d\n", result.ConcurrentWorkers) |
||||
fmt.Printf("Memory Used: %d MB\n", result.MemoryUsed/(1024*1024)) |
||||
fmt.Printf("Avg Latency: %v\n", result.AvgLatency) |
||||
fmt.Printf("P90 Latency: %v\n", result.P90Latency) |
||||
fmt.Printf("P95 Latency: %v\n", result.P95Latency) |
||||
fmt.Printf("P99 Latency: %v\n", result.P99Latency) |
||||
fmt.Printf("Bottom 10%% Avg Latency: %v\n", result.Bottom10Avg) |
||||
} |
||||
} |
||||
@ -0,0 +1,572 @@
@@ -0,0 +1,572 @@
|
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"sort" |
||||
"sync" |
||||
"time" |
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/envelopes/eventenvelope" |
||||
"git.mleku.dev/mleku/nostr/encoders/event" |
||||
"git.mleku.dev/mleku/nostr/encoders/filter" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
"git.mleku.dev/mleku/nostr/encoders/kind" |
||||
"git.mleku.dev/mleku/nostr/encoders/tag" |
||||
"git.mleku.dev/mleku/nostr/interfaces/signer/p8k" |
||||
"git.mleku.dev/mleku/nostr/ws" |
||||
"lukechampine.com/frand" |
||||
) |
||||
|
||||
// NetworkGraphTraversalBenchmark benchmarks graph traversal using NIP-01 queries over WebSocket
|
||||
type NetworkGraphTraversalBenchmark struct { |
||||
relayURL string |
||||
relay *ws.Client |
||||
results []*BenchmarkResult |
||||
mu sync.RWMutex |
||||
workers int |
||||
|
||||
// Cached data for the benchmark
|
||||
pubkeys [][]byte // 100k pubkeys as 32-byte arrays
|
||||
signers []*p8k.Signer // signers for each pubkey
|
||||
follows [][]int // follows[i] = list of indices that pubkey[i] follows
|
||||
rng *frand.RNG // deterministic PRNG
|
||||
} |
||||
|
||||
// NewNetworkGraphTraversalBenchmark creates a new network graph traversal benchmark
|
||||
func NewNetworkGraphTraversalBenchmark(relayURL string, workers int) *NetworkGraphTraversalBenchmark { |
||||
return &NetworkGraphTraversalBenchmark{ |
||||
relayURL: relayURL, |
||||
workers: workers, |
||||
results: make([]*BenchmarkResult, 0), |
||||
rng: frand.NewCustom(make([]byte, 32), 1024, 12), // ChaCha12 with seed buffer
|
||||
} |
||||
} |
||||
|
||||
// Connect establishes WebSocket connection to the relay
|
||||
func (n *NetworkGraphTraversalBenchmark) Connect(ctx context.Context) error { |
||||
var err error |
||||
n.relay, err = ws.RelayConnect(ctx, n.relayURL) |
||||
if err != nil { |
||||
return fmt.Errorf("failed to connect to relay %s: %w", n.relayURL, err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Close closes the relay connection
|
||||
func (n *NetworkGraphTraversalBenchmark) Close() { |
||||
if n.relay != nil { |
||||
n.relay.Close() |
||||
} |
||||
} |
||||
|
||||
// initializeDeterministicRNG initializes the PRNG with deterministic seed
|
||||
func (n *NetworkGraphTraversalBenchmark) initializeDeterministicRNG() { |
||||
// Create seed buffer from GraphBenchSeed (uint64 spread across 8 bytes)
|
||||
seedBuf := make([]byte, 32) |
||||
seed := GraphBenchSeed |
||||
seedBuf[0] = byte(seed >> 56) |
||||
seedBuf[1] = byte(seed >> 48) |
||||
seedBuf[2] = byte(seed >> 40) |
||||
seedBuf[3] = byte(seed >> 32) |
||||
seedBuf[4] = byte(seed >> 24) |
||||
seedBuf[5] = byte(seed >> 16) |
||||
seedBuf[6] = byte(seed >> 8) |
||||
seedBuf[7] = byte(seed) |
||||
n.rng = frand.NewCustom(seedBuf, 1024, 12) |
||||
} |
||||
|
||||
// generatePubkeys generates deterministic pubkeys using frand
|
||||
func (n *NetworkGraphTraversalBenchmark) generatePubkeys() { |
||||
fmt.Printf("Generating %d deterministic pubkeys...\n", GraphBenchNumPubkeys) |
||||
start := time.Now() |
||||
|
||||
n.initializeDeterministicRNG() |
||||
n.pubkeys = make([][]byte, GraphBenchNumPubkeys) |
||||
n.signers = make([]*p8k.Signer, GraphBenchNumPubkeys) |
||||
|
||||
for i := 0; i < GraphBenchNumPubkeys; i++ { |
||||
// Generate deterministic 32-byte secret key from PRNG
|
||||
secretKey := make([]byte, 32) |
||||
n.rng.Read(secretKey) |
||||
|
||||
// Create signer from secret key
|
||||
signer := p8k.MustNew() |
||||
if err := signer.InitSec(secretKey); err != nil { |
||||
panic(fmt.Sprintf("failed to init signer %d: %v", i, err)) |
||||
} |
||||
|
||||
n.signers[i] = signer |
||||
n.pubkeys[i] = make([]byte, 32) |
||||
copy(n.pubkeys[i], signer.Pub()) |
||||
|
||||
if (i+1)%10000 == 0 { |
||||
fmt.Printf(" Generated %d/%d pubkeys...\n", i+1, GraphBenchNumPubkeys) |
||||
} |
||||
} |
||||
|
||||
fmt.Printf("Generated %d pubkeys in %v\n", GraphBenchNumPubkeys, time.Since(start)) |
||||
} |
||||
|
||||
// generateFollowGraph generates the random follow graph with deterministic PRNG
|
||||
func (n *NetworkGraphTraversalBenchmark) generateFollowGraph() { |
||||
fmt.Printf("Generating follow graph (1-%d follows per pubkey)...\n", GraphBenchMaxFollows) |
||||
start := time.Now() |
||||
|
||||
// Reset RNG to ensure deterministic follow graph
|
||||
n.initializeDeterministicRNG() |
||||
// Skip the bytes used for pubkey generation
|
||||
skipBuf := make([]byte, 32*GraphBenchNumPubkeys) |
||||
n.rng.Read(skipBuf) |
||||
|
||||
n.follows = make([][]int, GraphBenchNumPubkeys) |
||||
|
||||
totalFollows := 0 |
||||
for i := 0; i < GraphBenchNumPubkeys; i++ { |
||||
// Determine number of follows for this pubkey (1 to 1000)
|
||||
numFollows := int(n.rng.Uint64n(uint64(GraphBenchMaxFollows-GraphBenchMinFollows+1))) + GraphBenchMinFollows |
||||
|
||||
// Generate random follow indices (excluding self)
|
||||
followSet := make(map[int]struct{}) |
||||
for len(followSet) < numFollows { |
||||
followIdx := int(n.rng.Uint64n(uint64(GraphBenchNumPubkeys))) |
||||
if followIdx != i { |
||||
followSet[followIdx] = struct{}{} |
||||
} |
||||
} |
||||
|
||||
// Convert to slice
|
||||
n.follows[i] = make([]int, 0, numFollows) |
||||
for idx := range followSet { |
||||
n.follows[i] = append(n.follows[i], idx) |
||||
} |
||||
totalFollows += numFollows |
||||
|
||||
if (i+1)%10000 == 0 { |
||||
fmt.Printf(" Generated follow lists for %d/%d pubkeys...\n", i+1, GraphBenchNumPubkeys) |
||||
} |
||||
} |
||||
|
||||
avgFollows := float64(totalFollows) / float64(GraphBenchNumPubkeys) |
||||
fmt.Printf("Generated follow graph in %v (avg %.1f follows/pubkey, total %d follows)\n", |
||||
time.Since(start), avgFollows, totalFollows) |
||||
} |
||||
|
||||
// createFollowListEvents creates kind 3 follow list events via WebSocket
|
||||
func (n *NetworkGraphTraversalBenchmark) createFollowListEvents(ctx context.Context) { |
||||
fmt.Println("Creating follow list events via WebSocket...") |
||||
start := time.Now() |
||||
|
||||
baseTime := time.Now().Unix() |
||||
|
||||
var mu sync.Mutex |
||||
var wg sync.WaitGroup |
||||
var successCount, errorCount int64 |
||||
latencies := make([]time.Duration, 0, GraphBenchNumPubkeys) |
||||
|
||||
// Use worker pool for parallel event creation
|
||||
numWorkers := n.workers |
||||
if numWorkers < 1 { |
||||
numWorkers = 4 |
||||
} |
||||
|
||||
workChan := make(chan int, numWorkers*2) |
||||
|
||||
// Rate limiter: cap at 1000 events/second per relay (to avoid overwhelming)
|
||||
perWorkerRate := 1000.0 / float64(numWorkers) |
||||
|
||||
for w := 0; w < numWorkers; w++ { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
|
||||
workerLimiter := NewRateLimiter(perWorkerRate) |
||||
|
||||
for i := range workChan { |
||||
workerLimiter.Wait() |
||||
|
||||
ev := event.New() |
||||
ev.Kind = kind.FollowList.K |
||||
ev.CreatedAt = baseTime + int64(i) |
||||
ev.Content = []byte("") |
||||
ev.Tags = tag.NewS() |
||||
|
||||
// Add p tags for all follows
|
||||
for _, followIdx := range n.follows[i] { |
||||
pubkeyHex := hex.Enc(n.pubkeys[followIdx]) |
||||
ev.Tags.Append(tag.NewFromAny("p", pubkeyHex)) |
||||
} |
||||
|
||||
// Sign the event
|
||||
if err := ev.Sign(n.signers[i]); err != nil { |
||||
mu.Lock() |
||||
errorCount++ |
||||
mu.Unlock() |
||||
ev.Free() |
||||
continue |
||||
} |
||||
|
||||
// Publish via WebSocket
|
||||
eventStart := time.Now() |
||||
errCh := n.relay.Write(eventenvelope.NewSubmissionWith(ev).Marshal(nil)) |
||||
|
||||
// Wait for write to complete
|
||||
select { |
||||
case err := <-errCh: |
||||
latency := time.Since(eventStart) |
||||
mu.Lock() |
||||
if err != nil { |
||||
errorCount++ |
||||
} else { |
||||
successCount++ |
||||
latencies = append(latencies, latency) |
||||
} |
||||
mu.Unlock() |
||||
case <-ctx.Done(): |
||||
mu.Lock() |
||||
errorCount++ |
||||
mu.Unlock() |
||||
} |
||||
|
||||
ev.Free() |
||||
} |
||||
}() |
||||
} |
||||
|
||||
// Send work
|
||||
for i := 0; i < GraphBenchNumPubkeys; i++ { |
||||
workChan <- i |
||||
if (i+1)%10000 == 0 { |
||||
fmt.Printf(" Queued %d/%d follow list events...\n", i+1, GraphBenchNumPubkeys) |
||||
} |
||||
} |
||||
close(workChan) |
||||
wg.Wait() |
||||
|
||||
duration := time.Since(start) |
||||
eventsPerSec := float64(successCount) / duration.Seconds() |
||||
|
||||
// Calculate latency stats
|
||||
var avgLatency, p90Latency, p95Latency, p99Latency time.Duration |
||||
if len(latencies) > 0 { |
||||
sort.Slice(latencies, func(i, j int) bool { return latencies[i] < latencies[j] }) |
||||
avgLatency = calculateAvgLatency(latencies) |
||||
p90Latency = calculatePercentileLatency(latencies, 0.90) |
||||
p95Latency = calculatePercentileLatency(latencies, 0.95) |
||||
p99Latency = calculatePercentileLatency(latencies, 0.99) |
||||
} |
||||
|
||||
fmt.Printf("Created %d follow list events in %v (%.2f events/sec, errors: %d)\n", |
||||
successCount, duration, eventsPerSec, errorCount) |
||||
fmt.Printf(" Avg latency: %v, P95: %v, P99: %v\n", avgLatency, p95Latency, p99Latency) |
||||
|
||||
// Record result for event creation phase
|
||||
result := &BenchmarkResult{ |
||||
TestName: "Graph Setup (Follow Lists)", |
||||
Duration: duration, |
||||
TotalEvents: int(successCount), |
||||
EventsPerSecond: eventsPerSec, |
||||
AvgLatency: avgLatency, |
||||
P90Latency: p90Latency, |
||||
P95Latency: p95Latency, |
||||
P99Latency: p99Latency, |
||||
Bottom10Avg: calculateBottom10Avg(latencies), |
||||
ConcurrentWorkers: numWorkers, |
||||
MemoryUsed: getMemUsage(), |
||||
SuccessRate: float64(successCount) / float64(GraphBenchNumPubkeys) * 100, |
||||
} |
||||
|
||||
n.mu.Lock() |
||||
n.results = append(n.results, result) |
||||
n.mu.Unlock() |
||||
} |
||||
|
||||
// runThirdDegreeTraversal runs the third-degree graph traversal benchmark via WebSocket
|
||||
func (n *NetworkGraphTraversalBenchmark) runThirdDegreeTraversal(ctx context.Context) { |
||||
fmt.Printf("\n=== Third-Degree Graph Traversal Benchmark (Network) ===\n") |
||||
fmt.Printf("Traversing 3 degrees of follows via WebSocket...\n") |
||||
|
||||
start := time.Now() |
||||
|
||||
var mu sync.Mutex |
||||
var wg sync.WaitGroup |
||||
var totalQueries int64 |
||||
var totalPubkeysFound int64 |
||||
queryLatencies := make([]time.Duration, 0, 10000) |
||||
traversalLatencies := make([]time.Duration, 0, 1000) |
||||
|
||||
// Sample a subset for detailed traversal
|
||||
sampleSize := 1000 |
||||
if sampleSize > GraphBenchNumPubkeys { |
||||
sampleSize = GraphBenchNumPubkeys |
||||
} |
||||
|
||||
// Deterministic sampling
|
||||
n.initializeDeterministicRNG() |
||||
sampleIndices := make([]int, sampleSize) |
||||
for i := 0; i < sampleSize; i++ { |
||||
sampleIndices[i] = int(n.rng.Uint64n(uint64(GraphBenchNumPubkeys))) |
||||
} |
||||
|
||||
fmt.Printf("Sampling %d pubkeys for traversal...\n", sampleSize) |
||||
|
||||
numWorkers := n.workers |
||||
if numWorkers < 1 { |
||||
numWorkers = 4 |
||||
} |
||||
|
||||
workChan := make(chan int, numWorkers*2) |
||||
|
||||
for w := 0; w < numWorkers; w++ { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
|
||||
for startIdx := range workChan { |
||||
traversalStart := time.Now() |
||||
foundPubkeys := make(map[string]struct{}) |
||||
|
||||
// Start with the initial pubkey
|
||||
currentLevel := [][]byte{n.pubkeys[startIdx]} |
||||
startPubkeyHex := hex.Enc(n.pubkeys[startIdx]) |
||||
foundPubkeys[startPubkeyHex] = struct{}{} |
||||
|
||||
// Traverse 3 degrees
|
||||
for depth := 0; depth < GraphBenchTraversalDepth; depth++ { |
||||
if len(currentLevel) == 0 { |
||||
break |
||||
} |
||||
|
||||
nextLevel := make([][]byte, 0) |
||||
|
||||
// Query follow lists for all pubkeys at current level
|
||||
// Batch queries for efficiency
|
||||
batchSize := 50 |
||||
for batchStart := 0; batchStart < len(currentLevel); batchStart += batchSize { |
||||
batchEnd := batchStart + batchSize |
||||
if batchEnd > len(currentLevel) { |
||||
batchEnd = len(currentLevel) |
||||
} |
||||
|
||||
batch := currentLevel[batchStart:batchEnd] |
||||
|
||||
// Build filter for kind 3 events from these pubkeys
|
||||
f := filter.New() |
||||
f.Kinds = kind.NewS(kind.FollowList) |
||||
f.Authors = tag.NewWithCap(len(batch)) |
||||
for _, pk := range batch { |
||||
f.Authors.T = append(f.Authors.T, pk) |
||||
} |
||||
|
||||
queryStart := time.Now() |
||||
|
||||
// Subscribe and collect results
|
||||
sub, err := n.relay.Subscribe(ctx, filter.NewS(f)) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
// Collect events with timeout
|
||||
timeout := time.After(5 * time.Second) |
||||
events := make([]*event.E, 0) |
||||
collectLoop: |
||||
for { |
||||
select { |
||||
case ev := <-sub.Events: |
||||
if ev != nil { |
||||
events = append(events, ev) |
||||
} |
||||
case <-sub.EndOfStoredEvents: |
||||
break collectLoop |
||||
case <-timeout: |
||||
break collectLoop |
||||
case <-ctx.Done(): |
||||
break collectLoop |
||||
} |
||||
} |
||||
sub.Unsub() |
||||
|
||||
queryLatency := time.Since(queryStart) |
||||
|
||||
mu.Lock() |
||||
totalQueries++ |
||||
queryLatencies = append(queryLatencies, queryLatency) |
||||
mu.Unlock() |
||||
|
||||
// Extract followed pubkeys from p tags
|
||||
for _, ev := range events { |
||||
for _, t := range *ev.Tags { |
||||
if len(t.T) >= 2 && string(t.T[0]) == "p" { |
||||
pubkeyHex := string(t.ValueHex()) |
||||
if _, exists := foundPubkeys[pubkeyHex]; !exists { |
||||
foundPubkeys[pubkeyHex] = struct{}{} |
||||
// Decode hex to bytes for next level
|
||||
if pkBytes, err := hex.Dec(pubkeyHex); err == nil { |
||||
nextLevel = append(nextLevel, pkBytes) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
ev.Free() |
||||
} |
||||
} |
||||
|
||||
currentLevel = nextLevel |
||||
} |
||||
|
||||
traversalLatency := time.Since(traversalStart) |
||||
|
||||
mu.Lock() |
||||
totalPubkeysFound += int64(len(foundPubkeys)) |
||||
traversalLatencies = append(traversalLatencies, traversalLatency) |
||||
mu.Unlock() |
||||
} |
||||
}() |
||||
} |
||||
|
||||
// Send work
|
||||
for _, idx := range sampleIndices { |
||||
workChan <- idx |
||||
} |
||||
close(workChan) |
||||
wg.Wait() |
||||
|
||||
duration := time.Since(start) |
||||
|
||||
// Calculate statistics
|
||||
var avgQueryLatency, p90QueryLatency, p95QueryLatency, p99QueryLatency time.Duration |
||||
if len(queryLatencies) > 0 { |
||||
sort.Slice(queryLatencies, func(i, j int) bool { return queryLatencies[i] < queryLatencies[j] }) |
||||
avgQueryLatency = calculateAvgLatency(queryLatencies) |
||||
p90QueryLatency = calculatePercentileLatency(queryLatencies, 0.90) |
||||
p95QueryLatency = calculatePercentileLatency(queryLatencies, 0.95) |
||||
p99QueryLatency = calculatePercentileLatency(queryLatencies, 0.99) |
||||
} |
||||
|
||||
var avgTraversalLatency, p90TraversalLatency, p95TraversalLatency, p99TraversalLatency time.Duration |
||||
if len(traversalLatencies) > 0 { |
||||
sort.Slice(traversalLatencies, func(i, j int) bool { return traversalLatencies[i] < traversalLatencies[j] }) |
||||
avgTraversalLatency = calculateAvgLatency(traversalLatencies) |
||||
p90TraversalLatency = calculatePercentileLatency(traversalLatencies, 0.90) |
||||
p95TraversalLatency = calculatePercentileLatency(traversalLatencies, 0.95) |
||||
p99TraversalLatency = calculatePercentileLatency(traversalLatencies, 0.99) |
||||
} |
||||
|
||||
avgPubkeysPerTraversal := float64(totalPubkeysFound) / float64(sampleSize) |
||||
traversalsPerSec := float64(sampleSize) / duration.Seconds() |
||||
queriesPerSec := float64(totalQueries) / duration.Seconds() |
||||
|
||||
fmt.Printf("\n=== Graph Traversal Results (Network) ===\n") |
||||
fmt.Printf("Traversals completed: %d\n", sampleSize) |
||||
fmt.Printf("Total queries: %d (%.2f queries/sec)\n", totalQueries, queriesPerSec) |
||||
fmt.Printf("Avg pubkeys found per traversal: %.1f\n", avgPubkeysPerTraversal) |
||||
fmt.Printf("Total duration: %v\n", duration) |
||||
fmt.Printf("\nQuery Latencies:\n") |
||||
fmt.Printf(" Avg: %v, P95: %v, P99: %v\n", avgQueryLatency, p95QueryLatency, p99QueryLatency) |
||||
fmt.Printf("\nFull Traversal Latencies (3 degrees):\n") |
||||
fmt.Printf(" Avg: %v, P95: %v, P99: %v\n", avgTraversalLatency, p95TraversalLatency, p99TraversalLatency) |
||||
fmt.Printf("Traversals/sec: %.2f\n", traversalsPerSec) |
||||
|
||||
// Record result for traversal phase
|
||||
result := &BenchmarkResult{ |
||||
TestName: "Graph Traversal (3 Degrees)", |
||||
Duration: duration, |
||||
TotalEvents: int(totalQueries), |
||||
EventsPerSecond: traversalsPerSec, |
||||
AvgLatency: avgTraversalLatency, |
||||
P90Latency: p90TraversalLatency, |
||||
P95Latency: p95TraversalLatency, |
||||
P99Latency: p99TraversalLatency, |
||||
Bottom10Avg: calculateBottom10Avg(traversalLatencies), |
||||
ConcurrentWorkers: numWorkers, |
||||
MemoryUsed: getMemUsage(), |
||||
SuccessRate: 100.0, |
||||
} |
||||
|
||||
n.mu.Lock() |
||||
n.results = append(n.results, result) |
||||
n.mu.Unlock() |
||||
|
||||
// Also record query performance separately
|
||||
queryResult := &BenchmarkResult{ |
||||
TestName: "Graph Queries (Follow Lists)", |
||||
Duration: duration, |
||||
TotalEvents: int(totalQueries), |
||||
EventsPerSecond: queriesPerSec, |
||||
AvgLatency: avgQueryLatency, |
||||
P90Latency: p90QueryLatency, |
||||
P95Latency: p95QueryLatency, |
||||
P99Latency: p99QueryLatency, |
||||
Bottom10Avg: calculateBottom10Avg(queryLatencies), |
||||
ConcurrentWorkers: numWorkers, |
||||
MemoryUsed: getMemUsage(), |
||||
SuccessRate: 100.0, |
||||
} |
||||
|
||||
n.mu.Lock() |
||||
n.results = append(n.results, queryResult) |
||||
n.mu.Unlock() |
||||
} |
||||
|
||||
// RunSuite runs the complete network graph traversal benchmark suite
|
||||
func (n *NetworkGraphTraversalBenchmark) RunSuite(ctx context.Context) error { |
||||
fmt.Println("\n╔════════════════════════════════════════════════════════╗") |
||||
fmt.Println("║ NETWORK GRAPH TRAVERSAL BENCHMARK (100k Pubkeys) ║") |
||||
fmt.Printf("║ Relay: %-46s ║\n", n.relayURL) |
||||
fmt.Println("╚════════════════════════════════════════════════════════╝") |
||||
|
||||
// Step 1: Generate pubkeys
|
||||
n.generatePubkeys() |
||||
|
||||
// Step 2: Generate follow graph
|
||||
n.generateFollowGraph() |
||||
|
||||
// Step 3: Connect to relay
|
||||
fmt.Printf("\nConnecting to relay: %s\n", n.relayURL) |
||||
if err := n.Connect(ctx); err != nil { |
||||
return fmt.Errorf("failed to connect: %w", err) |
||||
} |
||||
defer n.Close() |
||||
fmt.Println("Connected successfully!") |
||||
|
||||
// Step 4: Create follow list events via WebSocket
|
||||
n.createFollowListEvents(ctx) |
||||
|
||||
// Small delay to ensure events are processed
|
||||
fmt.Println("\nWaiting for events to be processed...") |
||||
time.Sleep(5 * time.Second) |
||||
|
||||
// Step 5: Run third-degree traversal benchmark
|
||||
n.runThirdDegreeTraversal(ctx) |
||||
|
||||
fmt.Printf("\n=== Network Graph Traversal Benchmark Complete ===\n\n") |
||||
return nil |
||||
} |
||||
|
||||
// GetResults returns the benchmark results
|
||||
func (n *NetworkGraphTraversalBenchmark) GetResults() []*BenchmarkResult { |
||||
n.mu.RLock() |
||||
defer n.mu.RUnlock() |
||||
return n.results |
||||
} |
||||
|
||||
// PrintResults prints the benchmark results
|
||||
func (n *NetworkGraphTraversalBenchmark) PrintResults() { |
||||
n.mu.RLock() |
||||
defer n.mu.RUnlock() |
||||
|
||||
for _, result := range n.results { |
||||
fmt.Printf("\nTest: %s\n", result.TestName) |
||||
fmt.Printf("Duration: %v\n", result.Duration) |
||||
fmt.Printf("Total Events/Queries: %d\n", result.TotalEvents) |
||||
fmt.Printf("Events/sec: %.2f\n", result.EventsPerSecond) |
||||
fmt.Printf("Success Rate: %.1f%%\n", result.SuccessRate) |
||||
fmt.Printf("Concurrent Workers: %d\n", result.ConcurrentWorkers) |
||||
fmt.Printf("Memory Used: %d MB\n", result.MemoryUsed/(1024*1024)) |
||||
fmt.Printf("Avg Latency: %v\n", result.AvgLatency) |
||||
fmt.Printf("P90 Latency: %v\n", result.P90Latency) |
||||
fmt.Printf("P95 Latency: %v\n", result.P95Latency) |
||||
fmt.Printf("P99 Latency: %v\n", result.P99Latency) |
||||
fmt.Printf("Bottom 10%% Avg Latency: %v\n", result.Bottom10Avg) |
||||
} |
||||
} |
||||
@ -0,0 +1,347 @@
@@ -0,0 +1,347 @@
|
||||
# Graph Query Implementation Phases |
||||
|
||||
This document provides a clear breakdown of implementation phases for NIP-XX Graph Queries. |
||||
|
||||
--- |
||||
|
||||
## Phase 0: Filter Extension Parsing (Foundation) ✅ COMPLETE |
||||
|
||||
**Goal**: Enable the nostr library to correctly "ignore" unknown filter fields per NIP-01, while preserving them for relay-level processing. |
||||
|
||||
### Deliverables (Completed) |
||||
- [x] Modified `filter.F` struct with `Extra` field |
||||
- [x] Modified `Unmarshal()` to skip unknown keys |
||||
- [x] `skipJSONValue()` helper function |
||||
- [x] `graph.ExtractFromFilter()` function |
||||
- [x] Integration in `handle-req.go` |
||||
- [x] Rate limiter with token bucket for graph queries |
||||
|
||||
--- |
||||
|
||||
## Phase 1: E-Tag Graph Index ✅ COMPLETE |
||||
|
||||
**Goal**: Create bidirectional indexes for event-to-event references (e-tags). |
||||
|
||||
### Index Key Structure |
||||
|
||||
``` |
||||
Event-Event Graph (Forward): eeg |
||||
eeg|source_event_serial(5)|target_event_serial(5)|kind(2)|direction(1) = 16 bytes |
||||
|
||||
Event-Event Graph (Reverse): gee |
||||
gee|target_event_serial(5)|kind(2)|direction(1)|source_event_serial(5) = 16 bytes |
||||
``` |
||||
|
||||
### Direction Constants |
||||
- `EdgeDirectionETagOut = 0` - Event references another event (outbound) |
||||
- `EdgeDirectionETagIn = 1` - Event is referenced by another (inbound) |
||||
|
||||
### Deliverables (Completed) |
||||
- [x] Index key definitions for eeg/gee (`pkg/database/indexes/keys.go`) |
||||
- [x] Direction constants for e-tags (`pkg/database/indexes/types/letter.go`) |
||||
- [x] E-tag graph creation in SaveEvent (`pkg/database/save-event.go`) |
||||
- [x] Tests for e-tag graph creation (`pkg/database/etag-graph_test.go`) |
||||
|
||||
**Key Bug Fix**: Buffer reuse in transaction required copying key bytes before writing second key to prevent overwrite. |
||||
|
||||
--- |
||||
|
||||
## Phase 2: Graph Traversal Primitives ✅ COMPLETE |
||||
|
||||
**Goal**: Implement pure index-based graph traversal functions. |
||||
|
||||
### 2.1 Core traversal functions |
||||
|
||||
**File**: `pkg/database/graph-traversal.go` |
||||
|
||||
```go |
||||
// Core primitives (no event decoding required) |
||||
func (d *D) GetPTagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) |
||||
func (d *D) GetETagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) |
||||
func (d *D) GetReferencingEvents(targetSerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) |
||||
func (d *D) GetFollowsFromPubkeySerial(pubkeySerial *types.Uint40) ([]*types.Uint40, error) |
||||
func (d *D) GetFollowersOfPubkeySerial(pubkeySerial *types.Uint40) ([]*types.Uint40, error) |
||||
func (d *D) GetPubkeyHexFromSerial(serial *types.Uint40) (string, error) |
||||
func (d *D) GetEventIDFromSerial(serial *types.Uint40) (string, error) |
||||
``` |
||||
|
||||
### 2.2 GraphResult struct |
||||
|
||||
**File**: `pkg/database/graph-result.go` |
||||
|
||||
```go |
||||
// GraphResult contains depth-organized traversal results |
||||
type GraphResult struct { |
||||
PubkeysByDepth map[int][]string // depth -> pubkeys first discovered at that depth |
||||
EventsByDepth map[int][]string // depth -> events discovered at that depth |
||||
FirstSeenPubkey map[string]int // pubkey hex -> depth where first seen |
||||
FirstSeenEvent map[string]int // event hex -> depth where first seen |
||||
TotalPubkeys int |
||||
TotalEvents int |
||||
InboundRefs map[uint16]map[string][]string // kind -> target -> []referencing_ids |
||||
OutboundRefs map[uint16]map[string][]string // kind -> source -> []referenced_ids |
||||
} |
||||
|
||||
func (r *GraphResult) ToDepthArrays() [][]string // For pubkey results |
||||
func (r *GraphResult) ToEventDepthArrays() [][]string // For event results |
||||
func (r *GraphResult) GetInboundRefsSorted(kind uint16) []RefAggregation |
||||
func (r *GraphResult) GetOutboundRefsSorted(kind uint16) []RefAggregation |
||||
``` |
||||
|
||||
### Deliverables (Completed) |
||||
- [x] Core traversal functions in `graph-traversal.go` |
||||
- [x] GraphResult struct with ToDepthArrays() and ToEventDepthArrays() |
||||
- [x] RefAggregation struct with sorted accessors |
||||
- [x] Tests in `graph-result_test.go` and `graph-traversal_test.go` |
||||
|
||||
--- |
||||
|
||||
## Phase 3: High-Level Traversals ✅ COMPLETE |
||||
|
||||
**Goal**: Implement the graph query methods (follows, followers, mentions, thread). |
||||
|
||||
### 3.1 Follow graph traversal |
||||
|
||||
**File**: `pkg/database/graph-follows.go` |
||||
|
||||
```go |
||||
// TraverseFollows performs BFS traversal of the follow graph |
||||
// Returns pubkeys grouped by first-discovered depth (no duplicates across depths) |
||||
func (d *D) TraverseFollows(seedPubkey []byte, maxDepth int) (*GraphResult, error) |
||||
|
||||
// TraverseFollowers performs BFS traversal to find who follows the seed pubkey |
||||
func (d *D) TraverseFollowers(seedPubkey []byte, maxDepth int) (*GraphResult, error) |
||||
|
||||
// Hex convenience wrappers |
||||
func (d *D) TraverseFollowsFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) |
||||
func (d *D) TraverseFollowersFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) |
||||
``` |
||||
|
||||
### 3.2 Other traversals |
||||
|
||||
**File**: `pkg/database/graph-mentions.go` |
||||
```go |
||||
func (d *D) FindMentions(pubkey []byte, kinds []uint16) (*GraphResult, error) |
||||
func (d *D) FindMentionsFromHex(pubkeyHex string, kinds []uint16) (*GraphResult, error) |
||||
func (d *D) FindMentionsByPubkeys(pubkeySerials []*types.Uint40, kinds []uint16) (*GraphResult, error) |
||||
``` |
||||
|
||||
**File**: `pkg/database/graph-thread.go` |
||||
```go |
||||
func (d *D) TraverseThread(seedEventID []byte, maxDepth int, direction string) (*GraphResult, error) |
||||
func (d *D) TraverseThreadFromHex(seedEventIDHex string, maxDepth int, direction string) (*GraphResult, error) |
||||
func (d *D) GetThreadReplies(eventID []byte, kinds []uint16) (*GraphResult, error) |
||||
func (d *D) GetThreadParents(eventID []byte) (*GraphResult, error) |
||||
``` |
||||
|
||||
### 3.3 Ref aggregation |
||||
|
||||
**File**: `pkg/database/graph-refs.go` |
||||
```go |
||||
func (d *D) AddInboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error |
||||
func (d *D) AddOutboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error |
||||
func (d *D) CollectRefsForPubkeys(pubkeySerials []*types.Uint40, refKinds []uint16, eventKinds []uint16) (*GraphResult, error) |
||||
``` |
||||
|
||||
### Deliverables (Completed) |
||||
- [x] TraverseFollows with early termination (2 consecutive empty depths) |
||||
- [x] TraverseFollowers |
||||
- [x] FindMentions and FindMentionsByPubkeys |
||||
- [x] TraverseThread with bidirectional traversal |
||||
- [x] Inbound/Outbound ref aggregation |
||||
- [x] Tests in `graph-follows_test.go` |
||||
|
||||
--- |
||||
|
||||
## Phase 4: Graph Query Handler and Response Generation ✅ COMPLETE |
||||
|
||||
**Goal**: Wire up the REQ handler to execute graph queries and generate relay-signed response events. |
||||
|
||||
### 4.1 Response Event Generation |
||||
|
||||
**Key Design Decision**: All graph query responses are returned as **relay-signed events**, enabling: |
||||
- Standard client validation (no special handling) |
||||
- Result caching and storage on relays |
||||
- Cryptographic proof of origin |
||||
|
||||
### 4.2 Response Kinds (Implemented) |
||||
|
||||
| Kind | Name | Description | |
||||
|------|------|-------------| |
||||
| 39000 | Graph Follows | Response for follows/followers queries | |
||||
| 39001 | Graph Mentions | Response for mentions queries | |
||||
| 39002 | Graph Thread | Response for thread traversal queries | |
||||
|
||||
### 4.3 Implementation Files |
||||
|
||||
**New files:** |
||||
- `pkg/protocol/graph/executor.go` - Executes graph queries and generates signed responses |
||||
- `pkg/database/graph-adapter.go` - Adapts database to `graph.GraphDatabase` interface |
||||
|
||||
**Modified files:** |
||||
- `app/server.go` - Added `graphExecutor` field |
||||
- `app/main.go` - Initialize graph executor on startup |
||||
- `app/handle-req.go` - Execute graph queries and return results |
||||
|
||||
### 4.4 Response Format (Implemented) |
||||
|
||||
The response is a relay-signed event with JSON content: |
||||
|
||||
```go |
||||
type ResponseContent struct { |
||||
PubkeysByDepth [][]string `json:"pubkeys_by_depth,omitempty"` |
||||
EventsByDepth [][]string `json:"events_by_depth,omitempty"` |
||||
TotalPubkeys int `json:"total_pubkeys,omitempty"` |
||||
TotalEvents int `json:"total_events,omitempty"` |
||||
} |
||||
``` |
||||
|
||||
**Example response event:** |
||||
```json |
||||
{ |
||||
"kind": 39000, |
||||
"pubkey": "<relay_identity_pubkey>", |
||||
"created_at": 1704067200, |
||||
"tags": [ |
||||
["method", "follows"], |
||||
["seed", "<seed_pubkey_hex>"], |
||||
["depth", "2"] |
||||
], |
||||
"content": "{\"pubkeys_by_depth\":[[\"pk1\",\"pk2\"],[\"pk3\",\"pk4\"]],\"total_pubkeys\":4}", |
||||
"sig": "<relay_signature>" |
||||
} |
||||
|
||||
### Deliverables (Completed) |
||||
- [x] Graph executor with query routing (`pkg/protocol/graph/executor.go`) |
||||
- [x] Response event generation with relay signature |
||||
- [x] GraphDatabase interface and adapter |
||||
- [x] Integration in `handle-req.go` |
||||
- [x] All tests passing |
||||
|
||||
--- |
||||
|
||||
## Phase 5: Migration & Configuration |
||||
|
||||
**Goal**: Enable backfilling and configuration. |
||||
|
||||
### 5.1 E-tag graph backfill migration |
||||
|
||||
**File**: `pkg/database/migrations.go` |
||||
|
||||
```go |
||||
func (d *D) MigrateETagGraph() error { |
||||
// Iterate all events |
||||
// Extract e-tags |
||||
// Create eeg/gee edges for targets that exist |
||||
} |
||||
``` |
||||
|
||||
### 5.2 Configuration |
||||
|
||||
**File**: `app/config/config.go` |
||||
|
||||
Add: |
||||
- `ORLY_GRAPH_QUERIES_ENABLED` - enable/disable feature |
||||
- `ORLY_GRAPH_MAX_DEPTH` - maximum traversal depth (default 16) |
||||
- `ORLY_GRAPH_RATE_LIMIT` - queries per minute per connection |
||||
|
||||
### 5.3 NIP-11 advertisement |
||||
|
||||
Update relay info document to advertise support and limits. |
||||
|
||||
### Deliverables |
||||
- [ ] Backfill migration |
||||
- [ ] Configuration options |
||||
- [ ] NIP-11 advertisement |
||||
- [ ] Documentation updates |
||||
|
||||
--- |
||||
|
||||
## Summary: Implementation Order |
||||
|
||||
| Phase | Description | Status | Dependencies | |
||||
|-------|-------------|--------|--------------| |
||||
| **0** | Filter extension parsing | ✅ Complete | None | |
||||
| **1** | E-tag graph index | ✅ Complete | Phase 0 | |
||||
| **2** | Graph traversal primitives | ✅ Complete | Phase 1 | |
||||
| **3** | High-level traversals | ✅ Complete | Phase 2 | |
||||
| **4** | Graph query handler | ✅ Complete | Phase 3 | |
||||
| **5** | Migration & configuration | Pending | Phase 4 | |
||||
|
||||
--- |
||||
|
||||
## Response Format Summary |
||||
|
||||
### Graph-Only Query (no kinds filter) |
||||
|
||||
**Request:** |
||||
```json |
||||
["REQ", "sub", {"_graph": {"method": "follows", "seed": "abc...", "depth": 2}}] |
||||
``` |
||||
|
||||
**Response:** Single signed event with depth arrays |
||||
```json |
||||
["EVENT", "sub", { |
||||
"kind": 39000, |
||||
"pubkey": "<relay_pubkey>", |
||||
"content": "[[\"depth1_pk1\",\"depth1_pk2\"],[\"depth2_pk3\",\"depth2_pk4\"]]", |
||||
"tags": [["d","follows:abc...:2"],["method","follows"],["seed","abc..."],["depth","2"]], |
||||
"sig": "..." |
||||
}] |
||||
["EOSE", "sub"] |
||||
``` |
||||
|
||||
### Query with Event Filters |
||||
|
||||
**Request:** |
||||
```json |
||||
["REQ", "sub", {"_graph": {"method": "follows", "seed": "abc...", "depth": 2}, "kinds": [0]}] |
||||
``` |
||||
|
||||
**Response:** Graph result + events in depth order |
||||
``` |
||||
["EVENT", "sub", <kind-39000 graph result>] |
||||
["EVENT", "sub", <kind-0 for depth-1 pubkey>] |
||||
["EVENT", "sub", <kind-0 for depth-1 pubkey>] |
||||
["EVENT", "sub", <kind-0 for depth-2 pubkey>] |
||||
... |
||||
["EOSE", "sub"] |
||||
``` |
||||
|
||||
### Query with Reference Aggregation |
||||
|
||||
**Request:** |
||||
```json |
||||
["REQ", "sub", {"_graph": {"method": "follows", "seed": "abc...", "depth": 1, "inbound_refs": [{"kinds": [7]}]}}] |
||||
``` |
||||
|
||||
**Response:** Graph result + refs sorted by count (descending) |
||||
``` |
||||
["EVENT", "sub", <kind-39000 with ref summaries>] |
||||
["EVENT", "sub", <kind-39001 target with 523 refs>] |
||||
["EVENT", "sub", <kind-39001 target with 312 refs>] |
||||
["EVENT", "sub", <kind-39001 target with 1 ref>] |
||||
["EOSE", "sub"] |
||||
``` |
||||
|
||||
--- |
||||
|
||||
## Testing Strategy |
||||
|
||||
### Unit Tests |
||||
- Filter parsing with unknown fields |
||||
- Index key encoding/decoding |
||||
- Traversal primitives |
||||
- Result depth array generation |
||||
- Reference sorting |
||||
|
||||
### Integration Tests |
||||
- Full graph query round-trip |
||||
- Response format validation |
||||
- Signature verification |
||||
- Backward compatibility (non-graph REQs still work) |
||||
|
||||
### Performance Tests |
||||
- Traversal latency at various depths |
||||
- Memory usage for large graphs |
||||
- Comparison with event-decoding approach |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,612 @@
@@ -0,0 +1,612 @@
|
||||
# NIP-XX: Graph Queries |
||||
|
||||
`draft` `optional` |
||||
|
||||
This NIP defines an extension to the REQ message filter that enables efficient social graph traversal queries without requiring clients to fetch and decode large numbers of events. |
||||
|
||||
## Motivation |
||||
|
||||
Nostr's social graph is encoded in event tags: |
||||
- **Follow relationships**: Kind-3 events with `p` tags listing followed pubkeys |
||||
- **Event references**: `e` tags linking replies, reactions, reposts to their targets |
||||
- **Mentions**: `p` tags in any event kind referencing other users |
||||
|
||||
Clients building social features (timelines, notifications, discovery) must currently: |
||||
1. Fetch kind-3 events for each user |
||||
2. Decode JSON to extract `p` tags |
||||
3. Recursively fetch more events for multi-hop queries |
||||
4. Aggregate and count references client-side |
||||
|
||||
This is inefficient, especially for: |
||||
- **Multi-hop follow graphs** (friends-of-friends) |
||||
- **Reaction/reply counts** on posts |
||||
- **Thread traversal** for long conversations |
||||
- **Follower discovery** (who follows this user?) |
||||
|
||||
Relays with graph-indexed storage can answer these queries orders of magnitude faster by traversing indexes directly without event decoding. |
||||
|
||||
## Protocol Extension |
||||
|
||||
### Filter Extension: `_graph` |
||||
|
||||
The `_graph` field is added to REQ filters. Per NIP-01, unknown fields are ignored by relays that don't support this extension, ensuring backward compatibility. |
||||
|
||||
```json |
||||
["REQ", "<subscription_id>", { |
||||
"_graph": { |
||||
"method": "<method>", |
||||
"seed": "<hex>", |
||||
"depth": <number>, |
||||
"inbound_refs": [<ref_spec>, ...], |
||||
"outbound_refs": [<ref_spec>, ...] |
||||
}, |
||||
"kinds": [<kind>, ...] |
||||
}] |
||||
``` |
||||
|
||||
### Fields |
||||
|
||||
#### `method` (required) |
||||
|
||||
The graph traversal method to execute: |
||||
|
||||
| Method | Seed Type | Description | |
||||
|--------|-----------|-------------| |
||||
| `follows` | pubkey | Traverse outbound follow relationships via kind-3 `p` tags | |
||||
| `followers` | pubkey | Find pubkeys whose kind-3 events contain `p` tag to seed | |
||||
| `mentions` | pubkey | Find events with `p` tag referencing seed pubkey | |
||||
| `thread` | event ID | Traverse reply chain via `e` tags | |
||||
|
||||
#### `seed` (required) |
||||
|
||||
64-character hex string. Interpretation depends on `method`: |
||||
- For `follows`, `followers`, `mentions`: pubkey hex |
||||
- For `thread`: event ID hex |
||||
|
||||
#### `depth` (optional) |
||||
|
||||
Maximum traversal depth. Integer from 1-16. Default: 1. |
||||
|
||||
- `depth: 1` returns direct connections only |
||||
- `depth: 2` returns connections and their connections (friends-of-friends) |
||||
- Higher depths expand the graph further |
||||
|
||||
**Early termination**: Traversal stops before reaching `depth` if two consecutive depth levels yield no new pubkeys. This prevents unnecessary work when the graph is exhausted. |
||||
|
||||
#### `inbound_refs` (optional) |
||||
|
||||
Array of reference specifications for finding events that **reference** discovered events (via `e` tags). Used to find reactions, replies, reposts, zaps, etc. |
||||
|
||||
```json |
||||
"inbound_refs": [ |
||||
{"kinds": [7], "from_depth": 1}, |
||||
{"kinds": [1, 6], "from_depth": 0} |
||||
] |
||||
``` |
||||
|
||||
#### `outbound_refs` (optional) |
||||
|
||||
Array of reference specifications for finding events **referenced by** discovered events (via `e` tags). Used to find what posts are being replied to, quoted, etc. |
||||
|
||||
```json |
||||
"outbound_refs": [ |
||||
{"kinds": [1], "from_depth": 1} |
||||
] |
||||
``` |
||||
|
||||
#### Reference Specification (`ref_spec`) |
||||
|
||||
```json |
||||
{ |
||||
"kinds": [<kind>, ...], |
||||
"from_depth": <number> |
||||
} |
||||
``` |
||||
|
||||
- `kinds`: Event kinds to match (required, non-empty array) |
||||
- `from_depth`: Only apply this filter from this depth onwards (optional, default: 0) |
||||
|
||||
**Semantics:** |
||||
- Multiple `ref_spec` objects in an array have **AND** semantics (all must match) |
||||
- Multiple kinds within a single `ref_spec` have **OR** semantics (any kind matches) |
||||
- `from_depth: 0` includes references to/from the seed itself |
||||
- `from_depth: 1` starts from first-hop connections |
||||
|
||||
#### `kinds` (standard filter field) |
||||
|
||||
When present alongside `_graph`, specifies which event kinds to return for discovered pubkeys (e.g., kind-0 profiles, kind-1 notes). |
||||
|
||||
## Response Format |
||||
|
||||
### Relay-Signed Result Events |
||||
|
||||
All graph query responses are returned as **signed Nostr events** created by the relay using its identity key. This design provides several benefits: |
||||
|
||||
1. **Standard validation**: Clients validate the response like any normal event - no special handling needed |
||||
2. **Caching**: Results can be stored on relays and retrieved later |
||||
3. **Transparency**: The relay's pubkey identifies who produced the result |
||||
4. **Cryptographic binding**: The signature proves the result came from a specific relay |
||||
|
||||
### Response Kinds |
||||
|
||||
| Kind | Name | Description | |
||||
|------|------|-------------| |
||||
| 39000 | Graph Follows | Response for follows/followers queries | |
||||
| 39001 | Graph Mentions | Response for mentions queries | |
||||
| 39002 | Graph Thread | Response for thread traversal queries | |
||||
|
||||
These are application-specific kinds in the 39000-39999 range. |
||||
|
||||
--- |
||||
|
||||
## Simple Query Response (graph-only filter) |
||||
|
||||
When a REQ contains **only** the `_graph` field (no `kinds`, `authors`, or other filter fields), the relay returns a single signed event containing the graph traversal results organized by depth. |
||||
|
||||
### Request Format |
||||
|
||||
```json |
||||
["REQ", "<sub>", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "<pubkey_hex>", |
||||
"depth": 3 |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
### Response: Kind 39000 Graph Result Event |
||||
|
||||
```json |
||||
{ |
||||
"kind": 39000, |
||||
"pubkey": "<relay_identity_pubkey>", |
||||
"created_at": <timestamp>, |
||||
"tags": [ |
||||
["method", "follows"], |
||||
["seed", "<seed_hex>"], |
||||
["depth", "3"] |
||||
], |
||||
"content": "{\"pubkeys_by_depth\":[[\"pubkey1\",\"pubkey2\"],[\"pubkey3\",\"pubkey4\"]],\"total_pubkeys\":4}", |
||||
"id": "<event_id>", |
||||
"sig": "<relay_signature>" |
||||
} |
||||
``` |
||||
|
||||
### Content Structure |
||||
|
||||
The `content` field contains a JSON object with depth arrays: |
||||
|
||||
```json |
||||
{ |
||||
"pubkeys_by_depth": [ |
||||
["<pubkey_depth_1>", "<pubkey_depth_1>", ...], |
||||
["<pubkey_depth_2>", "<pubkey_depth_2>", ...], |
||||
["<pubkey_depth_3>", "<pubkey_depth_3>", ...] |
||||
], |
||||
"total_pubkeys": 150 |
||||
} |
||||
``` |
||||
|
||||
For event-based queries (mentions, thread), the structure is: |
||||
|
||||
```json |
||||
{ |
||||
"events_by_depth": [ |
||||
["<event_id_depth_1>", ...], |
||||
["<event_id_depth_2>", ...] |
||||
], |
||||
"total_events": 42 |
||||
} |
||||
``` |
||||
|
||||
**Key properties:** |
||||
- **Array index = depth - 1**: Index 0 contains depth-1 pubkeys (direct follows) |
||||
- **Unique per depth**: Each pubkey/event appears only at the depth where it was **first discovered** |
||||
- **No duplicates**: A pubkey in depth 1 will NOT appear in depth 2 or 3 |
||||
- **Hex format**: All pubkeys and event IDs are 64-character lowercase hex strings |
||||
|
||||
### Example |
||||
|
||||
Alice follows Bob and Carol. Bob follows Dave. Carol follows Dave and Eve. |
||||
|
||||
Request: |
||||
```json |
||||
["REQ", "follow-net", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "<alice_pubkey>", |
||||
"depth": 2 |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
Response: |
||||
```json |
||||
["EVENT", "follow-net", { |
||||
"kind": 39000, |
||||
"pubkey": "<relay_pubkey>", |
||||
"created_at": 1704067200, |
||||
"tags": [ |
||||
["method", "follows"], |
||||
["seed", "<alice_pubkey>"], |
||||
["depth", "2"] |
||||
], |
||||
"content": "{\"pubkeys_by_depth\":[[\"<bob_pubkey>\",\"<carol_pubkey>\"],[\"<dave_pubkey>\",\"<eve_pubkey>\"]],\"total_pubkeys\":4}", |
||||
"sig": "<signature>" |
||||
}] |
||||
["EOSE", "follow-net"] |
||||
``` |
||||
|
||||
**Interpretation:** |
||||
- Depth 1 (index 0): Bob, Carol (Alice's direct follows) |
||||
- Depth 2 (index 1): Dave, Eve (friends-of-friends, excluding Bob and Carol) |
||||
- Note: Dave appears only once even though both Bob and Carol follow Dave |
||||
|
||||
--- |
||||
|
||||
## Query with Additional Filters |
||||
|
||||
When the REQ includes both `_graph` AND other filter fields (like `kinds`), the relay: |
||||
|
||||
1. Executes the graph traversal to discover pubkeys |
||||
2. Fetches the requested events for those pubkeys |
||||
3. Returns events in **ascending depth order** |
||||
|
||||
### Request Format |
||||
|
||||
```json |
||||
["REQ", "<sub>", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "<pubkey_hex>", |
||||
"depth": 2 |
||||
}, |
||||
"kinds": [0, 1] |
||||
}] |
||||
``` |
||||
|
||||
### Response |
||||
|
||||
``` |
||||
["EVENT", "<sub>", <kind-39000 graph result event>] |
||||
["EVENT", "<sub>", <kind-0 profile for depth-1 pubkey>] |
||||
["EVENT", "<sub>", <kind-1 note for depth-1 pubkey>] |
||||
... (all depth-1 events) |
||||
["EVENT", "<sub>", <kind-0 profile for depth-2 pubkey>] |
||||
["EVENT", "<sub>", <kind-1 note for depth-2 pubkey>] |
||||
... (all depth-2 events) |
||||
["EOSE", "<sub>"] |
||||
``` |
||||
|
||||
The graph result event (kind 39000) is sent first, allowing clients to know the complete graph structure before receiving individual events. |
||||
|
||||
--- |
||||
|
||||
## Query with Reference Aggregation (Planned) |
||||
|
||||
> **Note:** Reference aggregation is planned for a future implementation phase. The following describes the intended behavior. |
||||
|
||||
When `inbound_refs` or `outbound_refs` are specified, the response will include aggregated reference data **sorted by count descending** (most referenced first). |
||||
|
||||
### Request Format |
||||
|
||||
```json |
||||
["REQ", "popular-posts", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "<pubkey_hex>", |
||||
"depth": 1, |
||||
"inbound_refs": [ |
||||
{"kinds": [7], "from_depth": 1} |
||||
] |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
### Response (Planned) |
||||
|
||||
``` |
||||
["EVENT", "popular-posts", <kind-39000 graph result with ref summaries>] |
||||
["EVENT", "popular-posts", <aggregated ref event with 523 reactions>] |
||||
["EVENT", "popular-posts", <aggregated ref event with 312 reactions>] |
||||
... |
||||
["EVENT", "popular-posts", <aggregated ref event with 1 reaction>] |
||||
["EOSE", "popular-posts"] |
||||
``` |
||||
|
||||
### Kind 39001: Graph Mentions Result |
||||
|
||||
Used for `mentions` queries. Contains events that mention the seed pubkey: |
||||
|
||||
```json |
||||
{ |
||||
"kind": 39001, |
||||
"pubkey": "<relay_pubkey>", |
||||
"created_at": <timestamp>, |
||||
"tags": [ |
||||
["method", "mentions"], |
||||
["seed", "<seed_pubkey_hex>"], |
||||
["depth", "1"] |
||||
], |
||||
"content": "{\"events_by_depth\":[[\"<event_id_1>\",\"<event_id_2>\",...]],\"total_events\":42}", |
||||
"sig": "<signature>" |
||||
} |
||||
``` |
||||
|
||||
### Kind 39002: Graph Thread Result |
||||
|
||||
Used for `thread` queries. Contains events in a reply thread: |
||||
|
||||
```json |
||||
{ |
||||
"kind": 39002, |
||||
"pubkey": "<relay_pubkey>", |
||||
"created_at": <timestamp>, |
||||
"tags": [ |
||||
["method", "thread"], |
||||
["seed", "<seed_event_id_hex>"], |
||||
["depth", "10"] |
||||
], |
||||
"content": "{\"events_by_depth\":[[\"<reply_id_1>\",...],[\"<reply_id_2>\",...]],\"total_events\":156}", |
||||
"sig": "<signature>" |
||||
} |
||||
``` |
||||
|
||||
### Reference Aggregation (Future) |
||||
|
||||
When `inbound_refs` or `outbound_refs` are specified, the response includes aggregated reference data sorted by count descending. This feature is planned for a future implementation phase. |
||||
|
||||
--- |
||||
|
||||
## Examples |
||||
|
||||
### Example 1: Get Follow Network (Graph Only) |
||||
|
||||
Get Alice's 2-hop follow network as a single signed event: |
||||
|
||||
```json |
||||
["REQ", "follow-network", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "abc123...def456", |
||||
"depth": 2 |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
**Response:** |
||||
```json |
||||
["EVENT", "follow-network", { |
||||
"kind": 39000, |
||||
"pubkey": "<relay_pubkey>", |
||||
"tags": [ |
||||
["method", "follows"], |
||||
["seed", "abc123...def456"], |
||||
["depth", "2"] |
||||
], |
||||
"content": "{\"pubkeys_by_depth\":[[\"pub1\",\"pub2\",...150 pubkeys],[\"pub151\",\"pub152\",...3420 pubkeys]],\"total_pubkeys\":3570}", |
||||
"sig": "<signature>" |
||||
}] |
||||
["EOSE", "follow-network"] |
||||
``` |
||||
|
||||
The content JSON object contains: |
||||
- `pubkeys_by_depth[0]`: 150 pubkeys (depth 1 - direct follows) |
||||
- `pubkeys_by_depth[1]`: 3420 pubkeys (depth 2 - friends-of-friends, excluding depth 1) |
||||
- `total_pubkeys`: 3570 (total unique pubkeys discovered) |
||||
|
||||
### Example 2: Follow Network with Profiles |
||||
|
||||
```json |
||||
["REQ", "follow-profiles", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "abc123...def456", |
||||
"depth": 2 |
||||
}, |
||||
"kinds": [0] |
||||
}] |
||||
``` |
||||
|
||||
**Response:** |
||||
``` |
||||
["EVENT", "follow-profiles", <kind-39000 graph result>] |
||||
["EVENT", "follow-profiles", <kind-0 for depth-1 follow>] |
||||
... (150 depth-1 profiles) |
||||
["EVENT", "follow-profiles", <kind-0 for depth-2 follow>] |
||||
... (3420 depth-2 profiles) |
||||
["EOSE", "follow-profiles"] |
||||
``` |
||||
|
||||
### Example 3: Popular Posts by Reactions |
||||
|
||||
Find reactions to posts by Alice's follows, sorted by popularity: |
||||
|
||||
```json |
||||
["REQ", "popular-posts", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "abc123...def456", |
||||
"depth": 1, |
||||
"inbound_refs": [ |
||||
{"kinds": [7], "from_depth": 1} |
||||
] |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
**Response:** Most-reacted posts first, down to posts with only 1 reaction. |
||||
|
||||
### Example 4: Thread Traversal |
||||
|
||||
Fetch a complete reply thread: |
||||
|
||||
```json |
||||
["REQ", "thread", { |
||||
"_graph": { |
||||
"method": "thread", |
||||
"seed": "root_event_id_hex", |
||||
"depth": 10, |
||||
"inbound_refs": [ |
||||
{"kinds": [1], "from_depth": 0} |
||||
] |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
### Example 5: Who Follows Me? |
||||
|
||||
Find pubkeys that follow Alice: |
||||
|
||||
```json |
||||
["REQ", "my-followers", { |
||||
"_graph": { |
||||
"method": "followers", |
||||
"seed": "alice_pubkey_hex", |
||||
"depth": 1 |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
**Response:** Single kind-39000 event with follower pubkeys in content. |
||||
|
||||
### Example 6: Reactions AND Reposts (AND semantics) |
||||
|
||||
Find posts with both reactions and reposts: |
||||
|
||||
```json |
||||
["REQ", "engaged-posts", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "abc123...def456", |
||||
"depth": 1, |
||||
"inbound_refs": [ |
||||
{"kinds": [7], "from_depth": 1}, |
||||
{"kinds": [6], "from_depth": 1} |
||||
] |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
This returns only posts that have **both** kind-7 reactions AND kind-6 reposts. |
||||
|
||||
### Example 7: Reactions OR Reposts (OR semantics) |
||||
|
||||
Find posts with either reactions or reposts: |
||||
|
||||
```json |
||||
["REQ", "any-engagement", { |
||||
"_graph": { |
||||
"method": "follows", |
||||
"seed": "abc123...def456", |
||||
"depth": 1, |
||||
"inbound_refs": [ |
||||
{"kinds": [6, 7], "from_depth": 1} |
||||
] |
||||
} |
||||
}] |
||||
``` |
||||
|
||||
--- |
||||
|
||||
## Client Implementation Notes |
||||
|
||||
### Validating Graph Results |
||||
|
||||
Graph result events are signed by the relay's identity key. Clients should: |
||||
|
||||
1. Verify the signature as with any event |
||||
2. Optionally verify the relay pubkey matches the connected relay |
||||
3. Parse the `content` JSON to extract depth-organized results |
||||
|
||||
### Caching Results |
||||
|
||||
Because graph results are standard signed events, clients can: |
||||
|
||||
1. Store results locally for offline access |
||||
2. Optionally publish results to relays for sharing |
||||
3. Use the `method`, `seed`, and `depth` tags to identify equivalent queries |
||||
4. Compare `created_at` timestamps to determine freshness |
||||
|
||||
### Trust Considerations |
||||
|
||||
The relay is asserting "this is what the graph looks like from my perspective." Clients may want to: |
||||
|
||||
1. Query multiple relays and compare results |
||||
2. Prefer relays they trust for graph queries |
||||
3. Use the response as a starting point and verify critical paths independently |
||||
|
||||
--- |
||||
|
||||
## Relay Implementation Notes |
||||
|
||||
### Index Requirements |
||||
|
||||
Efficient implementation requires bidirectional graph indexes: |
||||
|
||||
**Pubkey Graph:** |
||||
- Event → Pubkey edges (author relationship, `p` tag references) |
||||
- Pubkey → Event edges (reverse lookup) |
||||
|
||||
**Event Graph:** |
||||
- Event → Event edges (`e` tag references) |
||||
- Event → Event reverse edges (what references this event) |
||||
|
||||
Both indexes should include: |
||||
- Event kind (for filtering) |
||||
- Direction (author vs tag, inbound vs outbound) |
||||
|
||||
### Query Execution |
||||
|
||||
1. **Resolve seed**: Convert seed hex to internal identifier |
||||
2. **BFS traversal**: Traverse graph to specified depth, tracking first-seen depth |
||||
3. **Deduplication**: Each pubkey appears only at its first-discovered depth |
||||
4. **Collect refs**: If `inbound_refs`/`outbound_refs` specified, scan reference indexes |
||||
5. **Aggregate**: Group references by target/source, count occurrences |
||||
6. **Sort**: Order by count descending (for refs) |
||||
7. **Sign response**: Create and sign relay events with identity key |
||||
|
||||
### Performance Considerations |
||||
|
||||
- Use serial-based internal identifiers (5-byte) instead of full 32-byte IDs |
||||
- Pre-compute common aggregations if possible |
||||
- Set reasonable limits on depth (default max: 16) and result counts |
||||
- Consider caching frequent queries |
||||
- Use rate limiting to prevent abuse |
||||
|
||||
--- |
||||
|
||||
## Backward Compatibility |
||||
|
||||
- Relays not supporting this NIP will ignore the `_graph` field per NIP-01 |
||||
- Clients should detect support via NIP-11 relay information document |
||||
- Response events (39000, 39001, 39002) are standard Nostr events |
||||
|
||||
## NIP-11 Advertisement |
||||
|
||||
Relays supporting this NIP should advertise it: |
||||
|
||||
```json |
||||
{ |
||||
"supported_nips": [1, "XX"], |
||||
"limitation": { |
||||
"graph_query_max_depth": 16 |
||||
} |
||||
} |
||||
``` |
||||
|
||||
## Security Considerations |
||||
|
||||
- **Rate limiting**: Graph queries can be expensive; relays should rate limit |
||||
- **Depth limits**: Maximum depth should be capped (recommended: 16) |
||||
- **Result limits**: Large follow graphs can return many results; consider size limits |
||||
- **Authentication**: Relays may require NIP-42 auth for graph queries |
||||
|
||||
## References |
||||
|
||||
- [NIP-01](https://github.com/nostr-protocol/nips/blob/master/01.md): Basic protocol |
||||
- [NIP-02](https://github.com/nostr-protocol/nips/blob/master/02.md): Follow lists (kind 3) |
||||
- [NIP-11](https://github.com/nostr-protocol/nips/blob/master/11.md): Relay information |
||||
- [NIP-33](https://github.com/nostr-protocol/nips/blob/master/33.md): Parameterized replaceable events |
||||
- [NIP-42](https://github.com/nostr-protocol/nips/blob/master/42.md): Authentication |
||||
@ -0,0 +1,460 @@
@@ -0,0 +1,460 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"testing" |
||||
|
||||
"github.com/dgraph-io/badger/v4" |
||||
"next.orly.dev/pkg/database/indexes" |
||||
"next.orly.dev/pkg/database/indexes/types" |
||||
"git.mleku.dev/mleku/nostr/encoders/event" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
"git.mleku.dev/mleku/nostr/encoders/tag" |
||||
) |
||||
|
||||
func TestETagGraphEdgeCreation(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create a parent event (the post being replied to)
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
parentID := make([]byte, 32) |
||||
parentID[0] = 0x10 |
||||
parentSig := make([]byte, 64) |
||||
parentSig[0] = 0x10 |
||||
|
||||
parentEvent := &event.E{ |
||||
ID: parentID, |
||||
Pubkey: parentPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("This is the parent post"), |
||||
Sig: parentSig, |
||||
Tags: &tag.S{}, |
||||
} |
||||
_, err = db.SaveEvent(ctx, parentEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save parent event: %v", err) |
||||
} |
||||
|
||||
// Create a reply event with e-tag pointing to parent
|
||||
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
replyID := make([]byte, 32) |
||||
replyID[0] = 0x20 |
||||
replySig := make([]byte, 64) |
||||
replySig[0] = 0x20 |
||||
|
||||
replyEvent := &event.E{ |
||||
ID: replyID, |
||||
Pubkey: replyPubkey, |
||||
CreatedAt: 1234567891, |
||||
Kind: 1, |
||||
Content: []byte("This is a reply"), |
||||
Sig: replySig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("e", hex.Enc(parentID)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, replyEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save reply event: %v", err) |
||||
} |
||||
|
||||
// Get serials for both events
|
||||
parentSerial, err := db.GetSerialById(parentID) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get parent serial: %v", err) |
||||
} |
||||
replySerial, err := db.GetSerialById(replyID) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get reply serial: %v", err) |
||||
} |
||||
|
||||
t.Logf("Parent serial: %d, Reply serial: %d", parentSerial.Get(), replySerial.Get()) |
||||
|
||||
// Verify forward edge exists (reply -> parent)
|
||||
forwardFound := false |
||||
prefix := []byte(indexes.EventEventGraphPrefix) |
||||
|
||||
err = db.View(func(txn *badger.Txn) error { |
||||
it := txn.NewIterator(badger.DefaultIteratorOptions) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { |
||||
item := it.Item() |
||||
key := item.KeyCopy(nil) |
||||
|
||||
// Decode the key
|
||||
srcSer, tgtSer, kind, direction := indexes.EventEventGraphVars() |
||||
keyReader := bytes.NewReader(key) |
||||
if err := indexes.EventEventGraphDec(srcSer, tgtSer, kind, direction).UnmarshalRead(keyReader); err != nil { |
||||
t.Logf("Failed to decode key: %v", err) |
||||
continue |
||||
} |
||||
|
||||
// Check if this is our edge
|
||||
if srcSer.Get() == replySerial.Get() && tgtSer.Get() == parentSerial.Get() { |
||||
forwardFound = true |
||||
if direction.Letter() != types.EdgeDirectionETagOut { |
||||
t.Errorf("Expected direction %d, got %d", types.EdgeDirectionETagOut, direction.Letter()) |
||||
} |
||||
if kind.Get() != 1 { |
||||
t.Errorf("Expected kind 1, got %d", kind.Get()) |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("View failed: %v", err) |
||||
} |
||||
if !forwardFound { |
||||
t.Error("Forward edge (reply -> parent) should exist") |
||||
} |
||||
|
||||
// Verify reverse edge exists (parent <- reply)
|
||||
reverseFound := false |
||||
prefix = []byte(indexes.GraphEventEventPrefix) |
||||
|
||||
err = db.View(func(txn *badger.Txn) error { |
||||
it := txn.NewIterator(badger.DefaultIteratorOptions) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { |
||||
item := it.Item() |
||||
key := item.KeyCopy(nil) |
||||
|
||||
// Decode the key
|
||||
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars() |
||||
keyReader := bytes.NewReader(key) |
||||
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil { |
||||
t.Logf("Failed to decode key: %v", err) |
||||
continue |
||||
} |
||||
|
||||
t.Logf("Found gee edge: tgt=%d kind=%d dir=%d src=%d", |
||||
tgtSer.Get(), kind.Get(), direction.Letter(), srcSer.Get()) |
||||
|
||||
// Check if this is our edge
|
||||
if tgtSer.Get() == parentSerial.Get() && srcSer.Get() == replySerial.Get() { |
||||
reverseFound = true |
||||
if direction.Letter() != types.EdgeDirectionETagIn { |
||||
t.Errorf("Expected direction %d, got %d", types.EdgeDirectionETagIn, direction.Letter()) |
||||
} |
||||
if kind.Get() != 1 { |
||||
t.Errorf("Expected kind 1, got %d", kind.Get()) |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("View failed: %v", err) |
||||
} |
||||
if !reverseFound { |
||||
t.Error("Reverse edge (parent <- reply) should exist") |
||||
} |
||||
} |
||||
|
||||
func TestETagGraphMultipleReplies(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create a parent event
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
parentID := make([]byte, 32) |
||||
parentID[0] = 0x10 |
||||
parentSig := make([]byte, 64) |
||||
parentSig[0] = 0x10 |
||||
|
||||
parentEvent := &event.E{ |
||||
ID: parentID, |
||||
Pubkey: parentPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("Parent post"), |
||||
Sig: parentSig, |
||||
Tags: &tag.S{}, |
||||
} |
||||
_, err = db.SaveEvent(ctx, parentEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save parent: %v", err) |
||||
} |
||||
|
||||
// Create multiple replies
|
||||
numReplies := 5 |
||||
for i := 0; i < numReplies; i++ { |
||||
replyPubkey := make([]byte, 32) |
||||
replyPubkey[0] = byte(i + 0x20) |
||||
replyID := make([]byte, 32) |
||||
replyID[0] = byte(i + 0x30) |
||||
replySig := make([]byte, 64) |
||||
replySig[0] = byte(i + 0x30) |
||||
|
||||
replyEvent := &event.E{ |
||||
ID: replyID, |
||||
Pubkey: replyPubkey, |
||||
CreatedAt: int64(1234567891 + i), |
||||
Kind: 1, |
||||
Content: []byte("Reply"), |
||||
Sig: replySig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("e", hex.Enc(parentID)), |
||||
), |
||||
} |
||||
_, err := db.SaveEvent(ctx, replyEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save reply %d: %v", i, err) |
||||
} |
||||
} |
||||
|
||||
// Count inbound edges to parent
|
||||
parentSerial, err := db.GetSerialById(parentID) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get parent serial: %v", err) |
||||
} |
||||
|
||||
inboundCount := 0 |
||||
prefix := []byte(indexes.GraphEventEventPrefix) |
||||
|
||||
err = db.View(func(txn *badger.Txn) error { |
||||
it := txn.NewIterator(badger.DefaultIteratorOptions) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { |
||||
item := it.Item() |
||||
key := item.KeyCopy(nil) |
||||
|
||||
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars() |
||||
keyReader := bytes.NewReader(key) |
||||
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil { |
||||
continue |
||||
} |
||||
|
||||
if tgtSer.Get() == parentSerial.Get() { |
||||
inboundCount++ |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("View failed: %v", err) |
||||
} |
||||
|
||||
if inboundCount != numReplies { |
||||
t.Errorf("Expected %d inbound edges, got %d", numReplies, inboundCount) |
||||
} |
||||
} |
||||
|
||||
func TestETagGraphDifferentKinds(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create a parent event (kind 1 - note)
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
parentID := make([]byte, 32) |
||||
parentID[0] = 0x10 |
||||
parentSig := make([]byte, 64) |
||||
parentSig[0] = 0x10 |
||||
|
||||
parentEvent := &event.E{ |
||||
ID: parentID, |
||||
Pubkey: parentPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("A note"), |
||||
Sig: parentSig, |
||||
Tags: &tag.S{}, |
||||
} |
||||
_, err = db.SaveEvent(ctx, parentEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save parent: %v", err) |
||||
} |
||||
|
||||
// Create a reaction (kind 7)
|
||||
reactionPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
reactionID := make([]byte, 32) |
||||
reactionID[0] = 0x20 |
||||
reactionSig := make([]byte, 64) |
||||
reactionSig[0] = 0x20 |
||||
|
||||
reactionEvent := &event.E{ |
||||
ID: reactionID, |
||||
Pubkey: reactionPubkey, |
||||
CreatedAt: 1234567891, |
||||
Kind: 7, |
||||
Content: []byte("+"), |
||||
Sig: reactionSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("e", hex.Enc(parentID)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, reactionEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save reaction: %v", err) |
||||
} |
||||
|
||||
// Create a repost (kind 6)
|
||||
repostPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003") |
||||
repostID := make([]byte, 32) |
||||
repostID[0] = 0x30 |
||||
repostSig := make([]byte, 64) |
||||
repostSig[0] = 0x30 |
||||
|
||||
repostEvent := &event.E{ |
||||
ID: repostID, |
||||
Pubkey: repostPubkey, |
||||
CreatedAt: 1234567892, |
||||
Kind: 6, |
||||
Content: []byte(""), |
||||
Sig: repostSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("e", hex.Enc(parentID)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, repostEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save repost: %v", err) |
||||
} |
||||
|
||||
// Query inbound edges by kind
|
||||
parentSerial, err := db.GetSerialById(parentID) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get parent serial: %v", err) |
||||
} |
||||
|
||||
kindCounts := make(map[uint16]int) |
||||
prefix := []byte(indexes.GraphEventEventPrefix) |
||||
|
||||
err = db.View(func(txn *badger.Txn) error { |
||||
it := txn.NewIterator(badger.DefaultIteratorOptions) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { |
||||
item := it.Item() |
||||
key := item.KeyCopy(nil) |
||||
|
||||
tgtSer, kind, direction, srcSer := indexes.GraphEventEventVars() |
||||
keyReader := bytes.NewReader(key) |
||||
if err := indexes.GraphEventEventDec(tgtSer, kind, direction, srcSer).UnmarshalRead(keyReader); err != nil { |
||||
continue |
||||
} |
||||
|
||||
if tgtSer.Get() == parentSerial.Get() { |
||||
kindCounts[kind.Get()]++ |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("View failed: %v", err) |
||||
} |
||||
|
||||
// Verify we have edges for each kind
|
||||
if kindCounts[7] != 1 { |
||||
t.Errorf("Expected 1 kind-7 (reaction) edge, got %d", kindCounts[7]) |
||||
} |
||||
if kindCounts[6] != 1 { |
||||
t.Errorf("Expected 1 kind-6 (repost) edge, got %d", kindCounts[6]) |
||||
} |
||||
} |
||||
|
||||
func TestETagGraphUnknownTarget(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create an event with e-tag pointing to non-existent event
|
||||
unknownID := make([]byte, 32) |
||||
unknownID[0] = 0xFF |
||||
unknownID[31] = 0xFF |
||||
|
||||
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
replyID := make([]byte, 32) |
||||
replyID[0] = 0x10 |
||||
replySig := make([]byte, 64) |
||||
replySig[0] = 0x10 |
||||
|
||||
replyEvent := &event.E{ |
||||
ID: replyID, |
||||
Pubkey: replyPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("Reply to unknown"), |
||||
Sig: replySig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("e", hex.Enc(unknownID)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, replyEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save reply: %v", err) |
||||
} |
||||
|
||||
// Verify event was saved
|
||||
replySerial, err := db.GetSerialById(replyID) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get reply serial: %v", err) |
||||
} |
||||
if replySerial == nil { |
||||
t.Fatal("Reply serial should exist") |
||||
} |
||||
|
||||
// Verify no forward edge was created (since target doesn't exist)
|
||||
edgeCount := 0 |
||||
prefix := []byte(indexes.EventEventGraphPrefix) |
||||
|
||||
err = db.View(func(txn *badger.Txn) error { |
||||
it := txn.NewIterator(badger.DefaultIteratorOptions) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { |
||||
item := it.Item() |
||||
key := item.KeyCopy(nil) |
||||
|
||||
srcSer, _, _, _ := indexes.EventEventGraphVars() |
||||
keyReader := bytes.NewReader(key) |
||||
if err := indexes.EventEventGraphDec(srcSer, new(types.Uint40), new(types.Uint16), new(types.Letter)).UnmarshalRead(keyReader); err != nil { |
||||
continue |
||||
} |
||||
|
||||
if srcSer.Get() == replySerial.Get() { |
||||
edgeCount++ |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("View failed: %v", err) |
||||
} |
||||
|
||||
if edgeCount != 0 { |
||||
t.Errorf("Expected no edges for unknown target, got %d", edgeCount) |
||||
} |
||||
} |
||||
@ -0,0 +1,42 @@
@@ -0,0 +1,42 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"next.orly.dev/pkg/protocol/graph" |
||||
) |
||||
|
||||
// GraphAdapter wraps a database instance and implements graph.GraphDatabase interface.
|
||||
// This allows the graph executor to call database traversal methods without
|
||||
// the database package importing the graph package.
|
||||
type GraphAdapter struct { |
||||
db *D |
||||
} |
||||
|
||||
// NewGraphAdapter creates a new GraphAdapter wrapping the given database.
|
||||
func NewGraphAdapter(db *D) *GraphAdapter { |
||||
return &GraphAdapter{db: db} |
||||
} |
||||
|
||||
// TraverseFollows implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) TraverseFollows(seedPubkey []byte, maxDepth int) (graph.GraphResultI, error) { |
||||
return a.db.TraverseFollows(seedPubkey, maxDepth) |
||||
} |
||||
|
||||
// TraverseFollowers implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) TraverseFollowers(seedPubkey []byte, maxDepth int) (graph.GraphResultI, error) { |
||||
return a.db.TraverseFollowers(seedPubkey, maxDepth) |
||||
} |
||||
|
||||
// FindMentions implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) FindMentions(pubkey []byte, kinds []uint16) (graph.GraphResultI, error) { |
||||
return a.db.FindMentions(pubkey, kinds) |
||||
} |
||||
|
||||
// TraverseThread implements graph.GraphDatabase.
|
||||
func (a *GraphAdapter) TraverseThread(seedEventID []byte, maxDepth int, direction string) (graph.GraphResultI, error) { |
||||
return a.db.TraverseThread(seedEventID, maxDepth, direction) |
||||
} |
||||
|
||||
// Verify GraphAdapter implements graph.GraphDatabase
|
||||
var _ graph.GraphDatabase = (*GraphAdapter)(nil) |
||||
@ -0,0 +1,199 @@
@@ -0,0 +1,199 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"lol.mleku.dev/log" |
||||
"next.orly.dev/pkg/database/indexes/types" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
) |
||||
|
||||
// TraverseFollows performs BFS traversal of the follow graph starting from a seed pubkey.
|
||||
// Returns pubkeys grouped by first-discovered depth (no duplicates across depths).
|
||||
//
|
||||
// The traversal works by:
|
||||
// 1. Starting with the seed pubkey at depth 0 (not included in results)
|
||||
// 2. For each pubkey at the current depth, find their kind-3 contact list
|
||||
// 3. Extract p-tags from the contact list to get follows
|
||||
// 4. Add new (unseen) follows to the next depth
|
||||
// 5. Continue until maxDepth is reached or no new pubkeys are found
|
||||
//
|
||||
// Early termination occurs if two consecutive depths yield no new pubkeys.
|
||||
func (d *D) TraverseFollows(seedPubkey []byte, maxDepth int) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
if len(seedPubkey) != 32 { |
||||
return result, ErrPubkeyNotFound |
||||
} |
||||
|
||||
// Get seed pubkey serial
|
||||
seedSerial, err := d.GetPubkeySerial(seedPubkey) |
||||
if err != nil { |
||||
log.D.F("TraverseFollows: seed pubkey not in database: %s", hex.Enc(seedPubkey)) |
||||
return result, nil // Not an error - just no results
|
||||
} |
||||
|
||||
// Track visited pubkeys by serial to avoid cycles
|
||||
visited := make(map[uint64]bool) |
||||
visited[seedSerial.Get()] = true // Mark seed as visited but don't add to results
|
||||
|
||||
// Current frontier (pubkeys to process at this depth)
|
||||
currentFrontier := []*types.Uint40{seedSerial} |
||||
|
||||
// Track consecutive empty depths for early termination
|
||||
consecutiveEmptyDepths := 0 |
||||
|
||||
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ { |
||||
var nextFrontier []*types.Uint40 |
||||
newPubkeysAtDepth := 0 |
||||
|
||||
for _, pubkeySerial := range currentFrontier { |
||||
// Get follows for this pubkey
|
||||
follows, err := d.GetFollowsFromPubkeySerial(pubkeySerial) |
||||
if err != nil { |
||||
log.D.F("TraverseFollows: error getting follows for serial %d: %v", pubkeySerial.Get(), err) |
||||
continue |
||||
} |
||||
|
||||
for _, followSerial := range follows { |
||||
// Skip if already visited
|
||||
if visited[followSerial.Get()] { |
||||
continue |
||||
} |
||||
visited[followSerial.Get()] = true |
||||
|
||||
// Get pubkey hex for result
|
||||
pubkeyHex, err := d.GetPubkeyHexFromSerial(followSerial) |
||||
if err != nil { |
||||
log.D.F("TraverseFollows: error getting pubkey hex for serial %d: %v", followSerial.Get(), err) |
||||
continue |
||||
} |
||||
|
||||
// Add to results at this depth
|
||||
result.AddPubkeyAtDepth(pubkeyHex, currentDepth) |
||||
newPubkeysAtDepth++ |
||||
|
||||
// Add to next frontier for further traversal
|
||||
nextFrontier = append(nextFrontier, followSerial) |
||||
} |
||||
} |
||||
|
||||
log.T.F("TraverseFollows: depth %d found %d new pubkeys", currentDepth, newPubkeysAtDepth) |
||||
|
||||
// Check for early termination
|
||||
if newPubkeysAtDepth == 0 { |
||||
consecutiveEmptyDepths++ |
||||
if consecutiveEmptyDepths >= 2 { |
||||
log.T.F("TraverseFollows: early termination at depth %d (2 consecutive empty depths)", currentDepth) |
||||
break |
||||
} |
||||
} else { |
||||
consecutiveEmptyDepths = 0 |
||||
} |
||||
|
||||
// Move to next depth
|
||||
currentFrontier = nextFrontier |
||||
} |
||||
|
||||
log.D.F("TraverseFollows: completed with %d total pubkeys across %d depths", |
||||
result.TotalPubkeys, len(result.PubkeysByDepth)) |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// TraverseFollowers performs BFS traversal to find who follows the seed pubkey.
|
||||
// This is the reverse of TraverseFollows - it finds users whose kind-3 lists
|
||||
// contain the target pubkey(s).
|
||||
//
|
||||
// At each depth:
|
||||
// - Depth 1: Users who directly follow the seed
|
||||
// - Depth 2: Users who follow anyone at depth 1 (followers of followers)
|
||||
// - etc.
|
||||
func (d *D) TraverseFollowers(seedPubkey []byte, maxDepth int) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
if len(seedPubkey) != 32 { |
||||
return result, ErrPubkeyNotFound |
||||
} |
||||
|
||||
// Get seed pubkey serial
|
||||
seedSerial, err := d.GetPubkeySerial(seedPubkey) |
||||
if err != nil { |
||||
log.D.F("TraverseFollowers: seed pubkey not in database: %s", hex.Enc(seedPubkey)) |
||||
return result, nil |
||||
} |
||||
|
||||
// Track visited pubkeys
|
||||
visited := make(map[uint64]bool) |
||||
visited[seedSerial.Get()] = true |
||||
|
||||
// Current frontier
|
||||
currentFrontier := []*types.Uint40{seedSerial} |
||||
|
||||
consecutiveEmptyDepths := 0 |
||||
|
||||
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ { |
||||
var nextFrontier []*types.Uint40 |
||||
newPubkeysAtDepth := 0 |
||||
|
||||
for _, targetSerial := range currentFrontier { |
||||
// Get followers of this pubkey
|
||||
followers, err := d.GetFollowersOfPubkeySerial(targetSerial) |
||||
if err != nil { |
||||
log.D.F("TraverseFollowers: error getting followers for serial %d: %v", targetSerial.Get(), err) |
||||
continue |
||||
} |
||||
|
||||
for _, followerSerial := range followers { |
||||
if visited[followerSerial.Get()] { |
||||
continue |
||||
} |
||||
visited[followerSerial.Get()] = true |
||||
|
||||
pubkeyHex, err := d.GetPubkeyHexFromSerial(followerSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
result.AddPubkeyAtDepth(pubkeyHex, currentDepth) |
||||
newPubkeysAtDepth++ |
||||
nextFrontier = append(nextFrontier, followerSerial) |
||||
} |
||||
} |
||||
|
||||
log.T.F("TraverseFollowers: depth %d found %d new pubkeys", currentDepth, newPubkeysAtDepth) |
||||
|
||||
if newPubkeysAtDepth == 0 { |
||||
consecutiveEmptyDepths++ |
||||
if consecutiveEmptyDepths >= 2 { |
||||
break |
||||
} |
||||
} else { |
||||
consecutiveEmptyDepths = 0 |
||||
} |
||||
|
||||
currentFrontier = nextFrontier |
||||
} |
||||
|
||||
log.D.F("TraverseFollowers: completed with %d total pubkeys", result.TotalPubkeys) |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// TraverseFollowsFromHex is a convenience wrapper that accepts hex-encoded pubkey.
|
||||
func (d *D) TraverseFollowsFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) { |
||||
seedPubkey, err := hex.Dec(seedPubkeyHex) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return d.TraverseFollows(seedPubkey, maxDepth) |
||||
} |
||||
|
||||
// TraverseFollowersFromHex is a convenience wrapper that accepts hex-encoded pubkey.
|
||||
func (d *D) TraverseFollowersFromHex(seedPubkeyHex string, maxDepth int) (*GraphResult, error) { |
||||
seedPubkey, err := hex.Dec(seedPubkeyHex) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return d.TraverseFollowers(seedPubkey, maxDepth) |
||||
} |
||||
@ -0,0 +1,318 @@
@@ -0,0 +1,318 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
"git.mleku.dev/mleku/nostr/encoders/tag" |
||||
) |
||||
|
||||
func TestTraverseFollows(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create a simple follow graph:
|
||||
// Alice -> Bob, Carol
|
||||
// Bob -> David, Eve
|
||||
// Carol -> Eve, Frank
|
||||
//
|
||||
// Expected depth 1 from Alice: Bob, Carol
|
||||
// Expected depth 2 from Alice: David, Eve, Frank (Eve deduplicated)
|
||||
|
||||
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003") |
||||
david, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000004") |
||||
eve, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000005") |
||||
frank, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000006") |
||||
|
||||
// Create Alice's follow list (kind 3)
|
||||
aliceContactID := make([]byte, 32) |
||||
aliceContactID[0] = 0x10 |
||||
aliceContactSig := make([]byte, 64) |
||||
aliceContactSig[0] = 0x10 |
||||
aliceContact := &event.E{ |
||||
ID: aliceContactID, |
||||
Pubkey: alice, |
||||
CreatedAt: 1234567890, |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: aliceContactSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(bob)), |
||||
tag.NewFromAny("p", hex.Enc(carol)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, aliceContact) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save Alice's contact list: %v", err) |
||||
} |
||||
|
||||
// Create Bob's follow list
|
||||
bobContactID := make([]byte, 32) |
||||
bobContactID[0] = 0x20 |
||||
bobContactSig := make([]byte, 64) |
||||
bobContactSig[0] = 0x20 |
||||
bobContact := &event.E{ |
||||
ID: bobContactID, |
||||
Pubkey: bob, |
||||
CreatedAt: 1234567891, |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: bobContactSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(david)), |
||||
tag.NewFromAny("p", hex.Enc(eve)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, bobContact) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save Bob's contact list: %v", err) |
||||
} |
||||
|
||||
// Create Carol's follow list
|
||||
carolContactID := make([]byte, 32) |
||||
carolContactID[0] = 0x30 |
||||
carolContactSig := make([]byte, 64) |
||||
carolContactSig[0] = 0x30 |
||||
carolContact := &event.E{ |
||||
ID: carolContactID, |
||||
Pubkey: carol, |
||||
CreatedAt: 1234567892, |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: carolContactSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(eve)), |
||||
tag.NewFromAny("p", hex.Enc(frank)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, carolContact) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save Carol's contact list: %v", err) |
||||
} |
||||
|
||||
// Traverse follows from Alice with depth 2
|
||||
result, err := db.TraverseFollows(alice, 2) |
||||
if err != nil { |
||||
t.Fatalf("TraverseFollows failed: %v", err) |
||||
} |
||||
|
||||
// Check depth 1: should have Bob and Carol
|
||||
depth1 := result.GetPubkeysAtDepth(1) |
||||
if len(depth1) != 2 { |
||||
t.Errorf("Expected 2 pubkeys at depth 1, got %d", len(depth1)) |
||||
} |
||||
|
||||
depth1Set := make(map[string]bool) |
||||
for _, pk := range depth1 { |
||||
depth1Set[pk] = true |
||||
} |
||||
if !depth1Set[hex.Enc(bob)] { |
||||
t.Error("Bob should be at depth 1") |
||||
} |
||||
if !depth1Set[hex.Enc(carol)] { |
||||
t.Error("Carol should be at depth 1") |
||||
} |
||||
|
||||
// Check depth 2: should have David, Eve, Frank (Eve deduplicated)
|
||||
depth2 := result.GetPubkeysAtDepth(2) |
||||
if len(depth2) != 3 { |
||||
t.Errorf("Expected 3 pubkeys at depth 2, got %d: %v", len(depth2), depth2) |
||||
} |
||||
|
||||
depth2Set := make(map[string]bool) |
||||
for _, pk := range depth2 { |
||||
depth2Set[pk] = true |
||||
} |
||||
if !depth2Set[hex.Enc(david)] { |
||||
t.Error("David should be at depth 2") |
||||
} |
||||
if !depth2Set[hex.Enc(eve)] { |
||||
t.Error("Eve should be at depth 2") |
||||
} |
||||
if !depth2Set[hex.Enc(frank)] { |
||||
t.Error("Frank should be at depth 2") |
||||
} |
||||
|
||||
// Verify total count
|
||||
if result.TotalPubkeys != 5 { |
||||
t.Errorf("Expected 5 total pubkeys, got %d", result.TotalPubkeys) |
||||
} |
||||
|
||||
// Verify ToDepthArrays output
|
||||
arrays := result.ToDepthArrays() |
||||
if len(arrays) != 2 { |
||||
t.Errorf("Expected 2 depth arrays, got %d", len(arrays)) |
||||
} |
||||
if len(arrays[0]) != 2 { |
||||
t.Errorf("Expected 2 pubkeys in depth 1 array, got %d", len(arrays[0])) |
||||
} |
||||
if len(arrays[1]) != 3 { |
||||
t.Errorf("Expected 3 pubkeys in depth 2 array, got %d", len(arrays[1])) |
||||
} |
||||
} |
||||
|
||||
func TestTraverseFollowsDepth1(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003") |
||||
|
||||
// Create Alice's follow list
|
||||
aliceContactID := make([]byte, 32) |
||||
aliceContactID[0] = 0x10 |
||||
aliceContactSig := make([]byte, 64) |
||||
aliceContactSig[0] = 0x10 |
||||
aliceContact := &event.E{ |
||||
ID: aliceContactID, |
||||
Pubkey: alice, |
||||
CreatedAt: 1234567890, |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: aliceContactSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(bob)), |
||||
tag.NewFromAny("p", hex.Enc(carol)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, aliceContact) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save contact list: %v", err) |
||||
} |
||||
|
||||
// Traverse with depth 1 only
|
||||
result, err := db.TraverseFollows(alice, 1) |
||||
if err != nil { |
||||
t.Fatalf("TraverseFollows failed: %v", err) |
||||
} |
||||
|
||||
if result.TotalPubkeys != 2 { |
||||
t.Errorf("Expected 2 pubkeys, got %d", result.TotalPubkeys) |
||||
} |
||||
|
||||
arrays := result.ToDepthArrays() |
||||
if len(arrays) != 1 { |
||||
t.Errorf("Expected 1 depth array for depth 1 query, got %d", len(arrays)) |
||||
} |
||||
} |
||||
|
||||
func TestTraverseFollowersBasic(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create scenario: Bob and Carol follow Alice
|
||||
alice, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
bob, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
carol, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003") |
||||
|
||||
// Bob's contact list includes Alice
|
||||
bobContactID := make([]byte, 32) |
||||
bobContactID[0] = 0x10 |
||||
bobContactSig := make([]byte, 64) |
||||
bobContactSig[0] = 0x10 |
||||
bobContact := &event.E{ |
||||
ID: bobContactID, |
||||
Pubkey: bob, |
||||
CreatedAt: 1234567890, |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: bobContactSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(alice)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, bobContact) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save Bob's contact list: %v", err) |
||||
} |
||||
|
||||
// Carol's contact list includes Alice
|
||||
carolContactID := make([]byte, 32) |
||||
carolContactID[0] = 0x20 |
||||
carolContactSig := make([]byte, 64) |
||||
carolContactSig[0] = 0x20 |
||||
carolContact := &event.E{ |
||||
ID: carolContactID, |
||||
Pubkey: carol, |
||||
CreatedAt: 1234567891, |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: carolContactSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(alice)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, carolContact) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save Carol's contact list: %v", err) |
||||
} |
||||
|
||||
// Find Alice's followers
|
||||
result, err := db.TraverseFollowers(alice, 1) |
||||
if err != nil { |
||||
t.Fatalf("TraverseFollowers failed: %v", err) |
||||
} |
||||
|
||||
if result.TotalPubkeys != 2 { |
||||
t.Errorf("Expected 2 followers, got %d", result.TotalPubkeys) |
||||
} |
||||
|
||||
followers := result.GetPubkeysAtDepth(1) |
||||
followerSet := make(map[string]bool) |
||||
for _, pk := range followers { |
||||
followerSet[pk] = true |
||||
} |
||||
if !followerSet[hex.Enc(bob)] { |
||||
t.Error("Bob should be a follower") |
||||
} |
||||
if !followerSet[hex.Enc(carol)] { |
||||
t.Error("Carol should be a follower") |
||||
} |
||||
} |
||||
|
||||
func TestTraverseFollowsNonExistent(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Try to traverse from a pubkey that doesn't exist
|
||||
nonExistent, _ := hex.Dec("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") |
||||
result, err := db.TraverseFollows(nonExistent, 2) |
||||
if err != nil { |
||||
t.Fatalf("TraverseFollows should not error for non-existent pubkey: %v", err) |
||||
} |
||||
|
||||
if result.TotalPubkeys != 0 { |
||||
t.Errorf("Expected 0 pubkeys for non-existent seed, got %d", result.TotalPubkeys) |
||||
} |
||||
} |
||||
@ -0,0 +1,91 @@
@@ -0,0 +1,91 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"lol.mleku.dev/log" |
||||
"next.orly.dev/pkg/database/indexes/types" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
) |
||||
|
||||
// FindMentions finds events that mention a pubkey via p-tags.
|
||||
// This returns events grouped by depth, where depth represents how the events relate:
|
||||
// - Depth 1: Events that directly mention the seed pubkey
|
||||
// - Depth 2+: Not typically used for mentions (reserved for future expansion)
|
||||
//
|
||||
// The kinds parameter filters which event kinds to include (e.g., [1] for notes only,
|
||||
// [1,7] for notes and reactions, etc.)
|
||||
func (d *D) FindMentions(pubkey []byte, kinds []uint16) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
if len(pubkey) != 32 { |
||||
return result, ErrPubkeyNotFound |
||||
} |
||||
|
||||
// Get pubkey serial
|
||||
pubkeySerial, err := d.GetPubkeySerial(pubkey) |
||||
if err != nil { |
||||
log.D.F("FindMentions: pubkey not in database: %s", hex.Enc(pubkey)) |
||||
return result, nil |
||||
} |
||||
|
||||
// Find all events that reference this pubkey
|
||||
eventSerials, err := d.GetEventsReferencingPubkey(pubkeySerial, kinds) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Add each event at depth 1
|
||||
for _, eventSerial := range eventSerials { |
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial) |
||||
if err != nil { |
||||
log.D.F("FindMentions: error getting event ID for serial %d: %v", eventSerial.Get(), err) |
||||
continue |
||||
} |
||||
result.AddEventAtDepth(eventIDHex, 1) |
||||
} |
||||
|
||||
log.D.F("FindMentions: found %d events mentioning pubkey %s", result.TotalEvents, hex.Enc(pubkey)) |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// FindMentionsFromHex is a convenience wrapper that accepts hex-encoded pubkey.
|
||||
func (d *D) FindMentionsFromHex(pubkeyHex string, kinds []uint16) (*GraphResult, error) { |
||||
pubkey, err := hex.Dec(pubkeyHex) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return d.FindMentions(pubkey, kinds) |
||||
} |
||||
|
||||
// FindMentionsByPubkeys returns events that mention any of the given pubkeys.
|
||||
// Useful for finding mentions across a set of followed accounts.
|
||||
func (d *D) FindMentionsByPubkeys(pubkeySerials []*types.Uint40, kinds []uint16) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
seen := make(map[uint64]bool) |
||||
|
||||
for _, pubkeySerial := range pubkeySerials { |
||||
eventSerials, err := d.GetEventsReferencingPubkey(pubkeySerial, kinds) |
||||
if err != nil { |
||||
log.D.F("FindMentionsByPubkeys: error for serial %d: %v", pubkeySerial.Get(), err) |
||||
continue |
||||
} |
||||
|
||||
for _, eventSerial := range eventSerials { |
||||
if seen[eventSerial.Get()] { |
||||
continue |
||||
} |
||||
seen[eventSerial.Get()] = true |
||||
|
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
result.AddEventAtDepth(eventIDHex, 1) |
||||
} |
||||
} |
||||
|
||||
return result, nil |
||||
} |
||||
@ -0,0 +1,206 @@
@@ -0,0 +1,206 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"lol.mleku.dev/log" |
||||
"next.orly.dev/pkg/database/indexes/types" |
||||
) |
||||
|
||||
// AddInboundRefsToResult collects inbound references (events that reference discovered items)
|
||||
// for events at a specific depth in the result.
|
||||
//
|
||||
// For example, if you have a follows graph result and want to find all kind-7 reactions
|
||||
// to posts by users at depth 1, this collects those reactions and adds them to result.InboundRefs.
|
||||
//
|
||||
// Parameters:
|
||||
// - result: The graph result to augment with ref data
|
||||
// - depth: The depth at which to collect refs (0 = all depths)
|
||||
// - kinds: Event kinds to collect (e.g., [7] for reactions, [6] for reposts)
|
||||
func (d *D) AddInboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error { |
||||
// Determine which events to find refs for
|
||||
var targetEventIDs []string |
||||
|
||||
if depth == 0 { |
||||
// Collect for all depths
|
||||
targetEventIDs = result.GetAllEvents() |
||||
} else { |
||||
targetEventIDs = result.GetEventsAtDepth(depth) |
||||
} |
||||
|
||||
// Also collect refs for events authored by pubkeys in the result
|
||||
// This is common for "find reactions to posts by my follows" queries
|
||||
pubkeys := result.GetAllPubkeys() |
||||
for _, pubkeyHex := range pubkeys { |
||||
pubkeySerial, err := d.PubkeyHexToSerial(pubkeyHex) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
// Get events authored by this pubkey
|
||||
// For efficiency, limit to relevant event kinds that might have reactions
|
||||
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, []uint16{1, 30023}) // notes and articles
|
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
for _, eventSerial := range authoredEvents { |
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
// Add to target list if not already tracking
|
||||
if !result.HasEvent(eventIDHex) { |
||||
targetEventIDs = append(targetEventIDs, eventIDHex) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// For each target event, find referencing events
|
||||
for _, eventIDHex := range targetEventIDs { |
||||
eventSerial, err := d.EventIDHexToSerial(eventIDHex) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
refSerials, err := d.GetReferencingEvents(eventSerial, kinds) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
for _, refSerial := range refSerials { |
||||
refEventIDHex, err := d.GetEventIDFromSerial(refSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
// Get the kind of the referencing event
|
||||
// For now, use the first kind in the filter (assumes single kind queries)
|
||||
// TODO: Look up actual event kind from index if needed
|
||||
if len(kinds) > 0 { |
||||
result.AddInboundRef(kinds[0], eventIDHex, refEventIDHex) |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.D.F("AddInboundRefsToResult: collected refs for %d target events", len(targetEventIDs)) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// AddOutboundRefsToResult collects outbound references (events referenced by discovered items).
|
||||
//
|
||||
// For example, find all events that posts by users at depth 1 reference (quoted posts, replied-to posts).
|
||||
func (d *D) AddOutboundRefsToResult(result *GraphResult, depth int, kinds []uint16) error { |
||||
// Determine source events
|
||||
var sourceEventIDs []string |
||||
|
||||
if depth == 0 { |
||||
sourceEventIDs = result.GetAllEvents() |
||||
} else { |
||||
sourceEventIDs = result.GetEventsAtDepth(depth) |
||||
} |
||||
|
||||
// Also include events authored by pubkeys in result
|
||||
pubkeys := result.GetAllPubkeys() |
||||
for _, pubkeyHex := range pubkeys { |
||||
pubkeySerial, err := d.PubkeyHexToSerial(pubkeyHex) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, kinds) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
for _, eventSerial := range authoredEvents { |
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
if !result.HasEvent(eventIDHex) { |
||||
sourceEventIDs = append(sourceEventIDs, eventIDHex) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// For each source event, find referenced events
|
||||
for _, eventIDHex := range sourceEventIDs { |
||||
eventSerial, err := d.EventIDHexToSerial(eventIDHex) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
refSerials, err := d.GetETagsFromEventSerial(eventSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
for _, refSerial := range refSerials { |
||||
refEventIDHex, err := d.GetEventIDFromSerial(refSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
// Use first kind for categorization
|
||||
if len(kinds) > 0 { |
||||
result.AddOutboundRef(kinds[0], eventIDHex, refEventIDHex) |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.D.F("AddOutboundRefsToResult: collected refs from %d source events", len(sourceEventIDs)) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// CollectRefsForPubkeys collects inbound references to events by specific pubkeys.
|
||||
// This is useful for "find all reactions to posts by these users" queries.
|
||||
//
|
||||
// Parameters:
|
||||
// - pubkeySerials: The pubkeys whose events should be checked for refs
|
||||
// - refKinds: Event kinds to collect (e.g., [7] for reactions)
|
||||
// - eventKinds: Event kinds to check for refs (e.g., [1] for notes)
|
||||
func (d *D) CollectRefsForPubkeys( |
||||
pubkeySerials []*types.Uint40, |
||||
refKinds []uint16, |
||||
eventKinds []uint16, |
||||
) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
for _, pubkeySerial := range pubkeySerials { |
||||
// Get events by this author
|
||||
authoredEvents, err := d.GetEventsByAuthor(pubkeySerial, eventKinds) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
for _, eventSerial := range authoredEvents { |
||||
eventIDHex, err := d.GetEventIDFromSerial(eventSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
// Find refs to this event
|
||||
refSerials, err := d.GetReferencingEvents(eventSerial, refKinds) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
for _, refSerial := range refSerials { |
||||
refEventIDHex, err := d.GetEventIDFromSerial(refSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
// Add to result
|
||||
if len(refKinds) > 0 { |
||||
result.AddInboundRef(refKinds[0], eventIDHex, refEventIDHex) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
return result, nil |
||||
} |
||||
@ -0,0 +1,327 @@
@@ -0,0 +1,327 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"sort" |
||||
) |
||||
|
||||
// GraphResult contains depth-organized traversal results for graph queries.
|
||||
// It tracks pubkeys and events discovered at each depth level, ensuring
|
||||
// each entity appears only at the depth where it was first discovered.
|
||||
type GraphResult struct { |
||||
// PubkeysByDepth maps depth -> pubkeys first discovered at that depth.
|
||||
// Each pubkey appears ONLY in the array for the depth where it was first seen.
|
||||
// Depth 1 = direct connections, Depth 2 = connections of connections, etc.
|
||||
PubkeysByDepth map[int][]string |
||||
|
||||
// EventsByDepth maps depth -> event IDs discovered at that depth.
|
||||
// Used for thread traversal queries.
|
||||
EventsByDepth map[int][]string |
||||
|
||||
// FirstSeenPubkey tracks which depth each pubkey was first discovered.
|
||||
// Key is pubkey hex, value is the depth (1-indexed).
|
||||
FirstSeenPubkey map[string]int |
||||
|
||||
// FirstSeenEvent tracks which depth each event was first discovered.
|
||||
// Key is event ID hex, value is the depth (1-indexed).
|
||||
FirstSeenEvent map[string]int |
||||
|
||||
// TotalPubkeys is the count of unique pubkeys discovered across all depths.
|
||||
TotalPubkeys int |
||||
|
||||
// TotalEvents is the count of unique events discovered across all depths.
|
||||
TotalEvents int |
||||
|
||||
// InboundRefs tracks inbound references (events that reference discovered items).
|
||||
// Structure: kind -> target_id -> []referencing_event_ids
|
||||
InboundRefs map[uint16]map[string][]string |
||||
|
||||
// OutboundRefs tracks outbound references (events referenced by discovered items).
|
||||
// Structure: kind -> source_id -> []referenced_event_ids
|
||||
OutboundRefs map[uint16]map[string][]string |
||||
} |
||||
|
||||
// NewGraphResult creates a new initialized GraphResult.
|
||||
func NewGraphResult() *GraphResult { |
||||
return &GraphResult{ |
||||
PubkeysByDepth: make(map[int][]string), |
||||
EventsByDepth: make(map[int][]string), |
||||
FirstSeenPubkey: make(map[string]int), |
||||
FirstSeenEvent: make(map[string]int), |
||||
InboundRefs: make(map[uint16]map[string][]string), |
||||
OutboundRefs: make(map[uint16]map[string][]string), |
||||
} |
||||
} |
||||
|
||||
// AddPubkeyAtDepth adds a pubkey to the result at the specified depth if not already seen.
|
||||
// Returns true if the pubkey was added (first time seen), false if already exists.
|
||||
func (r *GraphResult) AddPubkeyAtDepth(pubkeyHex string, depth int) bool { |
||||
if _, exists := r.FirstSeenPubkey[pubkeyHex]; exists { |
||||
return false |
||||
} |
||||
|
||||
r.FirstSeenPubkey[pubkeyHex] = depth |
||||
r.PubkeysByDepth[depth] = append(r.PubkeysByDepth[depth], pubkeyHex) |
||||
r.TotalPubkeys++ |
||||
return true |
||||
} |
||||
|
||||
// AddEventAtDepth adds an event ID to the result at the specified depth if not already seen.
|
||||
// Returns true if the event was added (first time seen), false if already exists.
|
||||
func (r *GraphResult) AddEventAtDepth(eventIDHex string, depth int) bool { |
||||
if _, exists := r.FirstSeenEvent[eventIDHex]; exists { |
||||
return false |
||||
} |
||||
|
||||
r.FirstSeenEvent[eventIDHex] = depth |
||||
r.EventsByDepth[depth] = append(r.EventsByDepth[depth], eventIDHex) |
||||
r.TotalEvents++ |
||||
return true |
||||
} |
||||
|
||||
// HasPubkey returns true if the pubkey has been discovered at any depth.
|
||||
func (r *GraphResult) HasPubkey(pubkeyHex string) bool { |
||||
_, exists := r.FirstSeenPubkey[pubkeyHex] |
||||
return exists |
||||
} |
||||
|
||||
// HasEvent returns true if the event has been discovered at any depth.
|
||||
func (r *GraphResult) HasEvent(eventIDHex string) bool { |
||||
_, exists := r.FirstSeenEvent[eventIDHex] |
||||
return exists |
||||
} |
||||
|
||||
// GetPubkeyDepth returns the depth at which a pubkey was first discovered.
|
||||
// Returns 0 if the pubkey was not found.
|
||||
func (r *GraphResult) GetPubkeyDepth(pubkeyHex string) int { |
||||
return r.FirstSeenPubkey[pubkeyHex] |
||||
} |
||||
|
||||
// GetEventDepth returns the depth at which an event was first discovered.
|
||||
// Returns 0 if the event was not found.
|
||||
func (r *GraphResult) GetEventDepth(eventIDHex string) int { |
||||
return r.FirstSeenEvent[eventIDHex] |
||||
} |
||||
|
||||
// GetDepthsSorted returns all depths that have pubkeys, sorted ascending.
|
||||
func (r *GraphResult) GetDepthsSorted() []int { |
||||
depths := make([]int, 0, len(r.PubkeysByDepth)) |
||||
for d := range r.PubkeysByDepth { |
||||
depths = append(depths, d) |
||||
} |
||||
sort.Ints(depths) |
||||
return depths |
||||
} |
||||
|
||||
// GetEventDepthsSorted returns all depths that have events, sorted ascending.
|
||||
func (r *GraphResult) GetEventDepthsSorted() []int { |
||||
depths := make([]int, 0, len(r.EventsByDepth)) |
||||
for d := range r.EventsByDepth { |
||||
depths = append(depths, d) |
||||
} |
||||
sort.Ints(depths) |
||||
return depths |
||||
} |
||||
|
||||
// ToDepthArrays converts the result to the response format: array of arrays.
|
||||
// Index 0 = depth 1 pubkeys, Index 1 = depth 2 pubkeys, etc.
|
||||
// Empty arrays are included for depths with no pubkeys to maintain index alignment.
|
||||
func (r *GraphResult) ToDepthArrays() [][]string { |
||||
if len(r.PubkeysByDepth) == 0 { |
||||
return [][]string{} |
||||
} |
||||
|
||||
// Find the maximum depth
|
||||
maxDepth := 0 |
||||
for d := range r.PubkeysByDepth { |
||||
if d > maxDepth { |
||||
maxDepth = d |
||||
} |
||||
} |
||||
|
||||
// Create result array with entries for each depth
|
||||
result := make([][]string, maxDepth) |
||||
for i := 0; i < maxDepth; i++ { |
||||
depth := i + 1 // depths are 1-indexed
|
||||
if pubkeys, exists := r.PubkeysByDepth[depth]; exists { |
||||
result[i] = pubkeys |
||||
} else { |
||||
result[i] = []string{} // Empty array for depths with no pubkeys
|
||||
} |
||||
} |
||||
return result |
||||
} |
||||
|
||||
// ToEventDepthArrays converts event results to the response format: array of arrays.
|
||||
// Index 0 = depth 1 events, Index 1 = depth 2 events, etc.
|
||||
func (r *GraphResult) ToEventDepthArrays() [][]string { |
||||
if len(r.EventsByDepth) == 0 { |
||||
return [][]string{} |
||||
} |
||||
|
||||
maxDepth := 0 |
||||
for d := range r.EventsByDepth { |
||||
if d > maxDepth { |
||||
maxDepth = d |
||||
} |
||||
} |
||||
|
||||
result := make([][]string, maxDepth) |
||||
for i := 0; i < maxDepth; i++ { |
||||
depth := i + 1 |
||||
if events, exists := r.EventsByDepth[depth]; exists { |
||||
result[i] = events |
||||
} else { |
||||
result[i] = []string{} |
||||
} |
||||
} |
||||
return result |
||||
} |
||||
|
||||
// AddInboundRef records an inbound reference from a referencing event to a target.
|
||||
func (r *GraphResult) AddInboundRef(kind uint16, targetIDHex string, referencingEventIDHex string) { |
||||
if r.InboundRefs[kind] == nil { |
||||
r.InboundRefs[kind] = make(map[string][]string) |
||||
} |
||||
r.InboundRefs[kind][targetIDHex] = append(r.InboundRefs[kind][targetIDHex], referencingEventIDHex) |
||||
} |
||||
|
||||
// AddOutboundRef records an outbound reference from a source event to a referenced event.
|
||||
func (r *GraphResult) AddOutboundRef(kind uint16, sourceIDHex string, referencedEventIDHex string) { |
||||
if r.OutboundRefs[kind] == nil { |
||||
r.OutboundRefs[kind] = make(map[string][]string) |
||||
} |
||||
r.OutboundRefs[kind][sourceIDHex] = append(r.OutboundRefs[kind][sourceIDHex], referencedEventIDHex) |
||||
} |
||||
|
||||
// RefAggregation represents aggregated reference data for a single target/source.
|
||||
type RefAggregation struct { |
||||
// TargetEventID is the event ID being referenced (for inbound) or referencing (for outbound)
|
||||
TargetEventID string |
||||
|
||||
// TargetAuthor is the author pubkey of the target event (if known)
|
||||
TargetAuthor string |
||||
|
||||
// TargetDepth is the depth at which this target was discovered in the graph
|
||||
TargetDepth int |
||||
|
||||
// RefKind is the kind of the referencing events
|
||||
RefKind uint16 |
||||
|
||||
// RefCount is the number of references to/from this target
|
||||
RefCount int |
||||
|
||||
// RefEventIDs is the list of event IDs that reference this target
|
||||
RefEventIDs []string |
||||
} |
||||
|
||||
// GetInboundRefsSorted returns inbound refs for a kind, sorted by count descending.
|
||||
func (r *GraphResult) GetInboundRefsSorted(kind uint16) []RefAggregation { |
||||
kindRefs := r.InboundRefs[kind] |
||||
if kindRefs == nil { |
||||
return nil |
||||
} |
||||
|
||||
aggs := make([]RefAggregation, 0, len(kindRefs)) |
||||
for targetID, refs := range kindRefs { |
||||
agg := RefAggregation{ |
||||
TargetEventID: targetID, |
||||
TargetDepth: r.GetEventDepth(targetID), |
||||
RefKind: kind, |
||||
RefCount: len(refs), |
||||
RefEventIDs: refs, |
||||
} |
||||
aggs = append(aggs, agg) |
||||
} |
||||
|
||||
// Sort by count descending
|
||||
sort.Slice(aggs, func(i, j int) bool { |
||||
return aggs[i].RefCount > aggs[j].RefCount |
||||
}) |
||||
|
||||
return aggs |
||||
} |
||||
|
||||
// GetOutboundRefsSorted returns outbound refs for a kind, sorted by count descending.
|
||||
func (r *GraphResult) GetOutboundRefsSorted(kind uint16) []RefAggregation { |
||||
kindRefs := r.OutboundRefs[kind] |
||||
if kindRefs == nil { |
||||
return nil |
||||
} |
||||
|
||||
aggs := make([]RefAggregation, 0, len(kindRefs)) |
||||
for sourceID, refs := range kindRefs { |
||||
agg := RefAggregation{ |
||||
TargetEventID: sourceID, |
||||
TargetDepth: r.GetEventDepth(sourceID), |
||||
RefKind: kind, |
||||
RefCount: len(refs), |
||||
RefEventIDs: refs, |
||||
} |
||||
aggs = append(aggs, agg) |
||||
} |
||||
|
||||
sort.Slice(aggs, func(i, j int) bool { |
||||
return aggs[i].RefCount > aggs[j].RefCount |
||||
}) |
||||
|
||||
return aggs |
||||
} |
||||
|
||||
// GetAllPubkeys returns all pubkeys discovered across all depths.
|
||||
func (r *GraphResult) GetAllPubkeys() []string { |
||||
all := make([]string, 0, r.TotalPubkeys) |
||||
for _, pubkeys := range r.PubkeysByDepth { |
||||
all = append(all, pubkeys...) |
||||
} |
||||
return all |
||||
} |
||||
|
||||
// GetAllEvents returns all event IDs discovered across all depths.
|
||||
func (r *GraphResult) GetAllEvents() []string { |
||||
all := make([]string, 0, r.TotalEvents) |
||||
for _, events := range r.EventsByDepth { |
||||
all = append(all, events...) |
||||
} |
||||
return all |
||||
} |
||||
|
||||
// GetPubkeysAtDepth returns pubkeys at a specific depth, or empty slice if none.
|
||||
func (r *GraphResult) GetPubkeysAtDepth(depth int) []string { |
||||
if pubkeys, exists := r.PubkeysByDepth[depth]; exists { |
||||
return pubkeys |
||||
} |
||||
return []string{} |
||||
} |
||||
|
||||
// GetEventsAtDepth returns events at a specific depth, or empty slice if none.
|
||||
func (r *GraphResult) GetEventsAtDepth(depth int) []string { |
||||
if events, exists := r.EventsByDepth[depth]; exists { |
||||
return events |
||||
} |
||||
return []string{} |
||||
} |
||||
|
||||
// Interface methods for external package access (e.g., pkg/protocol/graph)
|
||||
// These allow the graph executor to extract data without direct struct access.
|
||||
|
||||
// GetPubkeysByDepth returns the PubkeysByDepth map for external access.
|
||||
func (r *GraphResult) GetPubkeysByDepth() map[int][]string { |
||||
return r.PubkeysByDepth |
||||
} |
||||
|
||||
// GetEventsByDepth returns the EventsByDepth map for external access.
|
||||
func (r *GraphResult) GetEventsByDepth() map[int][]string { |
||||
return r.EventsByDepth |
||||
} |
||||
|
||||
// GetTotalPubkeys returns the total pubkey count for external access.
|
||||
func (r *GraphResult) GetTotalPubkeys() int { |
||||
return r.TotalPubkeys |
||||
} |
||||
|
||||
// GetTotalEvents returns the total event count for external access.
|
||||
func (r *GraphResult) GetTotalEvents() int { |
||||
return r.TotalEvents |
||||
} |
||||
@ -0,0 +1,191 @@
@@ -0,0 +1,191 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"lol.mleku.dev/log" |
||||
"next.orly.dev/pkg/database/indexes/types" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
) |
||||
|
||||
// TraverseThread performs BFS traversal of thread structure via e-tags.
|
||||
// Starting from a seed event, it finds all replies/references at each depth.
|
||||
//
|
||||
// The traversal works bidirectionally:
|
||||
// - Forward: Events that the seed references (parents, quoted posts)
|
||||
// - Backward: Events that reference the seed (replies, reactions, reposts)
|
||||
//
|
||||
// Parameters:
|
||||
// - seedEventID: The event ID to start traversal from
|
||||
// - maxDepth: Maximum depth to traverse
|
||||
// - direction: "both" (default), "inbound" (replies to seed), "outbound" (seed's references)
|
||||
func (d *D) TraverseThread(seedEventID []byte, maxDepth int, direction string) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
if len(seedEventID) != 32 { |
||||
return result, ErrEventNotFound |
||||
} |
||||
|
||||
// Get seed event serial
|
||||
seedSerial, err := d.GetSerialById(seedEventID) |
||||
if err != nil { |
||||
log.D.F("TraverseThread: seed event not in database: %s", hex.Enc(seedEventID)) |
||||
return result, nil |
||||
} |
||||
|
||||
// Normalize direction
|
||||
if direction == "" { |
||||
direction = "both" |
||||
} |
||||
|
||||
// Track visited events
|
||||
visited := make(map[uint64]bool) |
||||
visited[seedSerial.Get()] = true |
||||
|
||||
// Current frontier
|
||||
currentFrontier := []*types.Uint40{seedSerial} |
||||
|
||||
consecutiveEmptyDepths := 0 |
||||
|
||||
for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ { |
||||
var nextFrontier []*types.Uint40 |
||||
newEventsAtDepth := 0 |
||||
|
||||
for _, eventSerial := range currentFrontier { |
||||
// Get inbound references (events that reference this event)
|
||||
if direction == "both" || direction == "inbound" { |
||||
inboundSerials, err := d.GetReferencingEvents(eventSerial, nil) |
||||
if err != nil { |
||||
log.D.F("TraverseThread: error getting inbound refs for serial %d: %v", eventSerial.Get(), err) |
||||
} else { |
||||
for _, refSerial := range inboundSerials { |
||||
if visited[refSerial.Get()] { |
||||
continue |
||||
} |
||||
visited[refSerial.Get()] = true |
||||
|
||||
eventIDHex, err := d.GetEventIDFromSerial(refSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
result.AddEventAtDepth(eventIDHex, currentDepth) |
||||
newEventsAtDepth++ |
||||
nextFrontier = append(nextFrontier, refSerial) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Get outbound references (events this event references)
|
||||
if direction == "both" || direction == "outbound" { |
||||
outboundSerials, err := d.GetETagsFromEventSerial(eventSerial) |
||||
if err != nil { |
||||
log.D.F("TraverseThread: error getting outbound refs for serial %d: %v", eventSerial.Get(), err) |
||||
} else { |
||||
for _, refSerial := range outboundSerials { |
||||
if visited[refSerial.Get()] { |
||||
continue |
||||
} |
||||
visited[refSerial.Get()] = true |
||||
|
||||
eventIDHex, err := d.GetEventIDFromSerial(refSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
result.AddEventAtDepth(eventIDHex, currentDepth) |
||||
newEventsAtDepth++ |
||||
nextFrontier = append(nextFrontier, refSerial) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.T.F("TraverseThread: depth %d found %d new events", currentDepth, newEventsAtDepth) |
||||
|
||||
if newEventsAtDepth == 0 { |
||||
consecutiveEmptyDepths++ |
||||
if consecutiveEmptyDepths >= 2 { |
||||
break |
||||
} |
||||
} else { |
||||
consecutiveEmptyDepths = 0 |
||||
} |
||||
|
||||
currentFrontier = nextFrontier |
||||
} |
||||
|
||||
log.D.F("TraverseThread: completed with %d total events", result.TotalEvents) |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// TraverseThreadFromHex is a convenience wrapper that accepts hex-encoded event ID.
|
||||
func (d *D) TraverseThreadFromHex(seedEventIDHex string, maxDepth int, direction string) (*GraphResult, error) { |
||||
seedEventID, err := hex.Dec(seedEventIDHex) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return d.TraverseThread(seedEventID, maxDepth, direction) |
||||
} |
||||
|
||||
// GetThreadReplies finds all direct replies to an event.
|
||||
// This is a convenience method that returns events at depth 1 with inbound direction.
|
||||
func (d *D) GetThreadReplies(eventID []byte, kinds []uint16) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
if len(eventID) != 32 { |
||||
return result, ErrEventNotFound |
||||
} |
||||
|
||||
eventSerial, err := d.GetSerialById(eventID) |
||||
if err != nil { |
||||
return result, nil |
||||
} |
||||
|
||||
// Get events that reference this event
|
||||
replySerials, err := d.GetReferencingEvents(eventSerial, kinds) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
for _, replySerial := range replySerials { |
||||
eventIDHex, err := d.GetEventIDFromSerial(replySerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
result.AddEventAtDepth(eventIDHex, 1) |
||||
} |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// GetThreadParents finds events that a given event references (its parents/quotes).
|
||||
func (d *D) GetThreadParents(eventID []byte) (*GraphResult, error) { |
||||
result := NewGraphResult() |
||||
|
||||
if len(eventID) != 32 { |
||||
return result, ErrEventNotFound |
||||
} |
||||
|
||||
eventSerial, err := d.GetSerialById(eventID) |
||||
if err != nil { |
||||
return result, nil |
||||
} |
||||
|
||||
// Get events that this event references
|
||||
parentSerials, err := d.GetETagsFromEventSerial(eventSerial) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
for _, parentSerial := range parentSerials { |
||||
eventIDHex, err := d.GetEventIDFromSerial(parentSerial) |
||||
if err != nil { |
||||
continue |
||||
} |
||||
result.AddEventAtDepth(eventIDHex, 1) |
||||
} |
||||
|
||||
return result, nil |
||||
} |
||||
@ -0,0 +1,560 @@
@@ -0,0 +1,560 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
|
||||
"github.com/dgraph-io/badger/v4" |
||||
"lol.mleku.dev/chk" |
||||
"lol.mleku.dev/log" |
||||
"next.orly.dev/pkg/database/indexes" |
||||
"next.orly.dev/pkg/database/indexes/types" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
) |
||||
|
||||
// Graph traversal errors
|
||||
var ( |
||||
ErrPubkeyNotFound = errors.New("pubkey not found in database") |
||||
ErrEventNotFound = errors.New("event not found in database") |
||||
) |
||||
|
||||
// GetPTagsFromEventSerial extracts p-tag pubkey serials from an event by its serial.
|
||||
// This is a pure index-based operation - no event decoding required.
|
||||
// It scans the epg (event-pubkey-graph) index for p-tag edges.
|
||||
func (d *D) GetPTagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) { |
||||
var pubkeySerials []*types.Uint40 |
||||
|
||||
// Build prefix: epg|event_serial
|
||||
prefix := new(bytes.Buffer) |
||||
prefix.Write([]byte(indexes.EventPubkeyGraphPrefix)) |
||||
if err := eventSerial.MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err := d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Decode key: epg(3)|event_serial(5)|pubkey_serial(5)|kind(2)|direction(1)
|
||||
if len(key) != 16 { |
||||
continue |
||||
} |
||||
|
||||
// Extract direction to filter for p-tags only
|
||||
direction := key[15] |
||||
if direction != types.EdgeDirectionPTagOut { |
||||
continue // Skip author edges, only want p-tag edges
|
||||
} |
||||
|
||||
// Extract pubkey serial (bytes 8-12)
|
||||
pubkeySerial := new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[8:13]) |
||||
if err := pubkeySerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
continue |
||||
} |
||||
|
||||
pubkeySerials = append(pubkeySerials, pubkeySerial) |
||||
} |
||||
return nil |
||||
}) |
||||
|
||||
return pubkeySerials, err |
||||
} |
||||
|
||||
// GetETagsFromEventSerial extracts e-tag event serials from an event by its serial.
|
||||
// This is a pure index-based operation - no event decoding required.
|
||||
// It scans the eeg (event-event-graph) index for outbound e-tag edges.
|
||||
func (d *D) GetETagsFromEventSerial(eventSerial *types.Uint40) ([]*types.Uint40, error) { |
||||
var targetSerials []*types.Uint40 |
||||
|
||||
// Build prefix: eeg|source_event_serial
|
||||
prefix := new(bytes.Buffer) |
||||
prefix.Write([]byte(indexes.EventEventGraphPrefix)) |
||||
if err := eventSerial.MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err := d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Decode key: eeg(3)|source_serial(5)|target_serial(5)|kind(2)|direction(1)
|
||||
if len(key) != 16 { |
||||
continue |
||||
} |
||||
|
||||
// Extract target serial (bytes 8-12)
|
||||
targetSerial := new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[8:13]) |
||||
if err := targetSerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
continue |
||||
} |
||||
|
||||
targetSerials = append(targetSerials, targetSerial) |
||||
} |
||||
return nil |
||||
}) |
||||
|
||||
return targetSerials, err |
||||
} |
||||
|
||||
// GetReferencingEvents finds all events that reference a target event via e-tags.
|
||||
// Optionally filters by event kinds. Uses the gee (reverse e-tag graph) index.
|
||||
func (d *D) GetReferencingEvents(targetSerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) { |
||||
var sourceSerials []*types.Uint40 |
||||
|
||||
if len(kinds) == 0 { |
||||
// No kind filter - scan all kinds
|
||||
prefix := new(bytes.Buffer) |
||||
prefix.Write([]byte(indexes.GraphEventEventPrefix)) |
||||
if err := targetSerial.MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err := d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Decode key: gee(3)|target_serial(5)|kind(2)|direction(1)|source_serial(5)
|
||||
if len(key) != 16 { |
||||
continue |
||||
} |
||||
|
||||
// Extract source serial (bytes 11-15)
|
||||
sourceSerial := new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[11:16]) |
||||
if err := sourceSerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
continue |
||||
} |
||||
|
||||
sourceSerials = append(sourceSerials, sourceSerial) |
||||
} |
||||
return nil |
||||
}) |
||||
return sourceSerials, err |
||||
} |
||||
|
||||
// With kind filter - scan each kind's prefix
|
||||
for _, k := range kinds { |
||||
kind := new(types.Uint16) |
||||
kind.Set(k) |
||||
|
||||
direction := new(types.Letter) |
||||
direction.Set(types.EdgeDirectionETagIn) |
||||
|
||||
prefix := new(bytes.Buffer) |
||||
if err := indexes.GraphEventEventEnc(targetSerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err := d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Extract source serial (last 5 bytes)
|
||||
if len(key) < 5 { |
||||
continue |
||||
} |
||||
sourceSerial := new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[len(key)-5:]) |
||||
if err := sourceSerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
continue |
||||
} |
||||
|
||||
sourceSerials = append(sourceSerials, sourceSerial) |
||||
} |
||||
return nil |
||||
}) |
||||
if chk.E(err) { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return sourceSerials, nil |
||||
} |
||||
|
||||
// FindEventByAuthorAndKind finds the most recent event of a specific kind by an author.
|
||||
// This is used to find kind-3 contact lists for follow graph traversal.
|
||||
// Returns nil, nil if no matching event is found.
|
||||
func (d *D) FindEventByAuthorAndKind(authorSerial *types.Uint40, kind uint16) (*types.Uint40, error) { |
||||
var eventSerial *types.Uint40 |
||||
|
||||
// First, get the full pubkey from the serial
|
||||
pubkey, err := d.GetPubkeyBySerial(authorSerial) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Build prefix for kind-pubkey index: kpc|kind|pubkey_hash
|
||||
pubHash := new(types.PubHash) |
||||
if err := pubHash.FromPubkey(pubkey); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
|
||||
kindType := new(types.Uint16) |
||||
kindType.Set(kind) |
||||
|
||||
prefix := new(bytes.Buffer) |
||||
prefix.Write([]byte(indexes.KindPubkeyPrefix)) |
||||
if err := kindType.MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
if err := pubHash.MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err = d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
opts.Reverse = true // Most recent first (highest created_at)
|
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
// Seek to end of prefix range for reverse iteration
|
||||
seekKey := make([]byte, len(searchPrefix)+8+5) // prefix + max timestamp + max serial
|
||||
copy(seekKey, searchPrefix) |
||||
for i := len(searchPrefix); i < len(seekKey); i++ { |
||||
seekKey[i] = 0xFF |
||||
} |
||||
|
||||
it.Seek(seekKey) |
||||
if !it.ValidForPrefix(searchPrefix) { |
||||
// Try going to the first valid key if seek went past
|
||||
it.Rewind() |
||||
it.Seek(searchPrefix) |
||||
} |
||||
|
||||
if it.ValidForPrefix(searchPrefix) { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Decode key: kpc(3)|kind(2)|pubkey_hash(8)|created_at(8)|serial(5)
|
||||
// Total: 26 bytes
|
||||
if len(key) < 26 { |
||||
return nil |
||||
} |
||||
|
||||
// Extract serial (last 5 bytes)
|
||||
eventSerial = new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[len(key)-5:]) |
||||
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
}) |
||||
|
||||
return eventSerial, err |
||||
} |
||||
|
||||
// GetPubkeyHexFromSerial converts a pubkey serial to its hex string representation.
|
||||
func (d *D) GetPubkeyHexFromSerial(serial *types.Uint40) (string, error) { |
||||
pubkey, err := d.GetPubkeyBySerial(serial) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return hex.Enc(pubkey), nil |
||||
} |
||||
|
||||
// GetEventIDFromSerial converts an event serial to its hex ID string.
|
||||
func (d *D) GetEventIDFromSerial(serial *types.Uint40) (string, error) { |
||||
eventID, err := d.GetEventIdBySerial(serial) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return hex.Enc(eventID), nil |
||||
} |
||||
|
||||
// GetEventsReferencingPubkey finds all events that reference a pubkey via p-tags.
|
||||
// Uses the peg (pubkey-event-graph) index with direction filter for inbound p-tags.
|
||||
// Optionally filters by event kinds.
|
||||
func (d *D) GetEventsReferencingPubkey(pubkeySerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) { |
||||
var eventSerials []*types.Uint40 |
||||
|
||||
if len(kinds) == 0 { |
||||
// No kind filter - we need to scan common kinds since direction comes after kind in the key
|
||||
// Use same approach as QueryPTagGraph
|
||||
commonKinds := []uint16{1, 6, 7, 9735, 10002, 3, 4, 5, 30023} |
||||
kinds = commonKinds |
||||
} |
||||
|
||||
for _, k := range kinds { |
||||
kind := new(types.Uint16) |
||||
kind.Set(k) |
||||
|
||||
direction := new(types.Letter) |
||||
direction.Set(types.EdgeDirectionPTagIn) // Inbound p-tags
|
||||
|
||||
prefix := new(bytes.Buffer) |
||||
if err := indexes.PubkeyEventGraphEnc(pubkeySerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err := d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Key format: peg(3)|pubkey_serial(5)|kind(2)|direction(1)|event_serial(5) = 16 bytes
|
||||
if len(key) != 16 { |
||||
continue |
||||
} |
||||
|
||||
// Extract event serial (last 5 bytes)
|
||||
eventSerial := new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[11:16]) |
||||
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
continue |
||||
} |
||||
|
||||
eventSerials = append(eventSerials, eventSerial) |
||||
} |
||||
return nil |
||||
}) |
||||
if chk.E(err) { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return eventSerials, nil |
||||
} |
||||
|
||||
// GetEventsByAuthor finds all events authored by a pubkey.
|
||||
// Uses the peg (pubkey-event-graph) index with direction filter for author edges.
|
||||
// Optionally filters by event kinds.
|
||||
func (d *D) GetEventsByAuthor(authorSerial *types.Uint40, kinds []uint16) ([]*types.Uint40, error) { |
||||
var eventSerials []*types.Uint40 |
||||
|
||||
if len(kinds) == 0 { |
||||
// No kind filter - scan for author direction across common kinds
|
||||
// This is less efficient but necessary without kind filter
|
||||
commonKinds := []uint16{0, 1, 3, 6, 7, 30023, 10002} |
||||
kinds = commonKinds |
||||
} |
||||
|
||||
for _, k := range kinds { |
||||
kind := new(types.Uint16) |
||||
kind.Set(k) |
||||
|
||||
direction := new(types.Letter) |
||||
direction.Set(types.EdgeDirectionAuthor) // Author edges
|
||||
|
||||
prefix := new(bytes.Buffer) |
||||
if err := indexes.PubkeyEventGraphEnc(authorSerial, kind, direction, nil).MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err := d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Key format: peg(3)|pubkey_serial(5)|kind(2)|direction(1)|event_serial(5) = 16 bytes
|
||||
if len(key) != 16 { |
||||
continue |
||||
} |
||||
|
||||
// Extract event serial (last 5 bytes)
|
||||
eventSerial := new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[11:16]) |
||||
if err := eventSerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
continue |
||||
} |
||||
|
||||
eventSerials = append(eventSerials, eventSerial) |
||||
} |
||||
return nil |
||||
}) |
||||
if chk.E(err) { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return eventSerials, nil |
||||
} |
||||
|
||||
// GetFollowsFromPubkeySerial returns the pubkey serials that a user follows.
|
||||
// This extracts p-tags from the user's kind-3 contact list event.
|
||||
// Returns an empty slice if no kind-3 event is found.
|
||||
func (d *D) GetFollowsFromPubkeySerial(pubkeySerial *types.Uint40) ([]*types.Uint40, error) { |
||||
// Find the kind-3 event for this pubkey
|
||||
contactEventSerial, err := d.FindEventByAuthorAndKind(pubkeySerial, 3) |
||||
if err != nil { |
||||
log.D.F("GetFollowsFromPubkeySerial: error finding kind-3 for serial %d: %v", pubkeySerial.Get(), err) |
||||
return nil, nil // No kind-3 event found is not an error
|
||||
} |
||||
if contactEventSerial == nil { |
||||
log.T.F("GetFollowsFromPubkeySerial: no kind-3 event found for serial %d", pubkeySerial.Get()) |
||||
return nil, nil |
||||
} |
||||
|
||||
// Extract p-tags from the contact list event
|
||||
follows, err := d.GetPTagsFromEventSerial(contactEventSerial) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
log.T.F("GetFollowsFromPubkeySerial: found %d follows for serial %d", len(follows), pubkeySerial.Get()) |
||||
return follows, nil |
||||
} |
||||
|
||||
// GetFollowersOfPubkeySerial returns the pubkey serials of users who follow a given pubkey.
|
||||
// This finds all kind-3 events that have a p-tag referencing the target pubkey.
|
||||
func (d *D) GetFollowersOfPubkeySerial(targetSerial *types.Uint40) ([]*types.Uint40, error) { |
||||
// Find all kind-3 events that reference this pubkey via p-tag
|
||||
kind3Events, err := d.GetEventsReferencingPubkey(targetSerial, []uint16{3}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Extract the author serials from these events
|
||||
var followerSerials []*types.Uint40 |
||||
seen := make(map[uint64]bool) |
||||
|
||||
for _, eventSerial := range kind3Events { |
||||
// Get the author of this kind-3 event
|
||||
// We need to look up the event to get its author
|
||||
// Use the epg index to find the author edge
|
||||
authorSerial, err := d.GetEventAuthorSerial(eventSerial) |
||||
if err != nil { |
||||
log.D.F("GetFollowersOfPubkeySerial: couldn't get author for event %d: %v", eventSerial.Get(), err) |
||||
continue |
||||
} |
||||
|
||||
// Deduplicate (a user might have multiple kind-3 events)
|
||||
if seen[authorSerial.Get()] { |
||||
continue |
||||
} |
||||
seen[authorSerial.Get()] = true |
||||
followerSerials = append(followerSerials, authorSerial) |
||||
} |
||||
|
||||
log.T.F("GetFollowersOfPubkeySerial: found %d followers for serial %d", len(followerSerials), targetSerial.Get()) |
||||
return followerSerials, nil |
||||
} |
||||
|
||||
// GetEventAuthorSerial finds the author pubkey serial for an event.
|
||||
// Uses the epg (event-pubkey-graph) index with author direction.
|
||||
func (d *D) GetEventAuthorSerial(eventSerial *types.Uint40) (*types.Uint40, error) { |
||||
var authorSerial *types.Uint40 |
||||
|
||||
// Build prefix: epg|event_serial
|
||||
prefix := new(bytes.Buffer) |
||||
prefix.Write([]byte(indexes.EventPubkeyGraphPrefix)) |
||||
if err := eventSerial.MarshalWrite(prefix); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
searchPrefix := prefix.Bytes() |
||||
|
||||
err := d.View(func(txn *badger.Txn) error { |
||||
opts := badger.DefaultIteratorOptions |
||||
opts.PrefetchValues = false |
||||
opts.Prefix = searchPrefix |
||||
|
||||
it := txn.NewIterator(opts) |
||||
defer it.Close() |
||||
|
||||
for it.Seek(searchPrefix); it.ValidForPrefix(searchPrefix); it.Next() { |
||||
key := it.Item().KeyCopy(nil) |
||||
|
||||
// Decode key: epg(3)|event_serial(5)|pubkey_serial(5)|kind(2)|direction(1)
|
||||
if len(key) != 16 { |
||||
continue |
||||
} |
||||
|
||||
// Check direction - we want author (0)
|
||||
direction := key[15] |
||||
if direction != types.EdgeDirectionAuthor { |
||||
continue |
||||
} |
||||
|
||||
// Extract pubkey serial (bytes 8-12)
|
||||
authorSerial = new(types.Uint40) |
||||
serialReader := bytes.NewReader(key[8:13]) |
||||
if err := authorSerial.UnmarshalRead(serialReader); chk.E(err) { |
||||
continue |
||||
} |
||||
|
||||
return nil // Found the author
|
||||
} |
||||
return ErrEventNotFound |
||||
}) |
||||
|
||||
return authorSerial, err |
||||
} |
||||
|
||||
// PubkeyHexToSerial converts a pubkey hex string to its serial, if it exists.
|
||||
// Returns an error if the pubkey is not in the database.
|
||||
func (d *D) PubkeyHexToSerial(pubkeyHex string) (*types.Uint40, error) { |
||||
pubkeyBytes, err := hex.Dec(pubkeyHex) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(pubkeyBytes) != 32 { |
||||
return nil, errors.New("invalid pubkey length") |
||||
} |
||||
return d.GetPubkeySerial(pubkeyBytes) |
||||
} |
||||
|
||||
// EventIDHexToSerial converts an event ID hex string to its serial, if it exists.
|
||||
// Returns an error if the event is not in the database.
|
||||
func (d *D) EventIDHexToSerial(eventIDHex string) (*types.Uint40, error) { |
||||
eventIDBytes, err := hex.Dec(eventIDHex) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(eventIDBytes) != 32 { |
||||
return nil, errors.New("invalid event ID length") |
||||
} |
||||
return d.GetSerialById(eventIDBytes) |
||||
} |
||||
@ -0,0 +1,547 @@
@@ -0,0 +1,547 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package database |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
"git.mleku.dev/mleku/nostr/encoders/tag" |
||||
) |
||||
|
||||
func TestGetPTagsFromEventSerial(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create an author pubkey
|
||||
authorPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
|
||||
// Create p-tag target pubkeys
|
||||
target1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
target2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003") |
||||
|
||||
// Create event with p-tags
|
||||
eventID := make([]byte, 32) |
||||
eventID[0] = 0x10 |
||||
eventSig := make([]byte, 64) |
||||
eventSig[0] = 0x10 |
||||
|
||||
ev := &event.E{ |
||||
ID: eventID, |
||||
Pubkey: authorPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("Test event with p-tags"), |
||||
Sig: eventSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(target1)), |
||||
tag.NewFromAny("p", hex.Enc(target2)), |
||||
), |
||||
} |
||||
|
||||
_, err = db.SaveEvent(ctx, ev) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save event: %v", err) |
||||
} |
||||
|
||||
// Get the event serial
|
||||
eventSerial, err := db.GetSerialById(eventID) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get event serial: %v", err) |
||||
} |
||||
|
||||
// Get p-tags from event serial
|
||||
ptagSerials, err := db.GetPTagsFromEventSerial(eventSerial) |
||||
if err != nil { |
||||
t.Fatalf("GetPTagsFromEventSerial failed: %v", err) |
||||
} |
||||
|
||||
// Should have 2 p-tags
|
||||
if len(ptagSerials) != 2 { |
||||
t.Errorf("Expected 2 p-tag serials, got %d", len(ptagSerials)) |
||||
} |
||||
|
||||
// Verify the pubkeys
|
||||
for _, serial := range ptagSerials { |
||||
pubkey, err := db.GetPubkeyBySerial(serial) |
||||
if err != nil { |
||||
t.Errorf("Failed to get pubkey for serial: %v", err) |
||||
continue |
||||
} |
||||
pubkeyHex := hex.Enc(pubkey) |
||||
if pubkeyHex != hex.Enc(target1) && pubkeyHex != hex.Enc(target2) { |
||||
t.Errorf("Unexpected pubkey: %s", pubkeyHex) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestGetETagsFromEventSerial(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create a parent event
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
parentID := make([]byte, 32) |
||||
parentID[0] = 0x10 |
||||
parentSig := make([]byte, 64) |
||||
parentSig[0] = 0x10 |
||||
|
||||
parentEvent := &event.E{ |
||||
ID: parentID, |
||||
Pubkey: parentPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("Parent post"), |
||||
Sig: parentSig, |
||||
Tags: &tag.S{}, |
||||
} |
||||
_, err = db.SaveEvent(ctx, parentEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save parent event: %v", err) |
||||
} |
||||
|
||||
// Create a reply event with e-tag
|
||||
replyPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
replyID := make([]byte, 32) |
||||
replyID[0] = 0x20 |
||||
replySig := make([]byte, 64) |
||||
replySig[0] = 0x20 |
||||
|
||||
replyEvent := &event.E{ |
||||
ID: replyID, |
||||
Pubkey: replyPubkey, |
||||
CreatedAt: 1234567891, |
||||
Kind: 1, |
||||
Content: []byte("Reply"), |
||||
Sig: replySig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("e", hex.Enc(parentID)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, replyEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save reply event: %v", err) |
||||
} |
||||
|
||||
// Get e-tags from reply
|
||||
replySerial, _ := db.GetSerialById(replyID) |
||||
etagSerials, err := db.GetETagsFromEventSerial(replySerial) |
||||
if err != nil { |
||||
t.Fatalf("GetETagsFromEventSerial failed: %v", err) |
||||
} |
||||
|
||||
if len(etagSerials) != 1 { |
||||
t.Errorf("Expected 1 e-tag serial, got %d", len(etagSerials)) |
||||
} |
||||
|
||||
// Verify the target event
|
||||
if len(etagSerials) > 0 { |
||||
targetEventID, err := db.GetEventIdBySerial(etagSerials[0]) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get event ID from serial: %v", err) |
||||
} |
||||
if hex.Enc(targetEventID) != hex.Enc(parentID) { |
||||
t.Errorf("Expected parent ID, got %s", hex.Enc(targetEventID)) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestGetReferencingEvents(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create a parent event
|
||||
parentPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
parentID := make([]byte, 32) |
||||
parentID[0] = 0x10 |
||||
parentSig := make([]byte, 64) |
||||
parentSig[0] = 0x10 |
||||
|
||||
parentEvent := &event.E{ |
||||
ID: parentID, |
||||
Pubkey: parentPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("Parent post"), |
||||
Sig: parentSig, |
||||
Tags: &tag.S{}, |
||||
} |
||||
_, err = db.SaveEvent(ctx, parentEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save parent event: %v", err) |
||||
} |
||||
|
||||
// Create multiple replies and reactions
|
||||
for i := 0; i < 3; i++ { |
||||
replyPubkey := make([]byte, 32) |
||||
replyPubkey[0] = byte(0x20 + i) |
||||
replyID := make([]byte, 32) |
||||
replyID[0] = byte(0x30 + i) |
||||
replySig := make([]byte, 64) |
||||
replySig[0] = byte(0x30 + i) |
||||
|
||||
var evKind uint16 = 1 // Reply
|
||||
if i == 2 { |
||||
evKind = 7 // Reaction
|
||||
} |
||||
|
||||
replyEvent := &event.E{ |
||||
ID: replyID, |
||||
Pubkey: replyPubkey, |
||||
CreatedAt: int64(1234567891 + i), |
||||
Kind: evKind, |
||||
Content: []byte("Response"), |
||||
Sig: replySig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("e", hex.Enc(parentID)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, replyEvent) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save reply %d: %v", i, err) |
||||
} |
||||
} |
||||
|
||||
// Get parent serial
|
||||
parentSerial, _ := db.GetSerialById(parentID) |
||||
|
||||
// Test without kind filter
|
||||
refs, err := db.GetReferencingEvents(parentSerial, nil) |
||||
if err != nil { |
||||
t.Fatalf("GetReferencingEvents failed: %v", err) |
||||
} |
||||
if len(refs) != 3 { |
||||
t.Errorf("Expected 3 referencing events, got %d", len(refs)) |
||||
} |
||||
|
||||
// Test with kind filter (only replies)
|
||||
refs, err = db.GetReferencingEvents(parentSerial, []uint16{1}) |
||||
if err != nil { |
||||
t.Fatalf("GetReferencingEvents with kind filter failed: %v", err) |
||||
} |
||||
if len(refs) != 2 { |
||||
t.Errorf("Expected 2 kind-1 referencing events, got %d", len(refs)) |
||||
} |
||||
|
||||
// Test with kind filter (only reactions)
|
||||
refs, err = db.GetReferencingEvents(parentSerial, []uint16{7}) |
||||
if err != nil { |
||||
t.Fatalf("GetReferencingEvents with kind 7 filter failed: %v", err) |
||||
} |
||||
if len(refs) != 1 { |
||||
t.Errorf("Expected 1 kind-7 referencing event, got %d", len(refs)) |
||||
} |
||||
} |
||||
|
||||
func TestGetFollowsFromPubkeySerial(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create author and their follows
|
||||
authorPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
follow1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
follow2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003") |
||||
follow3, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000004") |
||||
|
||||
// Create kind-3 contact list
|
||||
eventID := make([]byte, 32) |
||||
eventID[0] = 0x10 |
||||
eventSig := make([]byte, 64) |
||||
eventSig[0] = 0x10 |
||||
|
||||
contactList := &event.E{ |
||||
ID: eventID, |
||||
Pubkey: authorPubkey, |
||||
CreatedAt: 1234567890, |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: eventSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(follow1)), |
||||
tag.NewFromAny("p", hex.Enc(follow2)), |
||||
tag.NewFromAny("p", hex.Enc(follow3)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, contactList) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save contact list: %v", err) |
||||
} |
||||
|
||||
// Get author serial
|
||||
authorSerial, err := db.GetPubkeySerial(authorPubkey) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get author serial: %v", err) |
||||
} |
||||
|
||||
// Get follows
|
||||
follows, err := db.GetFollowsFromPubkeySerial(authorSerial) |
||||
if err != nil { |
||||
t.Fatalf("GetFollowsFromPubkeySerial failed: %v", err) |
||||
} |
||||
|
||||
if len(follows) != 3 { |
||||
t.Errorf("Expected 3 follows, got %d", len(follows)) |
||||
} |
||||
|
||||
// Verify the follows are correct
|
||||
expectedFollows := map[string]bool{ |
||||
hex.Enc(follow1): false, |
||||
hex.Enc(follow2): false, |
||||
hex.Enc(follow3): false, |
||||
} |
||||
for _, serial := range follows { |
||||
pubkey, err := db.GetPubkeyBySerial(serial) |
||||
if err != nil { |
||||
t.Errorf("Failed to get pubkey from serial: %v", err) |
||||
continue |
||||
} |
||||
pkHex := hex.Enc(pubkey) |
||||
if _, exists := expectedFollows[pkHex]; exists { |
||||
expectedFollows[pkHex] = true |
||||
} else { |
||||
t.Errorf("Unexpected follow: %s", pkHex) |
||||
} |
||||
} |
||||
for pk, found := range expectedFollows { |
||||
if !found { |
||||
t.Errorf("Expected follow not found: %s", pk) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestGraphResult(t *testing.T) { |
||||
result := NewGraphResult() |
||||
|
||||
// Add pubkeys at different depths
|
||||
result.AddPubkeyAtDepth("pubkey1", 1) |
||||
result.AddPubkeyAtDepth("pubkey2", 1) |
||||
result.AddPubkeyAtDepth("pubkey3", 2) |
||||
result.AddPubkeyAtDepth("pubkey4", 2) |
||||
result.AddPubkeyAtDepth("pubkey5", 3) |
||||
|
||||
// Try to add duplicate
|
||||
added := result.AddPubkeyAtDepth("pubkey1", 2) |
||||
if added { |
||||
t.Error("Should not add duplicate pubkey") |
||||
} |
||||
|
||||
// Verify counts
|
||||
if result.TotalPubkeys != 5 { |
||||
t.Errorf("Expected 5 total pubkeys, got %d", result.TotalPubkeys) |
||||
} |
||||
|
||||
// Verify depth tracking
|
||||
if result.GetPubkeyDepth("pubkey1") != 1 { |
||||
t.Errorf("pubkey1 should be at depth 1") |
||||
} |
||||
if result.GetPubkeyDepth("pubkey3") != 2 { |
||||
t.Errorf("pubkey3 should be at depth 2") |
||||
} |
||||
|
||||
// Verify HasPubkey
|
||||
if !result.HasPubkey("pubkey1") { |
||||
t.Error("Should have pubkey1") |
||||
} |
||||
if result.HasPubkey("nonexistent") { |
||||
t.Error("Should not have nonexistent pubkey") |
||||
} |
||||
|
||||
// Verify ToDepthArrays
|
||||
arrays := result.ToDepthArrays() |
||||
if len(arrays) != 3 { |
||||
t.Errorf("Expected 3 depth arrays, got %d", len(arrays)) |
||||
} |
||||
if len(arrays[0]) != 2 { |
||||
t.Errorf("Expected 2 pubkeys at depth 1, got %d", len(arrays[0])) |
||||
} |
||||
if len(arrays[1]) != 2 { |
||||
t.Errorf("Expected 2 pubkeys at depth 2, got %d", len(arrays[1])) |
||||
} |
||||
if len(arrays[2]) != 1 { |
||||
t.Errorf("Expected 1 pubkey at depth 3, got %d", len(arrays[2])) |
||||
} |
||||
} |
||||
|
||||
func TestGraphResultRefs(t *testing.T) { |
||||
result := NewGraphResult() |
||||
|
||||
// Add some pubkeys
|
||||
result.AddPubkeyAtDepth("pubkey1", 1) |
||||
result.AddEventAtDepth("event1", 1) |
||||
|
||||
// Add inbound refs (kind 7 reactions)
|
||||
result.AddInboundRef(7, "event1", "reaction1") |
||||
result.AddInboundRef(7, "event1", "reaction2") |
||||
result.AddInboundRef(7, "event1", "reaction3") |
||||
|
||||
// Get sorted refs
|
||||
refs := result.GetInboundRefsSorted(7) |
||||
if len(refs) != 1 { |
||||
t.Fatalf("Expected 1 aggregation, got %d", len(refs)) |
||||
} |
||||
if refs[0].RefCount != 3 { |
||||
t.Errorf("Expected 3 refs, got %d", refs[0].RefCount) |
||||
} |
||||
if refs[0].TargetEventID != "event1" { |
||||
t.Errorf("Expected event1, got %s", refs[0].TargetEventID) |
||||
} |
||||
} |
||||
|
||||
func TestGetFollowersOfPubkeySerial(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create target pubkey (the one being followed)
|
||||
targetPubkey, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
|
||||
// Create followers
|
||||
follower1, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000002") |
||||
follower2, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000003") |
||||
|
||||
// Create kind-3 contact lists for followers
|
||||
for i, followerPubkey := range [][]byte{follower1, follower2} { |
||||
eventID := make([]byte, 32) |
||||
eventID[0] = byte(0x10 + i) |
||||
eventSig := make([]byte, 64) |
||||
eventSig[0] = byte(0x10 + i) |
||||
|
||||
contactList := &event.E{ |
||||
ID: eventID, |
||||
Pubkey: followerPubkey, |
||||
CreatedAt: int64(1234567890 + i), |
||||
Kind: 3, |
||||
Content: []byte(""), |
||||
Sig: eventSig, |
||||
Tags: tag.NewS( |
||||
tag.NewFromAny("p", hex.Enc(targetPubkey)), |
||||
), |
||||
} |
||||
_, err = db.SaveEvent(ctx, contactList) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save contact list %d: %v", i, err) |
||||
} |
||||
} |
||||
|
||||
// Get target serial
|
||||
targetSerial, err := db.GetPubkeySerial(targetPubkey) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get target serial: %v", err) |
||||
} |
||||
|
||||
// Get followers
|
||||
followers, err := db.GetFollowersOfPubkeySerial(targetSerial) |
||||
if err != nil { |
||||
t.Fatalf("GetFollowersOfPubkeySerial failed: %v", err) |
||||
} |
||||
|
||||
if len(followers) != 2 { |
||||
t.Errorf("Expected 2 followers, got %d", len(followers)) |
||||
} |
||||
|
||||
// Verify the followers
|
||||
expectedFollowers := map[string]bool{ |
||||
hex.Enc(follower1): false, |
||||
hex.Enc(follower2): false, |
||||
} |
||||
for _, serial := range followers { |
||||
pubkey, err := db.GetPubkeyBySerial(serial) |
||||
if err != nil { |
||||
t.Errorf("Failed to get pubkey from serial: %v", err) |
||||
continue |
||||
} |
||||
pkHex := hex.Enc(pubkey) |
||||
if _, exists := expectedFollowers[pkHex]; exists { |
||||
expectedFollowers[pkHex] = true |
||||
} else { |
||||
t.Errorf("Unexpected follower: %s", pkHex) |
||||
} |
||||
} |
||||
for pk, found := range expectedFollowers { |
||||
if !found { |
||||
t.Errorf("Expected follower not found: %s", pk) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestPubkeyHexToSerial(t *testing.T) { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
db, err := New(ctx, cancel, t.TempDir(), "info") |
||||
if err != nil { |
||||
t.Fatalf("Failed to create database: %v", err) |
||||
} |
||||
defer db.Close() |
||||
|
||||
// Create a pubkey by saving an event
|
||||
pubkeyBytes, _ := hex.Dec("0000000000000000000000000000000000000000000000000000000000000001") |
||||
eventID := make([]byte, 32) |
||||
eventID[0] = 0x10 |
||||
eventSig := make([]byte, 64) |
||||
eventSig[0] = 0x10 |
||||
|
||||
ev := &event.E{ |
||||
ID: eventID, |
||||
Pubkey: pubkeyBytes, |
||||
CreatedAt: 1234567890, |
||||
Kind: 1, |
||||
Content: []byte("Test"), |
||||
Sig: eventSig, |
||||
Tags: &tag.S{}, |
||||
} |
||||
_, err = db.SaveEvent(ctx, ev) |
||||
if err != nil { |
||||
t.Fatalf("Failed to save event: %v", err) |
||||
} |
||||
|
||||
// Convert hex to serial
|
||||
pubkeyHex := hex.Enc(pubkeyBytes) |
||||
serial, err := db.PubkeyHexToSerial(pubkeyHex) |
||||
if err != nil { |
||||
t.Fatalf("PubkeyHexToSerial failed: %v", err) |
||||
} |
||||
if serial == nil { |
||||
t.Fatal("Expected non-nil serial") |
||||
} |
||||
|
||||
// Convert back and verify
|
||||
backToHex, err := db.GetPubkeyHexFromSerial(serial) |
||||
if err != nil { |
||||
t.Fatalf("GetPubkeyHexFromSerial failed: %v", err) |
||||
} |
||||
if backToHex != pubkeyHex { |
||||
t.Errorf("Round-trip failed: %s != %s", backToHex, pubkeyHex) |
||||
} |
||||
} |
||||
@ -0,0 +1,202 @@
@@ -0,0 +1,202 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
// Package graph implements NIP-XX Graph Query protocol support.
|
||||
// This file contains the executor that runs graph traversal queries.
|
||||
package graph |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"strconv" |
||||
"time" |
||||
|
||||
"lol.mleku.dev/chk" |
||||
"lol.mleku.dev/log" |
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/event" |
||||
"git.mleku.dev/mleku/nostr/encoders/hex" |
||||
"git.mleku.dev/mleku/nostr/encoders/tag" |
||||
"git.mleku.dev/mleku/nostr/interfaces/signer" |
||||
"git.mleku.dev/mleku/nostr/interfaces/signer/p8k" |
||||
) |
||||
|
||||
// Response kinds for graph queries (ephemeral range, relay-signed)
|
||||
const ( |
||||
KindGraphFollows = 39000 // Response for follows/followers queries
|
||||
KindGraphMentions = 39001 // Response for mentions queries
|
||||
KindGraphThread = 39002 // Response for thread traversal queries
|
||||
) |
||||
|
||||
// GraphResultI is the interface that database.GraphResult implements.
|
||||
// This allows the executor to work with the database result without importing it.
|
||||
type GraphResultI interface { |
||||
ToDepthArrays() [][]string |
||||
ToEventDepthArrays() [][]string |
||||
GetAllPubkeys() []string |
||||
GetAllEvents() []string |
||||
GetPubkeysByDepth() map[int][]string |
||||
GetEventsByDepth() map[int][]string |
||||
GetTotalPubkeys() int |
||||
GetTotalEvents() int |
||||
} |
||||
|
||||
// GraphDatabase defines the interface for graph traversal operations.
|
||||
// This is implemented by the database package.
|
||||
type GraphDatabase interface { |
||||
// TraverseFollows performs BFS traversal of follow graph
|
||||
TraverseFollows(seedPubkey []byte, maxDepth int) (GraphResultI, error) |
||||
// TraverseFollowers performs BFS traversal to find followers
|
||||
TraverseFollowers(seedPubkey []byte, maxDepth int) (GraphResultI, error) |
||||
// FindMentions finds events mentioning a pubkey
|
||||
FindMentions(pubkey []byte, kinds []uint16) (GraphResultI, error) |
||||
// TraverseThread performs BFS traversal of thread structure
|
||||
TraverseThread(seedEventID []byte, maxDepth int, direction string) (GraphResultI, error) |
||||
} |
||||
|
||||
// Executor handles graph query execution and response generation.
|
||||
type Executor struct { |
||||
db GraphDatabase |
||||
relaySigner signer.I |
||||
relayPubkey []byte |
||||
} |
||||
|
||||
// NewExecutor creates a new graph query executor.
|
||||
// The secretKey should be the 32-byte relay identity secret key.
|
||||
func NewExecutor(db GraphDatabase, secretKey []byte) (*Executor, error) { |
||||
s, err := p8k.New() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if err = s.InitSec(secretKey); err != nil { |
||||
return nil, err |
||||
} |
||||
return &Executor{ |
||||
db: db, |
||||
relaySigner: s, |
||||
relayPubkey: s.Pub(), |
||||
}, nil |
||||
} |
||||
|
||||
// Execute runs a graph query and returns a relay-signed event with results.
|
||||
func (e *Executor) Execute(q *Query) (*event.E, error) { |
||||
var result GraphResultI |
||||
var err error |
||||
var responseKind uint16 |
||||
|
||||
// Decode seed (hex string to bytes)
|
||||
seedBytes, err := hex.Dec(q.Seed) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Execute the appropriate traversal
|
||||
switch q.Method { |
||||
case "follows": |
||||
responseKind = KindGraphFollows |
||||
result, err = e.db.TraverseFollows(seedBytes, q.Depth) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
log.D.F("graph executor: follows traversal returned %d pubkeys", result.GetTotalPubkeys()) |
||||
|
||||
case "followers": |
||||
responseKind = KindGraphFollows |
||||
result, err = e.db.TraverseFollowers(seedBytes, q.Depth) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
log.D.F("graph executor: followers traversal returned %d pubkeys", result.GetTotalPubkeys()) |
||||
|
||||
case "mentions": |
||||
responseKind = KindGraphMentions |
||||
// Mentions don't use depth traversal, just find direct mentions
|
||||
// Convert RefSpec kinds to uint16 for the database call
|
||||
var kinds []uint16 |
||||
if len(q.InboundRefs) > 0 { |
||||
for _, rs := range q.InboundRefs { |
||||
for _, k := range rs.Kinds { |
||||
kinds = append(kinds, uint16(k)) |
||||
} |
||||
} |
||||
} else { |
||||
kinds = []uint16{1} // Default to kind 1 (notes)
|
||||
} |
||||
result, err = e.db.FindMentions(seedBytes, kinds) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
log.D.F("graph executor: mentions query returned %d events", result.GetTotalEvents()) |
||||
|
||||
case "thread": |
||||
responseKind = KindGraphThread |
||||
result, err = e.db.TraverseThread(seedBytes, q.Depth, "both") |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
log.D.F("graph executor: thread traversal returned %d events", result.GetTotalEvents()) |
||||
|
||||
default: |
||||
return nil, ErrInvalidMethod |
||||
} |
||||
|
||||
// Generate response event
|
||||
return e.generateResponse(q, result, responseKind) |
||||
} |
||||
|
||||
// generateResponse creates a relay-signed event containing the query results.
|
||||
func (e *Executor) generateResponse(q *Query, result GraphResultI, responseKind uint16) (*event.E, error) { |
||||
// Build content as JSON with depth arrays
|
||||
var content ResponseContent |
||||
|
||||
if q.Method == "follows" || q.Method == "followers" { |
||||
// For pubkey-based queries, use pubkeys_by_depth
|
||||
content.PubkeysByDepth = result.ToDepthArrays() |
||||
content.TotalPubkeys = result.GetTotalPubkeys() |
||||
} else { |
||||
// For event-based queries, use events_by_depth
|
||||
content.EventsByDepth = result.ToEventDepthArrays() |
||||
content.TotalEvents = result.GetTotalEvents() |
||||
} |
||||
|
||||
contentBytes, err := json.Marshal(content) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Build tags
|
||||
tags := tag.NewS( |
||||
tag.NewFromAny("method", q.Method), |
||||
tag.NewFromAny("seed", q.Seed), |
||||
tag.NewFromAny("depth", strconv.Itoa(q.Depth)), |
||||
) |
||||
|
||||
// Create event
|
||||
ev := &event.E{ |
||||
Kind: responseKind, |
||||
CreatedAt: time.Now().Unix(), |
||||
Tags: tags, |
||||
Content: contentBytes, |
||||
} |
||||
|
||||
// Sign with relay identity
|
||||
if err = ev.Sign(e.relaySigner); chk.E(err) { |
||||
return nil, err |
||||
} |
||||
|
||||
return ev, nil |
||||
} |
||||
|
||||
// ResponseContent is the JSON structure for graph query responses.
|
||||
type ResponseContent struct { |
||||
// PubkeysByDepth contains arrays of pubkeys at each depth (1-indexed)
|
||||
// Each pubkey appears ONLY at the depth where it was first discovered.
|
||||
PubkeysByDepth [][]string `json:"pubkeys_by_depth,omitempty"` |
||||
|
||||
// EventsByDepth contains arrays of event IDs at each depth (1-indexed)
|
||||
EventsByDepth [][]string `json:"events_by_depth,omitempty"` |
||||
|
||||
// TotalPubkeys is the total count of unique pubkeys discovered
|
||||
TotalPubkeys int `json:"total_pubkeys,omitempty"` |
||||
|
||||
// TotalEvents is the total count of unique events discovered
|
||||
TotalEvents int `json:"total_events,omitempty"` |
||||
} |
||||
@ -0,0 +1,183 @@
@@ -0,0 +1,183 @@
|
||||
// Package graph implements NIP-XX Graph Query protocol support.
|
||||
// It provides types and functions for parsing and validating graph traversal queries.
|
||||
package graph |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"errors" |
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/filter" |
||||
) |
||||
|
||||
// Query represents a graph traversal query from a _graph filter extension.
|
||||
type Query struct { |
||||
// Method is the traversal method: "follows", "followers", "mentions", "thread"
|
||||
Method string `json:"method"` |
||||
|
||||
// Seed is the starting point for traversal (pubkey hex or event ID hex)
|
||||
Seed string `json:"seed"` |
||||
|
||||
// Depth is the maximum traversal depth (1-16, default: 1)
|
||||
Depth int `json:"depth,omitempty"` |
||||
|
||||
// InboundRefs specifies which inbound references to collect
|
||||
// (events that reference discovered events via e-tags)
|
||||
InboundRefs []RefSpec `json:"inbound_refs,omitempty"` |
||||
|
||||
// OutboundRefs specifies which outbound references to collect
|
||||
// (events referenced by discovered events via e-tags)
|
||||
OutboundRefs []RefSpec `json:"outbound_refs,omitempty"` |
||||
} |
||||
|
||||
// RefSpec specifies which event references to include in results.
|
||||
type RefSpec struct { |
||||
// Kinds is the list of event kinds to match (OR semantics within this spec)
|
||||
Kinds []int `json:"kinds"` |
||||
|
||||
// FromDepth specifies the minimum depth at which to collect refs (default: 0)
|
||||
// 0 = include refs from seed itself
|
||||
// 1 = start from first-hop connections
|
||||
FromDepth int `json:"from_depth,omitempty"` |
||||
} |
||||
|
||||
// Validation errors
|
||||
var ( |
||||
ErrMissingMethod = errors.New("_graph.method is required") |
||||
ErrInvalidMethod = errors.New("_graph.method must be one of: follows, followers, mentions, thread") |
||||
ErrMissingSeed = errors.New("_graph.seed is required") |
||||
ErrInvalidSeed = errors.New("_graph.seed must be a 64-character hex string") |
||||
ErrDepthTooHigh = errors.New("_graph.depth cannot exceed 16") |
||||
ErrEmptyRefSpecKinds = errors.New("ref spec kinds array cannot be empty") |
||||
) |
||||
|
||||
// Valid method names
|
||||
var validMethods = map[string]bool{ |
||||
"follows": true, |
||||
"followers": true, |
||||
"mentions": true, |
||||
"thread": true, |
||||
} |
||||
|
||||
// Validate checks the query for correctness and applies defaults.
|
||||
func (q *Query) Validate() error { |
||||
// Method is required
|
||||
if q.Method == "" { |
||||
return ErrMissingMethod |
||||
} |
||||
if !validMethods[q.Method] { |
||||
return ErrInvalidMethod |
||||
} |
||||
|
||||
// Seed is required
|
||||
if q.Seed == "" { |
||||
return ErrMissingSeed |
||||
} |
||||
if len(q.Seed) != 64 { |
||||
return ErrInvalidSeed |
||||
} |
||||
// Validate hex characters
|
||||
for _, c := range q.Seed { |
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) { |
||||
return ErrInvalidSeed |
||||
} |
||||
} |
||||
|
||||
// Apply depth defaults and limits
|
||||
if q.Depth < 1 { |
||||
q.Depth = 1 |
||||
} |
||||
if q.Depth > 16 { |
||||
return ErrDepthTooHigh |
||||
} |
||||
|
||||
// Validate ref specs
|
||||
for _, rs := range q.InboundRefs { |
||||
if len(rs.Kinds) == 0 { |
||||
return ErrEmptyRefSpecKinds |
||||
} |
||||
} |
||||
for _, rs := range q.OutboundRefs { |
||||
if len(rs.Kinds) == 0 { |
||||
return ErrEmptyRefSpecKinds |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// HasInboundRefs returns true if the query includes inbound reference collection.
|
||||
func (q *Query) HasInboundRefs() bool { |
||||
return len(q.InboundRefs) > 0 |
||||
} |
||||
|
||||
// HasOutboundRefs returns true if the query includes outbound reference collection.
|
||||
func (q *Query) HasOutboundRefs() bool { |
||||
return len(q.OutboundRefs) > 0 |
||||
} |
||||
|
||||
// HasRefs returns true if the query includes any reference collection.
|
||||
func (q *Query) HasRefs() bool { |
||||
return q.HasInboundRefs() || q.HasOutboundRefs() |
||||
} |
||||
|
||||
// InboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
|
||||
// It aggregates all RefSpecs where from_depth <= depth.
|
||||
func (q *Query) InboundKindsAtDepth(depth int) map[int]bool { |
||||
kinds := make(map[int]bool) |
||||
for _, rs := range q.InboundRefs { |
||||
if rs.FromDepth <= depth { |
||||
for _, k := range rs.Kinds { |
||||
kinds[k] = true |
||||
} |
||||
} |
||||
} |
||||
return kinds |
||||
} |
||||
|
||||
// OutboundKindsAtDepth returns a set of kinds that should be collected at the given depth.
|
||||
func (q *Query) OutboundKindsAtDepth(depth int) map[int]bool { |
||||
kinds := make(map[int]bool) |
||||
for _, rs := range q.OutboundRefs { |
||||
if rs.FromDepth <= depth { |
||||
for _, k := range rs.Kinds { |
||||
kinds[k] = true |
||||
} |
||||
} |
||||
} |
||||
return kinds |
||||
} |
||||
|
||||
// ExtractFromFilter checks if a filter has a _graph extension and parses it.
|
||||
// Returns nil if no _graph field is present.
|
||||
// Returns an error if _graph is present but invalid.
|
||||
func ExtractFromFilter(f *filter.F) (*Query, error) { |
||||
if f == nil || f.Extra == nil { |
||||
return nil, nil |
||||
} |
||||
|
||||
raw, ok := f.Extra["_graph"] |
||||
if !ok { |
||||
return nil, nil |
||||
} |
||||
|
||||
var q Query |
||||
if err := json.Unmarshal(raw, &q); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if err := q.Validate(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &q, nil |
||||
} |
||||
|
||||
// IsGraphQuery returns true if the filter contains a _graph extension.
|
||||
// This is a quick check that doesn't parse the full query.
|
||||
func IsGraphQuery(f *filter.F) bool { |
||||
if f == nil || f.Extra == nil { |
||||
return false |
||||
} |
||||
_, ok := f.Extra["_graph"] |
||||
return ok |
||||
} |
||||
@ -0,0 +1,397 @@
@@ -0,0 +1,397 @@
|
||||
package graph |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"git.mleku.dev/mleku/nostr/encoders/filter" |
||||
) |
||||
|
||||
func TestQueryValidate(t *testing.T) { |
||||
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" |
||||
|
||||
tests := []struct { |
||||
name string |
||||
query Query |
||||
wantErr error |
||||
}{ |
||||
{ |
||||
name: "valid follows query", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: validSeed, |
||||
Depth: 2, |
||||
}, |
||||
wantErr: nil, |
||||
}, |
||||
{ |
||||
name: "valid followers query", |
||||
query: Query{ |
||||
Method: "followers", |
||||
Seed: validSeed, |
||||
}, |
||||
wantErr: nil, |
||||
}, |
||||
{ |
||||
name: "valid mentions query", |
||||
query: Query{ |
||||
Method: "mentions", |
||||
Seed: validSeed, |
||||
Depth: 1, |
||||
}, |
||||
wantErr: nil, |
||||
}, |
||||
{ |
||||
name: "valid thread query", |
||||
query: Query{ |
||||
Method: "thread", |
||||
Seed: validSeed, |
||||
Depth: 10, |
||||
}, |
||||
wantErr: nil, |
||||
}, |
||||
{ |
||||
name: "valid query with inbound refs", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: validSeed, |
||||
Depth: 2, |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{7}, FromDepth: 1}, |
||||
}, |
||||
}, |
||||
wantErr: nil, |
||||
}, |
||||
{ |
||||
name: "valid query with multiple ref specs", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: validSeed, |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{7}, FromDepth: 1}, |
||||
{Kinds: []int{6}, FromDepth: 1}, |
||||
}, |
||||
OutboundRefs: []RefSpec{ |
||||
{Kinds: []int{1}, FromDepth: 0}, |
||||
}, |
||||
}, |
||||
wantErr: nil, |
||||
}, |
||||
{ |
||||
name: "missing method", |
||||
query: Query{Seed: validSeed}, |
||||
wantErr: ErrMissingMethod, |
||||
}, |
||||
{ |
||||
name: "invalid method", |
||||
query: Query{ |
||||
Method: "invalid", |
||||
Seed: validSeed, |
||||
}, |
||||
wantErr: ErrInvalidMethod, |
||||
}, |
||||
{ |
||||
name: "missing seed", |
||||
query: Query{ |
||||
Method: "follows", |
||||
}, |
||||
wantErr: ErrMissingSeed, |
||||
}, |
||||
{ |
||||
name: "seed too short", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: "abc123", |
||||
}, |
||||
wantErr: ErrInvalidSeed, |
||||
}, |
||||
{ |
||||
name: "seed with invalid characters", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg", |
||||
}, |
||||
wantErr: ErrInvalidSeed, |
||||
}, |
||||
{ |
||||
name: "depth too high", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: validSeed, |
||||
Depth: 17, |
||||
}, |
||||
wantErr: ErrDepthTooHigh, |
||||
}, |
||||
{ |
||||
name: "empty ref spec kinds", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: validSeed, |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{}, FromDepth: 1}, |
||||
}, |
||||
}, |
||||
wantErr: ErrEmptyRefSpecKinds, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
err := tt.query.Validate() |
||||
if tt.wantErr == nil { |
||||
if err != nil { |
||||
t.Errorf("unexpected error: %v", err) |
||||
} |
||||
} else { |
||||
if err != tt.wantErr { |
||||
t.Errorf("error = %v, want %v", err, tt.wantErr) |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestQueryDefaults(t *testing.T) { |
||||
validSeed := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" |
||||
|
||||
q := Query{ |
||||
Method: "follows", |
||||
Seed: validSeed, |
||||
Depth: 0, // Should default to 1
|
||||
} |
||||
|
||||
err := q.Validate() |
||||
if err != nil { |
||||
t.Fatalf("unexpected error: %v", err) |
||||
} |
||||
|
||||
if q.Depth != 1 { |
||||
t.Errorf("Depth = %d, want 1 (default)", q.Depth) |
||||
} |
||||
} |
||||
|
||||
func TestKindsAtDepth(t *testing.T) { |
||||
q := Query{ |
||||
Method: "follows", |
||||
Seed: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", |
||||
Depth: 3, |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{7}, FromDepth: 0}, // From seed
|
||||
{Kinds: []int{6, 16}, FromDepth: 1}, // From depth 1
|
||||
{Kinds: []int{9735}, FromDepth: 2}, // From depth 2
|
||||
}, |
||||
OutboundRefs: []RefSpec{ |
||||
{Kinds: []int{1}, FromDepth: 1}, |
||||
}, |
||||
} |
||||
|
||||
// Test inbound kinds at depth 0
|
||||
kinds0 := q.InboundKindsAtDepth(0) |
||||
if !kinds0[7] || kinds0[6] || kinds0[9735] { |
||||
t.Errorf("InboundKindsAtDepth(0) = %v, want only kind 7", kinds0) |
||||
} |
||||
|
||||
// Test inbound kinds at depth 1
|
||||
kinds1 := q.InboundKindsAtDepth(1) |
||||
if !kinds1[7] || !kinds1[6] || !kinds1[16] || kinds1[9735] { |
||||
t.Errorf("InboundKindsAtDepth(1) = %v, want kinds 7, 6, 16", kinds1) |
||||
} |
||||
|
||||
// Test inbound kinds at depth 2
|
||||
kinds2 := q.InboundKindsAtDepth(2) |
||||
if !kinds2[7] || !kinds2[6] || !kinds2[16] || !kinds2[9735] { |
||||
t.Errorf("InboundKindsAtDepth(2) = %v, want all kinds", kinds2) |
||||
} |
||||
|
||||
// Test outbound kinds at depth 0
|
||||
outKinds0 := q.OutboundKindsAtDepth(0) |
||||
if len(outKinds0) != 0 { |
||||
t.Errorf("OutboundKindsAtDepth(0) = %v, want empty", outKinds0) |
||||
} |
||||
|
||||
// Test outbound kinds at depth 1
|
||||
outKinds1 := q.OutboundKindsAtDepth(1) |
||||
if !outKinds1[1] { |
||||
t.Errorf("OutboundKindsAtDepth(1) = %v, want kind 1", outKinds1) |
||||
} |
||||
} |
||||
|
||||
func TestExtractFromFilter(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
filterJSON string |
||||
wantQuery bool |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
name: "filter with valid graph query", |
||||
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":2}}`, |
||||
wantQuery: true, |
||||
wantErr: false, |
||||
}, |
||||
{ |
||||
name: "filter without graph query", |
||||
filterJSON: `{"kinds":[1,7]}`, |
||||
wantQuery: false, |
||||
wantErr: false, |
||||
}, |
||||
{ |
||||
name: "filter with invalid graph query (missing method)", |
||||
filterJSON: `{"kinds":[1],"_graph":{"seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}}`, |
||||
wantQuery: false, |
||||
wantErr: true, |
||||
}, |
||||
{ |
||||
name: "filter with complex graph query", |
||||
filterJSON: `{"kinds":[0],"_graph":{"method":"follows","seed":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","depth":3,"inbound_refs":[{"kinds":[7],"from_depth":1}]}}`, |
||||
wantQuery: true, |
||||
wantErr: false, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
f := &filter.F{} |
||||
_, err := f.Unmarshal([]byte(tt.filterJSON)) |
||||
if err != nil { |
||||
t.Fatalf("failed to unmarshal filter: %v", err) |
||||
} |
||||
|
||||
q, err := ExtractFromFilter(f) |
||||
|
||||
if tt.wantErr { |
||||
if err == nil { |
||||
t.Error("expected error, got nil") |
||||
} |
||||
return |
||||
} |
||||
|
||||
if err != nil { |
||||
t.Errorf("unexpected error: %v", err) |
||||
return |
||||
} |
||||
|
||||
if tt.wantQuery && q == nil { |
||||
t.Error("expected query, got nil") |
||||
} |
||||
if !tt.wantQuery && q != nil { |
||||
t.Errorf("expected nil query, got %+v", q) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestIsGraphQuery(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
filterJSON string |
||||
want bool |
||||
}{ |
||||
{ |
||||
name: "filter with graph query", |
||||
filterJSON: `{"kinds":[1],"_graph":{"method":"follows","seed":"abc"}}`, |
||||
want: true, |
||||
}, |
||||
{ |
||||
name: "filter without graph query", |
||||
filterJSON: `{"kinds":[1,7]}`, |
||||
want: false, |
||||
}, |
||||
{ |
||||
name: "filter with other extension", |
||||
filterJSON: `{"kinds":[1],"_custom":"value"}`, |
||||
want: false, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
f := &filter.F{} |
||||
_, err := f.Unmarshal([]byte(tt.filterJSON)) |
||||
if err != nil { |
||||
t.Fatalf("failed to unmarshal filter: %v", err) |
||||
} |
||||
|
||||
got := IsGraphQuery(f) |
||||
if got != tt.want { |
||||
t.Errorf("IsGraphQuery() = %v, want %v", got, tt.want) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestQueryHasRefs(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
query Query |
||||
hasInbound bool |
||||
hasOutbound bool |
||||
hasRefs bool |
||||
}{ |
||||
{ |
||||
name: "no refs", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
}, |
||||
hasInbound: false, |
||||
hasOutbound: false, |
||||
hasRefs: false, |
||||
}, |
||||
{ |
||||
name: "only inbound refs", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{7}}, |
||||
}, |
||||
}, |
||||
hasInbound: true, |
||||
hasOutbound: false, |
||||
hasRefs: true, |
||||
}, |
||||
{ |
||||
name: "only outbound refs", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
OutboundRefs: []RefSpec{ |
||||
{Kinds: []int{1}}, |
||||
}, |
||||
}, |
||||
hasInbound: false, |
||||
hasOutbound: true, |
||||
hasRefs: true, |
||||
}, |
||||
{ |
||||
name: "both refs", |
||||
query: Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{7}}, |
||||
}, |
||||
OutboundRefs: []RefSpec{ |
||||
{Kinds: []int{1}}, |
||||
}, |
||||
}, |
||||
hasInbound: true, |
||||
hasOutbound: true, |
||||
hasRefs: true, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
if got := tt.query.HasInboundRefs(); got != tt.hasInbound { |
||||
t.Errorf("HasInboundRefs() = %v, want %v", got, tt.hasInbound) |
||||
} |
||||
if got := tt.query.HasOutboundRefs(); got != tt.hasOutbound { |
||||
t.Errorf("HasOutboundRefs() = %v, want %v", got, tt.hasOutbound) |
||||
} |
||||
if got := tt.query.HasRefs(); got != tt.hasRefs { |
||||
t.Errorf("HasRefs() = %v, want %v", got, tt.hasRefs) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
@ -0,0 +1,282 @@
@@ -0,0 +1,282 @@
|
||||
package graph |
||||
|
||||
import ( |
||||
"context" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
// RateLimiter implements a token bucket rate limiter with adaptive throttling
|
||||
// based on graph query complexity. It allows cooperative scheduling by inserting
|
||||
// pauses between operations to allow other work to proceed.
|
||||
type RateLimiter struct { |
||||
mu sync.Mutex |
||||
|
||||
// Token bucket parameters
|
||||
tokens float64 // Current available tokens
|
||||
maxTokens float64 // Maximum token capacity
|
||||
refillRate float64 // Tokens per second to add
|
||||
lastRefill time.Time // Last time tokens were refilled
|
||||
|
||||
// Throttling parameters
|
||||
baseDelay time.Duration // Minimum delay between operations
|
||||
maxDelay time.Duration // Maximum delay for complex queries
|
||||
depthFactor float64 // Multiplier per depth level
|
||||
limitFactor float64 // Multiplier based on result limit
|
||||
} |
||||
|
||||
// RateLimiterConfig configures the rate limiter behavior.
|
||||
type RateLimiterConfig struct { |
||||
// MaxTokens is the maximum number of tokens in the bucket (default: 100)
|
||||
MaxTokens float64 |
||||
|
||||
// RefillRate is tokens added per second (default: 10)
|
||||
RefillRate float64 |
||||
|
||||
// BaseDelay is the minimum delay between operations (default: 1ms)
|
||||
BaseDelay time.Duration |
||||
|
||||
// MaxDelay is the maximum delay for complex queries (default: 100ms)
|
||||
MaxDelay time.Duration |
||||
|
||||
// DepthFactor is the cost multiplier per depth level (default: 2.0)
|
||||
// A depth-3 query costs 2^3 = 8x more tokens than depth-1
|
||||
DepthFactor float64 |
||||
|
||||
// LimitFactor is additional cost per 100 results requested (default: 0.1)
|
||||
LimitFactor float64 |
||||
} |
||||
|
||||
// DefaultRateLimiterConfig returns sensible defaults for the rate limiter.
|
||||
func DefaultRateLimiterConfig() RateLimiterConfig { |
||||
return RateLimiterConfig{ |
||||
MaxTokens: 100.0, |
||||
RefillRate: 10.0, // Refills fully in 10 seconds
|
||||
BaseDelay: 1 * time.Millisecond, |
||||
MaxDelay: 100 * time.Millisecond, |
||||
DepthFactor: 2.0, |
||||
LimitFactor: 0.1, |
||||
} |
||||
} |
||||
|
||||
// NewRateLimiter creates a new rate limiter with the given configuration.
|
||||
func NewRateLimiter(cfg RateLimiterConfig) *RateLimiter { |
||||
if cfg.MaxTokens <= 0 { |
||||
cfg.MaxTokens = DefaultRateLimiterConfig().MaxTokens |
||||
} |
||||
if cfg.RefillRate <= 0 { |
||||
cfg.RefillRate = DefaultRateLimiterConfig().RefillRate |
||||
} |
||||
if cfg.BaseDelay <= 0 { |
||||
cfg.BaseDelay = DefaultRateLimiterConfig().BaseDelay |
||||
} |
||||
if cfg.MaxDelay <= 0 { |
||||
cfg.MaxDelay = DefaultRateLimiterConfig().MaxDelay |
||||
} |
||||
if cfg.DepthFactor <= 0 { |
||||
cfg.DepthFactor = DefaultRateLimiterConfig().DepthFactor |
||||
} |
||||
if cfg.LimitFactor <= 0 { |
||||
cfg.LimitFactor = DefaultRateLimiterConfig().LimitFactor |
||||
} |
||||
|
||||
return &RateLimiter{ |
||||
tokens: cfg.MaxTokens, |
||||
maxTokens: cfg.MaxTokens, |
||||
refillRate: cfg.RefillRate, |
||||
lastRefill: time.Now(), |
||||
baseDelay: cfg.BaseDelay, |
||||
maxDelay: cfg.MaxDelay, |
||||
depthFactor: cfg.DepthFactor, |
||||
limitFactor: cfg.LimitFactor, |
||||
} |
||||
} |
||||
|
||||
// QueryCost calculates the token cost for a graph query based on its complexity.
|
||||
// Higher depths and larger limits cost exponentially more tokens.
|
||||
func (rl *RateLimiter) QueryCost(q *Query) float64 { |
||||
if q == nil { |
||||
return 1.0 |
||||
} |
||||
|
||||
// Base cost is exponential in depth: depthFactor^depth
|
||||
// This models the exponential growth of traversal work
|
||||
cost := 1.0 |
||||
for i := 0; i < q.Depth; i++ { |
||||
cost *= rl.depthFactor |
||||
} |
||||
|
||||
// Add cost for reference collection (adds ~50% per ref spec)
|
||||
refCost := float64(len(q.InboundRefs)+len(q.OutboundRefs)) * 0.5 |
||||
cost += refCost |
||||
|
||||
return cost |
||||
} |
||||
|
||||
// OperationCost calculates the token cost for a single traversal operation.
|
||||
// This is used during query execution for per-operation throttling.
|
||||
func (rl *RateLimiter) OperationCost(depth int, nodesAtDepth int) float64 { |
||||
// Cost increases with depth and number of nodes to process
|
||||
depthMultiplier := 1.0 |
||||
for i := 0; i < depth; i++ { |
||||
depthMultiplier *= rl.depthFactor |
||||
} |
||||
|
||||
// More nodes at this depth = more work
|
||||
nodeFactor := 1.0 + float64(nodesAtDepth)*0.01 |
||||
|
||||
return depthMultiplier * nodeFactor |
||||
} |
||||
|
||||
// refillTokens adds tokens based on elapsed time since last refill.
|
||||
func (rl *RateLimiter) refillTokens() { |
||||
now := time.Now() |
||||
elapsed := now.Sub(rl.lastRefill).Seconds() |
||||
rl.lastRefill = now |
||||
|
||||
rl.tokens += elapsed * rl.refillRate |
||||
if rl.tokens > rl.maxTokens { |
||||
rl.tokens = rl.maxTokens |
||||
} |
||||
} |
||||
|
||||
// Acquire tries to acquire tokens for a query. If not enough tokens are available,
|
||||
// it waits until they become available or the context is cancelled.
|
||||
// Returns the delay that was applied, or an error if context was cancelled.
|
||||
func (rl *RateLimiter) Acquire(ctx context.Context, cost float64) (time.Duration, error) { |
||||
rl.mu.Lock() |
||||
defer rl.mu.Unlock() |
||||
|
||||
rl.refillTokens() |
||||
|
||||
var totalDelay time.Duration |
||||
|
||||
// Wait until we have enough tokens
|
||||
for rl.tokens < cost { |
||||
// Calculate how long we need to wait for tokens to refill
|
||||
tokensNeeded := cost - rl.tokens |
||||
waitTime := time.Duration(tokensNeeded/rl.refillRate*1000) * time.Millisecond |
||||
|
||||
// Clamp to max delay
|
||||
if waitTime > rl.maxDelay { |
||||
waitTime = rl.maxDelay |
||||
} |
||||
if waitTime < rl.baseDelay { |
||||
waitTime = rl.baseDelay |
||||
} |
||||
|
||||
// Release lock while waiting
|
||||
rl.mu.Unlock() |
||||
|
||||
select { |
||||
case <-ctx.Done(): |
||||
rl.mu.Lock() |
||||
return totalDelay, ctx.Err() |
||||
case <-time.After(waitTime): |
||||
} |
||||
|
||||
totalDelay += waitTime |
||||
rl.mu.Lock() |
||||
rl.refillTokens() |
||||
} |
||||
|
||||
// Consume tokens
|
||||
rl.tokens -= cost |
||||
return totalDelay, nil |
||||
} |
||||
|
||||
// TryAcquire attempts to acquire tokens without waiting.
|
||||
// Returns true if successful, false if insufficient tokens.
|
||||
func (rl *RateLimiter) TryAcquire(cost float64) bool { |
||||
rl.mu.Lock() |
||||
defer rl.mu.Unlock() |
||||
|
||||
rl.refillTokens() |
||||
|
||||
if rl.tokens >= cost { |
||||
rl.tokens -= cost |
||||
return true |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// Pause inserts a cooperative delay to allow other work to proceed.
|
||||
// The delay is proportional to the current depth and load.
|
||||
// This should be called periodically during long-running traversals.
|
||||
func (rl *RateLimiter) Pause(ctx context.Context, depth int, itemsProcessed int) error { |
||||
// Calculate adaptive delay based on depth and progress
|
||||
// Deeper traversals and more processed items = longer pauses
|
||||
delay := rl.baseDelay |
||||
|
||||
// Increase delay with depth
|
||||
for i := 0; i < depth; i++ { |
||||
delay += rl.baseDelay |
||||
} |
||||
|
||||
// Add extra delay every N items to allow other work
|
||||
if itemsProcessed > 0 && itemsProcessed%100 == 0 { |
||||
delay += rl.baseDelay * 5 |
||||
} |
||||
|
||||
// Cap at max delay
|
||||
if delay > rl.maxDelay { |
||||
delay = rl.maxDelay |
||||
} |
||||
|
||||
select { |
||||
case <-ctx.Done(): |
||||
return ctx.Err() |
||||
case <-time.After(delay): |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
// AvailableTokens returns the current number of available tokens.
|
||||
func (rl *RateLimiter) AvailableTokens() float64 { |
||||
rl.mu.Lock() |
||||
defer rl.mu.Unlock() |
||||
rl.refillTokens() |
||||
return rl.tokens |
||||
} |
||||
|
||||
// Throttler provides a simple interface for cooperative scheduling during traversal.
|
||||
// It wraps the rate limiter and provides depth-aware throttling.
|
||||
type Throttler struct { |
||||
rl *RateLimiter |
||||
depth int |
||||
itemsProcessed int |
||||
} |
||||
|
||||
// NewThrottler creates a throttler for a specific traversal operation.
|
||||
func NewThrottler(rl *RateLimiter, depth int) *Throttler { |
||||
return &Throttler{ |
||||
rl: rl, |
||||
depth: depth, |
||||
} |
||||
} |
||||
|
||||
// Tick should be called after processing each item.
|
||||
// It tracks progress and inserts pauses as needed.
|
||||
func (t *Throttler) Tick(ctx context.Context) error { |
||||
t.itemsProcessed++ |
||||
|
||||
// Insert cooperative pause periodically
|
||||
// More frequent pauses at higher depths
|
||||
interval := 50 |
||||
if t.depth >= 2 { |
||||
interval = 25 |
||||
} |
||||
if t.depth >= 4 { |
||||
interval = 10 |
||||
} |
||||
|
||||
if t.itemsProcessed%interval == 0 { |
||||
return t.rl.Pause(ctx, t.depth, t.itemsProcessed) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Complete marks the throttler as complete and returns stats.
|
||||
func (t *Throttler) Complete() (itemsProcessed int) { |
||||
return t.itemsProcessed |
||||
} |
||||
@ -0,0 +1,267 @@
@@ -0,0 +1,267 @@
|
||||
package graph |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestRateLimiterQueryCost(t *testing.T) { |
||||
rl := NewRateLimiter(DefaultRateLimiterConfig()) |
||||
|
||||
tests := []struct { |
||||
name string |
||||
query *Query |
||||
minCost float64 |
||||
maxCost float64 |
||||
}{ |
||||
{ |
||||
name: "nil query", |
||||
query: nil, |
||||
minCost: 1.0, |
||||
maxCost: 1.0, |
||||
}, |
||||
{ |
||||
name: "depth 1 no refs", |
||||
query: &Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
Depth: 1, |
||||
}, |
||||
minCost: 1.5, // depthFactor^1 = 2
|
||||
maxCost: 2.5, |
||||
}, |
||||
{ |
||||
name: "depth 2 no refs", |
||||
query: &Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
Depth: 2, |
||||
}, |
||||
minCost: 3.5, // depthFactor^2 = 4
|
||||
maxCost: 4.5, |
||||
}, |
||||
{ |
||||
name: "depth 3 no refs", |
||||
query: &Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
Depth: 3, |
||||
}, |
||||
minCost: 7.5, // depthFactor^3 = 8
|
||||
maxCost: 8.5, |
||||
}, |
||||
{ |
||||
name: "depth 2 with inbound refs", |
||||
query: &Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
Depth: 2, |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{7}}, |
||||
}, |
||||
}, |
||||
minCost: 4.0, // 4 + 0.5 = 4.5
|
||||
maxCost: 5.0, |
||||
}, |
||||
{ |
||||
name: "depth 2 with both refs", |
||||
query: &Query{ |
||||
Method: "follows", |
||||
Seed: "abc", |
||||
Depth: 2, |
||||
InboundRefs: []RefSpec{ |
||||
{Kinds: []int{7}}, |
||||
}, |
||||
OutboundRefs: []RefSpec{ |
||||
{Kinds: []int{1}}, |
||||
}, |
||||
}, |
||||
minCost: 4.5, // 4 + 0.5 + 0.5 = 5
|
||||
maxCost: 5.5, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
cost := rl.QueryCost(tt.query) |
||||
if cost < tt.minCost || cost > tt.maxCost { |
||||
t.Errorf("QueryCost() = %v, want between %v and %v", cost, tt.minCost, tt.maxCost) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestRateLimiterOperationCost(t *testing.T) { |
||||
rl := NewRateLimiter(DefaultRateLimiterConfig()) |
||||
|
||||
// Depth 0, 1 node
|
||||
cost0 := rl.OperationCost(0, 1) |
||||
if cost0 < 1.0 || cost0 > 1.1 { |
||||
t.Errorf("OperationCost(0, 1) = %v, want ~1.01", cost0) |
||||
} |
||||
|
||||
// Depth 1, 1 node
|
||||
cost1 := rl.OperationCost(1, 1) |
||||
if cost1 < 2.0 || cost1 > 2.1 { |
||||
t.Errorf("OperationCost(1, 1) = %v, want ~2.02", cost1) |
||||
} |
||||
|
||||
// Depth 2, 100 nodes
|
||||
cost2 := rl.OperationCost(2, 100) |
||||
if cost2 < 8.0 { |
||||
t.Errorf("OperationCost(2, 100) = %v, want > 8", cost2) |
||||
} |
||||
} |
||||
|
||||
func TestRateLimiterAcquire(t *testing.T) { |
||||
cfg := DefaultRateLimiterConfig() |
||||
cfg.MaxTokens = 10 |
||||
cfg.RefillRate = 100 // Fast refill for testing
|
||||
rl := NewRateLimiter(cfg) |
||||
|
||||
ctx := context.Background() |
||||
|
||||
// Should acquire immediately when tokens available
|
||||
delay, err := rl.Acquire(ctx, 5) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error: %v", err) |
||||
} |
||||
if delay > time.Millisecond*10 { |
||||
t.Errorf("expected minimal delay, got %v", delay) |
||||
} |
||||
|
||||
// Check remaining tokens
|
||||
remaining := rl.AvailableTokens() |
||||
if remaining > 6 { |
||||
t.Errorf("expected ~5 tokens remaining, got %v", remaining) |
||||
} |
||||
} |
||||
|
||||
func TestRateLimiterTryAcquire(t *testing.T) { |
||||
cfg := DefaultRateLimiterConfig() |
||||
cfg.MaxTokens = 10 |
||||
rl := NewRateLimiter(cfg) |
||||
|
||||
// Should succeed with enough tokens
|
||||
if !rl.TryAcquire(5) { |
||||
t.Error("TryAcquire(5) should succeed with 10 tokens") |
||||
} |
||||
|
||||
// Should succeed again
|
||||
if !rl.TryAcquire(5) { |
||||
t.Error("TryAcquire(5) should succeed with 5 tokens") |
||||
} |
||||
|
||||
// Should fail with insufficient tokens
|
||||
if rl.TryAcquire(1) { |
||||
t.Error("TryAcquire(1) should fail with 0 tokens") |
||||
} |
||||
} |
||||
|
||||
func TestRateLimiterContextCancellation(t *testing.T) { |
||||
cfg := DefaultRateLimiterConfig() |
||||
cfg.MaxTokens = 1 |
||||
cfg.RefillRate = 0.1 // Very slow refill
|
||||
rl := NewRateLimiter(cfg) |
||||
|
||||
// Drain tokens
|
||||
rl.TryAcquire(1) |
||||
|
||||
// Create cancellable context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) |
||||
defer cancel() |
||||
|
||||
// Try to acquire - should be cancelled
|
||||
_, err := rl.Acquire(ctx, 10) |
||||
if err != context.DeadlineExceeded { |
||||
t.Errorf("expected DeadlineExceeded, got %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestRateLimiterRefill(t *testing.T) { |
||||
cfg := DefaultRateLimiterConfig() |
||||
cfg.MaxTokens = 10 |
||||
cfg.RefillRate = 1000 // 1000 tokens per second
|
||||
rl := NewRateLimiter(cfg) |
||||
|
||||
// Drain tokens
|
||||
rl.TryAcquire(10) |
||||
|
||||
// Wait for refill
|
||||
time.Sleep(15 * time.Millisecond) |
||||
|
||||
// Should have some tokens now
|
||||
available := rl.AvailableTokens() |
||||
if available < 5 { |
||||
t.Errorf("expected >= 5 tokens after 15ms at 1000/s, got %v", available) |
||||
} |
||||
if available > 10 { |
||||
t.Errorf("expected <= 10 tokens (max), got %v", available) |
||||
} |
||||
} |
||||
|
||||
func TestRateLimiterPause(t *testing.T) { |
||||
rl := NewRateLimiter(DefaultRateLimiterConfig()) |
||||
ctx := context.Background() |
||||
|
||||
start := time.Now() |
||||
err := rl.Pause(ctx, 1, 0) |
||||
elapsed := time.Since(start) |
||||
|
||||
if err != nil { |
||||
t.Fatalf("unexpected error: %v", err) |
||||
} |
||||
|
||||
// Should have paused for at least baseDelay
|
||||
if elapsed < rl.baseDelay { |
||||
t.Errorf("pause duration %v < baseDelay %v", elapsed, rl.baseDelay) |
||||
} |
||||
} |
||||
|
||||
func TestThrottler(t *testing.T) { |
||||
cfg := DefaultRateLimiterConfig() |
||||
cfg.BaseDelay = 100 * time.Microsecond // Short for testing
|
||||
rl := NewRateLimiter(cfg) |
||||
|
||||
throttler := NewThrottler(rl, 1) |
||||
ctx := context.Background() |
||||
|
||||
// Process items
|
||||
for i := 0; i < 100; i++ { |
||||
if err := throttler.Tick(ctx); err != nil { |
||||
t.Fatalf("unexpected error at tick %d: %v", i, err) |
||||
} |
||||
} |
||||
|
||||
processed := throttler.Complete() |
||||
if processed != 100 { |
||||
t.Errorf("expected 100 items processed, got %d", processed) |
||||
} |
||||
} |
||||
|
||||
func TestThrottlerContextCancellation(t *testing.T) { |
||||
cfg := DefaultRateLimiterConfig() |
||||
rl := NewRateLimiter(cfg) |
||||
|
||||
throttler := NewThrottler(rl, 2) // depth 2 = more frequent pauses
|
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
||||
// Process some items
|
||||
for i := 0; i < 20; i++ { |
||||
throttler.Tick(ctx) |
||||
} |
||||
|
||||
// Cancel context
|
||||
cancel() |
||||
|
||||
// Next tick that would pause should return error
|
||||
for i := 0; i < 100; i++ { |
||||
if err := throttler.Tick(ctx); err != nil { |
||||
// Expected - context was cancelled
|
||||
return |
||||
} |
||||
} |
||||
// If we get here without error, the throttler didn't check context
|
||||
// This is acceptable if no pause was needed
|
||||
} |
||||
Loading…
Reference in new issue