Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 88 additions & 14 deletions network-api/network-api.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ func Register(router *msgpackrouter.Router) {
_ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL)

_ = router.RegisterMethod("udp/connect", udpConnect)
_ = router.RegisterMethod("udp/beginPacket", udpBeginPacket)
_ = router.RegisterMethod("udp/write", udpWrite)
_ = router.RegisterMethod("udp/awaitRead", udpAwaitRead)
_ = router.RegisterMethod("udp/endPacket", udpEndPacket)
_ = router.RegisterMethod("udp/awaitPacket", udpAwaitPacket)
_ = router.RegisterMethod("udp/read", udpRead)
_ = router.RegisterMethod("udp/dropPacket", udpDropPacket)
_ = router.RegisterMethod("udp/close", udpClose)
}

Expand All @@ -58,6 +61,8 @@ var liveConnections = make(map[uint]net.Conn)
var liveListeners = make(map[uint]net.Listener)
var liveUdpConnections = make(map[uint]net.PacketConn)
var udpReadBuffers = make(map[uint][]byte)
var udpWriteTargets = make(map[uint]*net.UDPAddr)
var udpWriteBuffers = make(map[uint][]byte)
var nextConnectionID atomic.Uint32

// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
Expand Down Expand Up @@ -375,9 +380,9 @@ func udpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (
return id, nil
}

func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 4 {
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port, payload"}
func udpBeginPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 3 {
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
Expand All @@ -391,9 +396,33 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
if !ok {
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
}
data, ok := params[3].([]byte)

lock.RLock()
defer lock.RUnlock()
if _, ok := liveUdpConnections[id]; !ok {
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
}
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
if err != nil {
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
}
udpWriteTargets[id] = addr
udpWriteBuffers[id] = nil
return true, nil
}

func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 2 {
return nil, []any{1, "Invalid number of parameters, expected udpConnId, payload"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
if dataStr, ok := params[3].(string); ok {
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
}
data, ok := params[1].([]byte)
if !ok {
if dataStr, ok := params[1].(string); ok {
data = []byte(dataStr)
} else {
// If data is not []byte or string, return an error
Expand All @@ -402,25 +431,52 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
}

lock.RLock()
udpConn, ok := liveUdpConnections[id]
udpBuffer, ok := udpWriteBuffers[id]
if ok {
udpWriteBuffers[id] = append(udpBuffer, data...)
}
lock.RUnlock()
if !ok {
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
}
return len(data), nil
}

targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
if err != nil {
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
func udpEndPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 1 {
return nil, []any{1, "Invalid number of parameters, expected expected udpConnId"}
}
id, buffExists := msgpackrpc.ToUint(params[0])
if !buffExists {
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
}

var udpBuffer []byte
var udpAddr *net.UDPAddr
lock.RLock()
udpConn, connExists := liveUdpConnections[id]
if connExists {
udpBuffer, buffExists = udpWriteBuffers[id]
udpAddr = udpWriteTargets[id]
delete(udpWriteBuffers, id)
delete(udpWriteTargets, id)
}
lock.RUnlock()
if !connExists {
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
}
if n, err := udpConn.WriteTo(data, addr); err != nil {
if !buffExists {
return nil, []any{3, fmt.Sprintf("No UDP packet begun for ID: %d", id)}
}

if n, err := udpConn.WriteTo(udpBuffer, udpAddr); err != nil {
return nil, []any{4, "Failed to write to UDP connection: " + err.Error()}
} else {
return n, nil
}
}

func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
func udpAwaitPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 1 && len(params) != 2 {
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
}
Expand Down Expand Up @@ -472,6 +528,24 @@ func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any)
return []any{n, host, port}, nil
}

func udpDropPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 1 && len(params) != 2 {
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
}

lock.RLock()
delete(udpReadBuffers, id)
lock.RUnlock()
if !ok {
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
}
return true, nil
}

func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 2 && len(params) != 3 {
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"}
Expand All @@ -494,7 +568,7 @@ func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_re
udpReadBuffers[id] = buffer[maxBytes:]
n = maxBytes
} else {
udpReadBuffers[id] = nil
delete(udpReadBuffers, id)
}
}
lock.Unlock()
Expand Down
82 changes: 65 additions & 17 deletions network-api/network-api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,18 @@ func TestUDPNetworkAPI(t *testing.T) {
require.NotEqual(t, conn1, conn2)

{
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Hello")})
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
require.Nil(t, err)
require.True(t, res.(bool))
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Hello")})
require.Nil(t, err)
require.Equal(t, 5, res)
res, err = udpEndPacket(ctx, nil, []any{conn1})
require.Nil(t, err)
require.Equal(t, 5, res)
}
{
res, err := udpAwaitRead(ctx, nil, []any{conn2})
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
require.Nil(t, err)
require.Equal(t, []any{5, "127.0.0.1", 9800}, res)

Expand All @@ -262,26 +268,44 @@ func TestUDPNetworkAPI(t *testing.T) {
require.Equal(t, []uint8("Hello"), res2)
}
{
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
require.Nil(t, err)
require.True(t, res.(bool))
res, err = udpWrite(ctx, nil, []any{conn1, []byte("On")})
require.Nil(t, err)
require.Equal(t, 2, res)
res, err = udpWrite(ctx, nil, []any{conn1, []byte("e")})
require.Nil(t, err)
require.Equal(t, 1, res)
res, err = udpEndPacket(ctx, nil, []any{conn1})
require.Nil(t, err)
require.Equal(t, 3, res)
}
{
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Two")})
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
require.Nil(t, err)
require.True(t, res.(bool))
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Two")})
require.Nil(t, err)
require.Equal(t, 3, res)
res, err = udpEndPacket(ctx, nil, []any{conn1})
require.Nil(t, err)
require.Equal(t, 3, res)
}
{
res, err := udpAwaitRead(ctx, nil, []any{conn2})
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
require.Nil(t, err)
require.Equal(t, []any{3, "127.0.0.1", 9800}, res)

res2, err := udpRead(ctx, nil, []any{conn2, 100})
// A partial read of a packet is allowed
res2, err := udpRead(ctx, nil, []any{conn2, 2})
require.Nil(t, err)
require.Equal(t, []uint8("One"), res2)
require.Equal(t, []uint8("On"), res2)
}
{
res, err := udpAwaitRead(ctx, nil, []any{conn2})
// Even if the previous packet was only partially read,
// the next packet can be received
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
require.Nil(t, err)
require.Equal(t, []any{3, "127.0.0.1", 9800}, res)

Expand Down Expand Up @@ -311,12 +335,18 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
require.NotEqual(t, conn1, conn2)

{
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("Hello")})
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
require.Nil(t, err)
require.True(t, res.(bool))
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Hello")})
require.Nil(t, err)
require.Equal(t, 5, res)
res, err = udpEndPacket(ctx, nil, []any{conn1})
require.Nil(t, err)
require.Equal(t, 5, res)
}
{
res, err := udpAwaitRead(ctx, nil, []any{conn2})
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
require.Nil(t, err)
require.Equal(t, 5, res.([]any)[0])

Expand All @@ -329,17 +359,29 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
require.Equal(t, []uint8("llo"), res2)
}
{
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("One")})
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
require.Nil(t, err)
require.True(t, res.(bool))
res, err = udpWrite(ctx, nil, []any{conn1, []byte("One")})
require.Nil(t, err)
require.Equal(t, 3, res)
res, err = udpEndPacket(ctx, nil, []any{conn1})
require.Nil(t, err)
require.Equal(t, 3, res)
}
{
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("Two")})
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
require.Nil(t, err)
require.True(t, res.(bool))
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Two")})
require.Nil(t, err)
require.Equal(t, 3, res)
res, err = udpEndPacket(ctx, nil, []any{conn1})
require.Nil(t, err)
require.Equal(t, 3, res)
}
{
res, err := udpAwaitRead(ctx, nil, []any{conn2})
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
require.Nil(t, err)
require.Equal(t, 3, res.([]any)[0])

Expand All @@ -348,7 +390,7 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
require.Equal(t, []uint8("One"), res2)
}
{
res, err := udpAwaitRead(ctx, nil, []any{conn2})
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
require.Nil(t, err)
require.Equal(t, 3, res.([]any)[0])

Expand All @@ -360,19 +402,25 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
// Check timeouts
go func() {
time.Sleep(200 * time.Millisecond)
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("Three")})
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
require.Nil(t, err)
require.True(t, res.(bool))
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Three")})
require.Nil(t, err)
require.Equal(t, 5, res)
res, err = udpEndPacket(ctx, nil, []any{conn1})
require.Nil(t, err)
require.Equal(t, 5, res)
}()
{
start := time.Now()
res, err := udpAwaitRead(ctx, nil, []any{conn2, 10})
res, err := udpAwaitPacket(ctx, nil, []any{conn2, 10})
require.Less(t, time.Since(start), 20*time.Millisecond)
require.Equal(t, []any{5, "Timeout"}, err)
require.Nil(t, res)
}
{
res, err := udpAwaitRead(ctx, nil, []any{conn2, 0})
res, err := udpAwaitPacket(ctx, nil, []any{conn2, 0})
require.Nil(t, err)
require.Equal(t, 5, res.([]any)[0])

Expand Down