diff --git a/README.md b/README.md index d9d4c01e..c4bb363a 100644 --- a/README.md +++ b/README.md @@ -479,6 +479,27 @@ In case multi-cluster support is enabled (default) and you have access to multip +### Prompts + +Prompts guide the LLM through multi-step workflows using existing tools. They provide structured instructions for complex operations that benefit from AI reasoning and interpretation. + + + +
+ +core + +- **cluster_health_check** - Guide for performing comprehensive health check on Kubernetes/OpenShift clusters. Provides step-by-step instructions for examining cluster operators, nodes, pods, workloads, storage, and events to identify issues affecting cluster stability. + - `check_events` (`string`) - Include recent warning events in the health check (may increase execution time). Valid values: 'true', 'false', 'yes', 'no', '1', '0'. Default: 'true' + - `output_format` (`string`) - Output format for results: 'text' (human-readable) or 'json' (machine-readable). Valid values: 'text', 'json'. Default: 'text' + - `verbose` (`string`) - Enable detailed output with additional context and resource-level details. Valid values: 'true', 'false', 'yes', 'no', '1', '0'. Default: 'false' + - `namespace` (`string`) - Limit health check to specific namespace (optional, defaults to all namespaces). Valid values: any Kubernetes namespace name or leave empty for all namespaces + +
+ + + + ## 🧑‍💻 Development ### Running with mcp-inspector diff --git a/internal/tools/update-readme/main.go b/internal/tools/update-readme/main.go index 1a9ba276..fab200ce 100644 --- a/internal/tools/update-readme/main.go +++ b/internal/tools/update-readme/main.go @@ -10,8 +10,10 @@ import ( "strings" internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" + "github.com/containers/kubernetes-mcp-server/pkg/promptsets" "github.com/containers/kubernetes-mcp-server/pkg/toolsets" + _ "github.com/containers/kubernetes-mcp-server/pkg/promptsets/core" _ "github.com/containers/kubernetes-mcp-server/pkg/toolsets/config" _ "github.com/containers/kubernetes-mcp-server/pkg/toolsets/core" _ "github.com/containers/kubernetes-mcp-server/pkg/toolsets/helm" @@ -90,6 +92,35 @@ func main() { toolsetTools.String(), ) + // Available Promptset Prompts + promptsetsList := promptsets.PromptSets() + promptsetPrompts := strings.Builder{} + for _, promptset := range promptsetsList { + prompts := promptset.GetPrompts(&OpenShift{}) + if len(prompts) == 0 { + continue + } + promptsetPrompts.WriteString("
\n\n" + promptset.GetName() + "\n\n") + for _, prompt := range prompts { + promptsetPrompts.WriteString(fmt.Sprintf("- **%s** - %s\n", prompt.Name, prompt.Description)) + for _, arg := range prompt.Arguments { + promptsetPrompts.WriteString(fmt.Sprintf(" - `%s` (`%s`)", arg.Name, "string")) + if arg.Required { + promptsetPrompts.WriteString(" **(required)**") + } + promptsetPrompts.WriteString(fmt.Sprintf(" - %s\n", arg.Description)) + } + promptsetPrompts.WriteString("\n") + } + promptsetPrompts.WriteString("
\n\n") + } + updated = replaceBetweenMarkers( + updated, + "", + "", + promptsetPrompts.String(), + ) + if err := os.WriteFile(localReadmePath, []byte(updated), 0o644); err != nil { panic(err) } diff --git a/pkg/api/prompts.go b/pkg/api/prompts.go new file mode 100644 index 00000000..5cd5b436 --- /dev/null +++ b/pkg/api/prompts.go @@ -0,0 +1,36 @@ +package api + +import ( + internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" +) + +// ServerPrompt represents a prompt that can be provided to the MCP server +type ServerPrompt struct { + Name string + Description string + Arguments []PromptArgument + GetMessages func(arguments map[string]string) []PromptMessage +} + +// PromptArgument defines an argument that can be passed to a prompt +type PromptArgument struct { + Name string + Description string + Required bool +} + +// PromptMessage represents a message in a prompt +type PromptMessage struct { + Role string // "user" or "assistant" + Content string +} + +// PromptSet groups related prompts together +type PromptSet interface { + // GetName returns the name of the prompt set + GetName() string + // GetDescription returns a description of what this prompt set provides + GetDescription() string + // GetPrompts returns all prompts in this set + GetPrompts(o internalk8s.Openshift) []ServerPrompt +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 5601e7f0..f58eb17a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -31,6 +31,7 @@ type StaticConfig struct { // When true, disable tools annotated with destructiveHint=true DisableDestructive bool `toml:"disable_destructive,omitempty"` Toolsets []string `toml:"toolsets,omitempty"` + Promptsets []string `toml:"promptsets,omitempty"` EnabledTools []string `toml:"enabled_tools,omitempty"` DisabledTools []string `toml:"disabled_tools,omitempty"` diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 6a4a6d2f..a0e7dd7b 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -9,12 +9,14 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" authenticationapiv1 "k8s.io/api/authentication/v1" + "k8s.io/klog/v2" "k8s.io/utils/ptr" "github.com/containers/kubernetes-mcp-server/pkg/api" "github.com/containers/kubernetes-mcp-server/pkg/config" internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" "github.com/containers/kubernetes-mcp-server/pkg/output" + "github.com/containers/kubernetes-mcp-server/pkg/promptsets" "github.com/containers/kubernetes-mcp-server/pkg/toolsets" "github.com/containers/kubernetes-mcp-server/pkg/version" ) @@ -27,6 +29,7 @@ type Configuration struct { *config.StaticConfig listOutput output.Output toolsets []api.Toolset + promptsets []api.PromptSet } func (c *Configuration) Toolsets() []api.Toolset { @@ -38,6 +41,23 @@ func (c *Configuration) Toolsets() []api.Toolset { return c.toolsets } +func (c *Configuration) Promptsets() []api.PromptSet { + if c.promptsets == nil { + // Default to core if no promptsets configured + promptsetNames := c.StaticConfig.Promptsets + if len(promptsetNames) == 0 { + promptsetNames = []string{"core"} + } + for _, promptset := range promptsetNames { + ps := promptsets.PromptSetFromString(promptset) + if ps != nil { + c.promptsets = append(c.promptsets, ps) + } + } + } + return c.promptsets +} + func (c *Configuration) ListOutput() output.Output { if c.listOutput == nil { c.listOutput = output.FromString(c.StaticConfig.ListOutput) @@ -77,7 +97,7 @@ func NewServer(configuration Configuration) (*Server, error) { }, &mcp.ServerOptions{ HasResources: false, - HasPrompts: false, + HasPrompts: true, HasTools: true, }), } @@ -165,11 +185,42 @@ func (s *Server) reloadKubernetesClusterProvider() error { s.server.AddTool(goSdkTool, goSdkToolHandler) } + // Register prompts + if err := s.registerPrompts(p); err != nil { + klog.Warningf("Failed to register prompts: %v", err) + // Don't fail the whole reload if prompts fail + } + // start new watch s.p.WatchTargets(s.reloadKubernetesClusterProvider) return nil } +// registerPrompts loads and registers all prompts with the MCP server +func (s *Server) registerPrompts(p internalk8s.Provider) error { + allPrompts := make([]api.ServerPrompt, 0) + for _, ps := range s.configuration.Promptsets() { + prompts := ps.GetPrompts(p) + allPrompts = append(allPrompts, prompts...) + klog.V(5).Infof("Loaded %d prompts from promptset '%s'", len(prompts), ps.GetName()) + } + + goSdkPrompts, goSdkHandlers, err := ServerPromptToGoSdkPrompt(s, allPrompts) + if err != nil { + return fmt.Errorf("failed to convert prompts: %v", err) + } + + // Register each prompt with its handler + for name, prompt := range goSdkPrompts { + handler := goSdkHandlers[name] + s.server.AddPrompt(prompt, handler) + } + + klog.V(3).Infof("Registered %d prompts", len(goSdkPrompts)) + + return nil +} + func (s *Server) ServeStdio(ctx context.Context) error { return s.server.Run(ctx, &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr}) } diff --git a/pkg/mcp/mcp_watch_test.go b/pkg/mcp/mcp_watch_test.go index 68287279..a162760d 100644 --- a/pkg/mcp/mcp_watch_test.go +++ b/pkg/mcp/mcp_watch_test.go @@ -54,12 +54,33 @@ func (s *WatchKubeConfigSuite) WaitForNotification() *mcp.JSONRPCNotification { return notification } +// WaitForToolsNotification waits for a tools/list_changed notification specifically +func (s *WatchKubeConfigSuite) WaitForToolsNotification() *mcp.JSONRPCNotification { + withTimeout, cancel := context.WithTimeout(s.T().Context(), 5*time.Second) + defer cancel() + var notification *mcp.JSONRPCNotification + s.OnNotification(func(n mcp.JSONRPCNotification) { + if n.Method == "notifications/tools/list_changed" { + notification = &n + } + }) + for notification == nil { + select { + case <-withTimeout.Done(): + s.FailNow("timeout waiting for tools/list_changed notification") + default: + time.Sleep(100 * time.Millisecond) + } + } + return notification +} + func (s *WatchKubeConfigSuite) TestNotifiesToolsChange() { // Given s.InitMcpClient() // When s.WriteKubeconfig() - notification := s.WaitForNotification() + notification := s.WaitForToolsNotification() // Then s.NotNil(notification, "WatchKubeConfig did not notify") s.Equal("notifications/tools/list_changed", notification.Method, "WatchKubeConfig did not notify tools change") diff --git a/pkg/mcp/modules.go b/pkg/mcp/modules.go index 464eefc8..49632641 100644 --- a/pkg/mcp/modules.go +++ b/pkg/mcp/modules.go @@ -1,5 +1,6 @@ package mcp +import _ "github.com/containers/kubernetes-mcp-server/pkg/promptsets/core" import _ "github.com/containers/kubernetes-mcp-server/pkg/toolsets/config" import _ "github.com/containers/kubernetes-mcp-server/pkg/toolsets/core" import _ "github.com/containers/kubernetes-mcp-server/pkg/toolsets/helm" diff --git a/pkg/mcp/prompts.go b/pkg/mcp/prompts.go new file mode 100644 index 00000000..2f079522 --- /dev/null +++ b/pkg/mcp/prompts.go @@ -0,0 +1,75 @@ +package mcp + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/containers/kubernetes-mcp-server/pkg/api" +) + +// ServerPromptToGoSdkPrompt converts our internal ServerPrompt to go-sdk Prompt format +func ServerPromptToGoSdkPrompt(s *Server, prompts []api.ServerPrompt) (map[string]*mcp.Prompt, map[string]mcp.PromptHandler, error) { + goSdkPrompts := make(map[string]*mcp.Prompt) + goSdkHandlers := make(map[string]mcp.PromptHandler) + + for _, prompt := range prompts { + // Convert arguments to PromptArgument pointers + var arguments []*mcp.PromptArgument + for _, arg := range prompt.Arguments { + arguments = append(arguments, &mcp.PromptArgument{ + Name: arg.Name, + Description: arg.Description, + Required: arg.Required, + }) + } + + goSdkPrompt := &mcp.Prompt{ + Name: prompt.Name, + Description: prompt.Description, + Arguments: arguments, + } + + // Create the prompt handler + handler := createPromptHandler(s, prompt) + + goSdkPrompts[prompt.Name] = goSdkPrompt + goSdkHandlers[prompt.Name] = handler + } + + return goSdkPrompts, goSdkHandlers, nil +} + +// createPromptHandler creates a handler function for a prompt +func createPromptHandler(s *Server, prompt api.ServerPrompt) mcp.PromptHandler { + return func(ctx context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + // Get arguments from the request + params := request.GetParams() + arguments := make(map[string]string) + if params != nil { + // Cast to concrete type to access Arguments field + if getPromptParams, ok := params.(*mcp.GetPromptParams); ok && getPromptParams != nil && getPromptParams.Arguments != nil { + arguments = getPromptParams.Arguments + } + } + + // Get messages from the prompt + promptMessages := prompt.GetMessages(arguments) + + // Convert to mcp-go format - need to use pointers + messages := make([]*mcp.PromptMessage, 0, len(promptMessages)) + for _, msg := range promptMessages { + messages = append(messages, &mcp.PromptMessage{ + Role: mcp.Role(msg.Role), + Content: &mcp.TextContent{ + Text: msg.Content, + }, + }) + } + + return &mcp.GetPromptResult{ + Description: prompt.Description, + Messages: messages, + }, nil + } +} diff --git a/pkg/mcp/prompts_test.go b/pkg/mcp/prompts_test.go new file mode 100644 index 00000000..f74d16ab --- /dev/null +++ b/pkg/mcp/prompts_test.go @@ -0,0 +1,202 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/containers/kubernetes-mcp-server/pkg/api" +) + +func TestServerPromptToGoSdkPrompt(t *testing.T) { + t.Run("Converts empty prompt list", func(t *testing.T) { + // Given + prompts := []api.ServerPrompt{} + + // When + resultPrompts, resultHandlers, err := ServerPromptToGoSdkPrompt(nil, prompts) + + // Then + require.NoError(t, err) + assert.Empty(t, resultPrompts) + assert.Empty(t, resultHandlers) + }) + + t.Run("Converts single prompt correctly", func(t *testing.T) { + // Given + prompts := []api.ServerPrompt{ + { + Name: "test_prompt", + Description: "Test prompt description", + Arguments: []api.PromptArgument{ + { + Name: "arg1", + Description: "Argument 1", + Required: true, + }, + }, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + return []api.PromptMessage{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there"}, + } + }, + }, + } + + // When + resultPrompts, resultHandlers, err := ServerPromptToGoSdkPrompt(nil, prompts) + + // Then + require.NoError(t, err) + require.Len(t, resultPrompts, 1) + require.Len(t, resultHandlers, 1) + + prompt := resultPrompts["test_prompt"] + require.NotNil(t, prompt) + assert.Equal(t, "test_prompt", prompt.Name) + assert.Equal(t, "Test prompt description", prompt.Description) + require.Len(t, prompt.Arguments, 1) + + arg := prompt.Arguments[0] + assert.Equal(t, "arg1", arg.Name) + assert.Equal(t, "Argument 1", arg.Description) + assert.True(t, arg.Required) + }) + + t.Run("Converts multiple prompts correctly", func(t *testing.T) { + // Given + prompts := []api.ServerPrompt{ + { + Name: "prompt1", + Description: "First prompt", + Arguments: []api.PromptArgument{}, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + return []api.PromptMessage{{Role: "user", Content: "test1"}} + }, + }, + { + Name: "prompt2", + Description: "Second prompt", + Arguments: []api.PromptArgument{}, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + return []api.PromptMessage{{Role: "user", Content: "test2"}} + }, + }, + } + + // When + resultPrompts, resultHandlers, err := ServerPromptToGoSdkPrompt(nil, prompts) + + // Then + require.NoError(t, err) + assert.Len(t, resultPrompts, 2) + assert.Len(t, resultHandlers, 2) + assert.NotNil(t, resultPrompts["prompt1"]) + assert.NotNil(t, resultPrompts["prompt2"]) + }) +} + +func TestCreatePromptHandler(t *testing.T) { + t.Run("Handler returns correct messages", func(t *testing.T) { + // Given + prompt := api.ServerPrompt{ + Name: "test", + Description: "Test prompt", + Arguments: []api.PromptArgument{}, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + return []api.PromptMessage{ + {Role: "user", Content: "Test message"}, + {Role: "assistant", Content: "Test response"}, + } + }, + } + + handler := createPromptHandler(nil, prompt) + + // Create request with empty arguments + request := &mcp.GetPromptRequest{ + Params: &mcp.GetPromptParams{ + Name: "test", + Arguments: map[string]string{}, + }, + } + + // When + result, err := handler(context.Background(), request) + + // Then + require.NoError(t, err) + assert.Equal(t, "Test prompt", result.Description) + require.Len(t, result.Messages, 2) + assert.Equal(t, mcp.Role("user"), result.Messages[0].Role) + textContent := result.Messages[0].Content.(*mcp.TextContent) + assert.Equal(t, "Test message", textContent.Text) + assert.Equal(t, mcp.Role("assistant"), result.Messages[1].Role) + textContent2 := result.Messages[1].Content.(*mcp.TextContent) + assert.Equal(t, "Test response", textContent2.Text) + }) + + t.Run("Handler uses provided arguments", func(t *testing.T) { + // Given + prompt := api.ServerPrompt{ + Name: "test", + Description: "Test prompt", + Arguments: []api.PromptArgument{ + {Name: "param1", Description: "Parameter 1", Required: false}, + }, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + value := arguments["param1"] + return []api.PromptMessage{ + {Role: "user", Content: "Value is: " + value}, + } + }, + } + + handler := createPromptHandler(nil, prompt) + + // Create request with arguments + request := &mcp.GetPromptRequest{ + Params: &mcp.GetPromptParams{ + Name: "test", + Arguments: map[string]string{"param1": "test_value"}, + }, + } + + // When + result, err := handler(context.Background(), request) + + // Then + require.NoError(t, err) + require.Len(t, result.Messages, 1) + textContent := result.Messages[0].Content.(*mcp.TextContent) + assert.Equal(t, "Value is: test_value", textContent.Text) + }) + + t.Run("Handler handles nil arguments", func(t *testing.T) { + // Given + prompt := api.ServerPrompt{ + Name: "test", + Description: "Test prompt", + Arguments: []api.PromptArgument{}, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + return []api.PromptMessage{{Role: "user", Content: "test"}} + }, + } + + handler := createPromptHandler(nil, prompt) + + // Create request with no params + request := &mcp.GetPromptRequest{} + + // When + result, err := handler(context.Background(), request) + + // Then + require.NoError(t, err) + require.Len(t, result.Messages, 1) + }) +} diff --git a/pkg/promptsets/core/health_check.go b/pkg/promptsets/core/health_check.go new file mode 100644 index 00000000..aea7f5e0 --- /dev/null +++ b/pkg/promptsets/core/health_check.go @@ -0,0 +1,328 @@ +package core + +import ( + "fmt" + "strings" + + "github.com/containers/kubernetes-mcp-server/pkg/api" +) + +const ( + // Health check configuration constants + defaultRestartThreshold = 5 + eventLookbackMinutes = 30 + maxWarningEvents = 20 +) + +// isVerboseEnabled checks if the verbose flag is enabled. +// It accepts "true", "1", "yes", or "y" (case-insensitive) as truthy values. +func isVerboseEnabled(value string) bool { + switch strings.ToLower(value) { + case "true", "1", "yes", "y": + return true + default: + return false + } +} + +// isBooleanEnabled checks if a boolean flag is enabled. +// It accepts "true", "1", "yes", or "y" (case-insensitive) as truthy values. +// If the value is empty and a default is provided, it returns the default value. +func isBooleanEnabled(value string, defaultValue bool) bool { + if value == "" { + return defaultValue + } + switch strings.ToLower(value) { + case "true", "1", "yes", "y": + return true + case "false", "0", "no", "n": + return false + default: + return defaultValue + } +} + +// getEmojiInstructions returns emoji usage instructions based on output format. +// Emojis are only recommended for text format output. +func getEmojiInstructions(outputFormat string) string { + if outputFormat == "json" { + return "" + } + return "\n- Use emojis for visual clarity: ✅ (healthy), ⚠️ (warning), ❌ (critical)" +} + +// initHealthCheckPrompts creates prompts for cluster health diagnostics. +// These prompts guide LLMs to systematically check cluster components using existing tools. +func initHealthCheckPrompts() []api.ServerPrompt { + return []api.ServerPrompt{ + { + Name: "cluster_health_check", + Description: "Guide for performing comprehensive health check on Kubernetes/OpenShift clusters. Provides step-by-step instructions for examining cluster operators, nodes, pods, workloads, storage, and events to identify issues affecting cluster stability.", + Arguments: []api.PromptArgument{ + { + Name: "check_events", + Description: "Include recent warning events in the health check (may increase execution time). Valid values: 'true', 'false', 'yes', 'no', '1', '0'. Default: 'true'", + Required: false, + }, + { + Name: "output_format", + Description: "Output format for results: 'text' (human-readable) or 'json' (machine-readable). Valid values: 'text', 'json'. Default: 'text'", + Required: false, + }, + { + Name: "verbose", + Description: "Enable detailed output with additional context and resource-level details. Valid values: 'true', 'false', 'yes', 'no', '1', '0'. Default: 'false'", + Required: false, + }, + { + Name: "namespace", + Description: "Limit health check to specific namespace (optional, defaults to all namespaces). Valid values: any Kubernetes namespace name or leave empty for all namespaces", + Required: false, + }, + }, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + verbose := isVerboseEnabled(arguments["verbose"]) + namespace := arguments["namespace"] + checkEvents := isBooleanEnabled(arguments["check_events"], true) // Default to true + outputFormat := arguments["output_format"] + if outputFormat == "" { + outputFormat = "text" // Default to text + } + + return buildHealthCheckPromptMessages(verbose, namespace, checkEvents, outputFormat) + }, + }, + } +} + +// buildHealthCheckPromptMessages constructs the prompt messages for cluster health checks. +// It adapts the instructions based on verbose mode, namespace filtering, event checking, and output format. +func buildHealthCheckPromptMessages(verbose bool, namespace string, checkEvents bool, outputFormat string) []api.PromptMessage { + scopeMsg := "across all namespaces" + podListInstruction := "- Use pods_list to get all pods" + + if namespace != "" { + scopeMsg = fmt.Sprintf("in namespace '%s'", namespace) + podListInstruction = fmt.Sprintf("- Use pods_list_in_namespace with namespace '%s'", namespace) + } + + verboseMsg := "" + if verbose { + verboseMsg = "\n\nFor verbose mode, include additional details such as:\n" + + "- Specific error messages from conditions\n" + + "- Resource-level details (CPU/memory pressure types)\n" + + "- Individual pod and deployment names\n" + + "- Event messages and timestamps" + } + + // Construct the event display range dynamically using maxWarningEvents + eventDisplayRange := fmt.Sprintf("10-%d", maxWarningEvents) + + // Build events section conditionally + eventsCheckSection := "" + eventsOutputSection := "" + eventsToolMention := "" + if checkEvents { + eventsCheckSection = fmt.Sprintf(` + +## 6. Check Recent Events +- Use events_list to get cluster events +- Filter for: + * Type = Warning + * Timestamp within last %d minutes +- Limit to %s most recent warnings +- Include event message and involved object`, eventLookbackMinutes, eventDisplayRange) + + eventsOutputSection = fmt.Sprintf(` + +### Recent Events +[Warning events from last %d minutes]`, eventLookbackMinutes) + + eventsToolMention = ", events_list, etc." + } + + // Build output format instructions + outputFormatInstructions := "" + if outputFormat == "json" { + outputFormatInstructions = ` + +## Output Format + +Structure your health check report as a JSON object with the following schema: + +` + "```json" + ` +{ + "cluster_type": "Kubernetes|OpenShift", + "cluster_version": "version string if determinable", + "check_time": "ISO 8601 timestamp", + "scope": "all namespaces|namespace: ", + "cluster_operators": { + "total": 0, + "degraded": 0, + "unavailable": 0, + "progressing": 0, + "issues": [] + }, + "node_health": { + "total": 0, + "not_ready": 0, + "unschedulable": 0, + "under_pressure": 0, + "issues": [] + }, + "pod_health": { + "total": 0, + "failed": 0, + "crash_looping": 0, + "image_pull_errors": 0, + "high_restarts": 0, + "issues": [] + }, + "workload_controllers": { + "deployments": {"total": 0, "unhealthy": 0}, + "statefulsets": {"total": 0, "unhealthy": 0}, + "daemonsets": {"total": 0, "unhealthy": 0}, + "issues": [] + }, + "storage": { + "total": 0, + "bound": 0, + "unbound": 0, + "issues": [] + },` + eventsOutputSection + ` + "summary": { + "critical_issues": 0, + "warnings": 0, + "overall_status": "healthy|has_warnings|has_critical_issues" + } +} +` + "```" + } else { + // Text format (default) + outputFormatInstructions = fmt.Sprintf(` + +## Output Format + +Structure your health check report as follows: + +`+"```"+` +================================================ +Cluster Health Check Report +================================================ +Cluster Type: [Kubernetes/OpenShift] +Cluster Version: [if determinable] +Check Time: [current timestamp] +Scope: [all namespaces / specific namespace] + +### Cluster Operators (OpenShift only) +[Status with counts and specific issues] + +### Node Health +[Status with counts: total, not ready, unschedulable, under pressure] + +### Pod Health +[Status with counts: total, failed, crash looping, image pull errors, high restarts] + +### Workload Controllers +[Status for Deployments, StatefulSets, DaemonSets] + +### Storage +[PVC status: total, bound, pending/other]%s + +================================================ +Summary +================================================ +Critical Issues: [count] +Warnings: [count] + +[Overall assessment: healthy / has warnings / has critical issues] +`+"```", eventsOutputSection) + } + + userMessage := fmt.Sprintf(`Please perform a comprehensive health check on the Kubernetes cluster %s. + +Follow these steps systematically: + +## 1. Check Cluster-Level Components + +### For OpenShift Clusters: +- Use resources_list with apiVersion=config.openshift.io/v1 and kind=ClusterOperator to check cluster operator health +- Look for operators with: + * Degraded=True (CRITICAL) + * Available=False (CRITICAL) + * Progressing=True (WARNING) + +### For All Kubernetes Clusters: +- Verify if this is an OpenShift cluster by checking for OpenShift-specific resources +- Note the cluster type in your report + +## 2. Check Node Health +- Use resources_list with apiVersion=v1 and kind=Node to examine all nodes +- Check each node for: + * Ready condition != True (CRITICAL) + * Unschedulable spec field = true (WARNING) + * MemoryPressure, DiskPressure, or PIDPressure conditions = True (WARNING) +- Count total nodes and categorize issues + +## 3. Check Pod Health +%s +- Identify problematic pods: + * Phase = Failed or Pending (CRITICAL) + * Container state waiting with reason: + - CrashLoopBackOff (CRITICAL) + - ImagePullBackOff or ErrImagePull (CRITICAL) + * RestartCount > %d (WARNING - configurable threshold) +- Group issues by type and count occurrences%s + +## 4. Check Workload Controllers +- Use resources_list for each workload type: + * apiVersion=apps/v1, kind=Deployment + * apiVersion=apps/v1, kind=StatefulSet + * apiVersion=apps/v1, kind=DaemonSet +- For each controller, compare: + * spec.replicas vs status.readyReplicas (Deployment/StatefulSet) + * status.desiredNumberScheduled vs status.numberReady (DaemonSet) + * Report mismatches as WARNINGs + +## 5. Check Storage +- Use resources_list with apiVersion=v1 and kind=PersistentVolumeClaim +- Identify PVCs not in Bound phase (WARNING) +- Note namespace and PVC name for each issue%s%s + +## Health Status Definitions + +- **CRITICAL**: Issues requiring immediate attention (e.g., pods failing, nodes not ready, degraded operators) +- **WARNING**: Issues that should be monitored (e.g., high restarts, progressing operators, resource pressure) +- **HEALTHY**: No issues detected + +## Important Notes + +- Use the existing tools (resources_list, pods_list%s) +- Be efficient: don't call the same tool multiple times unnecessarily +- If a resource type doesn't exist (e.g., ClusterOperator on vanilla K8s), skip it gracefully +- Provide clear, actionable insights in your summary%s + +### Common apiVersion Values + +When using resources_list, specify the correct apiVersion for each resource type: +- Core resources: apiVersion=v1 (Pod, Service, Node, PersistentVolumeClaim, ConfigMap, Secret, Namespace) +- Apps: apiVersion=apps/v1 (Deployment, StatefulSet, DaemonSet, ReplicaSet) +- Batch: apiVersion=batch/v1 (Job, CronJob) +- RBAC: apiVersion=rbac.authorization.k8s.io/v1 (Role, RoleBinding, ClusterRole, ClusterRoleBinding) +- Networking: apiVersion=networking.k8s.io/v1 (Ingress, NetworkPolicy) +- OpenShift Config: apiVersion=config.openshift.io/v1 (ClusterOperator, ClusterVersion) +- OpenShift Routes: apiVersion=route.openshift.io/v1 (Route)`, scopeMsg, podListInstruction, defaultRestartThreshold, verboseMsg, eventsCheckSection, outputFormatInstructions, eventsToolMention, getEmojiInstructions(outputFormat)) + + assistantMessage := `I'll perform a comprehensive cluster health check following the systematic approach outlined. Let me start by gathering information about the cluster components.` + + return []api.PromptMessage{ + { + Role: "user", + Content: userMessage, + }, + { + Role: "assistant", + Content: assistantMessage, + }, + } +} diff --git a/pkg/promptsets/core/health_check_test.go b/pkg/promptsets/core/health_check_test.go new file mode 100644 index 00000000..305666fa --- /dev/null +++ b/pkg/promptsets/core/health_check_test.go @@ -0,0 +1,479 @@ +package core + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsVerboseEnabled(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"true lowercase", "true", true}, + {"true capitalized", "True", true}, + {"true uppercase", "TRUE", true}, + {"numeric 1", "1", true}, + {"yes lowercase", "yes", true}, + {"yes capitalized", "Yes", true}, + {"yes uppercase", "YES", true}, + {"y lowercase", "y", true}, + {"y uppercase", "Y", true}, + {"false", "false", false}, + {"0", "0", false}, + {"no", "no", false}, + {"empty string", "", false}, + {"random string", "random", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isVerboseEnabled(tt.input) + assert.Equal(t, tt.expected, result, "isVerboseEnabled(%q) should return %v", tt.input, tt.expected) + }) + } +} + +func TestInitHealthCheckPrompts(t *testing.T) { + // When + prompts := initHealthCheckPrompts() + + // Then + require.Len(t, prompts, 1) + assert.Equal(t, "cluster_health_check", prompts[0].Name) + assert.Contains(t, prompts[0].Description, "comprehensive health check") + assert.Len(t, prompts[0].Arguments, 4) + + // Check arguments + assert.Equal(t, "check_events", prompts[0].Arguments[0].Name) + assert.False(t, prompts[0].Arguments[0].Required) + + assert.Equal(t, "output_format", prompts[0].Arguments[1].Name) + assert.False(t, prompts[0].Arguments[1].Required) + + assert.Equal(t, "verbose", prompts[0].Arguments[2].Name) + assert.False(t, prompts[0].Arguments[2].Required) + + assert.Equal(t, "namespace", prompts[0].Arguments[3].Name) + assert.False(t, prompts[0].Arguments[3].Required) +} + +func TestBuildHealthCheckPromptMessages(t *testing.T) { + t.Run("Default messages with no arguments", func(t *testing.T) { + // When - checkEvents=true (default), outputFormat="text" (default) + messages := buildHealthCheckPromptMessages(false, "", true, "text") + + // Then + require.Len(t, messages, 2) + assert.Equal(t, "user", messages[0].Role) + assert.Equal(t, "assistant", messages[1].Role) + + // Check user message content + userContent := messages[0].Content + assert.Contains(t, userContent, "across all namespaces") + assert.Contains(t, userContent, "Use pods_list to get all pods") + assert.Contains(t, userContent, "resources_list") + assert.Contains(t, userContent, "Check Recent Events") + assert.NotContains(t, userContent, "pods_list_in_namespace") + + // Check assistant message + assert.Contains(t, messages[1].Content, "comprehensive cluster health check") + }) + + t.Run("Messages with namespace filter", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "test-namespace", true, "text") + + // Then + require.Len(t, messages, 2) + + userContent := messages[0].Content + assert.Contains(t, userContent, "in namespace 'test-namespace'") + assert.NotContains(t, userContent, "across all namespaces") + assert.Contains(t, userContent, "Use pods_list_in_namespace with namespace 'test-namespace'") + assert.NotContains(t, userContent, "Use pods_list to get all pods") + }) + + t.Run("Messages with verbose mode", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(true, "", true, "text") + + // Then + require.Len(t, messages, 2) + + userContent := messages[0].Content + assert.Contains(t, userContent, "For verbose mode") + assert.Contains(t, userContent, "Specific error messages") + assert.Contains(t, userContent, "Resource-level details") + assert.Contains(t, userContent, "Individual pod and deployment names") + }) + + t.Run("Messages with both verbose and namespace", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(true, "prod", true, "text") + + // Then + require.Len(t, messages, 2) + + userContent := messages[0].Content + assert.Contains(t, userContent, "in namespace 'prod'") + assert.Contains(t, userContent, "For verbose mode") + }) + + t.Run("Messages with checkEvents disabled", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "", false, "text") + + // Then + require.Len(t, messages, 2) + + userContent := messages[0].Content + assert.NotContains(t, userContent, "Check Recent Events") + assert.NotContains(t, userContent, "events_list") + }) + + t.Run("Messages with JSON output format", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "", true, "json") + + // Then + require.Len(t, messages, 2) + + userContent := messages[0].Content + assert.Contains(t, userContent, "JSON object") + assert.Contains(t, userContent, "cluster_type") + assert.Contains(t, userContent, "node_health") + assert.Contains(t, userContent, "pod_health") + assert.NotContains(t, userContent, "✅") + assert.NotContains(t, userContent, "⚠️") + }) + + t.Run("User message contains all required sections", func(t *testing.T) { + // When - with checkEvents enabled + messages := buildHealthCheckPromptMessages(false, "", true, "text") + + // Then + userContent := messages[0].Content + + // Check for all main sections + sections := []string{ + "## 1. Check Cluster-Level Components", + "## 2. Check Node Health", + "## 3. Check Pod Health", + "## 4. Check Workload Controllers", + "## 5. Check Storage", + "## 6. Check Recent Events", + "## Output Format", + "## Health Status Definitions", + "## Important Notes", + } + + for _, section := range sections { + assert.Contains(t, userContent, section, "Missing section: %s", section) + } + }) + + t.Run("User message contains critical tool references", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "", true, "text") + + // Then + userContent := messages[0].Content + + // Check for tool names + tools := []string{ + "resources_list", + "pods_list", + } + + for _, tool := range tools { + assert.Contains(t, userContent, tool, "Missing tool reference: %s", tool) + } + }) + + t.Run("User message contains health check criteria", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "", true, "text") + + // Then + userContent := messages[0].Content + + // Check for critical conditions + criteria := []string{ + "Degraded=True (CRITICAL)", + "Available=False (CRITICAL)", + "Ready condition != True (CRITICAL)", + "CrashLoopBackOff (CRITICAL)", + "ImagePullBackOff", + "RestartCount > 5 (WARNING", + "MemoryPressure", + "DiskPressure", + } + + for _, criterion := range criteria { + assert.Contains(t, userContent, criterion, "Missing criterion: %s", criterion) + } + }) + + t.Run("User message contains workload types with apiVersions", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "", true, "text") + + // Then + userContent := messages[0].Content + + // Check for apiVersion + kind pairs + resourceSpecs := []string{ + "apiVersion=apps/v1, kind=Deployment", + "apiVersion=apps/v1, kind=StatefulSet", + "apiVersion=apps/v1, kind=DaemonSet", + "apiVersion=config.openshift.io/v1 and kind=ClusterOperator", + "apiVersion=v1 and kind=Node", + "apiVersion=v1 and kind=PersistentVolumeClaim", + } + + for _, spec := range resourceSpecs { + assert.Contains(t, userContent, spec, "Missing resource spec: %s", spec) + } + }) + + t.Run("User message contains output format template", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "", true, "text") + + // Then + userContent := messages[0].Content + + // Check for report structure + reportElements := []string{ + "Cluster Health Check Report", + "Cluster Type:", + "### Cluster Operators", + "### Node Health", + "### Pod Health", + "### Workload Controllers", + "### Storage", + "### Recent Events", + "Summary", + "Critical Issues:", + "Warnings:", + } + + for _, element := range reportElements { + assert.Contains(t, userContent, element, "Missing report element: %s", element) + } + }) + + t.Run("User message does not reference non-existent tools", func(t *testing.T) { + // When + messages := buildHealthCheckPromptMessages(false, "", true, "text") + + // Then + userContent := messages[0].Content + + // Make sure we're not referencing the old tool name + assert.NotContains(t, userContent, "pods_list_in_all_namespaces") + }) +} + +func TestGetMessagesWithArguments(t *testing.T) { + // Given + prompts := initHealthCheckPrompts() + require.Len(t, prompts, 1) + + getMessages := prompts[0].GetMessages + + t.Run("With no arguments", func(t *testing.T) { + // When + messages := getMessages(map[string]string{}) + + // Then + require.Len(t, messages, 2) + userContent := messages[0].Content + assert.Contains(t, userContent, "across all namespaces") + assert.NotContains(t, userContent, "For verbose mode") + // Default is checkEvents=true + assert.Contains(t, userContent, "Check Recent Events") + }) + + t.Run("With verbose=true", func(t *testing.T) { + // When + messages := getMessages(map[string]string{"verbose": "true"}) + + // Then + require.Len(t, messages, 2) + userContent := messages[0].Content + assert.Contains(t, userContent, "For verbose mode") + }) + + t.Run("With verbose=false", func(t *testing.T) { + // When + messages := getMessages(map[string]string{"verbose": "false"}) + + // Then + require.Len(t, messages, 2) + userContent := messages[0].Content + assert.NotContains(t, userContent, "For verbose mode") + }) + + t.Run("With namespace", func(t *testing.T) { + // When + messages := getMessages(map[string]string{"namespace": "kube-system"}) + + // Then + require.Len(t, messages, 2) + userContent := messages[0].Content + assert.Contains(t, userContent, "in namespace 'kube-system'") + }) + + t.Run("With all arguments", func(t *testing.T) { + // When + messages := getMessages(map[string]string{ + "verbose": "true", + "namespace": "default", + "check_events": "false", + "output_format": "json", + }) + + // Then + require.Len(t, messages, 2) + userContent := messages[0].Content + assert.Contains(t, userContent, "For verbose mode") + assert.Contains(t, userContent, "in namespace 'default'") + assert.NotContains(t, userContent, "Check Recent Events") + assert.Contains(t, userContent, "JSON object") + }) + + t.Run("With check_events=false", func(t *testing.T) { + // When + messages := getMessages(map[string]string{"check_events": "false"}) + + // Then + require.Len(t, messages, 2) + userContent := messages[0].Content + assert.NotContains(t, userContent, "Check Recent Events") + }) + + t.Run("With output_format=json", func(t *testing.T) { + // When + messages := getMessages(map[string]string{"output_format": "json"}) + + // Then + require.Len(t, messages, 2) + userContent := messages[0].Content + assert.Contains(t, userContent, "JSON object") + assert.Contains(t, userContent, "cluster_type") + }) +} + +func TestHealthCheckPromptCompleteness(t *testing.T) { + // This test ensures the prompt covers all essential aspects + + messages := buildHealthCheckPromptMessages(false, "", true, "text") + userContent := messages[0].Content + + t.Run("Covers all Kubernetes resource types", func(t *testing.T) { + resourceTypes := []string{ + "Node", + "Pod", + "Deployment", + "StatefulSet", + "DaemonSet", + "PersistentVolumeClaim", + "ClusterOperator", // OpenShift specific + } + + for _, rt := range resourceTypes { + assert.Contains(t, userContent, rt, "Missing resource type: %s", rt) + } + }) + + t.Run("Provides clear severity levels", func(t *testing.T) { + assert.Contains(t, userContent, "CRITICAL") + assert.Contains(t, userContent, "WARNING") + assert.Contains(t, userContent, "HEALTHY") + }) + + t.Run("Includes efficiency guidelines", func(t *testing.T) { + assert.Contains(t, userContent, "Be efficient") + assert.Contains(t, userContent, "don't call the same tool multiple times unnecessarily") + }) + + t.Run("Handles OpenShift gracefully", func(t *testing.T) { + assert.Contains(t, userContent, "For OpenShift Clusters") + assert.Contains(t, userContent, "For All Kubernetes Clusters") + assert.Contains(t, userContent, "skip it gracefully") + }) + + t.Run("Instructions are clear and actionable", func(t *testing.T) { + // Check that the prompt uses imperative language + imperativeVerbs := []string{"Use", "Check", "Look for", "Verify", "Identify", "Compare"} + foundVerbs := 0 + for _, verb := range imperativeVerbs { + if strings.Contains(userContent, verb) { + foundVerbs++ + } + } + assert.Greater(t, foundVerbs, 3, "Prompt should use clear imperative language") + }) + + t.Run("Includes apiVersion reference section", func(t *testing.T) { + assert.Contains(t, userContent, "Common apiVersion Values") + assert.Contains(t, userContent, "apiVersion=config.openshift.io/v1") + assert.Contains(t, userContent, "apiVersion=apps/v1") + assert.Contains(t, userContent, "apiVersion=v1") + assert.Contains(t, userContent, "ClusterOperator, ClusterVersion") + }) +} + +func TestIsBooleanEnabled(t *testing.T) { + tests := []struct { + name string + input string + defaultValue bool + expected bool + }{ + {"empty with default true", "", true, true}, + {"empty with default false", "", false, false}, + {"true lowercase", "true", false, true}, + {"true uppercase", "TRUE", false, true}, + {"false lowercase", "false", true, false}, + {"false uppercase", "FALSE", true, false}, + {"yes", "yes", false, true}, + {"no", "no", true, false}, + {"1", "1", false, true}, + {"0", "0", true, false}, + {"invalid with default true", "invalid", true, true}, + {"invalid with default false", "invalid", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isBooleanEnabled(tt.input, tt.defaultValue) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetEmojiInstructions(t *testing.T) { + t.Run("Returns emoji instructions for text format", func(t *testing.T) { + result := getEmojiInstructions("text") + assert.Contains(t, result, "✅") + assert.Contains(t, result, "⚠️") + assert.Contains(t, result, "❌") + }) + + t.Run("Returns empty for json format", func(t *testing.T) { + result := getEmojiInstructions("json") + assert.Empty(t, result) + }) + + t.Run("Returns emoji instructions for other formats", func(t *testing.T) { + result := getEmojiInstructions("yaml") + assert.Contains(t, result, "✅") + }) +} diff --git a/pkg/promptsets/core/promptset.go b/pkg/promptsets/core/promptset.go new file mode 100644 index 00000000..8bd35785 --- /dev/null +++ b/pkg/promptsets/core/promptset.go @@ -0,0 +1,39 @@ +package core + +import ( + "github.com/containers/kubernetes-mcp-server/pkg/api" + internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" + "github.com/containers/kubernetes-mcp-server/pkg/promptsets" +) + +const ( + Name = "core" + Description = "Core prompts for common Kubernetes/OpenShift operations including cluster health diagnostics" +) + +type PromptSet struct{} + +func (t *PromptSet) GetName() string { + return Name +} + +func (t *PromptSet) GetDescription() string { + return Description +} + +func (t *PromptSet) GetPrompts(o internalk8s.Openshift) []api.ServerPrompt { + prompts := make([]api.ServerPrompt, 0) + + // Health check prompts + prompts = append(prompts, initHealthCheckPrompts()...) + + // Future: Add more prompts here + // prompts = append(prompts, initTroubleshootingPrompts(o)...) + // prompts = append(prompts, initDeploymentPrompts(o)...) + + return prompts +} + +func init() { + promptsets.Register(&PromptSet{}) +} diff --git a/pkg/promptsets/promptsets.go b/pkg/promptsets/promptsets.go new file mode 100644 index 00000000..e140aa0d --- /dev/null +++ b/pkg/promptsets/promptsets.go @@ -0,0 +1,50 @@ +package promptsets + +import ( + "slices" + "strings" + + "github.com/containers/kubernetes-mcp-server/pkg/api" +) + +var promptsets []api.PromptSet + +// Clear removes all registered promptsets, TESTING PURPOSES ONLY. +func Clear() { + promptsets = []api.PromptSet{} +} + +// Register adds a promptset to the registry +func Register(promptset api.PromptSet) { + promptsets = append(promptsets, promptset) +} + +// PromptSets returns all registered promptsets +func PromptSets() []api.PromptSet { + return promptsets +} + +// PromptSetFromString returns a PromptSet by name, or nil if not found +func PromptSetFromString(name string) api.PromptSet { + for _, ps := range PromptSets() { + if ps.GetName() == strings.TrimSpace(name) { + return ps + } + } + return nil +} + +// AllPromptSets returns all available promptsets +func AllPromptSets() []api.PromptSet { + return PromptSets() +} + +// GetPromptSetNames returns names of all registered promptsets +func GetPromptSetNames() []string { + names := make([]string, 0, len(promptsets)) + for _, ps := range promptsets { + names = append(names, ps.GetName()) + } + slices.Sort(names) + return names +} diff --git a/pkg/promptsets/promptsets_test.go b/pkg/promptsets/promptsets_test.go new file mode 100644 index 00000000..31764361 --- /dev/null +++ b/pkg/promptsets/promptsets_test.go @@ -0,0 +1,138 @@ +package promptsets + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" +) + +type PromptSetsSuite struct { + suite.Suite +} + +func (s *PromptSetsSuite) SetupTest() { + // Clear the registry before each test + Clear() +} + +func (s *PromptSetsSuite) TestRegister() { + // Given + testPS := &testPromptSet{name: "test"} + + // When + Register(testPS) + + // Then + assert.Equal(s.T(), 1, len(PromptSets())) + assert.Equal(s.T(), testPS, PromptSets()[0]) +} + +func (s *PromptSetsSuite) TestPromptSetFromString() { + s.Run("Returns nil if promptset not found", func() { + // When + ps := PromptSetFromString("nonexistent") + + // Then + assert.Nil(s.T(), ps) + }) + + s.Run("Returns the correct promptset if found", func() { + // Given + testPS := &testPromptSet{name: "test"} + Register(testPS) + + // When + ps := PromptSetFromString("test") + + // Then + assert.Equal(s.T(), testPS, ps) + assert.Equal(s.T(), "test", ps.GetName()) + }) + + s.Run("Returns the correct promptset if found after trimming spaces", func() { + // Given + testPS := &testPromptSet{name: "test"} + Register(testPS) + + // When + ps := PromptSetFromString(" test ") + + // Then + assert.Equal(s.T(), testPS, ps) + }) +} + +func (s *PromptSetsSuite) TestAllPromptSets() { + // Given + testPS1 := &testPromptSet{name: "test1"} + testPS2 := &testPromptSet{name: "test2"} + Register(testPS1) + Register(testPS2) + + // When + all := AllPromptSets() + + // Then + assert.Equal(s.T(), 2, len(all)) + assert.Contains(s.T(), all, testPS1) + assert.Contains(s.T(), all, testPS2) +} + +func (s *PromptSetsSuite) TestGetPromptSetNames() { + s.Run("Returns empty slice when no promptsets registered", func() { + // When + names := GetPromptSetNames() + + // Then + assert.Empty(s.T(), names) + }) + + s.Run("Returns sorted names of all registered promptsets", func() { + // Given + Register(&testPromptSet{name: "zebra"}) + Register(&testPromptSet{name: "alpha"}) + Register(&testPromptSet{name: "beta"}) + + // When + names := GetPromptSetNames() + + // Then + assert.Equal(s.T(), []string{"alpha", "beta", "zebra"}, names) + }) +} + +func TestPromptSets(t *testing.T) { + suite.Run(t, new(PromptSetsSuite)) +} + +// Test helper +type testPromptSet struct { + name string +} + +func (t *testPromptSet) GetName() string { + return t.name +} + +func (t *testPromptSet) GetDescription() string { + return "Test promptset" +} + +func (t *testPromptSet) GetPrompts(o internalk8s.Openshift) []api.ServerPrompt { + return []api.ServerPrompt{ + { + Name: "test_prompt", + Description: "Test prompt", + Arguments: []api.PromptArgument{}, + GetMessages: func(arguments map[string]string) []api.PromptMessage { + return []api.PromptMessage{ + {Role: "user", Content: "test"}, + } + }, + }, + } +}