Added session based request rate limiting, uses token bucket strategy

This commit is contained in:
2025-07-14 12:39:46 +03:00
parent d78eb1fbc4
commit 0a0915a977
4 changed files with 138 additions and 7 deletions

View File

@@ -11,7 +11,7 @@ var RequestsProcessed = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Name: "requests_processed_total",
Help: "The total number of processed requests",
}, []string{"request_type"})
}, []string{"request_type", "dropped"})
// var RequestsInProgress = promauto.NewCounterVec(prometheus.CounterOpts{
// Namespace: namespace,
@@ -24,7 +24,7 @@ var RequestProcessingDuration = promauto.NewHistogramVec(prometheus.HistogramOpt
Name: "request_processing_duration_seconds",
Help: "The duration in seconds it took to process a request",
NativeHistogramBucketFactor: 1.00271,
}, []string{"request_type"})
}, []string{"request_type", "dropped"})
// var RequestProcessingDuration = promauto.NewSummaryVec(prometheus.SummaryOpts{
// Namespace: namespace,

View File

@@ -301,13 +301,21 @@ func (server *server) handleConnection(conn net.Conn) {
for request := range framer.Out {
start := time.Now().UTC()
processPacket(localCtx, sess, request)
dropped := processPacket(localCtx, sess, request)
duration := time.Since(start)
labels := prometheus.Labels{"request_type": request.Type().String()}
labels := prometheus.Labels{
"request_type": request.Type().String(),
"dropped": strconv.FormatBool(dropped),
}
metrics.RequestsProcessed.With(labels).Inc()
metrics.RequestProcessingDuration.With(labels).Observe(float64(duration.Seconds()))
slog.InfoContext(ctx, "processed request", "request_type", request.Type().String(), "duration", duration.String(), "duration_ns", duration.Nanoseconds())
if dropped {
slog.InfoContext(ctx, "dropped request", "request_type", request.Type().String(), "duration", duration.String(), "duration_ns", duration.Nanoseconds())
} else {
slog.InfoContext(ctx, "processed request", "request_type", request.Type().String(), "duration", duration.String(), "duration_ns", duration.Nanoseconds())
}
}
slog.InfoContext(ctx, "processor done")
}()
@@ -359,7 +367,12 @@ func (server *server) handleConnection(conn net.Conn) {
<-done
}
func processPacket(ctx context.Context, sess *session.Session, pkt packet.Packet) {
func processPacket(ctx context.Context, sess *session.Session, pkt packet.Packet) bool {
tokens := TokensPerRequest(pkt.Type())
if !sess.RateLimiter().Take(tokens) {
return false // Rate limit was hit
}
var response packet.Payload
request, err := pkt.DecodedPayload()
@@ -374,6 +387,8 @@ func processPacket(ctx context.Context, sess *session.Session, pkt packet.Packet
ok := sess.Write(ctx, response)
assert.Assert(ok, "context is never done and write will panic if queue is closed")
}
return true
}
func processRequest(ctx context.Context, sess *session.Session, request packet.Payload) packet.Payload {
@@ -532,3 +547,46 @@ func sendTosInfo(ctx context.Context, sess *session.Session) bool {
}
return sess.Write(ctx, payload)
}
func TokensPerRequest(requestType packet.PacketType) float64 {
// 1 token means 1 token per second, which is equivalent to 1ms
// The idea is that for 1000 users, each user has 1ms of server time
// This is the baseline but requests may take less/more
switch requestType {
case packet.PacketAcceptTos:
return 0.15
case packet.PacketGetNonce:
return 0.1
case packet.PacketAuthenticate:
return 1.5
case packet.PacketDeviceAnalytics:
return 0.2 // arbitrary
// TODO: once I get more data for these, add them
case packet.PacketBlockUser:
case packet.PacketCreateFrequency:
case packet.PacketCreateNetwork:
case packet.PacketDeleteFrequency:
case packet.PacketDeleteMessage:
case packet.PacketDeleteNetwork:
case packet.PacketEditMessage:
case packet.PacketGetBannedMembers:
case packet.PacketGetUserData:
case packet.PacketGetUsers:
case packet.PacketRequestMessages:
case packet.PacketSendMessage:
case packet.PacketSetLastReadMessages:
case packet.PacketSetMember:
case packet.PacketSetUserData:
case packet.PacketSwapFrequencies:
case packet.PacketTransferNetwork:
case packet.PacketTrustUser:
case packet.PacketUpdateFrequency:
case packet.PacketUpdateNetwork:
}
return 1
}

View File

@@ -12,12 +12,18 @@ import (
"github.com/kyren223/eko/internal/packet"
"github.com/kyren223/eko/internal/server/ctxkeys"
"github.com/kyren223/eko/pkg/assert"
"github.com/kyren223/eko/pkg/rate"
"github.com/kyren223/eko/pkg/snowflake"
)
const (
WriteQueueSize = 10
NonceSize = 32
DefaultRate = 0.1 // ms per second
DefaultLimit = 3 // ms burst
AuthenticatedRate = 1 // ms per second
AuthenticatedLimit = 20 // ms burst
)
type SessionManager interface {
@@ -45,7 +51,9 @@ type Session struct {
isTosAccepted bool
pubKey ed25519.PublicKey
id snowflake.ID
mu sync.RWMutex
rl rate.Limiter
mu sync.RWMutex
}
func NewSession(
@@ -68,6 +76,7 @@ func NewSession(
id: snowflake.InvalidID,
challengeMu: sync.Mutex{},
isTosAccepted: false,
rl: rate.NewLimiter(DefaultRate, DefaultLimit),
mu: sync.RWMutex{},
}
return session
@@ -78,6 +87,10 @@ func (s *Session) Addr() *net.TCPAddr {
return s.addr
}
func (s *Session) RateLimiter() *rate.Limiter {
return &s.rl
}
func (s *Session) IsTosAccepted() bool {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -115,6 +128,8 @@ func (s *Session) Promote(userId snowflake.ID, pubKey ed25519.PublicKey) {
defer s.mu.Unlock()
s.id = userId
s.pubKey = pubKey
s.rl.SetLimit(AuthenticatedLimit)
s.rl.SetRate(AuthenticatedRate)
}
func (s *Session) Manager() SessionManager {

58
pkg/rate/rate.go Normal file
View File

@@ -0,0 +1,58 @@
package rate
import "time"
type Limiter struct {
limit float64
rate float64
lastRefill time.Time
tokens float64
}
// rate refills limiter rate tokens per second
func NewLimiter(rate float64, limit float64) Limiter {
return Limiter{
limit: limit,
rate: rate,
lastRefill: time.Now().UTC(),
tokens: limit,
}
}
func (rl *Limiter) Fill() {
rl.update()
rl.tokens = rl.limit
}
func (rl *Limiter) SetRate(rate float64) {
rl.update()
rl.rate = rate
}
func (rl *Limiter) SetLimit(limit float64) {
rl.update()
rl.limit = limit
}
func (rl *Limiter) Take(tokens float64) bool {
rl.update()
has := rl.Has(tokens)
if has {
rl.tokens -= tokens
return true
}
return false
}
func (rl *Limiter) Has(tokens float64) bool {
rl.update()
return rl.tokens >= tokens
}
func (rl *Limiter) update() {
lastRefill := rl.lastRefill
rl.lastRefill = time.Now().UTC()
rl.tokens += min(time.Since(lastRefill).Seconds()*rl.rate, rl.limit)
}