Added create network to server api

This commit is contained in:
2024-11-27 17:34:28 +02:00
parent 6a34d7be01
commit be22bb89db
7 changed files with 280 additions and 70 deletions

View File

@@ -64,7 +64,7 @@ func (q *Queries) DeleteNetwork(ctx context.Context, id snowflake.ID) error {
return err
}
const getBannedUsersInNetwork = `-- name: GetBannedUsersInNetwork :many
const getNetworkBannedUsers = `-- name: GetNetworkBannedUsers :many
SELECT
users.id, users.name, users.public_key, users.description, users.is_public_dm, users.is_deleted,
users_networks.ban_reason
@@ -73,20 +73,20 @@ JOIN users ON users.id = users_networks.user_id
WHERE users_networks.network_id = ?
`
type GetBannedUsersInNetworkRow struct {
type GetNetworkBannedUsersRow struct {
User User
BanReason *string
}
func (q *Queries) GetBannedUsersInNetwork(ctx context.Context, networkID snowflake.ID) ([]GetBannedUsersInNetworkRow, error) {
rows, err := q.db.QueryContext(ctx, getBannedUsersInNetwork, networkID)
func (q *Queries) GetNetworkBannedUsers(ctx context.Context, networkID snowflake.ID) ([]GetNetworkBannedUsersRow, error) {
rows, err := q.db.QueryContext(ctx, getNetworkBannedUsers, networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetBannedUsersInNetworkRow
var items []GetNetworkBannedUsersRow
for rows.Next() {
var i GetBannedUsersInNetworkRow
var i GetNetworkBannedUsersRow
if err := rows.Scan(
&i.User.ID,
&i.User.Name,
@@ -129,6 +129,57 @@ func (q *Queries) GetNetworkById(ctx context.Context, id snowflake.ID) (Network,
return i, err
}
const getNetworkMembers = `-- name: GetNetworkMembers :many
SELECT
users.id, users.name, users.public_key, users.description, users.is_public_dm, users.is_deleted,
users_networks.joined_at,
users_networks.is_admin,
users_networks.is_muted
FROM users_networks
JOIN users ON users.id = users_networks.user_id
WHERE users_networks.network_id = ? AND is_member = true
`
type GetNetworkMembersRow struct {
User User
JoinedAt string
IsAdmin bool
IsMuted bool
}
func (q *Queries) GetNetworkMembers(ctx context.Context, networkID snowflake.ID) ([]GetNetworkMembersRow, error) {
rows, err := q.db.QueryContext(ctx, getNetworkMembers, networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetNetworkMembersRow
for rows.Next() {
var i GetNetworkMembersRow
if err := rows.Scan(
&i.User.ID,
&i.User.Name,
&i.User.PublicKey,
&i.User.Description,
&i.User.IsPublicDM,
&i.User.IsDeleted,
&i.JoinedAt,
&i.IsAdmin,
&i.IsMuted,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getPublicNetworks = `-- name: GetPublicNetworks :many
SELECT id, owner_id, name, icon, bg_hex_color, fg_hex_color, is_public FROM networks
WHERE is_public = true
@@ -165,57 +216,6 @@ func (q *Queries) GetPublicNetworks(ctx context.Context) ([]Network, error) {
return items, nil
}
const getUsersInNetwork = `-- name: GetUsersInNetwork :many
SELECT
users.id, users.name, users.public_key, users.description, users.is_public_dm, users.is_deleted,
users_networks.joined_at,
users_networks.is_admin,
users_networks.is_muted
FROM users_networks
JOIN users ON users.id = users_networks.user_id
WHERE users_networks.network_id = ? AND is_member = true
`
type GetUsersInNetworkRow struct {
User User
JoinedAt string
IsAdmin bool
IsMuted bool
}
func (q *Queries) GetUsersInNetwork(ctx context.Context, networkID snowflake.ID) ([]GetUsersInNetworkRow, error) {
rows, err := q.db.QueryContext(ctx, getUsersInNetwork, networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetUsersInNetworkRow
for rows.Next() {
var i GetUsersInNetworkRow
if err := rows.Scan(
&i.User.ID,
&i.User.Name,
&i.User.PublicKey,
&i.User.Description,
&i.User.IsPublicDM,
&i.User.IsDeleted,
&i.JoinedAt,
&i.IsAdmin,
&i.IsMuted,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const setNetworkIcon = `-- name: SetNetworkIcon :one
UPDATE networks SET
icon = ?,
@@ -306,6 +306,58 @@ func (q *Queries) SetNetworkName(ctx context.Context, arg SetNetworkNameParams)
return i, err
}
const setNetworkUser = `-- name: SetNetworkUser :one
INSERT INTO users_networks (
user_id, network_id,
is_member, is_admin, is_muted,
is_banned, ban_reason
) VALUES (
?, ?,
?, ?, ?,
?, ?
)
ON CONFLICT DO
UPDATE SET
is_member = ?, is_admin = ?, is_muted = ?,
is_banned = ?, ban_reason = ?
WHERE user_id = ? AND network_id = ?
RETURNING user_id, network_id, joined_at, is_member, is_admin, is_muted, is_banned, ban_reason
`
type SetNetworkUserParams struct {
UserID snowflake.ID
NetworkID snowflake.ID
IsMember bool
IsAdmin bool
IsMuted bool
IsBanned bool
BanReason *string
}
func (q *Queries) SetNetworkUser(ctx context.Context, arg SetNetworkUserParams) (UsersNetwork, error) {
row := q.db.QueryRowContext(ctx, setNetworkUser,
arg.UserID,
arg.NetworkID,
arg.IsMember,
arg.IsAdmin,
arg.IsMuted,
arg.IsBanned,
arg.BanReason,
)
var i UsersNetwork
err := row.Scan(
&i.UserID,
&i.NetworkID,
&i.JoinedAt,
&i.IsMember,
&i.IsAdmin,
&i.IsMuted,
&i.IsBanned,
&i.BanReason,
)
return i, err
}
const transferNetwork = `-- name: TransferNetwork :one
UPDATE networks SET
owner_id = ?

View File

@@ -26,8 +26,8 @@ func (m *CreateNetwork) Type() PacketType {
}
type UpdateNetwork struct {
Network snowflake.ID
CreateNetwork
Network snowflake.ID
}
func (m *UpdateNetwork) Type() PacketType {
@@ -52,25 +52,34 @@ func (m *DeleteNetwork) Type() PacketType {
}
type SetNetworkUser struct {
Network snowflake.ID
User snowflake.ID
Member *bool
Admin *bool
Muted *bool
Banned *bool
BanReason *string
Network snowflake.ID
User snowflake.ID
}
func (m *SetNetworkUser) Type() PacketType {
return PacketSetNetworkUser
}
type Member struct {
JoinedAt string
User data.User
IsAdmin bool
IsMuted bool
}
type FullNetwork struct {
Network data.Network
Frequencies []data.Frequency
Members []Member
}
type NetworksInfo struct {
Networks []struct {
Network data.Network
Frequencies []data.Frequency
Members []data.User
}
Networks []FullNetwork
}
func (m *NetworksInfo) Type() PacketType {
@@ -78,9 +87,9 @@ func (m *NetworksInfo) Type() PacketType {
}
type CreateFrequency struct {
Network snowflake.ID
Name string
HexColor string
Network snowflake.ID
Perms int
}
@@ -89,9 +98,9 @@ func (m *CreateFrequency) Type() PacketType {
}
type UpdateFrequency struct {
Frequency snowflake.ID
Name string
HexColor string
Frequency snowflake.ID
Perms byte
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/ed25519"
"database/sql"
"fmt"
"log"
"strconv"
"strings"
@@ -14,6 +15,8 @@ import (
"github.com/kyren223/eko/pkg/snowflake"
)
var internalError = packet.Error{Error: "internal server error"}
func SendMessage(ctx context.Context, sess *session.Session, request *packet.SendMessage) packet.Payload {
if (request.ReceiverID != nil) == (request.FrequencyID != nil) {
return &packet.Error{Error: "either receiver id or frequency id must exist"}
@@ -34,7 +37,7 @@ func SendMessage(ctx context.Context, sess *session.Session, request *packet.Sen
})
if err != nil {
log.Println(sess.Addr(), "database error:", err, "in SendMessage")
return &packet.Error{Error: "internal server error"}
return &internalError
}
return &packet.MessagesInfo{Messages: []data.Message{message}}
@@ -79,3 +82,91 @@ func CreateOrGetUser(ctx context.Context, node *snowflake.Node, pubKey ed25519.P
}
return user, nil
}
func CreateNetwork(ctx context.Context, sess *session.Session, request packet.CreateNetwork) packet.Payload {
name := strings.TrimSpace(request.Name)
if name == "" {
return &packet.Error{Error: "server name must not be blank"}
}
if len(request.Icon) > MaxIconSize {
return &packet.Error{Error: fmt.Sprintf("icon is too large, must be smaller than %v bytes", MaxIconSize)}
}
if ok, err := isValidHexColor(request.BgHexColor); !ok {
return &packet.Error{Error: err}
}
if ok, err := isValidHexColor(request.FgHexColor); !ok {
return &packet.Error{Error: err}
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
log.Println("database error:", err)
return &internalError
}
queries := data.New(db)
qtx := queries.WithTx(tx)
network, err := qtx.CreateNetwork(ctx, data.CreateNetworkParams{
ID: sess.Manager().Node().Generate(),
OwnerID: sess.ID(),
Name: name,
IsPublic: request.IsPublic,
Icon: request.Icon,
BgHexColor: request.BgHexColor,
FgHexColor: request.FgHexColor,
})
if err != nil {
log.Println("database error:", err)
return &internalError
}
frequency, err := qtx.CreateFrequency(ctx, data.CreateFrequencyParams{
ID: sess.Manager().Node().Generate(),
NetworkID: network.ID,
Name: DefaultFrequencyName,
HexColor: nil,
Perms: PermReadWrite,
})
if err != nil {
log.Println("database error:", err)
return &internalError
}
networkUser, err := qtx.SetNetworkUser(ctx, data.SetNetworkUserParams{
UserID: network.OwnerID,
NetworkID: network.ID,
IsMember: true,
IsAdmin: true,
IsMuted: false,
IsBanned: false,
BanReason: nil,
})
if err != nil {
log.Println("database error:", err)
return &internalError
}
user, err := qtx.GetUserById(ctx, network.OwnerID)
if err != nil {
log.Println("database error:", err)
return &internalError
}
fullNetwork := packet.FullNetwork{
Network: network,
Frequencies: []data.Frequency{frequency},
Members: []packet.Member{{
JoinedAt: networkUser.JoinedAt,
User: user,
IsAdmin: networkUser.IsAdmin,
IsMuted: networkUser.IsMuted,
}},
}
return &packet.NetworksInfo{
Networks: []packet.FullNetwork{fullNetwork},
}
}

View File

@@ -0,0 +1,23 @@
package api
import "strings"
const hex = "0123456789abcdefABCDEF"
func isValidHexColor(color string) (bool, string) {
if len(color) != 7 {
return false, "color must be hex with length of 7"
}
if color[0] != '#' {
return false, "color must start with '#'"
}
for _, c := range color {
if !strings.ContainsRune(hex, c) {
return false, "color must start with '#' and contain exactly 6 digits 0-9, a-f, A-F"
}
}
return true, ""
}

View File

@@ -0,0 +1,15 @@
package api
const (
MaxIconSize = 16
DefaultFrequencyName = "main"
)
const (
PermNoAccess = 0 + iota
PermRead
PermReadWrite
PermMax
)

View File

@@ -289,6 +289,8 @@ func processRequest(ctx context.Context, sess *session.Session, request packet.P
// TODO: add a way to measure the time each request/response took and log it
// Potentially even separate time for code vs DB operations
switch request := request.(type) {
case *packet.CreateNetwork:
return timeout(20*time.Millisecond, api.CreateNetwork, ctx, sess, request)
case *packet.SendMessage:
return timeout(20*time.Millisecond, api.SendMessage, ctx, sess, request)
case *packet.RequestMessages:
@@ -303,6 +305,7 @@ func timeout[T packet.Payload](
apiRequest func(context.Context, *session.Session, T) packet.Payload,
ctx context.Context, sess *session.Session, request T,
) packet.Payload {
// TODO: Remove the channel and just wait directly?
responseChan := make(chan packet.Payload)
ctx, cancel := context.WithTimeout(ctx, timeoutDuration)
defer cancel()

View File

@@ -16,7 +16,7 @@ INSERT INTO networks (
)
RETURNING *;
-- name: GetBannedUsersInNetwork :many
-- name: GetNetworkBannedUsers :many
SELECT
sqlc.embed(users),
users_networks.ban_reason
@@ -24,7 +24,7 @@ FROM users_networks
JOIN users ON users.id = users_networks.user_id
WHERE users_networks.network_id = ?;
-- name: GetUsersInNetwork :many
-- name: GetNetworkMembers :many
SELECT
sqlc.embed(users),
users_networks.joined_at,
@@ -62,3 +62,20 @@ RETURNING *;
-- name: DeleteNetwork :exec
DELETE FROM networks WHERE id = ?;
-- name: SetNetworkUser :one
INSERT INTO users_networks (
user_id, network_id,
is_member, is_admin, is_muted,
is_banned, ban_reason
) VALUES (
?, ?,
?, ?, ?,
?, ?
)
ON CONFLICT DO
UPDATE SET
is_member = ?, is_admin = ?, is_muted = ?,
is_banned = ?, ban_reason = ?
WHERE user_id = ? AND network_id = ?
RETURNING *;