Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var (
streamFlag bool // Enable streaming output
compactMode bool // Enable compact output mode
scriptMCPConfig *config.Config // Used to override config in script mode
approveToolRun bool

// Session management
saveSessionPath string
Expand Down Expand Up @@ -199,7 +200,6 @@ func InitConfig() {
viper.Set("hooks", hooksConfig)
}
}

}

// LoadConfigWithEnvSubstitution loads a config file with environment variable substitution
Expand Down Expand Up @@ -284,6 +284,8 @@ func init() {
BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling")
rootCmd.PersistentFlags().
BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution")
rootCmd.PersistentFlags().
BoolVar(&approveToolRun, "approve-tool-run", false, "enable requiring user approval for every tool call")

// Session management flags
rootCmd.PersistentFlags().
Expand Down Expand Up @@ -329,6 +331,7 @@ func init() {
viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify"))
viper.BindPFlag("approve-tool-run", rootCmd.PersistentFlags().Lookup("approve-tool-run"))

// Defaults are already set in flag definitions, no need to duplicate in viper

Expand Down Expand Up @@ -427,7 +430,8 @@ func runNormalMode(ctx context.Context) error {
debugLogger = bufferedLogger
}

mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ModelConfig: modelConfig,
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
ModelConfig: modelConfig,
MCPConfig: mcpConfig,
SystemPrompt: systemPrompt,
MaxSteps: viper.GetInt("max-steps"),
Expand Down Expand Up @@ -725,7 +729,8 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet flag can only be used with --prompt/-p")
}

return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor)
approveToolRun := viper.GetBool("approve-tool-run")
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor, approveToolRun)
}

// AgenticLoopConfig configures the behavior of the unified agentic loop
Expand All @@ -734,6 +739,7 @@ type AgenticLoopConfig struct {
IsInteractive bool // true for interactive mode, false for non-interactive
InitialPrompt string // initial prompt for non-interactive mode
ContinueAfterRun bool // true to continue to interactive mode after initial run (--no-exit)
ApproveToolRun bool // only used in interactive mode

// UI configuration
Quiet bool // suppress all output except final response
Expand Down Expand Up @@ -1083,7 +1089,27 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
currentSpinner.Start()
}
},
streamingCallback, // Add streaming callback as the last parameter
// Add streaming callback handler
streamingCallback,
// Tool call approval handler - called before tool execution to get user approval
func(toolName, toolArgs string) (bool, error) {
if !config.IsInteractive || !config.ApproveToolRun {
return true, nil
}
if currentSpinner != nil {
currentSpinner.Stop()
currentSpinner = nil
}
allow, err := cli.GetToolApproval(toolName, toolArgs)
if err != nil {
return false, err
}
// Start spinner again for tool calls
currentSpinner = ui.NewSpinner("Thinking...")
currentSpinner.Start()

return allow, nil
},
)

// Make sure spinner is stopped if still running
Expand Down Expand Up @@ -1286,6 +1312,7 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
IsInteractive: false,
InitialPrompt: prompt,
ContinueAfterRun: noExit,
ApproveToolRun: false,
Quiet: quiet,
ServerNames: serverNames,
ToolNames: toolNames,
Expand All @@ -1298,12 +1325,13 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
}

// runInteractiveMode handles the interactive mode execution
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error {
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor, approveToolRun bool) error {
// Configure and run unified agentic loop
config := AgenticLoopConfig{
IsInteractive: true,
InitialPrompt: "",
ContinueAfterRun: false,
ApproveToolRun: approveToolRun,
Quiet: false,
ServerNames: serverNames,
ToolNames: toolNames,
Expand Down
31 changes: 24 additions & 7 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"

tea "github.com/charmbracelet/bubbletea"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
Expand All @@ -12,8 +15,6 @@ import (
"github.com/mark3labs/mcphost/internal/config"
"github.com/mark3labs/mcphost/internal/models"
"github.com/mark3labs/mcphost/internal/tools"
"strings"
"time"
)

// AgentConfig is the config for agent.
Expand Down Expand Up @@ -44,6 +45,9 @@ type StreamingResponseHandler func(content string)
// ToolCallContentHandler is a function type for handling content that accompanies tool calls
type ToolCallContentHandler func(content string)

// ToolApprovalHandler is a function type for handling user approval of tool calls
type ToolApprovalHandler func(toolName, toolArgs string) (bool, error)

// Agent is the agent with real-time tool call display.
type Agent struct {
toolManager *tools.MCPToolManager
Expand Down Expand Up @@ -106,15 +110,15 @@ type GenerateWithLoopResult struct {

// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {

return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil)
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onToolApproval ToolApprovalHandler,
) (*GenerateWithLoopResult, error) {
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil, onToolApproval)
}

// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) {

onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler, onToolApproval ToolApprovalHandler,
) (*GenerateWithLoopResult, error) {
// Create a copy of messages to avoid modifying the original
workingMessages := make([]*schema.Message, len(messages))
copy(workingMessages, messages)
Expand Down Expand Up @@ -176,6 +180,19 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc

// Handle tool calls
for _, toolCall := range response.ToolCalls {
if onToolApproval != nil {
approved, err := onToolApproval(toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
return nil, err
}
if !approved {
rejectedMsg := fmt.Sprintf("The user did not allow tool call %s. Reason: User cancelled.", toolCall.Function.Name)
toolMessage := schema.ToolMessage(rejectedMsg, toolCall.ID)
workingMessages = append(workingMessages, toolMessage)
continue
}
}

// Notify about tool call
if onToolCall != nil {
onToolCall(toolCall.Function.Name, toolCall.Function.Arguments)
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ type Config struct {
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
Theme any `json:"theme" yaml:"theme"`
MarkdownTheme any `json:"markdown-theme" yaml:"markdown-theme"`
ApproveToolRun bool `json:"approve-tool-run" yaml:"approve-tool-run"`

// Model generation parameters
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
Expand Down
21 changes: 15 additions & 6 deletions internal/ui/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import (
"golang.org/x/term"
)

var (
promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))
)
var promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))

// CLI handles the command line interface with improved message rendering
type CLI struct {
Expand Down Expand Up @@ -83,7 +81,6 @@ func (c *CLI) GetPrompt() (string, error) {
// Run as a tea program
p := tea.NewProgram(input)
finalModel, err := p.Run()

if err != nil {
return "", err
}
Expand Down Expand Up @@ -151,7 +148,6 @@ func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error

// DisplayToolCallMessage displays a tool call in progress
func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) {

c.messageContainer.messages = nil // clear previous messages (they should have been printed already)
c.lastStreamHeight = 0 // Reset last stream height for new prompt

Expand Down Expand Up @@ -331,6 +327,20 @@ func (c *CLI) IsSlashCommand(input string) bool {
return strings.HasPrefix(input, "/")
}

func (c *CLI) GetToolApproval(toolName, toolArgs string) (bool, error) {
input := NewToolApprovalInput(toolName, toolArgs, c.width)
p := tea.NewProgram(input)
finalModel, err := p.Run()
if err != nil {
return false, err
}

if finalInput, ok := finalModel.(*ToolApprovalInput); ok {
return finalInput.approved, nil
}
return false, fmt.Errorf("GetToolApproval: unexpected error type")
}

// SlashCommandResult represents the result of handling a slash command
type SlashCommandResult struct {
Handled bool
Expand Down Expand Up @@ -377,7 +387,6 @@ func (c *CLI) ClearMessages() {

// displayContainer renders and displays the message container
func (c *CLI) displayContainer() {

// Add left padding to the entire container
content := c.messageContainer.Render()

Expand Down
135 changes: 135 additions & 0 deletions internal/ui/tool_approval_input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package ui

import (
"fmt"
"strings"

"github.com/charmbracelet/bubbles/textarea"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)

type ToolApprovalInput struct {
textarea textarea.Model
toolName string
toolArgs string
width int
selected bool // true when "yes" is highlighted and false when "no" is
approved bool
done bool
}

func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput {
ta := textarea.New()
ta.Placeholder = ""
ta.ShowLineNumbers = false
ta.CharLimit = 1000
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
ta.SetHeight(4) // Default to 3 lines like huh
ta.Focus()

// Style the textarea to match huh theme
ta.FocusedStyle.Base = lipgloss.NewStyle()
ta.FocusedStyle.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
ta.FocusedStyle.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
ta.FocusedStyle.Prompt = lipgloss.NewStyle()
ta.FocusedStyle.CursorLine = lipgloss.NewStyle()
ta.Cursor.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("39"))

return &ToolApprovalInput{
textarea: ta,
toolName: toolName,
toolArgs: toolArgs,
width: width,
selected: true,
}
}

func (t *ToolApprovalInput) Init() tea.Cmd {
return textarea.Blink
}

func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "y", "Y":
t.approved = true
t.done = true
return t, tea.Quit
case "n", "N":
t.approved = false
t.done = true
return t, tea.Quit
case "left":
t.selected = true
return t, nil
case "right":
t.selected = false
return t, nil
case "enter":
t.approved = t.selected
t.done = true
return t, tea.Quit
case "esc", "ctrl+c":
t.approved = false
t.done = true
return t, tea.Quit
}
}
return t, nil
}

func (t *ToolApprovalInput) View() string {
if t.done {
return "we are done"
}
// Add left padding to entire component (2 spaces like other UI elements)
containerStyle := lipgloss.NewStyle().PaddingLeft(2)

// Title
titleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("252")).
MarginBottom(1)

// Input box with huh-like styling
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.ThickBorder()).
BorderLeft(true).
BorderRight(false).
BorderTop(false).
BorderBottom(false).
BorderForeground(lipgloss.Color("39")).
PaddingLeft(1).
Width(t.width - 2) // Account for container padding

// Style for the currently selected/highlighted option
selectedStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("42")). // Bright green
Bold(true).
Underline(true)

// Style for the unselected/unhighlighted option
unselectedStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("240")) // Dark gray

// Build the view
var view strings.Builder
view.WriteString(titleStyle.Render("Allow tool execution"))
view.WriteString("\n")
details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs)
view.WriteString(details)
view.WriteString("Allow tool execution: ")

var yesText, noText string
if t.selected {
yesText = selectedStyle.Render("[y]es")
noText = unselectedStyle.Render("[n]o")
} else {
yesText = unselectedStyle.Render("[y]es")
noText = selectedStyle.Render("[n]o")
}
view.WriteString(yesText + "/" + noText + "\n")

return containerStyle.Render(inputBoxStyle.Render(view.String()))
}
2 changes: 2 additions & 0 deletions sdk/mcphost.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) {
nil, // onToolResult
nil, // onResponse
nil, // onToolCallContent
nil, // onToolApproval
)
if err != nil {
return "", err
Expand Down Expand Up @@ -171,6 +172,7 @@ func (m *MCPHost) PromptWithCallbacks(
nil, // onResponse
nil, // onToolCallContent
onStreaming,
nil, // onToolApproval
)
if err != nil {
return "", err
Expand Down
Loading