diff --git a/pkg/tunnel/controllers/relay.go b/pkg/tunnel/controllers/relay.go index 31594ba..fae2ce4 100644 --- a/pkg/tunnel/controllers/relay.go +++ b/pkg/tunnel/controllers/relay.go @@ -21,4 +21,6 @@ type Relay interface { SetOnConnect(onConnect func(ctx context.Context, tunnelName, agentName string, conn Connection) error) // SetOnDisconnect sets a callback that is invoked when a connection is closed. SetOnDisconnect(onDisconnect func(ctx context.Context, agentName, id string) error) + // SetOnShutdown sets a callback that is invoked when the relay is shutting down. + SetOnShutdown(onShutdown func()) } diff --git a/pkg/tunnel/controllers/tunnel_reconciler.go b/pkg/tunnel/controllers/tunnel_reconciler.go index ce05461..b349f50 100644 --- a/pkg/tunnel/controllers/tunnel_reconciler.go +++ b/pkg/tunnel/controllers/tunnel_reconciler.go @@ -3,10 +3,12 @@ package controllers import ( "context" "fmt" + "log/slog" "slices" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/util/retry" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" @@ -25,11 +27,13 @@ type TunnelReconciler struct { } func NewTunnelReconciler(c client.Client, relay Relay, labelSelector string) *TunnelReconciler { - return &TunnelReconciler{ + r := &TunnelReconciler{ client: c, relay: relay, labelSelector: labelSelector, } + relay.SetOnShutdown(r.RemoveRelayAddress) + return r } func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { @@ -102,3 +106,55 @@ func (r *TunnelReconciler) SetupWithManager(mgr ctrl.Manager) error { For(&corev1alpha2.Tunnel{}, builder.WithPredicates(&predicate.ResourceVersionChangedPredicate{}, ls)). Complete(r) } + +func (r *TunnelReconciler) RemoveRelayAddress() { + ctx := context.Background() + + // Build the same label selector we filter on during watch. + lss, err := metav1.ParseToLabelSelector(r.labelSelector) + if err != nil { + slog.Error("Failed to parse label selector during shutdown cleanup", slog.Any("error", err)) + return + } + sel, err := metav1.LabelSelectorAsSelector(lss) + if err != nil { + slog.Error("Failed to build label selector during shutdown cleanup", slog.Any("error", err)) + return + } + + var list corev1alpha2.TunnelList + if err := r.client.List(ctx, &list, &client.ListOptions{LabelSelector: sel}); err != nil { + slog.Error("Failed to list tunnels during shutdown cleanup", slog.Any("error", err)) + return + } + + relayAddr := r.relay.Address().String() + for _, t := range list.Items { + // Skip if there's nothing to remove. + if !slices.Contains(t.Status.Addresses, relayAddr) { + continue + } + + key := types.NamespacedName{Namespace: t.Namespace, Name: t.Name} + err := retry.RetryOnConflict(retry.DefaultRetry, func() error { + var latest corev1alpha2.Tunnel + if err := r.client.Get(ctx, key, &latest); err != nil { + return err + } + + // Filter out this relay's address. + filtered := latest.Status.Addresses[:0] + for _, a := range latest.Status.Addresses { + if a != relayAddr { + filtered = append(filtered, a) + } + } + latest.Status.Addresses = filtered + + return r.client.Status().Update(ctx, &latest) + }) + if err != nil { + slog.Error("Failed to remove relay address from tunnel during shutdown cleanup", slog.Any("error", err), slog.String("tunnel", key.String())) + } + } +} diff --git a/pkg/tunnel/controllers/tunnel_reconciler_test.go b/pkg/tunnel/controllers/tunnel_reconciler_test.go index 1c97b65..4455417 100644 --- a/pkg/tunnel/controllers/tunnel_reconciler_test.go +++ b/pkg/tunnel/controllers/tunnel_reconciler_test.go @@ -43,6 +43,7 @@ func TestTunnelReconciler(t *testing.T) { relay.On("SetCredentials", "tun-1", "secret-token").Once() relay.On("SetRelayAddresses", "tun-1", mock.Anything).Once() relay.On("SetEgressGateway", mock.Anything).Return().Once() + relay.On("SetOnShutdown", mock.Anything).Return().Once() r := controllers.NewTunnelReconciler(c, relay, "") @@ -53,6 +54,59 @@ func TestTunnelReconciler(t *testing.T) { relay.AssertExpectations(t) } +func TestTunnelReconciler_OnShutdownRemovesAddress(t *testing.T) { + scheme := runtime.NewScheme() + require.NoError(t, corev1alpha2.Install(scheme)) + + relayAddr := netip.MustParseAddrPort("1.1.1.1:443") + + tunnel := &corev1alpha2.Tunnel{ + ObjectMeta: metav1.ObjectMeta{Name: "tun-1"}, + Status: corev1alpha2.TunnelStatus{ + // Seed with our relay address plus another one to ensure only ours is removed. + Addresses: []string{relayAddr.String(), "2.2.2.2:443"}, + }, + } + + c := fakeclient.NewClientBuilder(). + WithScheme(scheme). + WithStatusSubresource(&corev1alpha2.Tunnel{}). + WithObjects(tunnel). + Build() + + relay := &mockRelay{} + relay.On("Address").Return(relayAddr) + + var onShutdown func() + relay. + On("SetOnShutdown", mock.Anything). + Run(func(args mock.Arguments) { + onShutdown = args.Get(0).(func()) + }). + Return(). + Once() + + // We don't need other relay expectations for this test. + controllers.NewTunnelReconciler(c, relay, "") + + // Sanity: ensure the tunnel initially contains the relay address. + var before corev1alpha2.Tunnel + require.NoError(t, c.Get(context.Background(), types.NamespacedName{Name: "tun-1"}, &before)) + require.Contains(t, before.Status.Addresses, relayAddr.String()) + + // Invoke the captured shutdown hook. + require.NotNil(t, onShutdown, "onShutdown should be captured from SetOnShutdown") + onShutdown() + + // After shutdown, our relay address should be removed from the status. + var after corev1alpha2.Tunnel + require.NoError(t, c.Get(context.Background(), types.NamespacedName{Name: "tun-1"}, &after)) + require.NotContains(t, after.Status.Addresses, relayAddr.String()) + require.ElementsMatch(t, []string{"2.2.2.2:443"}, after.Status.Addresses) + + relay.AssertExpectations(t) +} + func testLogr(t *testing.T) logr.Logger { if testing.Verbose() { l := stdr.New(log.New(os.Stdout, "", log.LstdFlags)) @@ -94,3 +148,7 @@ func (m *mockRelay) SetOnConnect(onConnect func(ctx context.Context, tunnelName, func (m *mockRelay) SetOnDisconnect(onDisconnect func(ctx context.Context, agentName, id string) error) { m.Called(onDisconnect) } + +func (m *mockRelay) SetOnShutdown(onShutdown func()) { + m.Called(onShutdown) +} diff --git a/pkg/tunnel/relay.go b/pkg/tunnel/relay.go index a236a4a..fc39535 100644 --- a/pkg/tunnel/relay.go +++ b/pkg/tunnel/relay.go @@ -50,6 +50,7 @@ type Relay struct { agents *haxmap.Map[string, string] // map[connectionID]agentName onConnect func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error onDisconnect func(ctx context.Context, agentName, id string) error + onShutdown func() } func NewRelay(name string, pc net.PacketConn, cert tls.Certificate, handler *icx.Handler, idHasher *hasher.Hasher, router router.Router) *Relay { @@ -112,6 +113,14 @@ func (r *Relay) SetOnDisconnect(onDisconnect func(ctx context.Context, agentName r.onDisconnect = onDisconnect } +// SetOnShutdown sets a callback that is invoked when the relay is shutting down. +func (r *Relay) SetOnShutdown(onShutdown func()) { + r.mu.Lock() + defer r.mu.Unlock() + + r.onShutdown = onShutdown +} + // Start starts the relay. func (r *Relay) Start(ctx context.Context) error { ln, err := quic.ListenEarly( @@ -178,6 +187,15 @@ func (r *Relay) Start(ctx context.Context) error { return err } + // Invoke shutdown callback if set. + r.mu.Lock() + onShutdown := r.onShutdown + r.mu.Unlock() + + if onShutdown != nil { + onShutdown() + } + return nil }