Implemented joining functionality

This commit is contained in:
2025-01-04 18:35:44 +02:00
parent feb8562dd0
commit 614cc4d9da
8 changed files with 212 additions and 7 deletions

View File

@@ -46,7 +46,6 @@ const (
type Model struct {
name string
privKey ed25519.PrivateKey
id snowflake.ID
loading loadscreen.Model
timer timer.Model
@@ -66,7 +65,6 @@ func New(privKey ed25519.PrivateKey, name string) Model {
m := Model{
name: name,
privKey: privKey,
id: 0,
loading: loadscreen.New(connectingToServer),
timer: newTimer(initialTimeout),
timeout: initialTimeout,
@@ -134,7 +132,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (m *Model) updateNotConnected(msg tea.Msg) tea.Cmd {
switch msg := msg.(type) {
case gateway.ConnectionEstablished:
m.id = snowflake.ID(msg)
state.State.UserID = (*snowflake.ID)(&msg)
m.connected = true
m.timeout = initialTimeout
return m.timer.Stop()
@@ -176,6 +174,7 @@ func (m *Model) updateConnected(msg tea.Msg) tea.Cmd {
gateway.Disconnect()
case gateway.ConnectionLost:
state.State.UserID = nil
m.connected = false
m.timeout = initialTimeout
return tea.Batch(gateway.Connect(m.privKey, connectionTimeout), m.loading.Init())
@@ -185,7 +184,24 @@ func (m *Model) updateConnected(msg tea.Msg) tea.Cmd {
state.State.Networks = msg.Networks
} else {
networks := state.State.Networks
networks = append(networks, msg.Networks...)
for _, newNetwork := range msg.Networks {
add := true
for i, existingNetwork := range networks {
if existingNetwork.ID == newNetwork.ID {
add = false
if newNetwork.Position == -1 {
newNetwork.Position = existingNetwork.Position
}
networks[i] = newNetwork
break
}
}
if add {
networks = append(networks, newNetwork)
}
}
networks = slices.DeleteFunc(networks, func(network packet.FullNetwork) bool {
return slices.Contains(msg.RemoveNetworks, network.ID)
})

View File

@@ -2,16 +2,19 @@ package networkjoin
import (
"errors"
"strconv"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kyren223/eko/internal/client/gateway"
"github.com/kyren223/eko/internal/client/ui/colors"
"github.com/kyren223/eko/internal/client/ui/core/state"
"github.com/kyren223/eko/internal/client/ui/field"
"github.com/kyren223/eko/internal/client/ui/layouts/flex"
"github.com/kyren223/eko/internal/packet"
"github.com/kyren223/eko/pkg/assert"
"github.com/kyren223/eko/pkg/snowflake"
)
var (
@@ -70,6 +73,9 @@ func New() Model {
if strings.TrimSpace(s) == "" {
return errors.New("cannot be empty")
}
if _, err := strconv.ParseInt(s, 10, 64); err != nil {
return err
}
return nil
}
@@ -154,10 +160,19 @@ func (m *Model) Select() tea.Cmd {
if m.name.Input.Err != nil {
return nil
}
name := m.name.Input.Value()
id, err := strconv.ParseInt(name, 10, 64)
assert.NoError(err, "input is already validated to be valid")
// TODO: change this to actual invite packet
request := packet.CreateNetwork{
Name: m.name.Input.Value(),
yes := true
request := packet.SetNetworkUser{
Member: &yes,
Admin: nil,
Muted: nil,
Banned: nil,
BanReason: nil,
Network: snowflake.ID(id),
User: *state.State.UserID,
}
return gateway.Send(&request)
}

View File

@@ -14,6 +14,8 @@ type Frequency struct {
}
type state struct {
UserID *snowflake.ID
// Key is either a frequency or receiver
FrequencyState map[snowflake.ID]Frequency
LastFrequency map[snowflake.ID]snowflake.ID // key is network
@@ -23,6 +25,7 @@ type state struct {
}
var State state = state{
UserID: nil,
FrequencyState: map[snowflake.ID]Frequency{},
LastFrequency: map[snowflake.ID]snowflake.ID{},
Messages: map[snowflake.ID]*btree.BTreeG[data.Message]{},

View File

@@ -56,6 +56,34 @@ func (q *Queries) GetNetworkBannedUsers(ctx context.Context, networkID snowflake
return items, nil
}
const getNetworkMemberById = `-- name: GetNetworkMemberById :one
SELECT user_id, network_id, joined_at, is_member, is_admin, is_muted, is_banned, ban_reason, position
FROM users_networks
WHERE users_networks.network_id = ? AND users_networks.user_id = ?
`
type GetNetworkMemberByIdParams struct {
NetworkID snowflake.ID
UserID snowflake.ID
}
func (q *Queries) GetNetworkMemberById(ctx context.Context, arg GetNetworkMemberByIdParams) (UserNetwork, error) {
row := q.db.QueryRowContext(ctx, getNetworkMemberById, arg.NetworkID, arg.UserID)
var i UserNetwork
err := row.Scan(
&i.UserID,
&i.NetworkID,
&i.JoinedAt,
&i.IsMember,
&i.IsAdmin,
&i.IsMuted,
&i.IsBanned,
&i.BanReason,
&i.Position,
)
return i, err
}
const getNetworkMembers = `-- name: GetNetworkMembers :many
SELECT
users.id, users.name, users.public_key, users.description, users.is_public_dm, users.is_deleted,

View File

@@ -411,3 +411,113 @@ func DeleteNetwork(ctx context.Context, sess *session.Session, request *packet.D
Set: false,
}
}
func SetNetworkUser(ctx context.Context, sess *session.Session, request packet.SetNetworkUser) packet.Payload {
queries := data.New(db)
network, err := queries.GetNetworkById(ctx, request.Network)
if err != nil {
log.Println("database error 1:", err)
return &ErrInternalError
}
member, err := queries.GetNetworkMemberById(ctx, data.GetNetworkMemberByIdParams{
NetworkID: request.Network,
UserID: request.User,
})
wantsToJoin := request.Member != nil && *request.Member && request.User == sess.ID()
if err == sql.ErrNoRows && network.IsPublic && wantsToJoin {
_, err = queries.SetNetworkUser(ctx, data.SetNetworkUserParams{
UserID: request.User,
NetworkID: request.Network,
IsMember: true,
IsAdmin: false,
IsMuted: false,
IsBanned: false,
BanReason: nil,
})
if err != nil {
log.Println("database error 2:", err)
return &ErrInternalError
}
payload, err := GetSingleNetworkInfo(ctx, queries, network)
if err != nil {
log.Println("database error 3:", err)
return &ErrInternalError
}
return payload
}
if err != nil {
log.Println("database error 4:", err)
return &ErrInternalError
}
isSessAdmin, err := IsNetworkAdmin(ctx, queries, sess.ID(), request.Network)
if err != nil {
log.Println("database error 5:", err)
return &ErrInternalError
}
isMember := member.IsMember
isAdmin := member.IsAdmin
isMuted := member.IsMuted
IsBanned := member.IsBanned
banReason := member.BanReason
if request.Member != nil && !IsBanned {
isLeave := !*request.Member && request.User == sess.ID()
isKick := !*request.Member && isSessAdmin
if request.User != network.OwnerID && (isLeave || isKick) {
isMember = false
isAdmin = false // Important for security
}
isJoin := *request.Member && request.User == sess.ID() && network.IsPublic
if isJoin {
isMember = true
}
} else if request.Admin != nil {
if network.OwnerID == sess.ID() && request.User != sess.ID() {
isAdmin = *request.Admin
}
} else if request.Muted != nil {
notSelf := request.User != sess.ID()
notOwner := request.User != network.OwnerID
if isSessAdmin && notSelf && notOwner {
isMuted = *request.Muted
}
} else if request.Banned != nil {
notSelf := request.User != sess.ID()
notOwner := request.User != network.OwnerID
if isSessAdmin && notSelf && notOwner {
IsBanned = *request.Banned
banReason = request.BanReason
isAdmin = false // Important for security
}
}
_, err = queries.SetNetworkUser(ctx, data.SetNetworkUserParams{
UserID: request.User,
NetworkID: request.Network,
IsMember: isMember,
IsAdmin: isAdmin,
IsMuted: isMuted,
IsBanned: IsBanned,
BanReason: banReason,
})
if err != nil {
log.Println("database error 6:", err)
return &ErrInternalError
}
payload, err := GetSingleNetworkInfo(ctx, queries, network)
if err != nil {
log.Println("database error 7:", err)
return &ErrInternalError
}
return payload
}

View File

@@ -5,6 +5,7 @@ import (
"strings"
"github.com/kyren223/eko/internal/data"
"github.com/kyren223/eko/internal/packet"
"github.com/kyren223/eko/pkg/snowflake"
)
@@ -40,3 +41,28 @@ func IsNetworkAdmin(ctx context.Context, queries *data.Queries, userId, networkI
isAdmin := userNetwork.IsAdmin && userNetwork.IsMember && !userNetwork.IsBanned
return isAdmin, nil
}
func GetSingleNetworkInfo(ctx context.Context, queries *data.Queries, network data.Network) (packet.Payload, error) {
frequencies, err := queries.GetNetworkFrequencies(ctx, network.ID)
if err != nil {
return nil, err
}
members, err := queries.GetNetworkMembers(ctx, network.ID)
if err != nil {
return nil, err
}
fullNetwork := packet.FullNetwork{
Network: network,
Frequencies: frequencies,
Members: members,
Position: -1,
}
return &packet.NetworksInfo{
Networks: []packet.FullNetwork{fullNetwork},
RemoveNetworks: nil,
Set: false,
}, nil
}

View File

@@ -319,6 +319,8 @@ func processRequest(ctx context.Context, sess *session.Session, request packet.P
response = timeout(500*time.Millisecond, api.DeleteNetwork, ctx, sess, request)
case *packet.SwapUserNetworks:
response = timeout(5*time.Millisecond, api.SwapUserNetworks, ctx, sess, request)
case *packet.SetNetworkUser:
response = timeout(5*time.Millisecond, api.SetNetworkUser, ctx, sess, request)
case *packet.CreateFrequency:
response = timeout(5*time.Millisecond, api.CreateFrequency, ctx, sess, request)

View File

@@ -16,6 +16,11 @@ FROM users_networks
JOIN users ON users.id = users_networks.user_id
WHERE users_networks.network_id = ? AND is_member = true;
-- name: GetNetworkMemberById :one
SELECT *
FROM users_networks
WHERE users_networks.network_id = ? AND users_networks.user_id = ?;
-- name: GetUserNetworks :many
SELECT sqlc.embed(networks), users_networks.position FROM networks
JOIN users_networks ON networks.id = users_networks.network_id