mirror of
https://github.com/Kyren223/eko.git
synced 2025-09-05 21:18:14 +00:00
Added session based request rate limiting, uses token bucket strategy
This commit is contained in:
@@ -11,7 +11,7 @@ var RequestsProcessed = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Name: "requests_processed_total",
|
||||
Help: "The total number of processed requests",
|
||||
}, []string{"request_type"})
|
||||
}, []string{"request_type", "dropped"})
|
||||
|
||||
// var RequestsInProgress = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
// Namespace: namespace,
|
||||
@@ -24,7 +24,7 @@ var RequestProcessingDuration = promauto.NewHistogramVec(prometheus.HistogramOpt
|
||||
Name: "request_processing_duration_seconds",
|
||||
Help: "The duration in seconds it took to process a request",
|
||||
NativeHistogramBucketFactor: 1.00271,
|
||||
}, []string{"request_type"})
|
||||
}, []string{"request_type", "dropped"})
|
||||
|
||||
// var RequestProcessingDuration = promauto.NewSummaryVec(prometheus.SummaryOpts{
|
||||
// Namespace: namespace,
|
||||
|
@@ -301,13 +301,21 @@ func (server *server) handleConnection(conn net.Conn) {
|
||||
|
||||
for request := range framer.Out {
|
||||
start := time.Now().UTC()
|
||||
processPacket(localCtx, sess, request)
|
||||
dropped := processPacket(localCtx, sess, request)
|
||||
duration := time.Since(start)
|
||||
|
||||
labels := prometheus.Labels{"request_type": request.Type().String()}
|
||||
labels := prometheus.Labels{
|
||||
"request_type": request.Type().String(),
|
||||
"dropped": strconv.FormatBool(dropped),
|
||||
}
|
||||
metrics.RequestsProcessed.With(labels).Inc()
|
||||
metrics.RequestProcessingDuration.With(labels).Observe(float64(duration.Seconds()))
|
||||
slog.InfoContext(ctx, "processed request", "request_type", request.Type().String(), "duration", duration.String(), "duration_ns", duration.Nanoseconds())
|
||||
if dropped {
|
||||
slog.InfoContext(ctx, "dropped request", "request_type", request.Type().String(), "duration", duration.String(), "duration_ns", duration.Nanoseconds())
|
||||
} else {
|
||||
slog.InfoContext(ctx, "processed request", "request_type", request.Type().String(), "duration", duration.String(), "duration_ns", duration.Nanoseconds())
|
||||
}
|
||||
|
||||
}
|
||||
slog.InfoContext(ctx, "processor done")
|
||||
}()
|
||||
@@ -359,7 +367,12 @@ func (server *server) handleConnection(conn net.Conn) {
|
||||
<-done
|
||||
}
|
||||
|
||||
func processPacket(ctx context.Context, sess *session.Session, pkt packet.Packet) {
|
||||
func processPacket(ctx context.Context, sess *session.Session, pkt packet.Packet) bool {
|
||||
tokens := TokensPerRequest(pkt.Type())
|
||||
if !sess.RateLimiter().Take(tokens) {
|
||||
return false // Rate limit was hit
|
||||
}
|
||||
|
||||
var response packet.Payload
|
||||
|
||||
request, err := pkt.DecodedPayload()
|
||||
@@ -374,6 +387,8 @@ func processPacket(ctx context.Context, sess *session.Session, pkt packet.Packet
|
||||
ok := sess.Write(ctx, response)
|
||||
assert.Assert(ok, "context is never done and write will panic if queue is closed")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func processRequest(ctx context.Context, sess *session.Session, request packet.Payload) packet.Payload {
|
||||
@@ -532,3 +547,46 @@ func sendTosInfo(ctx context.Context, sess *session.Session) bool {
|
||||
}
|
||||
return sess.Write(ctx, payload)
|
||||
}
|
||||
|
||||
func TokensPerRequest(requestType packet.PacketType) float64 {
|
||||
// 1 token means 1 token per second, which is equivalent to 1ms
|
||||
// The idea is that for 1000 users, each user has 1ms of server time
|
||||
// This is the baseline but requests may take less/more
|
||||
|
||||
switch requestType {
|
||||
|
||||
case packet.PacketAcceptTos:
|
||||
return 0.15
|
||||
case packet.PacketGetNonce:
|
||||
return 0.1
|
||||
case packet.PacketAuthenticate:
|
||||
return 1.5
|
||||
case packet.PacketDeviceAnalytics:
|
||||
return 0.2 // arbitrary
|
||||
|
||||
// TODO: once I get more data for these, add them
|
||||
case packet.PacketBlockUser:
|
||||
case packet.PacketCreateFrequency:
|
||||
case packet.PacketCreateNetwork:
|
||||
case packet.PacketDeleteFrequency:
|
||||
case packet.PacketDeleteMessage:
|
||||
case packet.PacketDeleteNetwork:
|
||||
case packet.PacketEditMessage:
|
||||
case packet.PacketGetBannedMembers:
|
||||
case packet.PacketGetUserData:
|
||||
case packet.PacketGetUsers:
|
||||
case packet.PacketRequestMessages:
|
||||
case packet.PacketSendMessage:
|
||||
case packet.PacketSetLastReadMessages:
|
||||
case packet.PacketSetMember:
|
||||
case packet.PacketSetUserData:
|
||||
case packet.PacketSwapFrequencies:
|
||||
case packet.PacketTransferNetwork:
|
||||
case packet.PacketTrustUser:
|
||||
case packet.PacketUpdateFrequency:
|
||||
case packet.PacketUpdateNetwork:
|
||||
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
@@ -12,12 +12,18 @@ import (
|
||||
"github.com/kyren223/eko/internal/packet"
|
||||
"github.com/kyren223/eko/internal/server/ctxkeys"
|
||||
"github.com/kyren223/eko/pkg/assert"
|
||||
"github.com/kyren223/eko/pkg/rate"
|
||||
"github.com/kyren223/eko/pkg/snowflake"
|
||||
)
|
||||
|
||||
const (
|
||||
WriteQueueSize = 10
|
||||
NonceSize = 32
|
||||
|
||||
DefaultRate = 0.1 // ms per second
|
||||
DefaultLimit = 3 // ms burst
|
||||
AuthenticatedRate = 1 // ms per second
|
||||
AuthenticatedLimit = 20 // ms burst
|
||||
)
|
||||
|
||||
type SessionManager interface {
|
||||
@@ -45,7 +51,9 @@ type Session struct {
|
||||
isTosAccepted bool
|
||||
pubKey ed25519.PublicKey
|
||||
id snowflake.ID
|
||||
mu sync.RWMutex
|
||||
rl rate.Limiter
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSession(
|
||||
@@ -68,6 +76,7 @@ func NewSession(
|
||||
id: snowflake.InvalidID,
|
||||
challengeMu: sync.Mutex{},
|
||||
isTosAccepted: false,
|
||||
rl: rate.NewLimiter(DefaultRate, DefaultLimit),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
return session
|
||||
@@ -78,6 +87,10 @@ func (s *Session) Addr() *net.TCPAddr {
|
||||
return s.addr
|
||||
}
|
||||
|
||||
func (s *Session) RateLimiter() *rate.Limiter {
|
||||
return &s.rl
|
||||
}
|
||||
|
||||
func (s *Session) IsTosAccepted() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -115,6 +128,8 @@ func (s *Session) Promote(userId snowflake.ID, pubKey ed25519.PublicKey) {
|
||||
defer s.mu.Unlock()
|
||||
s.id = userId
|
||||
s.pubKey = pubKey
|
||||
s.rl.SetLimit(AuthenticatedLimit)
|
||||
s.rl.SetRate(AuthenticatedRate)
|
||||
}
|
||||
|
||||
func (s *Session) Manager() SessionManager {
|
||||
|
58
pkg/rate/rate.go
Normal file
58
pkg/rate/rate.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package rate
|
||||
|
||||
import "time"
|
||||
|
||||
type Limiter struct {
|
||||
limit float64
|
||||
rate float64
|
||||
|
||||
lastRefill time.Time
|
||||
tokens float64
|
||||
}
|
||||
|
||||
// rate refills limiter rate tokens per second
|
||||
func NewLimiter(rate float64, limit float64) Limiter {
|
||||
return Limiter{
|
||||
limit: limit,
|
||||
rate: rate,
|
||||
lastRefill: time.Now().UTC(),
|
||||
tokens: limit,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *Limiter) Fill() {
|
||||
rl.update()
|
||||
rl.tokens = rl.limit
|
||||
}
|
||||
|
||||
func (rl *Limiter) SetRate(rate float64) {
|
||||
rl.update()
|
||||
rl.rate = rate
|
||||
}
|
||||
|
||||
func (rl *Limiter) SetLimit(limit float64) {
|
||||
rl.update()
|
||||
rl.limit = limit
|
||||
}
|
||||
|
||||
func (rl *Limiter) Take(tokens float64) bool {
|
||||
rl.update()
|
||||
has := rl.Has(tokens)
|
||||
if has {
|
||||
rl.tokens -= tokens
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rl *Limiter) Has(tokens float64) bool {
|
||||
rl.update()
|
||||
return rl.tokens >= tokens
|
||||
}
|
||||
|
||||
func (rl *Limiter) update() {
|
||||
lastRefill := rl.lastRefill
|
||||
rl.lastRefill = time.Now().UTC()
|
||||
|
||||
rl.tokens += min(time.Since(lastRefill).Seconds()*rl.rate, rl.limit)
|
||||
}
|
Reference in New Issue
Block a user