Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/tunnel/controllers/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ type Relay interface {
SetOnConnect(onConnect func(ctx context.Context, tunnelName, agentName string, conn Connection) error)
// SetOnDisconnect sets a callback that is invoked when a connection is closed.
SetOnDisconnect(onDisconnect func(ctx context.Context, agentName, id string) error)
// SetOnShutdown sets a callback that is invoked when the relay is shutting down.
SetOnShutdown(onShutdown func())
}
58 changes: 57 additions & 1 deletion pkg/tunnel/controllers/tunnel_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package controllers
import (
"context"
"fmt"
"log/slog"
"slices"

apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/util/retry"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
Expand All @@ -25,11 +27,13 @@ type TunnelReconciler struct {
}

func NewTunnelReconciler(c client.Client, relay Relay, labelSelector string) *TunnelReconciler {
return &TunnelReconciler{
r := &TunnelReconciler{
client: c,
relay: relay,
labelSelector: labelSelector,
}
relay.SetOnShutdown(r.RemoveRelayAddress)
return r
}

func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
Expand Down Expand Up @@ -102,3 +106,55 @@ func (r *TunnelReconciler) SetupWithManager(mgr ctrl.Manager) error {
For(&corev1alpha2.Tunnel{}, builder.WithPredicates(&predicate.ResourceVersionChangedPredicate{}, ls)).
Complete(r)
}

func (r *TunnelReconciler) RemoveRelayAddress() {
ctx := context.Background()

// Build the same label selector we filter on during watch.
lss, err := metav1.ParseToLabelSelector(r.labelSelector)
if err != nil {
slog.Error("Failed to parse label selector during shutdown cleanup", slog.Any("error", err))
return
}
sel, err := metav1.LabelSelectorAsSelector(lss)
if err != nil {
slog.Error("Failed to build label selector during shutdown cleanup", slog.Any("error", err))
return
}

var list corev1alpha2.TunnelList
if err := r.client.List(ctx, &list, &client.ListOptions{LabelSelector: sel}); err != nil {
slog.Error("Failed to list tunnels during shutdown cleanup", slog.Any("error", err))
return
}

relayAddr := r.relay.Address().String()
for _, t := range list.Items {
// Skip if there's nothing to remove.
if !slices.Contains(t.Status.Addresses, relayAddr) {
continue
}

key := types.NamespacedName{Namespace: t.Namespace, Name: t.Name}
err := retry.RetryOnConflict(retry.DefaultRetry, func() error {
var latest corev1alpha2.Tunnel
if err := r.client.Get(ctx, key, &latest); err != nil {
return err
}

// Filter out this relay's address.
filtered := latest.Status.Addresses[:0]
for _, a := range latest.Status.Addresses {
if a != relayAddr {
filtered = append(filtered, a)
}
}
latest.Status.Addresses = filtered

return r.client.Status().Update(ctx, &latest)
})
if err != nil {
slog.Error("Failed to remove relay address from tunnel during shutdown cleanup", slog.Any("error", err), slog.String("tunnel", key.String()))
}
}
}
58 changes: 58 additions & 0 deletions pkg/tunnel/controllers/tunnel_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func TestTunnelReconciler(t *testing.T) {
relay.On("SetCredentials", "tun-1", "secret-token").Once()
relay.On("SetRelayAddresses", "tun-1", mock.Anything).Once()
relay.On("SetEgressGateway", mock.Anything).Return().Once()
relay.On("SetOnShutdown", mock.Anything).Return().Once()

r := controllers.NewTunnelReconciler(c, relay, "")

Expand All @@ -53,6 +54,59 @@ func TestTunnelReconciler(t *testing.T) {
relay.AssertExpectations(t)
}

func TestTunnelReconciler_OnShutdownRemovesAddress(t *testing.T) {
scheme := runtime.NewScheme()
require.NoError(t, corev1alpha2.Install(scheme))

relayAddr := netip.MustParseAddrPort("1.1.1.1:443")

tunnel := &corev1alpha2.Tunnel{
ObjectMeta: metav1.ObjectMeta{Name: "tun-1"},
Status: corev1alpha2.TunnelStatus{
// Seed with our relay address plus another one to ensure only ours is removed.
Addresses: []string{relayAddr.String(), "2.2.2.2:443"},
},
}

c := fakeclient.NewClientBuilder().
WithScheme(scheme).
WithStatusSubresource(&corev1alpha2.Tunnel{}).
WithObjects(tunnel).
Build()

relay := &mockRelay{}
relay.On("Address").Return(relayAddr)

var onShutdown func()
relay.
On("SetOnShutdown", mock.Anything).
Run(func(args mock.Arguments) {
onShutdown = args.Get(0).(func())
}).
Return().
Once()

// We don't need other relay expectations for this test.
controllers.NewTunnelReconciler(c, relay, "")

// Sanity: ensure the tunnel initially contains the relay address.
var before corev1alpha2.Tunnel
require.NoError(t, c.Get(context.Background(), types.NamespacedName{Name: "tun-1"}, &before))
require.Contains(t, before.Status.Addresses, relayAddr.String())

// Invoke the captured shutdown hook.
require.NotNil(t, onShutdown, "onShutdown should be captured from SetOnShutdown")
onShutdown()

// After shutdown, our relay address should be removed from the status.
var after corev1alpha2.Tunnel
require.NoError(t, c.Get(context.Background(), types.NamespacedName{Name: "tun-1"}, &after))
require.NotContains(t, after.Status.Addresses, relayAddr.String())
require.ElementsMatch(t, []string{"2.2.2.2:443"}, after.Status.Addresses)

relay.AssertExpectations(t)
}

func testLogr(t *testing.T) logr.Logger {
if testing.Verbose() {
l := stdr.New(log.New(os.Stdout, "", log.LstdFlags))
Expand Down Expand Up @@ -94,3 +148,7 @@ func (m *mockRelay) SetOnConnect(onConnect func(ctx context.Context, tunnelName,
func (m *mockRelay) SetOnDisconnect(onDisconnect func(ctx context.Context, agentName, id string) error) {
m.Called(onDisconnect)
}

func (m *mockRelay) SetOnShutdown(onShutdown func()) {
m.Called(onShutdown)
}
18 changes: 18 additions & 0 deletions pkg/tunnel/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type Relay struct {
agents *haxmap.Map[string, string] // map[connectionID]agentName
onConnect func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error
onDisconnect func(ctx context.Context, agentName, id string) error
onShutdown func()
}

func NewRelay(name string, pc net.PacketConn, cert tls.Certificate, handler *icx.Handler, idHasher *hasher.Hasher, router router.Router) *Relay {
Expand Down Expand Up @@ -112,6 +113,14 @@ func (r *Relay) SetOnDisconnect(onDisconnect func(ctx context.Context, agentName
r.onDisconnect = onDisconnect
}

// SetOnShutdown sets a callback that is invoked when the relay is shutting down.
func (r *Relay) SetOnShutdown(onShutdown func()) {
r.mu.Lock()
defer r.mu.Unlock()

r.onShutdown = onShutdown
}

// Start starts the relay.
func (r *Relay) Start(ctx context.Context) error {
ln, err := quic.ListenEarly(
Expand Down Expand Up @@ -178,6 +187,15 @@ func (r *Relay) Start(ctx context.Context) error {
return err
}

// Invoke shutdown callback if set.
r.mu.Lock()
onShutdown := r.onShutdown
r.mu.Unlock()

if onShutdown != nil {
onShutdown()
}

return nil
}

Expand Down