Files
eko/internal/server/api/helpers.go
2025-02-02 19:23:50 +02:00

207 lines
5.7 KiB
Go

package api
import (
"context"
"log"
"strings"
"time"
"github.com/kyren223/eko/internal/data"
"github.com/kyren223/eko/internal/packet"
"github.com/kyren223/eko/internal/server/session"
"github.com/kyren223/eko/pkg/assert"
"github.com/kyren223/eko/pkg/snowflake"
)
const hex = "0123456789abcdefABCDEF"
func isValidHexColor(color string) (bool, string) {
if len(color) != 7 {
return false, "color must be hex with length of 7"
}
if color[0] != '#' {
return false, "color must start with '#'"
}
for _, c := range color[1:] {
if !strings.ContainsRune(hex, c) {
return false, "color must start with '#' and contain exactly 6 digits 0-9, a-f, A-F"
}
}
return true, ""
}
func IsNetworkAdmin(ctx context.Context, queries *data.Queries, userId, networkId snowflake.ID) (bool, error) {
userNetwork, err := queries.GetMemberById(ctx, data.GetMemberByIdParams{
NetworkID: networkId,
UserID: userId,
})
if err != nil {
return false, err
}
isAdmin := userNetwork.IsAdmin && userNetwork.IsMember && !userNetwork.IsBanned
return isAdmin, nil
}
func NetworkPropagateWithFilter(
ctx context.Context, sess *session.Session,
network snowflake.ID, payload packet.Payload,
filter func(userId snowflake.ID) (pass bool),
) packet.Payload {
var sessions []snowflake.ID
sess.Manager().UseSessions(func(s map[snowflake.ID]*session.Session) {
sessions = make([]snowflake.ID, 0, len(s)-1)
for key := range s {
if key != sess.ID() && filter(key) {
sessions = append(sessions, key)
}
}
})
queries := data.New(db)
sessions, err := queries.FilterUsersInNetwork(ctx, data.FilterUsersInNetworkParams{
NetworkID: network,
Users: sessions,
})
if err != nil {
log.Println("database error in propagate:", err)
return &ErrInternalError
}
for _, sessionId := range sessions {
session := sess.Manager().Session(sessionId)
if session == nil {
continue
}
timeout := 1 * time.Second
context, cancel := context.WithTimeout(context.Background(), timeout)
go func() {
defer cancel()
pkt := packet.NewPacket(packet.NewMsgPackEncoder(payload))
if ok := session.Write(context, pkt); !ok {
log.Println(sess.Addr(), "propagation to", session.Addr(), "failed")
}
}()
}
return payload
}
func NetworkPropagate(
ctx context.Context, sess *session.Session,
network snowflake.ID, payload packet.Payload,
) packet.Payload {
return NetworkPropagateWithFilter(ctx, sess, network, payload, func(userId snowflake.ID) bool {
return true
})
}
func SplitMembersAndUsers(membersAndUsers []data.GetNetworkMembersRow) ([]data.Member, []data.User) {
members := make([]data.Member, 0, len(membersAndUsers))
users := make([]data.User, 0, len(membersAndUsers))
for _, memberAndUser := range membersAndUsers {
members = append(members, memberAndUser.Member)
users = append(users, memberAndUser.User)
}
return members, users
}
func UserPropagate(
ctx context.Context, sess *session.Session,
user snowflake.ID, payload packet.Payload,
) packet.Payload {
session := sess.Manager().Session(user)
if session == nil {
log.Println(sess.Addr(), "propagation to user", user, "failed due to session being nil")
return payload
}
timeout := 1 * time.Second
context, cancel := context.WithTimeout(context.Background(), timeout)
go func() {
defer cancel()
pkt := packet.NewPacket(packet.NewMsgPackEncoder(payload))
if ok := session.Write(context, pkt); !ok {
log.Println(sess.Addr(), "propagation to", session.Addr(), "failed")
}
}()
return payload
}
const getNotificationsQuery = `-- name: GetNotifications :many
WITH
entries(source, lastId) AS (
VALUES /*SLICE:pair*/?
),
permitted_frequencies AS (
SELECT f.id, m.is_admin
FROM frequencies f
JOIN entries e ON f.id = e.source
LEFT JOIN members m
ON m.user_id = ?
AND m.network_id = f.network_id
WHERE m.is_member = true AND (f.perms != 0 OR m.is_admin = true)
)
SELECT
e.source,
CASE
WHEN COUNT(m.id) = 0 THEN NULL
ELSE SUM(CASE WHEN (m.ping = 0 OR (m.ping = 1 AND pf.is_admin = true) OR m.ping = ?) THEN 1 ELSE 0 END)
-- 0 is @everyone, 1 is @admins, otherwise it's user_id
END AS pings
FROM entries e
LEFT JOIN messages m ON m.id > e.lastId
AND (m.frequency_id = e.source OR
(m.receiver_id = e.source AND m.sender_id = ?) OR
(m.sender_id = e.source AND m.receiver_id = ?))
JOIN permitted_frequencies pf ON e.source = pf.id
GROUP BY e.source, e.lastId;
`
func getNotifications(ctx context.Context, arg *packet.GetNotifications, ping snowflake.ID) (packet.NotificationsInfo, error) {
assert.Assert(
len(arg.Source) == len(arg.LastReadId),
"Source and LastReadID must match",
"source", arg.Source, "last read id", arg.LastReadId,
)
query := getNotificationsQuery
var queryParams []interface{}
if len(arg.Source) > 0 {
for i := 0; i < len(arg.Source); i++ {
queryParams = append(queryParams, arg.Source[i])
queryParams = append(queryParams, arg.LastReadId[i])
}
query = strings.Replace(query, "/*SLICE:pair*/?", strings.Repeat(",(?,?)", len(arg.Source))[1:], 1)
} else {
assert.Never("source and id must not be empty")
}
queryParams = append(queryParams, ping, ping, ping, ping)
rows, err := db.QueryContext(ctx, query, queryParams...)
if err != nil {
return packet.NotificationsInfo{}, err
}
defer rows.Close()
var items packet.NotificationsInfo
for rows.Next() {
var source *snowflake.ID
var pings *int64
if err := rows.Scan(&source, &pings); err != nil {
return packet.NotificationsInfo{}, err
}
items.Source = append(items.Source, *source)
items.Pings = append(items.Pings, pings)
}
if err := rows.Close(); err != nil {
return packet.NotificationsInfo{}, err
}
if err := rows.Err(); err != nil {
return packet.NotificationsInfo{}, err
}
return items, nil
}