mirror of
https://github.com/Kyren223/eko.git
synced 2026-04-27 15:33:55 +00:00
Implemented joining functionality
This commit is contained in:
@@ -46,7 +46,6 @@ const (
|
||||
type Model struct {
|
||||
name string
|
||||
privKey ed25519.PrivateKey
|
||||
id snowflake.ID
|
||||
|
||||
loading loadscreen.Model
|
||||
timer timer.Model
|
||||
@@ -66,7 +65,6 @@ func New(privKey ed25519.PrivateKey, name string) Model {
|
||||
m := Model{
|
||||
name: name,
|
||||
privKey: privKey,
|
||||
id: 0,
|
||||
loading: loadscreen.New(connectingToServer),
|
||||
timer: newTimer(initialTimeout),
|
||||
timeout: initialTimeout,
|
||||
@@ -134,7 +132,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
func (m *Model) updateNotConnected(msg tea.Msg) tea.Cmd {
|
||||
switch msg := msg.(type) {
|
||||
case gateway.ConnectionEstablished:
|
||||
m.id = snowflake.ID(msg)
|
||||
state.State.UserID = (*snowflake.ID)(&msg)
|
||||
m.connected = true
|
||||
m.timeout = initialTimeout
|
||||
return m.timer.Stop()
|
||||
@@ -176,6 +174,7 @@ func (m *Model) updateConnected(msg tea.Msg) tea.Cmd {
|
||||
gateway.Disconnect()
|
||||
|
||||
case gateway.ConnectionLost:
|
||||
state.State.UserID = nil
|
||||
m.connected = false
|
||||
m.timeout = initialTimeout
|
||||
return tea.Batch(gateway.Connect(m.privKey, connectionTimeout), m.loading.Init())
|
||||
@@ -185,7 +184,24 @@ func (m *Model) updateConnected(msg tea.Msg) tea.Cmd {
|
||||
state.State.Networks = msg.Networks
|
||||
} else {
|
||||
networks := state.State.Networks
|
||||
networks = append(networks, msg.Networks...)
|
||||
|
||||
for _, newNetwork := range msg.Networks {
|
||||
add := true
|
||||
for i, existingNetwork := range networks {
|
||||
if existingNetwork.ID == newNetwork.ID {
|
||||
add = false
|
||||
if newNetwork.Position == -1 {
|
||||
newNetwork.Position = existingNetwork.Position
|
||||
}
|
||||
networks[i] = newNetwork
|
||||
break
|
||||
}
|
||||
}
|
||||
if add {
|
||||
networks = append(networks, newNetwork)
|
||||
}
|
||||
}
|
||||
|
||||
networks = slices.DeleteFunc(networks, func(network packet.FullNetwork) bool {
|
||||
return slices.Contains(msg.RemoveNetworks, network.ID)
|
||||
})
|
||||
|
||||
@@ -2,16 +2,19 @@ package networkjoin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/kyren223/eko/internal/client/gateway"
|
||||
"github.com/kyren223/eko/internal/client/ui/colors"
|
||||
"github.com/kyren223/eko/internal/client/ui/core/state"
|
||||
"github.com/kyren223/eko/internal/client/ui/field"
|
||||
"github.com/kyren223/eko/internal/client/ui/layouts/flex"
|
||||
"github.com/kyren223/eko/internal/packet"
|
||||
"github.com/kyren223/eko/pkg/assert"
|
||||
"github.com/kyren223/eko/pkg/snowflake"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -70,6 +73,9 @@ func New() Model {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return errors.New("cannot be empty")
|
||||
}
|
||||
if _, err := strconv.ParseInt(s, 10, 64); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -154,10 +160,19 @@ func (m *Model) Select() tea.Cmd {
|
||||
if m.name.Input.Err != nil {
|
||||
return nil
|
||||
}
|
||||
name := m.name.Input.Value()
|
||||
id, err := strconv.ParseInt(name, 10, 64)
|
||||
assert.NoError(err, "input is already validated to be valid")
|
||||
|
||||
// TODO: change this to actual invite packet
|
||||
request := packet.CreateNetwork{
|
||||
Name: m.name.Input.Value(),
|
||||
yes := true
|
||||
request := packet.SetNetworkUser{
|
||||
Member: &yes,
|
||||
Admin: nil,
|
||||
Muted: nil,
|
||||
Banned: nil,
|
||||
BanReason: nil,
|
||||
Network: snowflake.ID(id),
|
||||
User: *state.State.UserID,
|
||||
}
|
||||
return gateway.Send(&request)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ type Frequency struct {
|
||||
}
|
||||
|
||||
type state struct {
|
||||
UserID *snowflake.ID
|
||||
|
||||
// Key is either a frequency or receiver
|
||||
FrequencyState map[snowflake.ID]Frequency
|
||||
LastFrequency map[snowflake.ID]snowflake.ID // key is network
|
||||
@@ -23,6 +25,7 @@ type state struct {
|
||||
}
|
||||
|
||||
var State state = state{
|
||||
UserID: nil,
|
||||
FrequencyState: map[snowflake.ID]Frequency{},
|
||||
LastFrequency: map[snowflake.ID]snowflake.ID{},
|
||||
Messages: map[snowflake.ID]*btree.BTreeG[data.Message]{},
|
||||
|
||||
@@ -56,6 +56,34 @@ func (q *Queries) GetNetworkBannedUsers(ctx context.Context, networkID snowflake
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getNetworkMemberById = `-- name: GetNetworkMemberById :one
|
||||
SELECT user_id, network_id, joined_at, is_member, is_admin, is_muted, is_banned, ban_reason, position
|
||||
FROM users_networks
|
||||
WHERE users_networks.network_id = ? AND users_networks.user_id = ?
|
||||
`
|
||||
|
||||
type GetNetworkMemberByIdParams struct {
|
||||
NetworkID snowflake.ID
|
||||
UserID snowflake.ID
|
||||
}
|
||||
|
||||
func (q *Queries) GetNetworkMemberById(ctx context.Context, arg GetNetworkMemberByIdParams) (UserNetwork, error) {
|
||||
row := q.db.QueryRowContext(ctx, getNetworkMemberById, arg.NetworkID, arg.UserID)
|
||||
var i UserNetwork
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.NetworkID,
|
||||
&i.JoinedAt,
|
||||
&i.IsMember,
|
||||
&i.IsAdmin,
|
||||
&i.IsMuted,
|
||||
&i.IsBanned,
|
||||
&i.BanReason,
|
||||
&i.Position,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getNetworkMembers = `-- name: GetNetworkMembers :many
|
||||
SELECT
|
||||
users.id, users.name, users.public_key, users.description, users.is_public_dm, users.is_deleted,
|
||||
|
||||
@@ -411,3 +411,113 @@ func DeleteNetwork(ctx context.Context, sess *session.Session, request *packet.D
|
||||
Set: false,
|
||||
}
|
||||
}
|
||||
|
||||
func SetNetworkUser(ctx context.Context, sess *session.Session, request packet.SetNetworkUser) packet.Payload {
|
||||
queries := data.New(db)
|
||||
|
||||
network, err := queries.GetNetworkById(ctx, request.Network)
|
||||
if err != nil {
|
||||
log.Println("database error 1:", err)
|
||||
return &ErrInternalError
|
||||
}
|
||||
|
||||
member, err := queries.GetNetworkMemberById(ctx, data.GetNetworkMemberByIdParams{
|
||||
NetworkID: request.Network,
|
||||
UserID: request.User,
|
||||
})
|
||||
|
||||
wantsToJoin := request.Member != nil && *request.Member && request.User == sess.ID()
|
||||
if err == sql.ErrNoRows && network.IsPublic && wantsToJoin {
|
||||
_, err = queries.SetNetworkUser(ctx, data.SetNetworkUserParams{
|
||||
UserID: request.User,
|
||||
NetworkID: request.Network,
|
||||
IsMember: true,
|
||||
IsAdmin: false,
|
||||
IsMuted: false,
|
||||
IsBanned: false,
|
||||
BanReason: nil,
|
||||
})
|
||||
if err != nil {
|
||||
log.Println("database error 2:", err)
|
||||
return &ErrInternalError
|
||||
}
|
||||
|
||||
payload, err := GetSingleNetworkInfo(ctx, queries, network)
|
||||
if err != nil {
|
||||
log.Println("database error 3:", err)
|
||||
return &ErrInternalError
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Println("database error 4:", err)
|
||||
return &ErrInternalError
|
||||
}
|
||||
|
||||
isSessAdmin, err := IsNetworkAdmin(ctx, queries, sess.ID(), request.Network)
|
||||
if err != nil {
|
||||
log.Println("database error 5:", err)
|
||||
return &ErrInternalError
|
||||
}
|
||||
|
||||
isMember := member.IsMember
|
||||
isAdmin := member.IsAdmin
|
||||
isMuted := member.IsMuted
|
||||
IsBanned := member.IsBanned
|
||||
banReason := member.BanReason
|
||||
|
||||
if request.Member != nil && !IsBanned {
|
||||
isLeave := !*request.Member && request.User == sess.ID()
|
||||
isKick := !*request.Member && isSessAdmin
|
||||
if request.User != network.OwnerID && (isLeave || isKick) {
|
||||
isMember = false
|
||||
isAdmin = false // Important for security
|
||||
}
|
||||
isJoin := *request.Member && request.User == sess.ID() && network.IsPublic
|
||||
if isJoin {
|
||||
isMember = true
|
||||
}
|
||||
} else if request.Admin != nil {
|
||||
if network.OwnerID == sess.ID() && request.User != sess.ID() {
|
||||
isAdmin = *request.Admin
|
||||
}
|
||||
} else if request.Muted != nil {
|
||||
notSelf := request.User != sess.ID()
|
||||
notOwner := request.User != network.OwnerID
|
||||
if isSessAdmin && notSelf && notOwner {
|
||||
isMuted = *request.Muted
|
||||
}
|
||||
} else if request.Banned != nil {
|
||||
notSelf := request.User != sess.ID()
|
||||
notOwner := request.User != network.OwnerID
|
||||
if isSessAdmin && notSelf && notOwner {
|
||||
IsBanned = *request.Banned
|
||||
banReason = request.BanReason
|
||||
isAdmin = false // Important for security
|
||||
}
|
||||
}
|
||||
|
||||
_, err = queries.SetNetworkUser(ctx, data.SetNetworkUserParams{
|
||||
UserID: request.User,
|
||||
NetworkID: request.Network,
|
||||
IsMember: isMember,
|
||||
IsAdmin: isAdmin,
|
||||
IsMuted: isMuted,
|
||||
IsBanned: IsBanned,
|
||||
BanReason: banReason,
|
||||
})
|
||||
if err != nil {
|
||||
log.Println("database error 6:", err)
|
||||
return &ErrInternalError
|
||||
}
|
||||
|
||||
payload, err := GetSingleNetworkInfo(ctx, queries, network)
|
||||
if err != nil {
|
||||
log.Println("database error 7:", err)
|
||||
return &ErrInternalError
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/kyren223/eko/internal/data"
|
||||
"github.com/kyren223/eko/internal/packet"
|
||||
"github.com/kyren223/eko/pkg/snowflake"
|
||||
)
|
||||
|
||||
@@ -40,3 +41,28 @@ func IsNetworkAdmin(ctx context.Context, queries *data.Queries, userId, networkI
|
||||
isAdmin := userNetwork.IsAdmin && userNetwork.IsMember && !userNetwork.IsBanned
|
||||
return isAdmin, nil
|
||||
}
|
||||
|
||||
func GetSingleNetworkInfo(ctx context.Context, queries *data.Queries, network data.Network) (packet.Payload, error) {
|
||||
frequencies, err := queries.GetNetworkFrequencies(ctx, network.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
members, err := queries.GetNetworkMembers(ctx, network.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fullNetwork := packet.FullNetwork{
|
||||
Network: network,
|
||||
Frequencies: frequencies,
|
||||
Members: members,
|
||||
Position: -1,
|
||||
}
|
||||
|
||||
return &packet.NetworksInfo{
|
||||
Networks: []packet.FullNetwork{fullNetwork},
|
||||
RemoveNetworks: nil,
|
||||
Set: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -319,6 +319,8 @@ func processRequest(ctx context.Context, sess *session.Session, request packet.P
|
||||
response = timeout(500*time.Millisecond, api.DeleteNetwork, ctx, sess, request)
|
||||
case *packet.SwapUserNetworks:
|
||||
response = timeout(5*time.Millisecond, api.SwapUserNetworks, ctx, sess, request)
|
||||
case *packet.SetNetworkUser:
|
||||
response = timeout(5*time.Millisecond, api.SetNetworkUser, ctx, sess, request)
|
||||
|
||||
case *packet.CreateFrequency:
|
||||
response = timeout(5*time.Millisecond, api.CreateFrequency, ctx, sess, request)
|
||||
|
||||
@@ -16,6 +16,11 @@ FROM users_networks
|
||||
JOIN users ON users.id = users_networks.user_id
|
||||
WHERE users_networks.network_id = ? AND is_member = true;
|
||||
|
||||
-- name: GetNetworkMemberById :one
|
||||
SELECT *
|
||||
FROM users_networks
|
||||
WHERE users_networks.network_id = ? AND users_networks.user_id = ?;
|
||||
|
||||
-- name: GetUserNetworks :many
|
||||
SELECT sqlc.embed(networks), users_networks.position FROM networks
|
||||
JOIN users_networks ON networks.id = users_networks.network_id
|
||||
|
||||
Reference in New Issue
Block a user