Skip to content
Draft
18 changes: 14 additions & 4 deletions cmd/vmcp/app/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/stacklok/toolhive/pkg/groups"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/vmcp/aggregator"
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
"github.com/stacklok/toolhive/pkg/vmcp/auth/factory"
vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client"
"github.com/stacklok/toolhive/pkg/vmcp/config"
vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router"
Expand Down Expand Up @@ -213,8 +213,15 @@ func runServe(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("failed to create groups manager: %w", err)
}

// Create outgoing authentication registry from configuration
logger.Info("Initializing outgoing authentication")
outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, cfg.OutgoingAuth)
if err != nil {
return fmt.Errorf("failed to create outgoing authentication registry: %w", err)
}

// Create backend discoverer
discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager)
discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth)

// Discover backends from the configured group
logger.Infof("Discovering backends in group: %s", cfg.GroupRef)
Expand All @@ -230,7 +237,10 @@ func runServe(cmd *cobra.Command, _ []string) error {
logger.Infof("Discovered %d backends", len(backends))

// Create backend client
backendClient := vmcpclient.NewHTTPBackendClient()
backendClient, err := vmcpclient.NewHTTPBackendClient(outgoingRegistry)
if err != nil {
return fmt.Errorf("failed to create backend client: %w", err)
}

// Create conflict resolver based on configuration
// Use the factory method that handles all strategies
Expand Down Expand Up @@ -264,7 +274,7 @@ func runServe(cmd *cobra.Command, _ []string) error {
// Setup authentication middleware
logger.Infof("Setting up incoming authentication (type: %s)", cfg.IncomingAuth.Type)

authMiddleware, authInfoHandler, err := vmcpauth.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth)
authMiddleware, authInfoHandler, err := factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth)
if err != nil {
return fmt.Errorf("failed to create authentication middleware: %w", err)
}
Expand Down
45 changes: 44 additions & 1 deletion pkg/vmcp/aggregator/cli_discoverer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/stacklok/toolhive/pkg/groups"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/vmcp"
"github.com/stacklok/toolhive/pkg/vmcp/config"
"github.com/stacklok/toolhive/pkg/workloads"
)

Expand All @@ -16,14 +17,23 @@ import (
type cliBackendDiscoverer struct {
workloadsManager workloads.Manager
groupsManager groups.Manager
authConfig *config.OutgoingAuthConfig
}

// NewCLIBackendDiscoverer creates a new CLI-based backend discoverer.
// It discovers workloads from Docker/Podman containers managed by ToolHive.
func NewCLIBackendDiscoverer(workloadsManager workloads.Manager, groupsManager groups.Manager) BackendDiscoverer {
//
// The authConfig parameter configures authentication for discovered backends.
// If nil, backends will have no authentication configured.
func NewCLIBackendDiscoverer(
workloadsManager workloads.Manager,
groupsManager groups.Manager,
authConfig *config.OutgoingAuthConfig,
) BackendDiscoverer {
return &cliBackendDiscoverer{
workloadsManager: workloadsManager,
groupsManager: groupsManager,
authConfig: authConfig,
}
}

Expand Down Expand Up @@ -92,6 +102,16 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([
Metadata: make(map[string]string),
}

// Apply authentication configuration if provided
if d.authConfig != nil {
authStrategy, authMetadata := d.resolveAuthConfig(name)
backend.AuthStrategy = authStrategy
backend.AuthMetadata = authMetadata
if authStrategy != "" {
logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy)
}
}

// Copy user labels to metadata first
for k, v := range workload.Labels {
backend.Metadata[k] = v
Expand All @@ -116,6 +136,29 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([
return backends, nil
}

// resolveAuthConfig determines the authentication strategy and metadata for a backend.
// It checks for backend-specific configuration first, then falls back to default.
func (d *cliBackendDiscoverer) resolveAuthConfig(backendID string) (string, map[string]any) {
if d.authConfig == nil {
return "", nil
}

// Check for backend-specific configuration
if strategy, exists := d.authConfig.Backends[backendID]; exists && strategy != nil {
logger.Debugf("Using backend-specific auth strategy for %s: %s", backendID, strategy.Type)
return strategy.Type, strategy.Metadata
}

// Fall back to default configuration
if d.authConfig.Default != nil {
logger.Debugf("Using default auth strategy for %s: %s", backendID, d.authConfig.Default.Type)
return d.authConfig.Default.Type, d.authConfig.Default.Metadata
}

// No authentication configured
return "", nil
}

// mapWorkloadStatusToHealth converts a workload status to a backend health status.
func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus {
switch status {
Expand Down
18 changes: 9 additions & 9 deletions pkg/vmcp/aggregator/cli_discoverer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil)
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil)

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), testGroupName)

require.NoError(t, err)
Expand Down Expand Up @@ -79,7 +79,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil)
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil)

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), testGroupName)

require.NoError(t, err)
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil)
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil)

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), testGroupName)

require.NoError(t, err)
Expand All @@ -133,7 +133,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil)
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil)

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), testGroupName)

require.NoError(t, err)
Expand All @@ -150,7 +150,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {

mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil)

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), "nonexistent-group")

require.Error(t, err)
Expand All @@ -168,7 +168,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {

mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error"))

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), testGroupName)

require.Error(t, err)
Expand All @@ -187,7 +187,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {
mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil)
mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil)

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), "empty-group")

require.NoError(t, err)
Expand All @@ -214,7 +214,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil)
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil)

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), testGroupName)

require.NoError(t, err)
Expand All @@ -240,7 +240,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) {
mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload").
Return(core.Workload{}, errors.New("workload query failed"))

discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups)
discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil)
backends, err := discoverer.Discover(context.Background(), testGroupName)

require.NoError(t, err)
Expand Down
45 changes: 30 additions & 15 deletions pkg/vmcp/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Package auth provides authentication for Virtual MCP Server.
//
// This package defines:
// - OutgoingAuthenticator: Authenticates vMCP to backend servers
// - OutgoingAuthRegistry: Registry for managing backend authentication strategies
// - Strategy: Pluggable authentication strategies for backends
//
// Incoming authentication uses pkg/auth middleware (OIDC, local, anonymous)
Expand All @@ -17,24 +17,39 @@ import (
"github.com/stacklok/toolhive/pkg/auth"
)

// OutgoingAuthenticator handles authentication to backend MCP servers.
// This is responsible for obtaining and injecting appropriate credentials
// for each backend based on its authentication strategy.
// OutgoingAuthRegistry manages authentication strategies for outgoing requests to backend MCP servers.
// This is a registry that stores and retrieves Strategy implementations.
//
// The specific authentication strategies and their behavior will be defined
// during implementation based on the design decisions documented in the
// Virtual MCP Server proposal.
type OutgoingAuthenticator interface {
// AuthenticateRequest adds authentication to an outgoing backend request.
// The strategy and metadata are provided in the BackendTarget.
AuthenticateRequest(ctx context.Context, req *http.Request, strategy string, metadata map[string]any) error

// GetStrategy returns the authentication strategy handler for a given strategy name.
// This enables extensibility - new strategies can be registered.
// The registry supports dynamic strategy registration, allowing custom authentication
// strategies to be added at runtime. Once registered, strategies can be retrieved
// by name and used to authenticate requests to backends.
//
// Responsibilities:
// - Maintain registry of available strategies
// - Retrieve strategies by name
// - Register new strategies dynamically
//
// This registry does NOT perform authentication itself. Authentication is performed
// by Strategy implementations retrieved from this registry.
//
// Usage Pattern:
// 1. Register strategies during application initialization
// 2. Resolve strategy once at client creation time (cold path)
// 3. Call strategy.Authenticate() directly per-request (hot path)
//
// Thread-safety: Implementations must be safe for concurrent access.
type OutgoingAuthRegistry interface {
// GetStrategy retrieves an authentication strategy by name.
// Returns an error if the strategy is not found.
GetStrategy(name string) (Strategy, error)

// RegisterStrategy registers a new authentication strategy.
// This allows custom auth strategies to be added at runtime.
// The strategy name must match the name returned by strategy.Name().
// Returns an error if:
// - name is empty
// - strategy is nil
// - a strategy with the same name is already registered
// - strategy.Name() does not match the registration name
RegisterStrategy(name string, strategy Strategy) error
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package auth
package factory

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package auth
package factory

import (
"context"
Expand Down
Loading