Browse Source
- Add LoadMonitor interface in pkg/interfaces/loadmonitor/ for database load metrics - Implement PIDController with filtered derivative to suppress high-frequency noise - Proportional (P): immediate response to current error - Integral (I): eliminates steady-state offset with anti-windup clamping - Derivative (D): rate-of-change prediction with low-pass filtering - Create BadgerLoadMonitor tracking L0 tables, compaction score, and cache hit ratio - Create Neo4jLoadMonitor tracking query semaphore usage and latencies - Add AdaptiveRateLimiter combining PID controllers for reads and writes - Configure via environment variables: - ORLY_RATE_LIMIT_ENABLED: enable/disable rate limiting - ORLY_RATE_LIMIT_TARGET_MB: target memory limit (default 1500MB) - ORLY_RATE_LIMIT_*_K[PID]: PID gains for reads/writes - ORLY_RATE_LIMIT_MAX_*_MS: maximum delays - ORLY_RATE_LIMIT_*_TARGET: setpoints for reads/writes - Integrate rate limiter into Server struct and lifecycle management - Add comprehensive unit tests for PID controller behavior Files modified: - app/config/config.go: Add rate limiting configuration options - app/main.go: Initialize and start/stop rate limiter - app/server.go: Add rateLimiter field to Server struct - main.go: Create rate limiter with appropriate monitor - pkg/run/run.go: Pass disabled limiter for test instances - pkg/interfaces/loadmonitor/: New LoadMonitor interface - pkg/ratelimit/: New PID controller and limiter implementation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>main
12 changed files with 1511 additions and 13 deletions
@ -0,0 +1,58 @@
@@ -0,0 +1,58 @@
|
||||
// Package loadmonitor defines the interface for database load monitoring.
|
||||
// This allows different database backends to provide their own load metrics
|
||||
// while the rate limiter remains database-agnostic.
|
||||
package loadmonitor |
||||
|
||||
import "time" |
||||
|
||||
// Metrics contains load metrics from a database backend.
|
||||
// All values are normalized to 0.0-1.0 where 0 means no load and 1 means at capacity.
|
||||
type Metrics struct { |
||||
// MemoryPressure indicates memory usage relative to a target limit (0.0-1.0+).
|
||||
// Values above 1.0 indicate the target has been exceeded.
|
||||
MemoryPressure float64 |
||||
|
||||
// WriteLoad indicates the write-side load level (0.0-1.0).
|
||||
// For Badger: L0 tables and compaction score
|
||||
// For Neo4j: active write transactions
|
||||
WriteLoad float64 |
||||
|
||||
// ReadLoad indicates the read-side load level (0.0-1.0).
|
||||
// For Badger: cache hit ratio (inverted)
|
||||
// For Neo4j: active read transactions
|
||||
ReadLoad float64 |
||||
|
||||
// QueryLatency is the recent average query latency.
|
||||
QueryLatency time.Duration |
||||
|
||||
// WriteLatency is the recent average write latency.
|
||||
WriteLatency time.Duration |
||||
|
||||
// Timestamp is when these metrics were collected.
|
||||
Timestamp time.Time |
||||
} |
||||
|
||||
// Monitor defines the interface for database load monitoring.
|
||||
// Implementations are database-specific (Badger, Neo4j, etc.).
|
||||
type Monitor interface { |
||||
// GetMetrics returns the current load metrics.
|
||||
// This should be efficient as it may be called frequently.
|
||||
GetMetrics() Metrics |
||||
|
||||
// RecordQueryLatency records a query latency sample for averaging.
|
||||
RecordQueryLatency(latency time.Duration) |
||||
|
||||
// RecordWriteLatency records a write latency sample for averaging.
|
||||
RecordWriteLatency(latency time.Duration) |
||||
|
||||
// SetMemoryTarget sets the target memory limit in bytes.
|
||||
// Memory pressure is calculated relative to this target.
|
||||
SetMemoryTarget(bytes uint64) |
||||
|
||||
// Start begins background metric collection.
|
||||
// Returns a channel that will be closed when the monitor is stopped.
|
||||
Start() <-chan struct{} |
||||
|
||||
// Stop halts background metric collection.
|
||||
Stop() |
||||
} |
||||
@ -0,0 +1,237 @@
@@ -0,0 +1,237 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"runtime" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/dgraph-io/badger/v4" |
||||
"next.orly.dev/pkg/interfaces/loadmonitor" |
||||
) |
||||
|
||||
// BadgerMonitor implements loadmonitor.Monitor for the Badger database.
|
||||
// It collects metrics from Badger's LSM tree, caches, and Go runtime.
|
||||
type BadgerMonitor struct { |
||||
db *badger.DB |
||||
|
||||
// Target memory for pressure calculation
|
||||
targetMemoryBytes atomic.Uint64 |
||||
|
||||
// Latency tracking with exponential moving average
|
||||
queryLatencyNs atomic.Int64 |
||||
writeLatencyNs atomic.Int64 |
||||
latencyAlpha float64 // EMA coefficient (default 0.1)
|
||||
|
||||
// Cached metrics (updated by background goroutine)
|
||||
metricsLock sync.RWMutex |
||||
cachedMetrics loadmonitor.Metrics |
||||
lastL0Tables int |
||||
lastL0Score float64 |
||||
|
||||
// Background collection
|
||||
stopChan chan struct{} |
||||
stopped chan struct{} |
||||
interval time.Duration |
||||
} |
||||
|
||||
// Compile-time check that BadgerMonitor implements loadmonitor.Monitor
|
||||
var _ loadmonitor.Monitor = (*BadgerMonitor)(nil) |
||||
|
||||
// NewBadgerMonitor creates a new Badger load monitor.
|
||||
// The updateInterval controls how often metrics are collected (default 100ms).
|
||||
func NewBadgerMonitor(db *badger.DB, updateInterval time.Duration) *BadgerMonitor { |
||||
if updateInterval <= 0 { |
||||
updateInterval = 100 * time.Millisecond |
||||
} |
||||
|
||||
m := &BadgerMonitor{ |
||||
db: db, |
||||
latencyAlpha: 0.1, // 10% new, 90% old for smooth EMA
|
||||
stopChan: make(chan struct{}), |
||||
stopped: make(chan struct{}), |
||||
interval: updateInterval, |
||||
} |
||||
|
||||
// Set a default target (1.5GB)
|
||||
m.targetMemoryBytes.Store(1500 * 1024 * 1024) |
||||
|
||||
return m |
||||
} |
||||
|
||||
// GetMetrics returns the current load metrics.
|
||||
func (m *BadgerMonitor) GetMetrics() loadmonitor.Metrics { |
||||
m.metricsLock.RLock() |
||||
defer m.metricsLock.RUnlock() |
||||
return m.cachedMetrics |
||||
} |
||||
|
||||
// RecordQueryLatency records a query latency sample using exponential moving average.
|
||||
func (m *BadgerMonitor) RecordQueryLatency(latency time.Duration) { |
||||
ns := latency.Nanoseconds() |
||||
for { |
||||
old := m.queryLatencyNs.Load() |
||||
if old == 0 { |
||||
if m.queryLatencyNs.CompareAndSwap(0, ns) { |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
// EMA: new = alpha * sample + (1-alpha) * old
|
||||
newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) |
||||
if m.queryLatencyNs.CompareAndSwap(old, newVal) { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// RecordWriteLatency records a write latency sample using exponential moving average.
|
||||
func (m *BadgerMonitor) RecordWriteLatency(latency time.Duration) { |
||||
ns := latency.Nanoseconds() |
||||
for { |
||||
old := m.writeLatencyNs.Load() |
||||
if old == 0 { |
||||
if m.writeLatencyNs.CompareAndSwap(0, ns) { |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
// EMA: new = alpha * sample + (1-alpha) * old
|
||||
newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) |
||||
if m.writeLatencyNs.CompareAndSwap(old, newVal) { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// SetMemoryTarget sets the target memory limit in bytes.
|
||||
func (m *BadgerMonitor) SetMemoryTarget(bytes uint64) { |
||||
m.targetMemoryBytes.Store(bytes) |
||||
} |
||||
|
||||
// Start begins background metric collection.
|
||||
func (m *BadgerMonitor) Start() <-chan struct{} { |
||||
go m.collectLoop() |
||||
return m.stopped |
||||
} |
||||
|
||||
// Stop halts background metric collection.
|
||||
func (m *BadgerMonitor) Stop() { |
||||
close(m.stopChan) |
||||
<-m.stopped |
||||
} |
||||
|
||||
// collectLoop periodically collects metrics from Badger.
|
||||
func (m *BadgerMonitor) collectLoop() { |
||||
defer close(m.stopped) |
||||
|
||||
ticker := time.NewTicker(m.interval) |
||||
defer ticker.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-m.stopChan: |
||||
return |
||||
case <-ticker.C: |
||||
m.updateMetrics() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// updateMetrics collects current metrics from Badger and runtime.
|
||||
func (m *BadgerMonitor) updateMetrics() { |
||||
if m.db == nil || m.db.IsClosed() { |
||||
return |
||||
} |
||||
|
||||
metrics := loadmonitor.Metrics{ |
||||
Timestamp: time.Now(), |
||||
} |
||||
|
||||
// Calculate memory pressure from Go runtime
|
||||
var memStats runtime.MemStats |
||||
runtime.ReadMemStats(&memStats) |
||||
|
||||
targetBytes := m.targetMemoryBytes.Load() |
||||
if targetBytes > 0 { |
||||
// Use HeapAlloc as primary memory metric
|
||||
// This represents the actual live heap objects
|
||||
metrics.MemoryPressure = float64(memStats.HeapAlloc) / float64(targetBytes) |
||||
} |
||||
|
||||
// Get Badger LSM tree information for write load
|
||||
levels := m.db.Levels() |
||||
var l0Tables int |
||||
var maxScore float64 |
||||
|
||||
for _, level := range levels { |
||||
if level.Level == 0 { |
||||
l0Tables = level.NumTables |
||||
} |
||||
if level.Score > maxScore { |
||||
maxScore = level.Score |
||||
} |
||||
} |
||||
|
||||
// Calculate write load based on L0 tables and compaction score
|
||||
// L0 tables stall at NumLevelZeroTablesStall (default 16)
|
||||
// We consider write pressure high when approaching that limit
|
||||
const l0StallThreshold = 16 |
||||
l0Load := float64(l0Tables) / float64(l0StallThreshold) |
||||
if l0Load > 1.0 { |
||||
l0Load = 1.0 |
||||
} |
||||
|
||||
// Compaction score > 1.0 means compaction is needed
|
||||
// We blend L0 tables and compaction score for write load
|
||||
compactionLoad := maxScore / 2.0 // Score of 2.0 = fully loaded
|
||||
if compactionLoad > 1.0 { |
||||
compactionLoad = 1.0 |
||||
} |
||||
|
||||
// Blend: 60% L0 (immediate backpressure), 40% compaction score
|
||||
metrics.WriteLoad = 0.6*l0Load + 0.4*compactionLoad |
||||
|
||||
// Calculate read load from cache metrics
|
||||
blockMetrics := m.db.BlockCacheMetrics() |
||||
indexMetrics := m.db.IndexCacheMetrics() |
||||
|
||||
var blockHitRatio, indexHitRatio float64 |
||||
if blockMetrics != nil { |
||||
blockHitRatio = blockMetrics.Ratio() |
||||
} |
||||
if indexMetrics != nil { |
||||
indexHitRatio = indexMetrics.Ratio() |
||||
} |
||||
|
||||
// Average cache hit ratio (0 = no hits = high load, 1 = all hits = low load)
|
||||
avgHitRatio := (blockHitRatio + indexHitRatio) / 2.0 |
||||
|
||||
// Invert: low hit ratio = high read load
|
||||
// Use 0.5 as the threshold (below 50% hit ratio is concerning)
|
||||
if avgHitRatio < 0.5 { |
||||
metrics.ReadLoad = 1.0 - avgHitRatio*2 // 0% hits = 1.0 load, 50% hits = 0.0 load
|
||||
} else { |
||||
metrics.ReadLoad = 0 // Above 50% hit ratio = minimal load
|
||||
} |
||||
|
||||
// Store latencies
|
||||
metrics.QueryLatency = time.Duration(m.queryLatencyNs.Load()) |
||||
metrics.WriteLatency = time.Duration(m.writeLatencyNs.Load()) |
||||
|
||||
// Update cached metrics
|
||||
m.metricsLock.Lock() |
||||
m.cachedMetrics = metrics |
||||
m.lastL0Tables = l0Tables |
||||
m.lastL0Score = maxScore |
||||
m.metricsLock.Unlock() |
||||
} |
||||
|
||||
// GetL0Stats returns L0-specific statistics for debugging.
|
||||
func (m *BadgerMonitor) GetL0Stats() (tables int, score float64) { |
||||
m.metricsLock.RLock() |
||||
defer m.metricsLock.RUnlock() |
||||
return m.lastL0Tables, m.lastL0Score |
||||
} |
||||
@ -0,0 +1,56 @@
@@ -0,0 +1,56 @@
|
||||
//go:build !(js && wasm)
|
||||
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/dgraph-io/badger/v4" |
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j" |
||||
"next.orly.dev/pkg/interfaces/loadmonitor" |
||||
) |
||||
|
||||
// NewBadgerLimiter creates a rate limiter configured for a Badger database.
|
||||
// It automatically creates a BadgerMonitor for the provided database.
|
||||
func NewBadgerLimiter(config Config, db *badger.DB) *Limiter { |
||||
monitor := NewBadgerMonitor(db, 100*time.Millisecond) |
||||
return NewLimiter(config, monitor) |
||||
} |
||||
|
||||
// NewNeo4jLimiter creates a rate limiter configured for a Neo4j database.
|
||||
// It automatically creates a Neo4jMonitor for the provided driver.
|
||||
// querySem should be the semaphore used to limit concurrent queries.
|
||||
// maxConcurrency is typically 10 (matching the semaphore size).
|
||||
func NewNeo4jLimiter( |
||||
config Config, |
||||
driver neo4j.DriverWithContext, |
||||
querySem chan struct{}, |
||||
maxConcurrency int, |
||||
) *Limiter { |
||||
monitor := NewNeo4jMonitor(driver, querySem, maxConcurrency, 100*time.Millisecond) |
||||
return NewLimiter(config, monitor) |
||||
} |
||||
|
||||
// NewDisabledLimiter creates a rate limiter that is disabled.
|
||||
// This is useful when rate limiting is not configured.
|
||||
func NewDisabledLimiter() *Limiter { |
||||
config := DefaultConfig() |
||||
config.Enabled = false |
||||
return NewLimiter(config, nil) |
||||
} |
||||
|
||||
// MonitorFromBadgerDB creates a BadgerMonitor from a Badger database.
|
||||
// Exported for use when you need to create the monitor separately.
|
||||
func MonitorFromBadgerDB(db *badger.DB) loadmonitor.Monitor { |
||||
return NewBadgerMonitor(db, 100*time.Millisecond) |
||||
} |
||||
|
||||
// MonitorFromNeo4jDriver creates a Neo4jMonitor from a Neo4j driver.
|
||||
// Exported for use when you need to create the monitor separately.
|
||||
func MonitorFromNeo4jDriver( |
||||
driver neo4j.DriverWithContext, |
||||
querySem chan struct{}, |
||||
maxConcurrency int, |
||||
) loadmonitor.Monitor { |
||||
return NewNeo4jMonitor(driver, querySem, maxConcurrency, 100*time.Millisecond) |
||||
} |
||||
@ -0,0 +1,409 @@
@@ -0,0 +1,409 @@
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"context" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"next.orly.dev/pkg/interfaces/loadmonitor" |
||||
) |
||||
|
||||
// OperationType distinguishes between read and write operations
|
||||
// for applying different rate limiting strategies.
|
||||
type OperationType int |
||||
|
||||
const ( |
||||
// Read operations (REQ queries)
|
||||
Read OperationType = iota |
||||
// Write operations (EVENT saves, imports)
|
||||
Write |
||||
) |
||||
|
||||
// String returns a human-readable name for the operation type.
|
||||
func (o OperationType) String() string { |
||||
switch o { |
||||
case Read: |
||||
return "read" |
||||
case Write: |
||||
return "write" |
||||
default: |
||||
return "unknown" |
||||
} |
||||
} |
||||
|
||||
// Config holds configuration for the adaptive rate limiter.
|
||||
type Config struct { |
||||
// Enabled controls whether rate limiting is active.
|
||||
Enabled bool |
||||
|
||||
// TargetMemoryMB is the target memory limit in megabytes.
|
||||
// Memory pressure is calculated relative to this target.
|
||||
TargetMemoryMB int |
||||
|
||||
// WriteSetpoint is the target process variable for writes (0.0-1.0).
|
||||
// Default: 0.85 (throttle when load exceeds 85%)
|
||||
WriteSetpoint float64 |
||||
|
||||
// ReadSetpoint is the target process variable for reads (0.0-1.0).
|
||||
// Default: 0.90 (more tolerant for reads)
|
||||
ReadSetpoint float64 |
||||
|
||||
// PID gains for writes
|
||||
WriteKp float64 |
||||
WriteKi float64 |
||||
WriteKd float64 |
||||
|
||||
// PID gains for reads
|
||||
ReadKp float64 |
||||
ReadKi float64 |
||||
ReadKd float64 |
||||
|
||||
// MaxWriteDelayMs is the maximum delay for write operations in milliseconds.
|
||||
MaxWriteDelayMs int |
||||
|
||||
// MaxReadDelayMs is the maximum delay for read operations in milliseconds.
|
||||
MaxReadDelayMs int |
||||
|
||||
// MetricUpdateInterval is how often to poll the load monitor.
|
||||
MetricUpdateInterval time.Duration |
||||
|
||||
// MemoryWeight is the weight given to memory pressure in process variable (0.0-1.0).
|
||||
// The remaining weight is given to the load metric.
|
||||
// Default: 0.7 (70% memory, 30% load)
|
||||
MemoryWeight float64 |
||||
} |
||||
|
||||
// DefaultConfig returns a default configuration for the rate limiter.
|
||||
func DefaultConfig() Config { |
||||
return Config{ |
||||
Enabled: true, |
||||
TargetMemoryMB: 1500, // 1.5GB target
|
||||
WriteSetpoint: 0.85, |
||||
ReadSetpoint: 0.90, |
||||
WriteKp: 0.5, |
||||
WriteKi: 0.1, |
||||
WriteKd: 0.05, |
||||
ReadKp: 0.3, |
||||
ReadKi: 0.05, |
||||
ReadKd: 0.02, |
||||
MaxWriteDelayMs: 1000, // 1 second max
|
||||
MaxReadDelayMs: 500, // 500ms max
|
||||
MetricUpdateInterval: 100 * time.Millisecond, |
||||
MemoryWeight: 0.7, |
||||
} |
||||
} |
||||
|
||||
// NewConfigFromValues creates a Config from individual configuration values.
|
||||
// This is useful when loading configuration from environment variables.
|
||||
func NewConfigFromValues( |
||||
enabled bool, |
||||
targetMB int, |
||||
writeKp, writeKi, writeKd float64, |
||||
readKp, readKi, readKd float64, |
||||
maxWriteMs, maxReadMs int, |
||||
writeTarget, readTarget float64, |
||||
) Config { |
||||
return Config{ |
||||
Enabled: enabled, |
||||
TargetMemoryMB: targetMB, |
||||
WriteSetpoint: writeTarget, |
||||
ReadSetpoint: readTarget, |
||||
WriteKp: writeKp, |
||||
WriteKi: writeKi, |
||||
WriteKd: writeKd, |
||||
ReadKp: readKp, |
||||
ReadKi: readKi, |
||||
ReadKd: readKd, |
||||
MaxWriteDelayMs: maxWriteMs, |
||||
MaxReadDelayMs: maxReadMs, |
||||
MetricUpdateInterval: 100 * time.Millisecond, |
||||
MemoryWeight: 0.7, |
||||
} |
||||
} |
||||
|
||||
// Limiter implements adaptive rate limiting using PID control.
|
||||
// It monitors database load metrics and computes appropriate delays
|
||||
// to keep the system within its target operating range.
|
||||
type Limiter struct { |
||||
config Config |
||||
monitor loadmonitor.Monitor |
||||
|
||||
// PID controllers for reads and writes
|
||||
writePID *PIDController |
||||
readPID *PIDController |
||||
|
||||
// Cached metrics (updated periodically)
|
||||
metricsLock sync.RWMutex |
||||
currentMetrics loadmonitor.Metrics |
||||
|
||||
// Statistics
|
||||
totalWriteDelayMs atomic.Int64 |
||||
totalReadDelayMs atomic.Int64 |
||||
writeThrottles atomic.Int64 |
||||
readThrottles atomic.Int64 |
||||
|
||||
// Lifecycle
|
||||
ctx context.Context |
||||
cancel context.CancelFunc |
||||
stopOnce sync.Once |
||||
stopped chan struct{} |
||||
wg sync.WaitGroup |
||||
} |
||||
|
||||
// NewLimiter creates a new adaptive rate limiter.
|
||||
// If monitor is nil, the limiter will be disabled.
|
||||
func NewLimiter(config Config, monitor loadmonitor.Monitor) *Limiter { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
||||
l := &Limiter{ |
||||
config: config, |
||||
monitor: monitor, |
||||
ctx: ctx, |
||||
cancel: cancel, |
||||
stopped: make(chan struct{}), |
||||
} |
||||
|
||||
// Create PID controllers with configured gains
|
||||
l.writePID = NewPIDController( |
||||
config.WriteKp, config.WriteKi, config.WriteKd, |
||||
config.WriteSetpoint, |
||||
0.2, // Strong filtering for writes
|
||||
-2.0, float64(config.MaxWriteDelayMs)/1000.0*2, // Anti-windup limits
|
||||
0, float64(config.MaxWriteDelayMs)/1000.0, |
||||
) |
||||
|
||||
l.readPID = NewPIDController( |
||||
config.ReadKp, config.ReadKi, config.ReadKd, |
||||
config.ReadSetpoint, |
||||
0.15, // Very strong filtering for reads
|
||||
-1.0, float64(config.MaxReadDelayMs)/1000.0*2, |
||||
0, float64(config.MaxReadDelayMs)/1000.0, |
||||
) |
||||
|
||||
// Set memory target on monitor
|
||||
if monitor != nil && config.TargetMemoryMB > 0 { |
||||
monitor.SetMemoryTarget(uint64(config.TargetMemoryMB) * 1024 * 1024) |
||||
} |
||||
|
||||
return l |
||||
} |
||||
|
||||
// Start begins the rate limiter's background metric collection.
|
||||
func (l *Limiter) Start() { |
||||
if l.monitor == nil || !l.config.Enabled { |
||||
return |
||||
} |
||||
|
||||
// Start the monitor
|
||||
l.monitor.Start() |
||||
|
||||
// Start metric update loop
|
||||
l.wg.Add(1) |
||||
go l.updateLoop() |
||||
} |
||||
|
||||
// updateLoop periodically fetches metrics from the monitor.
|
||||
func (l *Limiter) updateLoop() { |
||||
defer l.wg.Done() |
||||
|
||||
ticker := time.NewTicker(l.config.MetricUpdateInterval) |
||||
defer ticker.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-l.ctx.Done(): |
||||
return |
||||
case <-ticker.C: |
||||
if l.monitor != nil { |
||||
metrics := l.monitor.GetMetrics() |
||||
l.metricsLock.Lock() |
||||
l.currentMetrics = metrics |
||||
l.metricsLock.Unlock() |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Stop halts the rate limiter.
|
||||
func (l *Limiter) Stop() { |
||||
l.stopOnce.Do(func() { |
||||
l.cancel() |
||||
if l.monitor != nil { |
||||
l.monitor.Stop() |
||||
} |
||||
l.wg.Wait() |
||||
close(l.stopped) |
||||
}) |
||||
} |
||||
|
||||
// Stopped returns a channel that closes when the limiter has stopped.
|
||||
func (l *Limiter) Stopped() <-chan struct{} { |
||||
return l.stopped |
||||
} |
||||
|
||||
// Wait blocks until the rate limiter permits the operation to proceed.
|
||||
// It returns the delay that was applied, or 0 if no delay was needed.
|
||||
// If the context is cancelled, it returns immediately.
|
||||
func (l *Limiter) Wait(ctx context.Context, opType OperationType) time.Duration { |
||||
if !l.config.Enabled || l.monitor == nil { |
||||
return 0 |
||||
} |
||||
|
||||
delay := l.ComputeDelay(opType) |
||||
if delay <= 0 { |
||||
return 0 |
||||
} |
||||
|
||||
// Apply the delay
|
||||
select { |
||||
case <-ctx.Done(): |
||||
return 0 |
||||
case <-time.After(delay): |
||||
return delay |
||||
} |
||||
} |
||||
|
||||
// ComputeDelay calculates the recommended delay for an operation.
|
||||
// This can be used to check the delay without actually waiting.
|
||||
func (l *Limiter) ComputeDelay(opType OperationType) time.Duration { |
||||
if !l.config.Enabled || l.monitor == nil { |
||||
return 0 |
||||
} |
||||
|
||||
// Get current metrics
|
||||
l.metricsLock.RLock() |
||||
metrics := l.currentMetrics |
||||
l.metricsLock.RUnlock() |
||||
|
||||
// Compute process variable as weighted combination of memory and load
|
||||
var loadMetric float64 |
||||
switch opType { |
||||
case Write: |
||||
loadMetric = metrics.WriteLoad |
||||
case Read: |
||||
loadMetric = metrics.ReadLoad |
||||
} |
||||
|
||||
// Combine memory pressure and load
|
||||
// Process variable = memoryWeight * memoryPressure + (1-memoryWeight) * loadMetric
|
||||
pv := l.config.MemoryWeight*metrics.MemoryPressure + (1-l.config.MemoryWeight)*loadMetric |
||||
|
||||
// Select the appropriate PID controller
|
||||
var delaySec float64 |
||||
switch opType { |
||||
case Write: |
||||
delaySec = l.writePID.Update(pv) |
||||
if delaySec > 0 { |
||||
l.writeThrottles.Add(1) |
||||
l.totalWriteDelayMs.Add(int64(delaySec * 1000)) |
||||
} |
||||
case Read: |
||||
delaySec = l.readPID.Update(pv) |
||||
if delaySec > 0 { |
||||
l.readThrottles.Add(1) |
||||
l.totalReadDelayMs.Add(int64(delaySec * 1000)) |
||||
} |
||||
} |
||||
|
||||
if delaySec <= 0 { |
||||
return 0 |
||||
} |
||||
|
||||
return time.Duration(delaySec * float64(time.Second)) |
||||
} |
||||
|
||||
// RecordLatency records an operation latency for the monitor.
|
||||
func (l *Limiter) RecordLatency(opType OperationType, latency time.Duration) { |
||||
if l.monitor == nil { |
||||
return |
||||
} |
||||
|
||||
switch opType { |
||||
case Write: |
||||
l.monitor.RecordWriteLatency(latency) |
||||
case Read: |
||||
l.monitor.RecordQueryLatency(latency) |
||||
} |
||||
} |
||||
|
||||
// Stats returns rate limiter statistics.
|
||||
type Stats struct { |
||||
WriteThrottles int64 |
||||
ReadThrottles int64 |
||||
TotalWriteDelayMs int64 |
||||
TotalReadDelayMs int64 |
||||
CurrentMetrics loadmonitor.Metrics |
||||
WritePIDState PIDState |
||||
ReadPIDState PIDState |
||||
} |
||||
|
||||
// PIDState contains the internal state of a PID controller.
|
||||
type PIDState struct { |
||||
Integral float64 |
||||
PrevError float64 |
||||
PrevFilteredError float64 |
||||
} |
||||
|
||||
// GetStats returns current rate limiter statistics.
|
||||
func (l *Limiter) GetStats() Stats { |
||||
l.metricsLock.RLock() |
||||
metrics := l.currentMetrics |
||||
l.metricsLock.RUnlock() |
||||
|
||||
wIntegral, wPrevErr, wPrevFiltered := l.writePID.GetState() |
||||
rIntegral, rPrevErr, rPrevFiltered := l.readPID.GetState() |
||||
|
||||
return Stats{ |
||||
WriteThrottles: l.writeThrottles.Load(), |
||||
ReadThrottles: l.readThrottles.Load(), |
||||
TotalWriteDelayMs: l.totalWriteDelayMs.Load(), |
||||
TotalReadDelayMs: l.totalReadDelayMs.Load(), |
||||
CurrentMetrics: metrics, |
||||
WritePIDState: PIDState{ |
||||
Integral: wIntegral, |
||||
PrevError: wPrevErr, |
||||
PrevFilteredError: wPrevFiltered, |
||||
}, |
||||
ReadPIDState: PIDState{ |
||||
Integral: rIntegral, |
||||
PrevError: rPrevErr, |
||||
PrevFilteredError: rPrevFiltered, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
// Reset clears all PID controller state and statistics.
|
||||
func (l *Limiter) Reset() { |
||||
l.writePID.Reset() |
||||
l.readPID.Reset() |
||||
l.writeThrottles.Store(0) |
||||
l.readThrottles.Store(0) |
||||
l.totalWriteDelayMs.Store(0) |
||||
l.totalReadDelayMs.Store(0) |
||||
} |
||||
|
||||
// IsEnabled returns whether rate limiting is active.
|
||||
func (l *Limiter) IsEnabled() bool { |
||||
return l.config.Enabled && l.monitor != nil |
||||
} |
||||
|
||||
// UpdateConfig updates the rate limiter configuration.
|
||||
// This is useful for dynamic tuning.
|
||||
func (l *Limiter) UpdateConfig(config Config) { |
||||
l.config = config |
||||
|
||||
// Update PID controllers
|
||||
l.writePID.SetSetpoint(config.WriteSetpoint) |
||||
l.writePID.SetGains(config.WriteKp, config.WriteKi, config.WriteKd) |
||||
l.writePID.OutputMax = float64(config.MaxWriteDelayMs) / 1000.0 |
||||
|
||||
l.readPID.SetSetpoint(config.ReadSetpoint) |
||||
l.readPID.SetGains(config.ReadKp, config.ReadKi, config.ReadKd) |
||||
l.readPID.OutputMax = float64(config.MaxReadDelayMs) / 1000.0 |
||||
|
||||
// Update memory target
|
||||
if l.monitor != nil && config.TargetMemoryMB > 0 { |
||||
l.monitor.SetMemoryTarget(uint64(config.TargetMemoryMB) * 1024 * 1024) |
||||
} |
||||
} |
||||
@ -0,0 +1,259 @@
@@ -0,0 +1,259 @@
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"context" |
||||
"runtime" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j" |
||||
"next.orly.dev/pkg/interfaces/loadmonitor" |
||||
) |
||||
|
||||
// Neo4jMonitor implements loadmonitor.Monitor for Neo4j database.
|
||||
// Since Neo4j driver doesn't expose detailed metrics, we track:
|
||||
// - Memory pressure via Go runtime
|
||||
// - Query concurrency via the semaphore
|
||||
// - Latency via recording
|
||||
type Neo4jMonitor struct { |
||||
driver neo4j.DriverWithContext |
||||
querySem chan struct{} // Reference to the query semaphore
|
||||
|
||||
// Target memory for pressure calculation
|
||||
targetMemoryBytes atomic.Uint64 |
||||
|
||||
// Latency tracking with exponential moving average
|
||||
queryLatencyNs atomic.Int64 |
||||
writeLatencyNs atomic.Int64 |
||||
latencyAlpha float64 // EMA coefficient (default 0.1)
|
||||
|
||||
// Concurrency tracking
|
||||
activeReads atomic.Int32 |
||||
activeWrites atomic.Int32 |
||||
maxConcurrency int |
||||
|
||||
// Cached metrics (updated by background goroutine)
|
||||
metricsLock sync.RWMutex |
||||
cachedMetrics loadmonitor.Metrics |
||||
|
||||
// Background collection
|
||||
stopChan chan struct{} |
||||
stopped chan struct{} |
||||
interval time.Duration |
||||
} |
||||
|
||||
// Compile-time check that Neo4jMonitor implements loadmonitor.Monitor
|
||||
var _ loadmonitor.Monitor = (*Neo4jMonitor)(nil) |
||||
|
||||
// NewNeo4jMonitor creates a new Neo4j load monitor.
|
||||
// The querySem should be the same semaphore used for limiting concurrent queries.
|
||||
// maxConcurrency is the maximum concurrent query limit (typically 10).
|
||||
func NewNeo4jMonitor( |
||||
driver neo4j.DriverWithContext, |
||||
querySem chan struct{}, |
||||
maxConcurrency int, |
||||
updateInterval time.Duration, |
||||
) *Neo4jMonitor { |
||||
if updateInterval <= 0 { |
||||
updateInterval = 100 * time.Millisecond |
||||
} |
||||
if maxConcurrency <= 0 { |
||||
maxConcurrency = 10 |
||||
} |
||||
|
||||
m := &Neo4jMonitor{ |
||||
driver: driver, |
||||
querySem: querySem, |
||||
maxConcurrency: maxConcurrency, |
||||
latencyAlpha: 0.1, // 10% new, 90% old for smooth EMA
|
||||
stopChan: make(chan struct{}), |
||||
stopped: make(chan struct{}), |
||||
interval: updateInterval, |
||||
} |
||||
|
||||
// Set a default target (1.5GB)
|
||||
m.targetMemoryBytes.Store(1500 * 1024 * 1024) |
||||
|
||||
return m |
||||
} |
||||
|
||||
// GetMetrics returns the current load metrics.
|
||||
func (m *Neo4jMonitor) GetMetrics() loadmonitor.Metrics { |
||||
m.metricsLock.RLock() |
||||
defer m.metricsLock.RUnlock() |
||||
return m.cachedMetrics |
||||
} |
||||
|
||||
// RecordQueryLatency records a query latency sample using exponential moving average.
|
||||
func (m *Neo4jMonitor) RecordQueryLatency(latency time.Duration) { |
||||
ns := latency.Nanoseconds() |
||||
for { |
||||
old := m.queryLatencyNs.Load() |
||||
if old == 0 { |
||||
if m.queryLatencyNs.CompareAndSwap(0, ns) { |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
// EMA: new = alpha * sample + (1-alpha) * old
|
||||
newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) |
||||
if m.queryLatencyNs.CompareAndSwap(old, newVal) { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// RecordWriteLatency records a write latency sample using exponential moving average.
|
||||
func (m *Neo4jMonitor) RecordWriteLatency(latency time.Duration) { |
||||
ns := latency.Nanoseconds() |
||||
for { |
||||
old := m.writeLatencyNs.Load() |
||||
if old == 0 { |
||||
if m.writeLatencyNs.CompareAndSwap(0, ns) { |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
// EMA: new = alpha * sample + (1-alpha) * old
|
||||
newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old)) |
||||
if m.writeLatencyNs.CompareAndSwap(old, newVal) { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// SetMemoryTarget sets the target memory limit in bytes.
|
||||
func (m *Neo4jMonitor) SetMemoryTarget(bytes uint64) { |
||||
m.targetMemoryBytes.Store(bytes) |
||||
} |
||||
|
||||
// Start begins background metric collection.
|
||||
func (m *Neo4jMonitor) Start() <-chan struct{} { |
||||
go m.collectLoop() |
||||
return m.stopped |
||||
} |
||||
|
||||
// Stop halts background metric collection.
|
||||
func (m *Neo4jMonitor) Stop() { |
||||
close(m.stopChan) |
||||
<-m.stopped |
||||
} |
||||
|
||||
// collectLoop periodically collects metrics.
|
||||
func (m *Neo4jMonitor) collectLoop() { |
||||
defer close(m.stopped) |
||||
|
||||
ticker := time.NewTicker(m.interval) |
||||
defer ticker.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-m.stopChan: |
||||
return |
||||
case <-ticker.C: |
||||
m.updateMetrics() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// updateMetrics collects current metrics.
|
||||
func (m *Neo4jMonitor) updateMetrics() { |
||||
metrics := loadmonitor.Metrics{ |
||||
Timestamp: time.Now(), |
||||
} |
||||
|
||||
// Calculate memory pressure from Go runtime
|
||||
var memStats runtime.MemStats |
||||
runtime.ReadMemStats(&memStats) |
||||
|
||||
targetBytes := m.targetMemoryBytes.Load() |
||||
if targetBytes > 0 { |
||||
// Use HeapAlloc as primary memory metric
|
||||
metrics.MemoryPressure = float64(memStats.HeapAlloc) / float64(targetBytes) |
||||
} |
||||
|
||||
// Calculate load from semaphore usage
|
||||
// querySem is a buffered channel - count how many slots are taken
|
||||
if m.querySem != nil { |
||||
usedSlots := len(m.querySem) |
||||
concurrencyLoad := float64(usedSlots) / float64(m.maxConcurrency) |
||||
if concurrencyLoad > 1.0 { |
||||
concurrencyLoad = 1.0 |
||||
} |
||||
// Both read and write use the same semaphore
|
||||
metrics.WriteLoad = concurrencyLoad |
||||
metrics.ReadLoad = concurrencyLoad |
||||
} |
||||
|
||||
// Add latency-based load adjustment
|
||||
// High latency indicates the database is struggling
|
||||
queryLatencyNs := m.queryLatencyNs.Load() |
||||
writeLatencyNs := m.writeLatencyNs.Load() |
||||
|
||||
// Consider > 500ms query latency as concerning
|
||||
const latencyThresholdNs = 500 * 1e6 // 500ms
|
||||
if queryLatencyNs > 0 { |
||||
latencyLoad := float64(queryLatencyNs) / float64(latencyThresholdNs) |
||||
if latencyLoad > 1.0 { |
||||
latencyLoad = 1.0 |
||||
} |
||||
// Blend concurrency and latency for read load
|
||||
metrics.ReadLoad = 0.5*metrics.ReadLoad + 0.5*latencyLoad |
||||
} |
||||
|
||||
if writeLatencyNs > 0 { |
||||
latencyLoad := float64(writeLatencyNs) / float64(latencyThresholdNs) |
||||
if latencyLoad > 1.0 { |
||||
latencyLoad = 1.0 |
||||
} |
||||
// Blend concurrency and latency for write load
|
||||
metrics.WriteLoad = 0.5*metrics.WriteLoad + 0.5*latencyLoad |
||||
} |
||||
|
||||
// Store latencies
|
||||
metrics.QueryLatency = time.Duration(queryLatencyNs) |
||||
metrics.WriteLatency = time.Duration(writeLatencyNs) |
||||
|
||||
// Update cached metrics
|
||||
m.metricsLock.Lock() |
||||
m.cachedMetrics = metrics |
||||
m.metricsLock.Unlock() |
||||
} |
||||
|
||||
// IncrementActiveReads tracks an active read operation.
|
||||
// Call this when starting a read, and call the returned function when done.
|
||||
func (m *Neo4jMonitor) IncrementActiveReads() func() { |
||||
m.activeReads.Add(1) |
||||
return func() { |
||||
m.activeReads.Add(-1) |
||||
} |
||||
} |
||||
|
||||
// IncrementActiveWrites tracks an active write operation.
|
||||
// Call this when starting a write, and call the returned function when done.
|
||||
func (m *Neo4jMonitor) IncrementActiveWrites() func() { |
||||
m.activeWrites.Add(1) |
||||
return func() { |
||||
m.activeWrites.Add(-1) |
||||
} |
||||
} |
||||
|
||||
// GetConcurrencyStats returns current concurrency statistics for debugging.
|
||||
func (m *Neo4jMonitor) GetConcurrencyStats() (reads, writes int32, semUsed int) { |
||||
reads = m.activeReads.Load() |
||||
writes = m.activeWrites.Load() |
||||
if m.querySem != nil { |
||||
semUsed = len(m.querySem) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// CheckConnectivity performs a connectivity check to Neo4j.
|
||||
// This can be used to verify the database is responsive.
|
||||
func (m *Neo4jMonitor) CheckConnectivity(ctx context.Context) error { |
||||
if m.driver == nil { |
||||
return nil |
||||
} |
||||
return m.driver.VerifyConnectivity(ctx) |
||||
} |
||||
@ -0,0 +1,218 @@
@@ -0,0 +1,218 @@
|
||||
// Package ratelimit provides adaptive rate limiting using PID control.
|
||||
// The PID controller uses proportional, integral, and derivative terms
|
||||
// with a low-pass filter on the derivative to suppress high-frequency noise.
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"math" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
// PIDController implements a PID controller with filtered derivative.
|
||||
// It is designed for rate limiting database operations based on load metrics.
|
||||
//
|
||||
// The controller computes a delay recommendation based on:
|
||||
// - Proportional (P): Immediate response to current error
|
||||
// - Integral (I): Accumulated error to eliminate steady-state offset
|
||||
// - Derivative (D): Rate of change prediction (filtered to reduce noise)
|
||||
//
|
||||
// The filtered derivative uses a low-pass filter to attenuate high-frequency
|
||||
// noise that would otherwise cause erratic control behavior.
|
||||
type PIDController struct { |
||||
// Gains
|
||||
Kp float64 // Proportional gain
|
||||
Ki float64 // Integral gain
|
||||
Kd float64 // Derivative gain
|
||||
|
||||
// Setpoint is the target process variable value (e.g., 0.85 for 85% of target memory).
|
||||
// The controller drives the process variable toward this setpoint.
|
||||
Setpoint float64 |
||||
|
||||
// DerivativeFilterAlpha is the low-pass filter coefficient for the derivative term.
|
||||
// Range: 0.0-1.0, where lower values provide stronger filtering.
|
||||
// Recommended: 0.2 for strong filtering, 0.5 for moderate filtering.
|
||||
DerivativeFilterAlpha float64 |
||||
|
||||
// Integral limits for anti-windup
|
||||
IntegralMax float64 |
||||
IntegralMin float64 |
||||
|
||||
// Output limits
|
||||
OutputMin float64 // Minimum output (typically 0 = no delay)
|
||||
OutputMax float64 // Maximum output (max delay in seconds)
|
||||
|
||||
// Internal state (protected by mutex)
|
||||
mu sync.Mutex |
||||
integral float64 |
||||
prevError float64 |
||||
prevFilteredError float64 |
||||
lastUpdate time.Time |
||||
initialized bool |
||||
} |
||||
|
||||
// DefaultPIDControllerForWrites creates a PID controller tuned for write operations.
|
||||
// Writes benefit from aggressive integral and moderate proportional response.
|
||||
func DefaultPIDControllerForWrites() *PIDController { |
||||
return &PIDController{ |
||||
Kp: 0.5, // Moderate proportional response
|
||||
Ki: 0.1, // Steady integral to eliminate offset
|
||||
Kd: 0.05, // Small derivative for prediction
|
||||
Setpoint: 0.85, // Target 85% of memory limit
|
||||
DerivativeFilterAlpha: 0.2, // Strong filtering (20% new, 80% old)
|
||||
IntegralMax: 10.0, // Anti-windup: max 10 seconds accumulated
|
||||
IntegralMin: -2.0, // Allow small negative for faster recovery
|
||||
OutputMin: 0.0, // No delay minimum
|
||||
OutputMax: 1.0, // Max 1 second delay per write
|
||||
} |
||||
} |
||||
|
||||
// DefaultPIDControllerForReads creates a PID controller tuned for read operations.
|
||||
// Reads should be more responsive but with less aggressive throttling.
|
||||
func DefaultPIDControllerForReads() *PIDController { |
||||
return &PIDController{ |
||||
Kp: 0.3, // Lower proportional (reads are more important)
|
||||
Ki: 0.05, // Lower integral (don't accumulate as aggressively)
|
||||
Kd: 0.02, // Very small derivative
|
||||
Setpoint: 0.90, // Target 90% (more tolerant of memory use)
|
||||
DerivativeFilterAlpha: 0.15, // Very strong filtering
|
||||
IntegralMax: 5.0, // Lower anti-windup limit
|
||||
IntegralMin: -1.0, // Allow small negative
|
||||
OutputMin: 0.0, // No delay minimum
|
||||
OutputMax: 0.5, // Max 500ms delay per read
|
||||
} |
||||
} |
||||
|
||||
// NewPIDController creates a new PID controller with custom parameters.
|
||||
func NewPIDController( |
||||
kp, ki, kd float64, |
||||
setpoint float64, |
||||
derivativeFilterAlpha float64, |
||||
integralMin, integralMax float64, |
||||
outputMin, outputMax float64, |
||||
) *PIDController { |
||||
return &PIDController{ |
||||
Kp: kp, |
||||
Ki: ki, |
||||
Kd: kd, |
||||
Setpoint: setpoint, |
||||
DerivativeFilterAlpha: derivativeFilterAlpha, |
||||
IntegralMin: integralMin, |
||||
IntegralMax: integralMax, |
||||
OutputMin: outputMin, |
||||
OutputMax: outputMax, |
||||
} |
||||
} |
||||
|
||||
// Update computes the PID output based on the current process variable.
|
||||
// The process variable should be in the range [0.0, 1.0+] representing load level.
|
||||
//
|
||||
// Returns the recommended delay in seconds. A value of 0 means no delay needed.
|
||||
func (p *PIDController) Update(processVariable float64) float64 { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
|
||||
now := time.Now() |
||||
|
||||
// Initialize on first call
|
||||
if !p.initialized { |
||||
p.lastUpdate = now |
||||
p.prevError = processVariable - p.Setpoint |
||||
p.prevFilteredError = p.prevError |
||||
p.initialized = true |
||||
return 0 // No delay on first call
|
||||
} |
||||
|
||||
// Calculate time delta
|
||||
dt := now.Sub(p.lastUpdate).Seconds() |
||||
if dt <= 0 { |
||||
dt = 0.001 // Minimum 1ms to avoid division by zero
|
||||
} |
||||
p.lastUpdate = now |
||||
|
||||
// Calculate current error (positive when above setpoint = need to throttle)
|
||||
error := processVariable - p.Setpoint |
||||
|
||||
// Proportional term: immediate response to current error
|
||||
pTerm := p.Kp * error |
||||
|
||||
// Integral term: accumulate error over time
|
||||
// Apply anti-windup by clamping the integral
|
||||
p.integral += error * dt |
||||
p.integral = clamp(p.integral, p.IntegralMin, p.IntegralMax) |
||||
iTerm := p.Ki * p.integral |
||||
|
||||
// Derivative term with low-pass filter
|
||||
// Apply exponential moving average to filter high-frequency noise:
|
||||
// filtered = alpha * new + (1 - alpha) * old
|
||||
// This is equivalent to a first-order low-pass filter
|
||||
filteredError := p.DerivativeFilterAlpha*error + (1-p.DerivativeFilterAlpha)*p.prevFilteredError |
||||
|
||||
// Derivative of the filtered error
|
||||
var dTerm float64 |
||||
if dt > 0 { |
||||
dTerm = p.Kd * (filteredError - p.prevFilteredError) / dt |
||||
} |
||||
|
||||
// Update previous values for next iteration
|
||||
p.prevError = error |
||||
p.prevFilteredError = filteredError |
||||
|
||||
// Compute total output and clamp to limits
|
||||
output := pTerm + iTerm + dTerm |
||||
output = clamp(output, p.OutputMin, p.OutputMax) |
||||
|
||||
// Only return positive delays (throttle when above setpoint)
|
||||
if output < 0 { |
||||
return 0 |
||||
} |
||||
return output |
||||
} |
||||
|
||||
// Reset clears the controller state, useful when conditions change significantly.
|
||||
func (p *PIDController) Reset() { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
|
||||
p.integral = 0 |
||||
p.prevError = 0 |
||||
p.prevFilteredError = 0 |
||||
p.initialized = false |
||||
} |
||||
|
||||
// SetSetpoint updates the target setpoint.
|
||||
func (p *PIDController) SetSetpoint(setpoint float64) { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.Setpoint = setpoint |
||||
} |
||||
|
||||
// SetGains updates the PID gains.
|
||||
func (p *PIDController) SetGains(kp, ki, kd float64) { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.Kp = kp |
||||
p.Ki = ki |
||||
p.Kd = kd |
||||
} |
||||
|
||||
// GetState returns the current internal state for monitoring/debugging.
|
||||
func (p *PIDController) GetState() (integral, prevError, prevFilteredError float64) { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
return p.integral, p.prevError, p.prevFilteredError |
||||
} |
||||
|
||||
// clamp restricts a value to the range [min, max].
|
||||
func clamp(value, min, max float64) float64 { |
||||
if math.IsNaN(value) { |
||||
return 0 |
||||
} |
||||
if value < min { |
||||
return min |
||||
} |
||||
if value > max { |
||||
return max |
||||
} |
||||
return value |
||||
} |
||||
@ -0,0 +1,176 @@
@@ -0,0 +1,176 @@
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestPIDController_BasicOperation(t *testing.T) { |
||||
pid := DefaultPIDControllerForWrites() |
||||
|
||||
// First call should return 0 (initialization)
|
||||
delay := pid.Update(0.5) |
||||
if delay != 0 { |
||||
t.Errorf("expected 0 delay on first call, got %v", delay) |
||||
} |
||||
|
||||
// Sleep a bit to ensure dt > 0
|
||||
time.Sleep(10 * time.Millisecond) |
||||
|
||||
// Process variable below setpoint (0.5 < 0.85) should return 0 delay
|
||||
delay = pid.Update(0.5) |
||||
if delay != 0 { |
||||
t.Errorf("expected 0 delay when below setpoint, got %v", delay) |
||||
} |
||||
|
||||
// Process variable above setpoint should return positive delay
|
||||
time.Sleep(10 * time.Millisecond) |
||||
delay = pid.Update(0.95) // 0.95 > 0.85 setpoint
|
||||
if delay <= 0 { |
||||
t.Errorf("expected positive delay when above setpoint, got %v", delay) |
||||
} |
||||
} |
||||
|
||||
func TestPIDController_IntegralAccumulation(t *testing.T) { |
||||
pid := NewPIDController( |
||||
0.5, 0.5, 0.0, // High Ki, no Kd
|
||||
0.5, // setpoint
|
||||
0.2, // filter alpha
|
||||
-10, 10, // integral bounds
|
||||
0, 1.0, // output bounds
|
||||
) |
||||
|
||||
// Initialize
|
||||
pid.Update(0.5) |
||||
time.Sleep(10 * time.Millisecond) |
||||
|
||||
// Continuously above setpoint should accumulate integral
|
||||
for i := 0; i < 10; i++ { |
||||
time.Sleep(10 * time.Millisecond) |
||||
pid.Update(0.8) // 0.3 above setpoint
|
||||
} |
||||
|
||||
integral, _, _ := pid.GetState() |
||||
if integral <= 0 { |
||||
t.Errorf("expected positive integral after sustained error, got %v", integral) |
||||
} |
||||
} |
||||
|
||||
func TestPIDController_FilteredDerivative(t *testing.T) { |
||||
pid := NewPIDController( |
||||
0.0, 0.0, 1.0, // Only Kd
|
||||
0.5, // setpoint
|
||||
0.5, // 50% filtering
|
||||
-10, 10, |
||||
0, 1.0, |
||||
) |
||||
|
||||
// Initialize with low value
|
||||
pid.Update(0.5) |
||||
time.Sleep(10 * time.Millisecond) |
||||
|
||||
// Second call with same value - derivative should be near zero
|
||||
pid.Update(0.5) |
||||
_, _, prevFiltered := pid.GetState() |
||||
|
||||
time.Sleep(10 * time.Millisecond) |
||||
|
||||
// Big jump - filtered derivative should be dampened
|
||||
delay := pid.Update(1.0) |
||||
|
||||
// The filtered derivative should cause some response, but dampened
|
||||
// Since we only have Kd=1.0 and alpha=0.5, the response should be modest
|
||||
if delay < 0 { |
||||
t.Errorf("expected non-negative delay, got %v", delay) |
||||
} |
||||
|
||||
_, _, newFiltered := pid.GetState() |
||||
// Filtered error should have moved toward the new error but not fully
|
||||
if newFiltered <= prevFiltered { |
||||
t.Errorf("filtered error should increase with rising process variable") |
||||
} |
||||
} |
||||
|
||||
func TestPIDController_AntiWindup(t *testing.T) { |
||||
pid := NewPIDController( |
||||
0.0, 1.0, 0.0, // Only Ki
|
||||
0.5, // setpoint
|
||||
0.2, // filter alpha
|
||||
-1.0, 1.0, // tight integral bounds
|
||||
0, 10.0, // wide output bounds
|
||||
) |
||||
|
||||
// Initialize
|
||||
pid.Update(0.5) |
||||
|
||||
// Drive the integral to its limit
|
||||
for i := 0; i < 100; i++ { |
||||
time.Sleep(1 * time.Millisecond) |
||||
pid.Update(1.0) // Large positive error
|
||||
} |
||||
|
||||
integral, _, _ := pid.GetState() |
||||
if integral > 1.0 { |
||||
t.Errorf("integral should be clamped at 1.0, got %v", integral) |
||||
} |
||||
} |
||||
|
||||
func TestPIDController_Reset(t *testing.T) { |
||||
pid := DefaultPIDControllerForWrites() |
||||
|
||||
// Build up some state
|
||||
pid.Update(0.5) |
||||
time.Sleep(10 * time.Millisecond) |
||||
pid.Update(0.9) |
||||
time.Sleep(10 * time.Millisecond) |
||||
pid.Update(0.95) |
||||
|
||||
// Reset
|
||||
pid.Reset() |
||||
|
||||
integral, prevErr, prevFiltered := pid.GetState() |
||||
if integral != 0 || prevErr != 0 || prevFiltered != 0 { |
||||
t.Errorf("expected all state to be zero after reset") |
||||
} |
||||
|
||||
// Next call should behave like first call
|
||||
delay := pid.Update(0.9) |
||||
if delay != 0 { |
||||
t.Errorf("expected 0 delay on first call after reset, got %v", delay) |
||||
} |
||||
} |
||||
|
||||
func TestPIDController_SetGains(t *testing.T) { |
||||
pid := DefaultPIDControllerForWrites() |
||||
|
||||
// Change gains
|
||||
pid.SetGains(1.0, 0.5, 0.1) |
||||
|
||||
if pid.Kp != 1.0 || pid.Ki != 0.5 || pid.Kd != 0.1 { |
||||
t.Errorf("gains not updated correctly") |
||||
} |
||||
} |
||||
|
||||
func TestPIDController_SetSetpoint(t *testing.T) { |
||||
pid := DefaultPIDControllerForWrites() |
||||
|
||||
pid.SetSetpoint(0.7) |
||||
|
||||
if pid.Setpoint != 0.7 { |
||||
t.Errorf("setpoint not updated, got %v", pid.Setpoint) |
||||
} |
||||
} |
||||
|
||||
func TestDefaultControllers(t *testing.T) { |
||||
writePID := DefaultPIDControllerForWrites() |
||||
readPID := DefaultPIDControllerForReads() |
||||
|
||||
// Write controller should have higher gains and lower setpoint
|
||||
if writePID.Kp <= readPID.Kp { |
||||
t.Errorf("write Kp should be higher than read Kp") |
||||
} |
||||
|
||||
if writePID.Setpoint >= readPID.Setpoint { |
||||
t.Errorf("write setpoint should be lower than read setpoint") |
||||
} |
||||
} |
||||
Loading…
Reference in new issue