From c3b4057f80d381ebb7d1dc14f4d0ae6d558210b4 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Thu, 30 Oct 2025 06:38:37 -0600 Subject: [PATCH 01/20] refactor: refactor plugins to make them extensible. Signed-off-by: Teryl Taylor --- .../adr/016-plugin-framework-ai-middleware.md | 2 +- docs/docs/architecture/plugins.md | 4 +- docs/docs/using/plugins/index.md | 10 +- docs/docs/using/plugins/rust-plugins.md | 4 +- llms/plugins-llms.md | 2 +- mcpgateway/plugins/framework/__init__.py | 39 +- mcpgateway/plugins/framework/base.py | 221 ++--- mcpgateway/plugins/framework/constants.py | 5 +- .../plugins/framework/external/mcp/client.py | 188 ++-- .../framework/external/mcp/server/runtime.py | 232 +---- .../framework/external/mcp/server/server.py | 71 +- mcpgateway/plugins/framework/hook_registry.py | 203 +++++ mcpgateway/plugins/framework/loader/plugin.py | 1 + mcpgateway/plugins/framework/manager.py | 815 +++++------------- mcpgateway/plugins/framework/models.py | 297 +------ mcpgateway/plugins/framework/registry.py | 54 +- mcpgateway/plugins/framework/utils.py | 429 ++++----- mcpgateway/plugins/mcp/__init__.py | 8 + mcpgateway/plugins/mcp/entities/__init__.py | 49 ++ mcpgateway/plugins/mcp/entities/base.py | 212 +++++ mcpgateway/plugins/mcp/entities/models.py | 267 ++++++ mcpgateway/services/prompt_service.py | 19 +- mcpgateway/services/resource_service.py | 9 +- mcpgateway/services/tool_service.py | 12 +- .../plugin.py.jinja | 2 +- plugin_templates/native/plugin.py.jinja | 2 +- plugins/README.md | 4 +- .../ai_artifacts_normalizer.py | 6 +- plugins/altk_json_processor/json_processor.py | 6 +- .../argument_normalizer.py | 6 +- .../cached_tool_result/cached_tool_result.py | 6 +- plugins/circuit_breaker/circuit_breaker.py | 6 +- .../citation_validator/citation_validator.py | 6 +- plugins/code_formatter/code_formatter.py | 6 +- .../code_safety_linter/code_safety_linter.py | 6 +- .../content_moderation/content_moderation.py | 6 +- plugins/deny_filter/deny.py | 5 +- .../external/clamav_server/clamav_plugin.py | 6 +- .../llmguard/llmguardplugin/plugin.py | 7 +- .../external/opa/opapluginfilter/plugin.py | 6 +- .../file_type_allowlist.py | 6 +- .../harmful_content_detector.py | 6 +- plugins/header_injector/header_injector.py | 6 +- plugins/html_to_markdown/html_to_markdown.py | 6 +- plugins/json_repair/json_repair.py | 6 +- .../license_header_injector.py | 6 +- plugins/markdown_cleaner/markdown_cleaner.py | 6 +- .../output_length_guard.py | 6 +- plugins/pii_filter/pii_filter.py | 6 +- .../privacy_notice_injector.py | 6 +- plugins/rate_limiter/rate_limiter.py | 6 +- plugins/regex_filter/search_replace.py | 6 +- plugins/resource_filter/resource_filter.py | 6 +- .../response_cache_by_prompt.py | 6 +- .../retry_with_backoff/retry_with_backoff.py | 6 +- .../robots_license_guard.py | 6 +- .../safe_html_sanitizer.py | 6 +- plugins/schema_guard/schema_guard.py | 6 +- .../secrets_detection/secrets_detection.py | 6 +- plugins/sql_sanitizer/sql_sanitizer.py | 6 +- plugins/summarizer/summarizer.py | 6 +- .../timezone_translator.py | 6 +- plugins/url_reputation/url_reputation.py | 6 +- plugins/vault/vault_plugin.py | 8 +- .../virus_total_checker.py | 6 +- plugins/watchdog/watchdog.py | 6 +- .../webhook_notification.py | 6 +- plugins_rust/docs/implementation-guide.md | 2 +- .../test_resource_plugin_integration.py | 223 +++-- .../plugins/fixtures/plugins/context.py | 10 +- .../plugins/fixtures/plugins/error.py | 8 +- .../plugins/fixtures/plugins/headers.py | 10 +- .../plugins/fixtures/plugins/passthrough.py | 8 +- .../external/mcp/server/test_runtime.py | 2 + .../external/mcp/test_client_config.py | 15 +- .../external/mcp/test_client_stdio.py | 29 +- .../mcp/test_client_streamable_http.py | 3 +- .../framework/loader/test_plugin_loader.py | 3 +- .../plugins/framework/test_context.py | 11 +- .../plugins/framework/test_errors.py | 9 +- .../plugins/framework/test_manager.py | 37 +- .../framework/test_manager_extended.py | 299 +++---- .../plugins/framework/test_registry.py | 44 +- .../plugins/framework/test_resource_hooks.py | 147 ++-- .../plugins/framework/test_utils.py | 301 +++---- .../test_json_processor.py | 6 +- .../test_argument_normalizer.py | 6 +- .../test_cached_tool_result.py | 5 +- .../test_code_safety_linter.py | 4 +- .../test_content_moderation.py | 6 +- .../test_content_moderation_integration.py | 19 +- .../external_clamav/test_clamav_remote.py | 14 +- .../test_file_type_allowlist.py | 5 +- .../html_to_markdown/test_html_to_markdown.py | 4 +- .../plugins/json_repair/test_json_repair.py | 5 +- .../markdown_cleaner/test_markdown_cleaner.py | 4 +- .../test_output_length_guard.py | 5 +- .../plugins/pii_filter/test_pii_filter.py | 8 +- .../plugins/rate_limiter/test_rate_limiter.py | 4 +- .../resource_filter/test_resource_filter.py | 4 +- .../plugins/schema_guard/test_schema_guard.py | 4 +- .../url_reputation/test_url_reputation.py | 6 +- .../test_virus_total_checker.py | 14 +- .../test_webhook_integration.py | 15 +- .../test_webhook_notification.py | 6 +- .../services/test_resource_service_plugins.py | 246 +++--- .../mcpgateway/services/test_tool_service.py | 72 +- 107 files changed, 2549 insertions(+), 2477 deletions(-) create mode 100644 mcpgateway/plugins/framework/hook_registry.py create mode 100644 mcpgateway/plugins/mcp/__init__.py create mode 100644 mcpgateway/plugins/mcp/entities/__init__.py create mode 100644 mcpgateway/plugins/mcp/entities/base.py create mode 100644 mcpgateway/plugins/mcp/entities/models.py diff --git a/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md b/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md index b5803cd59..5b239c9c7 100644 --- a/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md +++ b/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md @@ -20,7 +20,7 @@ We implemented a comprehensive plugin framework with the following key architect ```python from mcpgateway.plugins.framework import Plugin -class MyInProcessPlugin(Plugin): +class MyInProcessPlugin(MCPPlugin): async def prompt_pre_fetch(self, payload, context): ... # in‑process logic diff --git a/docs/docs/architecture/plugins.md b/docs/docs/architecture/plugins.md index 819cbdebf..2f27b2e86 100644 --- a/docs/docs/architecture/plugins.md +++ b/docs/docs/architecture/plugins.md @@ -1330,7 +1330,7 @@ class PluginSettings(BaseModel): #### PII Filter Plugin (Native) ```python -class PIIFilterPlugin(Plugin): +class PIIFilterPlugin(MCPPlugin): """Detects and masks Personally Identifiable Information""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, @@ -1367,7 +1367,7 @@ class PIIFilterPlugin(Plugin): #### Resource Filter Plugin (Security) ```python -class ResourceFilterPlugin(Plugin): +class ResourceFilterPlugin(MCPPlugin): """Validates and filters resource requests""" async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, diff --git a/docs/docs/using/plugins/index.md b/docs/docs/using/plugins/index.md index 0caf87132..89e36b7d4 100644 --- a/docs/docs/using/plugins/index.md +++ b/docs/docs/using/plugins/index.md @@ -89,7 +89,7 @@ Decide between a native (in‑process) or external (MCP) plugin: ```python from mcpgateway.plugins.framework import Plugin, PluginConfig, PluginContext, PromptPrehookPayload, PromptPrehookResult -class MyPlugin(Plugin): +class MyPlugin(MCPPlugin): def __init__(self, config: PluginConfig): super().__init__(config) @@ -539,7 +539,7 @@ from mcpgateway.plugins.framework import ( ResourcePostFetchResult ) -class MyPlugin(Plugin): +class MyPlugin(MCPPlugin): """Example plugin implementation.""" def __init__(self, config: PluginConfig): @@ -813,7 +813,7 @@ Metadata for other entities such as prompts and resources will be added in futur ### External Service Plugin Example ```python -class LLMGuardPlugin(Plugin): +class LLMGuardPlugin(MCPPlugin): """Example external service integration.""" def __init__(self, config: PluginConfig): @@ -901,7 +901,7 @@ default_config: # plugins/my_plugin/plugin.py from mcpgateway.plugins.framework import Plugin -class MyPlugin(Plugin): +class MyPlugin(MCPPlugin): # Implementation here pass ``` @@ -963,7 +963,7 @@ Errors inside a plugin should be raised as exceptions. The plugin manager will - Consider async operations for I/O ```python -class CachedPlugin(Plugin): +class CachedPlugin(MCPPlugin): def __init__(self, config): super().__init__(config) self._cache = {} diff --git a/docs/docs/using/plugins/rust-plugins.md b/docs/docs/using/plugins/rust-plugins.md index a10dfd9ce..a99c89735 100644 --- a/docs/docs/using/plugins/rust-plugins.md +++ b/docs/docs/using/plugins/rust-plugins.md @@ -496,7 +496,7 @@ try: except ImportError: RUST_AVAILABLE = False -class MyPlugin(Plugin): +class MyPlugin(MCPPlugin): def __init__(self, config): if RUST_AVAILABLE: self.impl = RustMyPlugin(config) @@ -624,7 +624,7 @@ If you have an existing Python plugin you want to optimize: You don't need to convert entire plugins at once: ```python -class MyPlugin(Plugin): +class MyPlugin(MCPPlugin): def __init__(self, config): # Use Rust for expensive operations if RUST_AVAILABLE: diff --git a/llms/plugins-llms.md b/llms/plugins-llms.md index c2a16c353..e31515872 100644 --- a/llms/plugins-llms.md +++ b/llms/plugins-llms.md @@ -179,7 +179,7 @@ from mcpgateway.plugins.framework import Plugin, PluginConfig, PluginContext from mcpgateway.plugins.framework import PromptPrehookPayload, PromptPrehookResult from mcpgateway.plugins.framework import PluginViolation -class MyGuard(Plugin): +class MyGuard(MCPPlugin): async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: if payload.args and any("forbidden" in v for v in payload.args.values() if isinstance(v, str)): return PromptPrehookResult( diff --git a/mcpgateway/plugins/framework/__init__.py b/mcpgateway/plugins/framework/__init__.py index db61745c7..c170aa35f 100644 --- a/mcpgateway/plugins/framework/__init__.py +++ b/mcpgateway/plugins/framework/__init__.py @@ -17,43 +17,30 @@ from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.errors import PluginError, PluginViolationError from mcpgateway.plugins.framework.external.mcp.server import ExternalPluginServer +from mcpgateway.plugins.framework.hook_registry import HookRegistry, get_hook_registry from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework.models import ( GlobalContext, - HttpHeaderPayload, - HttpHeaderPayloadResult, - HookType, + MCPServerConfig, PluginCondition, PluginConfig, PluginContext, PluginErrorModel, PluginMode, + PluginPayload, PluginResult, PluginViolation, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - PromptResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, ) __all__ = [ "ConfigLoader", "ExternalPluginServer", "GlobalContext", - "HookType", - "HttpHeaderPayload", - "HttpHeaderPayloadResult", + "HookRegistry", + "get_hook_registry", + "MCPServerConfig", "Plugin", "PluginCondition", "PluginConfig", @@ -63,20 +50,8 @@ "PluginLoader", "PluginManager", "PluginMode", + "PluginPayload", "PluginResult", "PluginViolation", "PluginViolationError", - "PromptPosthookPayload", - "PromptPosthookResult", - "PromptPrehookPayload", - "PromptPrehookResult", - "PromptResult", - "ResourcePostFetchPayload", - "ResourcePostFetchResult", - "ResourcePreFetchPayload", - "ResourcePreFetchResult", - "ToolPostInvokePayload", - "ToolPostInvokeResult", - "ToolPreInvokePayload", - "ToolPreInvokeResult", ] diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 28bd25481..a91739a44 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -2,7 +2,7 @@ """Location: ./mcpgateway/plugins/framework/base.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor, Mihai Criveti +-Authors: Teryl Taylor, Mihai Criveti Base plugin implementation. This module implements the base plugin object. @@ -17,27 +17,19 @@ """ # Standard +from typing import Awaitable, Callable, Optional, Union import uuid # First-Party +from mcpgateway.plugins.framework.errors import PluginError from mcpgateway.plugins.framework.models import ( - HookType, PluginCondition, PluginConfig, PluginContext, + PluginErrorModel, PluginMode, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, + PluginPayload, + PluginResult, ) @@ -45,7 +37,8 @@ class Plugin: """Base plugin object for pre/post processing of inputs and outputs at various locations throughout the server. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType, PluginMode + >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode + >>> from mcpgateway.plugins.mcp.entities import HookType >>> config = PluginConfig( ... name="test_plugin", ... description="Test plugin", @@ -68,14 +61,24 @@ class Plugin: True """ - def __init__(self, config: PluginConfig) -> None: + def __init__( + self, + config: PluginConfig, + hook_payloads: Optional[dict[str, PluginPayload]] = None, + hook_results: Optional[dict[str, PluginResult]] = None, + ) -> None: """Initialize a plugin with a configuration and context. Args: config: The plugin configuration + hook_payloads: optional mapping of hookpoints to payloads for the plugin. + Used for external plugins for converting json to pydantic. + hook_results: optional mapping of hookpoints to result types for the plugin. + Used for external plugins for converting json to pydantic. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType + >>> from mcpgateway.plugins.framework import PluginConfig + >>> from mcpgateway.plugins.mcp.entities import HookType >>> config = PluginConfig( ... name="simple_plugin", ... description="Simple test", @@ -90,6 +93,8 @@ def __init__(self, config: PluginConfig) -> None: 'simple_plugin' """ self._config = config + self._hook_payloads = hook_payloads + self._hook_results = hook_results @property def priority(self) -> int: @@ -128,7 +133,7 @@ def name(self) -> str: return self._config.name @property - def hooks(self) -> list[HookType]: + def hooks(self) -> list[str]: """Return the plugin's currently configured hooks. Returns: @@ -157,111 +162,86 @@ def conditions(self) -> list[PluginCondition] | None: async def initialize(self) -> None: """Initialize the plugin.""" - async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Plugin hook run before a prompt is retrieved and rendered. - - Args: - payload: The prompt payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'prompt_pre_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + async def shutdown(self) -> None: + """Plugin cleanup code.""" - async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Plugin hook run after a prompt is rendered. + def json_to_payload(self, hook: str, payload: Union[str | dict]) -> PluginPayload: + """Converts a json payload to the proper pydantic payload object given a hook type. Used + mainly for serialization/deserialization of external plugin payloads. Args: - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. + hook: the hook type for which the payload needs converting. + payload: the payload as a string or dict. + + Returns: + A pydantic payload object corresponding to the hook type. Raises: - NotImplementedError: needs to be implemented by sub class. + PluginError: if no payload type is defined. """ - raise NotImplementedError( - f"""'prompt_post_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + hook_payload_type: type[PluginPayload] | None = None - async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Plugin hook run before a tool is invoked. + # First try instance-level hook_payloads + if self._hook_payloads: + hook_payload_type = self._hook_payloads.get(hook, None) # type: ignore[assignment] - Args: - payload: The tool payload to be analyzed. - context: Contextual information about the hook call. + # Fall back to global registry + if not hook_payload_type: + # First-Party + from mcpgateway.plugins.framework.hook_registry import get_hook_registry - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'tool_pre_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + registry = get_hook_registry() + hook_payload_type = registry.get_payload_type(hook) - async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Plugin hook run after a tool is invoked. + if not hook_payload_type: + raise PluginError(error=PluginErrorModel(message=f"No payload defined for hook {hook}.", plugin_name=self.name)) - Args: - payload: The tool result payload to be analyzed. - context: Contextual information about the hook call. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'tool_post_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + if isinstance(payload, str): + return hook_payload_type.model_validate_json(payload) + return hook_payload_type.model_validate(payload) - async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Plugin hook run before a resource is fetched. + def json_to_result(self, hook: str, result: Union[str | dict]) -> PluginResult: + """Converts a json result to the proper pydantic result object given a hook type. Used + mainly for serialization/deserialization of external plugin results. Args: - payload: The resource payload to be analyzed. - context: Contextual information about the hook call. + hook: the hook type for which the result needs converting. + result: the result as a string or dict. + + Returns: + A pydantic result object corresponding to the hook type. Raises: - NotImplementedError: needs to be implemented by sub class. + PluginError: if no result type is defined. """ - raise NotImplementedError( - f"""'resource_pre_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + hook_result_type: type[PluginResult] | None = None - async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Plugin hook run after a resource is fetched. + # First try instance-level hook_results + if self._hook_results: + hook_result_type = self._hook_results.get(hook, None) # type: ignore[assignment] - Args: - payload: The resource content payload to be analyzed. - context: Contextual information about the hook call. + # Fall back to global registry + if not hook_result_type: + # First-Party + from mcpgateway.plugins.framework.hook_registry import get_hook_registry - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'resource_post_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) + registry = get_hook_registry() + hook_result_type = registry.get_result_type(hook) - async def shutdown(self) -> None: - """Plugin cleanup code.""" + if not hook_result_type: + raise PluginError(error=PluginErrorModel(message=f"No result defined for hook {hook}.", plugin_name=self.name)) + + if isinstance(result, str): + return hook_result_type.model_validate_json(result) + return hook_result_type.model_validate(result) class PluginRef: """Plugin reference which contains a uuid. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType, PluginMode + >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode + >>> from mcpgateway.plugins.mcp.entities import HookType >>> config = PluginConfig( ... name="ref_test", ... description="Reference test", @@ -294,7 +274,8 @@ def __init__(self, plugin: Plugin): plugin: The plugin to reference. Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, HookType + >>> from mcpgateway.plugins.framework import PluginConfig + >>> from mcpgateway.plugins.mcp.entities import HookType >>> config = PluginConfig( ... name="plugin_ref", ... description="Test", @@ -351,7 +332,7 @@ def name(self) -> str: return self._plugin.name @property - def hooks(self) -> list[HookType]: + def hooks(self) -> list[str]: """Returns the plugin's currently configured hooks. Returns: @@ -385,3 +366,47 @@ def mode(self) -> PluginMode: Plugin's mode. """ return self.plugin.mode + + +class HookRef: + """A Hook reference point with plugin and function.""" + + def __init__(self, hook: str, plugin_ref: PluginRef): + """Initialize a hook reference point. + + Args: + hook: name of the hook point. + plugin_ref: The reference to the plugin to hook. + """ + self._plugin_ref = plugin_ref + self._hook = hook + self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = getattr(plugin_ref.plugin, hook) + if not self._func: + raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} has no hook: {hook}", plugin_name=plugin_ref.plugin.name)) + + @property + def plugin_ref(self) -> PluginRef: + """The reference to the plugin object. + + Returns: + A plugin reference. + """ + return self._plugin_ref + + @property + def name(self) -> str: + """The name of the hooking function. + + Returns: + A plugin name. + """ + return self._hook + + @property + def hook(self) -> Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]]: + """The hooking function that can be invoked within the reference. + + Returns: + An awaitable hook function reference. + """ + return self._func diff --git a/mcpgateway/plugins/framework/constants.py b/mcpgateway/plugins/framework/constants.py index 155679c57..7c3d81e90 100644 --- a/mcpgateway/plugins/framework/constants.py +++ b/mcpgateway/plugins/framework/constants.py @@ -16,7 +16,6 @@ PYTHON_SUFFIX = ".py" URL = "url" SCRIPT = "script" -AFTER = "after" NAME = "name" PYTHON = "python" @@ -25,7 +24,6 @@ CONTEXT = "context" RESULT = "result" ERROR = "error" -GET_PLUGIN_CONFIG = "get_plugin_config" IGNORE_CONFIG_EXTERNAL = "ignore_config_external" # Global Context Metadata fields @@ -37,3 +35,6 @@ MCP_SERVER_NAME = "MCP Plugin Server" MCP_SERVER_INSTRUCTIONS = "External plugin server for MCP Gateway" GET_PLUGIN_CONFIGS = "get_plugin_configs" +GET_PLUGIN_CONFIG = "get_plugin_config" +HOOK_TYPE = "hook_type" +INVOKE_HOOK = "invoke_hook" diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index 1d8e60133..fcfb5e807 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -11,46 +11,48 @@ # Standard import asyncio from contextlib import AsyncExitStack +from functools import partial import json import logging import os -from typing import Any, Optional, Type, TypeVar +from typing import Any, Awaitable, Callable, Optional # Third-Party import httpx from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client -from pydantic import BaseModel +from mcp.types import TextContent # First-Party -from mcpgateway.plugins.framework.base import Plugin -from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, GET_PLUGIN_CONFIG, IGNORE_CONFIG_EXTERNAL, NAME, PAYLOAD, PLUGIN_NAME, PYTHON, PYTHON_SUFFIX, RESULT +from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef +from mcpgateway.plugins.framework.constants import ( + CONTEXT, + ERROR, + GET_PLUGIN_CONFIG, + HOOK_TYPE, + IGNORE_CONFIG_EXTERNAL, + INVOKE_HOOK, + NAME, + PAYLOAD, + PLUGIN_NAME, + PYTHON, + PYTHON_SUFFIX, + RESULT, +) from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context +from mcpgateway.plugins.framework.hook_registry import get_hook_registry from mcpgateway.plugins.framework.models import ( - HookType, MCPClientTLSConfig, PluginConfig, PluginContext, PluginErrorModel, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, + PluginPayload, + PluginResult, ) from mcpgateway.schemas import TransportType -P = TypeVar("P", bound=BaseModel) - logger = logging.getLogger(__name__) @@ -81,8 +83,12 @@ async def initialize(self) -> None: if not self._config.mcp: raise PluginError(error=PluginErrorModel(message="The mcp section must be defined for external plugin", plugin_name=self.name)) if self._config.mcp.proto == TransportType.STDIO: + if not self._config.mcp.script: + raise PluginError(error=PluginErrorModel(message="STDIO transport requires script", plugin_name=self.name)) await self.__connect_to_stdio_server(self._config.mcp.script) elif self._config.mcp.proto == TransportType.STREAMABLEHTTP: + if not self._config.mcp.url: + raise PluginError(error=PluginErrorModel(message="STREAMABLEHTTP transport requires url", plugin_name=self.name)) await self.__connect_to_http_server(self._config.mcp.url) try: @@ -146,9 +152,6 @@ async def __connect_to_http_server(self, uri: str) -> None: Raises: PluginError: if there is an external connection error after all retries. """ - max_retries = 3 - base_delay = 1.0 - plugin_tls = self._config.mcp.tls if self._config and self._config.mcp else None tls_config = plugin_tls or MCPClientTLSConfig.from_env() @@ -188,37 +191,37 @@ def _tls_httpx_client_factory( return httpx.AsyncClient(**kwargs) + max_retries = 3 + base_delay = 1.0 + for attempt in range(max_retries): - logger.info(f"Connecting to external plugin server: {uri} (attempt {attempt + 1}/{max_retries})") try: - # Create a fresh exit stack for each attempt + client_factory = _tls_httpx_client_factory if tls_config else None async with AsyncExitStack() as temp_stack: - client_factory = _tls_httpx_client_factory if tls_config else None streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory) if client_factory else streamablehttp_client(uri) http_transport = await temp_stack.enter_async_context(streamable_client) http_client, write_func, _ = http_transport session = await temp_stack.enter_async_context(ClientSession(http_client, write_func)) - await session.initialize() - # List available tools response = await session.list_tools() tools = response.tools - logger.info("Successfully connected to plugin MCP server with tools: %s", " ".join([tool.name for tool in tools])) + logger.info( + "Successfully connected to plugin MCP server with tools: %s", + " ".join([tool.name for tool in tools]), + ) - # Success! Now move to the main exit stack client_factory = _tls_httpx_client_factory if tls_config else None streamable_client = streamablehttp_client(uri, httpx_client_factory=client_factory) if client_factory else streamablehttp_client(uri) http_transport = await self._exit_stack.enter_async_context(streamable_client) self._http, self._write, _ = http_transport self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write)) + await self._session.initialize() return - except Exception as e: logger.warning(f"Connection attempt {attempt + 1}/{max_retries} failed: {e}") - if attempt == max_retries - 1: # Final attempt failed error_msg = f"External plugin '{self.name}' connection failed after {max_retries} attempts: {uri} is not reachable. Please ensure the MCP server is running." @@ -230,12 +233,11 @@ def _tls_httpx_client_factory( logger.info(f"Retrying in {delay}s...") await asyncio.sleep(delay) - async def __invoke_hook(self, payload_result_model: Type[P], hook_type: HookType, payload: BaseModel, context: PluginContext) -> P: + async def invoke_hook(self, hook_type: str, payload: PluginPayload, context: PluginContext) -> PluginResult: """Invoke an external plugin hook using the MCP protocol. Args: - payload_result_model: The type of result payload for the hook. - hook_type: The type of hook invoked (i.e., prompt_pre_hook) + hook_type: The type of hook invoked (i.e., prompt_pre_fetch) payload: The payload to be passed to the hook. context: The plugin context passed to the run. @@ -245,18 +247,31 @@ async def __invoke_hook(self, payload_result_model: Type[P], hook_type: HookType Returns: The resulting payload from the plugin. """ + # Get the result type from the global registry + registry = get_hook_registry() + result_type = registry.get_result_type(hook_type) + if not result_type: + raise PluginError(error=PluginErrorModel(message=f"Hook type '{hook_type}' not registered in hook registry", plugin_name=self.name)) + + if not self._session: + raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) try: - result = await self._session.call_tool(hook_type, {PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context}) + result = await self._session.call_tool(INVOKE_HOOK, {HOOK_TYPE: hook_type, PLUGIN_NAME: self.name, PAYLOAD: payload, CONTEXT: context}) for content in result.content: - res = json.loads(content.text) + if not isinstance(content, TextContent): + continue + try: + res = json.loads(content.text) + except json.decoder.JSONDecodeError: + raise PluginError(error=PluginErrorModel(message=f"Error trying to decode json: {content.text}", code="JSON_DECODE_ERROR", plugin_name=self.name)) if CONTEXT in res: cxt = PluginContext.model_validate(res[CONTEXT]) context.state = cxt.state context.metadata = cxt.metadata context.global_context.state = cxt.global_context.state if RESULT in res: - return payload_result_model.model_validate(res[RESULT]) + return result_type.model_validate(res[RESULT]) if ERROR in res: error = PluginErrorModel.model_validate(res[ERROR]) raise PluginError(error) @@ -268,83 +283,6 @@ async def __invoke_hook(self, payload_result_model: Type[P], hook_type: HookType raise PluginError(error=convert_exception_to_error(e, plugin_name=self.name)) raise PluginError(error=PluginErrorModel(message=f"Received invalid response. Result = {result}", plugin_name=self.name)) - async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Plugin hook run before a prompt is retrieved and rendered. - - Args: - payload: The prompt payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The prompt prehook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=PromptPrehookResult, hook_type=HookType.PROMPT_PRE_FETCH, payload=payload, context=context) - - async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Plugin hook run after a prompt is rendered. - - Args: - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - A set of prompt messages as modified or blocked by the plugin. - """ - return await self.__invoke_hook(payload_result_model=PromptPosthookResult, hook_type=HookType.PROMPT_POST_FETCH, payload=payload, context=context) - - async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Plugin hook run before a tool is invoked. - - Args: - payload: The tool payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The tool prehook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ToolPreInvokeResult, hook_type=HookType.TOOL_PRE_INVOKE, payload=payload, context=context) - - async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Plugin hook run after a tool is invoked. - - Args: - payload: The tool payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The tool posthook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ToolPostInvokeResult, hook_type=HookType.TOOL_POST_INVOKE, payload=payload, context=context) - - async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Plugin hook run before a resource is fetched. - - Args: - payload: The resource payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The resource prehook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ResourcePreFetchResult, hook_type=HookType.RESOURCE_PRE_FETCH, payload=payload, context=context) - - async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Plugin hook run after a resource is fetched. - - Args: - payload: The resource payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Returns: - The resource posthook with name and arguments as modified or blocked by the plugin. - """ - - return await self.__invoke_hook(payload_result_model=ResourcePostFetchResult, hook_type=HookType.RESOURCE_POST_FETCH, payload=payload, context=context) - async def __get_plugin_config(self) -> PluginConfig | None: """Retrieve plugin configuration for the current plugin on the remote MCP server. @@ -354,9 +292,13 @@ async def __get_plugin_config(self) -> PluginConfig | None: Returns: A plugin configuration for the current plugin from a remote MCP server. """ + if not self._session: + raise PluginError(error=PluginErrorModel(message="Plugin session not initialized", plugin_name=self.name)) try: configs = await self._session.call_tool(GET_PLUGIN_CONFIG, {NAME: self.name}) for content in configs.content: + if not isinstance(content, TextContent): + continue conf = json.loads(content.text) return PluginConfig.model_validate(conf) except Exception as e: @@ -369,3 +311,21 @@ async def shutdown(self) -> None: """Plugin cleanup code.""" if self._exit_stack: await self._exit_stack.aclose() + + +class ExternalHookRef(HookRef): + """A Hook reference point for external plugins.""" + + def __init__(self, hook: str, plugin_ref: PluginRef): + """Initialize a hook reference point for an external plugin. + + Args: + hook: name of the hook point. + plugin_ref: The reference to the plugin to hook. + """ + self._plugin_ref = plugin_ref + self._hook = hook + if hasattr(plugin_ref.plugin, INVOKE_HOOK): + self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) + if not self._func: + raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name)) diff --git a/mcpgateway/plugins/framework/external/mcp/server/runtime.py b/mcpgateway/plugins/framework/external/mcp/server/runtime.py index 09b3a2ed1..5091fc517 100755 --- a/mcpgateway/plugins/framework/external/mcp/server/runtime.py +++ b/mcpgateway/plugins/framework/external/mcp/server/runtime.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """Location: ./mcpgateway/plugins/framework/external/mcp/server/runtime.py Copyright 2025 @@ -29,32 +28,19 @@ # First-Party from mcpgateway.plugins.framework import ( ExternalPluginServer, - Plugin, - PluginContext, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, + MCPServerConfig, ) from mcpgateway.plugins.framework.constants import ( GET_PLUGIN_CONFIG, GET_PLUGIN_CONFIGS, + INVOKE_HOOK, MCP_SERVER_INSTRUCTIONS, MCP_SERVER_NAME, ) -from mcpgateway.plugins.framework.models import HookType, MCPServerConfig logger = logging.getLogger(__name__) -SERVER: ExternalPluginServer = None +SERVER: ExternalPluginServer | None = None # Module-level tool functions (extracted for testability) @@ -66,6 +52,8 @@ async def get_plugin_configs() -> list[dict]: Returns: JSON string containing list of plugin configuration dictionaries. """ + if not SERVER: + raise RuntimeError("Plugin server not initialized") return await SERVER.get_plugin_configs() @@ -78,175 +66,29 @@ async def get_plugin_config(name: str) -> dict: Returns: JSON string containing plugin configuration dictionary. """ - return await SERVER.get_plugin_config(name) + if not SERVER: + raise RuntimeError("Plugin server not initialized") + result = await SERVER.get_plugin_config(name) + if result is None: + return {} + return result -async def prompt_pre_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute prompt prefetch hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The prompt name and arguments to be analyzed - context: Contextual information required for execution - - Returns: - Result dictionary from the prompt prefetch hook. - """ - - def prompt_pre_fetch_func(plugin: Plugin, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Wrapper function to invoke prompt prefetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The prompt prehook payload. - context: The plugin context. - - Returns: - Result from the plugin's prompt_pre_fetch method. - """ - return plugin.prompt_pre_fetch(payload, context) - - return await SERVER.invoke_hook(PromptPrehookPayload, prompt_pre_fetch_func, plugin_name, payload, context) - - -async def prompt_post_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute prompt postfetch hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The prompt payload to be analyzed - context: Contextual information - - Returns: - Result dictionary from the prompt postfetch hook. - """ - - def prompt_post_fetch_func(plugin: Plugin, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Wrapper function to invoke prompt postfetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The prompt posthook payload. - context: The plugin context. - - Returns: - Result from the plugin's prompt_post_fetch method. - """ - return plugin.prompt_post_fetch(payload, context) - - return await SERVER.invoke_hook(PromptPosthookPayload, prompt_post_fetch_func, plugin_name, payload, context) - - -async def tool_pre_invoke(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute tool pre-invoke hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The tool name and arguments to be analyzed - context: Contextual information - - Returns: - Result dictionary from the tool pre-invoke hook. - """ - - def tool_pre_invoke_func(plugin: Plugin, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Wrapper function to invoke tool pre-invoke on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The tool pre-invoke payload. - context: The plugin context. - - Returns: - Result from the plugin's tool_pre_invoke method. - """ - return plugin.tool_pre_invoke(payload, context) - - return await SERVER.invoke_hook(ToolPreInvokePayload, tool_pre_invoke_func, plugin_name, payload, context) - - -async def tool_post_invoke(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute tool post-invoke hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The tool result to be analyzed - context: Contextual information - - Returns: - Result dictionary from the tool post-invoke hook. - """ - - def tool_post_invoke_func(plugin: Plugin, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Wrapper function to invoke tool post-invoke on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The tool post-invoke payload. - context: The plugin context. - - Returns: - Result from the plugin's tool_post_invoke method. - """ - return plugin.tool_post_invoke(payload, context) - - return await SERVER.invoke_hook(ToolPostInvokePayload, tool_post_invoke_func, plugin_name, payload, context) - - -async def resource_pre_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute resource prefetch hook for a plugin. - - Args: - plugin_name: The name of the plugin to execute - payload: The resource name and arguments to be analyzed - context: Contextual information - - Returns: - Result dictionary from the resource prefetch hook. - """ - - def resource_pre_fetch_func(plugin: Plugin, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Wrapper function to invoke resource prefetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The resource prefetch payload. - context: The plugin context. - - Returns: - Result from the plugin's resource_pre_fetch method. - """ - return plugin.resource_pre_fetch(payload, context) - - return await SERVER.invoke_hook(ResourcePreFetchPayload, resource_pre_fetch_func, plugin_name, payload, context) - - -async def resource_post_fetch(plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: - """Execute resource postfetch hook for a plugin. +async def invoke_hook(hook_type: str, plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: + """Execute a hook for a plugin. Args: + hook_type: The name or type of the hook. plugin_name: The name of the plugin to execute payload: The resource payload to be analyzed context: Contextual information Returns: - Result dictionary from the resource postfetch hook. + Result dictionary with payload, context and any error information. """ - - def resource_post_fetch_func(plugin: Plugin, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Wrapper function to invoke resource postfetch on a plugin instance. - - Args: - plugin: The plugin instance to execute. - payload: The resource postfetch payload. - context: The plugin context. - - Returns: - Result from the plugin's resource_post_fetch method. - """ - return plugin.resource_post_fetch(payload, context) - - return await SERVER.invoke_hook(ResourcePostFetchPayload, resource_post_fetch_func, plugin_name, payload, context) + if not SERVER: + raise RuntimeError("Plugin server not initialized") + return await SERVER.invoke_hook(hook_type, plugin_name, payload, context) class SSLCapableFastMCP(FastMCP): @@ -288,7 +130,7 @@ def _get_ssl_config(self) -> dict: if tls.ca_bundle: ssl_config["ssl_ca_certs"] = tls.ca_bundle - ssl_config["ssl_cert_reqs"] = tls.ssl_cert_reqs + ssl_config["ssl_cert_reqs"] = str(tls.ssl_cert_reqs) if tls.keyfile_password: ssl_config["ssl_keyfile_password"] = tls.keyfile_password @@ -315,12 +157,12 @@ async def _start_health_check_server(self, health_port: int) -> None: health_port: Port number for the health check server. """ # Third-Party - from starlette.applications import Starlette # pylint: disable=import-outside-toplevel - from starlette.requests import Request # pylint: disable=import-outside-toplevel - from starlette.responses import JSONResponse # pylint: disable=import-outside-toplevel - from starlette.routing import Route # pylint: disable=import-outside-toplevel + from starlette.applications import Starlette + from starlette.requests import Request + from starlette.responses import JSONResponse + from starlette.routing import Route - async def health_check(request: Request): # pylint: disable=unused-argument + async def health_check(request: Request): """Health check endpoint for container orchestration. Args: @@ -350,11 +192,11 @@ async def run_streamable_http_async(self) -> None: # Add health check endpoint to main app # Third-Party - from starlette.requests import Request # pylint: disable=import-outside-toplevel - from starlette.responses import JSONResponse # pylint: disable=import-outside-toplevel - from starlette.routing import Route # pylint: disable=import-outside-toplevel + from starlette.requests import Request + from starlette.responses import JSONResponse + from starlette.routing import Route - async def health_check(request: Request): # pylint: disable=unused-argument + async def health_check(request: Request): """Health check endpoint for container orchestration. Args: @@ -379,7 +221,7 @@ async def health_check(request: Request): # pylint: disable=unused-argument config_kwargs.update(ssl_config) logger.info(f"Starting plugin server on {self.settings.host}:{self.settings.port}") - config = uvicorn.Config(**config_kwargs) + config = uvicorn.Config(**config_kwargs) # type: ignore[arg-type] server = uvicorn.Server(config) # If SSL is enabled, start a separate HTTP health check server @@ -412,7 +254,7 @@ async def run(): Raises: Exception: If plugin server initialization or execution fails. """ - global SERVER # pylint: disable=global-statement + global SERVER # Initialize plugin server SERVER = ExternalPluginServer() @@ -445,12 +287,7 @@ async def run(): # Register module-level tool functions with FastMCP mcp.tool(name=GET_PLUGIN_CONFIGS)(get_plugin_configs) mcp.tool(name=GET_PLUGIN_CONFIG)(get_plugin_config) - mcp.tool(name=HookType.PROMPT_PRE_FETCH.value)(prompt_pre_fetch) - mcp.tool(name=HookType.PROMPT_POST_FETCH.value)(prompt_post_fetch) - mcp.tool(name=HookType.TOOL_PRE_INVOKE.value)(tool_pre_invoke) - mcp.tool(name=HookType.TOOL_POST_INVOKE.value)(tool_post_invoke) - mcp.tool(name=HookType.RESOURCE_PRE_FETCH.value)(resource_pre_fetch) - mcp.tool(name=HookType.RESOURCE_POST_FETCH.value)(resource_post_fetch) + mcp.tool(name=INVOKE_HOOK)(invoke_hook) # Run with stdio transport logger.info("Starting MCP plugin server with FastMCP (stdio transport)") @@ -467,12 +304,7 @@ async def run(): # Register module-level tool functions with FastMCP mcp.tool(name=GET_PLUGIN_CONFIGS)(get_plugin_configs) mcp.tool(name=GET_PLUGIN_CONFIG)(get_plugin_config) - mcp.tool(name=HookType.PROMPT_PRE_FETCH.value)(prompt_pre_fetch) - mcp.tool(name=HookType.PROMPT_POST_FETCH.value)(prompt_post_fetch) - mcp.tool(name=HookType.TOOL_PRE_INVOKE.value)(tool_pre_invoke) - mcp.tool(name=HookType.TOOL_POST_INVOKE.value)(tool_post_invoke) - mcp.tool(name=HookType.RESOURCE_PRE_FETCH.value)(resource_pre_fetch) - mcp.tool(name=HookType.RESOURCE_POST_FETCH.value)(resource_post_fetch) + mcp.tool(name=INVOKE_HOOK)(invoke_hook) # Run with streamable-http transport logger.info("Starting MCP plugin server with FastMCP (HTTP transport)") diff --git a/mcpgateway/plugins/framework/external/mcp/server/server.py b/mcpgateway/plugins/framework/external/mcp/server/server.py index 78dba8ce9..218d2a383 100644 --- a/mcpgateway/plugins/framework/external/mcp/server/server.py +++ b/mcpgateway/plugins/framework/external/mcp/server/server.py @@ -2,34 +2,27 @@ """Location: ./mcpgateway/plugins/framework/external/mcp/server/server.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -Plugin MCP Server. - Fred Araujo +Authors: Fred Araujo, Teryl Taylor Module that contains plugin MCP server code to serve external plugins. """ # Standard -import asyncio import logging import os -from typing import Any, Callable, Dict, Type, TypeVar +from typing import Any, Dict, TypeVar # Third-Party from pydantic import BaseModel # First-Party -from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.constants import CONTEXT, ERROR, PLUGIN_NAME, RESULT -from mcpgateway.plugins.framework.errors import convert_exception_to_error +from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError from mcpgateway.plugins.framework.loader.config import ConfigLoader -from mcpgateway.plugins.framework.manager import DEFAULT_PLUGIN_TIMEOUT, PluginManager +from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework.models import ( MCPServerConfig, PluginContext, - PluginErrorModel, - PluginResult, ) P = TypeVar("P", bound=BaseModel) @@ -48,7 +41,7 @@ def __init__(self, config_path: str | None = None) -> None: If set, this attribute overrides the value in PLUGINS_CONFIG_PATH. Examples: - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> server is not None True """ @@ -64,47 +57,46 @@ async def get_plugin_configs(self) -> list[dict]: Examples: >>> import asyncio - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> plugins = asyncio.run(server.get_plugin_configs()) >>> len(plugins) > 0 True """ plugins: list[dict] = [] - for plug in self._config.plugins: - plugins.append(plug.model_dump()) + if self._config.plugins: + for plug in self._config.plugins: + plugins.append(plug.model_dump()) return plugins - async def get_plugin_config(self, name: str) -> dict: + async def get_plugin_config(self, name: str) -> dict | None: """Return a plugin configuration give a plugin name. Args: name: The name of the plugin of which to return the plugin configuration. Returns: - A list of plugin configurations. + A plugin configuration dict, or None if not found. Examples: >>> import asyncio - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> c = asyncio.run(server.get_plugin_config(name = "DenyListPlugin")) >>> c is not None True >>> c["name"] == "DenyListPlugin" True """ - for plug in self._config.plugins: - if plug.name.lower() == name.lower(): - return plug.model_dump() + if self._config.plugins: + for plug in self._config.plugins: + if plug.name.lower() == name.lower(): + return plug.model_dump() return None - async def invoke_hook( - self, payload_model: Type[P], hook_function: Callable[[Plugin], Callable[[P, PluginContext], PluginResult]], plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any] - ) -> dict: + async def invoke_hook(self, hook_type: str, plugin_name: str, payload: Dict[str, Any], context: Dict[str, Any]) -> dict: """Invoke a plugin hook. Args: - payload_model: The type of the payload accepted for the hook. - hook_function: The hook function to be invoked. + hook_type: The type of hook function to be invoked. plugin_name: The name of the plugin to execute. payload: The prompt name and arguments to be analyzed. context: The contextual and state information required for the execution of the hook. @@ -120,10 +112,10 @@ async def invoke_hook( >>> import os >>> os.environ["PYTHONPATH"] = "." >>> from mcpgateway.plugins.framework import GlobalContext, PromptPrehookPayload, PluginContext, PromptPrehookResult - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> def prompt_pre_fetch_func(plugin: Plugin, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: ... return plugin.prompt_pre_fetch(payload, context) - >>> payload = PromptPrehookPayload(prompt_id="test_id", args={"user": "This is so innovative"}) + >>> payload = PromptPrehookPayload(name="test_prompt", args={"user": "This is so innovative"}) >>> context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) >>> initialized = asyncio.run(server.initialize()) >>> initialized @@ -135,21 +127,18 @@ async def invoke_hook( False """ global_plugin_manager = PluginManager() - plugin_timeout = global_plugin_manager.config.plugin_settings.plugin_timeout if global_plugin_manager.config else DEFAULT_PLUGIN_TIMEOUT - plugin = global_plugin_manager.get_plugin(plugin_name) result_payload: dict[str, Any] = {PLUGIN_NAME: plugin_name} try: - if plugin: - _payload = payload_model.model_validate(payload) - _context = PluginContext.model_validate(context) - result = await asyncio.wait_for(hook_function(plugin, _payload, _context), plugin_timeout) - result_payload[RESULT] = result.model_dump() - if not _context.is_empty(): - result_payload[CONTEXT] = _context.model_dump() - return result_payload - raise ValueError(f"Unable to retrieve plugin {plugin_name} to execute.") - except asyncio.TimeoutError: - result_payload[ERROR] = PluginErrorModel(message=f"Plugin {plugin_name} timed out from execution after {plugin_timeout} seconds.", plugin_name=plugin_name).model_dump() + _context = PluginContext.model_validate(context) + + result = await global_plugin_manager.invoke_hook_for_plugin(plugin_name, hook_type, payload, _context, payload_as_json=True) + + result_payload[RESULT] = result.model_dump() + if not _context.is_empty(): + result_payload[CONTEXT] = _context.model_dump() + return result_payload + except PluginError as pe: + result_payload[ERROR] = pe.error return result_payload except Exception as ex: logger.exception(ex) diff --git a/mcpgateway/plugins/framework/hook_registry.py b/mcpgateway/plugins/framework/hook_registry.py new file mode 100644 index 000000000..a10008cd7 --- /dev/null +++ b/mcpgateway/plugins/framework/hook_registry.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hook_registry.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Hook Registry. +This module provides a global registry for mapping hook types to their +corresponding payload and result Pydantic models. This enables external +plugins to properly serialize/deserialize payloads without needing direct +access to the specific plugin implementations. +""" + +# Standard +from typing import Dict, Optional, Type, Union + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class HookRegistry: + """Global registry for hook type metadata. + + This singleton registry maintains mappings between hook type names and their + associated Pydantic models for payloads and results. It enables dynamic + serialization/deserialization for external plugins. + + Examples: + >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult + >>> registry = HookRegistry() + >>> registry.register_hook("test_hook", PluginPayload, PluginResult) + >>> registry.get_payload_type("test_hook") + + >>> registry.get_result_type("test_hook") + + """ + + _instance: Optional["HookRegistry"] = None + _hook_payloads: Dict[str, Type[PluginPayload]] = {} + _hook_results: Dict[str, Type[PluginResult]] = {} + + def __new__(cls) -> "HookRegistry": + """Ensure singleton pattern for the registry. + + Returns: + The singleton HookRegistry instance. + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def register_hook( + self, + hook_type: str, + payload_class: Type[PluginPayload], + result_class: Type[PluginResult], + ) -> None: + """Register a hook type with its payload and result classes. + + Args: + hook_type: The hook type identifier (e.g., "prompt_pre_fetch"). + payload_class: The Pydantic model class for the hook's payload. + result_class: The Pydantic model class for the hook's result. + + Examples: + >>> registry = HookRegistry() + >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult + >>> registry.register_hook("custom_hook", PluginPayload, PluginResult) + """ + self._hook_payloads[hook_type] = payload_class + self._hook_results[hook_type] = result_class + + def get_payload_type(self, hook_type: str) -> Optional[Type[PluginPayload]]: + """Get the payload class for a hook type. + + Args: + hook_type: The hook type identifier. + + Returns: + The Pydantic payload class, or None if not registered. + + Examples: + >>> registry = HookRegistry() + >>> registry.get_payload_type("unknown_hook") + """ + return self._hook_payloads.get(hook_type) + + def get_result_type(self, hook_type: str) -> Optional[Type[PluginResult]]: + """Get the result class for a hook type. + + Args: + hook_type: The hook type identifier. + + Returns: + The Pydantic result class, or None if not registered. + + Examples: + >>> registry = HookRegistry() + >>> registry.get_result_type("unknown_hook") + """ + return self._hook_results.get(hook_type) + + def json_to_payload(self, hook_type: str, payload: Union[str, dict]) -> PluginPayload: + """Convert JSON to the appropriate payload Pydantic model. + + Args: + hook_type: The hook type identifier. + payload: The payload as JSON string or dictionary. + + Returns: + The deserialized Pydantic payload object. + + Raises: + ValueError: If the hook type is not registered. + + Examples: + >>> registry = HookRegistry() + >>> from mcpgateway.plugins.framework import PluginPayload + >>> registry.register_hook("test", PluginPayload, PluginResult) + >>> payload = registry.json_to_payload("test", "{}") + """ + payload_class = self.get_payload_type(hook_type) + if not payload_class: + raise ValueError(f"No payload type registered for hook: {hook_type}") + + if isinstance(payload, str): + return payload_class.model_validate_json(payload) + return payload_class.model_validate(payload) + + def json_to_result(self, hook_type: str, result: Union[str, dict]) -> PluginResult: + """Convert JSON to the appropriate result Pydantic model. + + Args: + hook_type: The hook type identifier. + result: The result as JSON string or dictionary. + + Returns: + The deserialized Pydantic result object. + + Raises: + ValueError: If the hook type is not registered. + + Examples: + >>> registry = HookRegistry() + >>> from mcpgateway.plugins.framework import PluginResult + >>> registry.register_hook("test", PluginPayload, PluginResult) + >>> result = registry.json_to_result("test", '{"continue_processing": true}') + """ + result_class = self.get_result_type(hook_type) + if not result_class: + raise ValueError(f"No result type registered for hook: {hook_type}") + + if isinstance(result, str): + return result_class.model_validate_json(result) + return result_class.model_validate(result) + + def is_registered(self, hook_type: str) -> bool: + """Check if a hook type is registered. + + Args: + hook_type: The hook type identifier. + + Returns: + True if the hook is registered, False otherwise. + + Examples: + >>> registry = HookRegistry() + >>> registry.is_registered("unknown") + False + """ + return hook_type in self._hook_payloads and hook_type in self._hook_results + + def get_registered_hooks(self) -> list[str]: + """Get all registered hook types. + + Returns: + List of registered hook type identifiers. + + Examples: + >>> registry = HookRegistry() + >>> hooks = registry.get_registered_hooks() + >>> isinstance(hooks, list) + True + """ + return list(self._hook_payloads.keys()) + + +# Global singleton instance +_global_registry = HookRegistry() + + +def get_hook_registry() -> HookRegistry: + """Get the global hook registry instance. + + Returns: + The singleton HookRegistry instance. + + Examples: + >>> registry = get_hook_registry() + >>> isinstance(registry, HookRegistry) + True + """ + return _global_registry diff --git a/mcpgateway/plugins/framework/loader/plugin.py b/mcpgateway/plugins/framework/loader/plugin.py index c1dbdc170..1fd9bd9c0 100644 --- a/mcpgateway/plugins/framework/loader/plugin.py +++ b/mcpgateway/plugins/framework/loader/plugin.py @@ -72,6 +72,7 @@ def __register_plugin_type(self, kind: str) -> None: kind: The fully-qualified type of the plugin to be registered. """ if kind not in self._plugin_types: + plugin_type: Type[Plugin] if kind == EXTERNAL_PLUGIN_TYPE: plugin_type = ExternalPlugin else: diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 9287effee..8ef940717 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -20,8 +20,9 @@ >>> # await manager.initialize() # Called in async context >>> # Create test payload and context - >>> from mcpgateway.plugins.framework.models import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(prompt_id="test", name="test", args={"user": "input"}) + >>> from mcpgateway.plugins.framework.models import GlobalContext + >>> from mcpgateway.plugins.mcp.entities.models import PromptPrehookPayload + >>> payload = PromptPrehookPayload(name="test", args={"user": "input"}) >>> context = GlobalContext(request_id="123") >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) # Called in async context """ @@ -30,61 +31,28 @@ import asyncio from copy import deepcopy import logging -import time -from typing import Any, Callable, Coroutine, Dict, Generic, Optional, Tuple, TypeVar +from typing import Any, Optional, Union # First-Party -from mcpgateway.plugins.framework.base import Plugin, PluginRef +from mcpgateway.plugins.framework.base import HookRef, Plugin from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError, PluginViolationError from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.models import ( Config, GlobalContext, - HookType, - PluginCondition, PluginContext, PluginContextTable, PluginErrorModel, PluginMode, + PluginPayload, PluginResult, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, ) from mcpgateway.plugins.framework.registry import PluginInstanceRegistry -from mcpgateway.plugins.framework.utils import ( - post_prompt_matches, - post_resource_matches, - post_tool_matches, - pre_prompt_matches, - pre_resource_matches, - pre_tool_matches, -) # Use standard logging to avoid circular imports (plugins -> services -> plugins) logger = logging.getLogger(__name__) -T = TypeVar( - "T", - PromptPosthookPayload, - PromptPrehookPayload, - ResourcePostFetchPayload, - ResourcePreFetchPayload, - ToolPostInvokePayload, - ToolPreInvokePayload, -) - - # Configuration constants DEFAULT_PLUGIN_TIMEOUT = 30 # seconds MAX_PAYLOAD_SIZE = 1_000_000 # 1MB @@ -100,7 +68,7 @@ class PayloadSizeError(ValueError): """Raised when a payload exceeds the maximum allowed size.""" -class PluginExecutor(Generic[T]): +class PluginExecutor: """Executes a list of plugins with timeout protection and error handling. This class manages the execution of plugins in priority order, handling: @@ -110,7 +78,7 @@ class PluginExecutor(Generic[T]): - Metadata aggregation from multiple plugins Examples: - >>> from mcpgateway.plugins.framework import PromptPrehookPayload + >>> from mcpgateway.plugins.mcp.entities.models import PromptPrehookPayload >>> executor = PluginExecutor[PromptPrehookPayload]() >>> # In async context: >>> # result, contexts = await executor.execute( @@ -134,22 +102,18 @@ def __init__(self, config: Optional[Config] = None, timeout: int = DEFAULT_PLUGI async def execute( self, - plugins: list[PluginRef], - payload: T, + hook_refs: list[HookRef], + payload: PluginPayload, global_context: GlobalContext, - plugin_run: Callable[[PluginRef, T, PluginContext], Coroutine[Any, Any, PluginResult[T]]], - compare: Callable[[T, list[PluginCondition], GlobalContext], bool], local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False, - ) -> tuple[PluginResult[T], PluginContextTable | None]: + ) -> tuple[PluginResult, PluginContextTable | None]: """Execute plugins in priority order with timeout protection. Args: plugins: List of plugins to execute, sorted by priority. payload: The payload to be processed by plugins. global_context: Shared context for all plugins containing request metadata. - plugin_run: Async function to execute a specific plugin hook. - compare: Function to check if plugin conditions match the current context. local_contexts: Optional existing contexts from previous hook executions. violations_as_exceptions: Raise violations as exceptions rather than as returns. @@ -165,39 +129,38 @@ async def execute( Examples: >>> # Execute plugins with timeout protection - >>> from mcpgateway.plugins.framework import HookType + >>> from mcpgateway.plugins.mcp.entities.models import HookType >>> executor = PluginExecutor(timeout=30) >>> # Assuming you have a registry instance: >>> # plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) >>> # In async context: >>> # result, contexts = await executor.execute( >>> # plugins=plugins, - >>> # payload=PromptPrehookPayload(prompt_id="123", args={}), + >>> # payload=PromptPrehookPayload(name="test", args={}), >>> # global_context=GlobalContext(request_id="123"), >>> # plugin_run=pre_prompt_fetch, >>> # compare=pre_prompt_matches >>> # ) """ - if not plugins: - return (PluginResult[T](modified_payload=None), None) + if not hook_refs: + return (PluginResult(modified_payload=None), None) # Validate payload size self._validate_payload_size(payload) res_local_contexts = {} - combined_metadata = {} - current_payload: T | None = None + combined_metadata: dict[str, Any] = {} + current_payload: PluginPayload | None = None - for pluginref in plugins: + for hook_ref in hook_refs: # Skip disabled plugins - if pluginref.mode == PluginMode.DISABLED: - logger.debug(f"Skipping disabled plugin {pluginref.name}") + if hook_ref.plugin_ref.mode == PluginMode.DISABLED: continue # Check if plugin conditions match current context - if pluginref.conditions and not compare(payload, pluginref.conditions, global_context): - logger.debug(f"Skipping plugin {pluginref.name} - conditions not met") - continue + # if pluginref.conditions and not compare(payload, pluginref.conditions, global_context): + # logger.debug(f"Skipping plugin {pluginref.name} - conditions not met") + # continue tmp_global_context = GlobalContext( request_id=global_context.request_id, @@ -208,7 +171,7 @@ async def execute( metadata={} if not global_context.metadata else deepcopy(global_context.metadata), ) # Get or create local context for this plugin - local_context_key = global_context.request_id + pluginref.uuid + local_context_key = global_context.request_id + hook_ref.plugin_ref.uuid if local_contexts and local_context_key in local_contexts: local_context = local_contexts[local_context_key] local_context.global_context = tmp_global_context @@ -216,68 +179,130 @@ async def execute( local_context = PluginContext(global_context=tmp_global_context) res_local_contexts[local_context_key] = local_context - try: - # Execute plugin with timeout protection - result = await self._execute_with_timeout(pluginref, plugin_run, current_payload or payload, local_context) - if local_context.global_context: - global_context.state.update(local_context.global_context.state) - global_context.metadata.update(local_context.global_context.metadata) - # Aggregate metadata from all plugins - if result.metadata: - combined_metadata.update(result.metadata) - - # Track payload modifications - if result.modified_payload is not None: - current_payload = result.modified_payload - - # Set plugin name in violation if present - if result.violation: - result.violation.plugin_name = pluginref.plugin.name - - # Handle plugin blocking the request - if not result.continue_processing: - if pluginref.plugin.mode == PluginMode.ENFORCE: - logger.warning(f"Plugin {pluginref.plugin.name} blocked request in enforce mode") - if violations_as_exceptions: - if result.violation: - plugin_name = result.violation.plugin_name - violation_reason = result.violation.reason - violation_desc = result.violation.description - violation_code = result.violation.code - raise PluginViolationError( - f"{plugin_run.__name__} blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", violation=result.violation - ) - raise PluginViolationError(f"{plugin_run.__name__} blocked by plugin") - return (PluginResult[T](continue_processing=False, modified_payload=current_payload, violation=result.violation, metadata=combined_metadata), res_local_contexts) - if pluginref.plugin.mode == PluginMode.PERMISSIVE: - logger.warning(f"Plugin {pluginref.plugin.name} would block (permissive mode): {result.violation.description if result.violation else 'No description'}") - - except asyncio.TimeoutError: - logger.error(f"Plugin {pluginref.name} timed out after {self.timeout}s") - if self.config.plugin_settings.fail_on_plugin_error or pluginref.plugin.mode == PluginMode.ENFORCE: - raise PluginError(error=PluginErrorModel(message=f"Plugin {pluginref.name} exceeded {self.timeout}s timeout", plugin_name=pluginref.name)) - # In permissive or enforce_ignore_error mode, continue with next plugin - continue - except PluginViolationError: - raise - except PluginError as pe: - logger.error(f"Plugin {pluginref.name} failed with error: {str(pe)}", exc_info=True) - if self.config.plugin_settings.fail_on_plugin_error or pluginref.plugin.mode == PluginMode.ENFORCE: - raise - except Exception as e: - logger.error(f"Plugin {pluginref.name} failed with error: {str(e)}", exc_info=True) - if self.config.plugin_settings.fail_on_plugin_error or pluginref.plugin.mode == PluginMode.ENFORCE: - raise PluginError(error=convert_exception_to_error(e, pluginref.name)) - # In permissive or enforce_ignore_error mode, continue with next plugin - continue + # Execute plugin with timeout protection + result = await self.execute_plugin( + hook_ref, + current_payload or payload, + local_context, + violations_as_exceptions, + global_context, + combined_metadata, + ) + # Track payload modifications + if result.modified_payload is not None: + current_payload = result.modified_payload + if not result.continue_processing and hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + return (result, res_local_contexts) + + return ( + PluginResult(continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata), + res_local_contexts, + ) + + async def execute_plugin( + self, + hook_ref: HookRef, + payload: PluginPayload, + local_context: PluginContext, + violations_as_exceptions: bool, + global_context: Optional[GlobalContext] = None, + combined_metadata: Optional[dict[str, Any]] = None, + ) -> PluginResult: + """Execute a single plugin with timeout protection. + + Args: + hook_ref: Hooking structure that contains the plugin and hook. + payload: The payload to be processed by plugins. + local_context: local context. + violations_as_exceptions: Raise violations as exceptions rather than as returns. + global_context: Shared context for all plugins containing request metadata. + combined_metadata: combination of the metadata of all plugins. - return (PluginResult[T](continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata), res_local_contexts) + Returns: + A tuple containing: + - PluginResult with processing status, modified payload, and metadata + - PluginContextTable with updated local contexts for each plugin - async def _execute_with_timeout(self, pluginref: PluginRef, plugin_run: Callable, payload: T, context: PluginContext) -> PluginResult[T]: + Raises: + PayloadSizeError: If the payload exceeds MAX_PAYLOAD_SIZE. + PluginError: If there is an error inside a plugin. + PluginViolationError: If a violation occurs and violation_as_exceptions is set. + """ + try: + # Execute plugin with timeout protection + result = await self._execute_with_timeout(hook_ref, payload, local_context) + if local_context.global_context and global_context: + global_context.state.update(local_context.global_context.state) + global_context.metadata.update(local_context.global_context.metadata) + # Aggregate metadata from all plugins + if result.metadata and combined_metadata is not None: + combined_metadata.update(result.metadata) + + # Track payload modifications + # if result.modified_payload is not None: + # current_payload = result.modified_payload + + # Set plugin name in violation if present + if result.violation: + result.violation.plugin_name = hook_ref.plugin_ref.plugin.name + + # Handle plugin blocking the request + if not result.continue_processing: + if hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + logger.warning("Plugin %s blocked request in enforce mode", hook_ref.plugin_ref.plugin.name) + if violations_as_exceptions: + if result.violation: + plugin_name = result.violation.plugin_name + violation_reason = result.violation.reason + violation_desc = result.violation.description + violation_code = result.violation.code + raise PluginViolationError( + f"{hook_ref.name} blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})", + violation=result.violation, + ) + raise PluginViolationError(f"{hook_ref.name} blocked by plugin") + return PluginResult( + continue_processing=False, + modified_payload=payload, + violation=result.violation, + metadata=combined_metadata, + ) + if hook_ref.plugin_ref.plugin.mode == PluginMode.PERMISSIVE: + logger.warning( + "Plugin %s would block (permissive mode): %s", + hook_ref.plugin_ref.plugin.name, + result.violation.description if result.violation else "No description", + ) + return result + except asyncio.TimeoutError as exc: + logger.error("Plugin %s timed out after %ds", hook_ref.plugin_ref.name, self.timeout) + if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin {hook_ref.plugin_ref.name} exceeded {self.timeout}s timeout", + plugin_name=hook_ref.plugin_ref.name, + ) + ) from exc + # In permissive or enforce_ignore_error mode, continue with next plugin + except PluginViolationError: + raise + except PluginError as pe: + logger.error("Plugin %s failed with error: %s", hook_ref.plugin_ref.name, str(pe), exc_info=True) + if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + raise + except Exception as e: + logger.error("Plugin %s failed with error: %s", hook_ref.plugin_ref.name, str(e), exc_info=True) + if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.plugin.mode == PluginMode.ENFORCE: + raise PluginError(error=convert_exception_to_error(e, hook_ref.plugin_ref.name)) from e + # In permissive or enforce_ignore_error mode, continue with next plugin + # Return a result indicating processing should continue despite the error + return PluginResult(continue_processing=True) + + async def _execute_with_timeout(self, hook_ref: HookRef, payload: PluginPayload, context: PluginContext) -> PluginResult: """Execute a plugin with timeout protection. Args: - pluginref: Reference to the plugin to execute. + hook_ref: Reference to the hook and plugin to execute. plugin_run: Function to execute the plugin. payload: Payload to process. context: Plugin execution context. @@ -288,7 +313,7 @@ async def _execute_with_timeout(self, pluginref: PluginRef, plugin_run: Callable Raises: asyncio.TimeoutError: If plugin exceeds timeout. """ - return await asyncio.wait_for(plugin_run(pluginref, payload, context), timeout=self.timeout) + return await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) def _validate_payload_size(self, payload: Any) -> None: """Validate that payload doesn't exceed size limits. @@ -312,154 +337,6 @@ def _validate_payload_size(self, payload: Any) -> None: raise PayloadSizeError(f"Result size {total_size} exceeds limit of {MAX_PAYLOAD_SIZE} bytes") -async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Call plugin's prompt pre-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, PromptPrehookPayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = PromptPrehookPayload(prompt_id="123", args={"key": "value"}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await pre_prompt_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.prompt_pre_fetch(payload, context) - - -async def post_prompt_fetch(plugin: PluginRef, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Call plugin's prompt post-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, PromptPosthookPayload, PluginContext, GlobalContext - >>> from mcpgateway.models import PromptResult - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> result = PromptResult(messages=[]) - >>> payload = PromptPosthookPayload(prompt_id="123", result=result) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await post_prompt_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.prompt_post_fetch(payload, context) - - -async def pre_tool_invoke(plugin: PluginRef, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Call plugin's tool pre-invoke hook. - - Args: - plugin: The plugin to execute. - payload: The tool payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ToolPreInvokePayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await pre_tool_invoke(plugin_ref, payload, context) - """ - return await plugin.plugin.tool_pre_invoke(payload, context) - - -async def post_tool_invoke(plugin: PluginRef, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Call plugin's tool post-invoke hook. - - Args: - plugin: The plugin to execute. - payload: The tool result payload to be analyzed. - context: Contextual information about the hook call. - - Returns: - The result of the plugin execution. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ToolPostInvokePayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await post_tool_invoke(plugin_ref, payload, context) - """ - return await plugin.plugin.tool_post_invoke(payload, context) - - -async def pre_resource_fetch(plugin: PluginRef, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Call plugin's resource pre-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The resource payload to be analyzed. - context: The plugin context. - - Returns: - ResourcePreFetchResult with processing status. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ResourcePreFetchPayload, PluginContext, GlobalContext - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = ResourcePreFetchPayload(uri="file:///data.txt", metadata={"cache": True}) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await pre_resource_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.resource_pre_fetch(payload, context) - - -async def post_resource_fetch(plugin: PluginRef, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Call plugin's resource post-fetch hook. - - Args: - plugin: The plugin to execute. - payload: The resource content payload to be analyzed. - context: The plugin context. - - Returns: - ResourcePostFetchResult with processing status. - - Examples: - >>> from mcpgateway.plugins.framework.base import PluginRef - >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, ResourcePostFetchPayload, PluginContext, GlobalContext - >>> from mcpgateway.models import ResourceContent - >>> # Assuming you have a plugin instance: - >>> # plugin_ref = PluginRef(my_plugin) - >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", text="Data") - >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - >>> context = PluginContext(global_context=GlobalContext(request_id="123")) - >>> # In async context: - >>> # result = await post_resource_fetch(plugin_ref, payload, context) - """ - return await plugin.plugin.resource_post_fetch(payload, context) - - class PluginManager: """Plugin manager for managing the plugin lifecycle. @@ -483,8 +360,9 @@ class PluginManager: >>> # print(f"Loaded {manager.plugin_count} plugins") >>> >>> # Execute prompt hooks - >>> from mcpgateway.plugins.framework import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(prompt_id="123", args={}) + >>> from mcpgateway.plugins.framework.models import GlobalContext + >>> from mcpgateway.plugins.mcp.entities.models import PromptPrehookPayload + >>> payload = PromptPrehookPayload(name="test", args={}) >>> context = GlobalContext(request_id="req-123") >>> # In async context: >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) @@ -498,16 +376,7 @@ class PluginManager: _initialized: bool = False _registry: PluginInstanceRegistry = PluginInstanceRegistry() _config: Config | None = None - _pre_prompt_executor: PluginExecutor[PromptPrehookPayload] = PluginExecutor[PromptPrehookPayload]() - _post_prompt_executor: PluginExecutor[PromptPosthookPayload] = PluginExecutor[PromptPosthookPayload]() - _pre_tool_executor: PluginExecutor[ToolPreInvokePayload] = PluginExecutor[ToolPreInvokePayload]() - _post_tool_executor: PluginExecutor[ToolPostInvokePayload] = PluginExecutor[ToolPostInvokePayload]() - _resource_pre_executor: PluginExecutor[ResourcePreFetchPayload] = PluginExecutor[ResourcePreFetchPayload]() - _resource_post_executor: PluginExecutor[ResourcePostFetchPayload] = PluginExecutor[ResourcePostFetchPayload]() - - # Context cleanup tracking - _context_store: Dict[str, Tuple[PluginContextTable, float]] = {} - _last_cleanup: float = 0 + _executor: PluginExecutor = PluginExecutor() def __init__(self, config: str = "", timeout: int = DEFAULT_PLUGIN_TIMEOUT): """Initialize plugin manager. @@ -528,23 +397,8 @@ def __init__(self, config: str = "", timeout: int = DEFAULT_PLUGIN_TIMEOUT): self._config = ConfigLoader.load_config(config) # Update executor timeouts - self._pre_prompt_executor.timeout = timeout - self._post_prompt_executor.timeout = timeout - self._pre_tool_executor.timeout = timeout - self._post_tool_executor.timeout = timeout - self._resource_pre_executor.timeout = timeout - self._resource_post_executor.timeout = timeout - self._pre_prompt_executor.config = self._config - self._post_prompt_executor.config = self._config - self._pre_tool_executor.config = self._config - self._post_tool_executor.config = self._config - self._resource_pre_executor.config = self._config - self._resource_post_executor.config = self._config - - # Initialize context tracking if not already done - if not hasattr(self, "_context_store"): - self._context_store = {} - self._last_cleanup = time.time() + self._executor.config = self._config + self._executor.timeout = timeout @property def config(self) -> Config | None: @@ -620,20 +474,20 @@ async def initialize(self) -> None: if plugin: self._registry.register(plugin) loaded_count += 1 - logger.info(f"Loaded plugin: {plugin_config.name} (mode: {plugin_config.mode})") + logger.info("Loaded plugin: %s (mode: %s)", plugin_config.name, plugin_config.mode) else: raise ValueError(f"Unable to instantiate plugin: {plugin_config.name}") else: - logger.info(f"Plugin: {plugin_config.name} is disabled. Ignoring.") + logger.info("Plugin: %s is disabled. Ignoring.", plugin_config.name) except Exception as e: # Clean error message without stack trace spam - logger.error(f"Failed to load plugin '{plugin_config.name}': {str(e)}") + logger.error("Failed to load plugin %s: {%s}", plugin_config.name, str(e)) # Let it crash gracefully with a clean error - raise RuntimeError(f"Plugin initialization failed: {plugin_config.name} - {str(e)}") + raise RuntimeError(f"Plugin initialization failed: {plugin_config.name} - {str(e)}") from e self._initialized = True - logger.info(f"Plugin manager initialized with {loaded_count} plugins") + logger.info("Plugin manager initialized with %s plugins", loaded_count) async def shutdown(self) -> None: """Shutdown all plugins and cleanup resources. @@ -657,275 +511,30 @@ async def shutdown(self) -> None: await self._registry.shutdown() # Clear context store - self._context_store.clear() # Reset state self._initialized = False logger.info("Plugin manager shutdown complete") - async def _cleanup_old_contexts(self) -> None: - """Remove contexts older than CONTEXT_MAX_AGE to prevent memory leaks. - - This method is called periodically during hook execution to clean up - stale contexts that are no longer needed. - """ - current_time = time.time() - - # Only cleanup every CONTEXT_CLEANUP_INTERVAL seconds - if current_time - self._last_cleanup < CONTEXT_CLEANUP_INTERVAL: - return - - # Find expired contexts - expired_keys = [key for key, (_, timestamp) in self._context_store.items() if current_time - timestamp > CONTEXT_MAX_AGE] - - # Remove expired contexts - for key in expired_keys: - del self._context_store[key] - - if expired_keys: - logger.info(f"Cleaned up {len(expired_keys)} expired plugin contexts") - - self._last_cleanup = current_time - - async def prompt_pre_fetch( - self, payload: PromptPrehookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[PromptPrehookResult, PluginContextTable | None]: - """Execute pre-fetch hooks before a prompt is retrieved and rendered. - - Args: - payload: The prompt payload containing name and arguments. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional existing contexts from previous executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - PromptPrehookResult with processing status and modified payload - - PluginContextTable with updated contexts for post-fetch hook - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> manager = PluginManager("plugins/config.yaml") - >>> # In async context: - >>> # await manager.initialize() - >>> - >>> from mcpgateway.plugins.framework import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload( - ... prompt_id="123", - ... name="greeting", - ... args={"user": "Alice"} - ... ) - >>> context = GlobalContext( - ... request_id="req-123", - ... user="alice@example.com" - ... ) - >>> - >>> # In async context: - >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) - >>> # if result.continue_processing: - >>> # # Proceed with prompt processing - >>> # modified_payload = result.modified_payload or payload - """ - # Cleanup old contexts periodically - await self._cleanup_old_contexts() - - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) - - # Execute plugins - result = await self._pre_prompt_executor.execute(plugins, payload, global_context, pre_prompt_fetch, pre_prompt_matches, local_contexts, violations_as_exceptions) - - # Store contexts for potential reuse - if result[1]: - self._context_store[global_context.request_id] = (result[1], time.time()) - - return result - - async def prompt_post_fetch( - self, payload: PromptPosthookPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[PromptPosthookResult, PluginContextTable | None]: - """Execute post-fetch hooks after a prompt is rendered. - - Args: - payload: The prompt result payload containing rendered messages. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional contexts from pre-fetch hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - PromptPosthookResult with processing status and modified result - - PluginContextTable with final contexts - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> # Continuing from prompt_pre_fetch example - >>> from mcpgateway.models import PromptResult, Message, TextContent, Role - >>> from mcpgateway.plugins.framework import PromptPosthookPayload, GlobalContext - >>> - >>> # Create a proper Message with TextContent - >>> message = Message( - ... role=Role.USER, - ... content=TextContent(type="text", text="Hello") - ... ) - >>> prompt_result = PromptResult(messages=[message]) - >>> - >>> post_payload = PromptPosthookPayload( - ... prompt_id="123", - ... result=prompt_result - ... ) - >>> - >>> manager = PluginManager("plugins/config.yaml") - >>> context = GlobalContext(request_id="req-123") - >>> - >>> # In async context: - >>> # result, _ = await manager.prompt_post_fetch( - >>> # post_payload, - >>> # context, - >>> # contexts # From pre_fetch - >>> # ) - >>> # if result.modified_payload: - >>> # # Use modified result - >>> # final_result = result.modified_payload.result - """ - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH) - - # Execute plugins - result = await self._post_prompt_executor.execute(plugins, payload, global_context, post_prompt_fetch, post_prompt_matches, local_contexts, violations_as_exceptions) - - # Clean up stored context after post-fetch - if global_context.request_id in self._context_store: - del self._context_store[global_context.request_id] - - return result - - async def tool_pre_invoke( - self, payload: ToolPreInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ToolPreInvokeResult, PluginContextTable | None]: - """Execute pre-invoke hooks before a tool is invoked. - - Args: - payload: The tool payload containing name and arguments. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional existing contexts from previous executions. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - ToolPreInvokeResult with processing status and modified payload - - PluginContextTable with updated contexts for post-invoke hook - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> manager = PluginManager("plugins/config.yaml") - >>> # In async context: - >>> # await manager.initialize() - >>> - >>> from mcpgateway.plugins.framework import ToolPreInvokePayload, GlobalContext - >>> payload = ToolPreInvokePayload( - ... name="calculator", - ... args={"operation": "add", "a": 5, "b": 3} - ... ) - >>> context = GlobalContext( - ... request_id="req-123", - ... user="alice@example.com" - ... ) - >>> - >>> # In async context: - >>> # result, contexts = await manager.tool_pre_invoke(payload, context) - >>> # if result.continue_processing: - >>> # # Proceed with tool invocation - >>> # modified_payload = result.modified_payload or payload - """ - # Cleanup old contexts periodically - await self._cleanup_old_contexts() - - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.TOOL_PRE_INVOKE) - - # Execute plugins - result = await self._pre_tool_executor.execute(plugins, payload, global_context, pre_tool_invoke, pre_tool_matches, local_contexts, violations_as_exceptions) - - # Store contexts for potential reuse - if result[1]: - self._context_store[global_context.request_id] = (result[1], time.time()) - - return result - - async def tool_post_invoke( - self, payload: ToolPostInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ToolPostInvokeResult, PluginContextTable | None]: - """Execute post-invoke hooks after a tool is invoked. - - Args: - payload: The tool result payload containing invocation results. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional contexts from pre-invoke hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. - - Returns: - A tuple containing: - - ToolPostInvokeResult with processing status and modified result - - PluginContextTable with final contexts - - Raises: - PayloadSizeError: If payload exceeds size limits. - - Examples: - >>> # Continuing from tool_pre_invoke example - >>> from mcpgateway.plugins.framework import ToolPostInvokePayload, GlobalContext - >>> - >>> post_payload = ToolPostInvokePayload( - ... name="calculator", - ... result={"result": 8, "status": "success"} - ... ) - >>> - >>> manager = PluginManager("plugins/config.yaml") - >>> context = GlobalContext(request_id="req-123") - >>> - >>> # In async context: - >>> # result, _ = await manager.tool_post_invoke( - >>> # post_payload, - >>> # context, - >>> # contexts # From pre_invoke - >>> # ) - >>> # if result.modified_payload: - >>> # # Use modified result - >>> # final_result = result.modified_payload.result - """ - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.TOOL_POST_INVOKE) - - # Execute plugins - result = await self._post_tool_executor.execute(plugins, payload, global_context, post_tool_invoke, post_tool_matches, local_contexts, violations_as_exceptions) - - # Clean up stored context after post-invoke - if global_context.request_id in self._context_store: - del self._context_store[global_context.request_id] - - return result - - async def resource_pre_fetch( - self, payload: ResourcePreFetchPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ResourcePreFetchResult, PluginContextTable | None]: - """Execute pre-fetch hooks before a resource is fetched. + async def invoke_hook( + self, + hook_type: str, + payload: PluginPayload, + global_context: GlobalContext, + local_contexts: Optional[PluginContextTable] = None, + violations_as_exceptions: bool = False, + ) -> tuple[PluginResult, PluginContextTable | None]: + """Invoke a set of plugins configured for the hook point in priority order. Args: - payload: The resource payload containing URI and metadata. + payload: The plugin payload for which the plugins will analyze and modify. global_context: Shared context for all plugins with request metadata. local_contexts: Optional existing contexts from previous hook executions. violations_as_exceptions: Raise violations as exceptions rather than as returns. Returns: A tuple containing: - - ResourcePreFetchResult with processing status and modified payload + - PluginResult with processing status and modified payload - PluginContextTable with plugin contexts for state management Examples: @@ -940,58 +549,72 @@ async def resource_pre_fetch( >>> # uri = result.modified_payload.uri """ # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.RESOURCE_PRE_FETCH) + hook_refs = self._registry.get_hook_refs_for_hook(hook_type=hook_type) # Execute plugins - result = await self._resource_pre_executor.execute(plugins, payload, global_context, pre_resource_fetch, pre_resource_matches, local_contexts, violations_as_exceptions) - - # Store context for potential post-fetch - if result[1]: - self._context_store[global_context.request_id] = (result[1], time.time()) - - # Periodic cleanup - await self._cleanup_old_contexts() + result = await self._executor.execute(hook_refs, payload, global_context, local_contexts, violations_as_exceptions) return result - async def resource_post_fetch( - self, payload: ResourcePostFetchPayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False - ) -> tuple[ResourcePostFetchResult, PluginContextTable | None]: - """Execute post-fetch hooks after a resource is fetched. + async def invoke_hook_for_plugin( + self, + name: str, + hook_type: str, + payload: Union[PluginPayload, dict[str, Any], str], + context: PluginContext, + violations_as_exceptions: bool = False, + payload_as_json=False, + ) -> PluginResult: + """Invoke a specific hook for a single named plugin. + + This method allows direct invocation of a particular plugin's hook by name, + bypassing the normal priority-ordered execution. Useful for testing individual + plugins or when specific plugin behavior needs to be triggered independently. Args: - payload: The resource content payload containing fetched data. - global_context: Shared context for all plugins with request metadata. - local_contexts: Optional contexts from pre-fetch hook execution. - violations_as_exceptions: Raise violations as exceptions rather than as returns. + name: The name of the plugin to invoke. + hook_type: The type of hook to execute (e.g., "prompt_pre_fetch"). + payload: The plugin payload to be processed by the hook. + context: Plugin execution context with local and global state. + violations_as_exceptions: Raise violations as exceptions rather than returns. + payload_as_json: payload passed in as json rather than pydantic. Returns: - A tuple containing: - - ResourcePostFetchResult with processing status and modified content - - PluginContextTable with updated plugin contexts + PluginResult with processing status, modified payload, and metadata. + + Raises: + PluginError: If the plugin or hook type cannot be found in the registry. Examples: >>> manager = PluginManager("plugins/config.yaml") >>> # In async context: >>> # await manager.initialize() - >>> # from mcpgateway.models import ResourceContent - >>> # content = ResourceContent(type="resource",id="res-1", uri="file:///data.txt", text="Data") - >>> # payload = ResourcePostFetchPayload("file:///data.txt", content) - >>> # context = GlobalContext(request_id="123", server_id="srv1") - >>> # contexts = self._context_store.get("123") # From pre-fetch - >>> # result, _ = await manager.resource_post_fetch(payload, context, contexts) - >>> # if result.continue_processing: - >>> # # Use modified result - >>> # final_content = result.modified_payload.content + >>> # payload = PromptPrehookPayload(name="test", args={}) + >>> # context = PluginContext(global_context=GlobalContext(request_id="123")) + >>> # result = await manager.invoke_hook_for_plugin( + >>> # name="auth_plugin", + >>> # hook_type="prompt_pre_fetch", + >>> # payload=payload, + >>> # context=context + >>> # ) """ - # Get plugins configured for this hook - plugins = self._registry.get_plugins_for_hook(HookType.RESOURCE_POST_FETCH) - - # Execute plugins - result = await self._resource_post_executor.execute(plugins, payload, global_context, post_resource_fetch, post_resource_matches, local_contexts, violations_as_exceptions) - - # Clean up stored context after post-fetch - if global_context.request_id in self._context_store: - del self._context_store[global_context.request_id] - - return result + hook_ref = self._registry.get_plugin_hook_by_name(name, hook_type) + if not hook_ref: + raise PluginError( + error=PluginErrorModel( + message=f"Unable to find {hook_type} for plugin {name}. Make sure the plugin is registered.", + plugin_name=name, + ) + ) + if payload_as_json: + plugin = hook_ref.plugin_ref.plugin + # When payload_as_json=True, payload should be str or dict + if isinstance(payload, (str, dict)): + pydantic_payload = plugin.json_to_payload(hook_type, payload) + return await self._executor.execute_plugin(hook_ref, pydantic_payload, context, violations_as_exceptions) + else: + raise ValueError(f"When payload_as_json=True, payload must be str or dict, got {type(payload)}") + # When payload_as_json=False, payload should already be a PluginPayload + if not isinstance(payload, PluginPayload): + raise ValueError(f"When payload_as_json=False, payload must be a PluginPayload, got {type(payload)}") + return await self._executor.execute_plugin(hook_ref, payload, context, violations_as_exceptions) diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index 1d02eb3c9..c9e790d15 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -13,50 +13,33 @@ from enum import Enum import os from pathlib import Path -from typing import Any, Generic, Optional, Self, TypeVar +from typing import Any, Generic, Optional, Self, TypeAlias, TypeVar # Third-Party -from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator, PrivateAttr, RootModel, ValidationInfo +from pydantic import ( + BaseModel, + Field, + field_serializer, + field_validator, + model_validator, + PrivateAttr, + ValidationInfo, +) # First-Party -from mcpgateway.models import PromptResult -from mcpgateway.plugins.framework.constants import AFTER, EXTERNAL_PLUGIN_TYPE, IGNORE_CONFIG_EXTERNAL, PYTHON_SUFFIX, SCRIPT, URL +from mcpgateway.plugins.framework.constants import ( + EXTERNAL_PLUGIN_TYPE, + IGNORE_CONFIG_EXTERNAL, + PYTHON_SUFFIX, + SCRIPT, + URL, +) from mcpgateway.schemas import TransportType from mcpgateway.validators import SecurityValidator T = TypeVar("T") -class HookType(str, Enum): - """MCP Forge Gateway hook points. - - Attributes: - prompt_pre_fetch: The prompt pre hook. - prompt_post_fetch: The prompt post hook. - tool_pre_invoke: The tool pre invoke hook. - tool_post_invoke: The tool post invoke hook. - resource_pre_fetch: The resource pre fetch hook. - resource_post_fetch: The resource post fetch hook. - - Examples: - >>> HookType.PROMPT_PRE_FETCH - - >>> HookType.PROMPT_PRE_FETCH.value - 'prompt_pre_fetch' - >>> HookType('prompt_post_fetch') - - >>> list(HookType) # doctest: +ELLIPSIS - [, , , , ...] - """ - - PROMPT_PRE_FETCH = "prompt_pre_fetch" - PROMPT_POST_FETCH = "prompt_post_fetch" - TOOL_PRE_INVOKE = "tool_pre_invoke" - TOOL_POST_INVOKE = "tool_post_invoke" - RESOURCE_PRE_FETCH = "resource_pre_fetch" - RESOURCE_POST_FETCH = "resource_post_fetch" - - class PluginMode(str, Enum): """Plugin modes of operation. @@ -262,7 +245,7 @@ class MCPTransportTLSConfigBase(BaseModel): ca_bundle: Optional[str] = Field(default=None, description="Path to CA bundle for verification") keyfile_password: Optional[str] = Field(default=None, description="Password for encrypted private key") - @field_validator("ca_bundle", "certfile", "keyfile", mode=AFTER) + @field_validator("ca_bundle", "certfile", "keyfile", mode="after") @classmethod def validate_path(cls, value: Optional[str]) -> Optional[str]: """Expand and validate file paths supplied in TLS configuration. @@ -284,7 +267,7 @@ def validate_path(cls, value: Optional[str]) -> Optional[str]: raise ValueError(f"TLS file path does not exist: {value}") return str(expanded) - @model_validator(mode=AFTER) + @model_validator(mode="after") def validate_cert_key(self) -> Self: # pylint: disable=bad-classmethod-argument """Ensure certificate and key options are consistent. @@ -421,7 +404,7 @@ class MCPServerConfig(BaseModel): tls (Optional[MCPServerTLSConfig]): Server-side TLS configuration. """ - host: str = Field(default="0.0.0.0", description="Server host to bind to") # nosec B104 + host: str = Field(default="0.0.0.0", description="Server host to bind to") port: int = Field(default=8000, description="Server port to bind to") tls: Optional[MCPServerTLSConfig] = Field(default=None, description="Server-side TLS configuration") @@ -499,7 +482,7 @@ class MCPClientConfig(BaseModel): script: Optional[str] = None tls: Optional[MCPClientTLSConfig] = None - @field_validator(URL, mode=AFTER) + @field_validator(URL, mode="after") @classmethod def validate_url(cls, url: str | None) -> str | None: """Validate a MCP url for streamable HTTP connections. @@ -518,7 +501,7 @@ def validate_url(cls, url: str | None) -> str | None: return result return url - @field_validator(SCRIPT, mode=AFTER) + @field_validator(SCRIPT, mode="after") @classmethod def validate_script(cls, script: str | None) -> str | None: """Validate an MCP stdio script. @@ -542,7 +525,7 @@ def validate_script(cls, script: str | None) -> str | None: raise ValueError(f"MCP server script {script} must have a .py or .sh suffix.") return script - @model_validator(mode=AFTER) + @model_validator(mode="after") def validate_tls_usage(self) -> Self: # pylint: disable=bad-classmethod-argument """Ensure TLS configuration is only used with HTTP-based transports. @@ -568,10 +551,10 @@ class PluginConfig(BaseModel): kind (str): The kind or type of plugin. Usually a fully qualified object type. namespace (str): The namespace where the plugin resides. version (str): version of the plugin. - hooks (list[str]): a list of the hook points where the plugin will be called. + hooks (list[str]): a list of the hook points where the plugin will be called. Default: []. tags (list[str]): a list of tags for making the plugin searchable. mode (bool): whether the plugin is active. - priority (int): indicates the order in which the plugin is run. Lower = higher priority. + priority (int): indicates the order in which the plugin is run. Lower = higher priority. Default: 100. conditions (Optional[list[PluginCondition]]): the conditions on which the plugin is run. applied_to (Optional[list[AppliedTo]]): the tools, fields, that the plugin is applied to. config (dict[str, Any]): the plugin specific configurations. @@ -584,16 +567,16 @@ class PluginConfig(BaseModel): kind: str namespace: Optional[str] = None version: Optional[str] = None - hooks: Optional[list[HookType]] = None - tags: Optional[list[str]] = None + hooks: list[str] = Field(default_factory=list) + tags: list[str] = Field(default_factory=list) mode: PluginMode = PluginMode.ENFORCE - priority: Optional[int] = None # Lower = higher priority - conditions: Optional[list[PluginCondition]] = None # When to apply + priority: int = 100 # Lower = higher priority + conditions: list[PluginCondition] = Field(default_factory=list) # When to apply applied_to: Optional[AppliedTo] = None # Fields to apply to. config: Optional[dict[str, Any]] = None mcp: Optional[MCPClientConfig] = None - @model_validator(mode=AFTER) + @model_validator(mode="after") def check_url_or_script_filled(self) -> Self: # pylint: disable=bad-classmethod-argument """Checks to see that at least one of url or script are set depending on MCP server configuration. @@ -613,7 +596,7 @@ def check_url_or_script_filled(self) -> Self: # pylint: disable=bad-classmethod raise ValueError(f"Plugin {self.name} must set transport type to either SSE or STREAMABLEHTTP or STDIO") return self - @model_validator(mode=AFTER) + @model_validator(mode="after") def check_config_and_external(self, info: ValidationInfo) -> Self: # pylint: disable=bad-classmethod-argument """Checks to see that a plugin's 'config' section is not defined if the kind is 'external'. This is because developers cannot override items in the plugin config section for external plugins. @@ -670,9 +653,9 @@ class PluginErrorModel(BaseModel): """ message: str + plugin_name: str code: Optional[str] = "" details: Optional[dict[str, Any]] = Field(default_factory=dict) - plugin_name: str class PluginViolation(BaseModel): @@ -765,61 +748,6 @@ class Config(BaseModel): server_settings: Optional[MCPServerConfig] = None -class PromptPrehookPayload(BaseModel): - """A prompt payload for a prompt prehook. - - Attributes: - prompt_id (str): The ID of the prompt template. - args (dic[str,str]): The prompt template arguments. - - Examples: - >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) - >>> payload.prompt_id - '123' - >>> payload.args - {'user': 'alice'} - >>> payload2 = PromptPrehookPayload(prompt_id="empty") - >>> payload2.args - {} - >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) - >>> p.prompt_id - '123' - >>> p.args["name"] - 'Bob' - """ - - prompt_id: str - args: Optional[dict[str, str]] = Field(default_factory=dict) - - -class PromptPosthookPayload(BaseModel): - """A prompt payload for a prompt posthook. - - Attributes: - prompt_id (str): The prompt ID. - result (PromptResult): The prompt after its template is rendered. - - Examples: - >>> from mcpgateway.models import PromptResult, Message, TextContent - >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) - >>> result = PromptResult(messages=[msg]) - >>> payload = PromptPosthookPayload(prompt_id="123", result=result) - >>> payload.prompt_id - '123' - >>> payload.result.messages[0].content.text - 'Hello World' - >>> from mcpgateway.models import PromptResult, Message, TextContent - >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) - >>> r = PromptResult(messages=[msg]) - >>> p = PromptPosthookPayload(prompt_id="123", result=r) - >>> p.prompt_id - '123' - """ - - prompt_id: str - result: PromptResult - - class PluginResult(BaseModel, Generic[T]): """A result of the plugin hook processing. The actual type is dependent on the hook. @@ -858,111 +786,6 @@ class PluginResult(BaseModel, Generic[T]): metadata: Optional[dict[str, Any]] = Field(default_factory=dict) -PromptPrehookResult = PluginResult[PromptPrehookPayload] -PromptPosthookResult = PluginResult[PromptPosthookPayload] - - -class HttpHeaderPayload(RootModel[dict[str, str]]): - """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" - - def __iter__(self): - """Custom iterator function to override root attribute. - - Returns: - A custom iterator for header dictionary. - """ - return iter(self.root) - - def __getitem__(self, item: str) -> str: - """Custom getitem function to override root attribute. - - Args: - item: The http header key. - - Returns: - A custom accesser for the header dictionary. - """ - return self.root[item] - - def __setitem__(self, key: str, value: str) -> None: - """Custom setitem function to override root attribute. - - Args: - key: The http header key. - value: The http header value to be set. - """ - self.root[key] = value - - def __len__(self): - """Custom len function to override root attribute. - - Returns: - The len of the header dictionary. - """ - return len(self.root) - - -HttpHeaderPayloadResult = PluginResult[HttpHeaderPayload] - - -class ToolPreInvokePayload(BaseModel): - """A tool payload for a tool pre-invoke hook. - - Args: - name: The tool name. - args: The tool arguments for invocation. - headers: The http pass through headers. - - Examples: - >>> payload = ToolPreInvokePayload(name="test_tool", args={"input": "data"}) - >>> payload.name - 'test_tool' - >>> payload.args - {'input': 'data'} - >>> payload2 = ToolPreInvokePayload(name="empty") - >>> payload2.args - {} - >>> p = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) - >>> p.name - 'calculator' - >>> p.args["operation"] - 'add' - - """ - - name: str - args: Optional[dict[str, Any]] = Field(default_factory=dict) - headers: Optional[HttpHeaderPayload] = None - - -class ToolPostInvokePayload(BaseModel): - """A tool payload for a tool post-invoke hook. - - Args: - name: The tool name. - result: The tool invocation result. - - Examples: - >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - >>> payload.name - 'calculator' - >>> payload.result - {'result': 8, 'status': 'success'} - >>> p = ToolPostInvokePayload(name="analyzer", result={"confidence": 0.95, "sentiment": "positive"}) - >>> p.name - 'analyzer' - >>> p.result["confidence"] - 0.95 - """ - - name: str - result: Any - - -ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] -ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] - - class GlobalContext(BaseModel): """The global context, which shared across all plugins. @@ -1061,58 +884,4 @@ def is_empty(self) -> bool: PluginContextTable = dict[str, PluginContext] - -class ResourcePreFetchPayload(BaseModel): - """A resource payload for a resource pre-fetch hook. - - Attributes: - uri: The resource URI. - metadata: Optional metadata for the resource request. - - Examples: - >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") - >>> payload.uri - 'file:///data.txt' - >>> payload2 = ResourcePreFetchPayload(uri="http://api/data", metadata={"Accept": "application/json"}) - >>> payload2.metadata - {'Accept': 'application/json'} - >>> p = ResourcePreFetchPayload(uri="file:///docs/readme.md", metadata={"version": "1.0"}) - >>> p.uri - 'file:///docs/readme.md' - >>> p.metadata["version"] - '1.0' - """ - - uri: str - metadata: Optional[dict[str, Any]] = Field(default_factory=dict) - - -class ResourcePostFetchPayload(BaseModel): - """A resource payload for a resource post-fetch hook. - - Attributes: - uri: The resource URI. - content: The fetched resource content. - - Examples: - >>> from mcpgateway.models import ResourceContent - >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", - ... text="Hello World") - >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - >>> payload.uri - 'file:///data.txt' - >>> payload.content.text - 'Hello World' - >>> from mcpgateway.models import ResourceContent - >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") - >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) - >>> p.uri - 'test://resource' - """ - - uri: str - content: Any - - -ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] -ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] +PluginPayload: TypeAlias = BaseModel diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index 519c26ada..0268b4c0f 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -14,8 +14,8 @@ from typing import Optional # First-Party -from mcpgateway.plugins.framework.base import Plugin, PluginRef -from mcpgateway.plugins.framework.models import HookType +from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef +from mcpgateway.plugins.framework.external.mcp.client import ExternalHookRef, ExternalPlugin # Use standard logging to avoid circular imports (plugins -> services -> plugins) logger = logging.getLogger(__name__) @@ -25,7 +25,8 @@ class PluginInstanceRegistry: """Registry for managing loaded plugins. Examples: - >>> from mcpgateway.plugins.framework import Plugin, PluginConfig, HookType + >>> from mcpgateway.plugins.framework import Plugin, PluginConfig + >>> from mcpgateway.plugins.mcp.entities import HookType >>> registry = PluginInstanceRegistry() >>> config = PluginConfig( ... name="test", @@ -60,8 +61,9 @@ def __init__(self) -> None: 0 """ self._plugins: dict[str, PluginRef] = {} - self._hooks: dict[HookType, list[PluginRef]] = defaultdict(list) - self._priority_cache: dict[HookType, list[PluginRef]] = {} + self._hooks: dict[str, list[HookRef]] = defaultdict(list) + self._hooks_by_name: dict[str, dict[str, HookRef]] = {} + self._priority_cache: dict[str, list[HookRef]] = {} def register(self, plugin: Plugin) -> None: """Register a plugin instance. @@ -79,13 +81,24 @@ def register(self, plugin: Plugin) -> None: self._plugins[plugin.name] = plugin_ref + plugin_hooks = {} + + external = isinstance(plugin, ExternalPlugin) + # Register hooks for hook_type in plugin.hooks: - self._hooks[hook_type].append(plugin_ref) + hook_ref: HookRef + if external: + hook_ref = ExternalHookRef(hook_type, plugin_ref) + else: + hook_ref = HookRef(hook_type, plugin_ref) + self._hooks[hook_type].append(hook_ref) + plugin_hooks[hook_type] = hook_ref # Invalidate priority cache for this hook self._priority_cache.pop(hook_type, None) + self._hooks_by_name[plugin.name] = plugin_hooks - logger.info(f"Registered plugin: {plugin.name} with hooks: {[h.name for h in plugin.hooks]}") + logger.info(f"Registered plugin: {plugin.name} with hooks: {[h for h in plugin.hooks]}") def unregister(self, plugin_name: str) -> None: """Unregister a plugin given its name. @@ -102,9 +115,12 @@ def unregister(self, plugin_name: str) -> None: plugin = self._plugins.pop(plugin_name) # Remove from hooks for hook_type in plugin.hooks: - self._hooks[hook_type] = [p for p in self._hooks[hook_type] if p.name != plugin_name] + self._hooks[hook_type] = [p for p in self._hooks[hook_type] if p.plugin_ref.name != plugin_name] self._priority_cache.pop(hook_type, None) + # Remove from hooks by name + self._hooks_by_name.pop(plugin_name, None) + logger.info(f"Unregistered plugin: {plugin_name}") def get_plugin(self, name: str) -> Optional[PluginRef]: @@ -118,7 +134,23 @@ def get_plugin(self, name: str) -> Optional[PluginRef]: """ return self._plugins.get(name) - def get_plugins_for_hook(self, hook_type: HookType) -> list[PluginRef]: + def get_plugin_hook_by_name(self, name: str, hook_type: str) -> Optional[HookRef]: + """Gets a hook reference for a particular plugin and hook type. + + Args: + name: plugin name. + hook_type: the hook type. + + Returns: + A hook reference for the plugin or None if not found. + """ + if name in self._hooks_by_name: + hooks = self._hooks_by_name[name] + if hook_type in hooks: + return hooks[hook_type] + return None + + def get_hook_refs_for_hook(self, hook_type: str) -> list[HookRef]: """Get all plugins for a specific hook, sorted by priority. Args: @@ -128,8 +160,8 @@ def get_plugins_for_hook(self, hook_type: HookType) -> list[PluginRef]: A list of plugin instances. """ if hook_type not in self._priority_cache: - plugins = sorted(self._hooks[hook_type], key=lambda p: p.priority) - self._priority_cache[hook_type] = plugins + hook_refs = sorted(self._hooks[hook_type], key=lambda p: p.plugin_ref.priority) + self._priority_cache[hook_type] = hook_refs return self._priority_cache[hook_type] def get_all_plugins(self) -> list[PluginRef]: diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py index 17f561fb1..50046277d 100644 --- a/mcpgateway/plugins/framework/utils.py +++ b/mcpgateway/plugins/framework/utils.py @@ -18,14 +18,17 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, PluginCondition, - PromptPosthookPayload, - PromptPrehookPayload, - ResourcePostFetchPayload, - ResourcePreFetchPayload, - ToolPostInvokePayload, - ToolPreInvokePayload, ) +# from mcpgateway.plugins.mcp.entities import ( +# PromptPosthookPayload, +# PromptPrehookPayload, +# ResourcePostFetchPayload, +# ResourcePreFetchPayload, +# ToolPostInvokePayload, +# ToolPreInvokePayload, +# ) + @cache # noqa def import_module(mod_name: str) -> ModuleType: @@ -111,208 +114,212 @@ def matches(condition: PluginCondition, context: GlobalContext) -> bool: return True -def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-prompt hooks. - - Args: - payload: the prompt prehook payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(prompt_id="id1", args={}) - >>> cond = PluginCondition(prompts={"id1"}) - >>> ctx = GlobalContext(request_id="req1") - >>> pre_prompt_matches(payload, [cond], ctx) - True - >>> payload2 = PromptPrehookPayload(prompt_id="id2", args={}) - >>> pre_prompt_matches(payload2, [cond], ctx) - False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.prompts and payload.prompt_id not in condition.prompts: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-prompt hooks. - - Args: - payload: the prompt posthook payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.prompts and payload.prompt_id not in condition.prompts: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def pre_tool_matches(payload: ToolPreInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-tool hooks. - - Args: - payload: the tool pre-invoke payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ToolPreInvokePayload, GlobalContext - >>> payload = ToolPreInvokePayload(name="calculator", args={}) - >>> cond = PluginCondition(tools={"calculator"}) - >>> ctx = GlobalContext(request_id="req1") - >>> pre_tool_matches(payload, [cond], ctx) - True - >>> payload2 = ToolPreInvokePayload(name="other", args={}) - >>> pre_tool_matches(payload2, [cond], ctx) - False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.tools and payload.name not in condition.tools: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def post_tool_matches(payload: ToolPostInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on post-tool hooks. - - Args: - payload: the tool post-invoke payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ToolPostInvokePayload, GlobalContext - >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8}) - >>> cond = PluginCondition(tools={"calculator"}) - >>> ctx = GlobalContext(request_id="req1") - >>> post_tool_matches(payload, [cond], ctx) - True - >>> payload2 = ToolPostInvokePayload(name="other", result={"result": 8}) - >>> post_tool_matches(payload2, [cond], ctx) - False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.tools and payload.name not in condition.tools: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def pre_resource_matches(payload: ResourcePreFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on pre-resource hooks. - - Args: - payload: the resource pre-fetch payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ResourcePreFetchPayload, GlobalContext - >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") - >>> cond = PluginCondition(resources={"file:///data.txt"}) - >>> ctx = GlobalContext(request_id="req1") - >>> pre_resource_matches(payload, [cond], ctx) - True - >>> payload2 = ResourcePreFetchPayload(uri="http://api/other") - >>> pre_resource_matches(payload2, [cond], ctx) - False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.resources and payload.uri not in condition.resources: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result - - -def post_resource_matches(payload: ResourcePostFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: - """Check for a match on post-resource hooks. - - Args: - payload: the resource post-fetch payload. - conditions: the conditions on the plugin that are required for execution. - context: the global context. - - Returns: - True if the plugin matches criteria. - - Examples: - >>> from mcpgateway.plugins.framework import PluginCondition, ResourcePostFetchPayload, GlobalContext - >>> from mcpgateway.models import ResourceContent - >>> content = ResourceContent(type="resource", id="123", uri="file:///data.txt", text="Test") - >>> payload = ResourcePostFetchPayload(id="123",uri="file:///data.txt", content=content) - >>> cond = PluginCondition(resources={"file:///data.txt"}) - >>> ctx = GlobalContext(request_id="req1") - >>> post_resource_matches(payload, [cond], ctx) - True - >>> payload2 = ResourcePostFetchPayload(uri="http://api/other", content=content) - >>> post_resource_matches(payload2, [cond], ctx) - False - """ - current_result = True - for index, condition in enumerate(conditions): - if not matches(condition, context): - current_result = False - - if condition.resources and payload.uri not in condition.resources: - current_result = False - if current_result: - return True - if index < len(conditions) - 1: - current_result = True - return current_result +# def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-prompt hooks. + +# Args: +# payload: the prompt prehook payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import PromptPrehookPayload +# >>> payload = PromptPrehookPayload(name="greeting", args={}) +# >>> cond = PluginCondition(prompts={"greeting"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> pre_prompt_matches(payload, [cond], ctx) +# True +# >>> payload2 = PromptPrehookPayload(name="other", args={}) +# >>> pre_prompt_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.prompts and payload.name not in condition.prompts: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-prompt hooks. + +# Args: +# payload: the prompt posthook payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.prompts and payload.name not in condition.prompts: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def pre_tool_matches(payload: ToolPreInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-tool hooks. + +# Args: +# payload: the tool pre-invoke payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ToolPreInvokePayload +# >>> payload = ToolPreInvokePayload(name="calculator", args={}) +# >>> cond = PluginCondition(tools={"calculator"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> pre_tool_matches(payload, [cond], ctx) +# True +# >>> payload2 = ToolPreInvokePayload(name="other", args={}) +# >>> pre_tool_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.tools and payload.name not in condition.tools: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def post_tool_matches(payload: ToolPostInvokePayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on post-tool hooks. + +# Args: +# payload: the tool post-invoke payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload +# >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8}) +# >>> cond = PluginCondition(tools={"calculator"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> post_tool_matches(payload, [cond], ctx) +# True +# >>> payload2 = ToolPostInvokePayload(name="other", result={"result": 8}) +# >>> post_tool_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.tools and payload.name not in condition.tools: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def pre_resource_matches(payload: ResourcePreFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on pre-resource hooks. + +# Args: +# payload: the resource pre-fetch payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ResourcePreFetchPayload +# >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") +# >>> cond = PluginCondition(resources={"file:///data.txt"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> pre_resource_matches(payload, [cond], ctx) +# True +# >>> payload2 = ResourcePreFetchPayload(uri="http://api/other") +# >>> pre_resource_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.resources and payload.uri not in condition.resources: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result + + +# def post_resource_matches(payload: ResourcePostFetchPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: +# """Check for a match on post-resource hooks. + +# Args: +# payload: the resource post-fetch payload. +# conditions: the conditions on the plugin that are required for execution. +# context: the global context. + +# Returns: +# True if the plugin matches criteria. + +# Examples: +# >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext +# >>> from mcpgateway.plugins.mcp.entities import ResourcePostFetchPayload, ResourceContent +# >>> content = ResourceContent(type="resource", uri="file:///data.txt", text="Test") +# >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) +# >>> cond = PluginCondition(resources={"file:///data.txt"}) +# >>> ctx = GlobalContext(request_id="req1") +# >>> post_resource_matches(payload, [cond], ctx) +# True +# >>> payload2 = ResourcePostFetchPayload(uri="http://api/other", content=content) +# >>> post_resource_matches(payload2, [cond], ctx) +# False +# """ +# current_result = True +# for index, condition in enumerate(conditions): +# if not matches(condition, context): +# current_result = False + +# if condition.resources and payload.uri not in condition.resources: +# current_result = False +# if current_result: +# return True +# if index < len(conditions) - 1: +# current_result = True +# return current_result diff --git a/mcpgateway/plugins/mcp/__init__.py b/mcpgateway/plugins/mcp/__init__.py new file mode 100644 index 000000000..c45913753 --- /dev/null +++ b/mcpgateway/plugins/mcp/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/mcp/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +MCP Plugins Package. +""" diff --git a/mcpgateway/plugins/mcp/entities/__init__.py b/mcpgateway/plugins/mcp/entities/__init__.py new file mode 100644 index 000000000..2e93aa073 --- /dev/null +++ b/mcpgateway/plugins/mcp/entities/__init__.py @@ -0,0 +1,49 @@ +"""Location: ./mcpgateway/plugins/mcp/entities/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +MCP Plugins Entities Package. +""" + +# First-Party +from mcpgateway.plugins.mcp.entities.models import ( + HttpHeaderPayload, + HttpHeaderPayloadResult, + HookType, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, + PromptResult, + ResourcePostFetchPayload, + ResourcePostFetchResult, + ResourcePreFetchPayload, + ResourcePreFetchResult, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) + +from mcpgateway.plugins.mcp.entities.base import MCPPlugin + +__all__ = [ + "HookType", + "HttpHeaderPayload", + "HttpHeaderPayloadResult", + "MCPPlugin", + "PromptPosthookPayload", + "PromptPosthookResult", + "PromptPrehookPayload", + "PromptPrehookResult", + "PromptResult", + "ResourcePostFetchPayload", + "ResourcePostFetchResult", + "ResourcePreFetchPayload", + "ResourcePreFetchResult", + "ToolPostInvokePayload", + "ToolPostInvokeResult", + "ToolPreInvokePayload", + "ToolPreInvokeResult", +] diff --git a/mcpgateway/plugins/mcp/entities/base.py b/mcpgateway/plugins/mcp/entities/base.py new file mode 100644 index 000000000..463d63202 --- /dev/null +++ b/mcpgateway/plugins/mcp/entities/base.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/mcp/entities/base.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Base plugin implementation. +This module implements the base plugin object. +It supports pre and post hooks AI safety, security and business processing +for the following locations in the server: +server_pre_register / server_post_register - for virtual server verification +tool_pre_invoke / tool_post_invoke - for guardrails +prompt_pre_fetch / prompt_post_fetch - for prompt filtering +resource_pre_fetch / resource_post_fetch - for content filtering +auth_pre_check / auth_post_check - for custom auth logic +federation_pre_sync / federation_post_sync - for gateway federation +""" + +# Standard + +# First-Party +from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.plugins.framework.models import PluginConfig, PluginContext +from mcpgateway.plugins.mcp.entities.models import ( + HookType, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, + ResourcePostFetchPayload, + ResourcePostFetchResult, + ResourcePreFetchPayload, + ResourcePreFetchResult, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) + + +def _register_mcp_hooks(): + """Register MCP hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hook_registry import get_hook_registry + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(HookType.PROMPT_PRE_FETCH): + registry.register_hook(HookType.PROMPT_PRE_FETCH, PromptPrehookPayload, PromptPrehookResult) + registry.register_hook(HookType.PROMPT_POST_FETCH, PromptPosthookPayload, PromptPosthookResult) + registry.register_hook(HookType.RESOURCE_PRE_FETCH, ResourcePreFetchPayload, ResourcePreFetchResult) + registry.register_hook(HookType.RESOURCE_POST_FETCH, ResourcePostFetchPayload, ResourcePostFetchResult) + registry.register_hook(HookType.TOOL_PRE_INVOKE, ToolPreInvokePayload, ToolPreInvokeResult) + registry.register_hook(HookType.TOOL_POST_INVOKE, ToolPostInvokePayload, ToolPostInvokeResult) + + +class MCPPlugin(Plugin): + """Base mcp plugin object for pre/post processing of inputs and outputs at various locations throughout the server. + + Examples: + >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode + >>> from mcpgateway.plugins.mcp.entities import HookType + >>> config = PluginConfig( + ... name="test_plugin", + ... description="Test plugin", + ... author="test", + ... kind="mcpgateway.plugins.framework.Plugin", + ... version="1.0.0", + ... hooks=[HookType.PROMPT_PRE_FETCH], + ... tags=["test"], + ... mode=PluginMode.ENFORCE, + ... priority=50 + ... ) + >>> plugin = MCPPlugin(config) + >>> plugin.name + 'test_plugin' + >>> plugin.priority + 50 + >>> plugin.mode + + >>> HookType.PROMPT_PRE_FETCH in plugin.hooks + True + """ + + def __init__(self, config: PluginConfig) -> None: + """Initialize a plugin with a configuration and context. + + Args: + config: The plugin configuration + + Examples: + >>> from mcpgateway.plugins.framework import PluginConfig + >>> from mcpgateway.plugins.mcp.entities import HookType + >>> config = PluginConfig( + ... name="simple_plugin", + ... description="Simple test", + ... author="test", + ... kind="test.Plugin", + ... version="1.0.0", + ... hooks=[HookType.PROMPT_POST_FETCH], + ... tags=["simple"] + ... ) + >>> plugin = MCPPlugin(config) + >>> plugin._config.name + 'simple_plugin' + """ + super().__init__(config) + + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """Plugin hook run before a prompt is retrieved and rendered. + + Args: + payload: The prompt payload to be analyzed. + context: contextual information about the hook call. Including why it was called. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError( + f"""'prompt_pre_fetch' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) + + async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: + """Plugin hook run after a prompt is rendered. + + Args: + payload: The prompt payload to be analyzed. + context: Contextual information about the hook call. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError( + f"""'prompt_post_fetch' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) + + async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: + """Plugin hook run before a tool is invoked. + + Args: + payload: The tool payload to be analyzed. + context: Contextual information about the hook call. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError( + f"""'tool_pre_invoke' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) + + async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: + """Plugin hook run after a tool is invoked. + + Args: + payload: The tool result payload to be analyzed. + context: Contextual information about the hook call. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError( + f"""'tool_post_invoke' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) + + async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: + """Plugin hook run before a resource is fetched. + + Args: + payload: The resource payload to be analyzed. + context: Contextual information about the hook call. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError( + f"""'resource_pre_fetch' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) + + async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: + """Plugin hook run after a resource is fetched. + + Args: + payload: The resource content payload to be analyzed. + context: Contextual information about the hook call. + + Raises: + NotImplementedError: needs to be implemented by sub class. + """ + raise NotImplementedError( + f"""'resource_post_fetch' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) + + +# Register MCP hooks when this module is imported +_register_mcp_hooks() diff --git a/mcpgateway/plugins/mcp/entities/models.py b/mcpgateway/plugins/mcp/entities/models.py new file mode 100644 index 000000000..3a3e63d88 --- /dev/null +++ b/mcpgateway/plugins/mcp/entities/models.py @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/mcp/entities/models.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for MCP plugins. +This module implements the pydantic models associated with +the base plugin layer including configurations, and contexts. +""" + +# Standard +from enum import Enum +from typing import Any, Optional + +# Third-Party +from pydantic import Field, RootModel + +# First-Party +from mcpgateway.models import PromptResult +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class HookType(str, Enum): + """MCP Forge Gateway hook points. + + Attributes: + prompt_pre_fetch: The prompt pre hook. + prompt_post_fetch: The prompt post hook. + tool_pre_invoke: The tool pre invoke hook. + tool_post_invoke: The tool post invoke hook. + resource_pre_fetch: The resource pre fetch hook. + resource_post_fetch: The resource post fetch hook. + + Examples: + >>> HookType.PROMPT_PRE_FETCH + + >>> HookType.PROMPT_PRE_FETCH.value + 'prompt_pre_fetch' + >>> HookType('prompt_post_fetch') + + >>> list(HookType) # doctest: +ELLIPSIS + [, , , , ...] + """ + + PROMPT_PRE_FETCH = "prompt_pre_fetch" + PROMPT_POST_FETCH = "prompt_post_fetch" + TOOL_PRE_INVOKE = "tool_pre_invoke" + TOOL_POST_INVOKE = "tool_post_invoke" + RESOURCE_PRE_FETCH = "resource_pre_fetch" + RESOURCE_POST_FETCH = "resource_post_fetch" + + +class PromptPrehookPayload(PluginPayload): + """A prompt payload for a prompt prehook. + + Attributes: + prompt_id (str): The ID of the prompt template. + args (dic[str,str]): The prompt template arguments. + + Examples: + >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) + >>> payload.prompt_id + '123' + >>> payload.args + {'user': 'alice'} + >>> payload2 = PromptPrehookPayload(prompt_id="empty") + >>> payload2.args + {} + >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) + >>> p.prompt_id + '123' + >>> p.args["name"] + 'Bob' + """ + + prompt_id: str + args: Optional[dict[str, str]] = Field(default_factory=dict) + + +class PromptPosthookPayload(PluginPayload): + """A prompt payload for a prompt posthook. + + Attributes: + prompt_id (str): The prompt ID. + result (PromptResult): The prompt after its template is rendered. + + Examples: + >>> from mcpgateway.models import PromptResult, Message, TextContent + >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) + >>> result = PromptResult(messages=[msg]) + >>> payload = PromptPosthookPayload(prompt_id="123", result=result) + >>> payload.prompt_id + '123' + >>> payload.result.messages[0].content.text + 'Hello World' + >>> from mcpgateway.models import PromptResult, Message, TextContent + >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) + >>> r = PromptResult(messages=[msg]) + >>> p = PromptPosthookPayload(prompt_id="123", result=r) + >>> p.prompt_id + '123' + """ + + prompt_id: str + result: PromptResult + + +PromptPrehookResult = PluginResult[PromptPrehookPayload] +PromptPosthookResult = PluginResult[PromptPosthookPayload] + + +class HttpHeaderPayload(RootModel[dict[str, str]]): + """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" + + def __iter__(self): + """Custom iterator function to override root attribute. + + Returns: + A custom iterator for header dictionary. + """ + return iter(self.root) + + def __getitem__(self, item: str) -> str: + """Custom getitem function to override root attribute. + + Args: + item: The http header key. + + Returns: + A custom accesser for the header dictionary. + """ + return self.root[item] + + def __setitem__(self, key: str, value: str) -> None: + """Custom setitem function to override root attribute. + + Args: + key: The http header key. + value: The http header value to be set. + """ + self.root[key] = value + + def __len__(self): + """Custom len function to override root attribute. + + Returns: + The len of the header dictionary. + """ + return len(self.root) + + +HttpHeaderPayloadResult = PluginResult[HttpHeaderPayload] + + +class ToolPreInvokePayload(PluginPayload): + """A tool payload for a tool pre-invoke hook. + + Args: + name: The tool name. + args: The tool arguments for invocation. + headers: The http pass through headers. + + Examples: + >>> payload = ToolPreInvokePayload(name="test_tool", args={"input": "data"}) + >>> payload.name + 'test_tool' + >>> payload.args + {'input': 'data'} + >>> payload2 = ToolPreInvokePayload(name="empty") + >>> payload2.args + {} + >>> p = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) + >>> p.name + 'calculator' + >>> p.args["operation"] + 'add' + + """ + + name: str + args: Optional[dict[str, Any]] = Field(default_factory=dict) + headers: Optional[HttpHeaderPayload] = None + + +class ToolPostInvokePayload(PluginPayload): + """A tool payload for a tool post-invoke hook. + + Args: + name: The tool name. + result: The tool invocation result. + + Examples: + >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) + >>> payload.name + 'calculator' + >>> payload.result + {'result': 8, 'status': 'success'} + >>> p = ToolPostInvokePayload(name="analyzer", result={"confidence": 0.95, "sentiment": "positive"}) + >>> p.name + 'analyzer' + >>> p.result["confidence"] + 0.95 + """ + + name: str + result: Any + + +ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] +ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] + + +class ResourcePreFetchPayload(PluginPayload): + """A resource payload for a resource pre-fetch hook. + + Attributes: + uri: The resource URI. + metadata: Optional metadata for the resource request. + + Examples: + >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") + >>> payload.uri + 'file:///data.txt' + >>> payload2 = ResourcePreFetchPayload(uri="http://api/data", metadata={"Accept": "application/json"}) + >>> payload2.metadata + {'Accept': 'application/json'} + >>> p = ResourcePreFetchPayload(uri="file:///docs/readme.md", metadata={"version": "1.0"}) + >>> p.uri + 'file:///docs/readme.md' + >>> p.metadata["version"] + '1.0' + """ + + uri: str + metadata: Optional[dict[str, Any]] = Field(default_factory=dict) + + +class ResourcePostFetchPayload(PluginPayload): + """A resource payload for a resource post-fetch hook. + + Attributes: + uri: The resource URI. + content: The fetched resource content. + + Examples: + >>> from mcpgateway.models import ResourceContent + >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", + ... text="Hello World") + >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) + >>> payload.uri + 'file:///data.txt' + >>> payload.content.text + 'Hello World' + >>> from mcpgateway.models import ResourceContent + >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") + >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) + >>> p.uri + 'test://resource' + """ + + uri: str + content: Any + + +ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] +ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index b0fcf94c7..c612ec8e4 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -36,7 +36,8 @@ from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.models import Message, PromptResult, Role, TextContent from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import GlobalContext, PluginManager +from mcpgateway.plugins.mcp.entities import HookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.metrics_common import build_top_performers @@ -690,8 +691,12 @@ async def get_prompt( if not request_id: request_id = uuid.uuid4().hex global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) - pre_result, context_table = await self._plugin_manager.prompt_pre_fetch( - payload=PromptPrehookPayload(prompt_id=str(prompt_id), args=arguments), global_context=global_context, local_contexts=None, violations_as_exceptions=True + pre_result, context_table = await self._plugin_manager.invoke_hook( + HookType.PROMPT_PRE_FETCH, + payload=PromptPrehookPayload(prompt_id=str(prompt_id), args=arguments), + global_context=global_context, + local_contexts=None, + violations_as_exceptions=True, ) # Use modified payload if provided @@ -755,8 +760,12 @@ async def get_prompt( raise PromptError(f"Failed to process prompt: {str(e)}") if self._plugin_manager: - post_result, _ = await self._plugin_manager.prompt_post_fetch( - payload=PromptPosthookPayload(prompt_id=str(prompt.id), result=result), global_context=global_context, local_contexts=context_table, violations_as_exceptions=True + post_result, _ = await self._plugin_manager.invoke_hook( + HookType.PROMPT_POST_FETCH, + payload=PromptPosthookPayload(prompt_id=str(prompt.id), result=result), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=True, ) # Use modified payload if provided result = post_result.modified_payload.result if post_result.modified_payload else result diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 9a31e5237..e0e926def 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -56,7 +56,8 @@ # Plugin support imports (conditional) try: # First-Party - from mcpgateway.plugins.framework import GlobalContext, PluginManager, ResourcePostFetchPayload, ResourcePreFetchPayload + from mcpgateway.plugins.framework import GlobalContext, PluginManager + from mcpgateway.plugins.mcp.entities import HookType, ResourcePostFetchPayload, ResourcePreFetchPayload PLUGINS_AVAILABLE = True except ImportError: @@ -735,7 +736,7 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request pre_payload = ResourcePreFetchPayload(uri=uri, metadata={}) # Execute pre-fetch hooks - pre_result, contexts = await self._plugin_manager.resource_pre_fetch(pre_payload, global_context, violations_as_exceptions=True) + pre_result, contexts = await self._plugin_manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, pre_payload, global_context, violations_as_exceptions=True) # Use modified URI if plugin changed it if pre_result.modified_payload: uri = pre_result.modified_payload.uri @@ -765,7 +766,9 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request post_payload = ResourcePostFetchPayload(uri=original_uri, content=content) # Execute post-fetch hooks - post_result, _ = await self._plugin_manager.resource_post_fetch(post_payload, global_context, contexts, violations_as_exceptions=True) # Pass contexts from pre-fetch + post_result, _ = await self._plugin_manager.invoke_hook( + HookType.RESOURCE_POST_FETCH, post_payload, global_context, contexts, violations_as_exceptions=True + ) # Pass contexts from pre-fetch # Use modified content if plugin changed it if post_result.modified_payload: diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 66919161d..c53237e53 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -49,8 +49,9 @@ from mcpgateway.models import Tool as PydanticTool from mcpgateway.models import ToolResult from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginError, PluginManager, PluginViolationError, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import GlobalContext, PluginError, PluginManager, PluginViolationError from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA +from mcpgateway.plugins.mcp.entities import HookType, HttpHeaderPayload, ToolPostInvokePayload, ToolPreInvokePayload from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager @@ -1002,7 +1003,8 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r if self._plugin_manager: tool_metadata = PydanticTool.model_validate(tool) global_context.metadata[TOOL_METADATA] = tool_metadata - pre_result, context_table = await self._plugin_manager.tool_pre_invoke( + pre_result, context_table = await self._plugin_manager.invoke_hook( + HookType.TOOL_PRE_INVOKE, payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), global_context=global_context, local_contexts=None, @@ -1153,7 +1155,8 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head if tool_gateway: gateway_metadata = PydanticGateway.model_validate(tool_gateway) global_context.metadata[GATEWAY_METADATA] = gateway_metadata - pre_result, context_table = await self._plugin_manager.tool_pre_invoke( + pre_result, context_table = await self._plugin_manager.invoke_hook( + HookType.TOOL_PRE_INVOKE, payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), global_context=global_context, local_contexts=None, @@ -1182,7 +1185,8 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head # Plugin hook: tool post-invoke if self._plugin_manager: - post_result, _ = await self._plugin_manager.tool_post_invoke( + post_result, _ = await self._plugin_manager.invoke_hook( + HookType.TOOL_POST_INVOKE, payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), global_context=global_context, local_contexts=context_table, diff --git a/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja b/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja index cdd8f3e80..e3a73631b 100644 --- a/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja +++ b/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja @@ -29,7 +29,7 @@ from mcpgateway.plugins.framework import ( {% else -%} {% set class_name = class_parts|join -%} {% endif -%} -class {{ class_name }}(Plugin): +class {{ class_name }}(MCPPlugin): """{{ description }}.""" def __init__(self, config: PluginConfig): diff --git a/plugin_templates/native/plugin.py.jinja b/plugin_templates/native/plugin.py.jinja index cdd8f3e80..e3a73631b 100644 --- a/plugin_templates/native/plugin.py.jinja +++ b/plugin_templates/native/plugin.py.jinja @@ -29,7 +29,7 @@ from mcpgateway.plugins.framework import ( {% else -%} {% set class_name = class_parts|join -%} {% endif -%} -class {{ class_name }}(Plugin): +class {{ class_name }}(MCPPlugin): """{{ description }}.""" def __init__(self, config: PluginConfig): diff --git a/plugins/README.md b/plugins/README.md index e3d2cc3d5..24e981824 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -196,7 +196,7 @@ from mcpgateway.plugins.framework.models import ( PluginResult ) -class MyPlugin(Plugin): +class MyPlugin(MCPPlugin): """Custom plugin implementation.""" async def tool_pre_invoke(self, payload: ToolPreInvokePayload) -> ToolPreInvokeResult: @@ -299,7 +299,7 @@ def validate_config(self) -> None: ### Resource Management ```python -class MyPlugin(Plugin): +class MyPlugin(MCPPlugin): def __init__(self, config: PluginConfig): super().__init__(config) self._session = None diff --git a/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py b/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py index 215e0e4b6..42fadbdf3 100644 --- a/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py +++ b/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py @@ -19,9 +19,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPrehookPayload, PromptPrehookResult, ResourcePostFetchPayload, @@ -104,7 +106,7 @@ def _normalize_text(text: str, cfg: AINormalizerConfig) -> str: return out -class AIArtifactsNormalizerPlugin(Plugin): +class AIArtifactsNormalizerPlugin(MCPPlugin): """Plugin to normalize AI-generated text artifacts in prompts, resources, and tool results.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/altk_json_processor/json_processor.py b/plugins/altk_json_processor/json_processor.py index 4d1cb25fa..b1664b49d 100644 --- a/plugins/altk_json_processor/json_processor.py +++ b/plugins/altk_json_processor/json_processor.py @@ -23,9 +23,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -36,7 +38,7 @@ logger = logging_service.get_logger(__name__) -class ALTKJsonProcessor(Plugin): +class ALTKJsonProcessor(MCPPlugin): """Uses JSON Processor from ALTK to extract data from long JSON responses.""" def __init__(self, config: PluginConfig): diff --git a/plugins/argument_normalizer/argument_normalizer.py b/plugins/argument_normalizer/argument_normalizer.py index b25732e25..8a98057c9 100644 --- a/plugins/argument_normalizer/argument_normalizer.py +++ b/plugins/argument_normalizer/argument_normalizer.py @@ -27,9 +27,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, @@ -515,7 +517,7 @@ def _normalize_value(value: Any, base_cfg: ArgumentNormalizerConfig, path: str, return value -class ArgumentNormalizerPlugin(Plugin): +class ArgumentNormalizerPlugin(MCPPlugin): """Argument Normalizer plugin for prompts and tools.""" def __init__(self, config: PluginConfig): diff --git a/plugins/cached_tool_result/cached_tool_result.py b/plugins/cached_tool_result/cached_tool_result.py index d4f3961d0..6d3674e19 100644 --- a/plugins/cached_tool_result/cached_tool_result.py +++ b/plugins/cached_tool_result/cached_tool_result.py @@ -25,9 +25,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -86,7 +88,7 @@ def _make_key(tool: str, args: dict | None, fields: Optional[List[str]]) -> str: return hashlib.sha256(raw.encode("utf-8")).hexdigest() -class CachedToolResultPlugin(Plugin): +class CachedToolResultPlugin(MCPPlugin): """Cache idempotent tool results (write-through).""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/circuit_breaker/circuit_breaker.py b/plugins/circuit_breaker/circuit_breaker.py index 57d748d41..61def4820 100644 --- a/plugins/circuit_breaker/circuit_breaker.py +++ b/plugins/circuit_breaker/circuit_breaker.py @@ -26,10 +26,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -138,7 +140,7 @@ def _is_error(result: Any) -> bool: return False -class CircuitBreakerPlugin(Plugin): +class CircuitBreakerPlugin(MCPPlugin): """Circuit breaker plugin to prevent cascading failures by tripping on high error rates.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/citation_validator/citation_validator.py b/plugins/citation_validator/citation_validator.py index fc7d71f0f..65c2bf1c4 100644 --- a/plugins/citation_validator/citation_validator.py +++ b/plugins/citation_validator/citation_validator.py @@ -24,10 +24,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -116,7 +118,7 @@ def _extract_links(text: str, limit: int) -> List[str]: return out -class CitationValidatorPlugin(Plugin): +class CitationValidatorPlugin(MCPPlugin): """Validates citations by checking URL reachability and content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/code_formatter/code_formatter.py b/plugins/code_formatter/code_formatter.py index fe2d51048..c62cdf2da 100644 --- a/plugins/code_formatter/code_formatter.py +++ b/plugins/code_formatter/code_formatter.py @@ -28,9 +28,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -145,7 +147,7 @@ def _format_by_language(result: Any, cfg: CodeFormatterConfig, language: str | N return _normalize_text(text, cfg) -class CodeFormatterPlugin(Plugin): +class CodeFormatterPlugin(MCPPlugin): """Lightweight formatter for post-invoke and resource content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/code_safety_linter/code_safety_linter.py b/plugins/code_safety_linter/code_safety_linter.py index 7c5d80032..a886fda8c 100644 --- a/plugins/code_safety_linter/code_safety_linter.py +++ b/plugins/code_safety_linter/code_safety_linter.py @@ -21,10 +21,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -48,7 +50,7 @@ class CodeSafetyConfig(BaseModel): ) -class CodeSafetyLinterPlugin(Plugin): +class CodeSafetyLinterPlugin(MCPPlugin): """Scan text outputs for dangerous code patterns.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/content_moderation/content_moderation.py b/plugins/content_moderation/content_moderation.py index 5a64eb4d7..2a3a9e75a 100644 --- a/plugins/content_moderation/content_moderation.py +++ b/plugins/content_moderation/content_moderation.py @@ -24,10 +24,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPrehookPayload, PromptPrehookResult, ToolPostInvokePayload, @@ -174,7 +176,7 @@ class ModerationResult(BaseModel): details: Dict[str, Any] = Field(default_factory=dict, description="Additional details") -class ContentModerationPlugin(Plugin): +class ContentModerationPlugin(MCPPlugin): """Plugin for advanced content moderation using multiple AI providers.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/deny_filter/deny.py b/plugins/deny_filter/deny.py index 7cf7e3790..0e598f921 100644 --- a/plugins/deny_filter/deny.py +++ b/plugins/deny_filter/deny.py @@ -12,7 +12,8 @@ from pydantic import BaseModel # First-Party -from mcpgateway.plugins.framework import Plugin, PluginConfig, PluginContext, PluginViolation, PromptPrehookPayload, PromptPrehookResult +from mcpgateway.plugins.framework import PluginConfig, PluginContext, PluginViolation +from mcpgateway.plugins.mcp.entities import MCPPlugin, PromptPrehookPayload, PromptPrehookResult from mcpgateway.services.logging_service import LoggingService # Initialize logging service first @@ -30,7 +31,7 @@ class DenyListConfig(BaseModel): words: list[str] -class DenyListPlugin(Plugin): +class DenyListPlugin(MCPPlugin): """Example deny list plugin.""" def __init__(self, config: PluginConfig): diff --git a/plugins/external/clamav_server/clamav_plugin.py b/plugins/external/clamav_server/clamav_plugin.py index efb1962a1..b593da62b 100644 --- a/plugins/external/clamav_server/clamav_plugin.py +++ b/plugins/external/clamav_server/clamav_plugin.py @@ -31,10 +31,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -119,7 +121,7 @@ def _clamd_instream_scan_unix(path: str, data: bytes, timeout: float) -> str: s.close() -class ClamAVRemotePlugin(Plugin): +class ClamAVRemotePlugin(MCPPlugin): """External ClamAV plugin for scanning resources and content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/external/llmguard/llmguardplugin/plugin.py b/plugins/external/llmguard/llmguardplugin/plugin.py index bf9d2a985..a548a313a 100644 --- a/plugins/external/llmguard/llmguardplugin/plugin.py +++ b/plugins/external/llmguard/llmguardplugin/plugin.py @@ -15,12 +15,15 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginError, PluginErrorModel, PluginViolation, +) +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -37,7 +40,7 @@ logger = logging_service.get_logger(__name__) -class LLMGuardPlugin(Plugin): +class LLMGuardPlugin(MCPPlugin): """A plugin that leverages the capabilities of llmguard library to apply guardrails on input and output prompts. Attributes: diff --git a/plugins/external/opa/opapluginfilter/plugin.py b/plugins/external/opa/opapluginfilter/plugin.py index 59826a9a5..60867f8a0 100644 --- a/plugins/external/opa/opapluginfilter/plugin.py +++ b/plugins/external/opa/opapluginfilter/plugin.py @@ -19,10 +19,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -63,7 +65,7 @@ class OPAResponseTemplates(str, Enum): HookPayload: TypeAlias = ToolPreInvokePayload | ToolPostInvokePayload | PromptPosthookPayload | PromptPrehookPayload | ResourcePreFetchPayload | ResourcePostFetchPayload -class OPAPluginFilter(Plugin): +class OPAPluginFilter(MCPPlugin): """An OPA plugin that enforces rego policies on requests and allows/denies requests as per policies.""" def __init__(self, config: PluginConfig): diff --git a/plugins/file_type_allowlist/file_type_allowlist.py b/plugins/file_type_allowlist/file_type_allowlist.py index d344c52f3..6a38492da 100644 --- a/plugins/file_type_allowlist/file_type_allowlist.py +++ b/plugins/file_type_allowlist/file_type_allowlist.py @@ -22,10 +22,12 @@ # First-Party from mcpgateway.models import ResourceContent from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -60,7 +62,7 @@ def _ext_from_uri(uri: str) -> str: return "" -class FileTypeAllowlistPlugin(Plugin): +class FileTypeAllowlistPlugin(MCPPlugin): """Block non-allowed file types for resources.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/harmful_content_detector/harmful_content_detector.py b/plugins/harmful_content_detector/harmful_content_detector.py index 7468cb0d1..c8c3a4900 100644 --- a/plugins/harmful_content_detector/harmful_content_detector.py +++ b/plugins/harmful_content_detector/harmful_content_detector.py @@ -23,10 +23,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPrehookPayload, PromptPrehookResult, ToolPostInvokePayload, @@ -119,7 +121,7 @@ def walk(obj: Any, path: str): yield from walk(value, "") -class HarmfulContentDetectorPlugin(Plugin): +class HarmfulContentDetectorPlugin(MCPPlugin): """Detects harmful content in prompts and tool outputs using keyword lexicons. This plugin scans for self-harm, violence, and hate categories. diff --git a/plugins/header_injector/header_injector.py b/plugins/header_injector/header_injector.py index 59173bdc3..c60cb8724 100644 --- a/plugins/header_injector/header_injector.py +++ b/plugins/header_injector/header_injector.py @@ -22,9 +22,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePreFetchPayload, ResourcePreFetchResult, ) @@ -57,7 +59,7 @@ def _should_apply(uri: str, prefixes: Optional[list[str]]) -> bool: return any(uri.startswith(p) for p in prefixes) -class HeaderInjectorPlugin(Plugin): +class HeaderInjectorPlugin(MCPPlugin): """Inject custom headers for resource fetching.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/html_to_markdown/html_to_markdown.py b/plugins/html_to_markdown/html_to_markdown.py index d3b92b11f..adc3799e5 100644 --- a/plugins/html_to_markdown/html_to_markdown.py +++ b/plugins/html_to_markdown/html_to_markdown.py @@ -20,9 +20,11 @@ # First-Party from mcpgateway.models import ResourceContent from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ) @@ -85,7 +87,7 @@ def _pre_fallback(m): return text.strip() -class HTMLToMarkdownPlugin(Plugin): +class HTMLToMarkdownPlugin(MCPPlugin): """Transform HTML ResourceContent to Markdown in `text` field.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/json_repair/json_repair.py b/plugins/json_repair/json_repair.py index 470209cc4..565a2914a 100644 --- a/plugins/json_repair/json_repair.py +++ b/plugins/json_repair/json_repair.py @@ -18,9 +18,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -70,7 +72,7 @@ def _repair(s: str) -> str | None: return None -class JSONRepairPlugin(Plugin): +class JSONRepairPlugin(MCPPlugin): """Repair JSON-like string outputs, returning corrected string if fixable.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/license_header_injector/license_header_injector.py b/plugins/license_header_injector/license_header_injector.py index 563cbee56..e8c398dc7 100644 --- a/plugins/license_header_injector/license_header_injector.py +++ b/plugins/license_header_injector/license_header_injector.py @@ -22,9 +22,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -88,7 +90,7 @@ def _inject_header(text: str, cfg: LicenseHeaderConfig, language: str) -> str: return f"{header_block}\n{text}" -class LicenseHeaderInjectorPlugin(Plugin): +class LicenseHeaderInjectorPlugin(MCPPlugin): """Inject a license header into textual code outputs.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/markdown_cleaner/markdown_cleaner.py b/plugins/markdown_cleaner/markdown_cleaner.py index be1c8e216..61f3b31ca 100644 --- a/plugins/markdown_cleaner/markdown_cleaner.py +++ b/plugins/markdown_cleaner/markdown_cleaner.py @@ -19,9 +19,11 @@ # First-Party from mcpgateway.models import Message, PromptResult, ResourceContent, TextContent from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -51,7 +53,7 @@ def _clean_md(text: str) -> str: return text.strip() -class MarkdownCleanerPlugin(Plugin): +class MarkdownCleanerPlugin(MCPPlugin): """Clean Markdown in prompts and resources.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/output_length_guard/output_length_guard.py b/plugins/output_length_guard/output_length_guard.py index 7c494987d..7497cb885 100644 --- a/plugins/output_length_guard/output_length_guard.py +++ b/plugins/output_length_guard/output_length_guard.py @@ -34,10 +34,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -98,7 +100,7 @@ def _truncate(value: str, max_chars: int, ellipsis: str) -> str: return value[:cut] + ell -class OutputLengthGuardPlugin(Plugin): +class OutputLengthGuardPlugin(MCPPlugin): """Guard tool outputs by length with block or truncate strategies.""" def __init__(self, config: PluginConfig): diff --git a/plugins/pii_filter/pii_filter.py b/plugins/pii_filter/pii_filter.py index 0f7215467..6ae59a5ed 100644 --- a/plugins/pii_filter/pii_filter.py +++ b/plugins/pii_filter/pii_filter.py @@ -19,10 +19,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -408,7 +410,7 @@ def _apply_mask(self, value: str, pii_type: PIIType, strategy: MaskingStrategy) return self.config.redaction_text -class PIIFilterPlugin(Plugin): +class PIIFilterPlugin(MCPPlugin): """PII Filter plugin for detecting and masking sensitive information.""" def __init__(self, config: PluginConfig): diff --git a/plugins/privacy_notice_injector/privacy_notice_injector.py b/plugins/privacy_notice_injector/privacy_notice_injector.py index 31e1d503e..80ad5546e 100644 --- a/plugins/privacy_notice_injector/privacy_notice_injector.py +++ b/plugins/privacy_notice_injector/privacy_notice_injector.py @@ -21,9 +21,11 @@ # First-Party from mcpgateway.models import Message, Role, TextContent from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, ) @@ -61,7 +63,7 @@ def _inject_text(existing: str, notice: str, placement: str) -> str: return existing -class PrivacyNoticeInjectorPlugin(Plugin): +class PrivacyNoticeInjectorPlugin(MCPPlugin): """Inject a privacy notice into prompt messages.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/rate_limiter/rate_limiter.py b/plugins/rate_limiter/rate_limiter.py index 78eccafa4..67720afa9 100644 --- a/plugins/rate_limiter/rate_limiter.py +++ b/plugins/rate_limiter/rate_limiter.py @@ -22,10 +22,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, @@ -114,7 +116,7 @@ def _allow(key: str, limit: Optional[str]) -> tuple[bool, dict[str, Any]]: return False, {"limited": True, "remaining": 0, "reset_in": window_seconds - (now - wnd.window_start)} -class RateLimiterPlugin(Plugin): +class RateLimiterPlugin(MCPPlugin): """Simple fixed-window rate limiter with per-user/tenant/tool buckets.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/regex_filter/search_replace.py b/plugins/regex_filter/search_replace.py index 79e4fc54f..506f1fafd 100644 --- a/plugins/regex_filter/search_replace.py +++ b/plugins/regex_filter/search_replace.py @@ -16,9 +16,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -52,7 +54,7 @@ class SearchReplaceConfig(BaseModel): words: list[SearchReplace] -class SearchReplacePlugin(Plugin): +class SearchReplacePlugin(MCPPlugin): """Example search replace plugin.""" def __init__(self, config: PluginConfig): diff --git a/plugins/resource_filter/resource_filter.py b/plugins/resource_filter/resource_filter.py index d5c191b35..7213e553e 100644 --- a/plugins/resource_filter/resource_filter.py +++ b/plugins/resource_filter/resource_filter.py @@ -19,11 +19,13 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginMode, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -33,7 +35,7 @@ ) -class ResourceFilterPlugin(Plugin): +class ResourceFilterPlugin(MCPPlugin): """Plugin that filters and modifies resources. This plugin demonstrates the use of resource hooks to: diff --git a/plugins/response_cache_by_prompt/response_cache_by_prompt.py b/plugins/response_cache_by_prompt/response_cache_by_prompt.py index fa7821817..6fc01533c 100644 --- a/plugins/response_cache_by_prompt/response_cache_by_prompt.py +++ b/plugins/response_cache_by_prompt/response_cache_by_prompt.py @@ -28,9 +28,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -123,7 +125,7 @@ class _Entry: expires_at: float -class ResponseCacheByPromptPlugin(Plugin): +class ResponseCacheByPromptPlugin(MCPPlugin): """Approximate response cache keyed by prompt similarity.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/retry_with_backoff/retry_with_backoff.py b/plugins/retry_with_backoff/retry_with_backoff.py index ef63ee87f..1cdbd9dd4 100644 --- a/plugins/retry_with_backoff/retry_with_backoff.py +++ b/plugins/retry_with_backoff/retry_with_backoff.py @@ -17,9 +17,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -43,7 +45,7 @@ class RetryPolicyConfig(BaseModel): retry_on_status: list[int] = Field(default_factory=lambda: [429, 500, 502, 503, 504]) -class RetryWithBackoffPlugin(Plugin): +class RetryWithBackoffPlugin(MCPPlugin): """Attach retry/backoff policy in metadata for observability/orchestration.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/robots_license_guard/robots_license_guard.py b/plugins/robots_license_guard/robots_license_guard.py index 3643688bf..820474930 100644 --- a/plugins/robots_license_guard/robots_license_guard.py +++ b/plugins/robots_license_guard/robots_license_guard.py @@ -23,10 +23,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -87,7 +89,7 @@ def _parse_meta(text: str) -> dict[str, str]: return found -class RobotsLicenseGuardPlugin(Plugin): +class RobotsLicenseGuardPlugin(MCPPlugin): """Honors robots/noai/license meta tags in fetched HTML content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/safe_html_sanitizer/safe_html_sanitizer.py b/plugins/safe_html_sanitizer/safe_html_sanitizer.py index 1d4364f0f..ebf53d106 100644 --- a/plugins/safe_html_sanitizer/safe_html_sanitizer.py +++ b/plugins/safe_html_sanitizer/safe_html_sanitizer.py @@ -30,9 +30,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ) @@ -276,7 +278,7 @@ def _to_text(html_str: str) -> str: return re.sub(r"\n{3,}", "\n\n", no_tags).strip() -class SafeHTMLSanitizerPlugin(Plugin): +class SafeHTMLSanitizerPlugin(MCPPlugin): """Sanitizes HTML content to remove XSS vectors and dangerous elements.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/schema_guard/schema_guard.py b/plugins/schema_guard/schema_guard.py index e8962b970..b652aa8ff 100644 --- a/plugins/schema_guard/schema_guard.py +++ b/plugins/schema_guard/schema_guard.py @@ -20,10 +20,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -103,7 +105,7 @@ def _validate(data: Any, schema: Dict[str, Any]) -> list[str]: return errors -class SchemaGuardPlugin(Plugin): +class SchemaGuardPlugin(MCPPlugin): """Validate tool args and results using a simple schema subset.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/secrets_detection/secrets_detection.py b/plugins/secrets_detection/secrets_detection.py index 1d2198a6a..ecdf3e8f1 100644 --- a/plugins/secrets_detection/secrets_detection.py +++ b/plugins/secrets_detection/secrets_detection.py @@ -23,10 +23,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPrehookPayload, PromptPrehookResult, ResourcePostFetchPayload, @@ -159,7 +161,7 @@ def _scan_container(container: Any, cfg: SecretsDetectionConfig) -> Tuple[int, A return total, container, all_findings -class SecretsDetectionPlugin(Plugin): +class SecretsDetectionPlugin(MCPPlugin): """Detect and optionally redact secrets in inputs/outputs.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/sql_sanitizer/sql_sanitizer.py b/plugins/sql_sanitizer/sql_sanitizer.py index 5ad84de02..c7b62b022 100644 --- a/plugins/sql_sanitizer/sql_sanitizer.py +++ b/plugins/sql_sanitizer/sql_sanitizer.py @@ -26,10 +26,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, @@ -157,7 +159,7 @@ def _scan_args(args: dict[str, Any] | None, cfg: SQLSanitizerConfig) -> tuple[li return issues, scanned -class SQLSanitizerPlugin(Plugin): +class SQLSanitizerPlugin(MCPPlugin): """Block or sanitize risky SQL statements in inputs.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/summarizer/summarizer.py b/plugins/summarizer/summarizer.py index 9ba229a54..ea936a27d 100644 --- a/plugins/summarizer/summarizer.py +++ b/plugins/summarizer/summarizer.py @@ -23,9 +23,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -260,7 +262,7 @@ def _maybe_get_text_from_result(result: Any) -> Optional[str]: return result if isinstance(result, str) else None -class SummarizerPlugin(Plugin): +class SummarizerPlugin(MCPPlugin): """Plugin to summarize long text content using LLM providers.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/timezone_translator/timezone_translator.py b/plugins/timezone_translator/timezone_translator.py index af644ca7d..2951b9eb6 100644 --- a/plugins/timezone_translator/timezone_translator.py +++ b/plugins/timezone_translator/timezone_translator.py @@ -25,9 +25,11 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -131,7 +133,7 @@ def _walk_and_translate(value: Any, source: ZoneInfo, target: ZoneInfo, fields: return value -class TimezoneTranslatorPlugin(Plugin): +class TimezoneTranslatorPlugin(MCPPlugin): """Converts detected ISO timestamps between server and user timezones.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/url_reputation/url_reputation.py b/plugins/url_reputation/url_reputation.py index 35bc2e82d..50023e73a 100644 --- a/plugins/url_reputation/url_reputation.py +++ b/plugins/url_reputation/url_reputation.py @@ -20,10 +20,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ResourcePreFetchPayload, ResourcePreFetchResult, ) @@ -41,7 +43,7 @@ class URLReputationConfig(BaseModel): blocked_patterns: List[str] = Field(default_factory=list) -class URLReputationPlugin(Plugin): +class URLReputationPlugin(MCPPlugin): """Static allow/deny URL reputation checks.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/vault/vault_plugin.py b/plugins/vault/vault_plugin.py index 994b9eaf4..4683606d3 100644 --- a/plugins/vault/vault_plugin.py +++ b/plugins/vault/vault_plugin.py @@ -22,13 +22,15 @@ # First-Party from mcpgateway.db import get_db from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, + HttpHeaderPayload, ToolPreInvokePayload, ToolPreInvokeResult, ) -from mcpgateway.plugins.framework.models import HttpHeaderPayload from mcpgateway.services.gateway_service import GatewayService from mcpgateway.services.logging_service import LoggingService @@ -75,7 +77,7 @@ class VaultConfig(BaseModel): system_handling: SystemHandling = SystemHandling.TAG -class Vault(Plugin): +class Vault(MCPPlugin): """Vault plugin that based on OAUTH2 config that protects a tool will generate bearer token based on a vault saved token""" def __init__(self, config: PluginConfig): diff --git a/plugins/virus_total_checker/virus_total_checker.py b/plugins/virus_total_checker/virus_total_checker.py index 5b10f696f..b506916f3 100644 --- a/plugins/virus_total_checker/virus_total_checker.py +++ b/plugins/virus_total_checker/virus_total_checker.py @@ -31,10 +31,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -332,7 +334,7 @@ def _apply_overrides(url: str, host: str | None, cfg: VirusTotalConfig) -> str | return None -class VirusTotalURLCheckerPlugin(Plugin): +class VirusTotalURLCheckerPlugin(MCPPlugin): """Query VirusTotal for URL/domain/IP verdicts and block on policy breaches.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/watchdog/watchdog.py b/plugins/watchdog/watchdog.py index e61711f4d..1fcf12b2d 100644 --- a/plugins/watchdog/watchdog.py +++ b/plugins/watchdog/watchdog.py @@ -23,10 +23,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -48,7 +50,7 @@ class WatchdogConfig(BaseModel): tool_overrides: Dict[str, Dict[str, Any]] = {} -class WatchdogPlugin(Plugin): +class WatchdogPlugin(MCPPlugin): """Records tool execution duration and enforces maximum runtime policy.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/webhook_notification/webhook_notification.py b/plugins/webhook_notification/webhook_notification.py index f76577b0c..4c2a686c1 100644 --- a/plugins/webhook_notification/webhook_notification.py +++ b/plugins/webhook_notification/webhook_notification.py @@ -27,10 +27,12 @@ # First-Party from mcpgateway.plugins.framework import ( - Plugin, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -117,7 +119,7 @@ class WebhookNotificationConfig(BaseModel): max_payload_size: int = Field(default=1000, description="Max payload size to include in notifications") -class WebhookNotificationPlugin(Plugin): +class WebhookNotificationPlugin(MCPPlugin): """Plugin for sending webhook notifications on events and violations.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins_rust/docs/implementation-guide.md b/plugins_rust/docs/implementation-guide.md index 6cb71a431..efd520730 100644 --- a/plugins_rust/docs/implementation-guide.md +++ b/plugins_rust/docs/implementation-guide.md @@ -314,7 +314,7 @@ except ImportError: RUST_AVAILABLE = False -class PIIFilterPlugin(Plugin): +class PIIFilterPlugin(MCPPlugin): """PII Filter with automatic Rust/Python selection.""" def __init__(self, config: PluginConfig): diff --git a/tests/integration/test_resource_plugin_integration.py b/tests/integration/test_resource_plugin_integration.py index 30bfa7e79..1582f6610 100644 --- a/tests/integration/test_resource_plugin_integration.py +++ b/tests/integration/test_resource_plugin_integration.py @@ -44,9 +44,19 @@ def resource_service_with_mock_plugins(self): # Standard from unittest.mock import AsyncMock + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + mock_manager = MagicMock() mock_manager._initialized = True mock_manager.initialize = AsyncMock() + # Add default invoke_hook mock that returns success + mock_manager.invoke_hook = AsyncMock( + return_value=( + PluginResult(continue_processing=True, modified_payload=None), + None # contexts + ) + ) MockPluginManager.return_value = mock_manager service = ResourceService() service._plugin_manager = mock_manager @@ -57,20 +67,7 @@ async def test_full_resource_lifecycle_with_plugins(self, test_db, resource_serv """Test complete resource lifecycle with plugin hooks.""" service, mock_manager = resource_service_with_mock_plugins - # Configure mock plugin manager for all operations - # Standard - from unittest.mock import AsyncMock - - pre_result = MagicMock() - pre_result.continue_processing = True - pre_result.modified_payload = None - - post_result = MagicMock() - post_result.continue_processing = True - post_result.modified_payload = None - - mock_manager.resource_pre_fetch = AsyncMock(return_value=(pre_result, {"context": "data"})) - mock_manager.resource_post_fetch = AsyncMock(return_value=(post_result, None)) + # The default invoke_hook from fixture will work fine for this test # 1. Create a resource resource_data = ResourceCreate( @@ -96,8 +93,8 @@ async def test_full_resource_lifecycle_with_plugins(self, test_db, resource_serv ) assert content is not None - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called (pre and post fetch) + assert mock_manager.invoke_hook.call_count >= 2 # 3. List resources resources = await service.list_resources(test_db) @@ -135,7 +132,7 @@ async def test_resource_filtering_integration(self, test_db): # Use real plugin manager but mock its initialization with patch("mcpgateway.services.resource_service.PluginManager") as MockPluginManager: # First-Party - from mcpgateway.plugins.framework.models import ( + from mcpgateway.plugins.mcp.entities import ( ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchResult, @@ -153,58 +150,67 @@ async def initialize(self): def initialized(self) -> bool: return self._initialized - async def resource_pre_fetch(self, payload, global_context, violations_as_exceptions): - # Allow test:// protocol - if payload.uri.startswith("test://"): + async def invoke_hook(self, hook_type, payload, global_context, local_contexts=None, **kwargs): + # First-Party + from mcpgateway.plugins.mcp.entities import HookType + + if hook_type == HookType.RESOURCE_PRE_FETCH: + # Allow test:// protocol + if payload.uri.startswith("test://"): + return ( + ResourcePreFetchResult( + continue_processing=True, + modified_payload=payload, + ), + {"validated": True}, + ) + else: + # First-Party + from mcpgateway.plugins.framework.models import PluginViolation + + raise PluginViolationError( + message="Protocol not allowed", + violation=PluginViolation( + reason="Protocol not allowed", + description="Protocol is not in the allowed list", + code="PROTOCOL_BLOCKED", + details={"protocol": payload.uri.split(":")[0], "uri": payload.uri}, + ), + ) + elif hook_type == HookType.RESOURCE_POST_FETCH: + # Filter sensitive content + if payload.content and payload.content.text: + filtered_text = payload.content.text.replace( + "password: secret123", + "password: [REDACTED]", + ) + filtered_content = ResourceContent( + id=payload.content.id, + type=payload.content.type, + uri=payload.content.uri, + text=filtered_text, + ) + modified_payload = ResourcePostFetchPayload( + uri=payload.uri, + content=filtered_content, + ) + return ( + ResourcePostFetchResult( + continue_processing=True, + modified_payload=modified_payload, + ), + None, + ) return ( - ResourcePreFetchResult( - continue_processing=True, - modified_payload=payload, - ), - {"validated": True}, + ResourcePostFetchResult(continue_processing=True), + None, ) else: + # Other hook types - just return success # First-Party - from mcpgateway.plugins.framework.models import PluginViolation - - raise PluginViolationError( - message="Protocol not allowed", - violation=PluginViolation( - reason="Protocol not allowed", - description="Protocol is not in the allowed list", - code="PROTOCOL_BLOCKED", - details={"protocol": payload.uri.split(":")[0], "uri": payload.uri}, - ), - ) + from mcpgateway.plugins.framework.models import PluginResult - async def resource_post_fetch(self, payload, global_context, contexts, violations_as_exceptions): - # Filter sensitive content - if payload.content and payload.content.text: - filtered_text = payload.content.text.replace( - "password: secret123", - "password: [REDACTED]", - ) - filtered_content = ResourceContent( - id=payload.content.id, - type=payload.content.type, - uri=payload.content.uri, - text=filtered_text, - ) - modified_payload = ResourcePostFetchPayload( - uri=payload.uri, - content=filtered_content, - ) - return ( - ResourcePostFetchResult( - continue_processing=True, - modified_payload=modified_payload, - ), - None, - ) - return ( - ResourcePostFetchResult(continue_processing=True), - None, - ) + return (PluginResult(continue_processing=True), None) MockPluginManager.return_value = MockFilterManager("test.yaml") service = ResourceService() @@ -257,29 +263,37 @@ async def test_plugin_context_flow(self, test_db, resource_service_with_mock_plu service, mock_manager = resource_service_with_mock_plugins # Track context flow + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + contexts_from_pre = {"plugin_data": "test_value", "validated": True} - async def pre_fetch_side_effect(payload, global_context, violations_as_exceptions): - # Verify global context - assert global_context.request_id == "integration-test-123" - assert global_context.user == "integration-user" - assert global_context.server_id == "server-123" - return ( - MagicMock(continue_processing=True, modified_payload=None), - contexts_from_pre, - ) - - async def post_fetch_side_effect(payload, global_context, contexts, violations_as_exceptions): - # Verify contexts from pre-fetch - assert contexts == contexts_from_pre - assert contexts["plugin_data"] == "test_value" - return ( - MagicMock(continue_processing=True), - None, - ) - - mock_manager.resource_pre_fetch.side_effect = pre_fetch_side_effect - mock_manager.resource_post_fetch.side_effect = post_fetch_side_effect + async def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.RESOURCE_PRE_FETCH: + # Verify global context + assert global_context.request_id == "integration-test-123" + assert global_context.user == "integration-user" + assert global_context.server_id == "server-123" + return ( + PluginResult(continue_processing=True, modified_payload=None), + contexts_from_pre, + ) + elif hook_type == HookType.RESOURCE_POST_FETCH: + # Verify contexts from pre-fetch + assert local_contexts == contexts_from_pre + assert local_contexts["plugin_data"] == "test_value" + return ( + PluginResult(continue_processing=True), + None, + ) + else: + return (PluginResult(continue_processing=True), None) + + # Standard + from unittest.mock import AsyncMock + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) # Create and read a resource resource = ResourceCreate( @@ -297,29 +311,15 @@ async def post_fetch_side_effect(payload, global_context, contexts, violations_a server_id="server-123", ) - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called + assert mock_manager.invoke_hook.call_count >= 2 @pytest.mark.asyncio async def test_template_resource_with_plugins(self, test_db, resource_service_with_mock_plugins): """Test resources work with plugins using template-like content.""" service, mock_manager = resource_service_with_mock_plugins - # Configure plugin manager - # Standard - from unittest.mock import AsyncMock - - # Create proper mock results - pre_result = MagicMock() - pre_result.continue_processing = True - pre_result.modified_payload = None - - post_result = MagicMock() - post_result.continue_processing = True - post_result.modified_payload = None - - mock_manager.resource_pre_fetch = AsyncMock(return_value=(pre_result, {"context": "data"})) - mock_manager.resource_post_fetch = AsyncMock(return_value=(post_result, None)) + # The default invoke_hook from fixture will work fine # Create a regular resource with template-like content resource = ResourceCreate( @@ -332,24 +332,15 @@ async def test_template_resource_with_plugins(self, test_db, resource_service_wi content = await service.read_resource(test_db, created.id) assert content.text == "Data for ID: 123" - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called + assert mock_manager.invoke_hook.call_count >= 2 @pytest.mark.asyncio async def test_inactive_resource_handling(self, test_db, resource_service_with_mock_plugins): """Test that inactive resources are handled correctly with plugins.""" service, mock_manager = resource_service_with_mock_plugins - # Configure mock plugin manager - # Standard - from unittest.mock import AsyncMock - - pre_result = MagicMock() - pre_result.continue_processing = True - pre_result.modified_payload = None - - mock_manager.resource_pre_fetch = AsyncMock(return_value=(pre_result, None)) - mock_manager.resource_post_fetch = AsyncMock() + # The default invoke_hook from fixture will work fine # Create a resource resource = ResourceCreate( @@ -373,5 +364,5 @@ async def test_inactive_resource_handling(self, test_db, resource_service_with_m assert "exists but is inactive" in str(exc_info.value) # Pre-fetch is called but post-fetch should not be called for inactive resources - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_not_called() + # Only one invoke_hook call (pre-fetch) since error occurs before post-fetch + assert mock_manager.invoke_hook.call_count == 1 diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py index eef673450..c5b3fc354 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py @@ -8,9 +8,9 @@ Context plugin. """ -from mcpgateway.plugins.framework import ( - Plugin, - PluginContext, +from mcpgateway.plugins.framework import PluginContext +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -26,7 +26,7 @@ ) -class ContextPlugin(Plugin): +class ContextPlugin(MCPPlugin): """A simple Context plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: @@ -111,7 +111,7 @@ async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: Pl return ResourcePreFetchResult(continue_processing=True) -class ContextPlugin2(Plugin): +class ContextPlugin2(MCPPlugin): """A simple Context plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py index d15f110c1..e0d44f874 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py @@ -8,9 +8,9 @@ Error plugin. """ -from mcpgateway.plugins.framework import ( - Plugin, - PluginContext, +from mcpgateway.plugins.framework import PluginContext +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -26,7 +26,7 @@ ) -class ErrorPlugin(Plugin): +class ErrorPlugin(MCPPlugin): """A simple error plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py index 1ba97649d..0d61aadd5 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py @@ -13,9 +13,11 @@ from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.plugins.framework import ( - HttpHeaderPayload, - Plugin, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, + HttpHeaderPayload, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -33,7 +35,7 @@ logger = logging.getLogger("header_plugin") -class HeadersMetaDataPlugin(Plugin): +class HeadersMetaDataPlugin(MCPPlugin): """A simple header plugin to read and modify headers.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: @@ -140,7 +142,7 @@ async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: Pl return ResourcePreFetchResult(continue_processing=True) -class HeadersPlugin(Plugin): +class HeadersPlugin(MCPPlugin): """A simple header plugin to read and modify headers.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py index 8a6db5869..b858b8ea8 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py @@ -8,9 +8,9 @@ """ # First-Party -from mcpgateway.plugins.framework import ( - Plugin, - PluginContext, +from mcpgateway.plugins.framework import PluginContext +from mcpgateway.plugins.mcp.entities import ( + MCPPlugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -26,7 +26,7 @@ ) -class PassThroughPlugin(Plugin): +class PassThroughPlugin(MCPPlugin): """A simple pass through plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py index 288275e8f..524a6b60f 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py @@ -18,6 +18,8 @@ from mcpgateway.plugins.framework import ( GlobalContext, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py index 6c960ce51..313bf6ed9 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py @@ -22,6 +22,9 @@ ConfigLoader, GlobalContext, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -121,35 +124,35 @@ async def test_hook_methods_empty_content(): # Test prompt_pre_fetch with empty content - should raise PluginError payload = PromptPrehookPayload(prompt_id="1", args={}) with pytest.raises(PluginError): - await plugin.prompt_pre_fetch(payload, context) + await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) # Test prompt_post_fetch with empty content - should raise PluginError message = Message(content=TextContent(type="text", text="test"), role=Role.USER) prompt_result = PromptResult(messages=[message]) payload = PromptPosthookPayload(prompt_id="1", result=prompt_result) with pytest.raises(PluginError): - await plugin.prompt_post_fetch(payload, context) + await plugin.invoke_hook(HookType.PROMPT_POST_FETCH, payload, context) # Test tool_pre_invoke with empty content - should raise PluginError payload = ToolPreInvokePayload(name="test", args={}) with pytest.raises(PluginError): - await plugin.tool_pre_invoke(payload, context) + await plugin.invoke_hook(HookType.TOOL_PRE_INVOKE, payload, context) # Test tool_post_invoke with empty content - should raise PluginError payload = ToolPostInvokePayload(name="test", result={}) with pytest.raises(PluginError): - await plugin.tool_post_invoke(payload, context) + await plugin.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) # Test resource_pre_fetch with empty content - should raise PluginError payload = ResourcePreFetchPayload(uri="file://test.txt") with pytest.raises(PluginError): - await plugin.resource_pre_fetch(payload, context) + await plugin.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, context) # Test resource_post_fetch with empty content - should raise PluginError resource_content = ResourceContent(type="resource", id="123",uri="file://test.txt", text="content") payload = ResourcePostFetchPayload(uri="file://test.txt", content=resource_content) with pytest.raises(PluginError): - await plugin.resource_post_fetch(payload, context) + await plugin.invoke_hook(HookType.RESOURCE_POST_FETCH, payload, context) await plugin.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py index 53f5f8e2b..e7ab7100d 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py @@ -29,6 +29,9 @@ PluginContext, PluginLoader, PluginManager, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -48,7 +51,7 @@ async def test_client_load_stdio(): loader = PluginLoader() plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"text": "That was innovative!"}) - result = await plugin.prompt_pre_fetch(prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) + result = await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) assert result.violation assert result.violation.reason == "Prompt not allowed" assert result.violation.description == "A deny word was found in the prompt" @@ -72,7 +75,7 @@ async def test_client_load_stdio_overrides(): loader = PluginLoader() plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args = {"text": "That was innovative!"}) - result = await plugin.prompt_pre_fetch(prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) + result = await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) assert result.violation assert result.violation.reason == "Prompt not allowed" assert result.violation.description == "A deny word was found in the prompt" @@ -98,7 +101,7 @@ async def test_client_load_stdio_post_prompt(): plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args = {"user": "What a crapshow!"}) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) - result = await plugin.prompt_pre_fetch(prompt, context) + result = await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, context) assert result.modified_payload.args["user"] == "What a yikesshow!" config = plugin.config assert config.name == "ReplaceBadWordsPlugin" @@ -111,7 +114,7 @@ async def test_client_load_stdio_post_prompt(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result = await plugin.prompt_post_fetch(payload_result, context=context) + result = await plugin.invoke_hook(HookType.PROMPT_POST_FETCH, payload_result, context=context) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "What the yikes?" await plugin.shutdown() @@ -185,7 +188,7 @@ async def test_hooks(): await plugin_manager.initialize() payload = PromptPrehookPayload(prompt_id="test_prompt", name="test_prompt", args={"arg0": "This is a crap argument"}) global_context = GlobalContext(request_id="1") - result, _ = await plugin_manager.prompt_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test prompt post hook across all registered plugins.""" @@ -193,31 +196,31 @@ async def test_hooks(): message = Message(content=TextContent(type="text", text="prompt"), role=Role.USER) prompt_result = PromptResult(messages=[message]) payload = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await plugin_manager.prompt_post_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test tool pre hook across all registered plugins.""" # Customize payload for testing payload = ToolPreInvokePayload(name="test_prompt", args={"arg0": "This is an argument"}) - result, _ = await plugin_manager.tool_pre_invoke(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.TOOL_PRE_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test tool post hook across all registered plugins.""" # Customize payload for testing payload = ToolPostInvokePayload(name="test_tool", result={"output0": "output value"}) - result, _ = await plugin_manager.tool_post_invoke(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing payload = ResourcePreFetchPayload(uri="file:///data.txt") - result, _ = await plugin_manager.resource_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing content = ResourceContent(type="resource", id="123", uri="file:///data.txt", text="Hello World") payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - result, _ = await plugin_manager.resource_post_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.RESOURCE_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing await plugin_manager.shutdown() @@ -233,7 +236,7 @@ async def test_errors(): global_context = GlobalContext(request_id="1") escaped_regex = re.escape("ValueError('Sadly! Prompt prefetch is broken!')") with pytest.raises(PluginError, match=escaped_regex): - await plugin_manager.prompt_pre_fetch(payload, global_context) + await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) await plugin_manager.shutdown() @@ -250,7 +253,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 2 ctxs = [contexts[key] for key in contexts.keys()] @@ -279,7 +282,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): assert result.modified_payload is None # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) ctxs = [contexts[key] for key in contexts.keys()] assert len(ctxs) == 2 diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py index 72fdf82f6..dd0eb8b68 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py @@ -18,7 +18,8 @@ # First-Party from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework import ConfigLoader, GlobalContext, PluginContext, PluginLoader, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import ConfigLoader, GlobalContext, PluginContext, PluginLoader +from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload @pytest.fixture(autouse=True) diff --git a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py index 9c7f15174..114f8449b 100644 --- a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py +++ b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py @@ -17,7 +17,8 @@ from mcpgateway.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework.models import GlobalContext, PluginContext, PluginMode, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import GlobalContext, PluginContext, PluginMode +from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload from plugins.regex_filter.search_replace import SearchReplaceConfig, SearchReplacePlugin from unittest.mock import patch diff --git a/tests/unit/mcpgateway/plugins/framework/test_context.py b/tests/unit/mcpgateway/plugins/framework/test_context.py index f84a94fde..0f8a3e0ba 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_context.py +++ b/tests/unit/mcpgateway/plugins/framework/test_context.py @@ -11,6 +11,9 @@ from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ToolPreInvokePayload, ToolPostInvokePayload, ) @@ -25,7 +28,7 @@ async def test_shared_context_across_pre_post_hooks(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 1 context = next(iter(contexts.values())) @@ -42,7 +45,7 @@ async def test_shared_context_across_pre_post_hooks(): # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) assert len(contexts) == 1 context = next(iter(contexts.values())) @@ -71,7 +74,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 2 ctxs = [contexts[key] for key in contexts.keys()] @@ -100,7 +103,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): assert result.modified_payload is None # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) ctxs = [contexts[key] for key in contexts.keys()] assert len(ctxs) == 2 diff --git a/tests/unit/mcpgateway/plugins/framework/test_errors.py b/tests/unit/mcpgateway/plugins/framework/test_errors.py index 9dccc1706..d74be9911 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_errors.py +++ b/tests/unit/mcpgateway/plugins/framework/test_errors.py @@ -16,9 +16,10 @@ PluginError, PluginMode, PluginManager, - PromptPrehookPayload, ) +from mcpgateway.plugins.mcp.entities import HookType, PromptPrehookPayload + @pytest.mark.asyncio async def test_convert_exception_to_error(): @@ -40,7 +41,7 @@ async def test_error_plugin(): global_context = GlobalContext(request_id="1") escaped_regex = re.escape("ValueError('Sadly! Prompt prefetch is broken!')") with pytest.raises(PluginError, match=escaped_regex): - await plugin_manager.prompt_pre_fetch(payload, global_context) + await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) await plugin_manager.shutdown() @@ -51,14 +52,14 @@ async def test_error_plugin_raise_error_false(): payload = PromptPrehookPayload(prompt_id="test_prompt", args={"arg0": "This is a crap argument"}) global_context = GlobalContext(request_id="1") with pytest.raises(PluginError): - result, _ = await plugin_manager.prompt_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) # assert result.continue_processing # assert not result.modified_payload await plugin_manager.shutdown() plugin_manager.config.plugins[0].mode = PluginMode.ENFORCE_IGNORE_ERROR await plugin_manager.initialize() - result, _ = await plugin_manager.prompt_pre_fetch(payload, global_context) + result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) assert result.continue_processing assert not result.modified_payload await plugin_manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index 7c58772c1..7df5b6d70 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -12,7 +12,8 @@ # First-Party from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginManager, PluginViolationError, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError +from mcpgateway.plugins.mcp.entities import HookType, HttpHeaderPayload, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload from plugins.regex_filter.search_replace import SearchReplaceConfig @@ -34,7 +35,7 @@ async def test_manager_single_transformer_prompt_plugin(): assert srconfig.words[0].replace == "crud" prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "What a crapshow!"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert len(result.modified_payload.args) == 1 assert result.modified_payload.args["user"] == "What a yikesshow!" @@ -44,7 +45,7 @@ async def test_manager_single_transformer_prompt_plugin(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await manager.prompt_post_fetch(payload_result, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(HookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "What a yikesshow!" await manager.shutdown() @@ -82,7 +83,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "It's always happy at the crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert len(result.modified_payload.args) == 1 assert result.modified_payload.args["user"] == "It's always gleeful at the yikesshow." @@ -92,7 +93,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await manager.prompt_post_fetch(payload_result, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(HookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "It's sullen at the yikes bakery." await manager.shutdown() @@ -105,7 +106,7 @@ async def test_manager_no_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "It's always happy at the crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert result.continue_processing assert not result.modified_payload await manager.shutdown() @@ -118,12 +119,12 @@ async def test_manager_filter_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "innovative"}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert not result.continue_processing assert result.violation with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert ve.value.violation assert ve.value.violation.reason == "Prompt not allowed" await manager.shutdown() @@ -136,11 +137,11 @@ async def test_manager_multi_filter_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "innovative crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert not result.continue_processing assert result.violation with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert ve.value.violation await manager.shutdown() @@ -155,7 +156,7 @@ async def test_manager_tool_hooks_empty(): # Test tool pre-invoke with no plugins tool_payload = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with no modifications assert result.continue_processing @@ -165,7 +166,7 @@ async def test_manager_tool_hooks_empty(): # Test tool post-invoke with no plugins tool_result_payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - result, contexts = await manager.tool_post_invoke(tool_result_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context) # Should continue processing with no modifications assert result.continue_processing @@ -186,7 +187,7 @@ async def test_manager_tool_hooks_with_transformer_plugin(): # Test tool pre-invoke - no plugins configured for tool hooks tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is crap data"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with no modifications (no plugins for tool hooks) assert result.continue_processing @@ -196,7 +197,7 @@ async def test_manager_tool_hooks_with_transformer_plugin(): # Test tool post-invoke - no plugins configured for tool hooks tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result with crap in it"}) - result, _ = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) # Should continue processing with no modifications (no plugins for tool hooks) assert result.continue_processing @@ -216,7 +217,7 @@ async def test_manager_tool_hooks_with_actual_plugin(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing @@ -228,7 +229,7 @@ async def test_manager_tool_hooks_with_actual_plugin(): # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, _ = await manager.tool_post_invoke(tool_result_payload, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) # Should continue processing with transformations applied assert result.continue_processing @@ -251,7 +252,7 @@ async def test_manager_tool_hooks_with_header_mods(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=None) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing @@ -267,7 +268,7 @@ async def test_manager_tool_hooks_with_header_mods(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=HttpHeaderPayload({"Content-Type": "application/json"})) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py index e8e1d8968..2e6bac7f6 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py @@ -17,11 +17,10 @@ # First-Party from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.plugins.framework.base import HookRef, Plugin from mcpgateway.plugins.framework.models import Config from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginCondition, PluginConfig, PluginContext, @@ -31,6 +30,11 @@ PluginResult, PluginViolation, PluginViolationError, +) + +from mcpgateway.plugins.mcp.entities import ( + HookType, + MCPPlugin, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, @@ -44,7 +48,7 @@ async def test_manager_timeout_handling(): """Test plugin timeout handling in both enforce and permissive modes.""" # Create a plugin that times out - class TimeoutPlugin(Plugin): + class TimeoutPlugin(MCPPlugin): async def prompt_pre_fetch(self, payload, context): await asyncio.sleep(10) # Longer than timeout return PluginResult(continue_processing=True) @@ -52,7 +56,7 @@ async def prompt_pre_fetch(self, payload, context): # Test with enforce mode manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") await manager.initialize() - manager._pre_prompt_executor.timeout = 0.01 # Set very short timeout + manager._executor.timeout = 0.01 # Set very short timeout # Mock plugin registry plugin_config = PluginConfig( @@ -60,16 +64,16 @@ async def prompt_pre_fetch(self, payload, context): ) timeout_plugin = TimeoutPlugin(plugin_config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(timeout_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") escaped_regex = re.escape("Plugin TimeoutPlugin exceeded 0.01s timeout") with pytest.raises(PluginError, match=escaped_regex): - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should pass since fail_on_plugin_error: false # assert result.continue_processing @@ -79,11 +83,11 @@ async def prompt_pre_fetch(self, payload, context): # Test with permissive mode plugin_config.mode = PluginMode.PERMISSIVE - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(timeout_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode assert result.continue_processing @@ -97,7 +101,7 @@ async def test_manager_exception_handling(): """Test plugin exception handling in both enforce and permissive modes.""" # Create a plugin that raises an exception - class ErrorPlugin(Plugin): + class ErrorPlugin(MCPPlugin): async def prompt_pre_fetch(self, payload, context): raise RuntimeError("Plugin error!") @@ -110,16 +114,16 @@ async def prompt_pre_fetch(self, payload, context): error_plugin = ErrorPlugin(plugin_config) # Test with enforce mode - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") escaped_regex = re.escape("RuntimeError('Plugin error!')") with pytest.raises(PluginError, match=escaped_regex): - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should block in enforce mode # assert result.continue_processing @@ -129,44 +133,44 @@ async def prompt_pre_fetch(self, payload, context): # Test with permissive mode plugin_config.mode = PluginMode.PERMISSIVE - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode assert result.continue_processing assert result.violation is None plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing assert result.violation is None plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing assert result.violation is None plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(error_plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing @@ -175,68 +179,68 @@ async def prompt_pre_fetch(self, payload, context): await manager.shutdown() -@pytest.mark.asyncio -async def test_manager_condition_filtering(): - """Test that plugins are filtered based on conditions.""" +# @pytest.mark.asyncio +# async def test_manager_condition_filtering(): +# """Test that plugins are filtered based on conditions.""" - class ConditionalPlugin(Plugin): - async def prompt_pre_fetch(self, payload, context): - payload.args["modified"] = "yes" - return PluginResult(continue_processing=True, modified_payload=payload) +# class ConditionalPlugin(MCPPlugin): +# async def prompt_pre_fetch(self, payload, context): +# payload.args["modified"] = "yes" +# return PluginResult(continue_processing=True, modified_payload=payload) - manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - await manager.initialize() +# manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") +# await manager.initialize() - # Plugin with server_id condition - plugin_config = PluginConfig( - name="ConditionalPlugin", - description="Test conditional plugin", - author="Test", - version="1.0", - tags=["test"], - kind="ConditionalPlugin", - hooks=["prompt_pre_fetch"], - config={}, - conditions=[PluginCondition(server_ids={"server1"})], - ) - plugin = ConditionalPlugin(plugin_config) +# # Plugin with server_id condition +# plugin_config = PluginConfig( +# name="ConditionalPlugin", +# description="Test conditional plugin", +# author="Test", +# version="1.0", +# tags=["test"], +# kind="ConditionalPlugin", +# hooks=["prompt_pre_fetch"], +# config={}, +# conditions=[PluginCondition(server_ids={"server1"})], +# ) +# plugin = ConditionalPlugin(plugin_config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] +# with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: +# plugin_ref = PluginRef(plugin) +# mock_get.return_value = [plugin_ref] - prompt = PromptPrehookPayload(prompt_id="test", args={}) +# prompt = PromptPrehookPayload(prompt_id="test", args={}) - # Test with matching server_id - global_context = GlobalContext(request_id="1", server_id="server1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) +# # Test with matching server_id +# global_context = GlobalContext(request_id="1", server_id="server1") +# result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) - # Plugin should execute - assert result.continue_processing - assert result.modified_payload is not None - assert result.modified_payload.args.get("modified") == "yes" +# # Plugin should execute +# assert result.continue_processing +# assert result.modified_payload is not None +# assert result.modified_payload.args.get("modified") == "yes" - # Test with non-matching server_id - prompt2 = PromptPrehookPayload(prompt_id="test", args={}) - global_context2 = GlobalContext(request_id="2", server_id="server2") - result2, _ = await manager.prompt_pre_fetch(prompt2, global_context=global_context2) +# # Test with non-matching server_id +# prompt2 = PromptPrehookPayload(prompt_id="test", args={}) +# global_context2 = GlobalContext(request_id="2", server_id="server2") +# result2, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt2, global_context=global_context2) - # Plugin should be skipped - assert result2.continue_processing - assert result2.modified_payload is None # No modification +# # Plugin should be skipped +# assert result2.continue_processing +# assert result2.modified_payload is None # No modification - await manager.shutdown() +# await manager.shutdown() @pytest.mark.asyncio async def test_manager_metadata_aggregation(): """Test metadata aggregation from multiple plugins.""" - class MetadataPlugin1(Plugin): + class MetadataPlugin1(MCPPlugin): async def prompt_pre_fetch(self, payload, context): return PluginResult(continue_processing=True, metadata={"plugin1": "data1", "shared": "value1"}) - class MetadataPlugin2(Plugin): + class MetadataPlugin2(MCPPlugin): async def prompt_pre_fetch(self, payload, context): return PluginResult( continue_processing=True, @@ -251,14 +255,14 @@ async def prompt_pre_fetch(self, payload, context): plugin1 = MetadataPlugin1(config1) plugin2 = MetadataPlugin2(config2) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - refs = [PluginRef(plugin1), PluginRef(plugin2)] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + refs = [HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin1)), HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin2))] mock_get.return_value = refs prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should aggregate metadata assert result.continue_processing @@ -273,7 +277,7 @@ async def prompt_pre_fetch(self, payload, context): async def test_manager_local_context_persistence(): """Test that local contexts persist across hook calls.""" - class StatefulPlugin(Plugin): + class StatefulPlugin(MCPPlugin): async def prompt_pre_fetch(self, payload, context: PluginContext): context.state["counter"] = context.state.get("counter", 0) + 1 return PluginResult(continue_processing=True) @@ -292,17 +296,25 @@ async def prompt_post_fetch(self, payload, context: PluginContext): ) plugin = StatefulPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_pre, patch.object(manager._registry, "get_plugins_for_hook") as mock_post: - plugin_ref = PluginRef(plugin) + # Create a single PluginRef to ensure the same UUID is used for both hooks + plugin_ref = PluginRef(plugin) + hook_ref_pre = HookRef(HookType.PROMPT_PRE_FETCH, plugin_ref) + hook_ref_post = HookRef(HookType.PROMPT_POST_FETCH, plugin_ref) + + def get_hook_refs_side_effect(hook_type): + if hook_type == HookType.PROMPT_PRE_FETCH: + return [hook_ref_pre] + elif hook_type == HookType.PROMPT_POST_FETCH: + return [hook_ref_post] + return [] - mock_pre.return_value = [plugin_ref] - mock_post.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook", side_effect=get_hook_refs_side_effect): # First call to pre_fetch prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") - result_pre, contexts = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result_pre, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert result_pre.continue_processing # Call to post_fetch with same contexts @@ -310,7 +322,7 @@ async def prompt_post_fetch(self, payload, context: PluginContext): prompt_result = PromptResult(messages=[message]) post_payload = PromptPosthookPayload(prompt_id="test", result=prompt_result) - result_post, _ = await manager.prompt_post_fetch(post_payload, global_context=global_context, local_contexts=contexts) + result_post, _ = await manager.invoke_hook(HookType.PROMPT_POST_FETCH, post_payload, global_context=global_context, local_contexts=contexts) # Should have modified with persisted state assert result_post.continue_processing @@ -324,7 +336,7 @@ async def prompt_post_fetch(self, payload, context: PluginContext): async def test_manager_plugin_blocking(): """Test plugin blocking behavior in enforce mode.""" - class BlockingPlugin(Plugin): + class BlockingPlugin(MCPPlugin): async def prompt_pre_fetch(self, payload, context): violation = PluginViolation(reason="Content violation", description="Blocked content detected", code="CONTENT_BLOCKED", details={"content": payload.args}) return PluginResult(continue_processing=False, violation=violation) @@ -337,14 +349,14 @@ async def prompt_pre_fetch(self, payload, context): ) plugin = BlockingPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={"text": "bad content"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should block the request assert not result.continue_processing @@ -353,7 +365,7 @@ async def prompt_pre_fetch(self, payload, context): assert result.violation.plugin_name == "BlockingPlugin" with pytest.raises(PluginViolationError) as pve: - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert pve.value.violation assert pve.value.message assert pve.value.violation.code == "CONTENT_BLOCKED" @@ -365,7 +377,7 @@ async def prompt_pre_fetch(self, payload, context): async def test_manager_plugin_permissive_blocking(): """Test plugin behavior when blocking in permissive mode.""" - class BlockingPlugin(Plugin): + class BlockingPlugin(MCPPlugin): async def prompt_pre_fetch(self, payload, context): violation = PluginViolation(reason="Would block", description="Content would be blocked", code="WOULD_BLOCK") return PluginResult(continue_processing=False, violation=violation) @@ -387,14 +399,14 @@ async def prompt_pre_fetch(self, payload, context): plugin = BlockingPlugin(config) # Test permissive mode blocking (covers lines 194-195) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin)) + mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={"text": "content"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.prompt_pre_fetch(prompt, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode - the permissive logic continues without blocking assert result.continue_processing @@ -434,10 +446,10 @@ async def test_manager_payload_size_validation(): """Test payload size validation functionality.""" # First-Party from mcpgateway.plugins.framework.manager import MAX_PAYLOAD_SIZE, PayloadSizeError, PluginExecutor - from mcpgateway.plugins.framework.models import PromptPosthookPayload, PromptPrehookPayload + from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload # Test payload size validation directly on executor (covers lines 252, 258) - executor = PluginExecutor[PromptPrehookPayload]() + executor = PluginExecutor() # Test large args payload (covers line 252) large_data = "x" * (MAX_PAYLOAD_SIZE + 1) @@ -457,7 +469,7 @@ async def test_manager_payload_size_validation(): large_post_payload = PromptPosthookPayload(prompt_id="test", result=large_result) # Should raise PayloadSizeError for large result - executor2 = PluginExecutor[PromptPosthookPayload]() + executor2 = PluginExecutor() with pytest.raises(PayloadSizeError, match="Result size .* exceeds limit"): executor2._validate_payload_size(large_post_payload) @@ -527,72 +539,20 @@ async def test_manager_initialization_edge_cases(): await manager2.shutdown() -@pytest.mark.asyncio -async def test_manager_context_cleanup(): - """Test context cleanup functionality.""" - # Standard - import time - - # First-Party - from mcpgateway.plugins.framework.manager import CONTEXT_MAX_AGE - - manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - await manager.initialize() - - # Add some old contexts to the store - old_time = time.time() - CONTEXT_MAX_AGE - 1 # Older than max age - manager._context_store["old_request"] = ({}, old_time) - manager._context_store["new_request"] = ({}, time.time()) - - # Force cleanup by setting last cleanup time to 0 - manager._last_cleanup = 0 - - with patch("mcpgateway.plugins.framework.manager.logger") as mock_logger: - # Run cleanup (covers lines 551, 554) - await manager._cleanup_old_contexts() - - # Should have removed old context - assert "old_request" not in manager._context_store - assert "new_request" in manager._context_store - - # Should log cleanup message - mock_logger.info.assert_called_with("Cleaned up 1 expired plugin contexts") - - await manager.shutdown() - - -@pytest.mark.asyncio -async def test_manager_constructor_context_init(): - """Test manager constructor context initialization.""" - - # Test that managers share state and context store exists (covers lines 432-433) - manager1 = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - manager2 = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") - - # Both managers should share the same state - assert hasattr(manager1, "_context_store") - assert hasattr(manager2, "_context_store") - assert hasattr(manager1, "_last_cleanup") - assert hasattr(manager2, "_last_cleanup") - - # They should be the same instance due to shared state - assert manager1._context_store is manager2._context_store - await manager1.shutdown() - await manager2.shutdown() - - @pytest.mark.asyncio async def test_base_plugin_coverage(): """Test base plugin functionality for complete coverage.""" # First-Party from mcpgateway.models import Message, PromptResult, Role, TextContent - from mcpgateway.plugins.framework.base import Plugin, PluginRef + from mcpgateway.plugins.framework.base import PluginRef from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginMode, + ) + from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, @@ -611,7 +571,7 @@ async def test_base_plugin_coverage(): config={}, ) - plugin = Plugin(config) + plugin = MCPPlugin(config) # Test tags property assert plugin.tags == ["test", "coverage"] @@ -690,7 +650,8 @@ async def test_plugin_loader_return_none(): """Test plugin loader return None case.""" # First-Party from mcpgateway.plugins.framework.loader.plugin import PluginLoader - from mcpgateway.plugins.framework.models import HookType, PluginConfig + from mcpgateway.plugins.framework import PluginConfig + from mcpgateway.plugins.mcp.entities import HookType loader = PluginLoader() @@ -736,7 +697,7 @@ async def test_manager_compare_function_wrapper(): # The compare function is used internally in _run_plugins # Test by using plugins with conditions - class TestPlugin(Plugin): + class TestPlugin(MCPPlugin): async def tool_pre_invoke(self, payload, context): return PluginResult(continue_processing=True) @@ -753,20 +714,20 @@ async def tool_pre_invoke(self, payload, context): ) plugin = TestPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.TOOL_PRE_INVOKE, PluginRef(plugin)) + mock_get.return_value = [hook_ref] # Test with matching tool tool_payload = ToolPreInvokePayload(name="calculator", args={}) global_context = GlobalContext(request_id="1") - result, _ = await manager.tool_pre_invoke(tool_payload, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert result.continue_processing # Test with non-matching tool tool_payload2 = ToolPreInvokePayload(name="other_tool", args={}) - result2, _ = await manager.tool_pre_invoke(tool_payload2, global_context=global_context) + result2, _ = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload2, global_context=global_context) assert result2.continue_processing await manager.shutdown() @@ -778,7 +739,7 @@ async def test_manager_tool_post_invoke_coverage(): manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") await manager.initialize() - class ModifyingPlugin(Plugin): + class ModifyingPlugin(MCPPlugin): async def tool_post_invoke(self, payload, context): payload.result["modified"] = True return PluginResult(continue_processing=True, modified_payload=payload) @@ -786,14 +747,14 @@ async def tool_post_invoke(self, payload, context): config = PluginConfig(name="ModifyingPlugin", description="Test modifying plugin", author="Test", version="1.0", tags=["test"], kind="ModifyingPlugin", hooks=["tool_post_invoke"], config={}) plugin = ModifyingPlugin(config) - with patch.object(manager._registry, "get_plugins_for_hook") as mock_get: - plugin_ref = PluginRef(plugin) - mock_get.return_value = [plugin_ref] + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(HookType.TOOL_POST_INVOKE, PluginRef(plugin)) + mock_get.return_value = [hook_ref] tool_payload = ToolPostInvokePayload(name="test_tool", result={"original": "data"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.tool_post_invoke(tool_payload, global_context=global_context) + result, _ = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_payload, global_context=global_context) assert result.continue_processing assert result.modified_payload is not None diff --git a/tests/unit/mcpgateway/plugins/framework/test_registry.py b/tests/unit/mcpgateway/plugins/framework/test_registry.py index 7f62b694f..16daa86b1 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_registry.py +++ b/tests/unit/mcpgateway/plugins/framework/test_registry.py @@ -14,10 +14,10 @@ import pytest # First-Party -from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework.models import HookType, PluginConfig +from mcpgateway.plugins.framework import PluginConfig +from mcpgateway.plugins.mcp.entities import HookType, MCPPlugin from mcpgateway.plugins.framework.registry import PluginInstanceRegistry @@ -96,21 +96,21 @@ async def test_registry_priority_sorting(): ) # Create plugin instances - low_priority_plugin = Plugin(low_priority_config) - high_priority_plugin = Plugin(high_priority_config) + low_priority_plugin = MCPPlugin(low_priority_config) + high_priority_plugin = MCPPlugin(high_priority_config) # Register plugins in reverse priority order registry.register(low_priority_plugin) registry.register(high_priority_plugin) # Get plugins for hook - should be sorted by priority (lines 131-134) - hook_plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + hook_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) assert len(hook_plugins) == 2 - assert hook_plugins[0].name == "HighPriority" # Lower number = higher priority - assert hook_plugins[1].name == "LowPriority" + assert hook_plugins[0].plugin_ref.name == "HighPriority" # Lower number = higher priority + assert hook_plugins[1].plugin_ref.name == "LowPriority" # Test priority cache - calling again should use cached result - cached_plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + cached_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) assert cached_plugins == hook_plugins # Clean up @@ -133,22 +133,22 @@ async def test_registry_hook_filtering(): name="PostFetchPlugin", description="Post-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_POST_FETCH], config={} ) - pre_fetch_plugin = Plugin(pre_fetch_config) - post_fetch_plugin = Plugin(post_fetch_config) + pre_fetch_plugin = MCPPlugin(pre_fetch_config) + post_fetch_plugin = MCPPlugin(post_fetch_config) registry.register(pre_fetch_plugin) registry.register(post_fetch_plugin) # Test hook filtering - pre_plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) - post_plugins = registry.get_plugins_for_hook(HookType.PROMPT_POST_FETCH) - tool_plugins = registry.get_plugins_for_hook(HookType.TOOL_PRE_INVOKE) + pre_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) + post_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_POST_FETCH) + tool_plugins = registry.get_hook_refs_for_hook(HookType.TOOL_PRE_INVOKE) assert len(pre_plugins) == 1 - assert pre_plugins[0].name == "PreFetchPlugin" + assert pre_plugins[0].plugin_ref.name == "PreFetchPlugin" assert len(post_plugins) == 1 - assert post_plugins[0].name == "PostFetchPlugin" + assert post_plugins[0].plugin_ref.name == "PostFetchPlugin" assert len(tool_plugins) == 0 # No plugins for this hook @@ -163,9 +163,9 @@ async def test_registry_shutdown(): registry = PluginInstanceRegistry() # Create mock plugins with shutdown methods - mock_plugin1 = Plugin(PluginConfig(name="Plugin1", description="Test plugin 1", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={})) + mock_plugin1 = MCPPlugin(PluginConfig(name="Plugin1", description="Test plugin 1", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={})) - mock_plugin2 = Plugin(PluginConfig(name="Plugin2", description="Test plugin 2", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_POST_FETCH], config={})) + mock_plugin2 = MCPPlugin(PluginConfig(name="Plugin2", description="Test plugin 2", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_POST_FETCH], config={})) # Mock the shutdown methods mock_plugin1.shutdown = AsyncMock() @@ -196,7 +196,7 @@ async def test_registry_shutdown_with_error(): registry = PluginInstanceRegistry() # Create mock plugin that fails during shutdown - failing_plugin = Plugin( + failing_plugin = MCPPlugin( PluginConfig(name="FailingPlugin", description="Plugin that fails shutdown", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) ) @@ -232,7 +232,7 @@ async def test_registry_edge_cases(): assert registry.plugin_count == 0 # Test getting hooks for empty registry - empty_hooks = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + empty_hooks = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) assert len(empty_hooks) == 0 # Test get_all_plugins when empty @@ -246,13 +246,13 @@ async def test_registry_cache_invalidation(): plugin_config = PluginConfig(name="TestPlugin", description="Test plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) - plugin = Plugin(plugin_config) + plugin = MCPPlugin(plugin_config) # Register plugin registry.register(plugin) # Get plugins for hook (populates cache) - hooks1 = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + hooks1 = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) assert len(hooks1) == 1 # Cache should be populated @@ -262,5 +262,5 @@ async def test_registry_cache_invalidation(): registry.unregister("TestPlugin") # Cache should be cleared for this hook type - hooks2 = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + hooks2 = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) assert len(hooks2) == 0 diff --git a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py index 1a3dbcb67..3d95e6e5e 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py +++ b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py @@ -15,12 +15,11 @@ # First-Party from mcpgateway.models import ResourceContent -from mcpgateway.plugins.framework.base import Plugin, PluginRef +from mcpgateway.plugins.framework.base import PluginRef # Registry is imported for mocking from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginCondition, PluginConfig, PluginContext, @@ -28,6 +27,10 @@ PluginManager, PluginMode, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, + MCPPlugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -64,7 +67,7 @@ async def test_plugin_resource_pre_fetch_default(self): hooks=[HookType.RESOURCE_PRE_FETCH], tags=["test"], ) - plugin = Plugin(config) + plugin = MCPPlugin(config) payload = ResourcePreFetchPayload(uri="file:///test.txt", metadata={}) context = PluginContext(global_context=GlobalContext(request_id="test-123")) @@ -83,7 +86,7 @@ async def test_plugin_resource_post_fetch_default(self): hooks=[HookType.RESOURCE_POST_FETCH], tags=["test"], ) - plugin = Plugin(config) + plugin = MCPPlugin(config) content = ResourceContent(type="resource", id="123",uri="file:///test.txt", text="Test content") payload = ResourcePostFetchPayload(uri="file:///test.txt", content=content) context = PluginContext(global_context=GlobalContext(request_id="test-123")) @@ -95,7 +98,7 @@ async def test_plugin_resource_post_fetch_default(self): async def test_resource_hook_blocking(self): """Test resource hook that blocks processing.""" - class BlockingResourcePlugin(Plugin): + class BlockingResourcePlugin(MCPPlugin): async def resource_pre_fetch(self, payload, context): return ResourcePreFetchResult( continue_processing=False, @@ -132,7 +135,7 @@ async def resource_pre_fetch(self, payload, context): async def test_resource_content_modification(self): """Test resource post-fetch content modification.""" - class ContentFilterPlugin(Plugin): + class ContentFilterPlugin(MCPPlugin): async def resource_post_fetch(self, payload, context): # Modify content to redact sensitive data modified_text = payload.content.text.replace("password: secret123", "password: [REDACTED]") @@ -181,7 +184,7 @@ async def resource_post_fetch(self, payload, context): async def test_resource_hook_with_conditions(self): """Test resource hooks with conditions.""" - class ConditionalResourcePlugin(Plugin): + class ConditionalResourcePlugin(MCPPlugin): async def resource_pre_fetch(self, payload, context): # Only process if conditions match return ResourcePreFetchResult( @@ -273,64 +276,58 @@ async def test_manager_resource_pre_fetch(self): payload = ResourcePreFetchPayload(uri="test://resource", metadata={}) global_context = GlobalContext(request_id="test-123") - result, contexts = await manager.resource_pre_fetch(payload, global_context) + result, contexts = await manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) assert result.continue_processing is True - MockRegistry.return_value.get_plugins_for_hook.assert_called_with(HookType.RESOURCE_PRE_FETCH) + MockRegistry.return_value.get_hook_refs_for_hook.assert_called_with(hook_type=HookType.RESOURCE_PRE_FETCH) @pytest.mark.asyncio async def test_manager_resource_post_fetch(self): """Test plugin manager resource_post_fetch execution.""" - with patch("mcpgateway.plugins.framework.manager.PluginInstanceRegistry") as MockRegistry: - with patch("mcpgateway.plugins.framework.loader.config.ConfigLoader.load_config") as MockConfig: - # Create a proper mock plugin with all required attributes - mock_plugin_obj = MagicMock() - mock_plugin_obj.name = "test_plugin" - mock_plugin_obj.priority = 50 - mock_plugin_obj.mode = PluginMode.ENFORCE - mock_plugin_obj.conditions = [] - mock_plugin_obj.resource_post_fetch = AsyncMock( - return_value=ResourcePostFetchResult( - continue_processing=True, - modified_payload=None, - ) - ) + # First-Party + from mcpgateway.plugins.framework.base import HookRef - # Create a PluginRef-like mock - mock_ref = MagicMock() - mock_ref._plugin = mock_plugin_obj - mock_ref.plugin = mock_plugin_obj - mock_ref.name = "test_plugin" - mock_ref.priority = 50 - mock_ref.mode = PluginMode.ENFORCE - mock_ref.conditions = [] - mock_ref.uuid = "test-uuid" + class TestResourcePlugin(MCPPlugin): + async def resource_post_fetch(self, payload, context): + return ResourcePostFetchResult( + continue_processing=True, + modified_payload=None, + ) - MockRegistry.return_value.get_plugins_for_hook.return_value = [mock_ref] + config = PluginConfig( + name="test_plugin", + description="Test resource plugin", + author="test", + kind="test.Plugin", + version="1.0.0", + hooks=[HookType.RESOURCE_POST_FETCH], + tags=["test"], + mode=PluginMode.ENFORCE, + ) + plugin = TestResourcePlugin(config) + plugin_ref = PluginRef(plugin) + hook_ref = HookRef(HookType.RESOURCE_POST_FETCH, plugin_ref) - # Mock config - mock_config = MagicMock() - mock_config.plugin_settings = MagicMock() - MockConfig.return_value = mock_config + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") + await manager.initialize() - manager = PluginManager("test_config.yaml") - manager._registry = MockRegistry.return_value - manager._initialized = True + with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): + content = ResourceContent(type="resource", id="123", uri="test://resource", text="Test") + payload = ResourcePostFetchPayload(uri="test://resource", content=content) + global_context = GlobalContext(request_id="test-123") - content = ResourceContent(type="resource", id="123", uri="test://resource", text="Test") - payload = ResourcePostFetchPayload(uri="test://resource", content=content) - global_context = GlobalContext(request_id="test-123") + result, contexts = await manager.invoke_hook(HookType.RESOURCE_POST_FETCH, payload, global_context, {}) - result, contexts = await manager.resource_post_fetch(payload, global_context, {}) + assert result.continue_processing is True + manager._registry.get_hook_refs_for_hook.assert_called_with(hook_type=HookType.RESOURCE_POST_FETCH) - assert result.continue_processing is True - MockRegistry.return_value.get_plugins_for_hook.assert_called_with(HookType.RESOURCE_POST_FETCH) + await manager.shutdown() @pytest.mark.asyncio async def test_resource_hook_chain_execution(self): """Test multiple resource plugins executing in priority order.""" - class FirstPlugin(Plugin): + class FirstPlugin(MCPPlugin): async def resource_pre_fetch(self, payload, context): # Add metadata payload.metadata["first"] = True @@ -339,7 +336,7 @@ async def resource_pre_fetch(self, payload, context): modified_payload=payload, ) - class SecondPlugin(Plugin): + class SecondPlugin(MCPPlugin): async def resource_pre_fetch(self, payload, context): # Check first plugin ran assert payload.metadata.get("first") is True @@ -383,8 +380,10 @@ async def resource_pre_fetch(self, payload, context): @pytest.mark.asyncio async def test_resource_hook_error_handling(self): """Test resource hook error handling.""" + # First-Party + from mcpgateway.plugins.framework.base import HookRef - class ErrorPlugin(Plugin): + class ErrorPlugin(MCPPlugin): async def resource_pre_fetch(self, payload, context): raise ValueError("Test error in plugin") @@ -399,47 +398,33 @@ async def resource_pre_fetch(self, payload, context): mode=PluginMode.PERMISSIVE, # Should continue on error ) plugin = ErrorPlugin(config) + plugin_ref = PluginRef(plugin) + hook_ref = HookRef(HookType.RESOURCE_PRE_FETCH, plugin_ref) - with patch("mcpgateway.plugins.framework.manager.PluginInstanceRegistry") as MockRegistry: - with patch("mcpgateway.plugins.framework.loader.config.ConfigLoader.load_config") as MockConfig: - # Create a proper mock ref - mock_ref = MagicMock() - mock_ref._plugin = plugin - mock_ref.plugin = plugin - mock_ref.name = "error_plugin" - mock_ref.priority = 100 - mock_ref.mode = PluginMode.PERMISSIVE - mock_ref.conditions = [] - mock_ref.uuid = "test-uuid" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") + await manager.initialize() - MockRegistry.return_value.get_plugins_for_hook.return_value = [mock_ref] - - # Mock config - mock_config = MagicMock() - mock_config.plugin_settings = MagicMock() - mock_config.plugin_settings.fail_on_plugin_error = False - MockConfig.return_value = mock_config + payload = ResourcePreFetchPayload(uri="test://resource", metadata={}) + global_context = GlobalContext(request_id="test-123") - manager = PluginManager("test_config.yaml") - manager._registry = MockRegistry.return_value - manager._initialized = True + # Test with permissive mode - should handle error gracefully + with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): + result, contexts = await manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) + assert result.continue_processing is True # Continues despite error - payload = ResourcePreFetchPayload(uri="test://resource", metadata={}) - global_context = GlobalContext(request_id="test-123") - # Should handle error gracefully when fail_on_plugin_error = False - result, contexts = await manager.resource_pre_fetch(payload, global_context) - assert result.continue_processing is True # Continues despite error + # Test with enforce mode - should raise PluginError + config.mode = PluginMode.ENFORCE + with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): + with pytest.raises(PluginError): + result, contexts = await manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) - mock_config.plugin_settings.fail_on_plugin_error = True - # Should throw a plugin error since fail_on_plugin_error = True - with pytest.raises(PluginError): - result, contexts = await manager.resource_pre_fetch(payload, global_context) + await manager.shutdown() @pytest.mark.asyncio async def test_resource_uri_modification(self): """Test resource URI modification in pre-fetch.""" - class URIModifierPlugin(Plugin): + class URIModifierPlugin(MCPPlugin): async def resource_pre_fetch(self, payload, context): # Modify URI to add prefix modified_payload = ResourcePreFetchPayload( @@ -474,7 +459,7 @@ async def resource_pre_fetch(self, payload, context): async def test_resource_metadata_enrichment(self): """Test resource metadata enrichment in pre-fetch.""" - class MetadataEnricherPlugin(Plugin): + class MetadataEnricherPlugin(MCPPlugin): async def resource_pre_fetch(self, payload, context): # Add metadata payload.metadata["timestamp"] = "2024-01-01T00:00:00Z" diff --git a/tests/unit/mcpgateway/plugins/framework/test_utils.py b/tests/unit/mcpgateway/plugins/framework/test_utils.py index 82b303417..126824756 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_utils.py +++ b/tests/unit/mcpgateway/plugins/framework/test_utils.py @@ -11,50 +11,51 @@ import sys # First-Party -from mcpgateway.plugins.framework.models import GlobalContext, PluginCondition, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload -from mcpgateway.plugins.framework.utils import import_module, matches, parse_class_name, post_prompt_matches, post_tool_matches, pre_prompt_matches, pre_tool_matches +from mcpgateway.plugins.framework import GlobalContext, PluginCondition +from mcpgateway.plugins.framework.utils import import_module, matches, parse_class_name #, post_prompt_matches, post_tool_matches, pre_prompt_matches, pre_tool_matches +#from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload -def test_server_ids(): - condition1 = PluginCondition(server_ids={"1", "2"}) - context1 = GlobalContext(server_id="1", tenant_id="4", request_id="5") +# def test_server_ids(): +# condition1 = PluginCondition(server_ids={"1", "2"}) +# context1 = GlobalContext(server_id="1", tenant_id="4", request_id="5") - payload1 = PromptPrehookPayload(prompt_id="test_prompt", args={}) +# payload1 = PromptPrehookPayload(prompt_id="test_prompt", args={}) - assert matches(condition=condition1, context=context1) - assert pre_prompt_matches(payload1, [condition1], context1) +# assert matches(condition=condition1, context=context1) +# assert pre_prompt_matches(payload1, [condition1], context1) - context2 = GlobalContext(server_id="3", tenant_id="6", request_id="1") - assert not matches(condition=condition1, context=context2) - assert not pre_prompt_matches(payload1, conditions=[condition1], context=context2) +# context2 = GlobalContext(server_id="3", tenant_id="6", request_id="1") +# assert not matches(condition=condition1, context=context2) +# assert not pre_prompt_matches(payload1, conditions=[condition1], context=context2) - condition2 = PluginCondition(server_ids={"1"}, tenant_ids={"4"}) +# condition2 = PluginCondition(server_ids={"1"}, tenant_ids={"4"}) - context2 = GlobalContext(server_id="1", tenant_id="4", request_id="1") +# context2 = GlobalContext(server_id="1", tenant_id="4", request_id="1") - assert matches(condition2, context2) - assert pre_prompt_matches(payload1, conditions=[condition2], context=context2) +# assert matches(condition2, context2) +# assert pre_prompt_matches(payload1, conditions=[condition2], context=context2) - context3 = GlobalContext(server_id="1", tenant_id="5", request_id="1") +# context3 = GlobalContext(server_id="1", tenant_id="5", request_id="1") - assert not matches(condition2, context3) - assert not pre_prompt_matches(payload1, conditions=[condition2], context=context3) +# assert not matches(condition2, context3) +# assert not pre_prompt_matches(payload1, conditions=[condition2], context=context3) - condition4 = PluginCondition(user_patterns=["blah", "barker", "bobby"]) - context4 = GlobalContext(user="blah", request_id="1") +# condition4 = PluginCondition(user_patterns=["blah", "barker", "bobby"]) +# context4 = GlobalContext(user="blah", request_id="1") - assert matches(condition4, context4) - assert pre_prompt_matches(payload1, conditions=[condition4], context=context4) +# assert matches(condition4, context4) +# assert pre_prompt_matches(payload1, conditions=[condition4], context=context4) - context5 = GlobalContext(user="barney", request_id="1") - assert not matches(condition4, context5) - assert not pre_prompt_matches(payload1, conditions=[condition4], context=context5) +# context5 = GlobalContext(user="barney", request_id="1") +# assert not matches(condition4, context5) +# assert not pre_prompt_matches(payload1, conditions=[condition4], context=context5) - condition5 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt"}) +# condition5 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt"}) - assert pre_prompt_matches(payload1, [condition5], context1) - condition6 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt2"}) - assert not pre_prompt_matches(payload1, [condition6], context1) +# assert pre_prompt_matches(payload1, [condition5], context1) +# condition6 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt2"}) +# assert not pre_prompt_matches(payload1, [condition6], context1) # ============================================================================ @@ -110,61 +111,61 @@ def test_parse_class_name(): # ============================================================================ -def test_post_prompt_matches(): - """Test the post_prompt_matches function.""" - # Import required models - # First-Party - from mcpgateway.models import Message, PromptResult, TextContent +# def test_post_prompt_matches(): +# """Test the post_prompt_matches function.""" +# # Import required models +# # First-Party +# from mcpgateway.models import Message, PromptResult, TextContent - # Test basic matching - msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) - result = PromptResult(messages=[msg]) - payload = PromptPosthookPayload(prompt_id="greeting", result=result) - condition = PluginCondition(prompts={"greeting"}) - context = GlobalContext(request_id="req1") +# # Test basic matching +# msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) +# result = PromptResult(messages=[msg]) +# payload = PromptPosthookPayload(prompt_id="greeting", result=result) +# condition = PluginCondition(prompts={"greeting"}) +# context = GlobalContext(request_id="req1") - assert post_prompt_matches(payload, [condition], context) is True +# assert post_prompt_matches(payload, [condition], context) is True - # Test no match - payload2 = PromptPosthookPayload(prompt_id ="other", result=result) - assert post_prompt_matches(payload2, [condition], context) is False +# # Test no match +# payload2 = PromptPosthookPayload(prompt_id ="other", result=result) +# assert post_prompt_matches(payload2, [condition], context) is False - # Test with server_id condition - condition_with_server = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) - context_with_server = GlobalContext(request_id="req1", server_id="srv1") +# # Test with server_id condition +# condition_with_server = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) +# context_with_server = GlobalContext(request_id="req1", server_id="srv1") - assert post_prompt_matches(payload, [condition_with_server], context_with_server) is True +# assert post_prompt_matches(payload, [condition_with_server], context_with_server) is True - # Test with mismatched server_id - context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") - assert post_prompt_matches(payload, [condition_with_server], context_wrong_server) is False +# # Test with mismatched server_id +# context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") +# assert post_prompt_matches(payload, [condition_with_server], context_wrong_server) is False -def test_post_prompt_matches_multiple_conditions(): - """Test post_prompt_matches with multiple conditions (OR logic).""" - # First-Party - from mcpgateway.models import Message, PromptResult, TextContent +# def test_post_prompt_matches_multiple_conditions(): +# """Test post_prompt_matches with multiple conditions (OR logic).""" +# # First-Party +# from mcpgateway.models import Message, PromptResult, TextContent - # Create the payload - msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) - result = PromptResult(messages=[msg]) - payload = PromptPosthookPayload(prompt_id="greeting", result=result) +# # Create the payload +# msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) +# result = PromptResult(messages=[msg]) +# payload = PromptPosthookPayload(prompt_id="greeting", result=result) - # First condition fails, second condition succeeds - condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) - condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) - context = GlobalContext(request_id="req1", server_id="srv2") +# # First condition fails, second condition succeeds +# condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) +# condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) +# context = GlobalContext(request_id="req1", server_id="srv2") - assert post_prompt_matches(payload, [condition1, condition2], context) is True +# assert post_prompt_matches(payload, [condition1, condition2], context) is True - # Both conditions fail - context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert post_prompt_matches(payload, [condition1, condition2], context_no_match) is False +# # Both conditions fail +# context_no_match = GlobalContext(request_id="req1", server_id="srv3") +# assert post_prompt_matches(payload, [condition1, condition2], context_no_match) is False - # Test reset logic between conditions - condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) - condition4 = PluginCondition(prompts={"greeting"}) - assert post_prompt_matches(payload, [condition3, condition4], context_no_match) is True +# # Test reset logic between conditions +# condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) +# condition4 = PluginCondition(prompts={"greeting"}) +# assert post_prompt_matches(payload, [condition3, condition4], context_no_match) is True # ============================================================================ @@ -172,49 +173,49 @@ def test_post_prompt_matches_multiple_conditions(): # ============================================================================ -def test_pre_tool_matches(): - """Test the pre_tool_matches function.""" - # Test basic matching - payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) - condition = PluginCondition(tools={"calculator"}) - context = GlobalContext(request_id="req1") +# def test_pre_tool_matches(): +# """Test the pre_tool_matches function.""" +# # Test basic matching +# payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) +# condition = PluginCondition(tools={"calculator"}) +# context = GlobalContext(request_id="req1") - assert pre_tool_matches(payload, [condition], context) is True +# assert pre_tool_matches(payload, [condition], context) is True - # Test no match - payload2 = ToolPreInvokePayload(name="other_tool", args={}) - assert pre_tool_matches(payload2, [condition], context) is False +# # Test no match +# payload2 = ToolPreInvokePayload(name="other_tool", args={}) +# assert pre_tool_matches(payload2, [condition], context) is False - # Test with server_id condition - condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) - context_with_server = GlobalContext(request_id="req1", server_id="srv1") +# # Test with server_id condition +# condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) +# context_with_server = GlobalContext(request_id="req1", server_id="srv1") - assert pre_tool_matches(payload, [condition_with_server], context_with_server) is True +# assert pre_tool_matches(payload, [condition_with_server], context_with_server) is True - # Test with mismatched server_id - context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") - assert pre_tool_matches(payload, [condition_with_server], context_wrong_server) is False +# # Test with mismatched server_id +# context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") +# assert pre_tool_matches(payload, [condition_with_server], context_wrong_server) is False -def test_pre_tool_matches_multiple_conditions(): - """Test pre_tool_matches with multiple conditions (OR logic).""" - payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) +# def test_pre_tool_matches_multiple_conditions(): +# """Test pre_tool_matches with multiple conditions (OR logic).""" +# payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) - # First condition fails, second condition succeeds - condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) - condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) - context = GlobalContext(request_id="req1", server_id="srv2") +# # First condition fails, second condition succeeds +# condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) +# condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) +# context = GlobalContext(request_id="req1", server_id="srv2") - assert pre_tool_matches(payload, [condition1, condition2], context) is True +# assert pre_tool_matches(payload, [condition1, condition2], context) is True - # Both conditions fail - context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert pre_tool_matches(payload, [condition1, condition2], context_no_match) is False +# # Both conditions fail +# context_no_match = GlobalContext(request_id="req1", server_id="srv3") +# assert pre_tool_matches(payload, [condition1, condition2], context_no_match) is False - # Test reset logic between conditions - condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) - condition4 = PluginCondition(tools={"calculator"}) - assert pre_tool_matches(payload, [condition3, condition4], context_no_match) is True +# # Test reset logic between conditions +# condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) +# condition4 = PluginCondition(tools={"calculator"}) +# assert pre_tool_matches(payload, [condition3, condition4], context_no_match) is True # ============================================================================ @@ -222,49 +223,49 @@ def test_pre_tool_matches_multiple_conditions(): # ============================================================================ -def test_post_tool_matches(): - """Test the post_tool_matches function.""" - # Test basic matching - payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) - condition = PluginCondition(tools={"calculator"}) - context = GlobalContext(request_id="req1") +# def test_post_tool_matches(): +# """Test the post_tool_matches function.""" +# # Test basic matching +# payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) +# condition = PluginCondition(tools={"calculator"}) +# context = GlobalContext(request_id="req1") - assert post_tool_matches(payload, [condition], context) is True +# assert post_tool_matches(payload, [condition], context) is True - # Test no match - payload2 = ToolPostInvokePayload(name="other_tool", result={}) - assert post_tool_matches(payload2, [condition], context) is False +# # Test no match +# payload2 = ToolPostInvokePayload(name="other_tool", result={}) +# assert post_tool_matches(payload2, [condition], context) is False - # Test with server_id condition - condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) - context_with_server = GlobalContext(request_id="req1", server_id="srv1") +# # Test with server_id condition +# condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) +# context_with_server = GlobalContext(request_id="req1", server_id="srv1") - assert post_tool_matches(payload, [condition_with_server], context_with_server) is True +# assert post_tool_matches(payload, [condition_with_server], context_with_server) is True - # Test with mismatched server_id - context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") - assert post_tool_matches(payload, [condition_with_server], context_wrong_server) is False +# # Test with mismatched server_id +# context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") +# assert post_tool_matches(payload, [condition_with_server], context_wrong_server) is False -def test_post_tool_matches_multiple_conditions(): - """Test post_tool_matches with multiple conditions (OR logic).""" - payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) +# def test_post_tool_matches_multiple_conditions(): +# """Test post_tool_matches with multiple conditions (OR logic).""" +# payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) - # First condition fails, second condition succeeds - condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) - condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) - context = GlobalContext(request_id="req1", server_id="srv2") +# # First condition fails, second condition succeeds +# condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) +# condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) +# context = GlobalContext(request_id="req1", server_id="srv2") - assert post_tool_matches(payload, [condition1, condition2], context) is True +# assert post_tool_matches(payload, [condition1, condition2], context) is True - # Both conditions fail - context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert post_tool_matches(payload, [condition1, condition2], context_no_match) is False +# # Both conditions fail +# context_no_match = GlobalContext(request_id="req1", server_id="srv3") +# assert post_tool_matches(payload, [condition1, condition2], context_no_match) is False - # Test reset logic between conditions - condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) - condition4 = PluginCondition(tools={"calculator"}) - assert post_tool_matches(payload, [condition3, condition4], context_no_match) is True +# # Test reset logic between conditions +# condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) +# condition4 = PluginCondition(tools={"calculator"}) +# assert post_tool_matches(payload, [condition3, condition4], context_no_match) is True # ============================================================================ @@ -272,25 +273,25 @@ def test_post_tool_matches_multiple_conditions(): # ============================================================================ -def test_pre_prompt_matches_multiple_conditions(): - """Test pre_prompt_matches with multiple conditions to cover OR logic paths.""" - payload = PromptPrehookPayload(prompt_id="greeting", args={}) +# def test_pre_prompt_matches_multiple_conditions(): +# """Test pre_prompt_matches with multiple conditions to cover OR logic paths.""" +# payload = PromptPrehookPayload(prompt_id="greeting", args={}) - # First condition fails, second condition succeeds - condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) - condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) - context = GlobalContext(request_id="req1", server_id="srv2") +# # First condition fails, second condition succeeds +# condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) +# condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) +# context = GlobalContext(request_id="req1", server_id="srv2") - assert pre_prompt_matches(payload, [condition1, condition2], context) is True +# assert pre_prompt_matches(payload, [condition1, condition2], context) is True - # Both conditions fail - context_no_match = GlobalContext(request_id="req1", server_id="srv3") - assert pre_prompt_matches(payload, [condition1, condition2], context_no_match) is False +# # Both conditions fail +# context_no_match = GlobalContext(request_id="req1", server_id="srv3") +# assert pre_prompt_matches(payload, [condition1, condition2], context_no_match) is False - # Test reset logic between conditions (line 140) - condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) - condition4 = PluginCondition(prompts={"greeting"}) - assert pre_prompt_matches(payload, [condition3, condition4], context_no_match) is True +# # Test reset logic between conditions (line 140) +# condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) +# condition4 = PluginCondition(prompts={"greeting"}) +# assert pre_prompt_matches(payload, [condition3, condition4], context_no_match) is True # ============================================================================ diff --git a/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py b/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py index 8b1f0be30..7fb6fa5a3 100644 --- a/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py +++ b/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py @@ -14,11 +14,13 @@ import pytest # First-Party -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ToolPostInvokePayload, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py b/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py index 8368fb5dd..022ad5dff 100644 --- a/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py +++ b/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py @@ -11,11 +11,13 @@ import pytest # First-Party -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPrehookPayload, ToolPreInvokePayload, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py b/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py index 10f2f16f7..631e3c8f2 100644 --- a/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py +++ b/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py @@ -11,9 +11,12 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) + +from mcpgateway.plugins.mcp.entities import ( + HookType, ToolPreInvokePayload, ToolPostInvokePayload, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py b/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py index 1de4ff24a..be3577281 100644 --- a/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py +++ b/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py @@ -11,9 +11,11 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ToolPostInvokePayload, ) from plugins.code_safety_linter.code_safety_linter import CodeSafetyLinterPlugin diff --git a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py index e7ec89ada..70b1b58a5 100644 --- a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py +++ b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py @@ -11,12 +11,14 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPrehookPayload, ToolPreInvokePayload, ToolPostInvokePayload, diff --git a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py index b443876bc..489fca952 100644 --- a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py +++ b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py @@ -13,11 +13,12 @@ import pytest from mcpgateway.plugins.framework.manager import PluginManager -from mcpgateway.plugins.framework.models import ( - GlobalContext, +from mcpgateway.plugins.framework import GlobalContext + +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPrehookPayload, ToolPreInvokePayload, - ToolPostInvokePayload, ) @@ -111,7 +112,7 @@ async def test_content_moderation_with_manager(): args={"query": "What is the weather like today?"} ) - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) # Verify result assert result.continue_processing is True @@ -194,7 +195,7 @@ async def test_content_moderation_blocking_harmful_content(): args={"query": "I hate all those people and want them gone"} ) - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) # Should be blocked due to high hate score assert result.continue_processing is False @@ -270,7 +271,7 @@ async def test_content_moderation_with_granite_fallback(): args={"query": "How to resolve conflicts peacefully"} ) - result, final_context = await manager.tool_pre_invoke(payload, context) + result, final_context = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, payload, context) # Should continue processing (fallback succeeded) assert result.continue_processing is True @@ -351,7 +352,7 @@ async def test_content_moderation_redaction(): args={"query": "This damn thing is not working"} ) - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) # Should continue processing but with modified content assert result.continue_processing is True @@ -442,7 +443,7 @@ async def test_content_moderation_multiple_providers(): args={"query": "What is machine learning?"} ) - prompt_result, _ = await manager.prompt_pre_fetch(prompt_payload, context) + prompt_result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt_payload, context) assert prompt_result.continue_processing is True # Test tool (goes to Granite) @@ -451,7 +452,7 @@ async def test_content_moderation_multiple_providers(): args={"query": "How to build AI models"} ) - tool_result, _ = await manager.tool_pre_invoke(tool_payload, context) + tool_result, _ = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, context) assert tool_result.continue_processing is True # Verify both providers were called diff --git a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py index f19dfe214..a3f8c571e 100644 --- a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py +++ b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py @@ -9,11 +9,13 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ResourcePostFetchPayload, ResourcePreFetchPayload, ) @@ -77,7 +79,7 @@ async def test_non_blocking_mode_reports_metadata(tmp_path): @pytest.mark.asyncio async def test_prompt_post_fetch_blocks_on_eicar_text(): plugin = _mk_plugin(True) - from mcpgateway.plugins.framework.models import PromptPosthookPayload + from mcpgateway.plugins.mcp.entities import PromptPosthookPayload pr = __import__("mcpgateway.models").models.PromptResult( messages=[ @@ -97,7 +99,7 @@ async def test_prompt_post_fetch_blocks_on_eicar_text(): @pytest.mark.asyncio async def test_tool_post_invoke_blocks_on_eicar_string(): plugin = _mk_plugin(True) - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload ctx = PluginContext(global_context=GlobalContext(request_id="r5")) payload = ToolPostInvokePayload(name="t", result={"text": EICAR}) @@ -118,7 +120,7 @@ async def test_health_stats_counters(): await plugin.resource_post_fetch(payload_r, ctx) # 2) prompt_post_fetch with EICAR -> attempted +1, infected +1 (total attempted=2, infected=2) - from mcpgateway.plugins.framework.models import PromptPosthookPayload + from mcpgateway.plugins.mcp.entities import PromptPosthookPayload pr = __import__("mcpgateway.models").models.PromptResult( messages=[ @@ -132,7 +134,7 @@ async def test_health_stats_counters(): await plugin.prompt_post_fetch(payload_p, ctx) # 3) tool_post_invoke with one EICAR and one clean string -> attempted +2, infected +1 - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload payload_t = ToolPostInvokePayload(name="t", result={"a": EICAR, "b": "clean"}) await plugin.tool_post_invoke(payload_t, ctx) diff --git a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py index e58430b9b..348af6781 100644 --- a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py +++ b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py @@ -11,9 +11,12 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) + +from mcpgateway.plugins.mcp.entities import ( + HookType, ResourcePreFetchPayload, ResourcePostFetchPayload, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py index a25d54fd8..e830ccbbe 100644 --- a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py +++ b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py @@ -11,9 +11,11 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ResourcePostFetchPayload, ) from mcpgateway.models import ResourceContent diff --git a/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py b/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py index d6ca40917..2be4c4213 100644 --- a/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py +++ b/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py @@ -12,9 +12,12 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) + +from mcpgateway.plugins.mcp.entities import ( + HookType, ToolPostInvokePayload, ) from plugins.json_repair.json_repair import JSONRepairPlugin diff --git a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py index e2b4c0df1..bb75e68d7 100644 --- a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py +++ b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py @@ -12,9 +12,11 @@ from mcpgateway.models import Message, PromptResult, TextContent from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPosthookPayload, ) from plugins.markdown_cleaner.markdown_cleaner import MarkdownCleanerPlugin diff --git a/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py b/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py index 621d98cc9..884da9828 100644 --- a/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py +++ b/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py @@ -10,9 +10,12 @@ # First-Party from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) + +from mcpgateway.plugins.mcp.entities import ( + HookType, ToolPostInvokePayload, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py index 23440ea33..3cde9b347 100644 --- a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py @@ -12,12 +12,14 @@ # First-Party from mcpgateway.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginMode, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPosthookPayload, PromptPrehookPayload, ) @@ -414,7 +416,7 @@ async def test_integration_with_manager(): payload = PromptPrehookPayload(prompt_id="test_prompt", args={"input": "Email: test@example.com, SSN: 123-45-6789"}) global_context = GlobalContext(request_id="test-manager") - result, contexts = await manager.prompt_pre_fetch(payload, global_context) + result, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) # Verify PII was masked assert result.modified_payload is not None diff --git a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py index 4e1bad235..0f152bb6a 100644 --- a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py +++ b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py @@ -11,9 +11,11 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPrehookPayload, ) from plugins.rate_limiter.rate_limiter import RateLimiterPlugin diff --git a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py index 08f12cf72..e8745c96c 100644 --- a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py @@ -14,10 +14,12 @@ from mcpgateway.models import ResourceContent from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginMode, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ResourcePostFetchPayload, ResourcePreFetchPayload, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py b/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py index 1f04cc08a..18c818e2b 100644 --- a/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py +++ b/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py @@ -11,9 +11,11 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ToolPreInvokePayload, ToolPostInvokePayload, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py b/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py index 649efe5e6..be9768faf 100644 --- a/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py +++ b/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py @@ -9,11 +9,13 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ResourcePreFetchPayload, ) from plugins.url_reputation.url_reputation import URLReputationPlugin diff --git a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py index 01eddc28a..b0e942085 100644 --- a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py +++ b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py @@ -15,9 +15,11 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, ResourcePreFetchPayload, ) @@ -144,7 +146,7 @@ async def test_local_allow_and_deny_overrides(): plugin = VirusTotalURLCheckerPlugin(cfg) plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"See {url}") ctx = PluginContext(global_context=GlobalContext(request_id="r7")) res = await plugin.tool_post_invoke(payload, ctx) @@ -190,7 +192,7 @@ async def test_override_precedence_allow_over_deny_vs_deny_over_allow(): plugin_allow = VirusTotalURLCheckerPlugin(cfg_allow) plugin_allow._client_factory = lambda c, h: _StubClient({}) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"visit {url}") ctx = PluginContext(global_context=GlobalContext(request_id="r8")) res_allow = await plugin_allow.tool_post_invoke(payload, ctx) @@ -249,7 +251,7 @@ async def test_prompt_scan_blocks_on_url(): os.environ["VT_API_KEY"] = "dummy" pr = PromptResult(messages=[Message(role="assistant", content=TextContent(type="text", text=f"see {url}"))]) - from mcpgateway.plugins.framework.models import PromptPosthookPayload + from mcpgateway.plugins.mcp.entities import PromptPosthookPayload payload = PromptPosthookPayload(prompt_id="p", result=pr) ctx = PluginContext(global_context=GlobalContext(request_id="r5")) res = await plugin.prompt_post_fetch(payload, ctx) @@ -291,7 +293,7 @@ async def test_resource_scan_blocks_on_url(): from mcpgateway.models import ResourceContent rc = ResourceContent(type="resource", id="345",uri="test://x", mime_type="text/plain", text=f"{url} is fishy") - from mcpgateway.plugins.framework.models import ResourcePostFetchPayload + from mcpgateway.plugins.mcp.entities import ResourcePostFetchPayload payload = ResourcePostFetchPayload(uri="test://x", content=rc) ctx = PluginContext(global_context=GlobalContext(request_id="r6")) res = await plugin.resource_post_fetch(payload, ctx) @@ -433,7 +435,7 @@ async def test_tool_output_url_block_and_ratio(): plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.framework.models import ToolPostInvokePayload + from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"See {url} for details") ctx = PluginContext(global_context=GlobalContext(request_id="r4")) diff --git a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py index 6307f651a..9eae48c7f 100644 --- a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py +++ b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py @@ -14,11 +14,10 @@ import pytest from mcpgateway.plugins.framework.manager import PluginManager -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, - ToolPostInvokePayload, - PluginViolation, ) +from mcpgateway.plugins.mcp.entities import HookType, ToolPostInvokePayload @pytest.mark.asyncio @@ -81,7 +80,7 @@ async def test_webhook_plugin_with_manager(): ) # Execute tool post-invoke hook - result, final_context = await manager.tool_post_invoke(payload, context) + result, final_context = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) # Verify result assert result.continue_processing is True @@ -164,14 +163,14 @@ async def test_webhook_plugin_violation_handling(): context = GlobalContext(request_id="violation-test", user="testuser") # Create payload with forbidden word that will trigger deny filter - from mcpgateway.plugins.framework.models import PromptPrehookPayload + from mcpgateway.plugins.mcp.entities import PromptPrehookPayload payload = PromptPrehookPayload( prompt_id="test_prompt", args={"query": "this contains forbidden word"} ) # Execute - should be blocked by deny filter - result, final_context = await manager.prompt_pre_fetch(payload, context) + result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) # Verify the request was blocked assert result.continue_processing is False @@ -248,7 +247,7 @@ async def test_webhook_plugin_multiple_webhooks(): ) # Execute hook - result, final_context = await manager.tool_post_invoke(payload, context) + result, final_context = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) assert result.continue_processing is True @@ -341,7 +340,7 @@ async def test_webhook_plugin_template_customization(): result={"data": "test"} ) - await manager.tool_post_invoke(payload, context) + await manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) # Verify webhook was called with custom template mock_client.post.assert_called_once() diff --git a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py index 23319275a..6aceeb285 100644 --- a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py +++ b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py @@ -13,10 +13,12 @@ from mcpgateway.plugins.framework.models import ( GlobalContext, - HookType, PluginConfig, PluginContext, PluginViolation, +) +from mcpgateway.plugins.mcp.entities import ( + HookType, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload, @@ -463,7 +465,7 @@ async def test_prompt_pre_and_post_hooks_return_success(self): # Test post-hook with mock notification plugin._notify_webhooks = AsyncMock() - from mcpgateway.plugins.framework.models import PromptPosthookPayload, PromptResult + from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptResult post_payload = PromptPosthookPayload( prompt_id="test_prompt", result=PromptResult(messages=[]) diff --git a/tests/unit/mcpgateway/services/test_resource_service_plugins.py b/tests/unit/mcpgateway/services/test_resource_service_plugins.py index 05c966816..f7b9d0e68 100644 --- a/tests/unit/mcpgateway/services/test_resource_service_plugins.py +++ b/tests/unit/mcpgateway/services/test_resource_service_plugins.py @@ -39,11 +39,21 @@ def resource_service(self): @pytest.fixture def resource_service_with_plugins(self): """Create a ResourceService instance with plugins enabled.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + with patch.dict(os.environ, {"PLUGINS_ENABLED": "true", "PLUGIN_CONFIG_FILE": "test_config.yaml"}): with patch("mcpgateway.services.resource_service.PluginManager") as MockPluginManager: mock_manager = MagicMock() mock_manager._initialized = False mock_manager.initialize = AsyncMock() + # Add default invoke_hook mock that returns success + mock_manager.invoke_hook = AsyncMock( + return_value=( + PluginResult(continue_processing=True, modified_payload=None), + None # contexts + ) + ) MockPluginManager.return_value = mock_manager service = ResourceService() service._plugin_manager = mock_manager @@ -70,6 +80,9 @@ async def test_read_resource_without_plugins(self, resource_service, mock_db): @pytest.mark.asyncio async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plugins, mock_db): """Test read_resource with pre-fetch hook execution.""" + # First-Party + from mcpgateway.plugins.mcp.entities import HookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -87,33 +100,6 @@ async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plu mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup pre-fetch hook response - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock( - continue_processing=True, - modified_payload=None, - violation=None, - ), - {"context": "data"}, # contexts - ) - ) - - # Setup post-fetch hook response - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock( - continue_processing=True, - modified_payload=None, - ), - None, - ) - ) - - # Explicitly call initialize if not already called - if hasattr(mock_manager.initialize, 'await_count') and mock_manager.initialize.await_count == 0: - await mock_manager.initialize() - result = await service.read_resource( mock_db, "test://resource", @@ -123,14 +109,14 @@ async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plu # Verify hooks were called mock_manager.initialize.assert_called() - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + assert mock_manager.invoke_hook.call_count >= 2 # Pre and post fetch - # Verify context was passed correctly - call_args = mock_manager.resource_pre_fetch.call_args - assert call_args[0][0].uri == "test://resource" # payload - assert call_args[0][1].request_id == "test-123" # global_context - assert call_args[0][1].user == "testuser" + # Verify context was passed correctly - check first call (pre-fetch) + first_call = mock_manager.invoke_hook.call_args_list[0] + assert first_call[0][0] == HookType.RESOURCE_PRE_FETCH # hook_type + assert first_call[0][1].uri == "test://resource" # payload + assert first_call[0][2].request_id == "test-123" # global_context + assert first_call[0][2].user == "testuser" @pytest.mark.asyncio async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugins, mock_db): @@ -152,8 +138,8 @@ async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugi mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup pre-fetch hook to block - mock_manager.resource_pre_fetch = AsyncMock( + # Setup invoke_hook to raise PluginViolationError + mock_manager.invoke_hook = AsyncMock( side_effect=PluginViolationError(message="Protocol not allowed", violation=PluginViolation( reason="Protocol not allowed", @@ -168,13 +154,15 @@ async def test_read_resource_blocked_by_plugin(self, resource_service_with_plugi await service.read_resource(mock_db, "file:///etc/passwd") assert "Protocol not allowed" in str(exc_info.value) - mock_manager.resource_pre_fetch.assert_called_once() - # Post-fetch should not be called if pre-fetch blocks - mock_manager.resource_post_fetch.assert_not_called() + mock_manager.invoke_hook.assert_called() @pytest.mark.asyncio async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_plugins, mock_db): """Test read_resource with URI modification by plugin.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + service = resource_service_with_plugins mock_manager = service._plugin_manager @@ -193,26 +181,27 @@ async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_ # Setup pre-fetch hook to modify URI modified_payload = MagicMock() modified_payload.uri = "cached://test://resource" - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock( - continue_processing=True, - modified_payload=modified_payload, - ), - {"context": "data"}, - ) - ) - # Setup post-fetch hook - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock( + # Use side_effect to return different results based on hook type + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.RESOURCE_PRE_FETCH: + return ( + PluginResult( + continue_processing=True, + modified_payload=modified_payload, + ), + {"context": "data"}, + ) + # POST_FETCH + return ( + PluginResult( continue_processing=True, modified_payload=None, ), None, ) - ) + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) result = await service.read_resource(mock_db, "test://resource") @@ -223,6 +212,10 @@ async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_ @pytest.mark.asyncio async def test_read_resource_content_filtered_by_plugin(self, resource_service_with_plugins, mock_db): """Test read_resource with content filtering by post-fetch hook.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -244,14 +237,6 @@ def scalar_one_or_none_side_effect(*args, **kwargs): mock_db.execute.return_value.scalar_one_or_none.side_effect = scalar_one_or_none_side_effect mock_db.get.return_value = mock_resource - # Setup pre-fetch hook - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - {"context": "data"}, - ) - ) - # Setup post-fetch hook to filter content filtered_content = ResourceContent( type="resource", @@ -260,17 +245,26 @@ def scalar_one_or_none_side_effect(*args, **kwargs): text="password: [REDACTED]\napi_key: [REDACTED]", ) resource_id = filtered_content.id - modified_payload = MagicMock() - modified_payload.content = filtered_content - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock( + modified_post_payload = MagicMock() + modified_post_payload.content = filtered_content + + # Use side_effect to return different results based on hook type + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.RESOURCE_PRE_FETCH: + return ( + PluginResult(continue_processing=True), + {"context": "data"}, + ) + # POST_FETCH + return ( + PluginResult( continue_processing=True, - modified_payload=modified_payload, + modified_payload=modified_post_payload, ), None, ) - ) + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) result = await service.read_resource(mock_db, resource_id) @@ -303,17 +297,21 @@ async def test_read_resource_plugin_error_handling(self, resource_service_with_p mock_db.get.return_value = mock_resource # Ensure resource_db is not None # Setup pre-fetch hook to raise an error - mock_manager.resource_pre_fetch = AsyncMock(side_effect=PluginError(error=PluginErrorModel(message="Plugin error", plugin_name="mock_plugin"))) + mock_manager.invoke_hook = AsyncMock(side_effect=PluginError(error=PluginErrorModel(message="Plugin error", plugin_name="mock_plugin"))) with pytest.raises(PluginError) as exc_info: await service.read_resource(mock_db, resource_id) - mock_manager.resource_pre_fetch.assert_called_once() + mock_manager.invoke_hook.assert_called_once() @pytest.mark.asyncio async def test_read_resource_post_fetch_blocking(self, resource_service_with_plugins, mock_db): """Test read_resource blocked by post-fetch hook.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -331,30 +329,32 @@ async def test_read_resource_post_fetch_blocking(self, resource_service_with_plu mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup pre-fetch hook - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - {"context": "data"}, + # Use side_effect to allow pre-fetch but block on post-fetch + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.RESOURCE_PRE_FETCH: + return ( + PluginResult(continue_processing=True), + {"context": "data"}, + ) + # POST_FETCH - raise error + raise PluginViolationError( + message="Content contains sensitive data", + violation=PluginViolation( + reason="Content contains sensitive data", + description="The resource content was flagged as containing sensitive information", + code="SENSITIVE_CONTENT", + details={"uri": "test://resource"} + ) ) - ) - # Setup post-fetch hook to block - mock_manager.resource_post_fetch = AsyncMock( - side_effect=PluginViolationError(message="Content contains sensitive data", - violation=PluginViolation( - reason="Content contains sensitive data", - description="The resource content was flagged as containing sensitive information", - code="SENSITIVE_CONTENT", - details={"uri": "test://resource"} - )) - ) + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with pytest.raises(PluginViolationError) as exc_info: await service.read_resource(mock_db, "test://resource") assert "Content contains sensitive data" in str(exc_info.value) - mock_manager.resource_post_fetch.assert_called_once() + # Verify invoke_hook was called at least twice (pre and post) + assert mock_manager.invoke_hook.call_count == 2 @pytest.mark.asyncio async def test_read_resource_with_template(self, resource_service_with_plugins, mock_db): @@ -377,32 +377,23 @@ async def test_read_resource_with_template(self, resource_service_with_plugins, mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup hooks - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - {"context": "data"}, - ) - ) - # Create a mock result with modified_payload explicitly set to None - mock_post_result = MagicMock() - mock_post_result.continue_processing = True - mock_post_result.modified_payload = None - - mock_manager.resource_post_fetch = AsyncMock( - return_value=(mock_post_result, None) - ) + # The default invoke_hook from fixture will work fine for this test + # since it just returns success with no modifications # Use the correct resource id for lookup result = await service.read_resource(mock_db, mock_resource.uri) assert result == mock_template_content - mock_manager.resource_pre_fetch.assert_called_once() - mock_manager.resource_post_fetch.assert_called_once() + # Verify hooks were called + assert mock_manager.invoke_hook.call_count >= 2 # Pre and post fetch @pytest.mark.asyncio async def test_read_resource_context_propagation(self, resource_service_with_plugins, mock_db): """Test context propagation from pre-fetch to post-fetch.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True service = resource_service_with_plugins @@ -422,28 +413,31 @@ async def test_read_resource_context_propagation(self, resource_service_with_plu # Capture contexts from pre-fetch test_contexts = {"plugin1": {"validated": True}} - mock_manager.resource_pre_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), - test_contexts, - ) - ) - # Verify contexts passed to post-fetch - mock_manager.resource_post_fetch = AsyncMock( - return_value=( - MagicMock(continue_processing=True), + # Use side_effect to return contexts from pre-fetch + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.RESOURCE_PRE_FETCH: + return ( + PluginResult(continue_processing=True), + test_contexts, + ) + # POST_FETCH + return ( + PluginResult(continue_processing=True), None, ) - ) + + mock_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) # The resource id must match the lookup for plugin logic to trigger await service.read_resource(mock_db, mock_resource.content.id) # Verify contexts were passed from pre to post - post_call_args = mock_manager.resource_post_fetch.call_args - assert post_call_args is not None, "resource_post_fetch was not called" - assert post_call_args[0][2] == test_contexts # Third argument is contexts + assert mock_manager.invoke_hook.call_count == 2 + # Check second call (post-fetch) to verify contexts were passed + post_call_args = mock_manager.invoke_hook.call_args_list[1] + # The contexts dict should be passed as the 4th positional arg (local_contexts) + assert post_call_args[0][3] == test_contexts # Fourth argument is local_contexts @pytest.mark.asyncio async def test_read_resource_inactive_resource(self, resource_service, mock_db): @@ -496,19 +490,13 @@ async def test_read_resource_no_request_id(self, resource_service_with_plugins, mock_db.execute.return_value.scalar_one_or_none.return_value = mock_resource mock_db.get.return_value = mock_resource # Ensure resource_db is not None - # Setup hooks - mock_manager.resource_pre_fetch = AsyncMock( - return_value=(MagicMock(continue_processing=True), None) - ) - mock_manager.resource_post_fetch = AsyncMock( - return_value=(MagicMock(continue_processing=True), None) - ) + # The default invoke_hook from fixture will work fine await service.read_resource(mock_db, "test://resource") - # Verify request_id was generated - call_args = mock_manager.resource_pre_fetch.call_args - assert call_args is not None, "resource_pre_fetch was not called" - global_context = call_args[0][1] + # Verify request_id was generated - check first call (pre-fetch) + assert mock_manager.invoke_hook.call_count >= 1, "invoke_hook was not called" + first_call = mock_manager.invoke_hook.call_args_list[0] + global_context = first_call[0][2] # Third positional arg is global_context assert global_context.request_id is not None assert len(global_context.request_id) > 0 diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index c4f46825d..2504f7984 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -2231,6 +2231,10 @@ def mock_passthrough(req_headers, tool_headers, db_session, gateway=None): async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, mock_tool, test_db): """Test invoking tool with successful plugin post-invoke hook.""" + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + # Configure tool as REST mock_tool.integration_type = "REST" mock_tool.request_type = "POST" @@ -2248,15 +2252,21 @@ async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, m mock_response.json = Mock(return_value={"result": "original response"}) tool_service._http_client.request.return_value = mock_response - # Mock plugin manager and post-invoke hook + # Mock plugin manager with invoke_hook mock_post_result = Mock() mock_post_result.continue_processing = True mock_post_result.violation = None mock_post_result.modified_payload = None tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(return_value=(mock_post_result, None)) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE + return (mock_post_result, None) + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), @@ -2264,8 +2274,8 @@ async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, m ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) - # Verify plugin post-invoke was called - tool_service._plugin_manager.tool_post_invoke.assert_called_once() + # Verify plugin hooks were called + assert tool_service._plugin_manager.invoke_hook.call_count == 2 # Pre and post invoke # Verify result assert result.content[0].text == '{\n "result": "original response"\n}' @@ -2298,9 +2308,19 @@ async def test_invoke_tool_with_plugin_post_invoke_modified_payload(self, tool_s mock_post_result.violation = None mock_post_result.modified_payload = mock_modified_payload + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(return_value=(mock_post_result, None)) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE + return (mock_post_result, None) + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), @@ -2308,8 +2328,8 @@ async def test_invoke_tool_with_plugin_post_invoke_modified_payload(self, tool_s ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) - # Verify plugin post-invoke was called - tool_service._plugin_manager.tool_post_invoke.assert_called_once() + # Verify plugin hooks were called + assert tool_service._plugin_manager.invoke_hook.call_count == 2 # Pre and post invoke # Verify result was modified by plugin assert result.content[0].text == "Modified by plugin" @@ -2342,9 +2362,19 @@ async def test_invoke_tool_with_plugin_post_invoke_invalid_modified_payload(self mock_post_result.violation = None mock_post_result.modified_payload = mock_modified_payload + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(return_value=(mock_post_result, None)) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE + return (mock_post_result, None) + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), @@ -2352,8 +2382,8 @@ async def test_invoke_tool_with_plugin_post_invoke_invalid_modified_payload(self ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) - # Verify plugin post-invoke was called - tool_service._plugin_manager.tool_post_invoke.assert_called_once() + # Verify plugin hooks were called + assert tool_service._plugin_manager.invoke_hook.call_count == 2 # Pre and post invoke # Verify result was converted to string since format was invalid assert result.content[0].text == "Invalid format - not a dict" @@ -2377,10 +2407,20 @@ async def test_invoke_tool_with_plugin_post_invoke_error_fail_on_error(self, too mock_response.json = Mock(return_value={"result": "original response"}) tool_service._http_client.request.return_value = mock_response - # Mock plugin manager and post-invoke hook with error + # Mock plugin manager with invoke_hook that raises error on POST_INVOKE + # First-Party + from mcpgateway.plugins.framework.models import PluginResult + from mcpgateway.plugins.mcp.entities import HookType + tool_service._plugin_manager = Mock() - tool_service._plugin_manager.tool_pre_invoke = AsyncMock(return_value=(Mock(continue_processing=True, violation=None, modified_payload=None), None)) - tool_service._plugin_manager.tool_post_invoke = AsyncMock(side_effect=Exception("Plugin error")) + + def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): + if hook_type == HookType.TOOL_PRE_INVOKE: + return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) + # POST_INVOKE - raise error + raise Exception("Plugin error") + + tool_service._plugin_manager.invoke_hook = AsyncMock(side_effect=invoke_hook_side_effect) # Mock plugin config to fail on errors mock_plugin_settings = Mock() From 61a51323eaba300366a159f6def827a060bc6d2b Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Thu, 30 Oct 2025 09:23:36 -0400 Subject: [PATCH 02/20] fix: pylint issues Signed-off-by: Frederico Araujo --- mcpgateway/plugins/framework/base.py | 4 ++-- .../plugins/framework/external/mcp/client.py | 5 ++++- .../framework/external/mcp/server/runtime.py | 20 +++++++++---------- mcpgateway/plugins/framework/manager.py | 3 +-- mcpgateway/plugins/framework/registry.py | 2 +- mcpgateway/plugins/mcp/entities/base.py | 2 +- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index a91739a44..3919d5758 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -188,7 +188,7 @@ def json_to_payload(self, hook: str, payload: Union[str | dict]) -> PluginPayloa # Fall back to global registry if not hook_payload_type: # First-Party - from mcpgateway.plugins.framework.hook_registry import get_hook_registry + from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel registry = get_hook_registry() hook_payload_type = registry.get_payload_type(hook) @@ -223,7 +223,7 @@ def json_to_result(self, hook: str, result: Union[str | dict]) -> PluginResult: # Fall back to global registry if not hook_result_type: # First-Party - from mcpgateway.plugins.framework.hook_registry import get_hook_registry + from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel registry = get_hook_registry() hook_result_type = registry.get_result_type(hook) diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index fcfb5e807..fc5905c14 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -316,9 +316,12 @@ async def shutdown(self) -> None: class ExternalHookRef(HookRef): """A Hook reference point for external plugins.""" - def __init__(self, hook: str, plugin_ref: PluginRef): + def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-init-not-called """Initialize a hook reference point for an external plugin. + Note: We intentionally don't call super().__init__() because external plugins + use invoke_hook() rather than direct method attributes. + Args: hook: name of the hook point. plugin_ref: The reference to the plugin to hook. diff --git a/mcpgateway/plugins/framework/external/mcp/server/runtime.py b/mcpgateway/plugins/framework/external/mcp/server/runtime.py index 5091fc517..5cb2241b8 100755 --- a/mcpgateway/plugins/framework/external/mcp/server/runtime.py +++ b/mcpgateway/plugins/framework/external/mcp/server/runtime.py @@ -157,12 +157,12 @@ async def _start_health_check_server(self, health_port: int) -> None: health_port: Port number for the health check server. """ # Third-Party - from starlette.applications import Starlette - from starlette.requests import Request - from starlette.responses import JSONResponse - from starlette.routing import Route + from starlette.applications import Starlette # pylint: disable=import-outside-toplevel + from starlette.requests import Request # pylint: disable=import-outside-toplevel + from starlette.responses import JSONResponse # pylint: disable=import-outside-toplevel + from starlette.routing import Route # pylint: disable=import-outside-toplevel - async def health_check(request: Request): + async def health_check(_request: Request): """Health check endpoint for container orchestration. Args: @@ -192,11 +192,11 @@ async def run_streamable_http_async(self) -> None: # Add health check endpoint to main app # Third-Party - from starlette.requests import Request - from starlette.responses import JSONResponse - from starlette.routing import Route + from starlette.requests import Request # pylint: disable=import-outside-toplevel + from starlette.responses import JSONResponse # pylint: disable=import-outside-toplevel + from starlette.routing import Route # pylint: disable=import-outside-toplevel - async def health_check(request: Request): + async def health_check(_request: Request): """Health check endpoint for container orchestration. Args: @@ -254,7 +254,7 @@ async def run(): Raises: Exception: If plugin server initialization or execution fails. """ - global SERVER + global SERVER # pylint: disable=global-statement # Initialize plugin server SERVER = ExternalPluginServer() diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 8ef940717..9c312e782 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -612,8 +612,7 @@ async def invoke_hook_for_plugin( if isinstance(payload, (str, dict)): pydantic_payload = plugin.json_to_payload(hook_type, payload) return await self._executor.execute_plugin(hook_ref, pydantic_payload, context, violations_as_exceptions) - else: - raise ValueError(f"When payload_as_json=True, payload must be str or dict, got {type(payload)}") + raise ValueError(f"When payload_as_json=True, payload must be str or dict, got {type(payload)}") # When payload_as_json=False, payload should already be a PluginPayload if not isinstance(payload, PluginPayload): raise ValueError(f"When payload_as_json=False, payload must be a PluginPayload, got {type(payload)}") diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index 0268b4c0f..a6e0d59e3 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -98,7 +98,7 @@ def register(self, plugin: Plugin) -> None: self._priority_cache.pop(hook_type, None) self._hooks_by_name[plugin.name] = plugin_hooks - logger.info(f"Registered plugin: {plugin.name} with hooks: {[h for h in plugin.hooks]}") + logger.info(f"Registered plugin: {plugin.name} with hooks: {list(plugin.hooks)}") def unregister(self, plugin_name: str) -> None: """Unregister a plugin given its name. diff --git a/mcpgateway/plugins/mcp/entities/base.py b/mcpgateway/plugins/mcp/entities/base.py index 463d63202..ae17704a6 100644 --- a/mcpgateway/plugins/mcp/entities/base.py +++ b/mcpgateway/plugins/mcp/entities/base.py @@ -45,7 +45,7 @@ def _register_mcp_hooks(): """ # Import here to avoid circular dependency at module load time # First-Party - from mcpgateway.plugins.framework.hook_registry import get_hook_registry + from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel registry = get_hook_registry() From 9c6b8fc41dbd5407f3348034e5f4e12d486eeb05 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Thu, 30 Oct 2025 09:39:32 -0400 Subject: [PATCH 03/20] chore: uv lock Signed-off-by: Frederico Araujo --- uv.lock | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/uv.lock b/uv.lock index dad5fa09e..49a18e346 100644 --- a/uv.lock +++ b/uv.lock @@ -4649,8 +4649,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/89/3fdb5902bdab8868bbedc1c6e6023a4e08112ceac5db97fc2012060e0c9a/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e164359396576a3cc701ba8af4751ae68a07235d7a380c631184a611220d9a4", size = 4410955, upload-time = "2025-10-10T11:11:21.21Z" }, { url = "https://files.pythonhosted.org/packages/ce/24/e18339c407a13c72b336e0d9013fbbbde77b6fd13e853979019a1269519c/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d57c9c387660b8893093459738b6abddbb30a7eab058b77b0d0d1c7d521ddfd7", size = 4468007, upload-time = "2025-10-10T11:11:24.831Z" }, { url = "https://files.pythonhosted.org/packages/91/7e/b8441e831a0f16c159b5381698f9f7f7ed54b77d57bc9c5f99144cc78232/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2c226ef95eb2250974bf6fa7a842082b31f68385c4f3268370e3f3870e7859ee", size = 4165012, upload-time = "2025-10-10T11:11:29.51Z" }, + { url = "https://files.pythonhosted.org/packages/0d/61/4aa89eeb6d751f05178a13da95516c036e27468c5d4d2509bb1e15341c81/psycopg2_binary-2.9.11-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a311f1edc9967723d3511ea7d2708e2c3592e3405677bf53d5c7246753591fbb", size = 3981881, upload-time = "2025-10-30T02:55:07.332Z" }, { url = "https://files.pythonhosted.org/packages/76/a1/2f5841cae4c635a9459fe7aca8ed771336e9383b6429e05c01267b0774cf/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb415404821b6d1c47353ebe9c8645967a5235e6d88f914147e7fd411419e6f", size = 3650985, upload-time = "2025-10-10T11:11:34.975Z" }, { url = "https://files.pythonhosted.org/packages/84/74/4defcac9d002bca5709951b975173c8c2fa968e1a95dc713f61b3a8d3b6a/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f07c9c4a5093258a03b28fab9b4f151aa376989e7f35f855088234e656ee6a94", size = 3296039, upload-time = "2025-10-10T11:11:40.432Z" }, + { url = "https://files.pythonhosted.org/packages/6d/c2/782a3c64403d8ce35b5c50e1b684412cf94f171dc18111be8c976abd2de1/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:00ce1830d971f43b667abe4a56e42c1e2d594b32da4802e44a73bacacb25535f", size = 3043477, upload-time = "2025-10-30T02:55:11.182Z" }, { url = "https://files.pythonhosted.org/packages/c8/31/36a1d8e702aa35c38fc117c2b8be3f182613faa25d794b8aeaab948d4c03/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cffe9d7697ae7456649617e8bb8d7a45afb71cd13f7ab22af3e5c61f04840908", size = 3345842, upload-time = "2025-10-10T11:11:45.366Z" }, { url = "https://files.pythonhosted.org/packages/6e/b4/a5375cda5b54cb95ee9b836930fea30ae5a8f14aa97da7821722323d979b/psycopg2_binary-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:304fd7b7f97eef30e91b8f7e720b3db75fee010b520e434ea35ed1ff22501d03", size = 2713894, upload-time = "2025-10-10T11:11:48.775Z" }, { url = "https://files.pythonhosted.org/packages/d8/91/f870a02f51be4a65987b45a7de4c2e1897dd0d01051e2b559a38fa634e3e/psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4", size = 3756603, upload-time = "2025-10-10T11:11:52.213Z" }, @@ -4658,8 +4660,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/75/364847b879eb630b3ac8293798e380e441a957c53657995053c5ec39a316/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a", size = 4411159, upload-time = "2025-10-10T11:12:00.49Z" }, { url = "https://files.pythonhosted.org/packages/6f/a0/567f7ea38b6e1c62aafd58375665a547c00c608a471620c0edc364733e13/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e", size = 4468234, upload-time = "2025-10-10T11:12:04.892Z" }, { url = "https://files.pythonhosted.org/packages/30/da/4e42788fb811bbbfd7b7f045570c062f49e350e1d1f3df056c3fb5763353/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db", size = 4166236, upload-time = "2025-10-10T11:12:11.674Z" }, + { url = "https://files.pythonhosted.org/packages/3c/94/c1777c355bc560992af848d98216148be5f1be001af06e06fc49cbded578/psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757", size = 3983083, upload-time = "2025-10-30T02:55:15.73Z" }, { url = "https://files.pythonhosted.org/packages/bd/42/c9a21edf0e3daa7825ed04a4a8588686c6c14904344344a039556d78aa58/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3", size = 3652281, upload-time = "2025-10-10T11:12:17.713Z" }, { url = "https://files.pythonhosted.org/packages/12/22/dedfbcfa97917982301496b6b5e5e6c5531d1f35dd2b488b08d1ebc52482/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a", size = 3298010, upload-time = "2025-10-10T11:12:22.671Z" }, + { url = "https://files.pythonhosted.org/packages/66/ea/d3390e6696276078bd01b2ece417deac954dfdd552d2edc3d03204416c0c/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34", size = 3044641, upload-time = "2025-10-30T02:55:19.929Z" }, { url = "https://files.pythonhosted.org/packages/12/9a/0402ded6cbd321da0c0ba7d34dc12b29b14f5764c2fc10750daa38e825fc/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d", size = 3347940, upload-time = "2025-10-10T11:12:26.529Z" }, { url = "https://files.pythonhosted.org/packages/b1/d2/99b55e85832ccde77b211738ff3925a5d73ad183c0b37bcbbe5a8ff04978/psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d", size = 2714147, upload-time = "2025-10-10T11:12:29.535Z" }, { url = "https://files.pythonhosted.org/packages/ff/a8/a2709681b3ac11b0b1786def10006b8995125ba268c9a54bea6f5ae8bd3e/psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c", size = 3756572, upload-time = "2025-10-10T11:12:32.873Z" }, @@ -4667,8 +4671,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/32/b2ffe8f3853c181e88f0a157c5fb4e383102238d73c52ac6d93a5c8bffe6/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0", size = 4411242, upload-time = "2025-10-10T11:12:42.388Z" }, { url = "https://files.pythonhosted.org/packages/10/04/6ca7477e6160ae258dc96f67c371157776564679aefd247b66f4661501a2/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c0377174bf1dd416993d16edc15357f6eb17ac998244cca19bc67cdc0e2e5766", size = 4468258, upload-time = "2025-10-10T11:12:48.654Z" }, { url = "https://files.pythonhosted.org/packages/3c/7e/6a1a38f86412df101435809f225d57c1a021307dd0689f7a5e7fe83588b1/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5c6ff3335ce08c75afaed19e08699e8aacf95d4a260b495a4a8545244fe2ceb3", size = 4166295, upload-time = "2025-10-10T11:12:52.525Z" }, + { url = "https://files.pythonhosted.org/packages/f2/7d/c07374c501b45f3579a9eb761cbf2604ddef3d96ad48679112c2c5aa9c25/psycopg2_binary-2.9.11-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84011ba3109e06ac412f95399b704d3d6950e386b7994475b231cf61eec2fc1f", size = 3983133, upload-time = "2025-10-30T02:55:24.329Z" }, { url = "https://files.pythonhosted.org/packages/82/56/993b7104cb8345ad7d4516538ccf8f0d0ac640b1ebd8c754a7b024e76878/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba34475ceb08cccbdd98f6b46916917ae6eeb92b5ae111df10b544c3a4621dc4", size = 3652383, upload-time = "2025-10-10T11:12:56.387Z" }, { url = "https://files.pythonhosted.org/packages/2d/ac/eaeb6029362fd8d454a27374d84c6866c82c33bfc24587b4face5a8e43ef/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b31e90fdd0f968c2de3b26ab014314fe814225b6c324f770952f7d38abf17e3c", size = 3298168, upload-time = "2025-10-10T11:13:00.403Z" }, + { url = "https://files.pythonhosted.org/packages/2b/39/50c3facc66bded9ada5cbc0de867499a703dc6bca6be03070b4e3b65da6c/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:d526864e0f67f74937a8fce859bd56c979f5e2ec57ca7c627f5f1071ef7fee60", size = 3044712, upload-time = "2025-10-30T02:55:27.975Z" }, { url = "https://files.pythonhosted.org/packages/9c/8e/b7de019a1f562f72ada81081a12823d3c1590bedc48d7d2559410a2763fe/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04195548662fa544626c8ea0f06561eb6203f1984ba5b4562764fbeb4c3d14b1", size = 3347549, upload-time = "2025-10-10T11:13:03.971Z" }, { url = "https://files.pythonhosted.org/packages/80/2d/1bb683f64737bbb1f86c82b7359db1eb2be4e2c0c13b947f80efefa7d3e5/psycopg2_binary-2.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:efff12b432179443f54e230fdf60de1f6cc726b6c832db8701227d089310e8aa", size = 2714215, upload-time = "2025-10-10T11:13:07.14Z" }, ] From 3e6193c7fb85b442c54f301d51bc9c58dd5e6ba6 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Thu, 30 Oct 2025 11:20:17 -0600 Subject: [PATCH 04/20] refactor: created a common directory for classes used across packages. Signed-off-by: Teryl Taylor --- TESTING.md | 2 +- docs/docs/architecture/multitenancy.md | 2 +- mcpgateway/admin.py | 2 +- mcpgateway/cache/session_registry.py | 2 +- mcpgateway/common/__init__.py | 8 + mcpgateway/common/config.py | 104 ++ mcpgateway/common/models.py | 1073 +++++++++++++++ mcpgateway/common/validators.py | 1190 +++++++++++++++++ mcpgateway/db.py | 6 +- mcpgateway/federation/discovery.py | 2 +- mcpgateway/federation/forward.py | 2 +- mcpgateway/handlers/sampling.py | 6 +- mcpgateway/main.py | 5 +- .../plugins/framework/external/mcp/client.py | 2 +- mcpgateway/plugins/framework/models.py | 4 +- mcpgateway/plugins/mcp/entities/models.py | 10 +- mcpgateway/schemas.py | 12 +- mcpgateway/services/completion_service.py | 2 +- mcpgateway/services/log_storage_service.py | 8 +- mcpgateway/services/logging_service.py | 8 +- mcpgateway/services/prompt_service.py | 2 +- mcpgateway/services/resource_service.py | 4 +- mcpgateway/services/root_service.py | 6 +- mcpgateway/services/tool_service.py | 8 +- mcpgateway/utils/pagination.py | 8 +- mcpgateway/utils/passthrough_headers.py | 2 +- plugin_templates/external/tests/test_all.py | 2 +- .../llmguard/tests/test_llmguardplugin.py | 2 +- plugins/external/opa/tests/test_all.py | 2 +- .../opa/tests/test_opapluginfilter.py | 2 +- .../file_type_allowlist.py | 2 +- plugins/html_to_markdown/html_to_markdown.py | 2 +- plugins/markdown_cleaner/markdown_cleaner.py | 3 +- .../privacy_notice_injector.py | 2 +- plugins/resource_filter/resource_filter.py | 2 +- tests/integration/test_integration.py | 2 +- .../test_resource_plugin_integration.py | 2 +- tests/security/test_input_validation.py | 2 +- .../external/mcp/server/test_runtime.py | 2 +- .../external/mcp/test_client_config.py | 2 +- .../external/mcp/test_client_stdio.py | 2 +- .../mcp/test_client_streamable_http.py | 2 +- .../framework/loader/test_plugin_loader.py | 2 +- .../plugins/framework/test_manager.py | 2 +- .../framework/test_manager_extended.py | 6 +- .../plugins/framework/test_resource_hooks.py | 2 +- .../plugins/framework/test_utils.py | 4 +- .../external_clamav/test_clamav_remote.py | 15 +- .../test_file_type_allowlist.py | 2 +- .../html_to_markdown/test_html_to_markdown.py | 2 +- .../markdown_cleaner/test_markdown_cleaner.py | 2 +- .../plugins/pii_filter/test_pii_filter.py | 2 +- .../resource_filter/test_resource_filter.py | 2 +- .../test_virus_total_checker.py | 4 +- .../services/test_completion_service.py | 2 +- .../services/test_export_service.py | 4 +- .../services/test_log_storage_service.py | 2 +- .../services/test_logging_service.py | 2 +- .../test_logging_service_comprehensive.py | 2 +- .../services/test_prompt_service.py | 2 +- .../services/test_resource_service_plugins.py | 2 +- tests/unit/mcpgateway/test_discovery.py | 2 +- .../mcpgateway/test_final_coverage_push.py | 2 +- tests/unit/mcpgateway/test_main.py | 6 +- tests/unit/mcpgateway/test_models.py | 2 +- .../mcpgateway/test_rpc_tool_invocation.py | 2 +- tests/unit/mcpgateway/test_schemas.py | 2 +- .../mcpgateway/validation/test_validators.py | 2 +- .../validation/test_validators_advanced.py | 2 +- 69 files changed, 2485 insertions(+), 109 deletions(-) create mode 100644 mcpgateway/common/__init__.py create mode 100644 mcpgateway/common/config.py create mode 100644 mcpgateway/common/models.py create mode 100644 mcpgateway/common/validators.py diff --git a/TESTING.md b/TESTING.md index ccf64cfa0..bf4d0c291 100644 --- a/TESTING.md +++ b/TESTING.md @@ -291,7 +291,7 @@ class TestExampleService: def test_with_database(db_session): """Test using database session fixture.""" # db_session is automatically provided by conftest.py - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool tool = Tool(name="test_tool") db_session.add(tool) db_session.commit() diff --git a/docs/docs/architecture/multitenancy.md b/docs/docs/architecture/multitenancy.md index 01389d295..f7083c266 100644 --- a/docs/docs/architecture/multitenancy.md +++ b/docs/docs/architecture/multitenancy.md @@ -652,7 +652,7 @@ For emergency password resets, you can update the database directly: python3 -c " from mcpgateway.services.argon2_service import Argon2PasswordService from mcpgateway.db import SessionLocal -from mcpgateway.models import EmailUser +from mcpgateway.common.models import EmailUser service = Argon2PasswordService() hashed = service.hash_password('new_password') diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index ba597abd9..f4e5174b3 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -49,11 +49,11 @@ from starlette.datastructures import UploadFile as StarletteUploadFile # First-Party +from mcpgateway.common.models import LogLevel from mcpgateway.config import settings from mcpgateway.db import get_db, GlobalConfig from mcpgateway.db import Tool as DbTool from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission -from mcpgateway.models import LogLevel from mcpgateway.schemas import ( A2AAgentCreate, A2AAgentRead, diff --git a/mcpgateway/cache/session_registry.py b/mcpgateway/cache/session_registry.py index 3679f4267..c04093e3e 100644 --- a/mcpgateway/cache/session_registry.py +++ b/mcpgateway/cache/session_registry.py @@ -64,9 +64,9 @@ # First-Party from mcpgateway import __version__ +from mcpgateway.common.models import Implementation, InitializeResult, ServerCapabilities from mcpgateway.config import settings from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord -from mcpgateway.models import Implementation, InitializeResult, ServerCapabilities from mcpgateway.services import PromptService, ResourceService, ToolService from mcpgateway.services.logging_service import LoggingService from mcpgateway.transports import SSETransport diff --git a/mcpgateway/common/__init__.py b/mcpgateway/common/__init__.py new file mode 100644 index 000000000..2f4c65db1 --- /dev/null +++ b/mcpgateway/common/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/common/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Common ContextForge package for shared classes and functions. +""" diff --git a/mcpgateway/common/config.py b/mcpgateway/common/config.py new file mode 100644 index 000000000..5ab271fb2 --- /dev/null +++ b/mcpgateway/common/config.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/config.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti, Manav Gupta + +Common MCP Gateway Configuration settings used across subpackages. +This module defines configuration settings for the MCP Gateway using Pydantic. +It loads configuration from environment variables with sensible defaults. +""" + +# Standard +from functools import lru_cache + +# Third-Party +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Validation settings for the security validator.""" + + # Validation patterns for safe display (configurable) + validation_dangerous_html_pattern: str = ( + r"<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|" + ) + + validation_dangerous_js_pattern: str = r"(?i)(?:^|\s|[\"'`<>=])(javascript:|vbscript:|data:\s*[^,]*[;\s]*(javascript|vbscript)|\bon[a-z]+\s*=|<\s*script\b)" + + validation_allowed_url_schemes: list[str] = ["http://", "https://", "ws://", "wss://"] + + # Character validation patterns + validation_name_pattern: str = r"^[a-zA-Z0-9_.\-\s]+$" # Allow spaces for names + validation_identifier_pattern: str = r"^[a-zA-Z0-9_\-\.]+$" # No spaces for IDs + validation_safe_uri_pattern: str = r"^[a-zA-Z0-9_\-.:/?=&%]+$" + validation_unsafe_uri_pattern: str = r'[<>"\'\\]' + validation_tool_name_pattern: str = r"^[a-zA-Z][a-zA-Z0-9._-]*$" # MCP tool naming + validation_tool_method_pattern: str = r"^[a-zA-Z][a-zA-Z0-9_\./-]*$" + + # MCP-compliant size limits (configurable via env) + validation_max_name_length: int = 255 + validation_max_description_length: int = 8192 # 8KB + validation_max_template_length: int = 65536 # 64KB + validation_max_content_length: int = 1048576 # 1MB + validation_max_json_depth: int = 10 + validation_max_url_length: int = 2048 + validation_max_rpc_param_size: int = 262144 # 256KB + + validation_max_method_length: int = 128 + + # Allowed MIME types + validation_allowed_mime_types: list[str] = [ + "text/plain", + "text/html", + "text/css", + "text/markdown", + "text/javascript", + "application/json", + "application/xml", + "application/pdf", + "image/png", + "image/jpeg", + "image/gif", + "image/svg+xml", + "application/octet-stream", + ] + + # Rate limiting + validation_max_requests_per_minute: int = 60 + + # CLI settings + plugins_cli_markup_mode: str | None = None + plugins_cli_completion: bool = True + + +@lru_cache() +def get_settings() -> Settings: + """Get cached settings instance. + + Returns: + Settings: A cached instance of the Settings class. + + Examples: + >>> settings = get_settings() + >>> isinstance(settings, Settings) + True + >>> # Second call returns the same cached instance + >>> settings2 = get_settings() + >>> settings is settings2 + True + """ + # Instantiate a fresh Pydantic Settings object, + # loading from env vars or .env exactly once. + cfg = Settings() + # Validate that transport_type is correct; will + # raise if mis-configured. + # cfg.validate_transport() + # Ensure sqlite DB directories exist if needed. + # cfg.validate_database() + # Return the one-and-only Settings instance (cached). + return cfg + + +# Create settings instance +settings = get_settings() diff --git a/mcpgateway/common/models.py b/mcpgateway/common/models.py new file mode 100644 index 000000000..34ee1d9d8 --- /dev/null +++ b/mcpgateway/common/models.py @@ -0,0 +1,1073 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/common/models.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +MCP Protocol Type Definitions. +This module defines all core MCP protocol types according to the specification. +It includes: + - Message content types (text, image, resource) + - Tool definitions and schemas + - Resource types and templates + - Prompt structures + - Protocol initialization types + - Sampling message types + - Capability definitions + +Examples: + >>> from mcpgateway.common.models import Role, LogLevel, TextContent + >>> Role.USER.value + 'user' + >>> Role.ASSISTANT.value + 'assistant' + >>> LogLevel.ERROR.value + 'error' + >>> LogLevel.INFO.value + 'info' + >>> content = TextContent(type='text', text='Hello') + >>> content.text + 'Hello' + >>> content.type + 'text' +""" + +# Standard +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +# Third-Party +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, ConfigDict, Field + + +class Role(str, Enum): + """Message role in conversations. + + Attributes: + ASSISTANT (str): Indicates the assistant's role. + USER (str): Indicates the user's role. + + Examples: + >>> Role.USER.value + 'user' + >>> Role.ASSISTANT.value + 'assistant' + >>> Role.USER == 'user' + True + >>> list(Role) + [, ] + """ + + ASSISTANT = "assistant" + USER = "user" + + +class LogLevel(str, Enum): + """Standard syslog severity levels as defined in RFC 5424. + + Attributes: + DEBUG (str): Debug level. + INFO (str): Informational level. + NOTICE (str): Notice level. + WARNING (str): Warning level. + ERROR (str): Error level. + CRITICAL (str): Critical level. + ALERT (str): Alert level. + EMERGENCY (str): Emergency level. + """ + + DEBUG = "debug" + INFO = "info" + NOTICE = "notice" + WARNING = "warning" + ERROR = "error" + CRITICAL = "critical" + ALERT = "alert" + EMERGENCY = "emergency" + + +# Base content types +class TextContent(BaseModel): + """Text content for messages. + + Attributes: + type (Literal["text"]): The fixed content type identifier for text. + text (str): The actual text message. + + Examples: + >>> content = TextContent(type='text', text='Hello World') + >>> content.text + 'Hello World' + >>> content.type + 'text' + >>> content.model_dump() + {'type': 'text', 'text': 'Hello World'} + """ + + type: Literal["text"] + text: str + + +class JSONContent(BaseModel): + """JSON content for messages. + Attributes: + type (Literal["text"]): The fixed content type identifier for text. + json (dict): The actual text message. + """ + + type: Literal["text"] + text: dict + + +class ImageContent(BaseModel): + """Image content for messages. + + Attributes: + type (Literal["image"]): The fixed content type identifier for images. + data (bytes): The binary data of the image. + mime_type (str): The MIME type (e.g. "image/png") of the image. + """ + + type: Literal["image"] + data: bytes + mime_type: str + + +class ResourceContent(BaseModel): + """Resource content that can be embedded. + + Attributes: + type (Literal["resource"]): The fixed content type identifier for resources. + id (str): The ID identifying the resource. + uri (str): The URI of the resource. + mime_type (Optional[str]): The MIME type of the resource, if known. + text (Optional[str]): A textual representation of the resource, if applicable. + blob (Optional[bytes]): Binary data of the resource, if applicable. + """ + + type: Literal["resource"] + id: str + uri: str + mime_type: Optional[str] = None + text: Optional[str] = None + blob: Optional[bytes] = None + + +ContentType = Union[TextContent, JSONContent, ImageContent, ResourceContent] + + +# Reference types - needed early for completion +class PromptReference(BaseModel): + """Reference to a prompt or prompt template. + + Attributes: + type (Literal["ref/prompt"]): The fixed reference type identifier for prompts. + name (str): The unique name of the prompt. + """ + + type: Literal["ref/prompt"] + name: str + + +class ResourceReference(BaseModel): + """Reference to a resource or resource template. + + Attributes: + type (Literal["ref/resource"]): The fixed reference type identifier for resources. + uri (str): The URI of the resource. + """ + + type: Literal["ref/resource"] + uri: str + + +# Completion types +class CompleteRequest(BaseModel): + """Request for completion suggestions. + + Attributes: + ref (Union[PromptReference, ResourceReference]): A reference to a prompt or resource. + argument (Dict[str, str]): A dictionary containing arguments for the completion. + """ + + ref: Union[PromptReference, ResourceReference] + argument: Dict[str, str] + + +class CompleteResult(BaseModel): + """Result for a completion request. + + Attributes: + completion (Dict[str, Any]): A dictionary containing the completion results. + """ + + completion: Dict[str, Any] = Field(..., description="Completion results") + + +# Implementation info +class Implementation(BaseModel): + """MCP implementation information. + + Attributes: + name (str): The name of the implementation. + version (str): The version of the implementation. + """ + + name: str + version: str + + +# Model preferences +class ModelHint(BaseModel): + """Hint for model selection. + + Attributes: + name (Optional[str]): An optional hint for the model name. + """ + + name: Optional[str] = None + + +class ModelPreferences(BaseModel): + """Server preferences for model selection. + + Attributes: + cost_priority (float): Priority for cost efficiency (0 to 1). + speed_priority (float): Priority for speed (0 to 1). + intelligence_priority (float): Priority for intelligence (0 to 1). + hints (List[ModelHint]): A list of model hints. + """ + + cost_priority: float = Field(ge=0, le=1) + speed_priority: float = Field(ge=0, le=1) + intelligence_priority: float = Field(ge=0, le=1) + hints: List[ModelHint] = [] + + +# Capability types +class ClientCapabilities(BaseModel): + """Capabilities that a client may support. + + Attributes: + roots (Optional[Dict[str, bool]]): Capabilities related to root management. + sampling (Optional[Dict[str, Any]]): Capabilities related to LLM sampling. + experimental (Optional[Dict[str, Dict[str, Any]]]): Experimental capabilities. + """ + + roots: Optional[Dict[str, bool]] = None + sampling: Optional[Dict[str, Any]] = None + experimental: Optional[Dict[str, Dict[str, Any]]] = None + + +class ServerCapabilities(BaseModel): + """Capabilities that a server may support. + + Attributes: + prompts (Optional[Dict[str, bool]]): Capability for prompt support. + resources (Optional[Dict[str, bool]]): Capability for resource support. + tools (Optional[Dict[str, bool]]): Capability for tool support. + logging (Optional[Dict[str, Any]]): Capability for logging support. + experimental (Optional[Dict[str, Dict[str, Any]]]): Experimental capabilities. + """ + + prompts: Optional[Dict[str, bool]] = None + resources: Optional[Dict[str, bool]] = None + tools: Optional[Dict[str, bool]] = None + logging: Optional[Dict[str, Any]] = None + experimental: Optional[Dict[str, Dict[str, Any]]] = None + + +# Initialization types +class InitializeRequest(BaseModel): + """Initial request sent from the client to the server. + + Attributes: + protocol_version (str): The protocol version (alias: protocolVersion). + capabilities (ClientCapabilities): The client's capabilities. + client_info (Implementation): The client's implementation information (alias: clientInfo). + + Note: + The alias settings allow backward compatibility with older Pydantic versions. + """ + + protocol_version: str = Field(..., alias="protocolVersion") + capabilities: ClientCapabilities + client_info: Implementation = Field(..., alias="clientInfo") + + model_config = ConfigDict( + populate_by_name=True, + ) + + +class InitializeResult(BaseModel): + """Server's response to the initialization request. + + Attributes: + protocol_version (str): The protocol version used. + capabilities (ServerCapabilities): The server's capabilities. + server_info (Implementation): The server's implementation information. + instructions (Optional[str]): Optional instructions for the client. + """ + + protocol_version: str = Field(..., alias="protocolVersion") + capabilities: ServerCapabilities = Field(..., alias="capabilities") + server_info: Implementation = Field(..., alias="serverInfo") + instructions: Optional[str] = Field(None, alias="instructions") + + model_config = ConfigDict( + populate_by_name=True, + ) + + +# Message types +class Message(BaseModel): + """A message in a conversation. + + Attributes: + role (Role): The role of the message sender. + content (ContentType): The content of the message. + """ + + role: Role + content: ContentType + + +class SamplingMessage(BaseModel): + """A message used in LLM sampling requests. + + Attributes: + role (Role): The role of the sender. + content (ContentType): The content of the sampling message. + """ + + role: Role + content: ContentType + + +# Sampling types for the client features +class CreateMessageResult(BaseModel): + """Result from a sampling/createMessage request. + + Attributes: + content (Union[TextContent, ImageContent]): The generated content. + model (str): The model used for generating the content. + role (Role): The role associated with the content. + stop_reason (Optional[str]): An optional reason for why sampling stopped. + """ + + content: Union[TextContent, ImageContent] + model: str + role: Role + stop_reason: Optional[str] = None + + +# Prompt types +class PromptArgument(BaseModel): + """An argument that can be passed to a prompt. + + Attributes: + name (str): The name of the argument. + description (Optional[str]): An optional description of the argument. + required (bool): Whether the argument is required. Defaults to False. + """ + + name: str + description: Optional[str] = None + required: bool = False + + +class Prompt(BaseModel): + """A prompt template offered by the server. + + Attributes: + name (str): The unique name of the prompt. + description (Optional[str]): A description of the prompt. + arguments (List[PromptArgument]): A list of expected prompt arguments. + """ + + name: str + description: Optional[str] = None + arguments: List[PromptArgument] = [] + + +class PromptResult(BaseModel): + """Result of rendering a prompt template. + + Attributes: + messages (List[Message]): The list of messages produced by rendering the prompt. + description (Optional[str]): An optional description of the rendered result. + """ + + messages: List[Message] + description: Optional[str] = None + + +class CommonAttributes(BaseModel): + """Common attributes for tools and gateways. + + Attributes: + name (str): The unique name of the tool. + url (AnyHttpUrl): The URL of the tool. + description (Optional[str]): A description of the tool. + created_at (Optional[datetime]): The time at which the tool was created. + update_at (Optional[datetime]): The time at which the tool was updated. + enabled (Optional[bool]): If the tool is enabled. + reachable (Optional[bool]): If the tool is currently reachable. + tags (Optional[list[str]]): A list of meta data tags describing the tool. + created_by (Optional[str]): The person that created the tool. + created_from_ip (Optional[str]): The client IP that created the tool. + created_via (Optional[str]): How the tool was created (e.g., ui). + created_user_agent (Optioanl[str]): The client user agent. + modified_by (Optional[str]): The person that modified the tool. + modified_from_ip (Optional[str]): The client IP that modified the tool. + modified_via (Optional[str]): How the tool was modified (e.g., ui). + modified_user_agent (Optioanl[str]): The client user agent. + import_batch_id (Optional[str]): The id of the batch file that imported the tool. + federation_source (Optional[str]): The federation source of the tool + version (Optional[int]): The version of the tool. + team_id (Optional[str]): The id of the team that created the tool. + owner_email (Optional[str]): Tool owner's email. + visibility (Optional[str]): Visibility of the tool (e.g., public, private). + """ + + name: str + url: AnyHttpUrl + description: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + enabled: Optional[bool] = None + reachable: Optional[bool] = None + auth_type: Optional[str] = None + tags: Optional[list[str]] = None + # Comprehensive metadata for audit tracking + created_by: Optional[str] = None + created_from_ip: Optional[str] = None + created_via: Optional[str] = None + created_user_agent: Optional[str] = None + + modified_by: Optional[str] = None + modified_from_ip: Optional[str] = None + modified_via: Optional[str] = None + modified_user_agent: Optional[str] = None + + import_batch_id: Optional[str] = None + federation_source: Optional[str] = None + version: Optional[int] = None + # Team scoping fields for resource organization + team_id: Optional[str] = None + owner_email: Optional[str] = None + visibility: Optional[str] = None + + +# Tool types +class Tool(CommonAttributes): + """A tool that can be invoked. + + Attributes: + original_name (str): The original supplied name of the tool before imported by the gateway. + integrationType (str): The integration type of the tool (e.g. MCP or REST). + requestType (str): The HTTP method used to invoke the tool (GET, POST, PUT, DELETE, SSE, STDIO). + headers (Dict[str, Any]): A JSON object representing HTTP headers. + input_schema (Dict[str, Any]): A JSON Schema for validating the tool's input. + output_schema (Optional[Dict[str, Any]]): A JSON Schema for validating the tool's output. + annotations (Optional[Dict[str, Any]]): Tool annotations for behavior hints. + auth_username (Optional[str]): The username for basic authentication. + auth_password (Optional[str]): The password for basic authentication. + auth_token (Optional[str]): The token for bearer authentication. + jsonpath_filter (Optional[str]): Filter the tool based on a JSON path expression. + custom_name (Optional[str]): Custom tool name. + custom_name_slug (Optional[str]): Alternative custom tool name. + display_name (Optional[str]): Display name. + gateway_id (Optional[str]): The gateway id on which the tool is hosted. + """ + + model_config = ConfigDict(from_attributes=True) + original_name: Optional[str] = None + integration_type: str = "MCP" + request_type: str = "SSE" + headers: Optional[Dict[str, Any]] = Field(default_factory=dict) + input_schema: Dict[str, Any] = Field(default_factory=lambda: {"type": "object", "properties": {}}) + output_schema: Optional[Dict[str, Any]] = Field(default=None, description="JSON Schema for validating the tool's output") + annotations: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Tool annotations for behavior hints") + auth_username: Optional[str] = None + auth_password: Optional[str] = None + auth_token: Optional[str] = None + jsonpath_filter: Optional[str] = None + + # custom_name,custom_name_slug, display_name + custom_name: Optional[str] = None + custom_name_slug: Optional[str] = None + display_name: Optional[str] = None + + # Federation relationship with a local gateway + gateway_id: Optional[str] = None + + +class ToolResult(BaseModel): + """Result of a tool invocation. + + Attributes: + content (List[ContentType]): A list of content items returned by the tool. + is_error (bool): Flag indicating if the tool call resulted in an error. + """ + + content: List[ContentType] + is_error: bool = False + + +# Resource types +class Resource(BaseModel): + """A resource available from the server. + + Attributes: + uri (str): The unique URI of the resource. + name (str): The human-readable name of the resource. + description (Optional[str]): A description of the resource. + mime_type (Optional[str]): The MIME type of the resource. + size (Optional[int]): The size of the resource. + """ + + uri: str + name: str + description: Optional[str] = None + mime_type: Optional[str] = None + size: Optional[int] = None + + +class ResourceTemplate(BaseModel): + """A template for constructing resource URIs. + + Attributes: + uri_template (str): The URI template string. + name (str): The unique name of the template. + description (Optional[str]): A description of the template. + mime_type (Optional[str]): The MIME type associated with the template. + """ + + uri_template: str + name: str + description: Optional[str] = None + mime_type: Optional[str] = None + + +class ListResourceTemplatesResult(BaseModel): + """The server's response to a resources/templates/list request from the client. + + Attributes: + meta (Optional[Dict[str, Any]]): Reserved property for metadata. + next_cursor (Optional[str]): Pagination cursor for the next page of results. + resource_templates (List[ResourceTemplate]): List of resource templates. + """ + + meta: Optional[Dict[str, Any]] = Field( + None, alias="_meta", description="This result property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses." + ) + next_cursor: Optional[str] = Field(None, description="An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.") + resource_templates: List[ResourceTemplate] = Field(default_factory=list, description="List of resource templates available on the server") + + model_config = ConfigDict( + populate_by_name=True, + ) + + +# Root types +class FileUrl(AnyUrl): + """A specialized URL type for local file-scheme resources. + + Key characteristics + ------------------- + * Scheme restricted - only the "file" scheme is permitted + (e.g. file:///path/to/file.txt). + * No host required - "file" URLs typically omit a network host; + therefore, the host component is not mandatory. + * String-friendly equality - developers naturally expect + FileUrl("file:///data") == "file:///data" to evaluate True. + AnyUrl (Pydantic) does not implement that, so we override + __eq__ to compare against plain strings transparently. + Hash semantics are kept consistent by delegating to the parent class. + + Examples + -------- + >>> url = FileUrl("file:///etc/hosts") + >>> url.scheme + 'file' + >>> url == "file:///etc/hosts" + True + >>> {"path": url} # hashable + {'path': FileUrl('file:///etc/hosts')} + + Notes + ----- + The override does not interfere with comparisons to other + AnyUrl/FileUrl instances; those still use the superclass + implementation. + """ + + # Restrict to the "file" scheme and omit host requirement + allowed_schemes = {"file"} + host_required = False + + def __eq__(self, other): # type: ignore[override] + """Return True when other is an equivalent URL or string. + + If other is a str it is coerced with str(self) for comparison; + otherwise defer to AnyUrl's comparison. + + Args: + other (Any): The object to compare against. May be a str, FileUrl, or AnyUrl. + + Returns: + bool: True if the other value is equal to this URL, either as a string + or as another URL object. False otherwise. + """ + if isinstance(other, str): + return str(self) == other + return super().__eq__(other) + + # Keep hashing behaviour aligned with equality + __hash__ = AnyUrl.__hash__ + + +class Root(BaseModel): + """A root directory or file. + + Attributes: + uri (Union[FileUrl, AnyUrl]): The unique identifier for the root. + name (Optional[str]): An optional human-readable name. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + uri: Union[FileUrl, AnyUrl] = Field(..., description="Unique identifier for the root") + name: Optional[str] = Field(None, description="Optional human-readable name") + + +# Progress types +class ProgressToken(BaseModel): + """Token for associating progress notifications. + + Attributes: + value (Union[str, int]): The token value. + """ + + value: Union[str, int] + + +class Progress(BaseModel): + """Progress update for long-running operations. + + Attributes: + progress_token (ProgressToken): The token associated with the progress update. + progress (float): The current progress value. + total (Optional[float]): The total progress value, if known. + """ + + progress_token: ProgressToken + progress: float + total: Optional[float] = None + + +# JSON-RPC types +class JSONRPCRequest(BaseModel): + """JSON-RPC 2.0 request. + + Attributes: + jsonrpc (Literal["2.0"]): The JSON-RPC version. + id (Optional[Union[str, int]]): The request identifier. + method (str): The method name. + params (Optional[Dict[str, Any]]): The parameters for the request. + """ + + jsonrpc: Literal["2.0"] + id: Optional[Union[str, int]] = None + method: str + params: Optional[Dict[str, Any]] = None + + +class JSONRPCResponse(BaseModel): + """JSON-RPC 2.0 response. + + Attributes: + jsonrpc (Literal["2.0"]): The JSON-RPC version. + id (Optional[Union[str, int]]): The request identifier. + result (Optional[Any]): The result of the request. + error (Optional[Dict[str, Any]]): The error object if an error occurred. + """ + + jsonrpc: Literal["2.0"] + id: Optional[Union[str, int]] = None + result: Optional[Any] = None + error: Optional[Dict[str, Any]] = None + + +class JSONRPCError(BaseModel): + """JSON-RPC 2.0 error. + + Attributes: + code (int): The error code. + message (str): A short description of the error. + data (Optional[Any]): Additional data about the error. + """ + + code: int + message: str + data: Optional[Any] = None + + +# Global configuration types +class GlobalConfig(BaseModel): + """Global server configuration. + + Attributes: + passthrough_headers (Optional[List[str]]): List of headers allowed to be passed through globally + """ + + passthrough_headers: Optional[List[str]] = Field(default=None, description="List of headers allowed to be passed through globally") + + +# Transport message types +class SSEEvent(BaseModel): + """Server-Sent Events message. + + Attributes: + id (Optional[str]): The event identifier. + event (Optional[str]): The event type. + data (str): The event data. + retry (Optional[int]): The retry timeout in milliseconds. + """ + + id: Optional[str] = None + event: Optional[str] = None + data: str + retry: Optional[int] = None + + +class WebSocketMessage(BaseModel): + """WebSocket protocol message. + + Attributes: + type (str): The type of the WebSocket message. + data (Any): The message data. + """ + + type: str + data: Any + + +# Notification types +class ResourceUpdateNotification(BaseModel): + """Notification of resource changes. + + Attributes: + method (Literal["notifications/resources/updated"]): The notification method. + uri (str): The URI of the updated resource. + """ + + method: Literal["notifications/resources/updated"] + uri: str + + +class ResourceListChangedNotification(BaseModel): + """Notification of resource list changes. + + Attributes: + method (Literal["notifications/resources/list_changed"]): The notification method. + """ + + method: Literal["notifications/resources/list_changed"] + + +class PromptListChangedNotification(BaseModel): + """Notification of prompt list changes. + + Attributes: + method (Literal["notifications/prompts/list_changed"]): The notification method. + """ + + method: Literal["notifications/prompts/list_changed"] + + +class ToolListChangedNotification(BaseModel): + """Notification of tool list changes. + + Attributes: + method (Literal["notifications/tools/list_changed"]): The notification method. + """ + + method: Literal["notifications/tools/list_changed"] + + +class CancelledNotification(BaseModel): + """Notification of request cancellation. + + Attributes: + method (Literal["notifications/cancelled"]): The notification method. + request_id (Union[str, int]): The ID of the cancelled request. + reason (Optional[str]): An optional reason for cancellation. + """ + + method: Literal["notifications/cancelled"] + request_id: Union[str, int] + reason: Optional[str] = None + + +class ProgressNotification(BaseModel): + """Notification of operation progress. + + Attributes: + method (Literal["notifications/progress"]): The notification method. + progress_token (ProgressToken): The token associated with the progress. + progress (float): The current progress value. + total (Optional[float]): The total progress value, if known. + """ + + method: Literal["notifications/progress"] + progress_token: ProgressToken + progress: float + total: Optional[float] = None + + +class LoggingNotification(BaseModel): + """Notification of log messages. + + Attributes: + method (Literal["notifications/message"]): The notification method. + level (LogLevel): The log level of the message. + logger (Optional[str]): The logger name. + data (Any): The log message data. + """ + + method: Literal["notifications/message"] + level: LogLevel + logger: Optional[str] = None + data: Any + + +# Federation types +class FederatedTool(Tool): + """A tool from a federated gateway. + + Attributes: + gateway_id (str): The identifier of the gateway. + gateway_name (str): The name of the gateway. + """ + + gateway_id: str + gateway_name: str + + +class FederatedResource(Resource): + """A resource from a federated gateway. + + Attributes: + gateway_id (str): The identifier of the gateway. + gateway_name (str): The name of the gateway. + """ + + gateway_id: str + gateway_name: str + + +class FederatedPrompt(Prompt): + """A prompt from a federated gateway. + + Attributes: + gateway_id (str): The identifier of the gateway. + gateway_name (str): The name of the gateway. + """ + + gateway_id: str + gateway_name: str + + +class Gateway(CommonAttributes): + """A federated gateway peer. + + Attributes: + id (str): The unique identifier for the gateway. + name (str): The name of the gateway. + url (AnyHttpUrl): The URL of the gateway. + capabilities (ServerCapabilities): The capabilities of the gateway. + last_seen (Optional[datetime]): Timestamp when the gateway was last seen. + """ + + model_config = ConfigDict(from_attributes=True) + id: str + capabilities: ServerCapabilities + last_seen: Optional[datetime] = None + slug: str + transport: str + last_seen: Optional[datetime] + # Header passthrough configuration + passthrough_headers: Optional[list[str]] # Store list of strings as JSON array + # Request type and authentication fields + auth_value: Optional[str | dict] + + +# ===== RBAC Models ===== + + +class RBACRole(BaseModel): + """Role model for RBAC system. + + Represents roles that can be assigned to users with specific permissions. + Supports global, team, and personal scopes with role inheritance. + + Attributes: + id: Unique role identifier + name: Human-readable role name + description: Role description and purpose + scope: Role scope ('global', 'team', 'personal') + permissions: List of permission strings + inherits_from: Parent role ID for inheritance + created_by: Email of user who created the role + is_system_role: Whether this is a system-defined role + is_active: Whether the role is currently active + created_at: Role creation timestamp + updated_at: Role last modification timestamp + + Examples: + >>> from datetime import datetime + >>> role = RBACRole( + ... id="role-123", + ... name="team_admin", + ... description="Team administrator with member management rights", + ... scope="team", + ... permissions=["teams.manage_members", "resources.create"], + ... created_by="admin@example.com", + ... created_at=datetime(2023, 1, 1), + ... updated_at=datetime(2023, 1, 1) + ... ) + >>> role.name + 'team_admin' + >>> "teams.manage_members" in role.permissions + True + """ + + id: str = Field(..., description="Unique role identifier") + name: str = Field(..., description="Human-readable role name") + description: Optional[str] = Field(None, description="Role description and purpose") + scope: str = Field(..., description="Role scope", pattern="^(global|team|personal)$") + permissions: List[str] = Field(..., description="List of permission strings") + inherits_from: Optional[str] = Field(None, description="Parent role ID for inheritance") + created_by: str = Field(..., description="Email of user who created the role") + is_system_role: bool = Field(False, description="Whether this is a system-defined role") + is_active: bool = Field(True, description="Whether the role is currently active") + created_at: datetime = Field(..., description="Role creation timestamp") + updated_at: datetime = Field(..., description="Role last modification timestamp") + + +class UserRoleAssignment(BaseModel): + """User role assignment model. + + Represents the assignment of roles to users in specific scopes (global, team, personal). + Includes metadata about who granted the role and when it expires. + + Attributes: + id: Unique assignment identifier + user_email: Email of the user assigned the role + role_id: ID of the assigned role + scope: Assignment scope ('global', 'team', 'personal') + scope_id: Team ID if team-scoped, None otherwise + granted_by: Email of user who granted this role + granted_at: Timestamp when role was granted + expires_at: Optional expiration timestamp + is_active: Whether the assignment is currently active + + Examples: + >>> from datetime import datetime + >>> user_role = UserRoleAssignment( + ... id="assignment-123", + ... user_email="user@example.com", + ... role_id="team-admin-123", + ... scope="team", + ... scope_id="team-engineering-456", + ... granted_by="admin@example.com", + ... granted_at=datetime(2023, 1, 1) + ... ) + >>> user_role.scope + 'team' + >>> user_role.is_active + True + """ + + id: str = Field(..., description="Unique assignment identifier") + user_email: str = Field(..., description="Email of the user assigned the role") + role_id: str = Field(..., description="ID of the assigned role") + scope: str = Field(..., description="Assignment scope", pattern="^(global|team|personal)$") + scope_id: Optional[str] = Field(None, description="Team ID if team-scoped, None otherwise") + granted_by: str = Field(..., description="Email of user who granted this role") + granted_at: datetime = Field(..., description="Timestamp when role was granted") + expires_at: Optional[datetime] = Field(None, description="Optional expiration timestamp") + is_active: bool = Field(True, description="Whether the assignment is currently active") + + +class PermissionAudit(BaseModel): + """Permission audit log model. + + Records all permission checks for security auditing and compliance. + Includes details about the user, permission, resource, and result. + + Attributes: + id: Unique audit log entry identifier + timestamp: When the permission check occurred + user_email: Email of user being checked + permission: Permission being checked (e.g., 'tools.create') + resource_type: Type of resource (e.g., 'tools', 'teams') + resource_id: Specific resource ID if applicable + team_id: Team context if applicable + granted: Whether permission was granted + roles_checked: JSON of roles that were checked + ip_address: IP address of the request + user_agent: User agent string + + Examples: + >>> from datetime import datetime + >>> audit_log = PermissionAudit( + ... id=1, + ... timestamp=datetime(2023, 1, 1), + ... user_email="user@example.com", + ... permission="tools.create", + ... resource_type="tools", + ... granted=True, + ... roles_checked={"roles": ["team_admin"]} + ... ) + >>> audit_log.granted + True + >>> audit_log.permission + 'tools.create' + """ + + id: int = Field(..., description="Unique audit log entry identifier") + timestamp: datetime = Field(..., description="When the permission check occurred") + user_email: Optional[str] = Field(None, description="Email of user being checked") + permission: str = Field(..., description="Permission being checked") + resource_type: Optional[str] = Field(None, description="Type of resource") + resource_id: Optional[str] = Field(None, description="Specific resource ID if applicable") + team_id: Optional[str] = Field(None, description="Team context if applicable") + granted: bool = Field(..., description="Whether permission was granted") + roles_checked: Optional[Dict] = Field(None, description="JSON of roles that were checked") + ip_address: Optional[str] = Field(None, description="IP address of the request") + user_agent: Optional[str] = Field(None, description="User agent string") + + +# Permission constants are imported from db.py to avoid duplication +# Use Permissions class from mcpgateway.db instead of duplicate SystemPermissions + + +class TransportType(str, Enum): + """ + Enumeration of supported transport mechanisms for communication between components. + + Attributes: + SSE (str): Server-Sent Events transport. + HTTP (str): Standard HTTP-based transport. + STDIO (str): Standard input/output transport. + STREAMABLEHTTP (str): HTTP transport with streaming. + """ + + SSE = "SSE" + HTTP = "HTTP" + STDIO = "STDIO" + STREAMABLEHTTP = "STREAMABLEHTTP" diff --git a/mcpgateway/common/validators.py b/mcpgateway/common/validators.py new file mode 100644 index 000000000..4e8f2fa11 --- /dev/null +++ b/mcpgateway/common/validators.py @@ -0,0 +1,1190 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/common/validators.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti, Madhav Kandukuri + +SecurityValidator for MCP Gateway +This module defines the `SecurityValidator` class, which provides centralized, configurable +validation logic for user-generated content in MCP-based applications. + +The validator enforces strict security and structural rules across common input types such as: +- Display text (e.g., names, descriptions) +- Identifiers and tool names +- URIs and URLs +- JSON object depth +- Templates (including limited HTML/Jinja2) +- MIME types + +Key Features: +- Pattern-based validation using settings-defined regex for HTML/script safety +- Configurable max lengths and depth limits +- Whitelist-based URL scheme and MIME type validation +- Safe escaping of user-visible text fields +- Reusable static/class methods for field-level and form-level validation + +Intended to be used with Pydantic or similar schema-driven systems to validate and sanitize +user input in a consistent, centralized way. + +Dependencies: +- Standard Library: re, html, logging, urllib.parse +- First-party: `settings` from `mcpgateway.config` + +Example usage: + SecurityValidator.validate_name("my_tool", field_name="Tool Name") + SecurityValidator.validate_url("https://example.com") + SecurityValidator.validate_json_depth({...}) + +Examples: + >>> from mcpgateway.common.validators import SecurityValidator + >>> SecurityValidator.sanitize_display_text('Test', 'test') + '<b>Test</b>' + >>> SecurityValidator.validate_name('valid_name-123', 'test') + 'valid_name-123' + >>> SecurityValidator.validate_identifier('my.test.id_123', 'test') + 'my.test.id_123' + >>> SecurityValidator.validate_json_depth({'a': {'b': 1}}) + >>> SecurityValidator.validate_json_depth({'a': 1}) +""" + +# Standard +import html +import logging +import re +from urllib.parse import urlparse +import uuid + +# First-Party +from mcpgateway.common.config import settings + +logger = logging.getLogger(__name__) + + +class SecurityValidator: + """Configurable validation with MCP-compliant limits""" + + # Configurable patterns (from settings) + DANGEROUS_HTML_PATTERN = ( + settings.validation_dangerous_html_pattern + ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' + DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script + ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] + + # Character type patterns + NAME_PATTERN = settings.validation_name_pattern # Default: ^[a-zA-Z0-9_\-\s]+$ + IDENTIFIER_PATTERN = settings.validation_identifier_pattern # Default: ^[a-zA-Z0-9_\-\.]+$ + VALIDATION_SAFE_URI_PATTERN = settings.validation_safe_uri_pattern # Default: ^[a-zA-Z0-9_\-.:/?=&%]+$ + VALIDATION_UNSAFE_URI_PATTERN = settings.validation_unsafe_uri_pattern # Default: [<>"\'\\] + TOOL_NAME_PATTERN = settings.validation_tool_name_pattern # Default: ^[a-zA-Z][a-zA-Z0-9_-]*$ + + # MCP-compliant limits (configurable) + MAX_NAME_LENGTH = settings.validation_max_name_length # Default: 255 + MAX_DESCRIPTION_LENGTH = settings.validation_max_description_length # Default: 8192 (8KB) + MAX_TEMPLATE_LENGTH = settings.validation_max_template_length # Default: 65536 + MAX_CONTENT_LENGTH = settings.validation_max_content_length # Default: 1048576 (1MB) + MAX_JSON_DEPTH = settings.validation_max_json_depth # Default: 10 + MAX_URL_LENGTH = settings.validation_max_url_length # Default: 2048 + + @classmethod + def sanitize_display_text(cls, value: str, field_name: str) -> str: + """Ensure text is safe for display in UI by escaping special characters + + Args: + value (str): Value to validate + field_name (str): Name of field being validated + + Returns: + str: Value if acceptable + + Raises: + ValueError: When input is not acceptable + + Examples: + Basic HTML escaping: + + >>> SecurityValidator.sanitize_display_text('Hello World', 'test') + 'Hello World' + >>> SecurityValidator.sanitize_display_text('Hello World', 'test') + 'Hello <b>World</b>' + + Empty/None handling: + + >>> SecurityValidator.sanitize_display_text('', 'test') + '' + >>> SecurityValidator.sanitize_display_text(None, 'test') #doctest: +SKIP + + Dangerous script patterns: + + >>> SecurityValidator.sanitize_display_text('alert();', 'test') + 'alert();' + >>> SecurityValidator.sanitize_display_text('javascript:alert(1)', 'test') + Traceback (most recent call last): + ... + ValueError: test contains script patterns that may cause display issues + + Polyglot attack patterns: + + >>> SecurityValidator.sanitize_display_text('"; alert()', 'test') + Traceback (most recent call last): + ... + ValueError: test contains potentially dangerous character sequences + >>> SecurityValidator.sanitize_display_text('-->test', 'test') + '-->test' + >>> SecurityValidator.sanitize_display_text('-->') + Traceback (most recent call last): + ... + ValueError: Template contains HTML tags that may interfere with proper display + >>> SecurityValidator.validate_template('Test ') + Traceback (most recent call last): + ... + ValueError: Template contains HTML tags that may interfere with proper display + >>> SecurityValidator.validate_template('
') + Traceback (most recent call last): + ... + ValueError: Template contains HTML tags that may interfere with proper display + + Event handlers blocked: + + >>> SecurityValidator.validate_template('
Test
') + Traceback (most recent call last): + ... + ValueError: Template contains event handlers that may cause display issues + >>> SecurityValidator.validate_template('onload = "alert(1)"') + Traceback (most recent call last): + ... + ValueError: Template contains event handlers that may cause display issues + + SSTI prevention patterns: + + >>> SecurityValidator.validate_template('{{ __import__ }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ config }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{% import os %}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 7*7 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 10/2 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 5+5 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('{{ 10-5 }}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + + Other template injection patterns: + + >>> SecurityValidator.validate_template('${evil}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('#{evil}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + >>> SecurityValidator.validate_template('%{evil}') + Traceback (most recent call last): + ... + ValueError: Template contains potentially dangerous expressions + + Length limit testing: + + >>> long_template = 'a' * 65537 + >>> SecurityValidator.validate_template(long_template) + Traceback (most recent call last): + ... + ValueError: Template exceeds maximum length of 65536 + """ + if not value: + return value + + if len(value) > cls.MAX_TEMPLATE_LENGTH: + raise ValueError(f"Template exceeds maximum length of {cls.MAX_TEMPLATE_LENGTH}") + + # Block dangerous tags but allow Jinja2 syntax {{ }} and {% %} + dangerous_tags = r"<(script|iframe|object|embed|link|meta|base|form)\b" + if re.search(dangerous_tags, value, re.IGNORECASE): + raise ValueError("Template contains HTML tags that may interfere with proper display") + + # Check for event handlers that could cause issues + if re.search(r"on\w+\s*=", value, re.IGNORECASE): + raise ValueError("Template contains event handlers that may cause display issues") + + # SSTI Prevention - block dangerous template expressions + ssti_patterns = [ + r"\{\{.*(__|\.|config|self|request|application|globals|builtins|import).*\}\}", # Jinja2 dangerous patterns + r"\{%.*(__|\.|config|self|request|application|globals|builtins|import).*%\}", # Jinja2 tags + r"\$\{.*\}", # ${} expressions + r"#\{.*\}", # #{} expressions + r"%\{.*\}", # %{} expressions + r"\{\{.*\*.*\}\}", # Math operations in templates (like {{7*7}}) + r"\{\{.*\/.*\}\}", # Division operations + r"\{\{.*\+.*\}\}", # Addition operations + r"\{\{.*\-.*\}\}", # Subtraction operations + ] + + for pattern in ssti_patterns: + if re.search(pattern, value, re.IGNORECASE): + raise ValueError("Template contains potentially dangerous expressions") + + return value + + @classmethod + def validate_url(cls, value: str, field_name: str = "URL") -> str: + """Validate URLs for allowed schemes and safe display + + Args: + value (str): Value to validate + field_name (str): Name of field being validated + + Returns: + str: Value if acceptable + + Raises: + ValueError: When input is not acceptable + + Examples: + Valid URLs: + + >>> SecurityValidator.validate_url('https://example.com') + 'https://example.com' + >>> SecurityValidator.validate_url('http://example.com') + 'http://example.com' + >>> SecurityValidator.validate_url('ws://example.com') + 'ws://example.com' + >>> SecurityValidator.validate_url('wss://example.com') + 'wss://example.com' + >>> SecurityValidator.validate_url('https://example.com:8080/path') + 'https://example.com:8080/path' + >>> SecurityValidator.validate_url('https://example.com/path?query=value') + 'https://example.com/path?query=value' + + Empty URL handling: + + >>> SecurityValidator.validate_url('') + Traceback (most recent call last): + ... + ValueError: URL cannot be empty + + Length validation: + + >>> long_url = 'https://example.com/' + 'a' * 2100 + >>> SecurityValidator.validate_url(long_url) + Traceback (most recent call last): + ... + ValueError: URL exceeds maximum length of 2048 + + Scheme validation: + + >>> SecurityValidator.validate_url('ftp://example.com') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('file:///etc/passwd') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('javascript:alert(1)') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('data:text/plain,hello') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('vbscript:alert(1)') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('about:blank') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('chrome://settings') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + >>> SecurityValidator.validate_url('mailto:test@example.com') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + + IPv6 URL blocking: + + >>> SecurityValidator.validate_url('https://[::1]:8080/') + Traceback (most recent call last): + ... + ValueError: URL contains IPv6 address which is not supported + >>> SecurityValidator.validate_url('https://[2001:db8::1]/') + Traceback (most recent call last): + ... + ValueError: URL contains IPv6 address which is not supported + + Protocol-relative URL blocking: + + >>> SecurityValidator.validate_url('//example.com/path') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + + Line break injection: + + >>> SecurityValidator.validate_url('https://example.com\\rHost: evil.com') + Traceback (most recent call last): + ... + ValueError: URL contains line breaks which are not allowed + >>> SecurityValidator.validate_url('https://example.com\\nHost: evil.com') + Traceback (most recent call last): + ... + ValueError: URL contains line breaks which are not allowed + + Space validation: + + >>> SecurityValidator.validate_url('https://exam ple.com') + Traceback (most recent call last): + ... + ValueError: URL contains spaces which are not allowed in URLs + >>> SecurityValidator.validate_url('https://example.com/path?query=hello world') + 'https://example.com/path?query=hello world' + + Malformed URLs: + + >>> SecurityValidator.validate_url('https://') + Traceback (most recent call last): + ... + ValueError: URL is not a valid URL + >>> SecurityValidator.validate_url('not-a-url') + Traceback (most recent call last): + ... + ValueError: URL must start with one of: http://, https://, ws://, wss:// + + Restricted IP addresses: + + >>> SecurityValidator.validate_url('https://0.0.0.0/') + Traceback (most recent call last): + ... + ValueError: URL contains invalid IP address (0.0.0.0) + >>> SecurityValidator.validate_url('https://169.254.169.254/') + Traceback (most recent call last): + ... + ValueError: URL contains restricted IP address + + Invalid port numbers: + + >>> SecurityValidator.validate_url('https://example.com:0/') + Traceback (most recent call last): + ... + ValueError: URL contains invalid port number + >>> try: + ... SecurityValidator.validate_url('https://example.com:65536/') + ... except ValueError as e: + ... 'Port out of range' in str(e) or 'invalid port' in str(e) + True + + Credentials in URL: + + >>> SecurityValidator.validate_url('https://user:pass@example.com/') + Traceback (most recent call last): + ... + ValueError: URL contains credentials which are not allowed + >>> SecurityValidator.validate_url('https://user@example.com/') + Traceback (most recent call last): + ... + ValueError: URL contains credentials which are not allowed + + XSS patterns in URLs: + + >>> SecurityValidator.validate_url('https://example.com/', 'test_field') + Traceback (most recent call last): + ... + ValueError: test_field contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'content') + Traceback (most recent call last): + ... + ValueError: content contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'data') + Traceback (most recent call last): + ... + ValueError: data contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'embed') + Traceback (most recent call last): + ... + ValueError: embed contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'style') + Traceback (most recent call last): + ... + ValueError: style contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'meta') + Traceback (most recent call last): + ... + ValueError: meta contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'base') + Traceback (most recent call last): + ... + ValueError: base contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('
', 'form') + Traceback (most recent call last): + ... + ValueError: form contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'image') + Traceback (most recent call last): + ... + ValueError: image contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'svg') + Traceback (most recent call last): + ... + ValueError: svg contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'video') + Traceback (most recent call last): + ... + ValueError: video contains HTML tags that may cause security issues + >>> SecurityValidator.validate_no_xss('', 'audio') + Traceback (most recent call last): + ... + ValueError: audio contains HTML tags that may cause security issues + """ + if not value: + return # Empty values are considered safe + # Check for dangerous HTML tags + if re.search(cls.DANGEROUS_HTML_PATTERN, value, re.IGNORECASE): + raise ValueError(f"{field_name} contains HTML tags that may cause security issues") + + @classmethod + def validate_json_depth( + cls, + obj: object, + max_depth: int | None = None, + current_depth: int = 0, + ) -> None: + """Validate that a JSON‑like structure does not exceed a depth limit. + + A *depth* is counted **only** when we enter a container (`dict` or + `list`). Primitive values (`str`, `int`, `bool`, `None`, etc.) do not + increase the depth, but an *empty* container still counts as one level. + + Args: + obj: Any Python object to inspect recursively. + max_depth: Maximum allowed depth (defaults to + :pyattr:`SecurityValidator.MAX_JSON_DEPTH`). + current_depth: Internal recursion counter. **Do not** set this + from user code. + + Raises: + ValueError: If the nesting level exceeds *max_depth*. + + Examples: + Simple flat dictionary – depth 1: :: + + >>> SecurityValidator.validate_json_depth({'name': 'Alice'}) + + Nested dict – depth 2: :: + + >>> SecurityValidator.validate_json_depth( + ... {'user': {'name': 'Alice'}} + ... ) + + Mixed dict/list – depth 3: :: + + >>> SecurityValidator.validate_json_depth( + ... {'users': [{'name': 'Alice', 'meta': {'age': 30}}]} + ... ) + + Exactly at the default limit (10) – allowed: :: + + >>> deep_10 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': + ... {'9': {'10': 'end'}}}}}}}}}} + >>> SecurityValidator.validate_json_depth(deep_10) + + One level deeper – rejected: :: + + >>> deep_11 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': + ... {'9': {'10': {'11': 'end'}}}}}}}}}}} + >>> SecurityValidator.validate_json_depth(deep_11) + Traceback (most recent call last): + ... + ValueError: JSON structure exceeds maximum depth of 10 + """ + if max_depth is None: + max_depth = cls.MAX_JSON_DEPTH + + # Only containers count toward depth; primitives are ignored + if not isinstance(obj, (dict, list)): + return + + next_depth = current_depth + 1 + if next_depth > max_depth: + raise ValueError(f"JSON structure exceeds maximum depth of {max_depth}") + + if isinstance(obj, dict): + for value in obj.values(): + cls.validate_json_depth(value, max_depth, next_depth) + else: # obj is a list + for item in obj: + cls.validate_json_depth(item, max_depth, next_depth) + + @classmethod + def validate_mime_type(cls, value: str) -> str: + """Validate MIME type format + + Args: + value (str): Value to validate + + Returns: + str: Value if acceptable + + Raises: + ValueError: When input is not acceptable + + Examples: + Empty/None handling: + + >>> SecurityValidator.validate_mime_type('') + '' + >>> SecurityValidator.validate_mime_type(None) #doctest: +SKIP + + Valid standard MIME types: + + >>> SecurityValidator.validate_mime_type('text/plain') + 'text/plain' + >>> SecurityValidator.validate_mime_type('application/json') + 'application/json' + >>> SecurityValidator.validate_mime_type('image/jpeg') + 'image/jpeg' + >>> SecurityValidator.validate_mime_type('text/html') + 'text/html' + >>> SecurityValidator.validate_mime_type('application/pdf') + 'application/pdf' + + Valid vendor-specific MIME types: + + >>> SecurityValidator.validate_mime_type('application/x-custom') + 'application/x-custom' + >>> SecurityValidator.validate_mime_type('text/x-log') + 'text/x-log' + + Valid MIME types with suffixes: + + >>> SecurityValidator.validate_mime_type('application/vnd.api+json') + 'application/vnd.api+json' + >>> SecurityValidator.validate_mime_type('image/svg+xml') + 'image/svg+xml' + + Invalid MIME type formats: + + >>> SecurityValidator.validate_mime_type('invalid') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text/') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('/plain') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text//plain') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text/plain/extra') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('text plain') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + >>> SecurityValidator.validate_mime_type('') + Traceback (most recent call last): + ... + ValueError: Invalid MIME type format + + Disallowed MIME types (not in whitelist - line 620): + + >>> try: + ... SecurityValidator.validate_mime_type('application/evil') + ... except ValueError as e: + ... 'not in the allowed list' in str(e) + True + >>> try: + ... SecurityValidator.validate_mime_type('text/evil') + ... except ValueError as e: + ... 'not in the allowed list' in str(e) + True + + Test MIME type with parameters (line 618): + + >>> try: + ... SecurityValidator.validate_mime_type('application/evil; charset=utf-8') + ... except ValueError as e: + ... 'Invalid MIME type format' in str(e) + True + """ + if not value: + return value + + # Basic MIME type pattern + mime_pattern = r"^[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*\/[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*$" + if not re.match(mime_pattern, value): + raise ValueError("Invalid MIME type format") + + # Common safe MIME types + safe_mime_types = settings.validation_allowed_mime_types + if value not in safe_mime_types: + # Allow x- vendor types and + suffixes + base_type = value.split(";")[0].strip() + if not (base_type.startswith("application/x-") or base_type.startswith("text/x-") or "+" in base_type): + raise ValueError(f"MIME type '{value}' is not in the allowed list") + + return value diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 5e5e97afe..087b3936e 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -38,16 +38,16 @@ from sqlalchemy.pool import QueuePool # First-Party +from mcpgateway.common.validators import SecurityValidator from mcpgateway.config import settings from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.db_isready import wait_for_db_ready -from mcpgateway.validators import SecurityValidator logger = logging.getLogger(__name__) if TYPE_CHECKING: # First-Party - from mcpgateway.models import ResourceContent + from mcpgateway.common.models import ResourceContent # ResourceContent will be imported locally where needed to avoid circular imports # EmailUser models moved to this file to avoid circular imports @@ -1923,7 +1923,7 @@ def content(self) -> "ResourceContent": # Local import to avoid circular import # First-Party - from mcpgateway.models import ResourceContent # pylint: disable=import-outside-toplevel + from mcpgateway.common.models import ResourceContent # pylint: disable=import-outside-toplevel if self.text_content is not None: return ResourceContent( diff --git a/mcpgateway/federation/discovery.py b/mcpgateway/federation/discovery.py index e8d5409e0..c5d2890f7 100644 --- a/mcpgateway/federation/discovery.py +++ b/mcpgateway/federation/discovery.py @@ -78,8 +78,8 @@ # First-Party from mcpgateway import __version__ +from mcpgateway.common.models import ServerCapabilities from mcpgateway.config import settings -from mcpgateway.models import ServerCapabilities from mcpgateway.services.logging_service import LoggingService # Initialize logging service first diff --git a/mcpgateway/federation/forward.py b/mcpgateway/federation/forward.py index 4609cf311..cd3b106e4 100644 --- a/mcpgateway/federation/forward.py +++ b/mcpgateway/federation/forward.py @@ -36,11 +36,11 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import ToolResult from mcpgateway.config import settings from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import ServerMetric from mcpgateway.db import Tool as DbTool -from mcpgateway.models import ToolResult from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.passthrough_headers import get_passthrough_headers diff --git a/mcpgateway/handlers/sampling.py b/mcpgateway/handlers/sampling.py index 2a6d90e59..01e461ec1 100644 --- a/mcpgateway/handlers/sampling.py +++ b/mcpgateway/handlers/sampling.py @@ -10,7 +10,7 @@ Examples: >>> import asyncio - >>> from mcpgateway.models import ModelPreferences + >>> from mcpgateway.common.models import ModelPreferences >>> handler = SamplingHandler() >>> asyncio.run(handler.initialize()) >>> @@ -48,7 +48,7 @@ from sqlalchemy.orm import Session # First-Party -from mcpgateway.models import CreateMessageResult, ModelPreferences, Role, TextContent +from mcpgateway.common.models import CreateMessageResult, ModelPreferences, Role, TextContent from mcpgateway.services.logging_service import LoggingService # Initialize logging service first @@ -247,7 +247,7 @@ def _select_model(self, preferences: ModelPreferences) -> str: SamplingError: If no suitable model found Examples: - >>> from mcpgateway.models import ModelPreferences, ModelHint + >>> from mcpgateway.common.models import ModelPreferences, ModelHint >>> handler = SamplingHandler() >>> >>> # Test intelligence priority diff --git a/mcpgateway/main.py b/mcpgateway/main.py index f69cb3d9e..632fd81c2 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -63,6 +63,7 @@ from mcpgateway.auth import get_current_user from mcpgateway.bootstrap_db import main as bootstrap_db from mcpgateway.cache import ResourceCache, SessionRegistry +from mcpgateway.common.models import InitializeResult, ListResourceTemplatesResult, LogLevel, Root from mcpgateway.config import settings from mcpgateway.db import refresh_slugs_on_startup, SessionLocal from mcpgateway.db import Tool as DbTool @@ -71,7 +72,6 @@ from mcpgateway.middleware.request_logging_middleware import RequestLoggingMiddleware from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware from mcpgateway.middleware.token_scoping import token_scoping_middleware -from mcpgateway.models import InitializeResult, ListResourceTemplatesResult, LogLevel, Root from mcpgateway.observability import init_telemetry from mcpgateway.plugins.framework import PluginError, PluginManager, PluginViolationError from mcpgateway.routers.well_known import router as well_known_router @@ -2699,8 +2699,7 @@ async def read_resource(resource_id: str, request: Request, db: Session = Depend # Ensure a plain JSON-serializable structure try: # First-Party - # pylint: disable=import-outside-toplevel - from mcpgateway.models import ResourceContent, TextContent + from mcpgateway.common.models import ResourceContent, TextContent # pylint: disable=import-outside-toplevel # If already a ResourceContent, serialize directly if isinstance(content, ResourceContent): diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index fc5905c14..9ebebaa28 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -25,6 +25,7 @@ from mcp.types import TextContent # First-Party +from mcpgateway.common.models import TransportType from mcpgateway.plugins.framework.base import HookRef, Plugin, PluginRef from mcpgateway.plugins.framework.constants import ( CONTEXT, @@ -51,7 +52,6 @@ PluginPayload, PluginResult, ) -from mcpgateway.schemas import TransportType logger = logging.getLogger(__name__) diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index c9e790d15..ad0de71ef 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -27,6 +27,8 @@ ) # First-Party +from mcpgateway.common.models import TransportType +from mcpgateway.common.validators import SecurityValidator from mcpgateway.plugins.framework.constants import ( EXTERNAL_PLUGIN_TYPE, IGNORE_CONFIG_EXTERNAL, @@ -34,8 +36,6 @@ SCRIPT, URL, ) -from mcpgateway.schemas import TransportType -from mcpgateway.validators import SecurityValidator T = TypeVar("T") diff --git a/mcpgateway/plugins/mcp/entities/models.py b/mcpgateway/plugins/mcp/entities/models.py index 3a3e63d88..ad13e0473 100644 --- a/mcpgateway/plugins/mcp/entities/models.py +++ b/mcpgateway/plugins/mcp/entities/models.py @@ -17,7 +17,7 @@ from pydantic import Field, RootModel # First-Party -from mcpgateway.models import PromptResult +from mcpgateway.common.models import PromptResult from mcpgateway.plugins.framework.models import PluginPayload, PluginResult @@ -86,7 +86,7 @@ class PromptPosthookPayload(PluginPayload): result (PromptResult): The prompt after its template is rendered. Examples: - >>> from mcpgateway.models import PromptResult, Message, TextContent + >>> from mcpgateway.common.models import PromptResult, Message, TextContent >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) >>> result = PromptResult(messages=[msg]) >>> payload = PromptPosthookPayload(prompt_id="123", result=result) @@ -94,7 +94,7 @@ class PromptPosthookPayload(PluginPayload): '123' >>> payload.result.messages[0].content.text 'Hello World' - >>> from mcpgateway.models import PromptResult, Message, TextContent + >>> from mcpgateway.common.models import PromptResult, Message, TextContent >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) >>> r = PromptResult(messages=[msg]) >>> p = PromptPosthookPayload(prompt_id="123", result=r) @@ -244,7 +244,7 @@ class ResourcePostFetchPayload(PluginPayload): content: The fetched resource content. Examples: - >>> from mcpgateway.models import ResourceContent + >>> from mcpgateway.common.models import ResourceContent >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", ... text="Hello World") >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) @@ -252,7 +252,7 @@ class ResourcePostFetchPayload(PluginPayload): 'file:///data.txt' >>> payload.content.text 'Hello World' - >>> from mcpgateway.models import ResourceContent + >>> from mcpgateway.common.models import ResourceContent >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) >>> p.uri diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index b32334287..231b6210a 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -33,15 +33,15 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, EmailStr, Field, field_serializer, field_validator, model_validator, ValidationInfo # First-Party +from mcpgateway.common.models import ImageContent +from mcpgateway.common.models import Prompt as MCPPrompt +from mcpgateway.common.models import Resource as MCPResource +from mcpgateway.common.models import ResourceContent, TextContent +from mcpgateway.common.models import Tool as MCPTool +from mcpgateway.common.validators import SecurityValidator from mcpgateway.config import settings -from mcpgateway.models import ImageContent -from mcpgateway.models import Prompt as MCPPrompt -from mcpgateway.models import Resource as MCPResource -from mcpgateway.models import ResourceContent, TextContent -from mcpgateway.models import Tool as MCPTool from mcpgateway.utils.services_auth import decode_auth, encode_auth from mcpgateway.validation.tags import validate_tags_field -from mcpgateway.validators import SecurityValidator logger = logging.getLogger(__name__) diff --git a/mcpgateway/services/completion_service.py b/mcpgateway/services/completion_service.py index bee038abd..89b99c9d9 100644 --- a/mcpgateway/services/completion_service.py +++ b/mcpgateway/services/completion_service.py @@ -25,9 +25,9 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import CompleteResult from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import Resource as DbResource -from mcpgateway.models import CompleteResult from mcpgateway.services.logging_service import LoggingService # Initialize logging service first diff --git a/mcpgateway/services/log_storage_service.py b/mcpgateway/services/log_storage_service.py index ed4631c9d..36dca4fb1 100644 --- a/mcpgateway/services/log_storage_service.py +++ b/mcpgateway/services/log_storage_service.py @@ -18,8 +18,8 @@ import uuid # First-Party +from mcpgateway.common.models import LogLevel from mcpgateway.config import settings -from mcpgateway.models import LogLevel class LogEntryDict(TypedDict, total=False): @@ -108,7 +108,7 @@ def to_dict(self) -> LogEntryDict: Dictionary representation of the log entry Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> entry = LogEntry(LogLevel.INFO, "Test message", entity_type="tool", entity_id="123") >>> d = entry.to_dict() >>> str(d['level']) @@ -371,7 +371,7 @@ def _meets_level_threshold(self, log_level: LogLevel, min_level: LogLevel) -> bo True if log level meets or exceeds minimum Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> service = LogStorageService() >>> service._meets_level_threshold(LogLevel.ERROR, LogLevel.WARNING) True @@ -462,7 +462,7 @@ def clear(self) -> int: Number of logs cleared Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> service = LogStorageService() >>> import asyncio >>> entry = asyncio.run(service.add_log(LogLevel.INFO, "Test")) diff --git a/mcpgateway/services/logging_service.py b/mcpgateway/services/logging_service.py index e876dcdca..36ab5780b 100644 --- a/mcpgateway/services/logging_service.py +++ b/mcpgateway/services/logging_service.py @@ -22,8 +22,8 @@ from pythonjsonlogger import json as jsonlogger # You may need to install python-json-logger package # First-Party +from mcpgateway.common.models import LogLevel from mcpgateway.config import settings -from mcpgateway.models import LogLevel from mcpgateway.services.log_storage_service import LogStorageService AnyioClosedResourceError: Optional[type] # pylint: disable=invalid-name @@ -405,7 +405,7 @@ async def set_level(self, level: LogLevel) -> None: Examples: >>> from mcpgateway.services.logging_service import LoggingService - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> import asyncio >>> service = LoggingService() >>> asyncio.run(service.set_level(LogLevel.DEBUG)) @@ -445,7 +445,7 @@ async def notify( # pylint: disable=too-many-positional-arguments Examples: >>> from mcpgateway.services.logging_service import LoggingService - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> import asyncio >>> service = LoggingService() >>> asyncio.run(service.notify('test', LogLevel.INFO)) @@ -538,7 +538,7 @@ def _should_log(self, level: LogLevel) -> bool: True if should log Examples: - >>> from mcpgateway.models import LogLevel + >>> from mcpgateway.common.models import LogLevel >>> service = LoggingService() >>> service._level = LogLevel.WARNING >>> service._should_log(LogLevel.ERROR) diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index c612ec8e4..eedc6dec0 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -30,11 +30,11 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.config import settings from mcpgateway.db import EmailTeam from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association -from mcpgateway.models import Message, PromptResult, Role, TextContent from mcpgateway.observability import create_span from mcpgateway.plugins.framework import GlobalContext, PluginManager from mcpgateway.plugins.mcp.entities import HookType, PromptPosthookPayload, PromptPrehookPayload diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index e0e926def..664324451 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -41,12 +41,12 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import ResourceContent, ResourceTemplate, TextContent from mcpgateway.db import EmailTeam from mcpgateway.db import Resource as DbResource from mcpgateway.db import ResourceMetric from mcpgateway.db import ResourceSubscription as DbSubscription from mcpgateway.db import server_resource_association -from mcpgateway.models import ResourceContent, ResourceTemplate, TextContent from mcpgateway.observability import create_span from mcpgateway.schemas import ResourceCreate, ResourceMetrics, ResourceRead, ResourceSubscription, ResourceUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService @@ -659,7 +659,7 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request Examples: >>> from mcpgateway.services.resource_service import ResourceService >>> from unittest.mock import MagicMock - >>> from mcpgateway.models import ResourceContent + >>> from mcpgateway.common.models import ResourceContent >>> service = ResourceService() >>> db = MagicMock() >>> uri = 'http://example.com/resource.txt' diff --git a/mcpgateway/services/root_service.py b/mcpgateway/services/root_service.py index 1e88e62e1..3f97b87c7 100644 --- a/mcpgateway/services/root_service.py +++ b/mcpgateway/services/root_service.py @@ -16,8 +16,8 @@ from urllib.parse import urlparse # First-Party +from mcpgateway.common.models import Root from mcpgateway.config import settings -from mcpgateway.models import Root from mcpgateway.services.logging_service import LoggingService # Initialize logging service first @@ -296,7 +296,7 @@ async def _notify_root_added(self, root: Root) -> None: Examples: >>> import asyncio >>> from mcpgateway.services.root_service import RootService - >>> from mcpgateway.models import Root + >>> from mcpgateway.common.models import Root >>> service = RootService() >>> queue = asyncio.Queue() >>> service._subscribers.append(queue) @@ -320,7 +320,7 @@ async def _notify_root_removed(self, root: Root) -> None: Examples: >>> import asyncio >>> from mcpgateway.services.root_service import RootService - >>> from mcpgateway.models import Root + >>> from mcpgateway.common.models import Root >>> service = RootService() >>> queue = asyncio.Queue() >>> service._subscribers.append(queue) diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index c53237e53..725983579 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -37,6 +37,10 @@ from sqlalchemy.orm import Session # First-Party +from mcpgateway.common.models import Gateway as PydanticGateway +from mcpgateway.common.models import TextContent +from mcpgateway.common.models import Tool as PydanticTool +from mcpgateway.common.models import ToolResult from mcpgateway.config import settings from mcpgateway.db import A2AAgent as DbA2AAgent from mcpgateway.db import EmailTeam @@ -44,10 +48,6 @@ from mcpgateway.db import server_tool_association from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric -from mcpgateway.models import Gateway as PydanticGateway -from mcpgateway.models import TextContent -from mcpgateway.models import Tool as PydanticTool -from mcpgateway.models import ToolResult from mcpgateway.observability import create_span from mcpgateway.plugins.framework import GlobalContext, PluginError, PluginManager, PluginViolationError from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA diff --git a/mcpgateway/utils/pagination.py b/mcpgateway/utils/pagination.py index cf5891681..339691fb7 100644 --- a/mcpgateway/utils/pagination.py +++ b/mcpgateway/utils/pagination.py @@ -22,7 +22,7 @@ from mcpgateway.utils.pagination import paginate_query from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools(db: Session): query = select(Tool).where(Tool.enabled == True) @@ -215,7 +215,7 @@ async def offset_paginate( from mcpgateway.utils.pagination import offset_paginate from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools_offset(db: Session, page: int = 1): query = select(Tool).where(Tool.enabled == True) @@ -314,7 +314,7 @@ async def cursor_paginate( from mcpgateway.utils.pagination import cursor_paginate from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools_cursor(db: Session, cursor: Optional[str] = None): query = select(Tool).order_by(Tool.created_at.desc()) @@ -436,7 +436,7 @@ async def paginate_query( from mcpgateway.utils.pagination import paginate_query from sqlalchemy import select - from mcpgateway.models import Tool + from mcpgateway.common.models import Tool async def list_tools_auto(db: Session, page: int = 1): query = select(Tool) diff --git a/mcpgateway/utils/passthrough_headers.py b/mcpgateway/utils/passthrough_headers.py index c3f7c1f91..a260dc0b0 100644 --- a/mcpgateway/utils/passthrough_headers.py +++ b/mcpgateway/utils/passthrough_headers.py @@ -350,7 +350,7 @@ async def set_global_passthrough_headers(db: Session) -> None: Config already exists (no DB write): >>> import pytest >>> from unittest.mock import Mock, patch - >>> from mcpgateway.models import GlobalConfig + >>> from mcpgateway.common.models import GlobalConfig >>> @pytest.mark.asyncio ... @patch("mcpgateway.utils.passthrough_headers.settings") ... async def test_existing_config(mock_settings): diff --git a/plugin_templates/external/tests/test_all.py b/plugin_templates/external/tests/test_all.py index 39987cbe7..b439b5136 100644 --- a/plugin_templates/external/tests/test_all.py +++ b/plugin_templates/external/tests/test_all.py @@ -8,7 +8,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, diff --git a/plugins/external/llmguard/tests/test_llmguardplugin.py b/plugins/external/llmguard/tests/test_llmguardplugin.py index 7107e5afd..6615e08ae 100644 --- a/plugins/external/llmguard/tests/test_llmguardplugin.py +++ b/plugins/external/llmguard/tests/test_llmguardplugin.py @@ -15,7 +15,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import GlobalContext, PluginConfig, PluginContext, PromptPosthookPayload, PromptPrehookPayload diff --git a/plugins/external/opa/tests/test_all.py b/plugins/external/opa/tests/test_all.py index 227abaebc..3e2d872bd 100644 --- a/plugins/external/opa/tests/test_all.py +++ b/plugins/external/opa/tests/test_all.py @@ -8,7 +8,7 @@ import pytest # First-Party -from mcpgateway.models import Message, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, diff --git a/plugins/external/opa/tests/test_opapluginfilter.py b/plugins/external/opa/tests/test_opapluginfilter.py index 046b5df2e..9ba896c9b 100644 --- a/plugins/external/opa/tests/test_opapluginfilter.py +++ b/plugins/external/opa/tests/test_opapluginfilter.py @@ -16,7 +16,7 @@ import pytest # First-Party -from mcpgateway.models import Message, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, diff --git a/plugins/file_type_allowlist/file_type_allowlist.py b/plugins/file_type_allowlist/file_type_allowlist.py index 6a38492da..5450e7524 100644 --- a/plugins/file_type_allowlist/file_type_allowlist.py +++ b/plugins/file_type_allowlist/file_type_allowlist.py @@ -20,7 +20,7 @@ from pydantic import BaseModel, Field # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, diff --git a/plugins/html_to_markdown/html_to_markdown.py b/plugins/html_to_markdown/html_to_markdown.py index adc3799e5..f500c00e6 100644 --- a/plugins/html_to_markdown/html_to_markdown.py +++ b/plugins/html_to_markdown/html_to_markdown.py @@ -18,7 +18,7 @@ from typing import Any # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, diff --git a/plugins/markdown_cleaner/markdown_cleaner.py b/plugins/markdown_cleaner/markdown_cleaner.py index 61f3b31ca..5b1d9cde7 100644 --- a/plugins/markdown_cleaner/markdown_cleaner.py +++ b/plugins/markdown_cleaner/markdown_cleaner.py @@ -17,7 +17,8 @@ from typing import Any # First-Party -from mcpgateway.models import Message, PromptResult, ResourceContent, TextContent +from mcpgateway.common.models import Message, PromptResult, TextContent +from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, diff --git a/plugins/privacy_notice_injector/privacy_notice_injector.py b/plugins/privacy_notice_injector/privacy_notice_injector.py index 80ad5546e..b37ab4055 100644 --- a/plugins/privacy_notice_injector/privacy_notice_injector.py +++ b/plugins/privacy_notice_injector/privacy_notice_injector.py @@ -19,7 +19,7 @@ from pydantic import BaseModel # First-Party -from mcpgateway.models import Message, Role, TextContent +from mcpgateway.common.models import Message, Role, TextContent from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, diff --git a/plugins/resource_filter/resource_filter.py b/plugins/resource_filter/resource_filter.py index 7213e553e..e4a481724 100644 --- a/plugins/resource_filter/resource_filter.py +++ b/plugins/resource_filter/resource_filter.py @@ -178,7 +178,7 @@ async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: if filtered_text != original_text: # Create new content object with filtered text # First-Party - from mcpgateway.models import ResourceContent + from mcpgateway.common.models import ResourceContent modified_content = ResourceContent( type=payload.content.type, diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index f7cb0f997..b9b3ec299 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -32,7 +32,7 @@ # First-Party from mcpgateway.main import app, require_auth -from mcpgateway.models import InitializeResult, ResourceContent, ServerCapabilities +from mcpgateway.common.models import InitializeResult, ResourceContent, ServerCapabilities from mcpgateway.schemas import ResourceRead, ServerRead, ToolMetrics, ToolRead # Local diff --git a/tests/integration/test_resource_plugin_integration.py b/tests/integration/test_resource_plugin_integration.py index 1582f6610..2a5ef2ab7 100644 --- a/tests/integration/test_resource_plugin_integration.py +++ b/tests/integration/test_resource_plugin_integration.py @@ -18,7 +18,7 @@ # First-Party from mcpgateway.db import Base -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.schemas import ResourceCreate from mcpgateway.services.resource_service import ResourceService diff --git a/tests/security/test_input_validation.py b/tests/security/test_input_validation.py index 78dc36027..85c43d575 100644 --- a/tests/security/test_input_validation.py +++ b/tests/security/test_input_validation.py @@ -35,7 +35,7 @@ # First-Party from mcpgateway.schemas import AdminToolCreate, encode_datetime, GatewayCreate, PromptArgument, PromptCreate, ResourceCreate, RPCRequest, ServerCreate, to_camel_case, ToolCreate, ToolInvocation -from mcpgateway.validators import SecurityValidator +from mcpgateway.common.validators import SecurityValidator # Configure logging for better test debugging logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py index 524a6b60f..4d979f873 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py @@ -14,7 +14,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginContext, diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py index 313bf6ed9..0f7c3bffc 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py @@ -17,7 +17,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( ConfigLoader, GlobalContext, diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py index e7ab7100d..44405c912 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py @@ -20,7 +20,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, ResourceContent, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, ResourceContent, Role, TextContent from mcpgateway.plugins.framework import ( ConfigLoader, GlobalContext, diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py index dd0eb8b68..72964d197 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py @@ -17,7 +17,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import ConfigLoader, GlobalContext, PluginContext, PluginLoader from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload diff --git a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py index 114f8449b..fa6b48d66 100644 --- a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py +++ b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py @@ -14,7 +14,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework import GlobalContext, PluginContext, PluginMode diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index 7df5b6d70..f077f7922 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -11,7 +11,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError from mcpgateway.plugins.mcp.entities import HookType, HttpHeaderPayload, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload from plugins.regex_filter.search_replace import SearchReplaceConfig diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py index 2e6bac7f6..88091140b 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py @@ -16,7 +16,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.base import HookRef, Plugin from mcpgateway.plugins.framework.models import Config from mcpgateway.plugins.framework import ( @@ -461,7 +461,7 @@ async def test_manager_payload_size_validation(): # Test large result payload (covers line 258) # First-Party - from mcpgateway.models import Message, PromptResult, Role, TextContent + from mcpgateway.common.models import Message, PromptResult, Role, TextContent large_text = "y" * (MAX_PAYLOAD_SIZE + 1) message = Message(role=Role.USER, content=TextContent(type="text", text=large_text)) @@ -543,7 +543,7 @@ async def test_manager_initialization_edge_cases(): async def test_base_plugin_coverage(): """Test base plugin functionality for complete coverage.""" # First-Party - from mcpgateway.models import Message, PromptResult, Role, TextContent + from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.base import PluginRef from mcpgateway.plugins.framework.models import ( GlobalContext, diff --git a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py index 3d95e6e5e..b120b0a75 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py +++ b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py @@ -14,7 +14,7 @@ import pytest # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework.base import PluginRef # Registry is imported for mocking diff --git a/tests/unit/mcpgateway/plugins/framework/test_utils.py b/tests/unit/mcpgateway/plugins/framework/test_utils.py index 126824756..00e0e51dd 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_utils.py +++ b/tests/unit/mcpgateway/plugins/framework/test_utils.py @@ -115,7 +115,7 @@ def test_parse_class_name(): # """Test the post_prompt_matches function.""" # # Import required models # # First-Party -# from mcpgateway.models import Message, PromptResult, TextContent +# from mcpgateway.common.models import Message, PromptResult, TextContent # # Test basic matching # msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) @@ -144,7 +144,7 @@ def test_parse_class_name(): # def test_post_prompt_matches_multiple_conditions(): # """Test post_prompt_matches with multiple conditions (OR logic).""" # # First-Party -# from mcpgateway.models import Message, PromptResult, TextContent +# from mcpgateway.common.models import Message, PromptResult, TextContent # # Create the payload # msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) diff --git a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py index a3f8c571e..2817c7dcc 100644 --- a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py +++ b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py @@ -19,7 +19,8 @@ ResourcePostFetchPayload, ResourcePreFetchPayload, ) -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from plugins.external.clamav_server.clamav_plugin import ClamAVRemotePlugin @@ -81,11 +82,11 @@ async def test_prompt_post_fetch_blocks_on_eicar_text(): plugin = _mk_plugin(True) from mcpgateway.plugins.mcp.entities import PromptPosthookPayload - pr = __import__("mcpgateway.models").models.PromptResult( + pr = PromptResult( messages=[ - __import__("mcpgateway.models").models.Message( + Message( role="assistant", - content=__import__("mcpgateway.models").models.TextContent(type="text", text=EICAR), + content=TextContent(type="text", text=EICAR), ) ] ) @@ -122,11 +123,11 @@ async def test_health_stats_counters(): # 2) prompt_post_fetch with EICAR -> attempted +1, infected +1 (total attempted=2, infected=2) from mcpgateway.plugins.mcp.entities import PromptPosthookPayload - pr = __import__("mcpgateway.models").models.PromptResult( + pr = PromptResult( messages=[ - __import__("mcpgateway.models").models.Message( + Message( role="assistant", - content=__import__("mcpgateway.models").models.TextContent(type="text", text=EICAR), + content=TextContent(type="text", text=EICAR), ) ] ) diff --git a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py index 348af6781..44b2ade84 100644 --- a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py +++ b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py @@ -20,7 +20,7 @@ ResourcePreFetchPayload, ResourcePostFetchPayload, ) -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from plugins.file_type_allowlist.file_type_allowlist import FileTypeAllowlistPlugin diff --git a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py index e830ccbbe..33bf9fd75 100644 --- a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py +++ b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py @@ -18,7 +18,7 @@ HookType, ResourcePostFetchPayload, ) -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from plugins.html_to_markdown.html_to_markdown import HTMLToMarkdownPlugin diff --git a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py index bb75e68d7..b4db80dfa 100644 --- a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py +++ b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py @@ -9,7 +9,7 @@ import pytest -from mcpgateway.models import Message, PromptResult, TextContent +from mcpgateway.common.models import Message, PromptResult, TextContent from mcpgateway.plugins.framework.models import ( GlobalContext, PluginConfig, diff --git a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py index 3cde9b347..b0ac9890c 100644 --- a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py @@ -11,7 +11,7 @@ import pytest # First-Party -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, diff --git a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py index e8745c96c..a5bac8a43 100644 --- a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py @@ -11,7 +11,7 @@ import pytest # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework.models import ( GlobalContext, PluginConfig, diff --git a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py index b0e942085..a12432057 100644 --- a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py +++ b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py @@ -24,7 +24,7 @@ ) from plugins.virus_total_checker.virus_total_checker import VirusTotalURLCheckerPlugin -from mcpgateway.models import Message, PromptResult, TextContent +from mcpgateway.common.models import Message, PromptResult, TextContent class _Resp: @@ -291,7 +291,7 @@ async def test_resource_scan_blocks_on_url(): plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.models import ResourceContent + from mcpgateway.common.models import ResourceContent rc = ResourceContent(type="resource", id="345",uri="test://x", mime_type="text/plain", text=f"{url} is fishy") from mcpgateway.plugins.mcp.entities import ResourcePostFetchPayload payload = ResourcePostFetchPayload(uri="test://x", content=rc) diff --git a/tests/unit/mcpgateway/services/test_completion_service.py b/tests/unit/mcpgateway/services/test_completion_service.py index e7fe866e2..f46a65d1a 100644 --- a/tests/unit/mcpgateway/services/test_completion_service.py +++ b/tests/unit/mcpgateway/services/test_completion_service.py @@ -9,7 +9,7 @@ import pytest # First-Party -from mcpgateway.models import ( +from mcpgateway.common.models import ( CompleteResult, ) from mcpgateway.services.completion_service import ( diff --git a/tests/unit/mcpgateway/services/test_export_service.py b/tests/unit/mcpgateway/services/test_export_service.py index 209f23e87..15a278f18 100644 --- a/tests/unit/mcpgateway/services/test_export_service.py +++ b/tests/unit/mcpgateway/services/test_export_service.py @@ -15,7 +15,7 @@ import pytest # First-Party -from mcpgateway.models import Root +from mcpgateway.common.models import Root from mcpgateway.schemas import GatewayRead, PromptMetrics, PromptRead, ResourceMetrics, ResourceRead, ServerMetrics, ServerRead, ToolMetrics, ToolRead from mcpgateway.services.export_service import ExportError, ExportService, ExportValidationError from mcpgateway.utils.services_auth import encode_auth @@ -971,7 +971,7 @@ async def test_export_selective_all_entity_types(export_service, mock_db): export_service.resource_service.list_resources.return_value = [sample_resource] # First-Party - from mcpgateway.models import Root + from mcpgateway.common.models import Root mock_roots = [Root(uri="file:///workspace", name="Workspace")] export_service.root_service.list_roots.return_value = mock_roots diff --git a/tests/unit/mcpgateway/services/test_log_storage_service.py b/tests/unit/mcpgateway/services/test_log_storage_service.py index 15c1742be..414e02ebc 100644 --- a/tests/unit/mcpgateway/services/test_log_storage_service.py +++ b/tests/unit/mcpgateway/services/test_log_storage_service.py @@ -16,7 +16,7 @@ import pytest # First-Party -from mcpgateway.models import LogLevel +from mcpgateway.common.models import LogLevel from mcpgateway.services.log_storage_service import LogEntry, LogStorageService diff --git a/tests/unit/mcpgateway/services/test_logging_service.py b/tests/unit/mcpgateway/services/test_logging_service.py index e8ae79b27..933852577 100644 --- a/tests/unit/mcpgateway/services/test_logging_service.py +++ b/tests/unit/mcpgateway/services/test_logging_service.py @@ -26,7 +26,7 @@ import pytest # First-Party -from mcpgateway.models import LogLevel +from mcpgateway.common.models import LogLevel from mcpgateway.services.logging_service import LoggingService # --------------------------------------------------------------------------- diff --git a/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py b/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py index e7cde8217..cbe5d0121 100644 --- a/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py +++ b/tests/unit/mcpgateway/services/test_logging_service_comprehensive.py @@ -17,7 +17,7 @@ import pytest # First-Party -from mcpgateway.models import LogLevel +from mcpgateway.common.models import LogLevel from mcpgateway.services.logging_service import _get_file_handler, _get_text_handler, LoggingService # --------------------------------------------------------------------------- diff --git a/tests/unit/mcpgateway/services/test_prompt_service.py b/tests/unit/mcpgateway/services/test_prompt_service.py index 992b12777..503b98a61 100644 --- a/tests/unit/mcpgateway/services/test_prompt_service.py +++ b/tests/unit/mcpgateway/services/test_prompt_service.py @@ -29,7 +29,7 @@ # First-Party from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric -from mcpgateway.models import Message, PromptResult, Role, TextContent +from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate from mcpgateway.services.prompt_service import ( diff --git a/tests/unit/mcpgateway/services/test_resource_service_plugins.py b/tests/unit/mcpgateway/services/test_resource_service_plugins.py index f7b9d0e68..bb79c9af4 100644 --- a/tests/unit/mcpgateway/services/test_resource_service_plugins.py +++ b/tests/unit/mcpgateway/services/test_resource_service_plugins.py @@ -16,7 +16,7 @@ from sqlalchemy.orm import Session # First-Party -from mcpgateway.models import ResourceContent +from mcpgateway.common.models import ResourceContent from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService from mcpgateway.plugins.framework import PluginError, PluginErrorModel, PluginViolation, PluginViolationError diff --git a/tests/unit/mcpgateway/test_discovery.py b/tests/unit/mcpgateway/test_discovery.py index 188360081..398e9f7f4 100644 --- a/tests/unit/mcpgateway/test_discovery.py +++ b/tests/unit/mcpgateway/test_discovery.py @@ -37,7 +37,7 @@ async def discovery(): async def _fake_gateway_info(url: str): # noqa: D401, ANN001 # Return an *empty* capabilities object - structure is unimportant here. # First-Party - from mcpgateway.models import ServerCapabilities + from mcpgateway.common.models import ServerCapabilities return ServerCapabilities() diff --git a/tests/unit/mcpgateway/test_final_coverage_push.py b/tests/unit/mcpgateway/test_final_coverage_push.py index d8ff42ec3..2004f79d2 100644 --- a/tests/unit/mcpgateway/test_final_coverage_push.py +++ b/tests/unit/mcpgateway/test_final_coverage_push.py @@ -16,7 +16,7 @@ import pytest # First-Party -from mcpgateway.models import ImageContent, LogLevel, ResourceContent, Role, TextContent +from mcpgateway.common.models import ImageContent, LogLevel, ResourceContent, Role, TextContent from mcpgateway.schemas import BaseModelWithConfigDict diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index cc2ed736c..045ae1c9e 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -24,7 +24,7 @@ # First-Party from mcpgateway.config import settings -from mcpgateway.models import InitializeResult, ResourceContent, ServerCapabilities +from mcpgateway.common.models import InitializeResult, ResourceContent, ServerCapabilities from mcpgateway.schemas import ( PromptRead, ResourceRead, @@ -1034,7 +1034,7 @@ class TestRootEndpoints: def test_list_roots_endpoint(self, mock_list, test_client, auth_headers): """Test listing all registered roots.""" # First-Party - from mcpgateway.models import Root + from mcpgateway.common.models import Root mock_list.return_value = [Root(uri="file:///test", name="Test Root")] # valid URI response = test_client.get("/roots/", headers=auth_headers) @@ -1048,7 +1048,7 @@ def test_list_roots_endpoint(self, mock_list, test_client, auth_headers): def test_add_root_endpoint(self, mock_add, test_client, auth_headers): """Test adding a new root directory.""" # First-Party - from mcpgateway.models import Root + from mcpgateway.common.models import Root mock_add.return_value = Root(uri="file:///test", name="Test Root") # valid URI diff --git a/tests/unit/mcpgateway/test_models.py b/tests/unit/mcpgateway/test_models.py index 10681902b..7e765d1f5 100644 --- a/tests/unit/mcpgateway/test_models.py +++ b/tests/unit/mcpgateway/test_models.py @@ -18,7 +18,7 @@ import pytest # First-Party -from mcpgateway.models import ( +from mcpgateway.common.models import ( ClientCapabilities, CreateMessageResult, ImageContent, diff --git a/tests/unit/mcpgateway/test_rpc_tool_invocation.py b/tests/unit/mcpgateway/test_rpc_tool_invocation.py index 34529820e..b303ed6ae 100644 --- a/tests/unit/mcpgateway/test_rpc_tool_invocation.py +++ b/tests/unit/mcpgateway/test_rpc_tool_invocation.py @@ -17,7 +17,7 @@ # First-Party from mcpgateway.main import app -from mcpgateway.models import Tool +from mcpgateway.common.models import Tool from mcpgateway.services.tool_service import ToolService diff --git a/tests/unit/mcpgateway/test_schemas.py b/tests/unit/mcpgateway/test_schemas.py index 2aef43d7f..4cc18169f 100644 --- a/tests/unit/mcpgateway/test_schemas.py +++ b/tests/unit/mcpgateway/test_schemas.py @@ -20,7 +20,7 @@ import pytest # First-Party -from mcpgateway.models import ( +from mcpgateway.common.models import ( ClientCapabilities, CreateMessageResult, ImageContent, diff --git a/tests/unit/mcpgateway/validation/test_validators.py b/tests/unit/mcpgateway/validation/test_validators.py index ccb574db5..8e81fd39a 100644 --- a/tests/unit/mcpgateway/validation/test_validators.py +++ b/tests/unit/mcpgateway/validation/test_validators.py @@ -15,7 +15,7 @@ import pytest # First-Party -from mcpgateway.validators import SecurityValidator +from mcpgateway.common.validators import SecurityValidator class DummySettings: diff --git a/tests/unit/mcpgateway/validation/test_validators_advanced.py b/tests/unit/mcpgateway/validation/test_validators_advanced.py index 82eaf75f6..6645f522d 100644 --- a/tests/unit/mcpgateway/validation/test_validators_advanced.py +++ b/tests/unit/mcpgateway/validation/test_validators_advanced.py @@ -27,7 +27,7 @@ import pytest # First-Party -from mcpgateway.validators import SecurityValidator +from mcpgateway.common.validators import SecurityValidator class DummySettings: From 16493830899a81227618564a82c5edbf0cd58840 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Thu, 30 Oct 2025 16:15:24 -0600 Subject: [PATCH 05/20] feat: added agent hooks. Signed-off-by: Teryl Taylor --- mcpgateway/plugins/agent/__init__.py | 26 ++ mcpgateway/plugins/agent/base.py | 165 ++++++++ mcpgateway/plugins/agent/models.py | 123 ++++++ mcpgateway/plugins/framework/models.py | 2 +- .../unit/mcpgateway/plugins/agent/__init__.py | 8 + .../plugins/agent/test_agent_plugins.py | 365 ++++++++++++++++++ .../fixtures/configs/agent_context.yaml | 29 ++ .../fixtures/configs/agent_filter.yaml | 34 ++ .../fixtures/configs/agent_passthrough.yaml | 28 ++ .../plugins/fixtures/plugins/agent_test.py | 197 ++++++++++ 10 files changed, 976 insertions(+), 1 deletion(-) create mode 100644 mcpgateway/plugins/agent/__init__.py create mode 100644 mcpgateway/plugins/agent/base.py create mode 100644 mcpgateway/plugins/agent/models.py create mode 100644 tests/unit/mcpgateway/plugins/agent/__init__.py create mode 100644 tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml create mode 100644 tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py diff --git a/mcpgateway/plugins/agent/__init__.py b/mcpgateway/plugins/agent/__init__.py new file mode 100644 index 000000000..576929642 --- /dev/null +++ b/mcpgateway/plugins/agent/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/agent/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Agent plugin framework exports. +""" + +from mcpgateway.plugins.agent.base import AgentPlugin +from mcpgateway.plugins.agent.models import ( + AgentHookType, + AgentPreInvokePayload, + AgentPreInvokeResult, + AgentPostInvokePayload, + AgentPostInvokeResult, +) + +__all__ = [ + "AgentPlugin", + "AgentHookType", + "AgentPreInvokePayload", + "AgentPreInvokeResult", + "AgentPostInvokePayload", + "AgentPostInvokeResult", +] diff --git a/mcpgateway/plugins/agent/base.py b/mcpgateway/plugins/agent/base.py new file mode 100644 index 000000000..a59145f31 --- /dev/null +++ b/mcpgateway/plugins/agent/base.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/agent/base.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Base plugin for agents. +This module implements the base plugin object for agent hooks. +It supports pre and post hooks for AI safety, security and business processing +for agent invocations: +- agent_pre_invoke: Before sending messages to agent +- agent_post_invoke: After receiving agent response +""" + +# First-Party +from mcpgateway.plugins.agent.models import ( + AgentHookType, + AgentPostInvokePayload, + AgentPostInvokeResult, + AgentPreInvokePayload, + AgentPreInvokeResult, +) +from mcpgateway.plugins.framework.base import Plugin +from mcpgateway.plugins.framework.models import PluginConfig, PluginContext + + +def _register_agent_hooks(): + """Register agent hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(AgentHookType.AGENT_PRE_INVOKE): + registry.register_hook(AgentHookType.AGENT_PRE_INVOKE, AgentPreInvokePayload, AgentPreInvokeResult) + registry.register_hook(AgentHookType.AGENT_POST_INVOKE, AgentPostInvokePayload, AgentPostInvokeResult) + + +class AgentPlugin(Plugin): + """Base agent plugin for pre/post processing of agent invocations. + + Examples: + >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode + >>> from mcpgateway.plugins.agent import AgentHookType + >>> config = PluginConfig( + ... name="test_agent_plugin", + ... description="Test agent plugin", + ... author="test", + ... kind="mcpgateway.plugins.agent.AgentPlugin", + ... version="1.0.0", + ... hooks=[AgentHookType.AGENT_PRE_INVOKE], + ... tags=["test"], + ... mode=PluginMode.ENFORCE, + ... priority=50 + ... ) + >>> plugin = AgentPlugin(config) + >>> plugin.name + 'test_agent_plugin' + >>> plugin.priority + 50 + >>> plugin.mode + + >>> AgentHookType.AGENT_PRE_INVOKE in plugin.hooks + True + """ + + def __init__(self, config: PluginConfig) -> None: + """Initialize an agent plugin with configuration. + + Args: + config: The plugin configuration + + Examples: + >>> from mcpgateway.plugins.framework import PluginConfig + >>> from mcpgateway.plugins.agent import AgentHookType + >>> config = PluginConfig( + ... name="simple_agent_plugin", + ... description="Simple test", + ... author="test", + ... kind="test.AgentPlugin", + ... version="1.0.0", + ... hooks=[AgentHookType.AGENT_POST_INVOKE], + ... tags=["simple"] + ... ) + >>> plugin = AgentPlugin(config) + >>> plugin._config.name + 'simple_agent_plugin' + """ + super().__init__(config) + _register_agent_hooks() + + async def agent_pre_invoke(self, payload: AgentPreInvokePayload, context: PluginContext) -> AgentPreInvokeResult: + """Hook before agent invocation. + + Args: + payload: Agent pre-invoke payload. + context: Plugin execution context. + + Raises: + NotImplementedError: needs to be implemented by sub class. + + Examples: + >>> import asyncio + >>> from mcpgateway.plugins.framework import PluginConfig, GlobalContext, PluginContext + >>> from mcpgateway.plugins.agent import AgentHookType, AgentPreInvokePayload + >>> config = PluginConfig( + ... name="test_plugin", + ... description="Test", + ... author="test", + ... kind="test.Plugin", + ... version="1.0.0", + ... hooks=[AgentHookType.AGENT_PRE_INVOKE] + ... ) + >>> plugin = AgentPlugin(config) + >>> payload = AgentPreInvokePayload(agent_id="agent-123", messages=[]) + >>> ctx = PluginContext(global_context=GlobalContext(request_id="r1")) + >>> result = asyncio.run(plugin.agent_pre_invoke(payload, ctx)) + >>> result.continue_processing + True + """ + raise NotImplementedError( + f"""'agent_pre_invoke' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) + + async def agent_post_invoke(self, payload: AgentPostInvokePayload, context: PluginContext) -> AgentPostInvokeResult: + """Hook after agent responds. + + Args: + payload: Agent post-invoke payload. + context: Plugin execution context. + + Raises: + NotImplementedError: needs to be implemented by sub class. + + Examples: + >>> import asyncio + >>> from mcpgateway.plugins.framework import PluginConfig, GlobalContext, PluginContext + >>> from mcpgateway.plugins.agent import AgentHookType, AgentPostInvokePayload + >>> config = PluginConfig( + ... name="test_plugin", + ... description="Test", + ... author="test", + ... kind="test.Plugin", + ... version="1.0.0", + ... hooks=[AgentHookType.AGENT_POST_INVOKE] + ... ) + >>> plugin = AgentPlugin(config) + >>> payload = AgentPostInvokePayload(agent_id="agent-123", messages=[]) + >>> ctx = PluginContext(global_context=GlobalContext(request_id="r1")) + >>> result = asyncio.run(plugin.agent_post_invoke(payload, ctx)) + >>> result.continue_processing + True + """ + raise NotImplementedError( + f"""'agent_post_invoke' not implemented for plugin {self._config.name} + of plugin type {type(self)} + """ + ) diff --git a/mcpgateway/plugins/agent/models.py b/mcpgateway/plugins/agent/models.py new file mode 100644 index 000000000..601de3f22 --- /dev/null +++ b/mcpgateway/plugins/agent/models.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/agent/models.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for agent plugins. +This module implements the pydantic models associated with +the base plugin layer including configurations, and contexts. +""" + +# Standard +from enum import Enum +from typing import Any, Dict, List, Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.common.models import Message +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult +from mcpgateway.plugins.mcp.entities.models import HttpHeaderPayload + + +class AgentHookType(str, Enum): + """Agent hook points. + + Attributes: + AGENT_PRE_INVOKE: Before agent invocation. + AGENT_POST_INVOKE: After agent responds. + + Examples: + >>> AgentHookType.AGENT_PRE_INVOKE + + >>> AgentHookType.AGENT_PRE_INVOKE.value + 'agent_pre_invoke' + >>> AgentHookType('agent_post_invoke') + + >>> list(AgentHookType) + [, ] + """ + + AGENT_PRE_INVOKE = "agent_pre_invoke" + AGENT_POST_INVOKE = "agent_post_invoke" + + +class AgentPreInvokePayload(PluginPayload): + """Agent payload for pre-invoke hook. + + Attributes: + agent_id: The agent identifier (can be modified for routing). + messages: Conversation messages (can be filtered/transformed). + tools: Optional list of tools available to agent. + headers: Optional HTTP headers. + model: Optional model override. + system_prompt: Optional system instructions. + parameters: Optional LLM parameters (temperature, max_tokens, etc.). + + Examples: + >>> payload = AgentPreInvokePayload(agent_id="agent-123", messages=[]) + >>> payload.agent_id + 'agent-123' + >>> payload.messages + [] + >>> payload.tools is None + True + >>> from mcpgateway.common.models import Message, Role, TextContent + >>> msg = Message(role=Role.USER, content=TextContent(type="text", text="Hello")) + >>> payload = AgentPreInvokePayload( + ... agent_id="agent-456", + ... messages=[msg], + ... tools=["search", "calculator"], + ... model="claude-3-5-sonnet-20241022" + ... ) + >>> payload.tools + ['search', 'calculator'] + >>> payload.model + 'claude-3-5-sonnet-20241022' + """ + + agent_id: str + messages: List[Message] + tools: Optional[List[str]] = None + headers: Optional[HttpHeaderPayload] = None + model: Optional[str] = None + system_prompt: Optional[str] = None + parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class AgentPostInvokePayload(PluginPayload): + """Agent payload for post-invoke hook. + + Attributes: + agent_id: The agent identifier. + messages: Response messages from agent (can be filtered/transformed). + tool_calls: Optional tool invocations made by agent. + + Examples: + >>> payload = AgentPostInvokePayload(agent_id="agent-123", messages=[]) + >>> payload.agent_id + 'agent-123' + >>> payload.messages + [] + >>> payload.tool_calls is None + True + >>> from mcpgateway.common.models import Message, Role, TextContent + >>> msg = Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Response")) + >>> payload = AgentPostInvokePayload( + ... agent_id="agent-456", + ... messages=[msg], + ... tool_calls=[{"name": "search", "arguments": {"query": "test"}}] + ... ) + >>> payload.tool_calls + [{'name': 'search', 'arguments': {'query': 'test'}}] + """ + + agent_id: str + messages: List[Message] + tool_calls: Optional[List[Dict[str, Any]]] = None + + +AgentPreInvokeResult = PluginResult[AgentPreInvokePayload] +AgentPostInvokeResult = PluginResult[AgentPostInvokePayload] diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index ad0de71ef..3e7cb1222 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -687,7 +687,7 @@ class PluginViolation(BaseModel): reason: str description: str code: str - details: dict[str, Any] + details: Optional[dict[str, Any]] = Field(default_factory=dict) _plugin_name: str = PrivateAttr(default="") @property diff --git a/tests/unit/mcpgateway/plugins/agent/__init__.py b/tests/unit/mcpgateway/plugins/agent/__init__.py new file mode 100644 index 000000000..5503bed0d --- /dev/null +++ b/tests/unit/mcpgateway/plugins/agent/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/agent/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for agent plugin framework. +""" diff --git a/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py b/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py new file mode 100644 index 000000000..4a9c67d30 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for agent plugin framework. +""" + +# Third-Party +import pytest + +# First-Party +from mcpgateway.common.models import Message, Role, TextContent +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError +from mcpgateway.plugins.agent import ( + AgentHookType, + AgentPreInvokePayload, + AgentPostInvokePayload, +) + + +@pytest.mark.asyncio +async def test_agent_passthrough_plugin(): + """Test that passthrough agent plugin works correctly.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Verify plugin loaded + assert manager.config.plugins[0].name == "PassThroughAgent" + assert manager.config.plugins[0].kind == "tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.PassThroughAgentPlugin" + assert AgentHookType.AGENT_PRE_INVOKE.value in manager.config.plugins[0].hooks + assert AgentHookType.AGENT_POST_INVOKE.value in manager.config.plugins[0].hooks + + # Create test payload + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello agent!")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=messages, + tools=["search", "calculator"], + model="claude-3-5-sonnet-20241022" + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-1") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Verify passthrough (no modification) + assert result.continue_processing is True + assert result.modified_payload is None + assert result.violation is None + + # Create response payload + response_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Hello user!")) + ] + post_payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=response_messages + ) + + # Invoke post-hook + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + post_payload, + global_context=global_context, + local_contexts=contexts + ) + + # Verify passthrough (no modification) + assert result.continue_processing is True + assert result.modified_payload is None + assert result.violation is None + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_filter_plugin_pre_invoke(): + """Test that filter agent plugin blocks messages with banned words in pre-invoke.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml") + await manager.initialize() + + # Create test payload with clean message + clean_messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello agent!")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=clean_messages + ) + + # Invoke pre-hook with clean message + global_context = GlobalContext(request_id="test-req-2") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Clean message should pass through + assert result.continue_processing is True + assert result.modified_payload is None + + # Create payload with blocked word + blocked_messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Click here for spam offers!")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=blocked_messages + ) + + # Invoke pre-hook with blocked message - should raise violation + with pytest.raises(PluginViolationError) as exc_info: + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context, + violations_as_exceptions=True + ) + + assert exc_info.value.violation.code == "BLOCKED_CONTENT" + assert "blocked content" in exc_info.value.violation.reason.lower() + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_filter_plugin_post_invoke(): + """Test that filter agent plugin blocks messages with banned words in post-invoke.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml") + await manager.initialize() + + # Create test payload with clean response + clean_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Here is your answer.")) + ] + payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=clean_messages + ) + + # Invoke post-hook with clean message + global_context = GlobalContext(request_id="test-req-3") + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + payload, + global_context=global_context + ) + + # Clean message should pass through + assert result.continue_processing is True + assert result.modified_payload is None + + # Create payload with blocked word + blocked_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="This looks like malware to me.")) + ] + payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=blocked_messages + ) + + # Invoke post-hook with blocked message - should raise violation + with pytest.raises(PluginViolationError) as exc_info: + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + payload, + global_context=global_context, + violations_as_exceptions=True + ) + + assert exc_info.value.violation.code == "BLOCKED_CONTENT" + assert "blocked content" in exc_info.value.violation.reason.lower() + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_filter_plugin_partial_filtering(): + """Test that filter plugin removes only blocked messages, keeps others.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml") + await manager.initialize() + + # Create payload with mixed messages + mixed_messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello agent!")), + Message(role=Role.USER, content=TextContent(type="text", text="Check out this spam!")), + Message(role=Role.USER, content=TextContent(type="text", text="What's the weather?")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=mixed_messages + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-4") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Should have modified payload with only 2 messages + assert result.modified_payload is not None + assert len(result.modified_payload.messages) == 2 + assert result.modified_payload.messages[0].content.text == "Hello agent!" + assert result.modified_payload.messages[1].content.text == "What's the weather?" + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_context_persistence(): + """Test that local context persists between pre and post hooks.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml") + await manager.initialize() + + # Create pre-invoke payload + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Hello!")) + ] + pre_payload = AgentPreInvokePayload( + agent_id="test-agent-123", + messages=messages + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-5") + pre_result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + pre_payload, + global_context=global_context + ) + + assert pre_result.continue_processing is True + + # Create post-invoke payload + response_messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="Hi there!")) + ] + post_payload = AgentPostInvokePayload( + agent_id="test-agent-123", + messages=response_messages + ) + + # Invoke post-hook with same contexts + post_result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + post_payload, + global_context=global_context, + local_contexts=contexts + ) + + # Verify context was verified (metadata added by post hook) + assert post_result.continue_processing is True + # The metadata should be in the contexts, not the result + # Check that invocation_count was incremented + assert contexts is not None + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_plugin_with_tools(): + """Test agent plugin with tools list.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Create payload with tools + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Search for Python tutorials")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=messages, + tools=["web_search", "code_search", "calculator"] + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-6") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Verify tools are preserved + assert result.continue_processing is True + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_plugin_with_model_override(): + """Test agent plugin with model override.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Create payload with model override + messages = [ + Message(role=Role.USER, content=TextContent(type="text", text="Analyze this code")) + ] + payload = AgentPreInvokePayload( + agent_id="test-agent", + messages=messages, + model="claude-3-opus-20240229", + parameters={"temperature": 0.7, "max_tokens": 1000} + ) + + # Invoke pre-hook + global_context = GlobalContext(request_id="test-req-7") + result, contexts = await manager.invoke_hook( + AgentHookType.AGENT_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Verify model and parameters are preserved + assert result.continue_processing is True + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_agent_plugin_with_tool_calls(): + """Test agent plugin with tool calls in response.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml") + await manager.initialize() + + # Create post-invoke payload with tool calls + messages = [ + Message(role=Role.ASSISTANT, content=TextContent(type="text", text="I'll search for that.")) + ] + tool_calls = [ + { + "name": "web_search", + "arguments": {"query": "Python tutorials", "num_results": 5} + } + ] + payload = AgentPostInvokePayload( + agent_id="test-agent", + messages=messages, + tool_calls=tool_calls + ) + + # Invoke post-hook + global_context = GlobalContext(request_id="test-req-8") + result, _ = await manager.invoke_hook( + AgentHookType.AGENT_POST_INVOKE, + payload, + global_context=global_context + ) + + # Verify tool calls are preserved + assert result.continue_processing is True + + await manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml new file mode 100644 index 000000000..74d4328b9 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml @@ -0,0 +1,29 @@ +plugins: + - name: ContextTrackingAgent + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.ContextTrackingAgentPlugin + description: An agent plugin that tracks state in local context + version: "1.0.0" + author: Test Suite + hooks: + - agent_pre_invoke + - agent_post_invoke + tags: + - test + - agent + - context + mode: enforce + priority: 50 + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml new file mode 100644 index 000000000..f5f927d1f --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml @@ -0,0 +1,34 @@ +plugins: + - name: MessageFilterAgent + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.MessageFilterAgentPlugin + description: An agent plugin that filters blocked words + version: "1.0.0" + author: Test Suite + hooks: + - agent_pre_invoke + - agent_post_invoke + tags: + - test + - agent + - filter + mode: enforce + priority: 50 + config: + blocked_words: + - spam + - malware + - phishing + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml new file mode 100644 index 000000000..3525dc3cc --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml @@ -0,0 +1,28 @@ +plugins: + - name: PassThroughAgent + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.PassThroughAgentPlugin + description: A simple pass-through agent plugin for testing + version: "1.0.0" + author: Test Suite + hooks: + - agent_pre_invoke + - agent_post_invoke + tags: + - test + - agent + mode: enforce + priority: 50 + +# Plugin directories to scan +plugin_dirs: + - "plugins/native" # Built-in plugins + - "plugins/custom" # Custom organization plugins + - "/etc/mcpgateway/plugins" # System-wide plugins + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py new file mode 100644 index 000000000..20c33bb44 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Test agent plugins for unit testing. +""" + +# First-Party +from mcpgateway.common.models import Message, Role, TextContent +from mcpgateway.plugins.framework import PluginContext +from mcpgateway.plugins.agent import ( + AgentPlugin, + AgentPreInvokePayload, + AgentPreInvokeResult, + AgentPostInvokePayload, + AgentPostInvokeResult, +) + + +class PassThroughAgentPlugin(AgentPlugin): + """A simple pass-through agent plugin that doesn't modify anything.""" + + async def agent_pre_invoke( + self, payload: AgentPreInvokePayload, context: PluginContext + ) -> AgentPreInvokeResult: + """Pass through without modification. + + Args: + payload: The agent pre-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result allowing processing to continue. + """ + return AgentPreInvokeResult(continue_processing=True) + + async def agent_post_invoke( + self, payload: AgentPostInvokePayload, context: PluginContext + ) -> AgentPostInvokeResult: + """Pass through without modification. + + Args: + payload: The agent post-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result allowing processing to continue. + """ + return AgentPostInvokeResult(continue_processing=True) + + +class MessageFilterAgentPlugin(AgentPlugin): + """An agent plugin that filters messages containing blocked words.""" + + async def agent_pre_invoke( + self, payload: AgentPreInvokePayload, context: PluginContext + ) -> AgentPreInvokeResult: + """Filter messages containing blocked words. + + Args: + payload: The agent pre-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result with filtered messages or violation. + """ + blocked_words = self.config.config.get("blocked_words", []) + + # Filter messages + filtered_messages = [] + for msg in payload.messages: + if isinstance(msg.content, TextContent): + text_lower = msg.content.text.lower() + if any(word in text_lower for word in blocked_words): + # Skip this message + continue + filtered_messages.append(msg) + + # If all messages were blocked, return violation + if not filtered_messages and payload.messages: + from mcpgateway.plugins.framework import PluginViolation + return AgentPreInvokeResult( + continue_processing=False, + violation=PluginViolation( + code="BLOCKED_CONTENT", + reason="All messages contained blocked content", + description="This is a test of content blocking" + ) + ) + + # Return modified payload if messages were filtered + if len(filtered_messages) != len(payload.messages): + modified_payload = AgentPreInvokePayload( + agent_id=payload.agent_id, + messages=filtered_messages, + tools=payload.tools, + headers=payload.headers, + model=payload.model, + system_prompt=payload.system_prompt, + parameters=payload.parameters + ) + return AgentPreInvokeResult(modified_payload=modified_payload) + + return AgentPreInvokeResult(continue_processing=True) + + async def agent_post_invoke( + self, payload: AgentPostInvokePayload, context: PluginContext + ) -> AgentPostInvokeResult: + """Filter response messages containing blocked words. + + Args: + payload: The agent post-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result with filtered messages or violation. + """ + blocked_words = self.config.config.get("blocked_words", []) + + # Filter messages + filtered_messages = [] + for msg in payload.messages: + if isinstance(msg.content, TextContent): + text_lower = msg.content.text.lower() + if any(word in text_lower for word in blocked_words): + # Skip this message + continue + filtered_messages.append(msg) + + # If all messages were blocked, return violation + if not filtered_messages and payload.messages: + from mcpgateway.plugins.framework import PluginViolation + return AgentPostInvokeResult( + continue_processing=False, + violation=PluginViolation( + code="BLOCKED_CONTENT", + reason="All response messages contained blocked content", + description="This is a test of content blocking" + ) + ) + + # Return modified payload if messages were filtered + if len(filtered_messages) != len(payload.messages): + modified_payload = AgentPostInvokePayload( + agent_id=payload.agent_id, + messages=filtered_messages, + tool_calls=payload.tool_calls + ) + return AgentPostInvokeResult(modified_payload=modified_payload) + + return AgentPostInvokeResult(continue_processing=True) + + +class ContextTrackingAgentPlugin(AgentPlugin): + """An agent plugin that tracks state in local context.""" + + async def agent_pre_invoke( + self, payload: AgentPreInvokePayload, context: PluginContext + ) -> AgentPreInvokeResult: + """Track invocation count in local context. + + Args: + payload: The agent pre-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result with updated local context. + """ + # Increment counter in local context + counter = context.metadata.get("invocation_count", 0) + context.metadata["invocation_count"] = counter + 1 + context.metadata["agent_id"] = payload.agent_id + + return AgentPreInvokeResult(continue_processing=True) + + async def agent_post_invoke( + self, payload: AgentPostInvokePayload, context: PluginContext + ) -> AgentPostInvokeResult: + """Verify context persists from pre-invoke. + + Args: + payload: The agent post-invoke payload. + context: Contextual information about the hook call. + + Returns: + The result after verifying context. + """ + # Verify context persisted + counter = context.metadata.get("invocation_count", 0) + agent_id = context.metadata.get("agent_id", "") + + # Add metadata about the context + context.metadata["context_verified"] = counter > 0 and agent_id == payload.agent_id + + return AgentPostInvokeResult(continue_processing=True) From dd191ac64de4322b87077c9da02a049f41870f84 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Fri, 31 Oct 2025 17:51:39 -0600 Subject: [PATCH 06/20] refactor: plugins to support 3 hook patterns Signed-off-by: Teryl Taylor --- .../adr/016-plugin-framework-ai-middleware.md | 2 +- docs/docs/architecture/plugins.md | 4 +- docs/docs/using/plugins/index.md | 10 +- docs/docs/using/plugins/rust-plugins.md | 4 +- llms/plugins-llms.md | 2 +- mcpgateway/plugins/agent/__init__.py | 26 - mcpgateway/plugins/agent/base.py | 165 ------- mcpgateway/plugins/framework/__init__.py | 52 +- mcpgateway/plugins/framework/base.py | 214 ++++++++- mcpgateway/plugins/framework/decorator.py | 174 +++++++ .../plugins/framework/external/mcp/client.py | 2 +- .../plugins/framework/hooks/__init__.py | 9 + .../models.py => framework/hooks/agents.py} | 22 +- mcpgateway/plugins/framework/hooks/http.py | 55 +++ mcpgateway/plugins/framework/hooks/prompts.py | 132 +++++ .../{hook_registry.py => hooks/registry.py} | 4 +- .../plugins/framework/hooks/resources.py | 113 +++++ mcpgateway/plugins/framework/hooks/tools.py | 117 +++++ mcpgateway/plugins/mcp/__init__.py | 8 - mcpgateway/plugins/mcp/entities/__init__.py | 49 -- mcpgateway/plugins/mcp/entities/base.py | 212 --------- mcpgateway/plugins/mcp/entities/models.py | 267 ----------- mcpgateway/services/prompt_service.py | 13 +- mcpgateway/services/resource_service.py | 13 +- mcpgateway/services/tool_service.py | 18 +- .../plugin.py.jinja | 2 +- plugin_templates/native/plugin.py.jinja | 2 +- plugins/README.md | 449 +++++++++++++++--- .../ai_artifacts_normalizer.py | 6 +- plugins/altk_json_processor/json_processor.py | 6 +- .../argument_normalizer.py | 6 +- .../cached_tool_result/cached_tool_result.py | 6 +- plugins/circuit_breaker/circuit_breaker.py | 6 +- .../citation_validator/citation_validator.py | 6 +- plugins/code_formatter/code_formatter.py | 6 +- .../code_safety_linter/code_safety_linter.py | 6 +- .../content_moderation/content_moderation.py | 6 +- plugins/deny_filter/deny.py | 12 +- .../external/clamav_server/clamav_plugin.py | 6 +- .../llmguard/llmguardplugin/plugin.py | 7 +- .../external/opa/opapluginfilter/plugin.py | 6 +- .../file_type_allowlist.py | 6 +- .../harmful_content_detector.py | 6 +- plugins/header_injector/header_injector.py | 6 +- plugins/html_to_markdown/html_to_markdown.py | 6 +- plugins/json_repair/json_repair.py | 6 +- .../license_header_injector.py | 6 +- plugins/markdown_cleaner/markdown_cleaner.py | 6 +- .../output_length_guard.py | 6 +- plugins/pii_filter/pii_filter.py | 6 +- .../privacy_notice_injector.py | 6 +- plugins/rate_limiter/rate_limiter.py | 6 +- plugins/regex_filter/search_replace.py | 6 +- plugins/resource_filter/resource_filter.py | 6 +- .../response_cache_by_prompt.py | 6 +- .../retry_with_backoff/retry_with_backoff.py | 6 +- .../robots_license_guard.py | 6 +- .../safe_html_sanitizer.py | 6 +- plugins/schema_guard/schema_guard.py | 6 +- .../secrets_detection/secrets_detection.py | 6 +- plugins/sql_sanitizer/sql_sanitizer.py | 6 +- plugins/summarizer/summarizer.py | 6 +- .../timezone_translator.py | 6 +- plugins/url_reputation/url_reputation.py | 6 +- plugins/vault/vault_plugin.py | 6 +- .../virus_total_checker.py | 6 +- plugins/watchdog/watchdog.py | 6 +- .../webhook_notification.py | 6 +- plugins_rust/docs/implementation-guide.md | 2 +- .../test_resource_plugin_integration.py | 14 +- .../plugins/agent/test_agent_plugins.py | 4 +- .../fixtures/configs/agent_context.yaml | 2 +- .../fixtures/configs/agent_filter.yaml | 2 +- .../fixtures/configs/agent_passthrough.yaml | 2 +- .../configs/test_hook_patterns_config.yaml | 26 + .../{agent_test.py => agent_plugins.py} | 14 +- .../plugins/fixtures/plugins/context.py | 10 +- .../plugins/fixtures/plugins/error.py | 8 +- .../plugins/fixtures/plugins/headers.py | 8 +- .../plugins/fixtures/plugins/passthrough.py | 8 +- .../plugins/fixtures/plugins/simple.py | 48 ++ .../external/mcp/server/test_runtime.py | 2 - .../external/mcp/test_client_config.py | 18 +- .../external/mcp/test_client_stdio.py | 32 +- .../mcp/test_client_streamable_http.py | 3 +- .../framework/hooks/test_hook_patterns.py | 312 ++++++++++++ .../framework/hooks/test_hook_registry.py | 137 ++++++ .../framework/loader/test_plugin_loader.py | 5 +- .../plugins/framework/test_context.py | 12 +- .../plugins/framework/test_errors.py | 10 +- .../plugins/framework/test_manager.py | 36 +- .../framework/test_manager_extended.py | 120 +++-- .../plugins/framework/test_registry.py | 50 +- .../plugins/framework/test_resource_hooks.py | 70 ++- .../test_json_processor.py | 6 +- .../test_argument_normalizer.py | 7 +- .../test_cached_tool_result.py | 9 +- .../test_code_safety_linter.py | 8 +- .../test_content_moderation.py | 7 +- .../test_content_moderation_integration.py | 17 +- .../external_clamav/test_clamav_remote.py | 14 +- .../test_file_type_allowlist.py | 9 +- .../html_to_markdown/test_html_to_markdown.py | 8 +- .../plugins/json_repair/test_json_repair.py | 9 +- .../markdown_cleaner/test_markdown_cleaner.py | 8 +- .../test_output_length_guard.py | 9 +- .../plugins/pii_filter/test_pii_filter.py | 8 +- .../plugins/rate_limiter/test_rate_limiter.py | 9 +- .../resource_filter/test_resource_filter.py | 8 +- .../plugins/schema_guard/test_schema_guard.py | 8 +- .../url_reputation/test_url_reputation.py | 6 +- .../test_virus_total_checker.py | 38 +- .../test_webhook_integration.py | 14 +- .../test_webhook_notification.py | 11 +- .../services/test_resource_service_plugins.py | 20 +- .../mcpgateway/services/test_tool_service.py | 17 +- 116 files changed, 2219 insertions(+), 1373 deletions(-) delete mode 100644 mcpgateway/plugins/agent/__init__.py delete mode 100644 mcpgateway/plugins/agent/base.py create mode 100644 mcpgateway/plugins/framework/decorator.py create mode 100644 mcpgateway/plugins/framework/hooks/__init__.py rename mcpgateway/plugins/{agent/models.py => framework/hooks/agents.py} (81%) create mode 100644 mcpgateway/plugins/framework/hooks/http.py create mode 100644 mcpgateway/plugins/framework/hooks/prompts.py rename mcpgateway/plugins/framework/{hook_registry.py => hooks/registry.py} (98%) create mode 100644 mcpgateway/plugins/framework/hooks/resources.py create mode 100644 mcpgateway/plugins/framework/hooks/tools.py delete mode 100644 mcpgateway/plugins/mcp/__init__.py delete mode 100644 mcpgateway/plugins/mcp/entities/__init__.py delete mode 100644 mcpgateway/plugins/mcp/entities/base.py delete mode 100644 mcpgateway/plugins/mcp/entities/models.py create mode 100644 tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml rename tests/unit/mcpgateway/plugins/fixtures/plugins/{agent_test.py => agent_plugins.py} (96%) create mode 100644 tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py create mode 100644 tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py create mode 100644 tests/unit/mcpgateway/plugins/framework/hooks/test_hook_registry.py diff --git a/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md b/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md index 5b239c9c7..b5803cd59 100644 --- a/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md +++ b/docs/docs/architecture/adr/016-plugin-framework-ai-middleware.md @@ -20,7 +20,7 @@ We implemented a comprehensive plugin framework with the following key architect ```python from mcpgateway.plugins.framework import Plugin -class MyInProcessPlugin(MCPPlugin): +class MyInProcessPlugin(Plugin): async def prompt_pre_fetch(self, payload, context): ... # in‑process logic diff --git a/docs/docs/architecture/plugins.md b/docs/docs/architecture/plugins.md index 2f27b2e86..819cbdebf 100644 --- a/docs/docs/architecture/plugins.md +++ b/docs/docs/architecture/plugins.md @@ -1330,7 +1330,7 @@ class PluginSettings(BaseModel): #### PII Filter Plugin (Native) ```python -class PIIFilterPlugin(MCPPlugin): +class PIIFilterPlugin(Plugin): """Detects and masks Personally Identifiable Information""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, @@ -1367,7 +1367,7 @@ class PIIFilterPlugin(MCPPlugin): #### Resource Filter Plugin (Security) ```python -class ResourceFilterPlugin(MCPPlugin): +class ResourceFilterPlugin(Plugin): """Validates and filters resource requests""" async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, diff --git a/docs/docs/using/plugins/index.md b/docs/docs/using/plugins/index.md index 89e36b7d4..0caf87132 100644 --- a/docs/docs/using/plugins/index.md +++ b/docs/docs/using/plugins/index.md @@ -89,7 +89,7 @@ Decide between a native (in‑process) or external (MCP) plugin: ```python from mcpgateway.plugins.framework import Plugin, PluginConfig, PluginContext, PromptPrehookPayload, PromptPrehookResult -class MyPlugin(MCPPlugin): +class MyPlugin(Plugin): def __init__(self, config: PluginConfig): super().__init__(config) @@ -539,7 +539,7 @@ from mcpgateway.plugins.framework import ( ResourcePostFetchResult ) -class MyPlugin(MCPPlugin): +class MyPlugin(Plugin): """Example plugin implementation.""" def __init__(self, config: PluginConfig): @@ -813,7 +813,7 @@ Metadata for other entities such as prompts and resources will be added in futur ### External Service Plugin Example ```python -class LLMGuardPlugin(MCPPlugin): +class LLMGuardPlugin(Plugin): """Example external service integration.""" def __init__(self, config: PluginConfig): @@ -901,7 +901,7 @@ default_config: # plugins/my_plugin/plugin.py from mcpgateway.plugins.framework import Plugin -class MyPlugin(MCPPlugin): +class MyPlugin(Plugin): # Implementation here pass ``` @@ -963,7 +963,7 @@ Errors inside a plugin should be raised as exceptions. The plugin manager will - Consider async operations for I/O ```python -class CachedPlugin(MCPPlugin): +class CachedPlugin(Plugin): def __init__(self, config): super().__init__(config) self._cache = {} diff --git a/docs/docs/using/plugins/rust-plugins.md b/docs/docs/using/plugins/rust-plugins.md index a99c89735..a10dfd9ce 100644 --- a/docs/docs/using/plugins/rust-plugins.md +++ b/docs/docs/using/plugins/rust-plugins.md @@ -496,7 +496,7 @@ try: except ImportError: RUST_AVAILABLE = False -class MyPlugin(MCPPlugin): +class MyPlugin(Plugin): def __init__(self, config): if RUST_AVAILABLE: self.impl = RustMyPlugin(config) @@ -624,7 +624,7 @@ If you have an existing Python plugin you want to optimize: You don't need to convert entire plugins at once: ```python -class MyPlugin(MCPPlugin): +class MyPlugin(Plugin): def __init__(self, config): # Use Rust for expensive operations if RUST_AVAILABLE: diff --git a/llms/plugins-llms.md b/llms/plugins-llms.md index e31515872..c2a16c353 100644 --- a/llms/plugins-llms.md +++ b/llms/plugins-llms.md @@ -179,7 +179,7 @@ from mcpgateway.plugins.framework import Plugin, PluginConfig, PluginContext from mcpgateway.plugins.framework import PromptPrehookPayload, PromptPrehookResult from mcpgateway.plugins.framework import PluginViolation -class MyGuard(MCPPlugin): +class MyGuard(Plugin): async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: if payload.args and any("forbidden" in v for v in payload.args.values() if isinstance(v, str)): return PromptPrehookResult( diff --git a/mcpgateway/plugins/agent/__init__.py b/mcpgateway/plugins/agent/__init__.py deleted file mode 100644 index 576929642..000000000 --- a/mcpgateway/plugins/agent/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./mcpgateway/plugins/agent/__init__.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -Agent plugin framework exports. -""" - -from mcpgateway.plugins.agent.base import AgentPlugin -from mcpgateway.plugins.agent.models import ( - AgentHookType, - AgentPreInvokePayload, - AgentPreInvokeResult, - AgentPostInvokePayload, - AgentPostInvokeResult, -) - -__all__ = [ - "AgentPlugin", - "AgentHookType", - "AgentPreInvokePayload", - "AgentPreInvokeResult", - "AgentPostInvokePayload", - "AgentPostInvokeResult", -] diff --git a/mcpgateway/plugins/agent/base.py b/mcpgateway/plugins/agent/base.py deleted file mode 100644 index a59145f31..000000000 --- a/mcpgateway/plugins/agent/base.py +++ /dev/null @@ -1,165 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./mcpgateway/plugins/agent/base.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -Base plugin for agents. -This module implements the base plugin object for agent hooks. -It supports pre and post hooks for AI safety, security and business processing -for agent invocations: -- agent_pre_invoke: Before sending messages to agent -- agent_post_invoke: After receiving agent response -""" - -# First-Party -from mcpgateway.plugins.agent.models import ( - AgentHookType, - AgentPostInvokePayload, - AgentPostInvokeResult, - AgentPreInvokePayload, - AgentPreInvokeResult, -) -from mcpgateway.plugins.framework.base import Plugin -from mcpgateway.plugins.framework.models import PluginConfig, PluginContext - - -def _register_agent_hooks(): - """Register agent hooks in the global registry. - - This is called lazily to avoid circular import issues. - """ - # Import here to avoid circular dependency at module load time - # First-Party - from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel - - registry = get_hook_registry() - - # Only register if not already registered (idempotent) - if not registry.is_registered(AgentHookType.AGENT_PRE_INVOKE): - registry.register_hook(AgentHookType.AGENT_PRE_INVOKE, AgentPreInvokePayload, AgentPreInvokeResult) - registry.register_hook(AgentHookType.AGENT_POST_INVOKE, AgentPostInvokePayload, AgentPostInvokeResult) - - -class AgentPlugin(Plugin): - """Base agent plugin for pre/post processing of agent invocations. - - Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode - >>> from mcpgateway.plugins.agent import AgentHookType - >>> config = PluginConfig( - ... name="test_agent_plugin", - ... description="Test agent plugin", - ... author="test", - ... kind="mcpgateway.plugins.agent.AgentPlugin", - ... version="1.0.0", - ... hooks=[AgentHookType.AGENT_PRE_INVOKE], - ... tags=["test"], - ... mode=PluginMode.ENFORCE, - ... priority=50 - ... ) - >>> plugin = AgentPlugin(config) - >>> plugin.name - 'test_agent_plugin' - >>> plugin.priority - 50 - >>> plugin.mode - - >>> AgentHookType.AGENT_PRE_INVOKE in plugin.hooks - True - """ - - def __init__(self, config: PluginConfig) -> None: - """Initialize an agent plugin with configuration. - - Args: - config: The plugin configuration - - Examples: - >>> from mcpgateway.plugins.framework import PluginConfig - >>> from mcpgateway.plugins.agent import AgentHookType - >>> config = PluginConfig( - ... name="simple_agent_plugin", - ... description="Simple test", - ... author="test", - ... kind="test.AgentPlugin", - ... version="1.0.0", - ... hooks=[AgentHookType.AGENT_POST_INVOKE], - ... tags=["simple"] - ... ) - >>> plugin = AgentPlugin(config) - >>> plugin._config.name - 'simple_agent_plugin' - """ - super().__init__(config) - _register_agent_hooks() - - async def agent_pre_invoke(self, payload: AgentPreInvokePayload, context: PluginContext) -> AgentPreInvokeResult: - """Hook before agent invocation. - - Args: - payload: Agent pre-invoke payload. - context: Plugin execution context. - - Raises: - NotImplementedError: needs to be implemented by sub class. - - Examples: - >>> import asyncio - >>> from mcpgateway.plugins.framework import PluginConfig, GlobalContext, PluginContext - >>> from mcpgateway.plugins.agent import AgentHookType, AgentPreInvokePayload - >>> config = PluginConfig( - ... name="test_plugin", - ... description="Test", - ... author="test", - ... kind="test.Plugin", - ... version="1.0.0", - ... hooks=[AgentHookType.AGENT_PRE_INVOKE] - ... ) - >>> plugin = AgentPlugin(config) - >>> payload = AgentPreInvokePayload(agent_id="agent-123", messages=[]) - >>> ctx = PluginContext(global_context=GlobalContext(request_id="r1")) - >>> result = asyncio.run(plugin.agent_pre_invoke(payload, ctx)) - >>> result.continue_processing - True - """ - raise NotImplementedError( - f"""'agent_pre_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) - - async def agent_post_invoke(self, payload: AgentPostInvokePayload, context: PluginContext) -> AgentPostInvokeResult: - """Hook after agent responds. - - Args: - payload: Agent post-invoke payload. - context: Plugin execution context. - - Raises: - NotImplementedError: needs to be implemented by sub class. - - Examples: - >>> import asyncio - >>> from mcpgateway.plugins.framework import PluginConfig, GlobalContext, PluginContext - >>> from mcpgateway.plugins.agent import AgentHookType, AgentPostInvokePayload - >>> config = PluginConfig( - ... name="test_plugin", - ... description="Test", - ... author="test", - ... kind="test.Plugin", - ... version="1.0.0", - ... hooks=[AgentHookType.AGENT_POST_INVOKE] - ... ) - >>> plugin = AgentPlugin(config) - >>> payload = AgentPostInvokePayload(agent_id="agent-123", messages=[]) - >>> ctx = PluginContext(global_context=GlobalContext(request_id="r1")) - >>> result = asyncio.run(plugin.agent_post_invoke(payload, ctx)) - >>> result.continue_processing - True - """ - raise NotImplementedError( - f"""'agent_post_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) diff --git a/mcpgateway/plugins/framework/__init__.py b/mcpgateway/plugins/framework/__init__.py index c170aa35f..ac5e4acb6 100644 --- a/mcpgateway/plugins/framework/__init__.py +++ b/mcpgateway/plugins/framework/__init__.py @@ -17,10 +17,39 @@ from mcpgateway.plugins.framework.base import Plugin from mcpgateway.plugins.framework.errors import PluginError, PluginViolationError from mcpgateway.plugins.framework.external.mcp.server import ExternalPluginServer -from mcpgateway.plugins.framework.hook_registry import HookRegistry, get_hook_registry +from mcpgateway.plugins.framework.hooks.registry import HookRegistry, get_hook_registry from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.manager import PluginManager +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.hooks.agents import ( + AgentHookType, + AgentPostInvokePayload, + AgentPostInvokeResult, + AgentPreInvokePayload, + AgentPreInvokeResult +) +from mcpgateway.plugins.framework.hooks.resources import ( + ResourceHookType, + ResourcePostFetchPayload, + ResourcePostFetchResult, + ResourcePreFetchPayload, + ResourcePreFetchResult +) +from mcpgateway.plugins.framework.hooks.prompts import ( + PromptHookType, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, +) +from mcpgateway.plugins.framework.hooks.tools import ( + ToolHookType, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolPreInvokeResult, + ToolPreInvokePayload +) from mcpgateway.plugins.framework.models import ( GlobalContext, MCPServerConfig, @@ -35,10 +64,16 @@ ) __all__ = [ + "AgentHookType", + "AgentPostInvokePayload", + "AgentPostInvokeResult", + "AgentPreInvokePayload", + "AgentPreInvokeResult", "ConfigLoader", "ExternalPluginServer", "GlobalContext", "HookRegistry", + "HttpHeaderPayload", "get_hook_registry", "MCPServerConfig", "Plugin", @@ -54,4 +89,19 @@ "PluginResult", "PluginViolation", "PluginViolationError", + "PromptHookType", + "PromptPosthookPayload", + "PromptPosthookResult", + "PromptPrehookPayload", + "PromptPrehookResult", + "ResourceHookType", + "ResourcePostFetchPayload", + "ResourcePostFetchResult", + "ResourcePreFetchPayload", + "ResourcePreFetchResult", + "ToolHookType", + "ToolPostInvokePayload", + "ToolPostInvokeResult", + "ToolPreInvokeResult", + "ToolPreInvokePayload" ] diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 3919d5758..759c36687 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -6,17 +6,10 @@ Base plugin implementation. This module implements the base plugin object. -It supports pre and post hooks AI safety, security and business processing -for the following locations in the server: -server_pre_register / server_post_register - for virtual server verification -tool_pre_invoke / tool_post_invoke - for guardrails -prompt_pre_fetch / prompt_post_fetch - for prompt filtering -resource_pre_fetch / resource_post_fetch - for content filtering -auth_pre_check / auth_post_check - for custom auth logic -federation_pre_sync / federation_post_sync - for gateway federation """ # Standard +from abc import ABC from typing import Awaitable, Callable, Optional, Union import uuid @@ -33,7 +26,7 @@ ) -class Plugin: +class Plugin(ABC): """Base plugin object for pre/post processing of inputs and outputs at various locations throughout the server. Examples: @@ -188,7 +181,7 @@ def json_to_payload(self, hook: str, payload: Union[str | dict]) -> PluginPayloa # Fall back to global registry if not hook_payload_type: # First-Party - from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel registry = get_hook_registry() hook_payload_type = registry.get_payload_type(hook) @@ -223,7 +216,7 @@ def json_to_result(self, hook: str, result: Union[str | dict]) -> PluginResult: # Fall back to global registry if not hook_result_type: # First-Party - from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel registry = get_hook_registry() hook_result_type = registry.get_result_type(hook) @@ -374,15 +367,208 @@ class HookRef: def __init__(self, hook: str, plugin_ref: PluginRef): """Initialize a hook reference point. + Discovers the hook method using either: + 1. Convention-based naming (method name matches hook type) + 2. Decorator-based (@hook decorator with matching hook_type) + Args: - hook: name of the hook point. + hook: name of the hook point (e.g., 'tool_pre_invoke'). plugin_ref: The reference to the plugin to hook. + + Raises: + PluginError: If no method is found for the specified hook. + + Examples: + >>> from mcpgateway.plugins.framework import PluginConfig + >>> config = PluginConfig(name="test", kind="test", version="1.0", author="test", hooks=["tool_pre_invoke"]) + >>> plugin = Plugin(config) + >>> plugin_ref = PluginRef(plugin) + >>> # This would work if plugin has tool_pre_invoke method or @hook("tool_pre_invoke") decorator """ + # Standard + import inspect + + # First-Party + from mcpgateway.plugins.framework.decorator import get_hook_metadata + self._plugin_ref = plugin_ref self._hook = hook - self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = getattr(plugin_ref.plugin, hook) + + # Try convention-based lookup first (method name matches hook type) + self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] | None = getattr(plugin_ref.plugin, hook, None) + + # If not found by convention, scan for @hook decorated methods + if self._func is None: + for name, method in inspect.getmembers(plugin_ref.plugin, predicate=inspect.ismethod): + # Skip private/magic methods + if name.startswith("_"): + continue + + # Check for @hook decorator metadata + metadata = get_hook_metadata(method) + if metadata and metadata.hook_type == hook: + self._func = method + break + + # Raise error if hook method not found by either approach if not self._func: - raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} has no hook: {hook}", plugin_name=plugin_ref.plugin.name)) + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_ref.plugin.name}' has no hook: '{hook}'. " + f"Method must either be named '{hook}' or decorated with @hook('{hook}')", + plugin_name=plugin_ref.plugin.name, + ) + ) + + # Validate hook method signature (parameter count and async) + self._validate_hook_signature(hook, self._func, plugin_ref.plugin.name) + + def _validate_hook_signature(self, hook: str, func: Callable, plugin_name: str) -> None: + """Validate that the hook method has the correct signature. + + Checks: + 1. Method accepts correct number of parameters (self, payload, context) + 2. Method is async (returns coroutine) + + Args: + hook: The hook type being validated + func: The hook method to validate + plugin_name: Name of the plugin (for error messages) + + Raises: + PluginError: If the signature is invalid + """ + # Standard + import inspect + + sig = inspect.signature(func) + params = list(sig.parameters.values()) + + # Check parameter count (should be: payload, context) + # Note: 'self' is not included in bound method signatures + if len(params) != 2: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' has invalid signature. " + f"Expected 2 parameters (payload, context), got {len(params)}: {list(sig.parameters.keys())}. " + f"Correct signature: async def {hook}(self, payload: PayloadType, context: PluginContext) -> ResultType", + plugin_name=plugin_name, + ) + ) + + # Check that method is async + if not inspect.iscoroutinefunction(func): + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' must be async. " + f"Method '{func.__name__}' is not a coroutine function. " + f"Use 'async def {func.__name__}(...)' instead of 'def {func.__name__}(...)'.", + plugin_name=plugin_name, + ) + ) + + # ========== OPTIONAL: Type Hint Validation ========== + # Uncomment to enable strict type checking of payload and return types. + # This validates that type hints match the expected types from the hook registry. + # Pros: Catches type errors at plugin load time instead of runtime + # Cons: Requires all plugins to have type hints, adds validation overhead + # + # self._validate_type_hints(hook, func, params, plugin_name) + + def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_name: str) -> None: + """Validate that type hints match expected payload and result types. + + This is an optional validation that can be enabled to enforce type safety. + + Args: + hook: The hook type being validated + func: The hook method to validate + params: List of function parameters + plugin_name: Name of the plugin (for error messages) + + Raises: + PluginError: If type hints are missing or don't match expected types + """ + # Standard + from typing import get_type_hints + + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry + + # Get expected types from registry + registry = get_hook_registry() + expected_payload_type = registry.get_payload_type(hook) + expected_result_type = registry.get_result_type(hook) + + # If hook is not registered in global registry, we can't validate types + if not expected_payload_type or not expected_result_type: + return + + # Get type hints from the function + try: + hints = get_type_hints(func) + except Exception as e: + # Type hints might use forward references or unavailable types + # We'll skip validation rather than fail + import logging + + logger = logging.getLogger(__name__) + logger.debug("Could not extract type hints for plugin '%s' hook '%s': %s", plugin_name, hook, e) + return + + # Validate payload parameter type (first parameter, since 'self' is not in params) + payload_param_name = params[0].name + if payload_param_name not in hints: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' missing type hint for parameter '{payload_param_name}'. " + f"Expected: {payload_param_name}: {expected_payload_type.__name__}", + plugin_name=plugin_name, + ) + ) + + actual_payload_type = hints[payload_param_name] + + # Check if types match (exact match or subclass) + if actual_payload_type != expected_payload_type: + # Check for generic types or complex type hints + actual_type_str = str(actual_payload_type) + expected_type_str = expected_payload_type.__name__ + + # If the expected type name is in the string representation, it's probably OK + if expected_type_str not in actual_type_str: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' parameter '{payload_param_name}' " + f"has incorrect type hint. Expected: {expected_type_str}, Got: {actual_type_str}", + plugin_name=plugin_name, + ) + ) + + # Validate return type + if "return" not in hints: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' missing return type hint. " + f"Expected: -> {expected_result_type.__name__}", + plugin_name=plugin_name, + ) + ) + + actual_return_type = hints["return"] + return_type_str = str(actual_return_type) + expected_return_str = expected_result_type.__name__ + + # For async functions, the return type might be wrapped in Coroutine or Awaitable + # We just check if the expected type is mentioned in the return type + if expected_return_str not in return_type_str and actual_return_type != expected_result_type: + raise PluginError( + error=PluginErrorModel( + message=f"Plugin '{plugin_name}' hook '{hook}' has incorrect return type hint. " + f"Expected: {expected_return_str}, Got: {return_type_str}", + plugin_name=plugin_name, + ) + ) @property def plugin_ref(self) -> PluginRef: diff --git a/mcpgateway/plugins/framework/decorator.py b/mcpgateway/plugins/framework/decorator.py new file mode 100644 index 000000000..2bd998618 --- /dev/null +++ b/mcpgateway/plugins/framework/decorator.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/decorator.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Hook decorator for dynamically registering plugin hooks. + +This module provides decorators for marking plugin methods as hook handlers. +Plugins can use these decorators to: +1. Override the default hook naming convention +2. Register custom hooks not in the standard framework + +Examples: + Override hook method name:: + + class MyPlugin(Plugin): + @hook(ToolHookType.TOOL_PRE_INVOKE) + def custom_name_for_tool_hook(self, payload, context): + # This gets called for tool_pre_invoke even though + # the method name doesn't match + return ToolPreInvokeResult(continue_processing=True) + + Register a completely new hook type:: + + class MyPlugin(Plugin): + @hook("custom_pre_process", CustomPayload, CustomResult) + def my_custom_hook(self, payload, context): + # This registers a new hook type dynamically + return CustomResult(continue_processing=True) + + Use default convention (no decorator needed):: + + class MyPlugin(Plugin): + def tool_pre_invoke(self, payload, context): + # Automatically recognized by naming convention + return ToolPreInvokeResult(continue_processing=True) +""" + +# Standard +from typing import Callable, Optional, Type, TypeVar + +# Third-Party +from pydantic import BaseModel + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + +# Attribute name for storing hook metadata on functions +_HOOK_METADATA_ATTR = "_plugin_hook_metadata" + +# Type vars for type hints +P = TypeVar("P", bound=PluginPayload) # Payload type +R = TypeVar("R", bound=PluginResult) # Result type + + +class HookMetadata: + """Metadata stored on decorated hook methods. + + Attributes: + hook_type: The hook type identifier (e.g., 'tool_pre_invoke') + payload_type: Optional payload class for hook registration + result_type: Optional result class for hook registration + """ + + def __init__( + self, + hook_type: str, + payload_type: Optional[Type[BaseModel]] = None, + result_type: Optional[Type[BaseModel]] = None, + ): + """Initialize hook metadata. + + Args: + hook_type: The hook type identifier + payload_type: Optional payload class for registering new hooks + result_type: Optional result class for registering new hooks + """ + self.hook_type = hook_type + self.payload_type = payload_type + self.result_type = result_type + + +def hook( + hook_type: str, + payload_type: Optional[Type[P]] = None, + result_type: Optional[Type[R]] = None, +) -> Callable[[Callable], Callable]: + """Decorator to mark a method as a plugin hook handler. + + This decorator attaches metadata to a method so the Plugin class can + discover it during initialization and register it with the appropriate + hook type. + + Args: + hook_type: The hook type identifier (e.g., 'tool_pre_invoke') + payload_type: Optional payload class for registering new hook types + result_type: Optional result class for registering new hook types + + Returns: + Decorator function that marks the method with hook metadata + + Examples: + Override method name:: + + @hook(ToolHookType.TOOL_PRE_INVOKE) + def my_custom_method_name(self, payload, context): + return ToolPreInvokeResult(continue_processing=True) + + Register new hook type:: + + @hook("email_pre_send", EmailPayload, EmailResult) + def handle_email(self, payload, context): + return EmailResult(continue_processing=True) + """ + + def decorator(func: Callable) -> Callable: + """Inner decorator that attaches metadata to the function. + + Args: + func: The function to decorate + + Returns: + The same function with metadata attached + """ + # Store metadata on the function object + metadata = HookMetadata(hook_type, payload_type, result_type) + setattr(func, _HOOK_METADATA_ATTR, metadata) + return func + + return decorator + + +def get_hook_metadata(func: Callable) -> Optional[HookMetadata]: + """Get hook metadata from a decorated function. + + Args: + func: The function to check + + Returns: + HookMetadata if the function is decorated, None otherwise + + Examples: + >>> @hook("test_hook") + ... def test_func(): + ... pass + >>> metadata = get_hook_metadata(test_func) + >>> metadata.hook_type + 'test_hook' + >>> get_hook_metadata(lambda: None) is None + True + """ + return getattr(func, _HOOK_METADATA_ATTR, None) + + +def has_hook_metadata(func: Callable) -> bool: + """Check if a function has hook metadata. + + Args: + func: The function to check + + Returns: + True if the function is decorated with @hook, False otherwise + + Examples: + >>> @hook("test_hook") + ... def decorated(): + ... pass + >>> has_hook_metadata(decorated) + True + >>> has_hook_metadata(lambda: None) + False + """ + return hasattr(func, _HOOK_METADATA_ATTR) diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index 9ebebaa28..0f90b7292 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -43,7 +43,7 @@ ) from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError from mcpgateway.plugins.framework.external.mcp.tls_utils import create_ssl_context -from mcpgateway.plugins.framework.hook_registry import get_hook_registry +from mcpgateway.plugins.framework.hooks.registry import get_hook_registry from mcpgateway.plugins.framework.models import ( MCPClientTLSConfig, PluginConfig, diff --git a/mcpgateway/plugins/framework/hooks/__init__.py b/mcpgateway/plugins/framework/hooks/__init__.py new file mode 100644 index 000000000..31153c3b7 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hooks/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Plugins hooks package. +Exposes predefined hooks for plugins +""" diff --git a/mcpgateway/plugins/agent/models.py b/mcpgateway/plugins/framework/hooks/agents.py similarity index 81% rename from mcpgateway/plugins/agent/models.py rename to mcpgateway/plugins/framework/hooks/agents.py index 601de3f22..c748aadea 100644 --- a/mcpgateway/plugins/agent/models.py +++ b/mcpgateway/plugins/framework/hooks/agents.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""Location: ./mcpgateway/plugins/agent/models.py +"""Location: ./mcpgateway/plugins/models/agents.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 Authors: Teryl Taylor @@ -19,7 +19,7 @@ # First-Party from mcpgateway.common.models import Message from mcpgateway.plugins.framework.models import PluginPayload, PluginResult -from mcpgateway.plugins.mcp.entities.models import HttpHeaderPayload +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload class AgentHookType(str, Enum): @@ -121,3 +121,21 @@ class AgentPostInvokePayload(PluginPayload): AgentPreInvokeResult = PluginResult[AgentPreInvokePayload] AgentPostInvokeResult = PluginResult[AgentPostInvokePayload] + +def _register_agent_hooks(): + """Register agent hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(AgentHookType.AGENT_PRE_INVOKE): + registry.register_hook(AgentHookType.AGENT_PRE_INVOKE, AgentPreInvokePayload, AgentPreInvokeResult) + registry.register_hook(AgentHookType.AGENT_POST_INVOKE, AgentPostInvokePayload, AgentPostInvokeResult) + +_register_agent_hooks() \ No newline at end of file diff --git a/mcpgateway/plugins/framework/hooks/http.py b/mcpgateway/plugins/framework/hooks/http.py new file mode 100644 index 000000000..34513adcc --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/http.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/models/http.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for http hooks and payloads. +""" + +from pydantic import RootModel + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + +class HttpHeaderPayload(RootModel[dict[str, str]], PluginPayload): + """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" + + def __iter__(self): + """Custom iterator function to override root attribute. + + Returns: + A custom iterator for header dictionary. + """ + return iter(self.root) + + def __getitem__(self, item: str) -> str: + """Custom getitem function to override root attribute. + + Args: + item: The http header key. + + Returns: + A custom accesser for the header dictionary. + """ + return self.root[item] + + def __setitem__(self, key: str, value: str) -> None: + """Custom setitem function to override root attribute. + + Args: + key: The http header key. + value: The http header value to be set. + """ + self.root[key] = value + + def __len__(self): + """Custom len function to override root attribute. + + Returns: + The len of the header dictionary. + """ + return len(self.root) + + +HttpHeaderPayloadResult = PluginResult[HttpHeaderPayload] diff --git a/mcpgateway/plugins/framework/hooks/prompts.py b/mcpgateway/plugins/framework/hooks/prompts.py new file mode 100644 index 000000000..faee02c42 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/prompts.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/hooks/prompts.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for prompt plugins. +This module implements the pydantic models associated with +the base plugin layer including configurations, and contexts. +""" + +# Standard +from enum import Enum +from typing import Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.common.models import PromptResult +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class PromptHookType(str, Enum): + """MCP Forge Gateway hook points. + + Attributes: + prompt_pre_fetch: The prompt pre hook. + prompt_post_fetch: The prompt post hook. + tool_pre_invoke: The tool pre invoke hook. + tool_post_invoke: The tool post invoke hook. + resource_pre_fetch: The resource pre fetch hook. + resource_post_fetch: The resource post fetch hook. + + Examples: + >>> PromptHookType.PROMPT_PRE_FETCH + + >>> PromptHookType.PROMPT_PRE_FETCH.value + 'prompt_pre_fetch' + >>> PromptHookType('prompt_post_fetch') + + >>> list(PromptHookType) + [, ] + """ + + PROMPT_PRE_FETCH = "prompt_pre_fetch" + PROMPT_POST_FETCH = "prompt_post_fetch" + + +class PromptPrehookPayload(PluginPayload): + """A prompt payload for a prompt prehook. + + Attributes: + prompt_id (str): The ID of the prompt template. + args (dic[str,str]): The prompt template arguments. + + Examples: + >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) + >>> payload.prompt_id + '123' + >>> payload.args + {'user': 'alice'} + >>> payload2 = PromptPrehookPayload(prompt_id="empty") + >>> payload2.args + {} + >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) + >>> p.prompt_id + '123' + >>> p.args["name"] + 'Bob' + """ + + prompt_id: str + args: Optional[dict[str, str]] = Field(default_factory=dict) + + +class PromptPosthookPayload(PluginPayload): + """A prompt payload for a prompt posthook. + + Attributes: + prompt_id (str): The prompt ID. + result (PromptResult): The prompt after its template is rendered. + + Examples: + >>> from mcpgateway.common.models import PromptResult, Message, TextContent + >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) + >>> result = PromptResult(messages=[msg]) + >>> payload = PromptPosthookPayload(prompt_id="123", result=result) + >>> payload.prompt_id + '123' + >>> payload.result.messages[0].content.text + 'Hello World' + >>> from mcpgateway.common.models import PromptResult, Message, TextContent + >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) + >>> r = PromptResult(messages=[msg]) + >>> p = PromptPosthookPayload(prompt_id="123", result=r) + >>> p.prompt_id + '123' + """ + + prompt_id: str + result: PromptResult + + +PromptPrehookResult = PluginResult[PromptPrehookPayload] +PromptPosthookResult = PluginResult[PromptPosthookPayload] + +def _register_prompt_hooks(): + """Register prompt hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(PromptHookType.PROMPT_PRE_FETCH): + registry.register_hook(PromptHookType.PROMPT_PRE_FETCH, PromptPrehookPayload, PromptPrehookResult) + registry.register_hook(PromptHookType.PROMPT_POST_FETCH, PromptPosthookPayload, PromptPosthookResult) + +_register_prompt_hooks() + + + + + + + + diff --git a/mcpgateway/plugins/framework/hook_registry.py b/mcpgateway/plugins/framework/hooks/registry.py similarity index 98% rename from mcpgateway/plugins/framework/hook_registry.py rename to mcpgateway/plugins/framework/hooks/registry.py index a10008cd7..570b9cb42 100644 --- a/mcpgateway/plugins/framework/hook_registry.py +++ b/mcpgateway/plugins/framework/hooks/registry.py @@ -115,7 +115,7 @@ def json_to_payload(self, hook_type: str, payload: Union[str, dict]) -> PluginPa Examples: >>> registry = HookRegistry() - >>> from mcpgateway.plugins.framework import PluginPayload + >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult >>> registry.register_hook("test", PluginPayload, PluginResult) >>> payload = registry.json_to_payload("test", "{}") """ @@ -142,7 +142,7 @@ def json_to_result(self, hook_type: str, result: Union[str, dict]) -> PluginResu Examples: >>> registry = HookRegistry() - >>> from mcpgateway.plugins.framework import PluginResult + >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult >>> registry.register_hook("test", PluginPayload, PluginResult) >>> result = registry.json_to_result("test", '{"continue_processing": true}') """ diff --git a/mcpgateway/plugins/framework/hooks/resources.py b/mcpgateway/plugins/framework/hooks/resources.py new file mode 100644 index 000000000..8d5c7058b --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/resources.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hooks/resources.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for resource hooks. +""" + +# Standard +from enum import Enum +from typing import Any, Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + + +class ResourceHookType(str, Enum): + """MCP Forge Gateway resource hook points. + + Attributes: + resource_pre_fetch: The resource pre fetch hook. + resource_post_fetch: The resource post fetch hook. + + Examples: + >>> ResourceHookType.RESOURCE_PRE_FETCH + + >>> ResourceHookType.RESOURCE_PRE_FETCH.value + 'resource_pre_fetch' + >>> ResourceHookType('resource_post_fetch') + + >>> list(ResourceHookType) + [, ] + """ + + RESOURCE_PRE_FETCH = "resource_pre_fetch" + RESOURCE_POST_FETCH = "resource_post_fetch" + +class ResourcePreFetchPayload(PluginPayload): + """A resource payload for a resource pre-fetch hook. + + Attributes: + uri: The resource URI. + metadata: Optional metadata for the resource request. + + Examples: + >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") + >>> payload.uri + 'file:///data.txt' + >>> payload2 = ResourcePreFetchPayload(uri="http://api/data", metadata={"Accept": "application/json"}) + >>> payload2.metadata + {'Accept': 'application/json'} + >>> p = ResourcePreFetchPayload(uri="file:///docs/readme.md", metadata={"version": "1.0"}) + >>> p.uri + 'file:///docs/readme.md' + >>> p.metadata["version"] + '1.0' + """ + + uri: str + metadata: Optional[dict[str, Any]] = Field(default_factory=dict) + + +class ResourcePostFetchPayload(PluginPayload): + """A resource payload for a resource post-fetch hook. + + Attributes: + uri: The resource URI. + content: The fetched resource content. + + Examples: + >>> from mcpgateway.common.models import ResourceContent + >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", + ... text="Hello World") + >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) + >>> payload.uri + 'file:///data.txt' + >>> payload.content.text + 'Hello World' + >>> from mcpgateway.common.models import ResourceContent + >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") + >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) + >>> p.uri + 'test://resource' + """ + + uri: str + content: Any + + +ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] +ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] + +def _register_resource_hooks(): + """Register resource hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(ResourceHookType.RESOURCE_PRE_FETCH): + registry.register_hook(ResourceHookType.RESOURCE_PRE_FETCH, ResourcePreFetchPayload, ResourcePreFetchResult) + registry.register_hook(ResourceHookType.RESOURCE_POST_FETCH, ResourcePostFetchPayload, ResourcePostFetchResult) + +_register_resource_hooks() diff --git a/mcpgateway/plugins/framework/hooks/tools.py b/mcpgateway/plugins/framework/hooks/tools.py new file mode 100644 index 000000000..16afbae36 --- /dev/null +++ b/mcpgateway/plugins/framework/hooks/tools.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/plugins/framework/hooks/tools.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Pydantic models for tool hooks. +""" + +# Standard +from enum import Enum +from typing import Any, Optional + +# Third-Party +from pydantic import Field + +# First-Party +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload + +class ToolHookType(str, Enum): + """MCP Forge Gateway hook points. + + Attributes: + tool_pre_invoke: The tool pre invoke hook. + tool_post_invoke: The tool post invoke hook. + + Examples: + >>> ToolHookType.TOOL_PRE_INVOKE + + >>> ToolHookType.TOOL_PRE_INVOKE.value + 'tool_pre_invoke' + >>> ToolHookType('tool_post_invoke') + + >>> list(ToolHookType) + [, ] + """ + + TOOL_PRE_INVOKE = "tool_pre_invoke" + TOOL_POST_INVOKE = "tool_post_invoke" + + +class ToolPreInvokePayload(PluginPayload): + """A tool payload for a tool pre-invoke hook. + + Args: + name: The tool name. + args: The tool arguments for invocation. + headers: The http pass through headers. + + Examples: + >>> payload = ToolPreInvokePayload(name="test_tool", args={"input": "data"}) + >>> payload.name + 'test_tool' + >>> payload.args + {'input': 'data'} + >>> payload2 = ToolPreInvokePayload(name="empty") + >>> payload2.args + {} + >>> p = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) + >>> p.name + 'calculator' + >>> p.args["operation"] + 'add' + + """ + + name: str + args: Optional[dict[str, Any]] = Field(default_factory=dict) + headers: Optional[HttpHeaderPayload] = None + + +class ToolPostInvokePayload(PluginPayload): + """A tool payload for a tool post-invoke hook. + + Args: + name: The tool name. + result: The tool invocation result. + + Examples: + >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) + >>> payload.name + 'calculator' + >>> payload.result + {'result': 8, 'status': 'success'} + >>> p = ToolPostInvokePayload(name="analyzer", result={"confidence": 0.95, "sentiment": "positive"}) + >>> p.name + 'analyzer' + >>> p.result["confidence"] + 0.95 + """ + + name: str + result: Any + + +ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] +ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] + +def _register_tool_hooks(): + """Register Tool hooks in the global registry. + + This is called lazily to avoid circular import issues. + """ + # Import here to avoid circular dependency at module load time + # First-Party + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + # Only register if not already registered (idempotent) + if not registry.is_registered(ToolHookType.TOOL_PRE_INVOKE): + registry.register_hook(ToolHookType.TOOL_PRE_INVOKE, ToolPreInvokePayload, ToolPreInvokeResult) + registry.register_hook(ToolHookType.TOOL_POST_INVOKE, ToolPostInvokePayload, ToolPostInvokeResult) + + +_register_tool_hooks() diff --git a/mcpgateway/plugins/mcp/__init__.py b/mcpgateway/plugins/mcp/__init__.py deleted file mode 100644 index c45913753..000000000 --- a/mcpgateway/plugins/mcp/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./mcpgateway/plugins/mcp/__init__.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -MCP Plugins Package. -""" diff --git a/mcpgateway/plugins/mcp/entities/__init__.py b/mcpgateway/plugins/mcp/entities/__init__.py deleted file mode 100644 index 2e93aa073..000000000 --- a/mcpgateway/plugins/mcp/entities/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Location: ./mcpgateway/plugins/mcp/entities/__init__.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -MCP Plugins Entities Package. -""" - -# First-Party -from mcpgateway.plugins.mcp.entities.models import ( - HttpHeaderPayload, - HttpHeaderPayloadResult, - HookType, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - PromptResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, -) - -from mcpgateway.plugins.mcp.entities.base import MCPPlugin - -__all__ = [ - "HookType", - "HttpHeaderPayload", - "HttpHeaderPayloadResult", - "MCPPlugin", - "PromptPosthookPayload", - "PromptPosthookResult", - "PromptPrehookPayload", - "PromptPrehookResult", - "PromptResult", - "ResourcePostFetchPayload", - "ResourcePostFetchResult", - "ResourcePreFetchPayload", - "ResourcePreFetchResult", - "ToolPostInvokePayload", - "ToolPostInvokeResult", - "ToolPreInvokePayload", - "ToolPreInvokeResult", -] diff --git a/mcpgateway/plugins/mcp/entities/base.py b/mcpgateway/plugins/mcp/entities/base.py deleted file mode 100644 index ae17704a6..000000000 --- a/mcpgateway/plugins/mcp/entities/base.py +++ /dev/null @@ -1,212 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./mcpgateway/plugins/mcp/entities/base.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -Base plugin implementation. -This module implements the base plugin object. -It supports pre and post hooks AI safety, security and business processing -for the following locations in the server: -server_pre_register / server_post_register - for virtual server verification -tool_pre_invoke / tool_post_invoke - for guardrails -prompt_pre_fetch / prompt_post_fetch - for prompt filtering -resource_pre_fetch / resource_post_fetch - for content filtering -auth_pre_check / auth_post_check - for custom auth logic -federation_pre_sync / federation_post_sync - for gateway federation -""" - -# Standard - -# First-Party -from mcpgateway.plugins.framework.base import Plugin -from mcpgateway.plugins.framework.models import PluginConfig, PluginContext -from mcpgateway.plugins.mcp.entities.models import ( - HookType, - PromptPosthookPayload, - PromptPosthookResult, - PromptPrehookPayload, - PromptPrehookResult, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokePayload, - ToolPreInvokeResult, -) - - -def _register_mcp_hooks(): - """Register MCP hooks in the global registry. - - This is called lazily to avoid circular import issues. - """ - # Import here to avoid circular dependency at module load time - # First-Party - from mcpgateway.plugins.framework.hook_registry import get_hook_registry # pylint: disable=import-outside-toplevel - - registry = get_hook_registry() - - # Only register if not already registered (idempotent) - if not registry.is_registered(HookType.PROMPT_PRE_FETCH): - registry.register_hook(HookType.PROMPT_PRE_FETCH, PromptPrehookPayload, PromptPrehookResult) - registry.register_hook(HookType.PROMPT_POST_FETCH, PromptPosthookPayload, PromptPosthookResult) - registry.register_hook(HookType.RESOURCE_PRE_FETCH, ResourcePreFetchPayload, ResourcePreFetchResult) - registry.register_hook(HookType.RESOURCE_POST_FETCH, ResourcePostFetchPayload, ResourcePostFetchResult) - registry.register_hook(HookType.TOOL_PRE_INVOKE, ToolPreInvokePayload, ToolPreInvokeResult) - registry.register_hook(HookType.TOOL_POST_INVOKE, ToolPostInvokePayload, ToolPostInvokeResult) - - -class MCPPlugin(Plugin): - """Base mcp plugin object for pre/post processing of inputs and outputs at various locations throughout the server. - - Examples: - >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode - >>> from mcpgateway.plugins.mcp.entities import HookType - >>> config = PluginConfig( - ... name="test_plugin", - ... description="Test plugin", - ... author="test", - ... kind="mcpgateway.plugins.framework.Plugin", - ... version="1.0.0", - ... hooks=[HookType.PROMPT_PRE_FETCH], - ... tags=["test"], - ... mode=PluginMode.ENFORCE, - ... priority=50 - ... ) - >>> plugin = MCPPlugin(config) - >>> plugin.name - 'test_plugin' - >>> plugin.priority - 50 - >>> plugin.mode - - >>> HookType.PROMPT_PRE_FETCH in plugin.hooks - True - """ - - def __init__(self, config: PluginConfig) -> None: - """Initialize a plugin with a configuration and context. - - Args: - config: The plugin configuration - - Examples: - >>> from mcpgateway.plugins.framework import PluginConfig - >>> from mcpgateway.plugins.mcp.entities import HookType - >>> config = PluginConfig( - ... name="simple_plugin", - ... description="Simple test", - ... author="test", - ... kind="test.Plugin", - ... version="1.0.0", - ... hooks=[HookType.PROMPT_POST_FETCH], - ... tags=["simple"] - ... ) - >>> plugin = MCPPlugin(config) - >>> plugin._config.name - 'simple_plugin' - """ - super().__init__(config) - - async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Plugin hook run before a prompt is retrieved and rendered. - - Args: - payload: The prompt payload to be analyzed. - context: contextual information about the hook call. Including why it was called. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'prompt_pre_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) - - async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Plugin hook run after a prompt is rendered. - - Args: - payload: The prompt payload to be analyzed. - context: Contextual information about the hook call. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'prompt_post_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) - - async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Plugin hook run before a tool is invoked. - - Args: - payload: The tool payload to be analyzed. - context: Contextual information about the hook call. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'tool_pre_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) - - async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Plugin hook run after a tool is invoked. - - Args: - payload: The tool result payload to be analyzed. - context: Contextual information about the hook call. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'tool_post_invoke' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) - - async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Plugin hook run before a resource is fetched. - - Args: - payload: The resource payload to be analyzed. - context: Contextual information about the hook call. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'resource_pre_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) - - async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Plugin hook run after a resource is fetched. - - Args: - payload: The resource content payload to be analyzed. - context: Contextual information about the hook call. - - Raises: - NotImplementedError: needs to be implemented by sub class. - """ - raise NotImplementedError( - f"""'resource_post_fetch' not implemented for plugin {self._config.name} - of plugin type {type(self)} - """ - ) - - -# Register MCP hooks when this module is imported -_register_mcp_hooks() diff --git a/mcpgateway/plugins/mcp/entities/models.py b/mcpgateway/plugins/mcp/entities/models.py deleted file mode 100644 index ad13e0473..000000000 --- a/mcpgateway/plugins/mcp/entities/models.py +++ /dev/null @@ -1,267 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./mcpgateway/plugins/mcp/entities/models.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Teryl Taylor - -Pydantic models for MCP plugins. -This module implements the pydantic models associated with -the base plugin layer including configurations, and contexts. -""" - -# Standard -from enum import Enum -from typing import Any, Optional - -# Third-Party -from pydantic import Field, RootModel - -# First-Party -from mcpgateway.common.models import PromptResult -from mcpgateway.plugins.framework.models import PluginPayload, PluginResult - - -class HookType(str, Enum): - """MCP Forge Gateway hook points. - - Attributes: - prompt_pre_fetch: The prompt pre hook. - prompt_post_fetch: The prompt post hook. - tool_pre_invoke: The tool pre invoke hook. - tool_post_invoke: The tool post invoke hook. - resource_pre_fetch: The resource pre fetch hook. - resource_post_fetch: The resource post fetch hook. - - Examples: - >>> HookType.PROMPT_PRE_FETCH - - >>> HookType.PROMPT_PRE_FETCH.value - 'prompt_pre_fetch' - >>> HookType('prompt_post_fetch') - - >>> list(HookType) # doctest: +ELLIPSIS - [, , , , ...] - """ - - PROMPT_PRE_FETCH = "prompt_pre_fetch" - PROMPT_POST_FETCH = "prompt_post_fetch" - TOOL_PRE_INVOKE = "tool_pre_invoke" - TOOL_POST_INVOKE = "tool_post_invoke" - RESOURCE_PRE_FETCH = "resource_pre_fetch" - RESOURCE_POST_FETCH = "resource_post_fetch" - - -class PromptPrehookPayload(PluginPayload): - """A prompt payload for a prompt prehook. - - Attributes: - prompt_id (str): The ID of the prompt template. - args (dic[str,str]): The prompt template arguments. - - Examples: - >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) - >>> payload.prompt_id - '123' - >>> payload.args - {'user': 'alice'} - >>> payload2 = PromptPrehookPayload(prompt_id="empty") - >>> payload2.args - {} - >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) - >>> p.prompt_id - '123' - >>> p.args["name"] - 'Bob' - """ - - prompt_id: str - args: Optional[dict[str, str]] = Field(default_factory=dict) - - -class PromptPosthookPayload(PluginPayload): - """A prompt payload for a prompt posthook. - - Attributes: - prompt_id (str): The prompt ID. - result (PromptResult): The prompt after its template is rendered. - - Examples: - >>> from mcpgateway.common.models import PromptResult, Message, TextContent - >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) - >>> result = PromptResult(messages=[msg]) - >>> payload = PromptPosthookPayload(prompt_id="123", result=result) - >>> payload.prompt_id - '123' - >>> payload.result.messages[0].content.text - 'Hello World' - >>> from mcpgateway.common.models import PromptResult, Message, TextContent - >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) - >>> r = PromptResult(messages=[msg]) - >>> p = PromptPosthookPayload(prompt_id="123", result=r) - >>> p.prompt_id - '123' - """ - - prompt_id: str - result: PromptResult - - -PromptPrehookResult = PluginResult[PromptPrehookPayload] -PromptPosthookResult = PluginResult[PromptPosthookPayload] - - -class HttpHeaderPayload(RootModel[dict[str, str]]): - """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" - - def __iter__(self): - """Custom iterator function to override root attribute. - - Returns: - A custom iterator for header dictionary. - """ - return iter(self.root) - - def __getitem__(self, item: str) -> str: - """Custom getitem function to override root attribute. - - Args: - item: The http header key. - - Returns: - A custom accesser for the header dictionary. - """ - return self.root[item] - - def __setitem__(self, key: str, value: str) -> None: - """Custom setitem function to override root attribute. - - Args: - key: The http header key. - value: The http header value to be set. - """ - self.root[key] = value - - def __len__(self): - """Custom len function to override root attribute. - - Returns: - The len of the header dictionary. - """ - return len(self.root) - - -HttpHeaderPayloadResult = PluginResult[HttpHeaderPayload] - - -class ToolPreInvokePayload(PluginPayload): - """A tool payload for a tool pre-invoke hook. - - Args: - name: The tool name. - args: The tool arguments for invocation. - headers: The http pass through headers. - - Examples: - >>> payload = ToolPreInvokePayload(name="test_tool", args={"input": "data"}) - >>> payload.name - 'test_tool' - >>> payload.args - {'input': 'data'} - >>> payload2 = ToolPreInvokePayload(name="empty") - >>> payload2.args - {} - >>> p = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) - >>> p.name - 'calculator' - >>> p.args["operation"] - 'add' - - """ - - name: str - args: Optional[dict[str, Any]] = Field(default_factory=dict) - headers: Optional[HttpHeaderPayload] = None - - -class ToolPostInvokePayload(PluginPayload): - """A tool payload for a tool post-invoke hook. - - Args: - name: The tool name. - result: The tool invocation result. - - Examples: - >>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - >>> payload.name - 'calculator' - >>> payload.result - {'result': 8, 'status': 'success'} - >>> p = ToolPostInvokePayload(name="analyzer", result={"confidence": 0.95, "sentiment": "positive"}) - >>> p.name - 'analyzer' - >>> p.result["confidence"] - 0.95 - """ - - name: str - result: Any - - -ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] -ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] - - -class ResourcePreFetchPayload(PluginPayload): - """A resource payload for a resource pre-fetch hook. - - Attributes: - uri: The resource URI. - metadata: Optional metadata for the resource request. - - Examples: - >>> payload = ResourcePreFetchPayload(uri="file:///data.txt") - >>> payload.uri - 'file:///data.txt' - >>> payload2 = ResourcePreFetchPayload(uri="http://api/data", metadata={"Accept": "application/json"}) - >>> payload2.metadata - {'Accept': 'application/json'} - >>> p = ResourcePreFetchPayload(uri="file:///docs/readme.md", metadata={"version": "1.0"}) - >>> p.uri - 'file:///docs/readme.md' - >>> p.metadata["version"] - '1.0' - """ - - uri: str - metadata: Optional[dict[str, Any]] = Field(default_factory=dict) - - -class ResourcePostFetchPayload(PluginPayload): - """A resource payload for a resource post-fetch hook. - - Attributes: - uri: The resource URI. - content: The fetched resource content. - - Examples: - >>> from mcpgateway.common.models import ResourceContent - >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", - ... text="Hello World") - >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - >>> payload.uri - 'file:///data.txt' - >>> payload.content.text - 'Hello World' - >>> from mcpgateway.common.models import ResourceContent - >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") - >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) - >>> p.uri - 'test://resource' - """ - - uri: str - content: Any - - -ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] -ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index eedc6dec0..30fd601fb 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -36,8 +36,13 @@ from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, PluginManager -from mcpgateway.plugins.mcp.entities import HookType, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import ( + GlobalContext, + PluginManager, + PromptHookType, + PromptPosthookPayload, + PromptPrehookPayload +) from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.metrics_common import build_top_performers @@ -692,7 +697,7 @@ async def get_prompt( request_id = uuid.uuid4().hex global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) pre_result, context_table = await self._plugin_manager.invoke_hook( - HookType.PROMPT_PRE_FETCH, + PromptHookType.PROMPT_PRE_FETCH, payload=PromptPrehookPayload(prompt_id=str(prompt_id), args=arguments), global_context=global_context, local_contexts=None, @@ -761,7 +766,7 @@ async def get_prompt( if self._plugin_manager: post_result, _ = await self._plugin_manager.invoke_hook( - HookType.PROMPT_POST_FETCH, + PromptHookType.PROMPT_POST_FETCH, payload=PromptPosthookPayload(prompt_id=str(prompt.id), result=result), global_context=global_context, local_contexts=context_table, diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 664324451..6790a156b 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -56,8 +56,13 @@ # Plugin support imports (conditional) try: # First-Party - from mcpgateway.plugins.framework import GlobalContext, PluginManager - from mcpgateway.plugins.mcp.entities import HookType, ResourcePostFetchPayload, ResourcePreFetchPayload + from mcpgateway.plugins.framework import ( + GlobalContext, + PluginManager, + ResourceHookType, + ResourcePostFetchPayload, + ResourcePreFetchPayload + ) PLUGINS_AVAILABLE = True except ImportError: @@ -736,7 +741,7 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request pre_payload = ResourcePreFetchPayload(uri=uri, metadata={}) # Execute pre-fetch hooks - pre_result, contexts = await self._plugin_manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, pre_payload, global_context, violations_as_exceptions=True) + pre_result, contexts = await self._plugin_manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, pre_payload, global_context, violations_as_exceptions=True) # Use modified URI if plugin changed it if pre_result.modified_payload: uri = pre_result.modified_payload.uri @@ -767,7 +772,7 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request # Execute post-fetch hooks post_result, _ = await self._plugin_manager.invoke_hook( - HookType.RESOURCE_POST_FETCH, post_payload, global_context, contexts, violations_as_exceptions=True + ResourceHookType.RESOURCE_POST_FETCH, post_payload, global_context, contexts, violations_as_exceptions=True ) # Pass contexts from pre-fetch # Use modified content if plugin changed it diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 725983579..fd992a4f7 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -49,9 +49,17 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, PluginError, PluginManager, PluginViolationError +from mcpgateway.plugins.framework import ( + GlobalContext, + PluginError, + PluginManager, + PluginViolationError, + ToolHookType, + HttpHeaderPayload, + ToolPostInvokePayload, + ToolPreInvokePayload +) from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA -from mcpgateway.plugins.mcp.entities import HookType, HttpHeaderPayload, ToolPostInvokePayload, ToolPreInvokePayload from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager @@ -1004,7 +1012,7 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r tool_metadata = PydanticTool.model_validate(tool) global_context.metadata[TOOL_METADATA] = tool_metadata pre_result, context_table = await self._plugin_manager.invoke_hook( - HookType.TOOL_PRE_INVOKE, + ToolHookType.TOOL_PRE_INVOKE, payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), global_context=global_context, local_contexts=None, @@ -1156,7 +1164,7 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head gateway_metadata = PydanticGateway.model_validate(tool_gateway) global_context.metadata[GATEWAY_METADATA] = gateway_metadata pre_result, context_table = await self._plugin_manager.invoke_hook( - HookType.TOOL_PRE_INVOKE, + ToolHookType.TOOL_PRE_INVOKE, payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), global_context=global_context, local_contexts=None, @@ -1186,7 +1194,7 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head # Plugin hook: tool post-invoke if self._plugin_manager: post_result, _ = await self._plugin_manager.invoke_hook( - HookType.TOOL_POST_INVOKE, + ToolHookType.TOOL_POST_INVOKE, payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), global_context=global_context, local_contexts=context_table, diff --git a/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja b/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja index e3a73631b..cdd8f3e80 100644 --- a/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja +++ b/plugin_templates/external/{{ plugin_name.lower().replace(' ', '_').replace('-', '_') }}/plugin.py.jinja @@ -29,7 +29,7 @@ from mcpgateway.plugins.framework import ( {% else -%} {% set class_name = class_parts|join -%} {% endif -%} -class {{ class_name }}(MCPPlugin): +class {{ class_name }}(Plugin): """{{ description }}.""" def __init__(self, config: PluginConfig): diff --git a/plugin_templates/native/plugin.py.jinja b/plugin_templates/native/plugin.py.jinja index e3a73631b..cdd8f3e80 100644 --- a/plugin_templates/native/plugin.py.jinja +++ b/plugin_templates/native/plugin.py.jinja @@ -29,7 +29,7 @@ from mcpgateway.plugins.framework import ( {% else -%} {% set class_name = class_parts|join -%} {% endif -%} -class {{ class_name }}(MCPPlugin): +class {{ class_name }}(Plugin): """{{ description }}.""" def __init__(self, config: PluginConfig): diff --git a/plugins/README.md b/plugins/README.md index 24e981824..7dc19ba41 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -43,9 +43,11 @@ Plugins can implement hooks at these lifecycle points: | `prompt_pre_fetch` | Before prompt template retrieval | `PromptPrehookPayload` | Input validation, access control | | `prompt_post_fetch` | After prompt template retrieval | `PromptPosthookPayload` | Content filtering, transformation | | `tool_pre_invoke` | Before tool execution | `ToolPreInvokePayload` | Parameter validation, safety checks | -| `tool_post_invoke` | After tool execution | `ToolPostInvokeResult` | Result filtering, audit logging | +| `tool_post_invoke` | After tool execution | `ToolPostInvokePayload` | Result filtering, audit logging | | `resource_pre_fetch` | Before resource retrieval | `ResourcePreFetchPayload` | Protocol/domain validation | -| `resource_post_fetch` | After resource retrieval | `ResourcePostFetchResult` | Content scanning, size limits | +| `resource_post_fetch` | After resource retrieval | `ResourcePostFetchPayload` | Content scanning, size limits | +| `agent_pre_invoke` | Before agent invocation | `AgentPreInvokePayload` | Message filtering, access control | +| `agent_post_invoke` | After agent response | `AgentPostInvokePayload` | Response filtering, audit logging | Future hooks (in development): - `server_pre_register` / `server_post_register` - Virtual server verification @@ -159,80 +161,279 @@ Validate and filter resource requests: ## Writing Custom Plugins -### 1. Plugin Structure +### Understanding the Plugin Base Class -Create a new directory under `plugins/`: +The `Plugin` class is an abstract base class (ABC) that provides the foundation for all plugins. You **must** subclass it and implement at least one hook method to create a functional plugin. -``` -plugins/my_plugin/ -β”œβ”€β”€ __init__.py -β”œβ”€β”€ plugin-manifest.yaml -β”œβ”€β”€ my_plugin.py -└── README.md +```python +from abc import ABC +from mcpgateway.plugins.framework import Plugin + +class MyPlugin(Plugin): + """Your plugin must inherit from Plugin.""" + # Implement hook methods (see patterns below) ``` -### 2. Plugin Manifest (`plugin-manifest.yaml`) +### Three Hook Registration Patterns -```yaml -description: "My custom plugin" -author: "Your Name" -version: "1.0.0" -available_hooks: - - "tool_pre_invoke" - - "tool_post_invoke" -default_configs: - my_setting: true - threshold: 0.8 -``` +The plugin framework supports three flexible patterns for registering hook methods: -### 3. Plugin Implementation +#### Pattern 1: Convention-Based (Recommended for Standard Hooks) + +The simplest approach - just name your method to match the hook type: ```python -# my_plugin.py -from mcpgateway.plugins.framework.base import Plugin -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, ToolPreInvokePayload, ToolPreInvokeResult, - PluginResult ) -class MyPlugin(MCPPlugin): - """Custom plugin implementation.""" +class MyPlugin(Plugin): + """Convention-based hook - method name matches hook type.""" + + async def tool_pre_invoke( + self, + payload: ToolPreInvokePayload, + context: PluginContext + ) -> ToolPreInvokeResult: + """This hook is automatically discovered by its name.""" + + # Your logic here + modified_args = {**payload.args, "processed": True} + + modified_payload = ToolPreInvokePayload( + name=payload.name, + args=modified_args, + headers=payload.headers + ) + + return ToolPreInvokeResult( + modified_payload=modified_payload, + metadata={"processed_by": self.name} + ) +``` + +**When to use:** Default choice for implementing standard framework hooks. - async def tool_pre_invoke(self, payload: ToolPreInvokePayload) -> ToolPreInvokeResult: - """Process tool invocation before execution.""" +#### Pattern 2: Decorator-Based (Custom Method Names) + +Use the `@hook` decorator to register a hook with a custom method name: + +```python +from mcpgateway.plugins.framework import Plugin, PluginContext +from mcpgateway.plugins.framework.decorator import hook +from mcpgateway.plugins.framework import ( + ToolHookType, + ToolPostInvokePayload, + ToolPostInvokeResult, +) - # Get plugin configuration - my_setting = self.config.get("my_setting", False) - threshold = self.config.get("threshold", 0.5) +class MyPlugin(Plugin): + """Decorator-based hook with custom method name.""" - # Implement your logic - if my_setting and self._should_block(payload): - return ToolPreInvokeResult( - result=PluginResult.BLOCK, - message="Request blocked by custom logic", - modified_payload=payload + @hook(ToolHookType.TOOL_POST_INVOKE) + async def my_custom_handler_name( + self, + payload: ToolPostInvokePayload, + context: PluginContext + ) -> ToolPostInvokeResult: + """Method name doesn't match hook type, but @hook decorator registers it.""" + + # Your logic here + return ToolPostInvokeResult(continue_processing=True) +``` + +**When to use:** When you want descriptive method names that better match your plugin's purpose. + +#### Pattern 3: Custom Hooks (Advanced) + +Register completely new hook types with custom payload and result types: + +```python +from mcpgateway.plugins.framework import Plugin, PluginContext, PluginPayload, PluginResult +from mcpgateway.plugins.framework.decorator import hook + +# Define custom payload type +class EmailPayload(PluginPayload): + recipient: str + subject: str + body: str + +# Define custom result type +class EmailResult(PluginResult[EmailPayload]): + pass + +class MyPlugin(Plugin): + """Custom hook with new hook type.""" + + @hook("email_pre_send", EmailPayload, EmailResult) + async def validate_email( + self, + payload: EmailPayload, + context: PluginContext + ) -> EmailResult: + """Completely new hook type: 'email_pre_send'""" + + # Validate email address + if "@" not in payload.recipient: + # Fix invalid email + modified_payload = EmailPayload( + recipient=f"{payload.recipient}@example.com", + subject=payload.subject, + body=payload.body + ) + return EmailResult( + modified_payload=modified_payload, + metadata={"fixed_email": True} ) - # Modify payload if needed - modified_payload = self._transform_payload(payload) + return EmailResult(continue_processing=True) +``` + +**When to use:** When extending the framework with domain-specific hook points not covered by standard hooks. + +### Hook Method Signature Requirements + +All hook methods must follow these rules: + +1. **Must be async**: All hooks are asynchronous +2. **Three parameters**: `self`, `payload`, `context` +3. **Type hints required** (for validation): Payload and result types must be properly typed +4. **Return appropriate result type**: Each hook returns a `PluginResult` typed with the hook's payload type + +```python +async def hook_name( + self, + payload: PayloadType, # Specific to the hook (e.g., ToolPreInvokePayload) + context: PluginContext # Always PluginContext +) -> PluginResult[PayloadType]: # PluginResult generic, parameterized by the payload type + """Hook implementation.""" + pass +``` + +**Understanding Result Types:** + +Each hook has a corresponding result type that is actually a type alias for `PluginResult[PayloadType]`: + +```python +# These are type aliases defined in the framework +ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] +ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] +PromptPrehookResult = PluginResult[PromptPrehookPayload] +# ... and so on for each hook type +``` + +This means when you return a result, you're returning a `PluginResult` instance that knows about the specific payload type: + +```python +# All of these are valid ways to construct results: +return ToolPreInvokeResult(continue_processing=True) +return ToolPreInvokeResult(modified_payload=new_payload) +return ToolPreInvokeResult( + modified_payload=new_payload, + metadata={"processed": True} +) +``` + +### Complete Plugin Example + +Here's a complete plugin showing all patterns: + +```python +# plugins/my_plugin/my_plugin.py +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, + PluginPayload, + PluginResult, + ToolPreInvokePayload, + ToolPreInvokeResult, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolHookType, +) +from mcpgateway.plugins.framework.decorator import hook + +class MyPlugin(Plugin): + """Example plugin demonstrating all three patterns.""" + + # Pattern 1: Convention-based + async def tool_pre_invoke( + self, + payload: ToolPreInvokePayload, + context: PluginContext + ) -> ToolPreInvokeResult: + """Pre-process tool invocation - found by naming convention.""" + + # Access plugin configuration + threshold = self.config.config.get("threshold", 0.5) + + # Modify payload + modified_args = {**payload.args, "plugin_processed": True} + modified_payload = ToolPreInvokePayload( + name=payload.name, + args=modified_args, + headers=payload.headers + ) return ToolPreInvokeResult( - result=PluginResult.CONTINUE, - modified_payload=modified_payload + modified_payload=modified_payload, + metadata={"threshold": threshold} ) - def _should_block(self, payload: ToolPreInvokePayload) -> bool: - """Custom blocking logic.""" - # Implement your validation logic here - return False + # Pattern 2: Decorator with custom name + @hook(ToolHookType.TOOL_POST_INVOKE) + async def process_tool_result( + self, + payload: ToolPostInvokePayload, + context: PluginContext + ) -> ToolPostInvokeResult: + """Post-process tool result - found via decorator.""" + + # Transform result + if isinstance(payload.result, dict): + modified_result = { + **payload.result, + "processed_by": self.name + } + modified_payload = ToolPostInvokePayload( + name=payload.name, + result=modified_result + ) + return ToolPostInvokeResult(modified_payload=modified_payload) + + return ToolPostInvokeResult(continue_processing=True) +``` + +### Plugin Structure + +Create a new directory under `plugins/`: + +``` +plugins/my_plugin/ +β”œβ”€β”€ __init__.py +β”œβ”€β”€ plugin-manifest.yaml +β”œβ”€β”€ my_plugin.py +└── README.md +``` + +### Plugin Manifest (`plugin-manifest.yaml`) - def _transform_payload(self, payload: ToolPreInvokePayload) -> ToolPreInvokePayload: - """Transform payload if needed.""" - return payload +```yaml +description: "My custom plugin" +author: "Your Name" +version: "1.0.0" +available_hooks: + - "tool_pre_invoke" + - "tool_post_invoke" +default_configs: + threshold: 0.8 + enable_logging: true ``` -### 4. Register Your Plugin +### Register Your Plugin Add to `plugins/config.yaml`: @@ -243,34 +444,88 @@ plugins: description: "My custom plugin description" version: "1.0.0" author: "Your Name" - hooks: ["tool_pre_invoke"] + hooks: ["tool_pre_invoke", "tool_post_invoke"] mode: "enforce" priority: 100 config: - my_setting: true threshold: 0.8 + enable_logging: true ``` ## Plugin Development Best Practices +### Hook Results and Control Flow + +Each hook returns a result object that controls execution flow: + +```python +# Allow processing to continue +return ToolPreInvokeResult(continue_processing=True) + +# Modify the payload +return ToolPreInvokeResult( + modified_payload=modified_payload, + metadata={"processed": True} +) + +# Block execution with a violation +from mcpgateway.plugins.framework import PluginViolation + +return ToolPreInvokeResult( + continue_processing=False, + violation=PluginViolation( + code="POLICY_VIOLATION", + reason="Request blocked by security policy", + description="Detected prohibited content" + ) +) +``` + ### Error Handling -Errors inside a plugin should be raised as exceptions. The plugin manager will catch the error, and its behavior depends on both the gateway's and plugin's configuration as follows: +Errors inside a plugin should be raised as exceptions. The plugin manager will catch the error, and its behavior depends on both the gateway's and plugin's configuration as follows: + +1. If `plugin_settings.fail_on_plugin_error` in the plugin `config.yaml` is set to `true`, the exception is bubbled up as a PluginError and the error is passed to the client of ContextForge regardless of the plugin mode. +2. If `plugin_settings.fail_on_plugin_error` is set to false, the error is handled based off of the plugin mode in the plugin's config as follows: + * If `mode` is `enforce`, both violations and errors are bubbled up as exceptions and the execution is blocked. + * If `mode` is `enforce_ignore_error`, violations are bubbled up as exceptions and execution is blocked, but errors are logged and execution continues. + * If `mode` is `permissive`, execution is allowed to proceed whether there are errors or violations. Both are logged. + +### Accessing Plugin Context + +The `context` parameter provides access to request-scoped and global state: -1. if `plugin_settings.fail_on_plugin_error` in the plugin `config.yaml` is set to `true` the exception is bubbled up as a PluginError and the error is passed to the client of ContextForge regardless of the plugin mode. -2. if `plugin_settings.fail_on_plugin_error` is set to false the error is handled based off of the plugin mode in the plugin's config as follows: - * if `mode` is `enforce`, both violations and errors are bubbled up as exceptions and the execution is blocked. - * if `mode` is `enforce_ignore_error`, violations are bubbled up as exceptions and execution is blocked, but errors are logged and execution continues. - * if `mode` is `permissive`, execution is allowed to proceed whether there are errors or violations. Both are logged. +```python +async def tool_pre_invoke( + self, + payload: ToolPreInvokePayload, + context: PluginContext +) -> ToolPreInvokeResult: + # Access request ID + request_id = context.global_context.request_id + + # Access user information + user = context.global_context.user + tenant_id = context.global_context.tenant_id + + # Store plugin-specific state (persists across pre/post hooks) + context.state["invocation_count"] = context.state.get("invocation_count", 0) + 1 + + # Add metadata + context.metadata["processing_time"] = 0.123 + + return ToolPreInvokeResult(continue_processing=True) +``` ### Logging and Monitoring + ```python def __init__(self, config: PluginConfig): super().__init__(config) self.logger.info(f"Initialized {self.name} v{self.version}") -async def tool_pre_invoke(self, payload: ToolPreInvokePayload) -> ToolPreInvokeResult: - self.logger.debug(f"Processing tool: {payload.tool_name}") +async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: + self.logger.debug(f"Processing tool: {payload.name}") # ... plugin logic self.metrics.increment("requests_processed") ``` @@ -278,14 +533,19 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload) -> ToolPreInvokeR ### Configuration Validation ```python -def validate_config(self) -> None: +def __init__(self, config: PluginConfig): + super().__init__(config) + self._validate_config() + +def _validate_config(self) -> None: """Validate plugin configuration.""" required_keys = ["threshold", "api_key"] for key in required_keys: - if key not in self.config: + if key not in self.config.config: raise ValueError(f"Missing required config key: {key}") - if not 0 <= self.config["threshold"] <= 1: + threshold = self.config.config.get("threshold") + if not 0 <= threshold <= 1: raise ValueError("threshold must be between 0 and 1") ``` @@ -299,16 +559,17 @@ def validate_config(self) -> None: ### Resource Management ```python -class MyPlugin(MCPPlugin): +class MyPlugin(Plugin): def __init__(self, config: PluginConfig): super().__init__(config) self._session = None - async def __aenter__(self): + async def initialize(self): + """Called when plugin is loaded.""" self._session = aiohttp.ClientSession() - return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def shutdown(self): + """Called when plugin manager shuts down.""" if self._session: await self._session.close() ``` @@ -316,30 +577,48 @@ class MyPlugin(MCPPlugin): ## Testing Plugins ### Unit Testing + ```python import pytest -from mcpgateway.plugins.framework.models import ToolPreInvokePayload, PluginConfig +from mcpgateway.plugins.framework import ( + PluginConfig, + PluginContext, + GlobalContext, + ToolPreInvokePayload, +) from plugins.my_plugin.my_plugin import MyPlugin @pytest.fixture def plugin(): config = PluginConfig( name="test_plugin", - config={"my_setting": True} + description="Test", + version="1.0", + author="Test", + kind="plugins.my_plugin.my_plugin.MyPlugin", + hooks=["tool_pre_invoke"], + config={"threshold": 0.8} ) return MyPlugin(config) +@pytest.mark.asyncio async def test_tool_pre_invoke(plugin): payload = ToolPreInvokePayload( - tool_name="test_tool", - arguments={"arg1": "value1"} + name="test_tool", + args={"arg1": "value1"} + ) + context = PluginContext( + global_context=GlobalContext(request_id="test-123") ) - result = await plugin.tool_pre_invoke(payload) - assert result.result == PluginResult.CONTINUE + result = await plugin.tool_pre_invoke(payload, context) + + assert result.continue_processing is True + assert result.modified_payload.args["plugin_processed"] is True ``` ### Integration Testing + ```bash # Test with live gateway make dev @@ -356,20 +635,39 @@ curl -X POST http://localhost:4444/tools/invoke \ 2. **Configuration errors**: Validate YAML syntax and required fields 3. **Performance issues**: Profile plugin execution time and optimize bottlenecks 4. **Hook not triggering**: Verify hook name matches available hooks in manifest +5. **Method signature errors**: Ensure hooks have correct parameters (self, payload, context) and are async ### Debug Mode + ```bash LOG_LEVEL=DEBUG make serve # port 4444 # Or with reloading dev server: LOG_LEVEL=DEBUG make dev # port 8000 ``` +### Testing Hook Discovery + +To verify your hooks are properly registered: + +```python +from mcpgateway.plugins.framework import PluginManager + +manager = PluginManager("path/to/config.yaml") +await manager.initialize() + +# Check loaded plugins +for plugin_config in manager.config.plugins: + print(f"Plugin: {plugin_config.name}") + print(f" Hooks: {plugin_config.hooks}") +``` + ## Documentation Links - **Plugin Usage Guide**: https://ibm.github.io/mcp-context-forge/using/plugins/ - **Plugin Lifecycle**: https://ibm.github.io/mcp-context-forge/using/plugins/lifecycle/ - **API Reference**: Generated from code docstrings - **Examples**: See `plugins/` directory for complete implementations +- **Hook Patterns Test**: `tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py` ## Performance Metrics @@ -387,3 +685,4 @@ The framework supports high-performance operations: - Error isolation between plugins - Comprehensive audit logging - Plugin configuration validation +- Hook signature validation at plugin load time diff --git a/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py b/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py index 42fadbdf3..923bb1ce0 100644 --- a/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py +++ b/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py @@ -21,9 +21,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPrehookPayload, PromptPrehookResult, ResourcePostFetchPayload, @@ -106,7 +104,7 @@ def _normalize_text(text: str, cfg: AINormalizerConfig) -> str: return out -class AIArtifactsNormalizerPlugin(MCPPlugin): +class AIArtifactsNormalizerPlugin(Plugin): """Plugin to normalize AI-generated text artifacts in prompts, resources, and tool results.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/altk_json_processor/json_processor.py b/plugins/altk_json_processor/json_processor.py index b1664b49d..df26cedd8 100644 --- a/plugins/altk_json_processor/json_processor.py +++ b/plugins/altk_json_processor/json_processor.py @@ -25,9 +25,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -38,7 +36,7 @@ logger = logging_service.get_logger(__name__) -class ALTKJsonProcessor(MCPPlugin): +class ALTKJsonProcessor(Plugin): """Uses JSON Processor from ALTK to extract data from long JSON responses.""" def __init__(self, config: PluginConfig): diff --git a/plugins/argument_normalizer/argument_normalizer.py b/plugins/argument_normalizer/argument_normalizer.py index 8a98057c9..8e847a7a4 100644 --- a/plugins/argument_normalizer/argument_normalizer.py +++ b/plugins/argument_normalizer/argument_normalizer.py @@ -29,9 +29,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, @@ -517,7 +515,7 @@ def _normalize_value(value: Any, base_cfg: ArgumentNormalizerConfig, path: str, return value -class ArgumentNormalizerPlugin(MCPPlugin): +class ArgumentNormalizerPlugin(Plugin): """Argument Normalizer plugin for prompts and tools.""" def __init__(self, config: PluginConfig): diff --git a/plugins/cached_tool_result/cached_tool_result.py b/plugins/cached_tool_result/cached_tool_result.py index 6d3674e19..cce7558b4 100644 --- a/plugins/cached_tool_result/cached_tool_result.py +++ b/plugins/cached_tool_result/cached_tool_result.py @@ -27,9 +27,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -88,7 +86,7 @@ def _make_key(tool: str, args: dict | None, fields: Optional[List[str]]) -> str: return hashlib.sha256(raw.encode("utf-8")).hexdigest() -class CachedToolResultPlugin(MCPPlugin): +class CachedToolResultPlugin(Plugin): """Cache idempotent tool results (write-through).""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/circuit_breaker/circuit_breaker.py b/plugins/circuit_breaker/circuit_breaker.py index 61def4820..f9e5de429 100644 --- a/plugins/circuit_breaker/circuit_breaker.py +++ b/plugins/circuit_breaker/circuit_breaker.py @@ -29,9 +29,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -140,7 +138,7 @@ def _is_error(result: Any) -> bool: return False -class CircuitBreakerPlugin(MCPPlugin): +class CircuitBreakerPlugin(Plugin): """Circuit breaker plugin to prevent cascading failures by tripping on high error rates.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/citation_validator/citation_validator.py b/plugins/citation_validator/citation_validator.py index 65c2bf1c4..44fdd4e80 100644 --- a/plugins/citation_validator/citation_validator.py +++ b/plugins/citation_validator/citation_validator.py @@ -27,9 +27,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -118,7 +116,7 @@ def _extract_links(text: str, limit: int) -> List[str]: return out -class CitationValidatorPlugin(MCPPlugin): +class CitationValidatorPlugin(Plugin): """Validates citations by checking URL reachability and content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/code_formatter/code_formatter.py b/plugins/code_formatter/code_formatter.py index c62cdf2da..47d3c2d09 100644 --- a/plugins/code_formatter/code_formatter.py +++ b/plugins/code_formatter/code_formatter.py @@ -30,9 +30,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -147,7 +145,7 @@ def _format_by_language(result: Any, cfg: CodeFormatterConfig, language: str | N return _normalize_text(text, cfg) -class CodeFormatterPlugin(MCPPlugin): +class CodeFormatterPlugin(Plugin): """Lightweight formatter for post-invoke and resource content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/code_safety_linter/code_safety_linter.py b/plugins/code_safety_linter/code_safety_linter.py index a886fda8c..c4c17768e 100644 --- a/plugins/code_safety_linter/code_safety_linter.py +++ b/plugins/code_safety_linter/code_safety_linter.py @@ -24,9 +24,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -50,7 +48,7 @@ class CodeSafetyConfig(BaseModel): ) -class CodeSafetyLinterPlugin(MCPPlugin): +class CodeSafetyLinterPlugin(Plugin): """Scan text outputs for dangerous code patterns.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/content_moderation/content_moderation.py b/plugins/content_moderation/content_moderation.py index 2a3a9e75a..50182e971 100644 --- a/plugins/content_moderation/content_moderation.py +++ b/plugins/content_moderation/content_moderation.py @@ -27,9 +27,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPostInvokePayload, @@ -176,7 +174,7 @@ class ModerationResult(BaseModel): details: Dict[str, Any] = Field(default_factory=dict, description="Additional details") -class ContentModerationPlugin(MCPPlugin): +class ContentModerationPlugin(Plugin): """Plugin for advanced content moderation using multiple AI providers.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/deny_filter/deny.py b/plugins/deny_filter/deny.py index 0e598f921..1b9b1e9b4 100644 --- a/plugins/deny_filter/deny.py +++ b/plugins/deny_filter/deny.py @@ -12,8 +12,14 @@ from pydantic import BaseModel # First-Party -from mcpgateway.plugins.framework import PluginConfig, PluginContext, PluginViolation -from mcpgateway.plugins.mcp.entities import MCPPlugin, PromptPrehookPayload, PromptPrehookResult +from mcpgateway.plugins.framework import ( + PluginConfig, + PluginContext, + PluginViolation, + Plugin, + PromptPrehookPayload, + PromptPrehookResult +) from mcpgateway.services.logging_service import LoggingService # Initialize logging service first @@ -31,7 +37,7 @@ class DenyListConfig(BaseModel): words: list[str] -class DenyListPlugin(MCPPlugin): +class DenyListPlugin(Plugin): """Example deny list plugin.""" def __init__(self, config: PluginConfig): diff --git a/plugins/external/clamav_server/clamav_plugin.py b/plugins/external/clamav_server/clamav_plugin.py index b593da62b..ba11e3467 100644 --- a/plugins/external/clamav_server/clamav_plugin.py +++ b/plugins/external/clamav_server/clamav_plugin.py @@ -34,9 +34,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -121,7 +119,7 @@ def _clamd_instream_scan_unix(path: str, data: bytes, timeout: float) -> str: s.close() -class ClamAVRemotePlugin(MCPPlugin): +class ClamAVRemotePlugin(Plugin): """External ClamAV plugin for scanning resources and content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/external/llmguard/llmguardplugin/plugin.py b/plugins/external/llmguard/llmguardplugin/plugin.py index a548a313a..afaa5a484 100644 --- a/plugins/external/llmguard/llmguardplugin/plugin.py +++ b/plugins/external/llmguard/llmguardplugin/plugin.py @@ -20,10 +20,7 @@ PluginError, PluginErrorModel, PluginViolation, -) -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -40,7 +37,7 @@ logger = logging_service.get_logger(__name__) -class LLMGuardPlugin(MCPPlugin): +class LLMGuardPlugin(Plugin): """A plugin that leverages the capabilities of llmguard library to apply guardrails on input and output prompts. Attributes: diff --git a/plugins/external/opa/opapluginfilter/plugin.py b/plugins/external/opa/opapluginfilter/plugin.py index 60867f8a0..4557d865a 100644 --- a/plugins/external/opa/opapluginfilter/plugin.py +++ b/plugins/external/opa/opapluginfilter/plugin.py @@ -22,9 +22,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -65,7 +63,7 @@ class OPAResponseTemplates(str, Enum): HookPayload: TypeAlias = ToolPreInvokePayload | ToolPostInvokePayload | PromptPosthookPayload | PromptPrehookPayload | ResourcePreFetchPayload | ResourcePostFetchPayload -class OPAPluginFilter(MCPPlugin): +class OPAPluginFilter(Plugin): """An OPA plugin that enforces rego policies on requests and allows/denies requests as per policies.""" def __init__(self, config: PluginConfig): diff --git a/plugins/file_type_allowlist/file_type_allowlist.py b/plugins/file_type_allowlist/file_type_allowlist.py index 5450e7524..9b2b62ab4 100644 --- a/plugins/file_type_allowlist/file_type_allowlist.py +++ b/plugins/file_type_allowlist/file_type_allowlist.py @@ -25,9 +25,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -62,7 +60,7 @@ def _ext_from_uri(uri: str) -> str: return "" -class FileTypeAllowlistPlugin(MCPPlugin): +class FileTypeAllowlistPlugin(Plugin): """Block non-allowed file types for resources.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/harmful_content_detector/harmful_content_detector.py b/plugins/harmful_content_detector/harmful_content_detector.py index c8c3a4900..3f9d0a48e 100644 --- a/plugins/harmful_content_detector/harmful_content_detector.py +++ b/plugins/harmful_content_detector/harmful_content_detector.py @@ -26,9 +26,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPostInvokePayload, @@ -121,7 +119,7 @@ def walk(obj: Any, path: str): yield from walk(value, "") -class HarmfulContentDetectorPlugin(MCPPlugin): +class HarmfulContentDetectorPlugin(Plugin): """Detects harmful content in prompts and tool outputs using keyword lexicons. This plugin scans for self-harm, violence, and hate categories. diff --git a/plugins/header_injector/header_injector.py b/plugins/header_injector/header_injector.py index c60cb8724..daa642155 100644 --- a/plugins/header_injector/header_injector.py +++ b/plugins/header_injector/header_injector.py @@ -24,9 +24,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePreFetchPayload, ResourcePreFetchResult, ) @@ -59,7 +57,7 @@ def _should_apply(uri: str, prefixes: Optional[list[str]]) -> bool: return any(uri.startswith(p) for p in prefixes) -class HeaderInjectorPlugin(MCPPlugin): +class HeaderInjectorPlugin(Plugin): """Inject custom headers for resource fetching.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/html_to_markdown/html_to_markdown.py b/plugins/html_to_markdown/html_to_markdown.py index f500c00e6..025a62ce4 100644 --- a/plugins/html_to_markdown/html_to_markdown.py +++ b/plugins/html_to_markdown/html_to_markdown.py @@ -22,9 +22,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ) @@ -87,7 +85,7 @@ def _pre_fallback(m): return text.strip() -class HTMLToMarkdownPlugin(MCPPlugin): +class HTMLToMarkdownPlugin(Plugin): """Transform HTML ResourceContent to Markdown in `text` field.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/json_repair/json_repair.py b/plugins/json_repair/json_repair.py index 565a2914a..f246faa1c 100644 --- a/plugins/json_repair/json_repair.py +++ b/plugins/json_repair/json_repair.py @@ -20,9 +20,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -72,7 +70,7 @@ def _repair(s: str) -> str | None: return None -class JSONRepairPlugin(MCPPlugin): +class JSONRepairPlugin(Plugin): """Repair JSON-like string outputs, returning corrected string if fixable.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/license_header_injector/license_header_injector.py b/plugins/license_header_injector/license_header_injector.py index e8c398dc7..5fc1e55b3 100644 --- a/plugins/license_header_injector/license_header_injector.py +++ b/plugins/license_header_injector/license_header_injector.py @@ -24,9 +24,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -90,7 +88,7 @@ def _inject_header(text: str, cfg: LicenseHeaderConfig, language: str) -> str: return f"{header_block}\n{text}" -class LicenseHeaderInjectorPlugin(MCPPlugin): +class LicenseHeaderInjectorPlugin(Plugin): """Inject a license header into textual code outputs.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/markdown_cleaner/markdown_cleaner.py b/plugins/markdown_cleaner/markdown_cleaner.py index 5b1d9cde7..a247e6a05 100644 --- a/plugins/markdown_cleaner/markdown_cleaner.py +++ b/plugins/markdown_cleaner/markdown_cleaner.py @@ -22,9 +22,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -54,7 +52,7 @@ def _clean_md(text: str) -> str: return text.strip() -class MarkdownCleanerPlugin(MCPPlugin): +class MarkdownCleanerPlugin(Plugin): """Clean Markdown in prompts and resources.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/output_length_guard/output_length_guard.py b/plugins/output_length_guard/output_length_guard.py index 7497cb885..4d2884d57 100644 --- a/plugins/output_length_guard/output_length_guard.py +++ b/plugins/output_length_guard/output_length_guard.py @@ -37,9 +37,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) @@ -100,7 +98,7 @@ def _truncate(value: str, max_chars: int, ellipsis: str) -> str: return value[:cut] + ell -class OutputLengthGuardPlugin(MCPPlugin): +class OutputLengthGuardPlugin(Plugin): """Guard tool outputs by length with block or truncate strategies.""" def __init__(self, config: PluginConfig): diff --git a/plugins/pii_filter/pii_filter.py b/plugins/pii_filter/pii_filter.py index 6ae59a5ed..4672deca8 100644 --- a/plugins/pii_filter/pii_filter.py +++ b/plugins/pii_filter/pii_filter.py @@ -22,9 +22,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -410,7 +408,7 @@ def _apply_mask(self, value: str, pii_type: PIIType, strategy: MaskingStrategy) return self.config.redaction_text -class PIIFilterPlugin(MCPPlugin): +class PIIFilterPlugin(Plugin): """PII Filter plugin for detecting and masking sensitive information.""" def __init__(self, config: PluginConfig): diff --git a/plugins/privacy_notice_injector/privacy_notice_injector.py b/plugins/privacy_notice_injector/privacy_notice_injector.py index b37ab4055..cd45058c3 100644 --- a/plugins/privacy_notice_injector/privacy_notice_injector.py +++ b/plugins/privacy_notice_injector/privacy_notice_injector.py @@ -23,9 +23,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, ) @@ -63,7 +61,7 @@ def _inject_text(existing: str, notice: str, placement: str) -> str: return existing -class PrivacyNoticeInjectorPlugin(MCPPlugin): +class PrivacyNoticeInjectorPlugin(Plugin): """Inject a privacy notice into prompt messages.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/rate_limiter/rate_limiter.py b/plugins/rate_limiter/rate_limiter.py index 67720afa9..74ba09a9e 100644 --- a/plugins/rate_limiter/rate_limiter.py +++ b/plugins/rate_limiter/rate_limiter.py @@ -25,9 +25,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, @@ -116,7 +114,7 @@ def _allow(key: str, limit: Optional[str]) -> tuple[bool, dict[str, Any]]: return False, {"limited": True, "remaining": 0, "reset_in": window_seconds - (now - wnd.window_start)} -class RateLimiterPlugin(MCPPlugin): +class RateLimiterPlugin(Plugin): """Simple fixed-window rate limiter with per-user/tenant/tool buckets.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/regex_filter/search_replace.py b/plugins/regex_filter/search_replace.py index 506f1fafd..ef6c59707 100644 --- a/plugins/regex_filter/search_replace.py +++ b/plugins/regex_filter/search_replace.py @@ -18,9 +18,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -54,7 +52,7 @@ class SearchReplaceConfig(BaseModel): words: list[SearchReplace] -class SearchReplacePlugin(MCPPlugin): +class SearchReplacePlugin(Plugin): """Example search replace plugin.""" def __init__(self, config: PluginConfig): diff --git a/plugins/resource_filter/resource_filter.py b/plugins/resource_filter/resource_filter.py index e4a481724..8a25aea4f 100644 --- a/plugins/resource_filter/resource_filter.py +++ b/plugins/resource_filter/resource_filter.py @@ -23,9 +23,7 @@ PluginContext, PluginMode, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -35,7 +33,7 @@ ) -class ResourceFilterPlugin(MCPPlugin): +class ResourceFilterPlugin(Plugin): """Plugin that filters and modifies resources. This plugin demonstrates the use of resource hooks to: diff --git a/plugins/response_cache_by_prompt/response_cache_by_prompt.py b/plugins/response_cache_by_prompt/response_cache_by_prompt.py index 6fc01533c..f84ff4d6c 100644 --- a/plugins/response_cache_by_prompt/response_cache_by_prompt.py +++ b/plugins/response_cache_by_prompt/response_cache_by_prompt.py @@ -30,9 +30,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -125,7 +123,7 @@ class _Entry: expires_at: float -class ResponseCacheByPromptPlugin(MCPPlugin): +class ResponseCacheByPromptPlugin(Plugin): """Approximate response cache keyed by prompt similarity.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/retry_with_backoff/retry_with_backoff.py b/plugins/retry_with_backoff/retry_with_backoff.py index 1cdbd9dd4..305da62a4 100644 --- a/plugins/retry_with_backoff/retry_with_backoff.py +++ b/plugins/retry_with_backoff/retry_with_backoff.py @@ -19,9 +19,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -45,7 +43,7 @@ class RetryPolicyConfig(BaseModel): retry_on_status: list[int] = Field(default_factory=lambda: [429, 500, 502, 503, 504]) -class RetryWithBackoffPlugin(MCPPlugin): +class RetryWithBackoffPlugin(Plugin): """Attach retry/backoff policy in metadata for observability/orchestration.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/robots_license_guard/robots_license_guard.py b/plugins/robots_license_guard/robots_license_guard.py index 820474930..5b7fe3a02 100644 --- a/plugins/robots_license_guard/robots_license_guard.py +++ b/plugins/robots_license_guard/robots_license_guard.py @@ -26,9 +26,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -89,7 +87,7 @@ def _parse_meta(text: str) -> dict[str, str]: return found -class RobotsLicenseGuardPlugin(MCPPlugin): +class RobotsLicenseGuardPlugin(Plugin): """Honors robots/noai/license meta tags in fetched HTML content.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/safe_html_sanitizer/safe_html_sanitizer.py b/plugins/safe_html_sanitizer/safe_html_sanitizer.py index ebf53d106..a6d68cca4 100644 --- a/plugins/safe_html_sanitizer/safe_html_sanitizer.py +++ b/plugins/safe_html_sanitizer/safe_html_sanitizer.py @@ -32,9 +32,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ) @@ -278,7 +276,7 @@ def _to_text(html_str: str) -> str: return re.sub(r"\n{3,}", "\n\n", no_tags).strip() -class SafeHTMLSanitizerPlugin(MCPPlugin): +class SafeHTMLSanitizerPlugin(Plugin): """Sanitizes HTML content to remove XSS vectors and dangerous elements.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/schema_guard/schema_guard.py b/plugins/schema_guard/schema_guard.py index b652aa8ff..132d21bbf 100644 --- a/plugins/schema_guard/schema_guard.py +++ b/plugins/schema_guard/schema_guard.py @@ -23,9 +23,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -105,7 +103,7 @@ def _validate(data: Any, schema: Dict[str, Any]) -> list[str]: return errors -class SchemaGuardPlugin(MCPPlugin): +class SchemaGuardPlugin(Plugin): """Validate tool args and results using a simple schema subset.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/secrets_detection/secrets_detection.py b/plugins/secrets_detection/secrets_detection.py index ecdf3e8f1..fb76c8411 100644 --- a/plugins/secrets_detection/secrets_detection.py +++ b/plugins/secrets_detection/secrets_detection.py @@ -26,9 +26,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPrehookPayload, PromptPrehookResult, ResourcePostFetchPayload, @@ -161,7 +159,7 @@ def _scan_container(container: Any, cfg: SecretsDetectionConfig) -> Tuple[int, A return total, container, all_findings -class SecretsDetectionPlugin(MCPPlugin): +class SecretsDetectionPlugin(Plugin): """Detect and optionally redact secrets in inputs/outputs.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/sql_sanitizer/sql_sanitizer.py b/plugins/sql_sanitizer/sql_sanitizer.py index c7b62b022..95d39f094 100644 --- a/plugins/sql_sanitizer/sql_sanitizer.py +++ b/plugins/sql_sanitizer/sql_sanitizer.py @@ -29,9 +29,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, @@ -159,7 +157,7 @@ def _scan_args(args: dict[str, Any] | None, cfg: SQLSanitizerConfig) -> tuple[li return issues, scanned -class SQLSanitizerPlugin(MCPPlugin): +class SQLSanitizerPlugin(Plugin): """Block or sanitize risky SQL statements in inputs.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/summarizer/summarizer.py b/plugins/summarizer/summarizer.py index ea936a27d..8f4a7990b 100644 --- a/plugins/summarizer/summarizer.py +++ b/plugins/summarizer/summarizer.py @@ -25,9 +25,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, @@ -262,7 +260,7 @@ def _maybe_get_text_from_result(result: Any) -> Optional[str]: return result if isinstance(result, str) else None -class SummarizerPlugin(MCPPlugin): +class SummarizerPlugin(Plugin): """Plugin to summarize long text content using LLM providers.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/timezone_translator/timezone_translator.py b/plugins/timezone_translator/timezone_translator.py index 2951b9eb6..ce1547db3 100644 --- a/plugins/timezone_translator/timezone_translator.py +++ b/plugins/timezone_translator/timezone_translator.py @@ -27,9 +27,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -133,7 +131,7 @@ def _walk_and_translate(value: Any, source: ZoneInfo, target: ZoneInfo, fields: return value -class TimezoneTranslatorPlugin(MCPPlugin): +class TimezoneTranslatorPlugin(Plugin): """Converts detected ISO timestamps between server and user timezones.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/url_reputation/url_reputation.py b/plugins/url_reputation/url_reputation.py index 50023e73a..4ea78b4b0 100644 --- a/plugins/url_reputation/url_reputation.py +++ b/plugins/url_reputation/url_reputation.py @@ -23,9 +23,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ResourcePreFetchPayload, ResourcePreFetchResult, ) @@ -43,7 +41,7 @@ class URLReputationConfig(BaseModel): blocked_patterns: List[str] = Field(default_factory=list) -class URLReputationPlugin(MCPPlugin): +class URLReputationPlugin(Plugin): """Static allow/deny URL reputation checks.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/vault/vault_plugin.py b/plugins/vault/vault_plugin.py index 4683606d3..4f23dd83a 100644 --- a/plugins/vault/vault_plugin.py +++ b/plugins/vault/vault_plugin.py @@ -24,9 +24,7 @@ from mcpgateway.plugins.framework import ( PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, HttpHeaderPayload, ToolPreInvokePayload, ToolPreInvokeResult, @@ -77,7 +75,7 @@ class VaultConfig(BaseModel): system_handling: SystemHandling = SystemHandling.TAG -class Vault(MCPPlugin): +class Vault(Plugin): """Vault plugin that based on OAUTH2 config that protects a tool will generate bearer token based on a vault saved token""" def __init__(self, config: PluginConfig): diff --git a/plugins/virus_total_checker/virus_total_checker.py b/plugins/virus_total_checker/virus_total_checker.py index b506916f3..5f4c2ba32 100644 --- a/plugins/virus_total_checker/virus_total_checker.py +++ b/plugins/virus_total_checker/virus_total_checker.py @@ -34,9 +34,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -334,7 +332,7 @@ def _apply_overrides(url: str, host: str | None, cfg: VirusTotalConfig) -> str | return None -class VirusTotalURLCheckerPlugin(MCPPlugin): +class VirusTotalURLCheckerPlugin(Plugin): """Query VirusTotal for URL/domain/IP verdicts and block on policy breaches.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/watchdog/watchdog.py b/plugins/watchdog/watchdog.py index 1fcf12b2d..d399e5e94 100644 --- a/plugins/watchdog/watchdog.py +++ b/plugins/watchdog/watchdog.py @@ -26,9 +26,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, @@ -50,7 +48,7 @@ class WatchdogConfig(BaseModel): tool_overrides: Dict[str, Dict[str, Any]] = {} -class WatchdogPlugin(MCPPlugin): +class WatchdogPlugin(Plugin): """Records tool execution duration and enforces maximum runtime policy.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins/webhook_notification/webhook_notification.py b/plugins/webhook_notification/webhook_notification.py index 4c2a686c1..ae888bf2d 100644 --- a/plugins/webhook_notification/webhook_notification.py +++ b/plugins/webhook_notification/webhook_notification.py @@ -30,9 +30,7 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -119,7 +117,7 @@ class WebhookNotificationConfig(BaseModel): max_payload_size: int = Field(default=1000, description="Max payload size to include in notifications") -class WebhookNotificationPlugin(MCPPlugin): +class WebhookNotificationPlugin(Plugin): """Plugin for sending webhook notifications on events and violations.""" def __init__(self, config: PluginConfig) -> None: diff --git a/plugins_rust/docs/implementation-guide.md b/plugins_rust/docs/implementation-guide.md index efd520730..6cb71a431 100644 --- a/plugins_rust/docs/implementation-guide.md +++ b/plugins_rust/docs/implementation-guide.md @@ -314,7 +314,7 @@ except ImportError: RUST_AVAILABLE = False -class PIIFilterPlugin(MCPPlugin): +class PIIFilterPlugin(Plugin): """PII Filter with automatic Rust/Python selection.""" def __init__(self, config: PluginConfig): diff --git a/tests/integration/test_resource_plugin_integration.py b/tests/integration/test_resource_plugin_integration.py index 2a5ef2ab7..850dfc7c4 100644 --- a/tests/integration/test_resource_plugin_integration.py +++ b/tests/integration/test_resource_plugin_integration.py @@ -132,7 +132,7 @@ async def test_resource_filtering_integration(self, test_db): # Use real plugin manager but mock its initialization with patch("mcpgateway.services.resource_service.PluginManager") as MockPluginManager: # First-Party - from mcpgateway.plugins.mcp.entities import ( + from mcpgateway.plugins.framework import ( ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchResult, @@ -152,9 +152,9 @@ def initialized(self) -> bool: async def invoke_hook(self, hook_type, payload, global_context, local_contexts=None, **kwargs): # First-Party - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ResourceHookType - if hook_type == HookType.RESOURCE_PRE_FETCH: + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: # Allow test:// protocol if payload.uri.startswith("test://"): return ( @@ -177,7 +177,7 @@ async def invoke_hook(self, hook_type, payload, global_context, local_contexts=N details={"protocol": payload.uri.split(":")[0], "uri": payload.uri}, ), ) - elif hook_type == HookType.RESOURCE_POST_FETCH: + elif hook_type == ResourceHookType.RESOURCE_POST_FETCH: # Filter sensitive content if payload.content and payload.content.text: filtered_text = payload.content.text.replace( @@ -265,12 +265,12 @@ async def test_plugin_context_flow(self, test_db, resource_service_with_mock_plu # Track context flow # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ResourceHookType contexts_from_pre = {"plugin_data": "test_value", "validated": True} async def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.RESOURCE_PRE_FETCH: + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: # Verify global context assert global_context.request_id == "integration-test-123" assert global_context.user == "integration-user" @@ -279,7 +279,7 @@ async def invoke_hook_side_effect(hook_type, payload, global_context, local_cont PluginResult(continue_processing=True, modified_payload=None), contexts_from_pre, ) - elif hook_type == HookType.RESOURCE_POST_FETCH: + elif hook_type == ResourceHookType.RESOURCE_POST_FETCH: # Verify contexts from pre-fetch assert local_contexts == contexts_from_pre assert local_contexts["plugin_data"] == "test_value" diff --git a/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py b/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py index 4a9c67d30..ac7f480e2 100644 --- a/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py +++ b/tests/unit/mcpgateway/plugins/agent/test_agent_plugins.py @@ -13,7 +13,7 @@ # First-Party from mcpgateway.common.models import Message, Role, TextContent from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError -from mcpgateway.plugins.agent import ( +from mcpgateway.plugins.framework import ( AgentHookType, AgentPreInvokePayload, AgentPostInvokePayload, @@ -28,7 +28,7 @@ async def test_agent_passthrough_plugin(): # Verify plugin loaded assert manager.config.plugins[0].name == "PassThroughAgent" - assert manager.config.plugins[0].kind == "tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.PassThroughAgentPlugin" + assert manager.config.plugins[0].kind == "tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.PassThroughAgentPlugin" assert AgentHookType.AGENT_PRE_INVOKE.value in manager.config.plugins[0].hooks assert AgentHookType.AGENT_POST_INVOKE.value in manager.config.plugins[0].hooks diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml index 74d4328b9..68d7f400f 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_context.yaml @@ -1,6 +1,6 @@ plugins: - name: ContextTrackingAgent - kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.ContextTrackingAgentPlugin + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.ContextTrackingAgentPlugin description: An agent plugin that tracks state in local context version: "1.0.0" author: Test Suite diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml index f5f927d1f..9d31a5061 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_filter.yaml @@ -1,6 +1,6 @@ plugins: - name: MessageFilterAgent - kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.MessageFilterAgentPlugin + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.MessageFilterAgentPlugin description: An agent plugin that filters blocked words version: "1.0.0" author: Test Suite diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml index 3525dc3cc..31793520a 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml @@ -1,6 +1,6 @@ plugins: - name: PassThroughAgent - kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_test.PassThroughAgentPlugin + kind: tests.unit.mcpgateway.plugins.fixtures.plugins.agent_plugins.PassThroughAgentPlugin description: A simple pass-through agent plugin for testing version: "1.0.0" author: Test Suite diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml new file mode 100644 index 000000000..072952ded --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml @@ -0,0 +1,26 @@ +plugins: + - name: DemoPlugin + kind: test_hook_patterns.DemoPlugin + description: Demonstration plugin showing all three hook patterns + version: "1.0.0" + author: Demo + hooks: + - tool_pre_invoke + - tool_post_invoke + - email_pre_send + tags: + - demo + - test + mode: enforce + priority: 50 + +# Plugin directories to scan (not needed for this demo) +plugin_dirs: [] + +# Global plugin settings +plugin_settings: + parallel_execution_within_band: true + plugin_timeout: 30 + fail_on_plugin_error: true + enable_plugin_api: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py similarity index 96% rename from tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py rename to tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py index 20c33bb44..7112a2c11 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""Location: ./tests/unit/mcpgateway/plugins/fixtures/plugins/agent_test.py +"""Location: ./tests/unit/mcpgateway/plugins/fixtures/plugins/agent_plugins.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 Authors: Teryl Taylor @@ -9,9 +9,9 @@ # First-Party from mcpgateway.common.models import Message, Role, TextContent -from mcpgateway.plugins.framework import PluginContext -from mcpgateway.plugins.agent import ( - AgentPlugin, +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, AgentPreInvokePayload, AgentPreInvokeResult, AgentPostInvokePayload, @@ -19,7 +19,7 @@ ) -class PassThroughAgentPlugin(AgentPlugin): +class PassThroughAgentPlugin(Plugin): """A simple pass-through agent plugin that doesn't modify anything.""" async def agent_pre_invoke( @@ -51,7 +51,7 @@ async def agent_post_invoke( return AgentPostInvokeResult(continue_processing=True) -class MessageFilterAgentPlugin(AgentPlugin): +class MessageFilterAgentPlugin(Plugin): """An agent plugin that filters messages containing blocked words.""" async def agent_pre_invoke( @@ -153,7 +153,7 @@ async def agent_post_invoke( return AgentPostInvokeResult(continue_processing=True) -class ContextTrackingAgentPlugin(AgentPlugin): +class ContextTrackingAgentPlugin(Plugin): """An agent plugin that tracks state in local context.""" async def agent_pre_invoke( diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py index c5b3fc354..e8e251ebb 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/context.py @@ -8,9 +8,9 @@ Context plugin. """ -from mcpgateway.plugins.framework import PluginContext -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, +from mcpgateway.plugins.framework import ( + PluginContext, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -26,7 +26,7 @@ ) -class ContextPlugin(MCPPlugin): +class ContextPlugin(Plugin): """A simple Context plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: @@ -111,7 +111,7 @@ async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: Pl return ResourcePreFetchResult(continue_processing=True) -class ContextPlugin2(MCPPlugin): +class ContextPlugin2(Plugin): """A simple Context plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py index e0d44f874..32279ad2d 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/error.py @@ -8,9 +8,9 @@ Error plugin. """ -from mcpgateway.plugins.framework import PluginContext -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, +from mcpgateway.plugins.framework import ( + PluginContext, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -26,7 +26,7 @@ ) -class ErrorPlugin(MCPPlugin): +class ErrorPlugin(Plugin): """A simple error plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py index 0d61aadd5..00b95faa0 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/headers.py @@ -14,9 +14,7 @@ from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.plugins.framework import ( PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, + Plugin, HttpHeaderPayload, PromptPosthookPayload, PromptPosthookResult, @@ -35,7 +33,7 @@ logger = logging.getLogger("header_plugin") -class HeadersMetaDataPlugin(MCPPlugin): +class HeadersMetaDataPlugin(Plugin): """A simple header plugin to read and modify headers.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: @@ -142,7 +140,7 @@ async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: Pl return ResourcePreFetchResult(continue_processing=True) -class HeadersPlugin(MCPPlugin): +class HeadersPlugin(Plugin): """A simple header plugin to read and modify headers.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py index b858b8ea8..9f6c4b3d2 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/passthrough.py @@ -8,9 +8,9 @@ """ # First-Party -from mcpgateway.plugins.framework import PluginContext -from mcpgateway.plugins.mcp.entities import ( - MCPPlugin, +from mcpgateway.plugins.framework import ( + PluginContext, + Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -26,7 +26,7 @@ ) -class PassThroughPlugin(MCPPlugin): +class PassThroughPlugin(Plugin): """A simple pass through plugin.""" async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py new file mode 100644 index 000000000..287fc3ab5 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/fixtures/plugins/simple.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Test Suite + +Simple minimal plugins for testing the plugin framework. +These plugins provide basic passthrough implementations for testing +registration, priority sorting, hook filtering, etc. +""" + +# First-Party +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, + PromptPosthookPayload, + PromptPosthookResult, + PromptPrehookPayload, + PromptPrehookResult, + ToolPostInvokePayload, + ToolPostInvokeResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) + + +class SimplePromptPlugin(Plugin): + """Minimal plugin with prompt hooks for testing.""" + + async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: + """Passthrough prompt pre-fetch hook.""" + return PromptPrehookResult(continue_processing=True) + + async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: + """Passthrough prompt post-fetch hook.""" + return PromptPosthookResult(continue_processing=True) + + +class SimpleToolPlugin(Plugin): + """Minimal plugin with tool hooks for testing.""" + + async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: + """Passthrough tool pre-invoke hook.""" + return ToolPreInvokeResult(continue_processing=True) + + async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: + """Passthrough tool post-invoke hook.""" + return ToolPostInvokeResult(continue_processing=True) diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py index 4d979f873..1d675a70f 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/server/test_runtime.py @@ -18,8 +18,6 @@ from mcpgateway.plugins.framework import ( GlobalContext, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py index 0f7c3bffc..5c6267ebf 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_config.py @@ -22,9 +22,9 @@ ConfigLoader, GlobalContext, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, + ResourceHookType, + ToolHookType, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -124,35 +124,35 @@ async def test_hook_methods_empty_content(): # Test prompt_pre_fetch with empty content - should raise PluginError payload = PromptPrehookPayload(prompt_id="1", args={}) with pytest.raises(PluginError): - await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) + await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Test prompt_post_fetch with empty content - should raise PluginError message = Message(content=TextContent(type="text", text="test"), role=Role.USER) prompt_result = PromptResult(messages=[message]) payload = PromptPosthookPayload(prompt_id="1", result=prompt_result) with pytest.raises(PluginError): - await plugin.invoke_hook(HookType.PROMPT_POST_FETCH, payload, context) + await plugin.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload, context) # Test tool_pre_invoke with empty content - should raise PluginError payload = ToolPreInvokePayload(name="test", args={}) with pytest.raises(PluginError): - await plugin.invoke_hook(HookType.TOOL_PRE_INVOKE, payload, context) + await plugin.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, context) # Test tool_post_invoke with empty content - should raise PluginError payload = ToolPostInvokePayload(name="test", result={}) with pytest.raises(PluginError): - await plugin.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) + await plugin.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) # Test resource_pre_fetch with empty content - should raise PluginError payload = ResourcePreFetchPayload(uri="file://test.txt") with pytest.raises(PluginError): - await plugin.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, context) + await plugin.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, context) # Test resource_post_fetch with empty content - should raise PluginError resource_content = ResourceContent(type="resource", id="123",uri="file://test.txt", text="content") payload = ResourcePostFetchPayload(uri="file://test.txt", content=resource_content) with pytest.raises(PluginError): - await plugin.invoke_hook(HookType.RESOURCE_POST_FETCH, payload, context) + await plugin.invoke_hook(ResourceHookType.RESOURCE_POST_FETCH, payload, context) await plugin.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py index 44405c912..5b3ea2538 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_stdio.py @@ -29,9 +29,9 @@ PluginContext, PluginLoader, PluginManager, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, + ResourceHookType, + ToolHookType, PromptPosthookPayload, PromptPrehookPayload, ResourcePostFetchPayload, @@ -51,7 +51,7 @@ async def test_client_load_stdio(): loader = PluginLoader() plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"text": "That was innovative!"}) - result = await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) + result = await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) assert result.violation assert result.violation.reason == "Prompt not allowed" assert result.violation.description == "A deny word was found in the prompt" @@ -75,7 +75,7 @@ async def test_client_load_stdio_overrides(): loader = PluginLoader() plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args = {"text": "That was innovative!"}) - result = await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) + result = await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, PluginContext(global_context=GlobalContext(request_id="1", server_id="2"))) assert result.violation assert result.violation.reason == "Prompt not allowed" assert result.violation.description == "A deny word was found in the prompt" @@ -101,7 +101,7 @@ async def test_client_load_stdio_post_prompt(): plugin = await loader.load_and_instantiate_plugin(config.plugins[0]) prompt = PromptPrehookPayload(prompt_id="test_prompt", args = {"user": "What a crapshow!"}) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) - result = await plugin.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, context) + result = await plugin.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, context) assert result.modified_payload.args["user"] == "What a yikesshow!" config = plugin.config assert config.name == "ReplaceBadWordsPlugin" @@ -114,7 +114,7 @@ async def test_client_load_stdio_post_prompt(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result = await plugin.invoke_hook(HookType.PROMPT_POST_FETCH, payload_result, context=context) + result = await plugin.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload_result, context=context) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "What the yikes?" await plugin.shutdown() @@ -188,7 +188,7 @@ async def test_hooks(): await plugin_manager.initialize() payload = PromptPrehookPayload(prompt_id="test_prompt", name="test_prompt", args={"arg0": "This is a crap argument"}) global_context = GlobalContext(request_id="1") - result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test prompt post hook across all registered plugins.""" @@ -196,31 +196,31 @@ async def test_hooks(): message = Message(content=TextContent(type="text", text="prompt"), role=Role.USER) prompt_result = PromptResult(messages=[message]) payload = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_POST_FETCH, payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test tool pre hook across all registered plugins.""" # Customize payload for testing payload = ToolPreInvokePayload(name="test_prompt", args={"arg0": "This is an argument"}) - result, _ = await plugin_manager.invoke_hook(HookType.TOOL_PRE_INVOKE, payload, global_context) + result, _ = await plugin_manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing """Test tool post hook across all registered plugins.""" # Customize payload for testing payload = ToolPostInvokePayload(name="test_tool", result={"output0": "output value"}) - result, _ = await plugin_manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, global_context) + result, _ = await plugin_manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, global_context) # Assert expected behaviors assert result.continue_processing payload = ResourcePreFetchPayload(uri="file:///data.txt") - result, _ = await plugin_manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) + result, _ = await plugin_manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing content = ResourceContent(type="resource", id="123", uri="file:///data.txt", text="Hello World") payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) - result, _ = await plugin_manager.invoke_hook(HookType.RESOURCE_POST_FETCH, payload, global_context) + result, _ = await plugin_manager.invoke_hook(ResourceHookType.RESOURCE_POST_FETCH, payload, global_context) # Assert expected behaviors assert result.continue_processing await plugin_manager.shutdown() @@ -236,7 +236,7 @@ async def test_errors(): global_context = GlobalContext(request_id="1") escaped_regex = re.escape("ValueError('Sadly! Prompt prefetch is broken!')") with pytest.raises(PluginError, match=escaped_regex): - await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) + await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) await plugin_manager.shutdown() @@ -253,7 +253,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 2 ctxs = [contexts[key] for key in contexts.keys()] @@ -282,7 +282,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): assert result.modified_payload is None # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) ctxs = [contexts[key] for key in contexts.keys()] assert len(ctxs) == 2 diff --git a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py index 72964d197..05dcbfbd4 100644 --- a/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py +++ b/tests/unit/mcpgateway/plugins/framework/external/mcp/test_client_streamable_http.py @@ -18,8 +18,7 @@ # First-Party from mcpgateway.common.models import Message, PromptResult, Role, TextContent -from mcpgateway.plugins.framework import ConfigLoader, GlobalContext, PluginContext, PluginLoader -from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import ConfigLoader, GlobalContext, PluginContext, PluginLoader, PromptPosthookPayload, PromptPrehookPayload @pytest.fixture(autouse=True) diff --git a/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py new file mode 100644 index 000000000..11291fdae --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/plugins/framework/hooks/test_hook_patterns.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests demonstrating three hook patterns in the plugin framework: +1. Convention-based: method name matches hook type +2. Decorator-based: @hook decorator with custom method name +3. Custom hook: @hook decorator with new hook type + payload/result types +""" + +# Third-Party +import pytest + +# First-Party +from mcpgateway.plugins.framework import ( + Plugin, + PluginContext, + GlobalContext, + PluginManager, + PluginPayload, + PluginResult, + ToolHookType, + ToolPreInvokePayload, + ToolPreInvokeResult, + ToolPostInvokePayload, + ToolPostInvokeResult, +) +from mcpgateway.plugins.framework.decorator import hook + + +# ========== Custom Hook Definition ========== +class EmailPayload(PluginPayload): + """Payload for email hook.""" + + recipient: str + subject: str + body: str + + +class EmailResult(PluginResult[EmailPayload]): + """Result for email hook.""" + + pass + + +# ========== Demo Plugin with All Three Patterns ========== +class DemoPlugin(Plugin): + """Demo plugin showing all three hook patterns.""" + + # Pattern 1: Convention-based (method name matches hook type) + async def tool_pre_invoke( + self, payload: ToolPreInvokePayload, context: PluginContext + ) -> ToolPreInvokeResult: + """Pattern 1: Convention-based hook. + + This method is found automatically because its name matches + the hook type 'tool_pre_invoke'. + """ + # Modify the payload + modified_payload = ToolPreInvokePayload( + name=payload.name, + args={**payload.args, "pattern": "convention"}, + headers=payload.headers, + ) + + return ToolPreInvokeResult( + modified_payload=modified_payload, + metadata={"pattern": "convention", "hook": "tool_pre_invoke"} + ) + + # Pattern 2: Decorator-based with custom method name + @hook(ToolHookType.TOOL_POST_INVOKE) + async def my_custom_tool_post_handler( + self, payload: ToolPostInvokePayload, context: PluginContext + ) -> ToolPostInvokeResult: + """Pattern 2: Decorator-based hook with custom method name. + + This method is found via the @hook decorator even though + the method name doesn't match the hook type. + """ + # Modify the result + modified_result = {**payload.result, "pattern": "decorator"} if isinstance(payload.result, dict) else payload.result + + modified_payload = ToolPostInvokePayload( + name=payload.name, + result=modified_result, + ) + + return ToolPostInvokeResult( + modified_payload=modified_payload, + metadata={"pattern": "decorator", "hook": "tool_post_invoke"} + ) + + # Pattern 3: Custom hook with payload and result types + @hook("email_pre_send", EmailPayload, EmailResult) + async def validate_email( + self, payload: EmailPayload, context: PluginContext + ) -> EmailResult: + """Pattern 3: Custom hook with new hook type. + + This registers a completely new hook type 'email_pre_send' + with its own payload and result types. + """ + # Validate email + if "@" not in payload.recipient: + modified_payload = EmailPayload( + recipient=f"{payload.recipient}@example.com", + subject=payload.subject, + body=payload.body, + ) + return EmailResult( + modified_payload=modified_payload, + metadata={"pattern": "custom", "hook": "email_pre_send", "fixed_email": True} + ) + + return EmailResult( + continue_processing=True, + metadata={"pattern": "custom", "hook": "email_pre_send"} + ) + + +# ========== Pytest Tests ========== +@pytest.mark.asyncio +async def test_pattern_1_convention_based_hook(): + """Test Pattern 1: Convention-based hook (method name matches hook type).""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Create payload for tool_pre_invoke + payload = ToolPreInvokePayload( + name="my_calculator", + args={"operation": "add", "a": 5, "b": 3} + ) + + global_context = GlobalContext(request_id="test-1") + + # Invoke the hook + result, contexts = await manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is not None + assert result.modified_payload.name == "my_calculator" + assert result.modified_payload.args["operation"] == "add" + assert result.modified_payload.args["a"] == 5 + assert result.modified_payload.args["b"] == 3 + assert result.modified_payload.args["pattern"] == "convention" # Added by hook + assert result.metadata is not None + assert result.metadata["pattern"] == "convention" + assert result.metadata["hook"] == "tool_pre_invoke" + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_pattern_2_decorator_based_hook(): + """Test Pattern 2: Decorator-based hook with custom method name.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Create payload for tool_post_invoke + payload = ToolPostInvokePayload( + name="my_calculator", + result={"sum": 8, "status": "success"} + ) + + global_context = GlobalContext(request_id="test-2") + + # Invoke the hook + result, contexts = await manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is not None + assert result.modified_payload.name == "my_calculator" + assert result.modified_payload.result["sum"] == 8 + assert result.modified_payload.result["status"] == "success" + assert result.modified_payload.result["pattern"] == "decorator" # Added by hook + assert result.metadata is not None + assert result.metadata["pattern"] == "decorator" + assert result.metadata["hook"] == "tool_post_invoke" + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_pattern_3_custom_hook_valid_email(): + """Test Pattern 3: Custom hook with new hook type (valid email).""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Test with valid email + payload = EmailPayload( + recipient="user@example.com", + subject="Test Email", + body="This is a test." + ) + + global_context = GlobalContext(request_id="test-3a") + + result, contexts = await manager.invoke_hook( + "email_pre_send", + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is None # No modification needed for valid email + assert result.metadata is not None + assert result.metadata["pattern"] == "custom" + assert result.metadata["hook"] == "email_pre_send" + assert "fixed_email" not in result.metadata # Email was already valid + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_pattern_3_custom_hook_invalid_email(): + """Test Pattern 3: Custom hook with new hook type (invalid email gets fixed).""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + # Test with invalid email (missing @) + payload = EmailPayload( + recipient="invalid-email", + subject="Test Email 2", + body="This email address needs fixing." + ) + + global_context = GlobalContext(request_id="test-3b") + + result, contexts = await manager.invoke_hook( + "email_pre_send", + payload, + global_context=global_context + ) + + # Assertions + assert result is not None + assert result.continue_processing is True + assert result.modified_payload is not None + assert result.modified_payload.recipient == "invalid-email@example.com" # Fixed by hook + assert result.modified_payload.subject == "Test Email 2" + assert result.modified_payload.body == "This email address needs fixing." + assert result.metadata is not None + assert result.metadata["pattern"] == "custom" + assert result.metadata["hook"] == "email_pre_send" + assert result.metadata["fixed_email"] is True # Hook fixed the email + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_all_three_patterns_in_sequence(): + """Test all three patterns work together in the same plugin manager.""" + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/test_hook_patterns_config.yaml") + await manager.initialize() + + global_context = GlobalContext(request_id="test-all") + + # Test Pattern 1: Convention-based + payload1 = ToolPreInvokePayload( + name="test_tool", + args={"param": "value"} + ) + result1, _ = await manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload1, + global_context=global_context + ) + assert result1.modified_payload.args["pattern"] == "convention" + + # Test Pattern 2: Decorator-based + payload2 = ToolPostInvokePayload( + name="test_tool", + result={"data": "output"} + ) + result2, _ = await manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload2, + global_context=global_context + ) + assert result2.modified_payload.result["pattern"] == "decorator" + + # Test Pattern 3: Custom hook + payload3 = EmailPayload( + recipient="test", + subject="Test", + body="Test" + ) + result3, _ = await manager.invoke_hook( + "email_pre_send", + payload3, + global_context=global_context + ) + assert result3.modified_payload.recipient == "test@example.com" + + await manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_registry.py b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_registry.py new file mode 100644 index 000000000..c54a05770 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/framework/hooks/test_hook_registry.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 Β© IBM Corporation +SPDX-License-Identifier: Apache-2.0 + +Test suite for hook registry functionality. +""" + +# Third-Party +import pytest + +# First-Party +from mcpgateway.plugins.framework import ( + get_hook_registry, + AgentHookType, + PromptHookType, + ResourceHookType, + ToolHookType, + PromptPrehookPayload, + PromptPrehookResult, + PromptPosthookPayload, + PromptPosthookResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) + + +class TestHookRegistry: + """Test cases for the HookRegistry class.""" + + @pytest.fixture + def registry(self): + """Provide a hook registry instance.""" + return get_hook_registry() + + def test_mcp_hooks_are_registered(self, registry): + """Test that all MCP hooks are registered.""" + assert registry.is_registered(PromptHookType.PROMPT_PRE_FETCH) + assert registry.is_registered(PromptHookType.PROMPT_POST_FETCH) + assert registry.is_registered(ToolHookType.TOOL_PRE_INVOKE) + assert registry.is_registered(ToolHookType.TOOL_POST_INVOKE) + assert registry.is_registered(ResourceHookType.RESOURCE_PRE_FETCH) + assert registry.is_registered(ResourceHookType.RESOURCE_POST_FETCH) + + def test_get_payload_type(self, registry): + """Test retrieving payload types from registry.""" + payload_type = registry.get_payload_type(PromptHookType.PROMPT_PRE_FETCH) + assert payload_type == PromptPrehookPayload + + payload_type = registry.get_payload_type(PromptHookType.PROMPT_POST_FETCH) + assert payload_type == PromptPosthookPayload + + payload_type = registry.get_payload_type(ToolHookType.TOOL_PRE_INVOKE) + assert payload_type == ToolPreInvokePayload + + def test_get_result_type(self, registry): + """Test retrieving result types from registry.""" + result_type = registry.get_result_type(PromptHookType.PROMPT_PRE_FETCH) + assert result_type == PromptPrehookResult + + result_type = registry.get_result_type(PromptHookType.PROMPT_POST_FETCH) + assert result_type == PromptPosthookResult + + result_type = registry.get_result_type(ToolHookType.TOOL_PRE_INVOKE) + assert result_type == ToolPreInvokeResult + + def test_get_unregistered_hook_returns_none(self, registry): + """Test that unregistered hooks return None.""" + assert registry.get_payload_type("unknown_hook") is None + assert registry.get_result_type("unknown_hook") is None + assert not registry.is_registered("unknown_hook") + + def test_json_to_payload_with_dict(self, registry): + """Test converting dictionary to payload.""" + payload_dict = {"prompt_id": "test", "args": {"key": "value"}} + payload = registry.json_to_payload(PromptHookType.PROMPT_PRE_FETCH, payload_dict) + + assert isinstance(payload, PromptPrehookPayload) + assert payload.prompt_id == "test" + assert payload.args["key"] == "value" + + def test_json_to_payload_with_json_string(self, registry): + """Test converting JSON string to payload.""" + payload_json = '{"prompt_id": "test", "args": {"key": "value"}}' + payload = registry.json_to_payload(PromptHookType.PROMPT_PRE_FETCH, payload_json) + + assert isinstance(payload, PromptPrehookPayload) + assert payload.prompt_id == "test" + assert payload.args["key"] == "value" + + def test_json_to_result_with_dict(self, registry): + """Test converting dictionary to result.""" + result_dict = {"continue_processing": True, "modified_payload": None} + result = registry.json_to_result(PromptHookType.PROMPT_PRE_FETCH, result_dict) + + assert isinstance(result, PromptPrehookResult) + assert result.continue_processing is True + + def test_json_to_result_with_json_string(self, registry): + """Test converting JSON string to result.""" + result_json = '{"continue_processing": false, "modified_payload": null}' + result = registry.json_to_result(PromptHookType.PROMPT_PRE_FETCH, result_json) + + assert isinstance(result, PromptPrehookResult) + assert result.continue_processing is False + + def test_json_to_payload_unregistered_hook_raises_error(self, registry): + """Test that converting payload for unregistered hook raises ValueError.""" + with pytest.raises(ValueError, match="No payload type registered for hook"): + registry.json_to_payload("unknown_hook", {}) + + def test_json_to_result_unregistered_hook_raises_error(self, registry): + """Test that converting result for unregistered hook raises ValueError.""" + with pytest.raises(ValueError, match="No result type registered for hook"): + registry.json_to_result("unknown_hook", {}) + + def test_get_registered_hooks(self, registry): + """Test retrieving all registered hook types.""" + hooks = registry.get_registered_hooks() + + assert isinstance(hooks, list) + assert len(hooks) >= 8 # At least the 6 MCP hooks + assert PromptHookType.PROMPT_PRE_FETCH in hooks + assert PromptHookType.PROMPT_POST_FETCH in hooks + assert ToolHookType.TOOL_PRE_INVOKE in hooks + assert ToolHookType.TOOL_POST_INVOKE in hooks + assert ResourceHookType.RESOURCE_PRE_FETCH in hooks + assert ResourceHookType.RESOURCE_POST_FETCH in hooks + assert AgentHookType.AGENT_POST_INVOKE in hooks + assert AgentHookType.AGENT_PRE_INVOKE in hooks + + def test_registry_is_singleton(self): + """Test that get_hook_registry returns the same instance.""" + registry1 = get_hook_registry() + registry2 = get_hook_registry() + + assert registry1 is registry2 diff --git a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py index fa6b48d66..a0d54bf40 100644 --- a/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py +++ b/tests/unit/mcpgateway/plugins/framework/loader/test_plugin_loader.py @@ -17,11 +17,8 @@ from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework import GlobalContext, PluginContext, PluginMode -from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import GlobalContext, PluginContext, PluginMode, PromptPosthookPayload, PromptPrehookPayload from plugins.regex_filter.search_replace import SearchReplaceConfig, SearchReplacePlugin -from unittest.mock import patch - def test_config_loader_load(): """pytest for testing the config loader.""" diff --git a/tests/unit/mcpgateway/plugins/framework/test_context.py b/tests/unit/mcpgateway/plugins/framework/test_context.py index 0f8a3e0ba..74983f325 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_context.py +++ b/tests/unit/mcpgateway/plugins/framework/test_context.py @@ -11,9 +11,7 @@ from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, ToolPreInvokePayload, ToolPostInvokePayload, ) @@ -28,7 +26,7 @@ async def test_shared_context_across_pre_post_hooks(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 1 context = next(iter(contexts.values())) @@ -45,7 +43,7 @@ async def test_shared_context_across_pre_post_hooks(): # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) assert len(contexts) == 1 context = next(iter(contexts.values())) @@ -74,7 +72,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert len(contexts) == 2 ctxs = [contexts[key] for key in contexts.keys()] @@ -103,7 +101,7 @@ async def test_shared_context_across_pre_post_hooks_multi_plugins(): assert result.modified_payload is None # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) ctxs = [contexts[key] for key in contexts.keys()] assert len(ctxs) == 2 diff --git a/tests/unit/mcpgateway/plugins/framework/test_errors.py b/tests/unit/mcpgateway/plugins/framework/test_errors.py index d74be9911..738113453 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_errors.py +++ b/tests/unit/mcpgateway/plugins/framework/test_errors.py @@ -16,10 +16,10 @@ PluginError, PluginMode, PluginManager, + PromptHookType, + PromptPrehookPayload ) -from mcpgateway.plugins.mcp.entities import HookType, PromptPrehookPayload - @pytest.mark.asyncio async def test_convert_exception_to_error(): @@ -41,7 +41,7 @@ async def test_error_plugin(): global_context = GlobalContext(request_id="1") escaped_regex = re.escape("ValueError('Sadly! Prompt prefetch is broken!')") with pytest.raises(PluginError, match=escaped_regex): - await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) + await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) await plugin_manager.shutdown() @@ -52,14 +52,14 @@ async def test_error_plugin_raise_error_false(): payload = PromptPrehookPayload(prompt_id="test_prompt", args={"arg0": "This is a crap argument"}) global_context = GlobalContext(request_id="1") with pytest.raises(PluginError): - result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) # assert result.continue_processing # assert not result.modified_payload await plugin_manager.shutdown() plugin_manager.config.plugins[0].mode = PluginMode.ENFORCE_IGNORE_ERROR await plugin_manager.initialize() - result, _ = await plugin_manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) + result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) assert result.continue_processing assert not result.modified_payload await plugin_manager.shutdown() diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager.py b/tests/unit/mcpgateway/plugins/framework/test_manager.py index f077f7922..87144d266 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager.py @@ -13,7 +13,7 @@ # First-Party from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework import GlobalContext, PluginManager, PluginViolationError -from mcpgateway.plugins.mcp.entities import HookType, HttpHeaderPayload, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import PromptHookType, ToolHookType, HttpHeaderPayload, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload from plugins.regex_filter.search_replace import SearchReplaceConfig @@ -35,7 +35,7 @@ async def test_manager_single_transformer_prompt_plugin(): assert srconfig.words[0].replace == "crud" prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "What a crapshow!"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert len(result.modified_payload.args) == 1 assert result.modified_payload.args["user"] == "What a yikesshow!" @@ -45,7 +45,7 @@ async def test_manager_single_transformer_prompt_plugin(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await manager.invoke_hook(HookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "What a yikesshow!" await manager.shutdown() @@ -83,7 +83,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "It's always happy at the crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert len(result.modified_payload.args) == 1 assert result.modified_payload.args["user"] == "It's always gleeful at the yikesshow." @@ -93,7 +93,7 @@ async def test_manager_multiple_transformer_preprompt_plugin(): payload_result = PromptPosthookPayload(prompt_id="test_prompt", result=prompt_result) - result, _ = await manager.invoke_hook(HookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, payload_result, global_context=global_context, local_contexts=contexts) assert len(result.modified_payload.result.messages) == 1 assert result.modified_payload.result.messages[0].content.text == "It's sullen at the yikes bakery." await manager.shutdown() @@ -106,7 +106,7 @@ async def test_manager_no_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "It's always happy at the crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert result.continue_processing assert not result.modified_payload await manager.shutdown() @@ -119,12 +119,12 @@ async def test_manager_filter_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "innovative"}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert not result.continue_processing assert result.violation with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert ve.value.violation assert ve.value.violation.reason == "Prompt not allowed" await manager.shutdown() @@ -137,11 +137,11 @@ async def test_manager_multi_filter_plugins(): assert manager.initialized prompt = PromptPrehookPayload(prompt_id="test_prompt", args={"user": "innovative crapshow."}) global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert not result.continue_processing assert result.violation with pytest.raises(PluginViolationError) as ve: - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert ve.value.violation await manager.shutdown() @@ -156,7 +156,7 @@ async def test_manager_tool_hooks_empty(): # Test tool pre-invoke with no plugins tool_payload = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with no modifications assert result.continue_processing @@ -166,7 +166,7 @@ async def test_manager_tool_hooks_empty(): # Test tool post-invoke with no plugins tool_result_payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"}) - result, contexts = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context) # Should continue processing with no modifications assert result.continue_processing @@ -187,7 +187,7 @@ async def test_manager_tool_hooks_with_transformer_plugin(): # Test tool pre-invoke - no plugins configured for tool hooks tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is crap data"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with no modifications (no plugins for tool hooks) assert result.continue_processing @@ -197,7 +197,7 @@ async def test_manager_tool_hooks_with_transformer_plugin(): # Test tool post-invoke - no plugins configured for tool hooks tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result with crap in it"}) - result, _ = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) # Should continue processing with no modifications (no plugins for tool hooks) assert result.continue_processing @@ -217,7 +217,7 @@ async def test_manager_tool_hooks_with_actual_plugin(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing @@ -229,7 +229,7 @@ async def test_manager_tool_hooks_with_actual_plugin(): # Test tool post-invoke with transformation tool_result_payload = ToolPostInvokePayload(name="test_tool", result={"output": "Result was bad", "status": "wrong format"}) - result, _ = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_result_payload, global_context=global_context, local_contexts=contexts) # Should continue processing with transformations applied assert result.continue_processing @@ -252,7 +252,7 @@ async def test_manager_tool_hooks_with_header_mods(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=None) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing @@ -268,7 +268,7 @@ async def test_manager_tool_hooks_with_header_mods(): # Test tool pre-invoke with transformation - use correct tool name from config tool_payload = ToolPreInvokePayload(name="test_tool", args={"input": "This is bad data", "quality": "wrong"}, headers=HttpHeaderPayload({"Content-Type": "application/json"})) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) # Should continue processing with transformations applied assert result.continue_processing diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py index 88091140b..dc037d8c8 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py @@ -30,11 +30,9 @@ PluginResult, PluginViolation, PluginViolationError, -) - -from mcpgateway.plugins.mcp.entities import ( - HookType, - MCPPlugin, + PromptHookType, + ToolHookType, + Plugin, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, @@ -48,7 +46,7 @@ async def test_manager_timeout_handling(): """Test plugin timeout handling in both enforce and permissive modes.""" # Create a plugin that times out - class TimeoutPlugin(MCPPlugin): + class TimeoutPlugin(Plugin): async def prompt_pre_fetch(self, payload, context): await asyncio.sleep(10) # Longer than timeout return PluginResult(continue_processing=True) @@ -65,7 +63,7 @@ async def prompt_pre_fetch(self, payload, context): timeout_plugin = TimeoutPlugin(plugin_config) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={}) @@ -73,7 +71,7 @@ async def prompt_pre_fetch(self, payload, context): escaped_regex = re.escape("Plugin TimeoutPlugin exceeded 0.01s timeout") with pytest.raises(PluginError, match=escaped_regex): - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should pass since fail_on_plugin_error: false # assert result.continue_processing @@ -84,10 +82,10 @@ async def prompt_pre_fetch(self, payload, context): # Test with permissive mode plugin_config.mode = PluginMode.PERMISSIVE with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) mock_get.return_value = [hook_ref] - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode assert result.continue_processing @@ -101,7 +99,7 @@ async def test_manager_exception_handling(): """Test plugin exception handling in both enforce and permissive modes.""" # Create a plugin that raises an exception - class ErrorPlugin(MCPPlugin): + class ErrorPlugin(Plugin): async def prompt_pre_fetch(self, payload, context): raise RuntimeError("Plugin error!") @@ -115,7 +113,7 @@ async def prompt_pre_fetch(self, payload, context): # Test with enforce mode with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={}) @@ -123,7 +121,7 @@ async def prompt_pre_fetch(self, payload, context): escaped_regex = re.escape("RuntimeError('Plugin error!')") with pytest.raises(PluginError, match=escaped_regex): - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should block in enforce mode # assert result.continue_processing @@ -134,10 +132,10 @@ async def prompt_pre_fetch(self, payload, context): # Test with permissive mode plugin_config.mode = PluginMode.PERMISSIVE with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) mock_get.return_value = [hook_ref] - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode assert result.continue_processing @@ -145,10 +143,10 @@ async def prompt_pre_fetch(self, payload, context): plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) mock_get.return_value = [hook_ref] - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing @@ -156,10 +154,10 @@ async def prompt_pre_fetch(self, payload, context): plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) mock_get.return_value = [hook_ref] - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing @@ -167,10 +165,10 @@ async def prompt_pre_fetch(self, payload, context): plugin_config.mode = PluginMode.ENFORCE_IGNORE_ERROR with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) mock_get.return_value = [hook_ref] - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in enforce_ignore_error mode assert result.continue_processing @@ -183,7 +181,7 @@ async def prompt_pre_fetch(self, payload, context): # async def test_manager_condition_filtering(): # """Test that plugins are filtered based on conditions.""" -# class ConditionalPlugin(MCPPlugin): +# class ConditionalPlugin(Plugin): # async def prompt_pre_fetch(self, payload, context): # payload.args["modified"] = "yes" # return PluginResult(continue_processing=True, modified_payload=payload) @@ -236,11 +234,11 @@ async def prompt_pre_fetch(self, payload, context): async def test_manager_metadata_aggregation(): """Test metadata aggregation from multiple plugins.""" - class MetadataPlugin1(MCPPlugin): + class MetadataPlugin1(Plugin): async def prompt_pre_fetch(self, payload, context): return PluginResult(continue_processing=True, metadata={"plugin1": "data1", "shared": "value1"}) - class MetadataPlugin2(MCPPlugin): + class MetadataPlugin2(Plugin): async def prompt_pre_fetch(self, payload, context): return PluginResult( continue_processing=True, @@ -256,13 +254,13 @@ async def prompt_pre_fetch(self, payload, context): plugin2 = MetadataPlugin2(config2) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - refs = [HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin1)), HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin2))] + refs = [HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin1)), HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin2))] mock_get.return_value = refs prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should aggregate metadata assert result.continue_processing @@ -277,7 +275,7 @@ async def prompt_pre_fetch(self, payload, context): async def test_manager_local_context_persistence(): """Test that local contexts persist across hook calls.""" - class StatefulPlugin(MCPPlugin): + class StatefulPlugin(Plugin): async def prompt_pre_fetch(self, payload, context: PluginContext): context.state["counter"] = context.state.get("counter", 0) + 1 return PluginResult(continue_processing=True) @@ -298,13 +296,13 @@ async def prompt_post_fetch(self, payload, context: PluginContext): # Create a single PluginRef to ensure the same UUID is used for both hooks plugin_ref = PluginRef(plugin) - hook_ref_pre = HookRef(HookType.PROMPT_PRE_FETCH, plugin_ref) - hook_ref_post = HookRef(HookType.PROMPT_POST_FETCH, plugin_ref) + hook_ref_pre = HookRef(PromptHookType.PROMPT_PRE_FETCH, plugin_ref) + hook_ref_post = HookRef(PromptHookType.PROMPT_POST_FETCH, plugin_ref) def get_hook_refs_side_effect(hook_type): - if hook_type == HookType.PROMPT_PRE_FETCH: + if hook_type == PromptHookType.PROMPT_PRE_FETCH: return [hook_ref_pre] - elif hook_type == HookType.PROMPT_POST_FETCH: + elif hook_type == PromptHookType.PROMPT_POST_FETCH: return [hook_ref_post] return [] @@ -314,7 +312,7 @@ def get_hook_refs_side_effect(hook_type): prompt = PromptPrehookPayload(prompt_id="test", args={}) global_context = GlobalContext(request_id="1") - result_pre, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result_pre, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) assert result_pre.continue_processing # Call to post_fetch with same contexts @@ -322,7 +320,7 @@ def get_hook_refs_side_effect(hook_type): prompt_result = PromptResult(messages=[message]) post_payload = PromptPosthookPayload(prompt_id="test", result=prompt_result) - result_post, _ = await manager.invoke_hook(HookType.PROMPT_POST_FETCH, post_payload, global_context=global_context, local_contexts=contexts) + result_post, _ = await manager.invoke_hook(PromptHookType.PROMPT_POST_FETCH, post_payload, global_context=global_context, local_contexts=contexts) # Should have modified with persisted state assert result_post.continue_processing @@ -336,7 +334,7 @@ def get_hook_refs_side_effect(hook_type): async def test_manager_plugin_blocking(): """Test plugin blocking behavior in enforce mode.""" - class BlockingPlugin(MCPPlugin): + class BlockingPlugin(Plugin): async def prompt_pre_fetch(self, payload, context): violation = PluginViolation(reason="Content violation", description="Blocked content detected", code="CONTENT_BLOCKED", details={"content": payload.args}) return PluginResult(continue_processing=False, violation=violation) @@ -350,13 +348,13 @@ async def prompt_pre_fetch(self, payload, context): plugin = BlockingPlugin(config) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin)) mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={"text": "bad content"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should block the request assert not result.continue_processing @@ -365,7 +363,7 @@ async def prompt_pre_fetch(self, payload, context): assert result.violation.plugin_name == "BlockingPlugin" with pytest.raises(PluginViolationError) as pve: - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context, violations_as_exceptions=True) assert pve.value.violation assert pve.value.message assert pve.value.violation.code == "CONTENT_BLOCKED" @@ -377,7 +375,7 @@ async def prompt_pre_fetch(self, payload, context): async def test_manager_plugin_permissive_blocking(): """Test plugin behavior when blocking in permissive mode.""" - class BlockingPlugin(MCPPlugin): + class BlockingPlugin(Plugin): async def prompt_pre_fetch(self, payload, context): violation = PluginViolation(reason="Would block", description="Content would be blocked", code="WOULD_BLOCK") return PluginResult(continue_processing=False, violation=violation) @@ -400,13 +398,13 @@ async def prompt_pre_fetch(self, payload, context): # Test permissive mode blocking (covers lines 194-195) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.PROMPT_PRE_FETCH, PluginRef(plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin)) mock_get.return_value = [hook_ref] prompt = PromptPrehookPayload(prompt_id="test", args={"text": "content"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) # Should continue in permissive mode - the permissive logic continues without blocking assert result.continue_processing @@ -446,7 +444,7 @@ async def test_manager_payload_size_validation(): """Test payload size validation functionality.""" # First-Party from mcpgateway.plugins.framework.manager import MAX_PAYLOAD_SIZE, PayloadSizeError, PluginExecutor - from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload, PromptPrehookPayload # Test payload size validation directly on executor (covers lines 252, 258) executor = PluginExecutor() @@ -504,7 +502,7 @@ async def test_manager_initialization_edge_cases(): tags=["test"], kind="nonexistent.Plugin", mode=PluginMode.ENFORCE, - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}, ) ], @@ -528,7 +526,7 @@ async def test_manager_initialization_edge_cases(): tags=["test"], kind="test.Plugin", mode=PluginMode.DISABLED, # Disabled mode - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}, ) ], @@ -545,14 +543,13 @@ async def test_base_plugin_coverage(): # First-Party from mcpgateway.common.models import Message, PromptResult, Role, TextContent from mcpgateway.plugins.framework.base import PluginRef - from mcpgateway.plugins.framework.models import ( + from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, PluginMode, - ) - from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, + ToolHookType, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, @@ -567,11 +564,11 @@ async def test_base_plugin_coverage(): version="1.0", tags=["test", "coverage"], # Tags to be accessed kind="test.Plugin", - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}, ) - plugin = MCPPlugin(config) + plugin = Plugin(config) # Test tags property assert plugin.tags == ["test", "coverage"] @@ -587,7 +584,7 @@ async def test_base_plugin_coverage(): context = PluginContext(global_context=GlobalContext(request_id="test")) payload = PromptPrehookPayload(prompt_id="test", args={}) - with pytest.raises(NotImplementedError, match="'prompt_pre_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'prompt_pre_fetch'"): await plugin.prompt_pre_fetch(payload, context) # Test NotImplementedError for prompt_post_fetch (covers lines 167-171) @@ -595,17 +592,17 @@ async def test_base_plugin_coverage(): result = PromptResult(messages=[message]) post_payload = PromptPosthookPayload(prompt_id="test", result=result) - with pytest.raises(NotImplementedError, match="'prompt_post_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'prompt_post_fetch'"): await plugin.prompt_post_fetch(post_payload, context) # Test default tool_pre_invoke implementation (covers line 191) tool_payload = ToolPreInvokePayload(name="test_tool", args={"key": "value"}) - with pytest.raises(NotImplementedError, match="'tool_pre_invoke' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'tool_pre_invoke'"): await plugin.tool_pre_invoke(tool_payload, context) # Test default tool_post_invoke implementation (covers line 211) tool_post_payload = ToolPostInvokePayload(name="test_tool", result={"result": "success"}) - with pytest.raises(NotImplementedError, match="'tool_post_invoke' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'tool_post_invoke'"): await plugin.tool_post_invoke(tool_post_payload, context) @@ -651,12 +648,11 @@ async def test_plugin_loader_return_none(): # First-Party from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework import PluginConfig - from mcpgateway.plugins.mcp.entities import HookType loader = PluginLoader() # Test return None when plugin_type is None (covers line 90) - config = PluginConfig(name="TestPlugin", description="Test", author="Test", version="1.0", tags=["test"], kind="test.plugin.TestPlugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) + config = PluginConfig(name="TestPlugin", description="Test", author="Test", version="1.0", tags=["test"], kind="test.plugin.TestPlugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}) # Mock the plugin_types dict to contain None for this kind loader._plugin_types[config.kind] = None @@ -697,7 +693,7 @@ async def test_manager_compare_function_wrapper(): # The compare function is used internally in _run_plugins # Test by using plugins with conditions - class TestPlugin(MCPPlugin): + class TestPlugin(Plugin): async def tool_pre_invoke(self, payload, context): return PluginResult(continue_processing=True) @@ -715,19 +711,19 @@ async def tool_pre_invoke(self, payload, context): plugin = TestPlugin(config) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.TOOL_PRE_INVOKE, PluginRef(plugin)) + hook_ref = HookRef(ToolHookType.TOOL_PRE_INVOKE, PluginRef(plugin)) mock_get.return_value = [hook_ref] # Test with matching tool tool_payload = ToolPreInvokePayload(name="calculator", args={}) global_context = GlobalContext(request_id="1") - result, _ = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) assert result.continue_processing # Test with non-matching tool tool_payload2 = ToolPreInvokePayload(name="other_tool", args={}) - result2, _ = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload2, global_context=global_context) + result2, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload2, global_context=global_context) assert result2.continue_processing await manager.shutdown() @@ -739,7 +735,7 @@ async def test_manager_tool_post_invoke_coverage(): manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") await manager.initialize() - class ModifyingPlugin(MCPPlugin): + class ModifyingPlugin(Plugin): async def tool_post_invoke(self, payload, context): payload.result["modified"] = True return PluginResult(continue_processing=True, modified_payload=payload) @@ -748,13 +744,13 @@ async def tool_post_invoke(self, payload, context): plugin = ModifyingPlugin(config) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(HookType.TOOL_POST_INVOKE, PluginRef(plugin)) + hook_ref = HookRef(ToolHookType.TOOL_POST_INVOKE, PluginRef(plugin)) mock_get.return_value = [hook_ref] tool_payload = ToolPostInvokePayload(name="test_tool", result={"original": "data"}) global_context = GlobalContext(request_id="1") - result, _ = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, tool_payload, global_context=global_context) + result, _ = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, tool_payload, global_context=global_context) assert result.continue_processing assert result.modified_payload is not None diff --git a/tests/unit/mcpgateway/plugins/framework/test_registry.py b/tests/unit/mcpgateway/plugins/framework/test_registry.py index 16daa86b1..64fa9e009 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_registry.py +++ b/tests/unit/mcpgateway/plugins/framework/test_registry.py @@ -16,9 +16,9 @@ # First-Party from mcpgateway.plugins.framework.loader.config import ConfigLoader from mcpgateway.plugins.framework.loader.plugin import PluginLoader -from mcpgateway.plugins.framework import PluginConfig -from mcpgateway.plugins.mcp.entities import HookType, MCPPlugin +from mcpgateway.plugins.framework import PluginConfig, Plugin, PromptHookType, ToolHookType from mcpgateway.plugins.framework.registry import PluginInstanceRegistry +from tests.unit.mcpgateway.plugins.fixtures.plugins.simple import SimplePromptPlugin @pytest.mark.asyncio @@ -78,7 +78,7 @@ async def test_registry_priority_sorting(): version="1.0", tags=["test"], kind="test.Plugin", - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], priority=300, # High number = low priority config={}, ) @@ -90,27 +90,27 @@ async def test_registry_priority_sorting(): version="1.0", tags=["test"], kind="test.Plugin", - hooks=[HookType.PROMPT_PRE_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH], priority=50, # Low number = high priority config={}, ) # Create plugin instances - low_priority_plugin = MCPPlugin(low_priority_config) - high_priority_plugin = MCPPlugin(high_priority_config) + low_priority_plugin = SimplePromptPlugin(low_priority_config) + high_priority_plugin = SimplePromptPlugin(high_priority_config) # Register plugins in reverse priority order registry.register(low_priority_plugin) registry.register(high_priority_plugin) # Get plugins for hook - should be sorted by priority (lines 131-134) - hook_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) + hook_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(hook_plugins) == 2 assert hook_plugins[0].plugin_ref.name == "HighPriority" # Lower number = higher priority assert hook_plugins[1].plugin_ref.name == "LowPriority" # Test priority cache - calling again should use cached result - cached_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) + cached_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert cached_plugins == hook_plugins # Clean up @@ -126,23 +126,23 @@ async def test_registry_hook_filtering(): # Create plugin with specific hooks pre_fetch_config = PluginConfig( - name="PreFetchPlugin", description="Pre-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={} + name="PreFetchPlugin", description="Pre-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={} ) post_fetch_config = PluginConfig( - name="PostFetchPlugin", description="Post-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_POST_FETCH], config={} + name="PostFetchPlugin", description="Post-fetch plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_POST_FETCH], config={} ) - pre_fetch_plugin = MCPPlugin(pre_fetch_config) - post_fetch_plugin = MCPPlugin(post_fetch_config) + pre_fetch_plugin = SimplePromptPlugin(pre_fetch_config) + post_fetch_plugin = SimplePromptPlugin(post_fetch_config) registry.register(pre_fetch_plugin) registry.register(post_fetch_plugin) # Test hook filtering - pre_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) - post_plugins = registry.get_hook_refs_for_hook(HookType.PROMPT_POST_FETCH) - tool_plugins = registry.get_hook_refs_for_hook(HookType.TOOL_PRE_INVOKE) + pre_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) + post_plugins = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_POST_FETCH) + tool_plugins = registry.get_hook_refs_for_hook(ToolHookType.TOOL_PRE_INVOKE) assert len(pre_plugins) == 1 assert pre_plugins[0].plugin_ref.name == "PreFetchPlugin" @@ -163,9 +163,9 @@ async def test_registry_shutdown(): registry = PluginInstanceRegistry() # Create mock plugins with shutdown methods - mock_plugin1 = MCPPlugin(PluginConfig(name="Plugin1", description="Test plugin 1", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={})) + mock_plugin1 = SimplePromptPlugin(PluginConfig(name="Plugin1", description="Test plugin 1", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={})) - mock_plugin2 = MCPPlugin(PluginConfig(name="Plugin2", description="Test plugin 2", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_POST_FETCH], config={})) + mock_plugin2 = SimplePromptPlugin(PluginConfig(name="Plugin2", description="Test plugin 2", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_POST_FETCH], config={})) # Mock the shutdown methods mock_plugin1.shutdown = AsyncMock() @@ -196,8 +196,8 @@ async def test_registry_shutdown_with_error(): registry = PluginInstanceRegistry() # Create mock plugin that fails during shutdown - failing_plugin = MCPPlugin( - PluginConfig(name="FailingPlugin", description="Plugin that fails shutdown", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) + failing_plugin = SimplePromptPlugin( + PluginConfig(name="FailingPlugin", description="Plugin that fails shutdown", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}) ) # Mock shutdown to raise an exception @@ -232,7 +232,7 @@ async def test_registry_edge_cases(): assert registry.plugin_count == 0 # Test getting hooks for empty registry - empty_hooks = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) + empty_hooks = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(empty_hooks) == 0 # Test get_all_plugins when empty @@ -244,23 +244,23 @@ async def test_registry_cache_invalidation(): """Test that priority cache is invalidated correctly.""" registry = PluginInstanceRegistry() - plugin_config = PluginConfig(name="TestPlugin", description="Test plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[HookType.PROMPT_PRE_FETCH], config={}) + plugin_config = PluginConfig(name="TestPlugin", description="Test plugin", author="Test", version="1.0", tags=["test"], kind="test.Plugin", hooks=[PromptHookType.PROMPT_PRE_FETCH], config={}) - plugin = MCPPlugin(plugin_config) + plugin = SimplePromptPlugin(plugin_config) # Register plugin registry.register(plugin) # Get plugins for hook (populates cache) - hooks1 = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) + hooks1 = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(hooks1) == 1 # Cache should be populated - assert HookType.PROMPT_PRE_FETCH in registry._priority_cache + assert PromptHookType.PROMPT_PRE_FETCH in registry._priority_cache # Unregister plugin (should invalidate cache) registry.unregister("TestPlugin") # Cache should be cleared for this hook type - hooks2 = registry.get_hook_refs_for_hook(HookType.PROMPT_PRE_FETCH) + hooks2 = registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH) assert len(hooks2) == 0 diff --git a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py index b120b0a75..b783ec45f 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py +++ b/tests/unit/mcpgateway/plugins/framework/test_resource_hooks.py @@ -27,10 +27,8 @@ PluginManager, PluginMode, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, - MCPPlugin, + ResourceHookType, + Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, @@ -64,14 +62,14 @@ async def test_plugin_resource_pre_fetch_default(self): author="test", kind="test.Plugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], ) - plugin = MCPPlugin(config) + plugin = Plugin(config) payload = ResourcePreFetchPayload(uri="file:///test.txt", metadata={}) context = PluginContext(global_context=GlobalContext(request_id="test-123")) - with pytest.raises(NotImplementedError, match="'resource_pre_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'resource_pre_fetch'"): await plugin.resource_pre_fetch(payload, context) @pytest.mark.asyncio @@ -83,22 +81,22 @@ async def test_plugin_resource_post_fetch_default(self): author="test", kind="test.Plugin", version="1.0.0", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], tags=["test"], ) - plugin = MCPPlugin(config) + plugin = Plugin(config) content = ResourceContent(type="resource", id="123",uri="file:///test.txt", text="Test content") payload = ResourcePostFetchPayload(uri="file:///test.txt", content=content) context = PluginContext(global_context=GlobalContext(request_id="test-123")) - with pytest.raises(NotImplementedError, match="'resource_post_fetch' not implemented"): + with pytest.raises(AttributeError, match="'Plugin' object has no attribute 'resource_post_fetch'"): await plugin.resource_post_fetch(payload, context) @pytest.mark.asyncio async def test_resource_hook_blocking(self): """Test resource hook that blocks processing.""" - class BlockingResourcePlugin(MCPPlugin): + class BlockingResourcePlugin(Plugin): async def resource_pre_fetch(self, payload, context): return ResourcePreFetchResult( continue_processing=False, @@ -116,7 +114,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.BlockingPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], mode=PluginMode.ENFORCE, ) @@ -135,7 +133,7 @@ async def resource_pre_fetch(self, payload, context): async def test_resource_content_modification(self): """Test resource post-fetch content modification.""" - class ContentFilterPlugin(MCPPlugin): + class ContentFilterPlugin(Plugin): async def resource_post_fetch(self, payload, context): # Modify content to redact sensitive data modified_text = payload.content.text.replace("password: secret123", "password: [REDACTED]") @@ -160,7 +158,7 @@ async def resource_post_fetch(self, payload, context): author="test", kind="test.FilterPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], tags=["filter"], ) plugin = ContentFilterPlugin(config) @@ -184,7 +182,7 @@ async def resource_post_fetch(self, payload, context): async def test_resource_hook_with_conditions(self): """Test resource hooks with conditions.""" - class ConditionalResourcePlugin(MCPPlugin): + class ConditionalResourcePlugin(Plugin): async def resource_pre_fetch(self, payload, context): # Only process if conditions match return ResourcePreFetchResult( @@ -201,7 +199,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.ConditionalPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["conditional"], conditions=[ PluginCondition( @@ -276,10 +274,10 @@ async def test_manager_resource_pre_fetch(self): payload = ResourcePreFetchPayload(uri="test://resource", metadata={}) global_context = GlobalContext(request_id="test-123") - result, contexts = await manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) assert result.continue_processing is True - MockRegistry.return_value.get_hook_refs_for_hook.assert_called_with(hook_type=HookType.RESOURCE_PRE_FETCH) + MockRegistry.return_value.get_hook_refs_for_hook.assert_called_with(hook_type=ResourceHookType.RESOURCE_PRE_FETCH) @pytest.mark.asyncio async def test_manager_resource_post_fetch(self): @@ -287,7 +285,7 @@ async def test_manager_resource_post_fetch(self): # First-Party from mcpgateway.plugins.framework.base import HookRef - class TestResourcePlugin(MCPPlugin): + class TestResourcePlugin(Plugin): async def resource_post_fetch(self, payload, context): return ResourcePostFetchResult( continue_processing=True, @@ -300,13 +298,13 @@ async def resource_post_fetch(self, payload, context): author="test", kind="test.Plugin", version="1.0.0", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], tags=["test"], mode=PluginMode.ENFORCE, ) plugin = TestResourcePlugin(config) plugin_ref = PluginRef(plugin) - hook_ref = HookRef(HookType.RESOURCE_POST_FETCH, plugin_ref) + hook_ref = HookRef(ResourceHookType.RESOURCE_POST_FETCH, plugin_ref) manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") await manager.initialize() @@ -316,10 +314,10 @@ async def resource_post_fetch(self, payload, context): payload = ResourcePostFetchPayload(uri="test://resource", content=content) global_context = GlobalContext(request_id="test-123") - result, contexts = await manager.invoke_hook(HookType.RESOURCE_POST_FETCH, payload, global_context, {}) + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_POST_FETCH, payload, global_context, {}) assert result.continue_processing is True - manager._registry.get_hook_refs_for_hook.assert_called_with(hook_type=HookType.RESOURCE_POST_FETCH) + manager._registry.get_hook_refs_for_hook.assert_called_with(hook_type=ResourceHookType.RESOURCE_POST_FETCH) await manager.shutdown() @@ -327,7 +325,7 @@ async def resource_post_fetch(self, payload, context): async def test_resource_hook_chain_execution(self): """Test multiple resource plugins executing in priority order.""" - class FirstPlugin(MCPPlugin): + class FirstPlugin(Plugin): async def resource_pre_fetch(self, payload, context): # Add metadata payload.metadata["first"] = True @@ -336,7 +334,7 @@ async def resource_pre_fetch(self, payload, context): modified_payload=payload, ) - class SecondPlugin(MCPPlugin): + class SecondPlugin(Plugin): async def resource_pre_fetch(self, payload, context): # Check first plugin ran assert payload.metadata.get("first") is True @@ -352,7 +350,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.First", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], priority=10, # Higher priority ) @@ -362,7 +360,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.Second", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], priority=20, # Lower priority ) @@ -383,7 +381,7 @@ async def test_resource_hook_error_handling(self): # First-Party from mcpgateway.plugins.framework.base import HookRef - class ErrorPlugin(MCPPlugin): + class ErrorPlugin(Plugin): async def resource_pre_fetch(self, payload, context): raise ValueError("Test error in plugin") @@ -393,13 +391,13 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.ErrorPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["test"], mode=PluginMode.PERMISSIVE, # Should continue on error ) plugin = ErrorPlugin(config) plugin_ref = PluginRef(plugin) - hook_ref = HookRef(HookType.RESOURCE_PRE_FETCH, plugin_ref) + hook_ref = HookRef(ResourceHookType.RESOURCE_PRE_FETCH, plugin_ref) manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") await manager.initialize() @@ -409,14 +407,14 @@ async def resource_pre_fetch(self, payload, context): # Test with permissive mode - should handle error gracefully with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): - result, contexts = await manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) assert result.continue_processing is True # Continues despite error # Test with enforce mode - should raise PluginError config.mode = PluginMode.ENFORCE with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): with pytest.raises(PluginError): - result, contexts = await manager.invoke_hook(HookType.RESOURCE_PRE_FETCH, payload, global_context) + result, contexts = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context) await manager.shutdown() @@ -424,7 +422,7 @@ async def resource_pre_fetch(self, payload, context): async def test_resource_uri_modification(self): """Test resource URI modification in pre-fetch.""" - class URIModifierPlugin(MCPPlugin): + class URIModifierPlugin(Plugin): async def resource_pre_fetch(self, payload, context): # Modify URI to add prefix modified_payload = ResourcePreFetchPayload( @@ -442,7 +440,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.URIModifier", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["modifier"], ) plugin = URIModifierPlugin(config) @@ -459,7 +457,7 @@ async def resource_pre_fetch(self, payload, context): async def test_resource_metadata_enrichment(self): """Test resource metadata enrichment in pre-fetch.""" - class MetadataEnricherPlugin(MCPPlugin): + class MetadataEnricherPlugin(Plugin): async def resource_pre_fetch(self, payload, context): # Add metadata payload.metadata["timestamp"] = "2024-01-01T00:00:00Z" @@ -476,7 +474,7 @@ async def resource_pre_fetch(self, payload, context): author="test", kind="test.Enricher", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], tags=["enricher"], ) plugin = MetadataEnricherPlugin(config) diff --git a/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py b/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py index 7fb6fa5a3..c230550ad 100644 --- a/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py +++ b/tests/unit/mcpgateway/plugins/plugins/altk_json_processor/test_json_processor.py @@ -18,9 +18,7 @@ GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, ToolPostInvokePayload, ) @@ -41,7 +39,7 @@ async def test_threshold(): plugin = ALTKJsonProcessor( # type: ignore PluginConfig( - name="jsonprocessor", kind="plugins.altk_json_processor.json_processor.ALTKJsonProcessor", hooks=[HookType.TOOL_POST_INVOKE], config={"llm_provider": "pytestmock", "length_threshold": 50} + name="jsonprocessor", kind="plugins.altk_json_processor.json_processor.ALTKJsonProcessor", hooks=[ToolHookType.TOOL_POST_INVOKE], config={"llm_provider": "pytestmock", "length_threshold": 50} ) ) ctx = PluginContext(global_context=GlobalContext(request_id="r1")) diff --git a/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py b/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py index 022ad5dff..1f9d1db6d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py +++ b/tests/unit/mcpgateway/plugins/plugins/argument_normalizer/test_argument_normalizer.py @@ -15,9 +15,8 @@ GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, + ToolHookType, PromptPrehookPayload, ToolPreInvokePayload, ) @@ -32,7 +31,7 @@ def _mk_plugin(config: dict | None = None) -> ArgumentNormalizerPlugin: cfg = PluginConfig( name="arg_norm", kind="plugins.argument_normalizer.argument_normalizer.ArgumentNormalizerPlugin", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.TOOL_PRE_INVOKE], + hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE], priority=30, config=config or {}, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py b/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py index 631e3c8f2..6025a302b 100644 --- a/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py +++ b/tests/unit/mcpgateway/plugins/plugins/cached_tool_result/test_cached_tool_result.py @@ -9,14 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) - -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, ToolPreInvokePayload, ToolPostInvokePayload, ) @@ -29,7 +26,7 @@ async def test_cache_store_and_hit(): PluginConfig( name="cache", kind="plugins.cached_tool_result.cached_tool_result.CachedToolResultPlugin", - hooks=[HookType.TOOL_PRE_INVOKE, HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_PRE_INVOKE, ToolHookType.TOOL_POST_INVOKE], config={"cacheable_tools": ["echo"], "ttl": 60}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py b/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py index be3577281..8429d587d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py +++ b/tests/unit/mcpgateway/plugins/plugins/code_safety_linter/test_code_safety_linter.py @@ -9,13 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, ToolPostInvokePayload, ) from plugins.code_safety_linter.code_safety_linter import CodeSafetyLinterPlugin @@ -27,7 +25,7 @@ async def test_detects_eval_pattern(): PluginConfig( name="csl", kind="plugins.code_safety_linter.code_safety_linter.CodeSafetyLinterPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], ) ) ctx = PluginContext(global_context=GlobalContext(request_id="r1")) diff --git a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py index 70b1b58a5..6cb5a349a 100644 --- a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py +++ b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation.py @@ -16,9 +16,8 @@ PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, + ToolHookType, PromptPrehookPayload, ToolPreInvokePayload, ToolPostInvokePayload, @@ -65,7 +64,7 @@ def _create_plugin(config_dict=None) -> ContentModerationPlugin: PluginConfig( name="content_moderation_test", kind="plugins.content_moderation.content_moderation.ContentModerationPlugin", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.TOOL_PRE_INVOKE], + hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE], config=default_config, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py index 489fca952..8c5202b3a 100644 --- a/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py +++ b/tests/unit/mcpgateway/plugins/plugins/content_moderation/test_content_moderation_integration.py @@ -15,8 +15,9 @@ from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework import GlobalContext -from mcpgateway.plugins.mcp.entities import ( - HookType, +from mcpgateway.plugins.framework import ( + PromptHookType, + ToolHookType, PromptPrehookPayload, ToolPreInvokePayload, ) @@ -112,7 +113,7 @@ async def test_content_moderation_with_manager(): args={"query": "What is the weather like today?"} ) - result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Verify result assert result.continue_processing is True @@ -195,7 +196,7 @@ async def test_content_moderation_blocking_harmful_content(): args={"query": "I hate all those people and want them gone"} ) - result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Should be blocked due to high hate score assert result.continue_processing is False @@ -271,7 +272,7 @@ async def test_content_moderation_with_granite_fallback(): args={"query": "How to resolve conflicts peacefully"} ) - result, final_context = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, payload, context) + result, final_context = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, context) # Should continue processing (fallback succeeded) assert result.continue_processing is True @@ -352,7 +353,7 @@ async def test_content_moderation_redaction(): args={"query": "This damn thing is not working"} ) - result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Should continue processing but with modified content assert result.continue_processing is True @@ -443,7 +444,7 @@ async def test_content_moderation_multiple_providers(): args={"query": "What is machine learning?"} ) - prompt_result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt_payload, context) + prompt_result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt_payload, context) assert prompt_result.continue_processing is True # Test tool (goes to Granite) @@ -452,7 +453,7 @@ async def test_content_moderation_multiple_providers(): args={"query": "How to build AI models"} ) - tool_result, _ = await manager.invoke_hook(HookType.TOOL_PRE_INVOKE, tool_payload, context) + tool_result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, context) assert tool_result.continue_processing is True # Verify both providers were called diff --git a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py index 2817c7dcc..baadb334d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py +++ b/tests/unit/mcpgateway/plugins/plugins/external_clamav/test_clamav_remote.py @@ -13,9 +13,7 @@ GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload, ) @@ -32,7 +30,7 @@ def _mk_plugin(block_on_positive: bool = True) -> ClamAVRemotePlugin: cfg = PluginConfig( name="clamav", kind="plugins.external.clamav_server.clamav_plugin.ClamAVRemotePlugin", - hooks=[HookType.RESOURCE_PRE_FETCH, HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH, ResourceHookType.RESOURCE_POST_FETCH], config={ "mode": "eicar_only", "block_on_positive": block_on_positive, @@ -80,7 +78,7 @@ async def test_non_blocking_mode_reports_metadata(tmp_path): @pytest.mark.asyncio async def test_prompt_post_fetch_blocks_on_eicar_text(): plugin = _mk_plugin(True) - from mcpgateway.plugins.mcp.entities import PromptPosthookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload pr = PromptResult( messages=[ @@ -100,7 +98,7 @@ async def test_prompt_post_fetch_blocks_on_eicar_text(): @pytest.mark.asyncio async def test_tool_post_invoke_blocks_on_eicar_string(): plugin = _mk_plugin(True) - from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload ctx = PluginContext(global_context=GlobalContext(request_id="r5")) payload = ToolPostInvokePayload(name="t", result={"text": EICAR}) @@ -121,7 +119,7 @@ async def test_health_stats_counters(): await plugin.resource_post_fetch(payload_r, ctx) # 2) prompt_post_fetch with EICAR -> attempted +1, infected +1 (total attempted=2, infected=2) - from mcpgateway.plugins.mcp.entities import PromptPosthookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload pr = PromptResult( messages=[ @@ -135,7 +133,7 @@ async def test_health_stats_counters(): await plugin.prompt_post_fetch(payload_p, ctx) # 3) tool_post_invoke with one EICAR and one clean string -> attempted +2, infected +1 - from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload_t = ToolPostInvokePayload(name="t", result={"a": EICAR, "b": "clean"}) await plugin.tool_post_invoke(payload_t, ctx) diff --git a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py index 44b2ade84..82d809c4d 100644 --- a/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py +++ b/tests/unit/mcpgateway/plugins/plugins/file_type_allowlist/test_file_type_allowlist.py @@ -9,14 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) - -from mcpgateway.plugins.mcp.entities import ( - HookType, + ResourceHookType, ResourcePreFetchPayload, ResourcePostFetchPayload, ) @@ -30,7 +27,7 @@ async def test_blocks_disallowed_extension_and_mime(): PluginConfig( name="fta", kind="plugins.file_type_allowlist.file_type_allowlist.FileTypeAllowlistPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH, HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH, ResourceHookType.RESOURCE_POST_FETCH], config={"allowed_extensions": [".md"], "allowed_mime_types": ["text/markdown"]}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py index 33bf9fd75..165ea9c67 100644 --- a/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py +++ b/tests/unit/mcpgateway/plugins/plugins/html_to_markdown/test_html_to_markdown.py @@ -9,13 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ResourceHookType, ResourcePostFetchPayload, ) from mcpgateway.common.models import ResourceContent @@ -28,7 +26,7 @@ async def test_html_to_markdown_transforms_basic_html(): PluginConfig( name="html2md", kind="plugins.html_to_markdown.html_to_markdown.HTMLToMarkdownPlugin", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], ) ) html = "

Title

Hello link

print('x')
" diff --git a/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py b/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py index 2be4c4213..07e089d24 100644 --- a/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py +++ b/tests/unit/mcpgateway/plugins/plugins/json_repair/test_json_repair.py @@ -10,14 +10,11 @@ import json import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) - -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, ToolPostInvokePayload, ) from plugins.json_repair.json_repair import JSONRepairPlugin @@ -29,7 +26,7 @@ async def test_repairs_trailing_commas_and_single_quotes(): PluginConfig( name="jsonr", kind="plugins.json_repair.json_repair.JSONRepairPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], ) ) ctx = PluginContext(global_context=GlobalContext(request_id="r1")) diff --git a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py index b4db80dfa..9f469f0ec 100644 --- a/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py +++ b/tests/unit/mcpgateway/plugins/plugins/markdown_cleaner/test_markdown_cleaner.py @@ -10,13 +10,11 @@ import pytest from mcpgateway.common.models import Message, PromptResult, TextContent -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, PromptPosthookPayload, ) from plugins.markdown_cleaner.markdown_cleaner import MarkdownCleanerPlugin @@ -28,7 +26,7 @@ async def test_cleans_markdown_prompt(): PluginConfig( name="mdclean", kind="plugins.markdown_cleaner.markdown_cleaner.MarkdownCleanerPlugin", - hooks=[HookType.PROMPT_POST_FETCH], + hooks=[PromptHookType.PROMPT_POST_FETCH], ) ) txt = "#Heading\n\n\n* item\n\n```\n\n```\n" diff --git a/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py b/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py index 884da9828..37e0796e9 100644 --- a/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py +++ b/tests/unit/mcpgateway/plugins/plugins/output_length_guard/test_output_length_guard.py @@ -8,14 +8,11 @@ """ # First-Party -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) - -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, ToolPostInvokePayload, ) @@ -30,7 +27,7 @@ def _mk_plugin(config: dict | None = None) -> OutputLengthGuardPlugin: cfg = PluginConfig( name="out_len_guard", kind="plugins.output_length_guard.output_length_guard.OutputLengthGuardPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], priority=90, config=config or {}, ) diff --git a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py index b0ac9890c..bd4979abd 100644 --- a/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/pii_filter/test_pii_filter.py @@ -17,9 +17,7 @@ PluginConfig, PluginContext, PluginMode, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, PromptPosthookPayload, PromptPrehookPayload, ) @@ -231,7 +229,7 @@ def plugin_config(self) -> PluginConfig: author="Test", kind="plugins.pii_filter.pii_filter.PIIFilterPlugin", version="1.0", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.PROMPT_POST_FETCH], + hooks=[PromptHookType.PROMPT_PRE_FETCH, PromptHookType.PROMPT_POST_FETCH], tags=["test", "pii"], mode=PluginMode.ENFORCE, priority=10, @@ -416,7 +414,7 @@ async def test_integration_with_manager(): payload = PromptPrehookPayload(prompt_id="test_prompt", args={"input": "Email: test@example.com, SSN: 123-45-6789"}) global_context = GlobalContext(request_id="test-manager") - result, contexts = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, global_context) + result, contexts = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) # Verify PII was masked assert result.modified_payload is not None diff --git a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py index 0f152bb6a..2ee6d0db3 100644 --- a/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py +++ b/tests/unit/mcpgateway/plugins/plugins/rate_limiter/test_rate_limiter.py @@ -9,14 +9,13 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, PromptPrehookPayload, + ToolHookType ) from plugins.rate_limiter.rate_limiter import RateLimiterPlugin @@ -26,7 +25,7 @@ def _mk(rate: str) -> RateLimiterPlugin: PluginConfig( name="rl", kind="plugins.rate_limiter.rate_limiter.RateLimiterPlugin", - hooks=[HookType.PROMPT_PRE_FETCH, HookType.TOOL_PRE_INVOKE], + hooks=[PromptHookType.PROMPT_PRE_FETCH, ToolHookType.TOOL_PRE_INVOKE], config={"by_user": rate}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py index a5bac8a43..bbe2032f2 100644 --- a/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py +++ b/tests/unit/mcpgateway/plugins/plugins/resource_filter/test_resource_filter.py @@ -12,14 +12,12 @@ # First-Party from mcpgateway.common.models import ResourceContent -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, PluginMode, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload, ) @@ -38,7 +36,7 @@ def plugin_config(self): author="test", kind="plugins.resource_filter.resource_filter.ResourceFilterPlugin", version="1.0.0", - hooks=[HookType.RESOURCE_PRE_FETCH, HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH, ResourceHookType.RESOURCE_POST_FETCH], tags=["test", "filter"], mode=PluginMode.ENFORCE, config={ diff --git a/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py b/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py index 18c818e2b..0dd8cf008 100644 --- a/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py +++ b/tests/unit/mcpgateway/plugins/plugins/schema_guard/test_schema_guard.py @@ -9,13 +9,11 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, ToolPreInvokePayload, ToolPostInvokePayload, ) @@ -39,7 +37,7 @@ async def test_schema_guard_valid_and_invalid(): PluginConfig( name="sg", kind="plugins.schema_guard.schema_guard.SchemaGuardPlugin", - hooks=[HookType.TOOL_PRE_INVOKE, HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_PRE_INVOKE, ToolHookType.TOOL_POST_INVOKE], config=cfg, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py b/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py index be9768faf..a8eb15a83 100644 --- a/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py +++ b/tests/unit/mcpgateway/plugins/plugins/url_reputation/test_url_reputation.py @@ -13,9 +13,7 @@ GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ResourceHookType, ResourcePreFetchPayload, ) from plugins.url_reputation.url_reputation import URLReputationPlugin @@ -27,7 +25,7 @@ async def test_blocks_blocklisted_domain(): PluginConfig( name="urlrep", kind="plugins.url_reputation.url_reputation.URLReputationPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={"blocked_domains": ["bad.example"]}, ) ) diff --git a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py index a12432057..2e9a04395 100644 --- a/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py +++ b/tests/unit/mcpgateway/plugins/plugins/virus_total_checker/test_virus_total_checker.py @@ -13,13 +13,13 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + PromptHookType, + ResourceHookType, + ToolHookType, ResourcePreFetchPayload, ) @@ -70,7 +70,7 @@ async def test_url_block_on_malicious(tmp_path, monkeypatch): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={ "enabled": True, "check_url": True, @@ -136,7 +136,7 @@ async def test_local_allow_and_deny_overrides(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -146,7 +146,7 @@ async def test_local_allow_and_deny_overrides(): plugin = VirusTotalURLCheckerPlugin(cfg) plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"See {url}") ctx = PluginContext(global_context=GlobalContext(request_id="r7")) res = await plugin.tool_post_invoke(payload, ctx) @@ -157,7 +157,7 @@ async def test_local_allow_and_deny_overrides(): cfg2 = PluginConfig( name="vt2", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -180,7 +180,7 @@ async def test_override_precedence_allow_over_deny_vs_deny_over_allow(): cfg_allow = PluginConfig( name="vt-allow", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -192,7 +192,7 @@ async def test_override_precedence_allow_over_deny_vs_deny_over_allow(): plugin_allow = VirusTotalURLCheckerPlugin(cfg_allow) plugin_allow._client_factory = lambda c, h: _StubClient({}) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"visit {url}") ctx = PluginContext(global_context=GlobalContext(request_id="r8")) res_allow = await plugin_allow.tool_post_invoke(payload, ctx) @@ -202,7 +202,7 @@ async def test_override_precedence_allow_over_deny_vs_deny_over_allow(): cfg_deny = PluginConfig( name="vt-deny", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -223,7 +223,7 @@ async def test_prompt_scan_blocks_on_url(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.PROMPT_POST_FETCH], + hooks=[PromptHookType.PROMPT_POST_FETCH], config={ "enabled": True, "scan_prompt_outputs": True, @@ -251,7 +251,7 @@ async def test_prompt_scan_blocks_on_url(): os.environ["VT_API_KEY"] = "dummy" pr = PromptResult(messages=[Message(role="assistant", content=TextContent(type="text", text=f"see {url}"))]) - from mcpgateway.plugins.mcp.entities import PromptPosthookPayload + from mcpgateway.plugins.framework import PromptPosthookPayload payload = PromptPosthookPayload(prompt_id="p", result=pr) ctx = PluginContext(global_context=GlobalContext(request_id="r5")) res = await plugin.prompt_post_fetch(payload, ctx) @@ -264,7 +264,7 @@ async def test_resource_scan_blocks_on_url(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_POST_FETCH], + hooks=[ResourceHookType.RESOURCE_POST_FETCH], config={ "enabled": True, "scan_resource_contents": True, @@ -293,7 +293,7 @@ async def test_resource_scan_blocks_on_url(): from mcpgateway.common.models import ResourceContent rc = ResourceContent(type="resource", id="345",uri="test://x", mime_type="text/plain", text=f"{url} is fishy") - from mcpgateway.plugins.mcp.entities import ResourcePostFetchPayload + from mcpgateway.plugins.framework import ResourcePostFetchPayload payload = ResourcePostFetchPayload(uri="test://x", content=rc) ctx = PluginContext(global_context=GlobalContext(request_id="r6")) res = await plugin.resource_post_fetch(payload, ctx) @@ -311,7 +311,7 @@ async def test_file_hash_lookup_blocks(tmp_path, monkeypatch): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={ "enabled": True, "enable_file_checks": True, @@ -355,7 +355,7 @@ async def test_unknown_file_then_upload_wait_allows_when_clean(tmp_path): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.RESOURCE_PRE_FETCH], + hooks=[ResourceHookType.RESOURCE_PRE_FETCH], config={ "enabled": True, "enable_file_checks": True, @@ -404,7 +404,7 @@ async def test_tool_output_url_block_and_ratio(): cfg = PluginConfig( name="vt", kind="plugins.virus_total_checker.virus_total_checker.VirusTotalURLCheckerPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config={ "enabled": True, "scan_tool_outputs": True, @@ -435,7 +435,7 @@ async def test_tool_output_url_block_and_ratio(): plugin._client_factory = lambda c, h: _StubClient(routes) # type: ignore os.environ["VT_API_KEY"] = "dummy" - from mcpgateway.plugins.mcp.entities import ToolPostInvokePayload + from mcpgateway.plugins.framework import ToolPostInvokePayload payload = ToolPostInvokePayload(name="writer", result=f"See {url} for details") ctx = PluginContext(global_context=GlobalContext(request_id="r4")) diff --git a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py index 9eae48c7f..22a353c19 100644 --- a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py +++ b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_integration.py @@ -16,8 +16,10 @@ from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework import ( GlobalContext, + PromptHookType, + ToolHookType, + ToolPostInvokePayload ) -from mcpgateway.plugins.mcp.entities import HookType, ToolPostInvokePayload @pytest.mark.asyncio @@ -80,7 +82,7 @@ async def test_webhook_plugin_with_manager(): ) # Execute tool post-invoke hook - result, final_context = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) + result, final_context = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) # Verify result assert result.continue_processing is True @@ -163,14 +165,14 @@ async def test_webhook_plugin_violation_handling(): context = GlobalContext(request_id="violation-test", user="testuser") # Create payload with forbidden word that will trigger deny filter - from mcpgateway.plugins.mcp.entities import PromptPrehookPayload + from mcpgateway.plugins.framework import PromptPrehookPayload payload = PromptPrehookPayload( prompt_id="test_prompt", args={"query": "this contains forbidden word"} ) # Execute - should be blocked by deny filter - result, final_context = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, payload, context) + result, final_context = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, context) # Verify the request was blocked assert result.continue_processing is False @@ -247,7 +249,7 @@ async def test_webhook_plugin_multiple_webhooks(): ) # Execute hook - result, final_context = await manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) + result, final_context = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) assert result.continue_processing is True @@ -340,7 +342,7 @@ async def test_webhook_plugin_template_customization(): result={"data": "test"} ) - await manager.invoke_hook(HookType.TOOL_POST_INVOKE, payload, context) + await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, context) # Verify webhook was called with custom template mock_client.post.assert_called_once() diff --git a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py index 6aceeb285..a05c41b93 100644 --- a/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py +++ b/tests/unit/mcpgateway/plugins/plugins/webhook_notification/test_webhook_notification.py @@ -11,14 +11,12 @@ import pytest -from mcpgateway.plugins.framework.models import ( +from mcpgateway.plugins.framework import ( GlobalContext, PluginConfig, PluginContext, PluginViolation, -) -from mcpgateway.plugins.mcp.entities import ( - HookType, + ToolHookType, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload, @@ -55,7 +53,7 @@ def _create_plugin(config_dict=None) -> WebhookNotificationPlugin: PluginConfig( name="webhook_test", kind="plugins.webhook_notification.webhook_notification.WebhookNotificationPlugin", - hooks=[HookType.TOOL_POST_INVOKE], + hooks=[ToolHookType.TOOL_POST_INVOKE], config=default_config, ) ) @@ -465,7 +463,8 @@ async def test_prompt_pre_and_post_hooks_return_success(self): # Test post-hook with mock notification plugin._notify_webhooks = AsyncMock() - from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptResult + from mcpgateway.plugins.framework import PromptPosthookPayload + from mcpgateway.common.models import PromptResult post_payload = PromptPosthookPayload( prompt_id="test_prompt", result=PromptResult(messages=[]) diff --git a/tests/unit/mcpgateway/services/test_resource_service_plugins.py b/tests/unit/mcpgateway/services/test_resource_service_plugins.py index bb79c9af4..fd3fdf513 100644 --- a/tests/unit/mcpgateway/services/test_resource_service_plugins.py +++ b/tests/unit/mcpgateway/services/test_resource_service_plugins.py @@ -81,7 +81,7 @@ async def test_read_resource_without_plugins(self, resource_service, mock_db): async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plugins, mock_db): """Test read_resource with pre-fetch hook execution.""" # First-Party - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ResourceHookType import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True @@ -113,7 +113,7 @@ async def test_read_resource_with_pre_fetch_hook(self, resource_service_with_plu # Verify context was passed correctly - check first call (pre-fetch) first_call = mock_manager.invoke_hook.call_args_list[0] - assert first_call[0][0] == HookType.RESOURCE_PRE_FETCH # hook_type + assert first_call[0][0] == ResourceHookType.RESOURCE_PRE_FETCH # hook_type assert first_call[0][1].uri == "test://resource" # payload assert first_call[0][2].request_id == "test-123" # global_context assert first_call[0][2].user == "testuser" @@ -161,7 +161,7 @@ async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_ """Test read_resource with URI modification by plugin.""" # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ResourceHookType service = resource_service_with_plugins mock_manager = service._plugin_manager @@ -184,7 +184,7 @@ async def test_read_resource_uri_modified_by_plugin(self, resource_service_with_ # Use side_effect to return different results based on hook type def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.RESOURCE_PRE_FETCH: + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: return ( PluginResult( continue_processing=True, @@ -214,7 +214,7 @@ async def test_read_resource_content_filtered_by_plugin(self, resource_service_w """Test read_resource with content filtering by post-fetch hook.""" # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ResourceHookType import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True @@ -250,7 +250,7 @@ def scalar_one_or_none_side_effect(*args, **kwargs): # Use side_effect to return different results based on hook type def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.RESOURCE_PRE_FETCH: + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: return ( PluginResult(continue_processing=True), {"context": "data"}, @@ -310,7 +310,7 @@ async def test_read_resource_post_fetch_blocking(self, resource_service_with_plu """Test read_resource blocked by post-fetch hook.""" # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ResourceHookType import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True @@ -331,7 +331,7 @@ async def test_read_resource_post_fetch_blocking(self, resource_service_with_plu # Use side_effect to allow pre-fetch but block on post-fetch def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.RESOURCE_PRE_FETCH: + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: return ( PluginResult(continue_processing=True), {"context": "data"}, @@ -392,7 +392,7 @@ async def test_read_resource_context_propagation(self, resource_service_with_plu """Test context propagation from pre-fetch to post-fetch.""" # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ResourceHookType import mcpgateway.services.resource_service as resource_service_mod resource_service_mod.PLUGINS_AVAILABLE = True @@ -416,7 +416,7 @@ async def test_read_resource_context_propagation(self, resource_service_with_plu # Use side_effect to return contexts from pre-fetch def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.RESOURCE_PRE_FETCH: + if hook_type == ResourceHookType.RESOURCE_PRE_FETCH: return ( PluginResult(continue_processing=True), test_contexts, diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 2504f7984..a795e1438 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -2233,7 +2233,7 @@ async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, m """Test invoking tool with successful plugin post-invoke hook.""" # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ToolHookType # Configure tool as REST mock_tool.integration_type = "REST" @@ -2261,7 +2261,7 @@ async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, m tool_service._plugin_manager = Mock() def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.TOOL_PRE_INVOKE: + if hook_type == ToolHookType.TOOL_PRE_INVOKE: return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) # POST_INVOKE return (mock_post_result, None) @@ -2309,13 +2309,12 @@ async def test_invoke_tool_with_plugin_post_invoke_modified_payload(self, tool_s mock_post_result.modified_payload = mock_modified_payload # First-Party - from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import PluginResult, ToolHookType tool_service._plugin_manager = Mock() def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.TOOL_PRE_INVOKE: + if hook_type == ToolHookType.TOOL_PRE_INVOKE: return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) # POST_INVOKE return (mock_post_result, None) @@ -2364,12 +2363,12 @@ async def test_invoke_tool_with_plugin_post_invoke_invalid_modified_payload(self # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ToolHookType tool_service._plugin_manager = Mock() def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.TOOL_PRE_INVOKE: + if hook_type == ToolHookType.TOOL_PRE_INVOKE: return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) # POST_INVOKE return (mock_post_result, None) @@ -2410,12 +2409,12 @@ async def test_invoke_tool_with_plugin_post_invoke_error_fail_on_error(self, too # Mock plugin manager with invoke_hook that raises error on POST_INVOKE # First-Party from mcpgateway.plugins.framework.models import PluginResult - from mcpgateway.plugins.mcp.entities import HookType + from mcpgateway.plugins.framework import ToolHookType tool_service._plugin_manager = Mock() def invoke_hook_side_effect(hook_type, payload, global_context, local_contexts=None, **kwargs): - if hook_type == HookType.TOOL_PRE_INVOKE: + if hook_type == ToolHookType.TOOL_PRE_INVOKE: return (PluginResult(continue_processing=True, violation=None, modified_payload=None), None) # POST_INVOKE - raise error raise Exception("Plugin error") From 5c3b05ce4f6d8a3e929d2fbc8986c1bb8feae39c Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Sun, 2 Nov 2025 20:37:18 -0500 Subject: [PATCH 07/20] chore: fix lint issues Signed-off-by: Frederico Araujo --- mcpgateway/plugins/framework/__init__.py | 26 +++---------------- mcpgateway/plugins/framework/base.py | 16 +++++------- mcpgateway/plugins/framework/hooks/agents.py | 6 +++-- mcpgateway/plugins/framework/hooks/http.py | 2 ++ mcpgateway/plugins/framework/hooks/prompts.py | 10 ++----- .../plugins/framework/hooks/resources.py | 3 +++ mcpgateway/plugins/framework/hooks/tools.py | 4 ++- mcpgateway/services/prompt_service.py | 8 +----- mcpgateway/services/resource_service.py | 8 +----- mcpgateway/services/tool_service.py | 11 +------- 10 files changed, 27 insertions(+), 67 deletions(-) diff --git a/mcpgateway/plugins/framework/__init__.py b/mcpgateway/plugins/framework/__init__.py index ac5e4acb6..7783d788a 100644 --- a/mcpgateway/plugins/framework/__init__.py +++ b/mcpgateway/plugins/framework/__init__.py @@ -22,20 +22,8 @@ from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload -from mcpgateway.plugins.framework.hooks.agents import ( - AgentHookType, - AgentPostInvokePayload, - AgentPostInvokeResult, - AgentPreInvokePayload, - AgentPreInvokeResult -) -from mcpgateway.plugins.framework.hooks.resources import ( - ResourceHookType, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult -) +from mcpgateway.plugins.framework.hooks.agents import AgentHookType, AgentPostInvokePayload, AgentPostInvokeResult, AgentPreInvokePayload, AgentPreInvokeResult +from mcpgateway.plugins.framework.hooks.resources import ResourceHookType, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, ResourcePreFetchResult from mcpgateway.plugins.framework.hooks.prompts import ( PromptHookType, PromptPosthookPayload, @@ -43,13 +31,7 @@ PromptPrehookPayload, PromptPrehookResult, ) -from mcpgateway.plugins.framework.hooks.tools import ( - ToolHookType, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokeResult, - ToolPreInvokePayload -) +from mcpgateway.plugins.framework.hooks.tools import ToolHookType, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokeResult, ToolPreInvokePayload from mcpgateway.plugins.framework.models import ( GlobalContext, MCPServerConfig, @@ -103,5 +85,5 @@ "ToolPostInvokePayload", "ToolPostInvokeResult", "ToolPreInvokeResult", - "ToolPreInvokePayload" + "ToolPreInvokePayload", ] diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 759c36687..c41aac070 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -414,8 +414,7 @@ def __init__(self, hook: str, plugin_ref: PluginRef): if not self._func: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_ref.plugin.name}' has no hook: '{hook}'. " - f"Method must either be named '{hook}' or decorated with @hook('{hook}')", + message=f"Plugin '{plugin_ref.plugin.name}' has no hook: '{hook}'. " f"Method must either be named '{hook}' or decorated with @hook('{hook}')", plugin_name=plugin_ref.plugin.name, ) ) @@ -510,6 +509,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n except Exception as e: # Type hints might use forward references or unavailable types # We'll skip validation rather than fail + # Standard import logging logger = logging.getLogger(__name__) @@ -521,8 +521,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if payload_param_name not in hints: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' missing type hint for parameter '{payload_param_name}'. " - f"Expected: {payload_param_name}: {expected_payload_type.__name__}", + message=f"Plugin '{plugin_name}' hook '{hook}' missing type hint for parameter '{payload_param_name}'. " f"Expected: {payload_param_name}: {expected_payload_type.__name__}", plugin_name=plugin_name, ) ) @@ -539,8 +538,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if expected_type_str not in actual_type_str: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' parameter '{payload_param_name}' " - f"has incorrect type hint. Expected: {expected_type_str}, Got: {actual_type_str}", + message=f"Plugin '{plugin_name}' hook '{hook}' parameter '{payload_param_name}' " f"has incorrect type hint. Expected: {expected_type_str}, Got: {actual_type_str}", plugin_name=plugin_name, ) ) @@ -549,8 +547,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if "return" not in hints: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' missing return type hint. " - f"Expected: -> {expected_result_type.__name__}", + message=f"Plugin '{plugin_name}' hook '{hook}' missing return type hint. " f"Expected: -> {expected_result_type.__name__}", plugin_name=plugin_name, ) ) @@ -564,8 +561,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if expected_return_str not in return_type_str and actual_return_type != expected_result_type: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' has incorrect return type hint. " - f"Expected: {expected_return_str}, Got: {return_type_str}", + message=f"Plugin '{plugin_name}' hook '{hook}' has incorrect return type hint. " f"Expected: {expected_return_str}, Got: {return_type_str}", plugin_name=plugin_name, ) ) diff --git a/mcpgateway/plugins/framework/hooks/agents.py b/mcpgateway/plugins/framework/hooks/agents.py index c748aadea..db99139b3 100644 --- a/mcpgateway/plugins/framework/hooks/agents.py +++ b/mcpgateway/plugins/framework/hooks/agents.py @@ -18,8 +18,8 @@ # First-Party from mcpgateway.common.models import Message -from mcpgateway.plugins.framework.models import PluginPayload, PluginResult from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult class AgentHookType(str, Enum): @@ -122,6 +122,7 @@ class AgentPostInvokePayload(PluginPayload): AgentPreInvokeResult = PluginResult[AgentPreInvokePayload] AgentPostInvokeResult = PluginResult[AgentPostInvokePayload] + def _register_agent_hooks(): """Register agent hooks in the global registry. @@ -138,4 +139,5 @@ def _register_agent_hooks(): registry.register_hook(AgentHookType.AGENT_PRE_INVOKE, AgentPreInvokePayload, AgentPreInvokeResult) registry.register_hook(AgentHookType.AGENT_POST_INVOKE, AgentPostInvokePayload, AgentPostInvokeResult) -_register_agent_hooks() \ No newline at end of file + +_register_agent_hooks() diff --git a/mcpgateway/plugins/framework/hooks/http.py b/mcpgateway/plugins/framework/hooks/http.py index 34513adcc..675bc285c 100644 --- a/mcpgateway/plugins/framework/hooks/http.py +++ b/mcpgateway/plugins/framework/hooks/http.py @@ -7,11 +7,13 @@ Pydantic models for http hooks and payloads. """ +# Third-Party from pydantic import RootModel # First-Party from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + class HttpHeaderPayload(RootModel[dict[str, str]], PluginPayload): """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" diff --git a/mcpgateway/plugins/framework/hooks/prompts.py b/mcpgateway/plugins/framework/hooks/prompts.py index faee02c42..a2349530f 100644 --- a/mcpgateway/plugins/framework/hooks/prompts.py +++ b/mcpgateway/plugins/framework/hooks/prompts.py @@ -105,6 +105,7 @@ class PromptPosthookPayload(PluginPayload): PromptPrehookResult = PluginResult[PromptPrehookPayload] PromptPosthookResult = PluginResult[PromptPosthookPayload] + def _register_prompt_hooks(): """Register prompt hooks in the global registry. @@ -121,12 +122,5 @@ def _register_prompt_hooks(): registry.register_hook(PromptHookType.PROMPT_PRE_FETCH, PromptPrehookPayload, PromptPrehookResult) registry.register_hook(PromptHookType.PROMPT_POST_FETCH, PromptPosthookPayload, PromptPosthookResult) -_register_prompt_hooks() - - - - - - - +_register_prompt_hooks() diff --git a/mcpgateway/plugins/framework/hooks/resources.py b/mcpgateway/plugins/framework/hooks/resources.py index 8d5c7058b..cf5390bbe 100644 --- a/mcpgateway/plugins/framework/hooks/resources.py +++ b/mcpgateway/plugins/framework/hooks/resources.py @@ -39,6 +39,7 @@ class ResourceHookType(str, Enum): RESOURCE_PRE_FETCH = "resource_pre_fetch" RESOURCE_POST_FETCH = "resource_post_fetch" + class ResourcePreFetchPayload(PluginPayload): """A resource payload for a resource pre-fetch hook. @@ -94,6 +95,7 @@ class ResourcePostFetchPayload(PluginPayload): ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] + def _register_resource_hooks(): """Register resource hooks in the global registry. @@ -110,4 +112,5 @@ def _register_resource_hooks(): registry.register_hook(ResourceHookType.RESOURCE_PRE_FETCH, ResourcePreFetchPayload, ResourcePreFetchResult) registry.register_hook(ResourceHookType.RESOURCE_POST_FETCH, ResourcePostFetchPayload, ResourcePostFetchResult) + _register_resource_hooks() diff --git a/mcpgateway/plugins/framework/hooks/tools.py b/mcpgateway/plugins/framework/hooks/tools.py index 16afbae36..b9d804958 100644 --- a/mcpgateway/plugins/framework/hooks/tools.py +++ b/mcpgateway/plugins/framework/hooks/tools.py @@ -15,8 +15,9 @@ from pydantic import Field # First-Party -from mcpgateway.plugins.framework.models import PluginPayload, PluginResult from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + class ToolHookType(str, Enum): """MCP Forge Gateway hook points. @@ -97,6 +98,7 @@ class ToolPostInvokePayload(PluginPayload): ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] + def _register_tool_hooks(): """Register Tool hooks in the global registry. diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 30fd601fb..616991964 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -36,13 +36,7 @@ from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import ( - GlobalContext, - PluginManager, - PromptHookType, - PromptPosthookPayload, - PromptPrehookPayload -) +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.metrics_common import build_top_performers diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 6790a156b..3b9fbb662 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -56,13 +56,7 @@ # Plugin support imports (conditional) try: # First-Party - from mcpgateway.plugins.framework import ( - GlobalContext, - PluginManager, - ResourceHookType, - ResourcePostFetchPayload, - ResourcePreFetchPayload - ) + from mcpgateway.plugins.framework import GlobalContext, PluginManager, ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload PLUGINS_AVAILABLE = True except ImportError: diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index fd992a4f7..4c32b18b1 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -49,16 +49,7 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import ( - GlobalContext, - PluginError, - PluginManager, - PluginViolationError, - ToolHookType, - HttpHeaderPayload, - ToolPostInvokePayload, - ToolPreInvokePayload -) +from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginError, PluginManager, PluginViolationError, ToolHookType, ToolPostInvokePayload, ToolPreInvokePayload from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService From 4df81f7f6aad208b997f33f27efbf1f70ccede08 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Mon, 3 Nov 2025 10:32:38 -0700 Subject: [PATCH 08/20] feat: add comparison function to deal with PluginCondition Signed-off-by: Teryl Taylor --- mcpgateway/plugins/framework/__init__.py | 26 +- mcpgateway/plugins/framework/base.py | 16 +- mcpgateway/plugins/framework/hooks/agents.py | 6 +- mcpgateway/plugins/framework/hooks/http.py | 2 + mcpgateway/plugins/framework/hooks/prompts.py | 10 +- .../plugins/framework/hooks/resources.py | 3 + mcpgateway/plugins/framework/hooks/tools.py | 4 +- mcpgateway/plugins/framework/manager.py | 13 +- mcpgateway/plugins/framework/models.py | 4 +- mcpgateway/plugins/framework/utils.py | 115 +++++++ mcpgateway/services/prompt_service.py | 8 +- mcpgateway/services/resource_service.py | 8 +- mcpgateway/services/tool_service.py | 11 +- .../framework/test_manager_extended.py | 312 +++++++++++++++--- .../plugins/framework/test_utils.py | 310 +++++++++-------- 15 files changed, 567 insertions(+), 281 deletions(-) diff --git a/mcpgateway/plugins/framework/__init__.py b/mcpgateway/plugins/framework/__init__.py index ac5e4acb6..7783d788a 100644 --- a/mcpgateway/plugins/framework/__init__.py +++ b/mcpgateway/plugins/framework/__init__.py @@ -22,20 +22,8 @@ from mcpgateway.plugins.framework.loader.plugin import PluginLoader from mcpgateway.plugins.framework.manager import PluginManager from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload -from mcpgateway.plugins.framework.hooks.agents import ( - AgentHookType, - AgentPostInvokePayload, - AgentPostInvokeResult, - AgentPreInvokePayload, - AgentPreInvokeResult -) -from mcpgateway.plugins.framework.hooks.resources import ( - ResourceHookType, - ResourcePostFetchPayload, - ResourcePostFetchResult, - ResourcePreFetchPayload, - ResourcePreFetchResult -) +from mcpgateway.plugins.framework.hooks.agents import AgentHookType, AgentPostInvokePayload, AgentPostInvokeResult, AgentPreInvokePayload, AgentPreInvokeResult +from mcpgateway.plugins.framework.hooks.resources import ResourceHookType, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, ResourcePreFetchResult from mcpgateway.plugins.framework.hooks.prompts import ( PromptHookType, PromptPosthookPayload, @@ -43,13 +31,7 @@ PromptPrehookPayload, PromptPrehookResult, ) -from mcpgateway.plugins.framework.hooks.tools import ( - ToolHookType, - ToolPostInvokePayload, - ToolPostInvokeResult, - ToolPreInvokeResult, - ToolPreInvokePayload -) +from mcpgateway.plugins.framework.hooks.tools import ToolHookType, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokeResult, ToolPreInvokePayload from mcpgateway.plugins.framework.models import ( GlobalContext, MCPServerConfig, @@ -103,5 +85,5 @@ "ToolPostInvokePayload", "ToolPostInvokeResult", "ToolPreInvokeResult", - "ToolPreInvokePayload" + "ToolPreInvokePayload", ] diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 759c36687..c41aac070 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -414,8 +414,7 @@ def __init__(self, hook: str, plugin_ref: PluginRef): if not self._func: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_ref.plugin.name}' has no hook: '{hook}'. " - f"Method must either be named '{hook}' or decorated with @hook('{hook}')", + message=f"Plugin '{plugin_ref.plugin.name}' has no hook: '{hook}'. " f"Method must either be named '{hook}' or decorated with @hook('{hook}')", plugin_name=plugin_ref.plugin.name, ) ) @@ -510,6 +509,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n except Exception as e: # Type hints might use forward references or unavailable types # We'll skip validation rather than fail + # Standard import logging logger = logging.getLogger(__name__) @@ -521,8 +521,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if payload_param_name not in hints: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' missing type hint for parameter '{payload_param_name}'. " - f"Expected: {payload_param_name}: {expected_payload_type.__name__}", + message=f"Plugin '{plugin_name}' hook '{hook}' missing type hint for parameter '{payload_param_name}'. " f"Expected: {payload_param_name}: {expected_payload_type.__name__}", plugin_name=plugin_name, ) ) @@ -539,8 +538,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if expected_type_str not in actual_type_str: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' parameter '{payload_param_name}' " - f"has incorrect type hint. Expected: {expected_type_str}, Got: {actual_type_str}", + message=f"Plugin '{plugin_name}' hook '{hook}' parameter '{payload_param_name}' " f"has incorrect type hint. Expected: {expected_type_str}, Got: {actual_type_str}", plugin_name=plugin_name, ) ) @@ -549,8 +547,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if "return" not in hints: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' missing return type hint. " - f"Expected: -> {expected_result_type.__name__}", + message=f"Plugin '{plugin_name}' hook '{hook}' missing return type hint. " f"Expected: -> {expected_result_type.__name__}", plugin_name=plugin_name, ) ) @@ -564,8 +561,7 @@ def _validate_type_hints(self, hook: str, func: Callable, params: list, plugin_n if expected_return_str not in return_type_str and actual_return_type != expected_result_type: raise PluginError( error=PluginErrorModel( - message=f"Plugin '{plugin_name}' hook '{hook}' has incorrect return type hint. " - f"Expected: {expected_return_str}, Got: {return_type_str}", + message=f"Plugin '{plugin_name}' hook '{hook}' has incorrect return type hint. " f"Expected: {expected_return_str}, Got: {return_type_str}", plugin_name=plugin_name, ) ) diff --git a/mcpgateway/plugins/framework/hooks/agents.py b/mcpgateway/plugins/framework/hooks/agents.py index c748aadea..db99139b3 100644 --- a/mcpgateway/plugins/framework/hooks/agents.py +++ b/mcpgateway/plugins/framework/hooks/agents.py @@ -18,8 +18,8 @@ # First-Party from mcpgateway.common.models import Message -from mcpgateway.plugins.framework.models import PluginPayload, PluginResult from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult class AgentHookType(str, Enum): @@ -122,6 +122,7 @@ class AgentPostInvokePayload(PluginPayload): AgentPreInvokeResult = PluginResult[AgentPreInvokePayload] AgentPostInvokeResult = PluginResult[AgentPostInvokePayload] + def _register_agent_hooks(): """Register agent hooks in the global registry. @@ -138,4 +139,5 @@ def _register_agent_hooks(): registry.register_hook(AgentHookType.AGENT_PRE_INVOKE, AgentPreInvokePayload, AgentPreInvokeResult) registry.register_hook(AgentHookType.AGENT_POST_INVOKE, AgentPostInvokePayload, AgentPostInvokeResult) -_register_agent_hooks() \ No newline at end of file + +_register_agent_hooks() diff --git a/mcpgateway/plugins/framework/hooks/http.py b/mcpgateway/plugins/framework/hooks/http.py index 34513adcc..675bc285c 100644 --- a/mcpgateway/plugins/framework/hooks/http.py +++ b/mcpgateway/plugins/framework/hooks/http.py @@ -7,11 +7,13 @@ Pydantic models for http hooks and payloads. """ +# Third-Party from pydantic import RootModel # First-Party from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + class HttpHeaderPayload(RootModel[dict[str, str]], PluginPayload): """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" diff --git a/mcpgateway/plugins/framework/hooks/prompts.py b/mcpgateway/plugins/framework/hooks/prompts.py index faee02c42..a2349530f 100644 --- a/mcpgateway/plugins/framework/hooks/prompts.py +++ b/mcpgateway/plugins/framework/hooks/prompts.py @@ -105,6 +105,7 @@ class PromptPosthookPayload(PluginPayload): PromptPrehookResult = PluginResult[PromptPrehookPayload] PromptPosthookResult = PluginResult[PromptPosthookPayload] + def _register_prompt_hooks(): """Register prompt hooks in the global registry. @@ -121,12 +122,5 @@ def _register_prompt_hooks(): registry.register_hook(PromptHookType.PROMPT_PRE_FETCH, PromptPrehookPayload, PromptPrehookResult) registry.register_hook(PromptHookType.PROMPT_POST_FETCH, PromptPosthookPayload, PromptPosthookResult) -_register_prompt_hooks() - - - - - - - +_register_prompt_hooks() diff --git a/mcpgateway/plugins/framework/hooks/resources.py b/mcpgateway/plugins/framework/hooks/resources.py index 8d5c7058b..cf5390bbe 100644 --- a/mcpgateway/plugins/framework/hooks/resources.py +++ b/mcpgateway/plugins/framework/hooks/resources.py @@ -39,6 +39,7 @@ class ResourceHookType(str, Enum): RESOURCE_PRE_FETCH = "resource_pre_fetch" RESOURCE_POST_FETCH = "resource_post_fetch" + class ResourcePreFetchPayload(PluginPayload): """A resource payload for a resource pre-fetch hook. @@ -94,6 +95,7 @@ class ResourcePostFetchPayload(PluginPayload): ResourcePreFetchResult = PluginResult[ResourcePreFetchPayload] ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] + def _register_resource_hooks(): """Register resource hooks in the global registry. @@ -110,4 +112,5 @@ def _register_resource_hooks(): registry.register_hook(ResourceHookType.RESOURCE_PRE_FETCH, ResourcePreFetchPayload, ResourcePreFetchResult) registry.register_hook(ResourceHookType.RESOURCE_POST_FETCH, ResourcePostFetchPayload, ResourcePostFetchResult) + _register_resource_hooks() diff --git a/mcpgateway/plugins/framework/hooks/tools.py b/mcpgateway/plugins/framework/hooks/tools.py index 16afbae36..b9d804958 100644 --- a/mcpgateway/plugins/framework/hooks/tools.py +++ b/mcpgateway/plugins/framework/hooks/tools.py @@ -15,8 +15,9 @@ from pydantic import Field # First-Party -from mcpgateway.plugins.framework.models import PluginPayload, PluginResult from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.models import PluginPayload, PluginResult + class ToolHookType(str, Enum): """MCP Forge Gateway hook points. @@ -97,6 +98,7 @@ class ToolPostInvokePayload(PluginPayload): ToolPreInvokeResult = PluginResult[ToolPreInvokePayload] ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] + def _register_tool_hooks(): """Register Tool hooks in the global registry. diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 9c312e782..e0d5c92db 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -49,6 +49,7 @@ PluginResult, ) from mcpgateway.plugins.framework.registry import PluginInstanceRegistry +from mcpgateway.plugins.framework.utils import payload_matches # Use standard logging to avoid circular imports (plugins -> services -> plugins) logger = logging.getLogger(__name__) @@ -105,15 +106,17 @@ async def execute( hook_refs: list[HookRef], payload: PluginPayload, global_context: GlobalContext, + hook_type: str, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False, ) -> tuple[PluginResult, PluginContextTable | None]: """Execute plugins in priority order with timeout protection. Args: - plugins: List of plugins to execute, sorted by priority. + hook_refs: List of hook references to execute, sorted by priority. payload: The payload to be processed by plugins. global_context: Shared context for all plugins containing request metadata. + hook_type: The hook type identifier (e.g., "tool_pre_invoke"). local_contexts: Optional existing contexts from previous hook executions. violations_as_exceptions: Raise violations as exceptions rather than as returns. @@ -158,9 +161,9 @@ async def execute( continue # Check if plugin conditions match current context - # if pluginref.conditions and not compare(payload, pluginref.conditions, global_context): - # logger.debug(f"Skipping plugin {pluginref.name} - conditions not met") - # continue + if hook_ref.plugin_ref.conditions and not payload_matches(payload, hook_type, hook_ref.plugin_ref.conditions, global_context): + logger.debug("Skipping plugin %s - conditions not met", hook_ref.plugin_ref.name) + continue tmp_global_context = GlobalContext( request_id=global_context.request_id, @@ -552,7 +555,7 @@ async def invoke_hook( hook_refs = self._registry.get_hook_refs_for_hook(hook_type=hook_type) # Execute plugins - result = await self._executor.execute(hook_refs, payload, global_context, local_contexts, violations_as_exceptions) + result = await self._executor.execute(hook_refs, payload, global_context, hook_type, local_contexts, violations_as_exceptions) return result diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index 3e7cb1222..84893ffc8 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -173,6 +173,7 @@ class PluginCondition(BaseModel): tools (Optional[set[str]]): set of tool names. prompts (Optional[set[str]]): set of prompt names. resources (Optional[set[str]]): set of resource URIs. + agents (Optional[set[str]]): set of agent IDs. user_pattern (Optional[list[str]]): list of user patterns. content_types (Optional[list[str]]): list of content types. @@ -193,10 +194,11 @@ class PluginCondition(BaseModel): tools: Optional[set[str]] = None prompts: Optional[set[str]] = None resources: Optional[set[str]] = None + agents: Optional[set[str]] = None user_patterns: Optional[list[str]] = None content_types: Optional[list[str]] = None - @field_serializer("server_ids", "tenant_ids", "tools", "prompts") + @field_serializer("server_ids", "tenant_ids", "tools", "prompts", "resources", "agents") def serialize_set(self, value: set[str] | None) -> list[str] | None: """Serialize set objects in PluginCondition for MCP. diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py index 50046277d..0d40e01ac 100644 --- a/mcpgateway/plugins/framework/utils.py +++ b/mcpgateway/plugins/framework/utils.py @@ -13,6 +13,7 @@ from functools import cache import importlib from types import ModuleType +from typing import Any, Optional # First-Party from mcpgateway.plugins.framework.models import ( @@ -114,6 +115,120 @@ def matches(condition: PluginCondition, context: GlobalContext) -> bool: return True +def get_matchable_value(payload: Any, hook_type: str) -> Optional[str]: + """Extract the matchable value from a payload based on hook type. + + This function maps hook types to their corresponding payload attributes + that should be used for conditional matching. + + Args: + payload: The payload object (e.g., ToolPreInvokePayload, AgentPreInvokePayload). + hook_type: The hook type identifier. + + Returns: + The matchable value (e.g., tool name, agent ID, resource URI) or None. + + Examples: + >>> from mcpgateway.plugins.framework import GlobalContext + >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload + >>> payload = ToolPreInvokePayload(name="calculator", args={}) + >>> get_matchable_value(payload, "tool_pre_invoke") + 'calculator' + >>> get_matchable_value(payload, "unknown_hook") + """ + # Mapping: hook_type -> payload attribute name + field_map = { + "tool_pre_invoke": "name", + "tool_post_invoke": "name", + "prompt_pre_fetch": "prompt_id", + "prompt_post_fetch": "prompt_id", + "resource_pre_fetch": "uri", + "resource_post_fetch": "uri", + "agent_pre_invoke": "agent_id", + "agent_post_invoke": "agent_id", + } + + field_name = field_map.get(hook_type) + if field_name: + return getattr(payload, field_name, None) + return None + + +def payload_matches( + payload: Any, + hook_type: str, + conditions: list[PluginCondition], + context: GlobalContext, +) -> bool: + """Check if a payload matches any of the plugin conditions. + + This function provides generic conditional matching for all hook types. + It checks both GlobalContext conditions (via matches()) and payload-specific + conditions (tools, prompts, resources, agents). + + Args: + payload: The payload object. + hook_type: The hook type identifier. + conditions: List of conditions to check against. + context: The global context. + + Returns: + True if the payload matches any condition or if no conditions are specified. + + Examples: + >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext + >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload + >>> payload = ToolPreInvokePayload(name="calculator", args={}) + >>> cond = PluginCondition(tools={"calculator"}) + >>> ctx = GlobalContext(request_id="req1") + >>> payload_matches(payload, "tool_pre_invoke", [cond], ctx) + True + >>> cond2 = PluginCondition(tools={"other_tool"}) + >>> payload_matches(payload, "tool_pre_invoke", [cond2], ctx) + False + >>> payload_matches(payload, "tool_pre_invoke", [], ctx) + True + """ + # Mapping: hook_type -> PluginCondition attribute name + condition_attr_map = { + "tool_pre_invoke": "tools", + "tool_post_invoke": "tools", + "prompt_pre_fetch": "prompts", + "prompt_post_fetch": "prompts", + "resource_pre_fetch": "resources", + "resource_post_fetch": "resources", + "agent_pre_invoke": "agents", + "agent_post_invoke": "agents", + } + + # If no conditions, match everything + if not conditions: + return True + + # Check each condition (OR logic between conditions) + for condition in conditions: + # First check GlobalContext conditions + if not matches(condition, context): + continue + + # Then check payload-specific conditions + condition_attr = condition_attr_map.get(hook_type) + if condition_attr: + condition_set = getattr(condition, condition_attr, None) + if condition_set: + # Extract the matchable value from the payload + payload_value = get_matchable_value(payload, hook_type) + if payload_value and payload_value not in condition_set: + # Payload value doesn't match this condition's set + continue + + # If we get here, this condition matched + return True + + # No conditions matched + return False + + # def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCondition], context: GlobalContext) -> bool: # """Check for a match on pre-prompt hooks. diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 30fd601fb..616991964 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -36,13 +36,7 @@ from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import ( - GlobalContext, - PluginManager, - PromptHookType, - PromptPosthookPayload, - PromptPrehookPayload -) +from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.utils.metrics_common import build_top_performers diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 6790a156b..3b9fbb662 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -56,13 +56,7 @@ # Plugin support imports (conditional) try: # First-Party - from mcpgateway.plugins.framework import ( - GlobalContext, - PluginManager, - ResourceHookType, - ResourcePostFetchPayload, - ResourcePreFetchPayload - ) + from mcpgateway.plugins.framework import GlobalContext, PluginManager, ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload PLUGINS_AVAILABLE = True except ImportError: diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index fd992a4f7..4c32b18b1 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -49,16 +49,7 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import ( - GlobalContext, - PluginError, - PluginManager, - PluginViolationError, - ToolHookType, - HttpHeaderPayload, - ToolPostInvokePayload, - ToolPreInvokePayload -) +from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginError, PluginManager, PluginViolationError, ToolHookType, ToolPostInvokePayload, ToolPreInvokePayload from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService diff --git a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py index dc037d8c8..0a7bd317f 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py +++ b/tests/unit/mcpgateway/plugins/framework/test_manager_extended.py @@ -177,57 +177,267 @@ async def prompt_pre_fetch(self, payload, context): await manager.shutdown() -# @pytest.mark.asyncio -# async def test_manager_condition_filtering(): -# """Test that plugins are filtered based on conditions.""" - -# class ConditionalPlugin(Plugin): -# async def prompt_pre_fetch(self, payload, context): -# payload.args["modified"] = "yes" -# return PluginResult(continue_processing=True, modified_payload=payload) - -# manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") -# await manager.initialize() - -# # Plugin with server_id condition -# plugin_config = PluginConfig( -# name="ConditionalPlugin", -# description="Test conditional plugin", -# author="Test", -# version="1.0", -# tags=["test"], -# kind="ConditionalPlugin", -# hooks=["prompt_pre_fetch"], -# config={}, -# conditions=[PluginCondition(server_ids={"server1"})], -# ) -# plugin = ConditionalPlugin(plugin_config) - -# with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: -# plugin_ref = PluginRef(plugin) -# mock_get.return_value = [plugin_ref] - -# prompt = PromptPrehookPayload(prompt_id="test", args={}) - -# # Test with matching server_id -# global_context = GlobalContext(request_id="1", server_id="server1") -# result, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) - -# # Plugin should execute -# assert result.continue_processing -# assert result.modified_payload is not None -# assert result.modified_payload.args.get("modified") == "yes" - -# # Test with non-matching server_id -# prompt2 = PromptPrehookPayload(prompt_id="test", args={}) -# global_context2 = GlobalContext(request_id="2", server_id="server2") -# result2, _ = await manager.invoke_hook(HookType.PROMPT_PRE_FETCH, prompt2, global_context=global_context2) - -# # Plugin should be skipped -# assert result2.continue_processing -# assert result2.modified_payload is None # No modification - -# await manager.shutdown() +@pytest.mark.asyncio +async def test_manager_condition_filtering(): + """Test that plugins are filtered based on conditions across all hook types.""" + from mcpgateway.plugins.framework import ( + ResourceHookType, + ResourcePreFetchPayload, + AgentHookType, + AgentPreInvokePayload, + ) + + manager = PluginManager("./tests/unit/mcpgateway/plugins/fixtures/configs/valid_no_plugin.yaml") + await manager.initialize() + + # ========== Test 1: Server ID condition (GlobalContext) ========== + class ConditionalPlugin(Plugin): + async def prompt_pre_fetch(self, payload, context): + payload.args["modified"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + plugin_config = PluginConfig( + name="ConditionalPlugin", + description="Test conditional plugin", + author="Test", + version="1.0", + tags=["test"], + kind="ConditionalPlugin", + hooks=["prompt_pre_fetch"], + config={}, + conditions=[PluginCondition(server_ids={"server1"})], + ) + plugin = ConditionalPlugin(plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + plugin_ref = PluginRef(plugin) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, plugin_ref) + mock_get.return_value = [hook_ref] + + prompt = PromptPrehookPayload(prompt_id="test", args={}) + + # Test with matching server_id + global_context = GlobalContext(request_id="1", server_id="server1") + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + + # Plugin should execute + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.args.get("modified") == "yes" + + # Test with non-matching server_id + prompt2 = PromptPrehookPayload(prompt_id="test", args={}) + global_context2 = GlobalContext(request_id="2", server_id="server2") + result2, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt2, global_context=global_context2) + + # Plugin should be skipped + assert result2.continue_processing + assert result2.modified_payload is None # No modification + + # ========== Test 2: Prompt-specific filtering ========== + class PromptFilterPlugin(Plugin): + async def prompt_pre_fetch(self, payload, context): + payload.args["prompt_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + prompt_plugin_config = PluginConfig( + name="PromptFilterPlugin", + description="Test prompt filtering", + author="Test", + version="1.0", + tags=["test"], + kind="PromptFilterPlugin", + hooks=["prompt_pre_fetch"], + config={}, + conditions=[PluginCondition(prompts={"greeting", "welcome"})], + ) + prompt_plugin = PromptFilterPlugin(prompt_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(prompt_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching prompt + prompt_match = PromptPrehookPayload(prompt_id="greeting", args={}) + global_context = GlobalContext(request_id="3") + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.args.get("prompt_filtered") == "yes" + + # Test with non-matching prompt + prompt_no_match = PromptPrehookPayload(prompt_id="other", args={}) + result2, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 3: Tool filtering ========== + class ToolFilterPlugin(Plugin): + async def tool_pre_invoke(self, payload, context): + payload.args["tool_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + tool_plugin_config = PluginConfig( + name="ToolFilterPlugin", + description="Test tool filtering", + author="Test", + version="1.0", + tags=["test"], + kind="ToolFilterPlugin", + hooks=["tool_pre_invoke"], + config={}, + conditions=[PluginCondition(tools={"calculator", "converter"})], + ) + tool_plugin = ToolFilterPlugin(tool_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ToolHookType.TOOL_PRE_INVOKE, PluginRef(tool_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching tool + tool_match = ToolPreInvokePayload(name="calculator", args={}) + global_context = GlobalContext(request_id="4") + result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.args.get("tool_filtered") == "yes" + + # Test with non-matching tool + tool_no_match = ToolPreInvokePayload(name="other_tool", args={}) + result2, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 4: Resource filtering ========== + class ResourceFilterPlugin(Plugin): + async def resource_pre_fetch(self, payload, context): + payload.metadata["resource_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + resource_plugin_config = PluginConfig( + name="ResourceFilterPlugin", + description="Test resource filtering", + author="Test", + version="1.0", + tags=["test"], + kind="ResourceFilterPlugin", + hooks=["resource_pre_fetch"], + config={}, + conditions=[PluginCondition(resources={"file:///data.txt", "file:///config.json"})], + ) + resource_plugin = ResourceFilterPlugin(resource_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ResourceHookType.RESOURCE_PRE_FETCH, PluginRef(resource_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching resource + resource_match = ResourcePreFetchPayload(uri="file:///data.txt", metadata={}) + global_context = GlobalContext(request_id="5") + result, _ = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, resource_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.metadata.get("resource_filtered") == "yes" + + # Test with non-matching resource + resource_no_match = ResourcePreFetchPayload(uri="file:///other.txt", metadata={}) + result2, _ = await manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, resource_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 5: Agent filtering ========== + class AgentFilterPlugin(Plugin): + async def agent_pre_invoke(self, payload, context): + payload.parameters["agent_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + agent_plugin_config = PluginConfig( + name="AgentFilterPlugin", + description="Test agent filtering", + author="Test", + version="1.0", + tags=["test"], + kind="AgentFilterPlugin", + hooks=["agent_pre_invoke"], + config={}, + conditions=[PluginCondition(agents={"agent1", "agent2"})], + ) + agent_plugin = AgentFilterPlugin(agent_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(AgentHookType.AGENT_PRE_INVOKE, PluginRef(agent_plugin)) + mock_get.return_value = [hook_ref] + + # Test with matching agent + agent_match = AgentPreInvokePayload(agent_id="agent1", messages=[], parameters={}) + global_context = GlobalContext(request_id="6") + result, _ = await manager.invoke_hook(AgentHookType.AGENT_PRE_INVOKE, agent_match, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.parameters.get("agent_filtered") == "yes" + + # Test with non-matching agent + agent_no_match = AgentPreInvokePayload(agent_id="agent3", messages=[], parameters={}) + result2, _ = await manager.invoke_hook(AgentHookType.AGENT_PRE_INVOKE, agent_no_match, global_context=global_context) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # ========== Test 6: Combined conditions (server_id + tool name) ========== + class CombinedFilterPlugin(Plugin): + async def tool_pre_invoke(self, payload, context): + payload.args["combined_filtered"] = "yes" + return PluginResult(continue_processing=True, modified_payload=payload) + + combined_plugin_config = PluginConfig( + name="CombinedFilterPlugin", + description="Test combined filtering", + author="Test", + version="1.0", + tags=["test"], + kind="CombinedFilterPlugin", + hooks=["tool_pre_invoke"], + config={}, + conditions=[PluginCondition(server_ids={"server1"}, tools={"calculator"})], + ) + combined_plugin = CombinedFilterPlugin(combined_plugin_config) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(ToolHookType.TOOL_PRE_INVOKE, PluginRef(combined_plugin)) + mock_get.return_value = [hook_ref] + + # Test with both conditions matching + tool_payload = ToolPreInvokePayload(name="calculator", args={}) + global_context = GlobalContext(request_id="7", server_id="server1") + result, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context) + + assert result.continue_processing + assert result.modified_payload is not None + assert result.modified_payload.args.get("combined_filtered") == "yes" + + # Test with server_id mismatch + global_context2 = GlobalContext(request_id="8", server_id="server2") + result2, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload, global_context=global_context2) + + assert result2.continue_processing + assert result2.modified_payload is None # Plugin skipped + + # Test with tool name mismatch + tool_payload2 = ToolPreInvokePayload(name="other_tool", args={}) + global_context3 = GlobalContext(request_id="9", server_id="server1") + result3, _ = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, tool_payload2, global_context=global_context3) + + assert result3.continue_processing + assert result3.modified_payload is None # Plugin skipped + + await manager.shutdown() @pytest.mark.asyncio diff --git a/tests/unit/mcpgateway/plugins/framework/test_utils.py b/tests/unit/mcpgateway/plugins/framework/test_utils.py index 00e0e51dd..7b27626b0 100644 --- a/tests/unit/mcpgateway/plugins/framework/test_utils.py +++ b/tests/unit/mcpgateway/plugins/framework/test_utils.py @@ -11,51 +11,58 @@ import sys # First-Party -from mcpgateway.plugins.framework import GlobalContext, PluginCondition -from mcpgateway.plugins.framework.utils import import_module, matches, parse_class_name #, post_prompt_matches, post_tool_matches, pre_prompt_matches, pre_tool_matches -#from mcpgateway.plugins.mcp.entities import PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import ( + GlobalContext, + PluginCondition, + PromptPrehookPayload, + PromptPosthookPayload, + ToolPreInvokePayload, + ToolPostInvokePayload, +) +from mcpgateway.plugins.framework.utils import import_module, matches, parse_class_name, payload_matches -# def test_server_ids(): -# condition1 = PluginCondition(server_ids={"1", "2"}) -# context1 = GlobalContext(server_id="1", tenant_id="4", request_id="5") +def test_server_ids(): + """Test conditional matching with server IDs, tenant IDs, and user patterns.""" + condition1 = PluginCondition(server_ids={"1", "2"}) + context1 = GlobalContext(server_id="1", tenant_id="4", request_id="5") -# payload1 = PromptPrehookPayload(prompt_id="test_prompt", args={}) + payload1 = PromptPrehookPayload(prompt_id="test_prompt", args={}) -# assert matches(condition=condition1, context=context1) -# assert pre_prompt_matches(payload1, [condition1], context1) + assert matches(condition=condition1, context=context1) + assert payload_matches(payload1, "prompt_pre_fetch", [condition1], context1) -# context2 = GlobalContext(server_id="3", tenant_id="6", request_id="1") -# assert not matches(condition=condition1, context=context2) -# assert not pre_prompt_matches(payload1, conditions=[condition1], context=context2) + context2 = GlobalContext(server_id="3", tenant_id="6", request_id="1") + assert not matches(condition=condition1, context=context2) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition1], context2) -# condition2 = PluginCondition(server_ids={"1"}, tenant_ids={"4"}) + condition2 = PluginCondition(server_ids={"1"}, tenant_ids={"4"}) -# context2 = GlobalContext(server_id="1", tenant_id="4", request_id="1") + context2 = GlobalContext(server_id="1", tenant_id="4", request_id="1") -# assert matches(condition2, context2) -# assert pre_prompt_matches(payload1, conditions=[condition2], context=context2) + assert matches(condition2, context2) + assert payload_matches(payload1, "prompt_pre_fetch", [condition2], context2) -# context3 = GlobalContext(server_id="1", tenant_id="5", request_id="1") + context3 = GlobalContext(server_id="1", tenant_id="5", request_id="1") -# assert not matches(condition2, context3) -# assert not pre_prompt_matches(payload1, conditions=[condition2], context=context3) + assert not matches(condition2, context3) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition2], context3) -# condition4 = PluginCondition(user_patterns=["blah", "barker", "bobby"]) -# context4 = GlobalContext(user="blah", request_id="1") + condition4 = PluginCondition(user_patterns=["blah", "barker", "bobby"]) + context4 = GlobalContext(user="blah", request_id="1") -# assert matches(condition4, context4) -# assert pre_prompt_matches(payload1, conditions=[condition4], context=context4) + assert matches(condition4, context4) + assert payload_matches(payload1, "prompt_pre_fetch", [condition4], context4) -# context5 = GlobalContext(user="barney", request_id="1") -# assert not matches(condition4, context5) -# assert not pre_prompt_matches(payload1, conditions=[condition4], context=context5) + context5 = GlobalContext(user="barney", request_id="1") + assert not matches(condition4, context5) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition4], context5) -# condition5 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt"}) + condition5 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt"}) -# assert pre_prompt_matches(payload1, [condition5], context1) -# condition6 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt2"}) -# assert not pre_prompt_matches(payload1, [condition6], context1) + assert payload_matches(payload1, "prompt_pre_fetch", [condition5], context1) + condition6 = PluginCondition(server_ids={"1", "2"}, prompts={"test_prompt2"}) + assert not payload_matches(payload1, "prompt_pre_fetch", [condition6], context1) # ============================================================================ @@ -107,191 +114,180 @@ def test_parse_class_name(): # ============================================================================ -# Test post_prompt_matches function +# Test payload_matches for prompt hooks # ============================================================================ -# def test_post_prompt_matches(): -# """Test the post_prompt_matches function.""" -# # Import required models -# # First-Party -# from mcpgateway.common.models import Message, PromptResult, TextContent +def test_payload_matches_prompt_post_fetch(): + """Test payload_matches for prompt_post_fetch hook.""" + # Test basic matching + payload = PromptPosthookPayload(prompt_id="greeting", result={"messages": []}) + condition = PluginCondition(prompts={"greeting"}) + context = GlobalContext(request_id="req1") -# # Test basic matching -# msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) -# result = PromptResult(messages=[msg]) -# payload = PromptPosthookPayload(prompt_id="greeting", result=result) -# condition = PluginCondition(prompts={"greeting"}) -# context = GlobalContext(request_id="req1") + assert payload_matches(payload, "prompt_post_fetch", [condition], context) is True -# assert post_prompt_matches(payload, [condition], context) is True + # Test no match + payload2 = PromptPosthookPayload(prompt_id="other", result={"messages": []}) + assert payload_matches(payload2, "prompt_post_fetch", [condition], context) is False -# # Test no match -# payload2 = PromptPosthookPayload(prompt_id ="other", result=result) -# assert post_prompt_matches(payload2, [condition], context) is False + # Test with server_id condition + condition_with_server = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) + context_with_server = GlobalContext(request_id="req1", server_id="srv1") -# # Test with server_id condition -# condition_with_server = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) -# context_with_server = GlobalContext(request_id="req1", server_id="srv1") + assert payload_matches(payload, "prompt_post_fetch", [condition_with_server], context_with_server) is True -# assert post_prompt_matches(payload, [condition_with_server], context_with_server) is True + # Test with mismatched server_id + context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") + assert payload_matches(payload, "prompt_post_fetch", [condition_with_server], context_wrong_server) is False -# # Test with mismatched server_id -# context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") -# assert post_prompt_matches(payload, [condition_with_server], context_wrong_server) is False +def test_payload_matches_prompt_multiple_conditions(): + """Test payload_matches for prompts with multiple conditions (OR logic).""" + # Create the payload + payload = PromptPosthookPayload(prompt_id="greeting", result={"messages": []}) -# def test_post_prompt_matches_multiple_conditions(): -# """Test post_prompt_matches with multiple conditions (OR logic).""" -# # First-Party -# from mcpgateway.common.models import Message, PromptResult, TextContent + # First condition fails, second condition succeeds + condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) + condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) + context = GlobalContext(request_id="req1", server_id="srv2") -# # Create the payload -# msg = Message(role="assistant", content=TextContent(type="text", text="Hello")) -# result = PromptResult(messages=[msg]) -# payload = PromptPosthookPayload(prompt_id="greeting", result=result) + assert payload_matches(payload, "prompt_post_fetch", [condition1, condition2], context) is True -# # First condition fails, second condition succeeds -# condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) -# condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) -# context = GlobalContext(request_id="req1", server_id="srv2") + # Both conditions fail + context_no_match = GlobalContext(request_id="req1", server_id="srv3") + assert payload_matches(payload, "prompt_post_fetch", [condition1, condition2], context_no_match) is False -# assert post_prompt_matches(payload, [condition1, condition2], context) is True - -# # Both conditions fail -# context_no_match = GlobalContext(request_id="req1", server_id="srv3") -# assert post_prompt_matches(payload, [condition1, condition2], context_no_match) is False - -# # Test reset logic between conditions -# condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) -# condition4 = PluginCondition(prompts={"greeting"}) -# assert post_prompt_matches(payload, [condition3, condition4], context_no_match) is True + # Test reset logic between conditions + condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) + condition4 = PluginCondition(prompts={"greeting"}) + assert payload_matches(payload, "prompt_post_fetch", [condition3, condition4], context_no_match) is True # ============================================================================ -# Test pre_tool_matches function +# Test payload_matches for tool hooks # ============================================================================ -# def test_pre_tool_matches(): -# """Test the pre_tool_matches function.""" -# # Test basic matching -# payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) -# condition = PluginCondition(tools={"calculator"}) -# context = GlobalContext(request_id="req1") +def test_payload_matches_tool_pre_invoke(): + """Test payload_matches for tool_pre_invoke hook.""" + # Test basic matching + payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) + condition = PluginCondition(tools={"calculator"}) + context = GlobalContext(request_id="req1") -# assert pre_tool_matches(payload, [condition], context) is True + assert payload_matches(payload, "tool_pre_invoke", [condition], context) is True -# # Test no match -# payload2 = ToolPreInvokePayload(name="other_tool", args={}) -# assert pre_tool_matches(payload2, [condition], context) is False + # Test no match + payload2 = ToolPreInvokePayload(name="other_tool", args={}) + assert payload_matches(payload2, "tool_pre_invoke", [condition], context) is False -# # Test with server_id condition -# condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) -# context_with_server = GlobalContext(request_id="req1", server_id="srv1") + # Test with server_id condition + condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) + context_with_server = GlobalContext(request_id="req1", server_id="srv1") -# assert pre_tool_matches(payload, [condition_with_server], context_with_server) is True + assert payload_matches(payload, "tool_pre_invoke", [condition_with_server], context_with_server) is True -# # Test with mismatched server_id -# context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") -# assert pre_tool_matches(payload, [condition_with_server], context_wrong_server) is False + # Test with mismatched server_id + context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") + assert payload_matches(payload, "tool_pre_invoke", [condition_with_server], context_wrong_server) is False -# def test_pre_tool_matches_multiple_conditions(): -# """Test pre_tool_matches with multiple conditions (OR logic).""" -# payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) +def test_payload_matches_tool_pre_invoke_multiple_conditions(): + """Test payload_matches for tool_pre_invoke with multiple conditions (OR logic).""" + payload = ToolPreInvokePayload(name="calculator", args={"operation": "add"}) -# # First condition fails, second condition succeeds -# condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) -# condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) -# context = GlobalContext(request_id="req1", server_id="srv2") + # First condition fails, second condition succeeds + condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) + condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) + context = GlobalContext(request_id="req1", server_id="srv2") -# assert pre_tool_matches(payload, [condition1, condition2], context) is True + assert payload_matches(payload, "tool_pre_invoke", [condition1, condition2], context) is True -# # Both conditions fail -# context_no_match = GlobalContext(request_id="req1", server_id="srv3") -# assert pre_tool_matches(payload, [condition1, condition2], context_no_match) is False + # Both conditions fail + context_no_match = GlobalContext(request_id="req1", server_id="srv3") + assert payload_matches(payload, "tool_pre_invoke", [condition1, condition2], context_no_match) is False -# # Test reset logic between conditions -# condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) -# condition4 = PluginCondition(tools={"calculator"}) -# assert pre_tool_matches(payload, [condition3, condition4], context_no_match) is True + # Test reset logic between conditions + condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) + condition4 = PluginCondition(tools={"calculator"}) + assert payload_matches(payload, "tool_pre_invoke", [condition3, condition4], context_no_match) is True # ============================================================================ -# Test post_tool_matches function +# Test payload_matches for tool_post_invoke # ============================================================================ -# def test_post_tool_matches(): -# """Test the post_tool_matches function.""" -# # Test basic matching -# payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) -# condition = PluginCondition(tools={"calculator"}) -# context = GlobalContext(request_id="req1") +def test_payload_matches_tool_post_invoke(): + """Test payload_matches for tool_post_invoke hook.""" + # Test basic matching + payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) + condition = PluginCondition(tools={"calculator"}) + context = GlobalContext(request_id="req1") -# assert post_tool_matches(payload, [condition], context) is True + assert payload_matches(payload, "tool_post_invoke", [condition], context) is True -# # Test no match -# payload2 = ToolPostInvokePayload(name="other_tool", result={}) -# assert post_tool_matches(payload2, [condition], context) is False + # Test no match + payload2 = ToolPostInvokePayload(name="other_tool", result={}) + assert payload_matches(payload2, "tool_post_invoke", [condition], context) is False -# # Test with server_id condition -# condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) -# context_with_server = GlobalContext(request_id="req1", server_id="srv1") + # Test with server_id condition + condition_with_server = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) + context_with_server = GlobalContext(request_id="req1", server_id="srv1") -# assert post_tool_matches(payload, [condition_with_server], context_with_server) is True + assert payload_matches(payload, "tool_post_invoke", [condition_with_server], context_with_server) is True -# # Test with mismatched server_id -# context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") -# assert post_tool_matches(payload, [condition_with_server], context_wrong_server) is False + # Test with mismatched server_id + context_wrong_server = GlobalContext(request_id="req1", server_id="srv2") + assert payload_matches(payload, "tool_post_invoke", [condition_with_server], context_wrong_server) is False -# def test_post_tool_matches_multiple_conditions(): -# """Test post_tool_matches with multiple conditions (OR logic).""" -# payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) +def test_payload_matches_tool_post_invoke_multiple_conditions(): + """Test payload_matches for tool_post_invoke with multiple conditions (OR logic).""" + payload = ToolPostInvokePayload(name="calculator", result={"value": 42}) -# # First condition fails, second condition succeeds -# condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) -# condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) -# context = GlobalContext(request_id="req1", server_id="srv2") + # First condition fails, second condition succeeds + condition1 = PluginCondition(server_ids={"srv1"}, tools={"calculator"}) + condition2 = PluginCondition(server_ids={"srv2"}, tools={"calculator"}) + context = GlobalContext(request_id="req1", server_id="srv2") -# assert post_tool_matches(payload, [condition1, condition2], context) is True + assert payload_matches(payload, "tool_post_invoke", [condition1, condition2], context) is True -# # Both conditions fail -# context_no_match = GlobalContext(request_id="req1", server_id="srv3") -# assert post_tool_matches(payload, [condition1, condition2], context_no_match) is False + # Both conditions fail + context_no_match = GlobalContext(request_id="req1", server_id="srv3") + assert payload_matches(payload, "tool_post_invoke", [condition1, condition2], context_no_match) is False -# # Test reset logic between conditions -# condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) -# condition4 = PluginCondition(tools={"calculator"}) -# assert post_tool_matches(payload, [condition3, condition4], context_no_match) is True + # Test reset logic between conditions + condition3 = PluginCondition(server_ids={"srv3"}, tools={"other"}) + condition4 = PluginCondition(tools={"calculator"}) + assert payload_matches(payload, "tool_post_invoke", [condition3, condition4], context_no_match) is True # ============================================================================ -# Test enhanced pre_prompt_matches scenarios +# Test payload_matches for prompt_pre_fetch with multiple conditions # ============================================================================ -# def test_pre_prompt_matches_multiple_conditions(): -# """Test pre_prompt_matches with multiple conditions to cover OR logic paths.""" -# payload = PromptPrehookPayload(prompt_id="greeting", args={}) +def test_payload_matches_prompt_pre_fetch_multiple_conditions(): + """Test payload_matches for prompt_pre_fetch with multiple conditions to cover OR logic paths.""" + payload = PromptPrehookPayload(prompt_id="greeting", args={}) -# # First condition fails, second condition succeeds -# condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) -# condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) -# context = GlobalContext(request_id="req1", server_id="srv2") + # First condition fails, second condition succeeds + condition1 = PluginCondition(server_ids={"srv1"}, prompts={"greeting"}) + condition2 = PluginCondition(server_ids={"srv2"}, prompts={"greeting"}) + context = GlobalContext(request_id="req1", server_id="srv2") -# assert pre_prompt_matches(payload, [condition1, condition2], context) is True + assert payload_matches(payload, "prompt_pre_fetch", [condition1, condition2], context) is True -# # Both conditions fail -# context_no_match = GlobalContext(request_id="req1", server_id="srv3") -# assert pre_prompt_matches(payload, [condition1, condition2], context_no_match) is False + # Both conditions fail + context_no_match = GlobalContext(request_id="req1", server_id="srv3") + assert payload_matches(payload, "prompt_pre_fetch", [condition1, condition2], context_no_match) is False -# # Test reset logic between conditions (line 140) -# condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) -# condition4 = PluginCondition(prompts={"greeting"}) -# assert pre_prompt_matches(payload, [condition3, condition4], context_no_match) is True + # Test reset logic between conditions (OR logic) + condition3 = PluginCondition(server_ids={"srv3"}, prompts={"other"}) + condition4 = PluginCondition(prompts={"greeting"}) + assert payload_matches(payload, "prompt_pre_fetch", [condition3, condition4], context_no_match) is True # ============================================================================ From 5f8bcbf573535972630877b82a51e34b8c33f844 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Mon, 3 Nov 2025 20:56:41 -0500 Subject: [PATCH 09/20] chore: removed unrecognized mypy option Signed-off-by: Frederico Araujo --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bd38514b0..c68e605f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -515,7 +515,6 @@ warn_unreachable = true # Warn about unreachable code warn_unused_ignores = true # Warn if a "# type: ignore" is unnecessary warn_unused_configs = true # Warn about unused config options warn_redundant_casts = true # Warn if a cast does nothing -warn_unused_coroutine = true # Warn if an unused async coroutine is defined strict_equality = true # Disallow ==/!= between incompatible types # Output formatting From d3e4ea953b8390d6f7ab93d0686614c49bd46cd8 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Mon, 3 Nov 2025 21:28:14 -0500 Subject: [PATCH 10/20] fix: static type check issues Signed-off-by: Frederico Araujo --- mcpgateway/config.py | 2 +- mcpgateway/plugins/framework/base.py | 2 +- mcpgateway/plugins/framework/external/mcp/client.py | 4 ++-- mcpgateway/plugins/framework/external/mcp/server/runtime.py | 2 +- mcpgateway/plugins/framework/hooks/agents.py | 2 +- mcpgateway/plugins/framework/hooks/http.py | 4 ++-- mcpgateway/plugins/framework/hooks/prompts.py | 2 +- mcpgateway/plugins/framework/hooks/resources.py | 2 +- mcpgateway/plugins/framework/hooks/tools.py | 2 +- mcpgateway/plugins/framework/manager.py | 2 +- mcpgateway/plugins/tools/cli.py | 6 +++--- 11 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 66b0a4650..4dd893f71 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -902,7 +902,7 @@ def parse_issuers(cls, v: Any) -> set[str]: # Plugin CLI settings plugins_cli_completion: bool = Field(default=False, description="Enable auto-completion for plugins CLI") - plugins_cli_markup_mode: str | None = Field(default=None, description="Set markup mode for plugins CLI") + plugins_cli_markup_mode: Literal["markdown", "rich", "disabled"] | None = Field(default=None, description="Set markup mode for plugins CLI") # Development dev_mode: bool = False diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index c41aac070..e104013de 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -585,7 +585,7 @@ def name(self) -> str: return self._hook @property - def hook(self) -> Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]]: + def hook(self) -> Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] | None: """The hooking function that can be invoked within the reference. Returns: diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index 0f90b7292..b334f0521 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -329,6 +329,6 @@ def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-i self._plugin_ref = plugin_ref self._hook = hook if hasattr(plugin_ref.plugin, INVOKE_HOOK): - self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) - if not self._func: + self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial(plugin_ref.plugin.invoke_hook, hook) # type: ignore[attr-defined] + else: raise PluginError(error=PluginErrorModel(message=f"Plugin: {plugin_ref.plugin.name} is not an external plugin", plugin_name=plugin_ref.plugin.name)) diff --git a/mcpgateway/plugins/framework/external/mcp/server/runtime.py b/mcpgateway/plugins/framework/external/mcp/server/runtime.py index 5cb2241b8..b4e57a39e 100755 --- a/mcpgateway/plugins/framework/external/mcp/server/runtime.py +++ b/mcpgateway/plugins/framework/external/mcp/server/runtime.py @@ -235,7 +235,7 @@ async def health_check(_request: Request): await server.serve() -async def run(): +async def run() -> None: """Run the external plugin server with FastMCP. Supports both stdio and HTTP transports. Auto-detects transport based on stdin diff --git a/mcpgateway/plugins/framework/hooks/agents.py b/mcpgateway/plugins/framework/hooks/agents.py index db99139b3..eea547c9a 100644 --- a/mcpgateway/plugins/framework/hooks/agents.py +++ b/mcpgateway/plugins/framework/hooks/agents.py @@ -123,7 +123,7 @@ class AgentPostInvokePayload(PluginPayload): AgentPostInvokeResult = PluginResult[AgentPostInvokePayload] -def _register_agent_hooks(): +def _register_agent_hooks() -> None: """Register agent hooks in the global registry. This is called lazily to avoid circular import issues. diff --git a/mcpgateway/plugins/framework/hooks/http.py b/mcpgateway/plugins/framework/hooks/http.py index 675bc285c..cd8c4e120 100644 --- a/mcpgateway/plugins/framework/hooks/http.py +++ b/mcpgateway/plugins/framework/hooks/http.py @@ -17,7 +17,7 @@ class HttpHeaderPayload(RootModel[dict[str, str]], PluginPayload): """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" - def __iter__(self): + def __iter__(self): # type: ignore[no-untyped-def] """Custom iterator function to override root attribute. Returns: @@ -45,7 +45,7 @@ def __setitem__(self, key: str, value: str) -> None: """ self.root[key] = value - def __len__(self): + def __len__(self) -> int: """Custom len function to override root attribute. Returns: diff --git a/mcpgateway/plugins/framework/hooks/prompts.py b/mcpgateway/plugins/framework/hooks/prompts.py index a2349530f..d57e6bf34 100644 --- a/mcpgateway/plugins/framework/hooks/prompts.py +++ b/mcpgateway/plugins/framework/hooks/prompts.py @@ -106,7 +106,7 @@ class PromptPosthookPayload(PluginPayload): PromptPosthookResult = PluginResult[PromptPosthookPayload] -def _register_prompt_hooks(): +def _register_prompt_hooks() -> None: """Register prompt hooks in the global registry. This is called lazily to avoid circular import issues. diff --git a/mcpgateway/plugins/framework/hooks/resources.py b/mcpgateway/plugins/framework/hooks/resources.py index cf5390bbe..b31439130 100644 --- a/mcpgateway/plugins/framework/hooks/resources.py +++ b/mcpgateway/plugins/framework/hooks/resources.py @@ -96,7 +96,7 @@ class ResourcePostFetchPayload(PluginPayload): ResourcePostFetchResult = PluginResult[ResourcePostFetchPayload] -def _register_resource_hooks(): +def _register_resource_hooks() -> None: """Register resource hooks in the global registry. This is called lazily to avoid circular import issues. diff --git a/mcpgateway/plugins/framework/hooks/tools.py b/mcpgateway/plugins/framework/hooks/tools.py index b9d804958..7560d05b0 100644 --- a/mcpgateway/plugins/framework/hooks/tools.py +++ b/mcpgateway/plugins/framework/hooks/tools.py @@ -99,7 +99,7 @@ class ToolPostInvokePayload(PluginPayload): ToolPostInvokeResult = PluginResult[ToolPostInvokePayload] -def _register_tool_hooks(): +def _register_tool_hooks() -> None: """Register Tool hooks in the global registry. This is called lazily to avoid circular import issues. diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index e0d5c92db..48b3c9d27 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -566,7 +566,7 @@ async def invoke_hook_for_plugin( payload: Union[PluginPayload, dict[str, Any], str], context: PluginContext, violations_as_exceptions: bool = False, - payload_as_json=False, + payload_as_json: bool = False, ) -> PluginResult: """Invoke a specific hook for a single named plugin. diff --git a/mcpgateway/plugins/tools/cli.py b/mcpgateway/plugins/tools/cli.py index 3029cf0d6..01a2b5cd0 100644 --- a/mcpgateway/plugins/tools/cli.py +++ b/mcpgateway/plugins/tools/cli.py @@ -73,7 +73,7 @@ # --------------------------------------------------------------------------- -def command_exists(command_name): +def command_exists(command_name: str) -> bool: """Check if a given command-line utility exists and is executable. Args: @@ -132,7 +132,7 @@ def bootstrap( answers_file: Optional[Annotated[typer.FileText, typer.Option("--answers_file", "-a", help="The answers file to be used for bootstrapping.")]] = None, defaults: Annotated[bool, typer.Option("--defaults", help="Bootstrap with defaults.")] = False, dry_run: Annotated[bool, typer.Option("--dry_run", help="Run but do not make any changes.")] = False, -): +) -> None: """Boostrap a new plugin project from a template. Args: @@ -161,7 +161,7 @@ def bootstrap( @app.callback() -def callback(): # pragma: no cover +def callback() -> None: # pragma: no cover """This function exists to force 'bootstrap' to be a subcommand.""" From d4a24373e5cf52fe55469994fc26b741cf02a20a Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Wed, 5 Nov 2025 10:06:04 -0700 Subject: [PATCH 11/20] fix: updated schemas imports. Signed-off-by: Teryl Taylor --- mcpgateway/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index c9dc008c1..792e6891c 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -33,7 +33,7 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, EmailStr, Field, field_serializer, field_validator, model_validator, ValidationInfo # First-Party -from mcpgateway.common.models import ImageContent +from mcpgateway.common.models import Annotations, ImageContent from mcpgateway.common.models import Prompt as MCPPrompt from mcpgateway.common.models import Resource as MCPResource from mcpgateway.common.models import ResourceContent, TextContent From 79dfcaa971aedbfd4bbfdba5b30fb0096f6fb852 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Wed, 5 Nov 2025 22:12:48 -0500 Subject: [PATCH 12/20] fix: doctests Signed-off-by: Frederico Araujo --- mcpgateway/plugins/framework/base.py | 18 +++++++++--------- .../framework/external/mcp/server/server.py | 16 +++++++--------- mcpgateway/plugins/framework/hooks/registry.py | 6 +++--- mcpgateway/plugins/framework/registry.py | 8 +++++--- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index e104013de..60c64cc18 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -31,14 +31,14 @@ class Plugin(ABC): Examples: >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode - >>> from mcpgateway.plugins.mcp.entities import HookType + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="test_plugin", ... description="Test plugin", ... author="test", ... kind="mcpgateway.plugins.framework.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_PRE_FETCH], + ... hooks=[PromptHookType.PROMPT_PRE_FETCH], ... tags=["test"], ... mode=PluginMode.ENFORCE, ... priority=50 @@ -50,7 +50,7 @@ class Plugin(ABC): 50 >>> plugin.mode - >>> HookType.PROMPT_PRE_FETCH in plugin.hooks + >>> PromptHookType.PROMPT_PRE_FETCH in plugin.hooks True """ @@ -71,14 +71,14 @@ def __init__( Examples: >>> from mcpgateway.plugins.framework import PluginConfig - >>> from mcpgateway.plugins.mcp.entities import HookType + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="simple_plugin", ... description="Simple test", ... author="test", ... kind="test.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_POST_FETCH], + ... hooks=[PromptHookType.PROMPT_POST_FETCH], ... tags=["simple"] ... ) >>> plugin = Plugin(config) @@ -234,14 +234,14 @@ class PluginRef: Examples: >>> from mcpgateway.plugins.framework import PluginConfig, PluginMode - >>> from mcpgateway.plugins.mcp.entities import HookType + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="ref_test", ... description="Reference test", ... author="test", ... kind="test.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_PRE_FETCH], + ... hooks=[PromptHookType.PROMPT_PRE_FETCH], ... tags=["ref", "test"], ... mode=PluginMode.PERMISSIVE, ... priority=100 @@ -268,14 +268,14 @@ def __init__(self, plugin: Plugin): Examples: >>> from mcpgateway.plugins.framework import PluginConfig - >>> from mcpgateway.plugins.mcp.entities import HookType + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> config = PluginConfig( ... name="plugin_ref", ... description="Test", ... author="test", ... kind="test.Plugin", ... version="1.0.0", - ... hooks=[HookType.PROMPT_POST_FETCH], + ... hooks=[PromptHookType.PROMPT_POST_FETCH], ... tags=[] ... ) >>> plugin = Plugin(config) diff --git a/mcpgateway/plugins/framework/external/mcp/server/server.py b/mcpgateway/plugins/framework/external/mcp/server/server.py index 218d2a383..adf8036fe 100644 --- a/mcpgateway/plugins/framework/external/mcp/server/server.py +++ b/mcpgateway/plugins/framework/external/mcp/server/server.py @@ -41,7 +41,7 @@ def __init__(self, config_path: str | None = None) -> None: If set, this attribute overrides the value in PLUGINS_CONFIG_PATH. Examples: - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> server is not None True """ @@ -57,7 +57,7 @@ async def get_plugin_configs(self) -> list[dict]: Examples: >>> import asyncio - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> plugins = asyncio.run(server.get_plugin_configs()) >>> len(plugins) > 0 True @@ -79,7 +79,7 @@ async def get_plugin_config(self, name: str) -> dict | None: Examples: >>> import asyncio - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> c = asyncio.run(server.get_plugin_config(name = "DenyListPlugin")) >>> c is not None True @@ -111,16 +111,14 @@ async def invoke_hook(self, hook_type: str, plugin_name: str, payload: Dict[str, >>> import asyncio >>> import os >>> os.environ["PYTHONPATH"] = "." - >>> from mcpgateway.plugins.framework import GlobalContext, PromptPrehookPayload, PluginContext, PromptPrehookResult - >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway.plugins/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") - >>> def prompt_pre_fetch_func(plugin: Plugin, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - ... return plugin.prompt_pre_fetch(payload, context) - >>> payload = PromptPrehookPayload(name="test_prompt", args={"user": "This is so innovative"}) + >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, PromptHookType, PromptPrehookPayload, PluginContext, PromptPrehookResult + >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") + >>> payload = PromptPrehookPayload(prompt_id="123", name="test_prompt", args={"user": "This is so innovative"}) >>> context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) >>> initialized = asyncio.run(server.initialize()) >>> initialized True - >>> result = asyncio.run(server.invoke_hook(PromptPrehookPayload, prompt_pre_fetch_func, "DenyListPlugin", payload.model_dump(), context.model_dump())) + >>> result = asyncio.run(server.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, "DenyListPlugin", payload.model_dump(), context.model_dump())) >>> result is not None True >>> result["result"]["continue_processing"] diff --git a/mcpgateway/plugins/framework/hooks/registry.py b/mcpgateway/plugins/framework/hooks/registry.py index 570b9cb42..94608e3a3 100644 --- a/mcpgateway/plugins/framework/hooks/registry.py +++ b/mcpgateway/plugins/framework/hooks/registry.py @@ -30,7 +30,7 @@ class HookRegistry: >>> registry = HookRegistry() >>> registry.register_hook("test_hook", PluginPayload, PluginResult) >>> registry.get_payload_type("test_hook") - + >>> registry.get_result_type("test_hook") """ @@ -115,8 +115,8 @@ def json_to_payload(self, hook_type: str, payload: Union[str, dict]) -> PluginPa Examples: >>> registry = HookRegistry() - >>> from mcpgateway.plugins.framework import PluginPayload, PluginResult - >>> registry.register_hook("test", PluginPayload, PluginResult) + >>> from mcpgateway.plugins.framework.hooks import PromptPrehookPayload, PromptPrehookResult + >>> registry.register_hook("test", PromptPrehookPayload, PromptPrehookResult) >>> payload = registry.json_to_payload("test", "{}") """ payload_class = self.get_payload_type(hook_type) diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index a6e0d59e3..e987e46f0 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -26,7 +26,7 @@ class PluginInstanceRegistry: Examples: >>> from mcpgateway.plugins.framework import Plugin, PluginConfig - >>> from mcpgateway.plugins.mcp.entities import HookType + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> registry = PluginInstanceRegistry() >>> config = PluginConfig( ... name="test", @@ -34,14 +34,16 @@ class PluginInstanceRegistry: ... author="test", ... kind="test.Plugin", ... version="1.0", - ... hooks=[HookType.PROMPT_PRE_FETCH], + ... hooks=[PromptHookType.PROMPT_PRE_FETCH], ... tags=[] ... ) + >>> def prompt_pre_fetch(self, payload, context): ... >>> plugin = Plugin(config) + >>> plugin.prompt_pre_fetch = prompt_pre_fetch >>> registry.register(plugin) >>> registry.get_plugin("test").name 'test' - >>> len(registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH)) + >>> len(registry.get_plugins_for_hook(PromptHookType.PROMPT_PRE_FETCH)) 1 >>> registry.unregister("test") >>> registry.get_plugin("test") is None From 0ebb6495e32b089211cda0b216e2297ca49d2aea Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Wed, 5 Nov 2025 23:41:59 -0500 Subject: [PATCH 13/20] fix: remaining doctests Signed-off-by: Frederico Araujo --- mcpgateway/plugins/framework/hooks/registry.py | 4 ++-- mcpgateway/plugins/framework/manager.py | 17 ++++++++--------- mcpgateway/plugins/framework/registry.py | 4 ++-- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mcpgateway/plugins/framework/hooks/registry.py b/mcpgateway/plugins/framework/hooks/registry.py index 94608e3a3..177175471 100644 --- a/mcpgateway/plugins/framework/hooks/registry.py +++ b/mcpgateway/plugins/framework/hooks/registry.py @@ -115,9 +115,9 @@ def json_to_payload(self, hook_type: str, payload: Union[str, dict]) -> PluginPa Examples: >>> registry = HookRegistry() - >>> from mcpgateway.plugins.framework.hooks import PromptPrehookPayload, PromptPrehookResult + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload, PromptPrehookResult >>> registry.register_hook("test", PromptPrehookPayload, PromptPrehookResult) - >>> payload = registry.json_to_payload("test", "{}") + >>> payload = registry.json_to_payload("test", {"prompt_id": "123"}) """ payload_class = self.get_payload_type(hook_type) if not payload_class: diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 48b3c9d27..d8eddb3fe 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -21,8 +21,8 @@ >>> # Create test payload and context >>> from mcpgateway.plugins.framework.models import GlobalContext - >>> from mcpgateway.plugins.mcp.entities.models import PromptPrehookPayload - >>> payload = PromptPrehookPayload(name="test", args={"user": "input"}) + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload + >>> payload = PromptPrehookPayload(prompt_id="123", name="test", args={"user": "input"}) >>> context = GlobalContext(request_id="123") >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) # Called in async context """ @@ -79,8 +79,7 @@ class PluginExecutor: - Metadata aggregation from multiple plugins Examples: - >>> from mcpgateway.plugins.mcp.entities.models import PromptPrehookPayload - >>> executor = PluginExecutor[PromptPrehookPayload]() + >>> executor = PluginExecutor() >>> # In async context: >>> # result, contexts = await executor.execute( >>> # plugins=[plugin1, plugin2], @@ -132,14 +131,14 @@ async def execute( Examples: >>> # Execute plugins with timeout protection - >>> from mcpgateway.plugins.mcp.entities.models import HookType + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType >>> executor = PluginExecutor(timeout=30) >>> # Assuming you have a registry instance: - >>> # plugins = registry.get_plugins_for_hook(HookType.PROMPT_PRE_FETCH) + >>> # plugins = registry.get_plugins_for_hook(PromptHookType.PROMPT_PRE_FETCH) >>> # In async context: >>> # result, contexts = await executor.execute( >>> # plugins=plugins, - >>> # payload=PromptPrehookPayload(name="test", args={}), + >>> # payload=PromptPrehookPayload(prompt_id="123", name="test", args={}), >>> # global_context=GlobalContext(request_id="123"), >>> # plugin_run=pre_prompt_fetch, >>> # compare=pre_prompt_matches @@ -364,8 +363,8 @@ class PluginManager: >>> >>> # Execute prompt hooks >>> from mcpgateway.plugins.framework.models import GlobalContext - >>> from mcpgateway.plugins.mcp.entities.models import PromptPrehookPayload - >>> payload = PromptPrehookPayload(name="test", args={}) + >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload + >>> payload = PromptPrehookPayload(prompt_id="123", name="test", args={}) >>> context = GlobalContext(request_id="req-123") >>> # In async context: >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) diff --git a/mcpgateway/plugins/framework/registry.py b/mcpgateway/plugins/framework/registry.py index e987e46f0..28c5259dc 100644 --- a/mcpgateway/plugins/framework/registry.py +++ b/mcpgateway/plugins/framework/registry.py @@ -37,13 +37,13 @@ class PluginInstanceRegistry: ... hooks=[PromptHookType.PROMPT_PRE_FETCH], ... tags=[] ... ) - >>> def prompt_pre_fetch(self, payload, context): ... + >>> async def prompt_pre_fetch(payload, context): ... >>> plugin = Plugin(config) >>> plugin.prompt_pre_fetch = prompt_pre_fetch >>> registry.register(plugin) >>> registry.get_plugin("test").name 'test' - >>> len(registry.get_plugins_for_hook(PromptHookType.PROMPT_PRE_FETCH)) + >>> len(registry.get_hook_refs_for_hook(PromptHookType.PROMPT_PRE_FETCH)) 1 >>> registry.unregister("test") >>> registry.get_plugin("test") is None From 88739c469df84fcaaaa75accc776cbeafdb28065 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Thu, 6 Nov 2025 00:31:08 -0500 Subject: [PATCH 14/20] fix: lint issues Signed-off-by: Frederico Araujo --- mcp-servers/templates/go/copier.yaml | 1 - mcpgateway/services/tool_service.py | 4 +- .../ai_artifacts_normalizer.py | 2 +- plugins/altk_json_processor/json_processor.py | 2 +- .../argument_normalizer.py | 19 ++- .../cached_tool_result/cached_tool_result.py | 2 +- plugins/circuit_breaker/circuit_breaker.py | 2 +- .../citation_validator/citation_validator.py | 2 +- plugins/code_formatter/code_formatter.py | 2 +- .../code_safety_linter/code_safety_linter.py | 2 +- plugins/config.yaml | 8 +- .../content_moderation/content_moderation.py | 142 ++++++++++++++++-- plugins/deny_filter/deny.py | 9 +- .../external/clamav_server/clamav_plugin.py | 68 +++++++-- plugins/external/llmguard/docker-compose.yaml | 4 +- .../llmguard/examples/config-all-in-one.yaml | 14 +- .../examples/config-complex-policy.yaml | 2 +- .../examples/config-input-output-filter.yaml | 2 +- ...g-separate-plugins-filters-sanitizers.yaml | 2 +- .../external/llmguard/llmguardplugin/cache.py | 14 +- .../llmguard/llmguardplugin/llmguard.py | 32 +++- .../llmguard/llmguardplugin/plugin.py | 27 +++- .../llmguard/llmguardplugin/policy.py | 8 +- .../llmguard/resources/plugins/config.yaml | 2 +- .../external/opa/opapluginfilter/plugin.py | 22 +-- .../opa/resources/plugins/config.yaml | 2 +- plugins/external/opa/tests/test_all.py | 42 +++++- .../file_type_allowlist.py | 2 +- .../harmful_content_detector.py | 2 +- plugins/header_injector/header_injector.py | 2 +- plugins/html_to_markdown/html_to_markdown.py | 2 +- plugins/json_repair/json_repair.py | 2 +- .../license_header_injector.py | 2 +- plugins/markdown_cleaner/markdown_cleaner.py | 5 +- .../output_length_guard.py | 2 +- plugins/pii_filter/pii_filter.py | 10 +- plugins/pii_filter/pii_filter_rust.py | 33 ++-- .../privacy_notice_injector.py | 2 +- plugins/rate_limiter/rate_limiter.py | 2 +- plugins/regex_filter/search_replace.py | 2 +- plugins/resource_filter/resource_filter.py | 2 +- .../response_cache_by_prompt.py | 2 +- .../retry_with_backoff/retry_with_backoff.py | 2 +- .../robots_license_guard.py | 2 +- .../safe_html_sanitizer.py | 2 +- plugins/schema_guard/schema_guard.py | 2 +- .../secrets_detection/secrets_detection.py | 2 +- plugins/sql_sanitizer/sql_sanitizer.py | 2 +- plugins/summarizer/summarizer.py | 2 +- .../timezone_translator.py | 2 +- plugins/url_reputation/url_reputation.py | 2 +- plugins/vault/vault_plugin.py | 4 +- .../virus_total_checker.py | 10 +- plugins/watchdog/watchdog.py | 2 +- .../webhook_notification.py | 119 +++++++++++++-- .../fixtures/configs/agent_passthrough.yaml | 2 +- 56 files changed, 507 insertions(+), 158 deletions(-) diff --git a/mcp-servers/templates/go/copier.yaml b/mcp-servers/templates/go/copier.yaml index e6d615ea5..61d2cb223 100644 --- a/mcp-servers/templates/go/copier.yaml +++ b/mcp-servers/templates/go/copier.yaml @@ -45,4 +45,3 @@ include_container: type: bool help: Include Dockerfile for a minimal runtime image default: true - diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 6a4a5a7d8..96f98d50a 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -1179,7 +1179,7 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r global_context.metadata[TOOL_METADATA] = tool_metadata pre_result, context_table = await self._plugin_manager.invoke_hook( ToolHookType.TOOL_PRE_INVOKE, - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), global_context=global_context, local_contexts=None, violations_as_exceptions=True, @@ -1335,7 +1335,7 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head global_context.metadata[GATEWAY_METADATA] = gateway_metadata pre_result, context_table = await self._plugin_manager.invoke_hook( ToolHookType.TOOL_PRE_INVOKE, - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), global_context=global_context, local_contexts=None, violations_as_exceptions=True, diff --git a/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py b/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py index 923bb1ce0..215e0e4b6 100644 --- a/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py +++ b/plugins/ai_artifacts_normalizer/ai_artifacts_normalizer.py @@ -19,9 +19,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, PromptPrehookPayload, PromptPrehookResult, ResourcePostFetchPayload, diff --git a/plugins/altk_json_processor/json_processor.py b/plugins/altk_json_processor/json_processor.py index df26cedd8..4d1cb25fa 100644 --- a/plugins/altk_json_processor/json_processor.py +++ b/plugins/altk_json_processor/json_processor.py @@ -23,9 +23,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) diff --git a/plugins/argument_normalizer/argument_normalizer.py b/plugins/argument_normalizer/argument_normalizer.py index 8e847a7a4..e06eb5df8 100644 --- a/plugins/argument_normalizer/argument_normalizer.py +++ b/plugins/argument_normalizer/argument_normalizer.py @@ -27,9 +27,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, @@ -145,7 +145,15 @@ class EffectiveCfg: def _merge_overrides(base: ArgumentNormalizerConfig, path: str) -> EffectiveCfg: - """Compute an effective configuration for a given field path.""" + """Compute an effective configuration for a given field path. + + Args: + base: Base configuration to start from. + path: Field path to compute configuration for. + + Returns: + Effective configuration for the given field path. + """ cfg = base # Start with base values eff = EffectiveCfg( @@ -444,6 +452,13 @@ def repl(m: re.Match[str]) -> str: def _normalize_text(text: str, eff: EffectiveCfg) -> str: """Normalize a text value using an effective configuration. + Args: + text: Text value to normalize. + eff: Effective configuration to use for normalization. + + Returns: + Normalized text value. + Examples: Normalize unicode and whitespace: diff --git a/plugins/cached_tool_result/cached_tool_result.py b/plugins/cached_tool_result/cached_tool_result.py index cce7558b4..d4f3961d0 100644 --- a/plugins/cached_tool_result/cached_tool_result.py +++ b/plugins/cached_tool_result/cached_tool_result.py @@ -25,9 +25,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, diff --git a/plugins/circuit_breaker/circuit_breaker.py b/plugins/circuit_breaker/circuit_breaker.py index f9e5de429..57d748d41 100644 --- a/plugins/circuit_breaker/circuit_breaker.py +++ b/plugins/circuit_breaker/circuit_breaker.py @@ -26,10 +26,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, diff --git a/plugins/citation_validator/citation_validator.py b/plugins/citation_validator/citation_validator.py index 44fdd4e80..fc7d71f0f 100644 --- a/plugins/citation_validator/citation_validator.py +++ b/plugins/citation_validator/citation_validator.py @@ -24,10 +24,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, diff --git a/plugins/code_formatter/code_formatter.py b/plugins/code_formatter/code_formatter.py index 47d3c2d09..fe2d51048 100644 --- a/plugins/code_formatter/code_formatter.py +++ b/plugins/code_formatter/code_formatter.py @@ -28,9 +28,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, diff --git a/plugins/code_safety_linter/code_safety_linter.py b/plugins/code_safety_linter/code_safety_linter.py index c4c17768e..7c5d80032 100644 --- a/plugins/code_safety_linter/code_safety_linter.py +++ b/plugins/code_safety_linter/code_safety_linter.py @@ -21,10 +21,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) diff --git a/plugins/config.yaml b/plugins/config.yaml index bc67a5d6f..7c821daf6 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -179,11 +179,11 @@ plugins: priority: 119 conditions: [] config: - allowed_tags: ["a","p","div","span","strong","em","code","pre","ul","ol","li","h1","h2","h3","h4","h5","h6","blockquote","img","br","hr","table","thead","tbody","tr","th","td"] + allowed_tags: ["a", "p", "div", "span", "strong", "em", "code", "pre", "ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6", "blockquote", "img", "br", "hr", "table", "thead", "tbody", "tr", "th", "td"] allowed_attrs: - "*": ["id","class","title","alt"] - a: ["href","rel","target"] - img: ["src","width","height","alt","title"] + "*": ["id", "class", "title", "alt"] + a: ["href", "rel", "target"] + img: ["src", "width", "height", "alt", "title"] remove_comments: true drop_unknown_tags: true strip_event_handlers: true diff --git a/plugins/content_moderation/content_moderation.py b/plugins/content_moderation/content_moderation.py index 907ebde91..877654aee 100644 --- a/plugins/content_moderation/content_moderation.py +++ b/plugins/content_moderation/content_moderation.py @@ -24,10 +24,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPostInvokePayload, @@ -189,7 +189,15 @@ def __init__(self, config: PluginConfig) -> None: self._cache: Dict[str, ModerationResult] = {} if self._cfg.enable_caching else None async def _get_cache_key(self, text: str, provider: ModerationProvider) -> str: - """Generate cache key for content.""" + """Generate cache key for content. + + Args: + text: Content text to generate key for. + provider: Moderation provider being used. + + Returns: + Cache key string. + """ # Standard import hashlib @@ -197,7 +205,15 @@ async def _get_cache_key(self, text: str, provider: ModerationProvider) -> str: return f"{provider.value}:{content_hash}" async def _get_cached_result(self, text: str, provider: ModerationProvider) -> Optional[ModerationResult]: - """Get cached moderation result.""" + """Get cached moderation result. + + Args: + text: Content text to check cache for. + provider: Moderation provider being used. + + Returns: + Cached moderation result if available, None otherwise. + """ if not self._cfg.enable_caching or not self._cache: return None @@ -205,7 +221,13 @@ async def _get_cached_result(self, text: str, provider: ModerationProvider) -> O return self._cache.get(cache_key) async def _cache_result(self, text: str, provider: ModerationProvider, result: ModerationResult) -> None: - """Cache moderation result.""" + """Cache moderation result. + + Args: + text: Content text being cached. + provider: Moderation provider being used. + result: Moderation result to cache. + """ if not self._cfg.enable_caching or not self._cache: return @@ -213,7 +235,18 @@ async def _cache_result(self, text: str, provider: ModerationProvider, result: M self._cache[cache_key] = result async def _moderate_with_ibm_watson(self, text: str) -> ModerationResult: - """Moderate content using IBM Watson Natural Language Understanding.""" + """Moderate content using IBM Watson Natural Language Understanding. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from IBM Watson. + + Raises: + ValueError: If IBM Watson configuration not provided. + Exception: If API call fails. + """ if not self._cfg.ibm_watson: raise ValueError("IBM Watson configuration not provided") @@ -284,7 +317,18 @@ async def _moderate_with_ibm_watson(self, text: str) -> ModerationResult: raise async def _moderate_with_ibm_granite(self, text: str) -> ModerationResult: - """Moderate content using IBM Granite Guardian via Ollama.""" + """Moderate content using IBM Granite Guardian via Ollama. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from IBM Granite. + + Raises: + ValueError: If IBM Granite configuration not provided. + Exception: If API call fails. + """ if not self._cfg.ibm_granite: raise ValueError("IBM Granite configuration not provided") @@ -351,7 +395,18 @@ async def _moderate_with_ibm_granite(self, text: str) -> ModerationResult: raise async def _moderate_with_openai(self, text: str) -> ModerationResult: - """Moderate content using OpenAI Moderation API.""" + """Moderate content using OpenAI Moderation API. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from OpenAI. + + Raises: + ValueError: If OpenAI configuration not provided. + Exception: If API call fails. + """ if not self._cfg.openai: raise ValueError("OpenAI configuration not provided") @@ -413,7 +468,15 @@ async def _moderate_with_openai(self, text: str) -> ModerationResult: raise async def _apply_moderation_action(self, text: str, result: ModerationResult) -> str: - """Apply the moderation action to the text.""" + """Apply the moderation action to the text. + + Args: + text: Original content text. + result: Moderation result with action to apply. + + Returns: + Modified text based on moderation action. + """ if result.action == ModerationAction.BLOCK: return "" # Empty content elif result.action == ModerationAction.REDACT: @@ -432,7 +495,14 @@ async def _apply_moderation_action(self, text: str, result: ModerationResult) -> return text # Return original text async def _moderate_content(self, text: str) -> ModerationResult: - """Moderate content using the configured provider.""" + """Moderate content using the configured provider. + + Args: + text: Content text to moderate. + + Returns: + Moderation result from the configured provider. + """ if len(text) > self._cfg.max_text_length: text = text[: self._cfg.max_text_length] @@ -482,7 +552,14 @@ async def _moderate_content(self, text: str) -> ModerationResult: return result async def _moderate_with_patterns(self, text: str) -> ModerationResult: - """Fallback moderation using regex patterns.""" + """Fallback moderation using regex patterns. + + Args: + text: Content text to moderate. + + Returns: + Moderation result based on pattern matching. + """ categories = {} # Basic pattern matching for different categories @@ -532,7 +609,14 @@ async def _moderate_with_patterns(self, text: str) -> ModerationResult: ) async def _extract_text_content(self, payload: Any) -> List[str]: - """Extract text content from various payload types.""" + """Extract text content from various payload types. + + Args: + payload: Payload to extract text from. + + Returns: + List of extracted text strings. + """ texts = [] if hasattr(payload, "args") and payload.args: @@ -551,7 +635,15 @@ async def _extract_text_content(self, payload: Any) -> List[str]: return [text for text in texts if len(text.strip()) > 3] # Filter very short texts async def prompt_pre_fetch(self, payload: PromptPrehookPayload, _context: PluginContext) -> PromptPrehookResult: - """Moderate prompt content before fetching.""" + """Moderate prompt content before fetching. + + Args: + payload: Prompt payload to moderate. + _context: Plugin context (unused). + + Returns: + Result indicating whether to continue processing. + """ texts = await self._extract_text_content(payload) for text in texts: @@ -595,7 +687,15 @@ async def prompt_pre_fetch(self, payload: PromptPrehookPayload, _context: Plugin return PromptPrehookResult() async def tool_pre_invoke(self, payload: ToolPreInvokePayload, _context: PluginContext) -> ToolPreInvokeResult: - """Moderate tool arguments before invocation.""" + """Moderate tool arguments before invocation. + + Args: + payload: Tool invocation payload to moderate. + _context: Plugin context (unused). + + Returns: + Result indicating whether to continue processing. + """ texts = await self._extract_text_content(payload) for text in texts: @@ -634,7 +734,15 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, _context: PluginC return ToolPreInvokeResult(metadata={"moderation_checked": True}) async def tool_post_invoke(self, payload: ToolPostInvokePayload, _context: PluginContext) -> ToolPostInvokeResult: - """Moderate tool output after invocation.""" + """Moderate tool output after invocation. + + Args: + payload: Tool result payload to moderate. + _context: Plugin context (unused). + + Returns: + Result indicating whether to continue processing. + """ # Extract text from tool results result_text = "" if hasattr(payload.result, "content"): @@ -681,7 +789,11 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, _context: Plugi return ToolPostInvokeResult(metadata={"output_checked": True}) async def __aenter__(self): - """Async context manager entry.""" + """Async context manager entry. + + Returns: + ContentModerationPlugin: The plugin instance. + """ return self async def __aexit__(self, _exc_type, _exc_val, _exc_tb): diff --git a/plugins/deny_filter/deny.py b/plugins/deny_filter/deny.py index 1b9b1e9b4..7cf7e3790 100644 --- a/plugins/deny_filter/deny.py +++ b/plugins/deny_filter/deny.py @@ -12,14 +12,7 @@ from pydantic import BaseModel # First-Party -from mcpgateway.plugins.framework import ( - PluginConfig, - PluginContext, - PluginViolation, - Plugin, - PromptPrehookPayload, - PromptPrehookResult -) +from mcpgateway.plugins.framework import Plugin, PluginConfig, PluginContext, PluginViolation, PromptPrehookPayload, PromptPrehookResult from mcpgateway.services.logging_service import LoggingService # Initialize logging service first diff --git a/plugins/external/clamav_server/clamav_plugin.py b/plugins/external/clamav_server/clamav_plugin.py index ba11e3467..7fdce282f 100644 --- a/plugins/external/clamav_server/clamav_plugin.py +++ b/plugins/external/clamav_server/clamav_plugin.py @@ -31,10 +31,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -52,8 +52,14 @@ def _has_eicar(data: bytes) -> bool: - """Has Eicar implementation.""" + """Check if data contains EICAR test virus signature. + + Args: + data: Bytes to scan for EICAR signature. + Returns: + True if EICAR signature found, False otherwise. + """ blob = data.decode("latin1", errors="ignore") return any(sig in blob for sig in EICAR_SIGNATURES) @@ -62,8 +68,11 @@ class ClamAVConfig: """ClamAVConfig implementation.""" def __init__(self, cfg: dict[str, Any] | None) -> None: - """Initialize the instance.""" + """Initialize the instance. + Args: + cfg: Configuration dictionary. + """ c = cfg or {} self.mode: str = c.get("mode", "eicar_only") # eicar_only|clamd_tcp|clamd_unix self.host: str | None = c.get("clamd_host") @@ -75,8 +84,17 @@ def __init__(self, cfg: dict[str, Any] | None) -> None: def _clamd_instream_scan_tcp(host: str, port: int, data: bytes, timeout: float) -> str: - """Clamd Instream Scan Tcp implementation.""" + """Scan data using ClamAV daemon via TCP connection. + Args: + host: ClamAV daemon host address. + port: ClamAV daemon port number. + data: Bytes to scan. + timeout: Connection timeout in seconds. + + Returns: + Scan response from ClamAV daemon. + """ # Minimal INSTREAM protocol: https://linux.die.net/man/8/clamd s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(timeout) @@ -99,8 +117,16 @@ def _clamd_instream_scan_tcp(host: str, port: int, data: bytes, timeout: float) def _clamd_instream_scan_unix(path: str, data: bytes, timeout: float) -> str: - """Clamd Instream Scan Unix implementation.""" + """Scan data using ClamAV daemon via Unix socket connection. + + Args: + path: Unix socket path. + data: Bytes to scan. + timeout: Connection timeout in seconds. + Returns: + Scan response from ClamAV daemon. + """ s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s.settimeout(timeout) s.connect(path) @@ -123,23 +149,35 @@ class ClamAVRemotePlugin(Plugin): """External ClamAV plugin for scanning resources and content.""" def __init__(self, config: PluginConfig) -> None: - """Initialize the instance.""" + """Initialize the instance. + Args: + config: Plugin configuration. + """ super().__init__(config) self._cfg = ClamAVConfig(config.config) self._stats: dict[str, int] = {"attempted": 0, "infected": 0, "blocked": 0, "errors": 0} def _bump(self, key: str) -> None: - """Bump implementation.""" + """Increment statistics counter. + Args: + key: Statistics key to increment. + """ try: self._stats[key] = int(self._stats.get(key, 0)) + 1 except Exception: pass def _scan_bytes(self, data: bytes) -> tuple[bool, str]: - """Scan Bytes implementation.""" + """Scan bytes for malware using configured scan method. + + Args: + data: Bytes to scan for malware. + Returns: + Tuple of (infected: bool, detail: str) indicating if malware was found and scan details. + """ if len(data) > self._cfg.max_bytes: return False, "SKIPPED: too large" @@ -284,8 +322,14 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin # Recursively scan string values in tool outputs def iter_strings(obj): - """Iter Strings implementation.""" + """Recursively iterate over all string values in an object. + Args: + obj: Object to iterate over (str, dict, list, or other). + + Yields: + String values found in the object. + """ if isinstance(obj, str): yield obj elif isinstance(obj, dict): @@ -320,7 +364,11 @@ def iter_strings(obj): return ToolPostInvokeResult(metadata={"clamav": {"error": str(exc)}}) def health(self) -> dict[str, Any]: - """Return plugin health and metrics; try clamd connectivity when configured.""" + """Return plugin health and metrics; try clamd connectivity when configured. + + Returns: + Dictionary containing plugin health status and metrics. + """ status = {"mode": self._cfg.mode, "block_on_positive": self._cfg.block_on_positive, "stats": dict(self._stats)} reachable = None try: diff --git a/plugins/external/llmguard/docker-compose.yaml b/plugins/external/llmguard/docker-compose.yaml index 9cd8afbd3..1e0399f05 100644 --- a/plugins/external/llmguard/docker-compose.yaml +++ b/plugins/external/llmguard/docker-compose.yaml @@ -17,7 +17,7 @@ services: llmguardplugin: container_name: llmguardplugin - image: mcpgateway/llmguardplugin:latest # Use the local latest image. Run `make docker-prod` to build it. + image: mcpgateway/llmguardplugin:latest # Use the local latest image. Run `make docker-prod` to build it. restart: always env_file: - .env @@ -33,7 +33,7 @@ services: llmguardplugin-testing: container_name: llmguardplugin-testing - image: mcpgateway/llmguardplugin-testing:latest # Use the local latest image. Run `make docker-prod` to build it. + image: mcpgateway/llmguardplugin-testing:latest # Use the local latest image. Run `make docker-prod` to build it. env_file: - .env ports: diff --git a/plugins/external/llmguard/examples/config-all-in-one.yaml b/plugins/external/llmguard/examples/config-all-in-one.yaml index 62679b563..d5ee7666a 100644 --- a/plugins/external/llmguard/examples/config-all-in-one.yaml +++ b/plugins/external/llmguard/examples/config-all-in-one.yaml @@ -5,7 +5,7 @@ plugins: description: "A plugin for running input and output through llmguard scanners " version: "0.1" author: "ContextForge" - hooks: ["prompt_pre_fetch","prompt_post_fetch"] + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] tags: ["plugin", "guardrails", "llmguard", "pre-post", "filters", "sanitizers"] mode: "enforce" # enforce | permissive | disabled priority: 20 @@ -18,11 +18,11 @@ plugins: cache_ttl: 120 #defined in seconds input: filters: - PromptInjection: - threshold: 0.6 - use_onnx: false - policy: PromptInjection - policy_message: I'm sorry, I cannot allow this input. + PromptInjection: + threshold: 0.6 + use_onnx: false + policy: PromptInjection + policy_message: I'm sorry, I cannot allow this input. sanitizers: Anonymize: language: "en" @@ -34,7 +34,7 @@ plugins: matching_strategy: exact filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/examples/config-complex-policy.yaml b/plugins/external/llmguard/examples/config-complex-policy.yaml index 199588ec9..dad5720d1 100644 --- a/plugins/external/llmguard/examples/config-complex-policy.yaml +++ b/plugins/external/llmguard/examples/config-complex-policy.yaml @@ -61,7 +61,7 @@ plugins: output: filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/examples/config-input-output-filter.yaml b/plugins/external/llmguard/examples/config-input-output-filter.yaml index 1d5272e2f..153d843bf 100644 --- a/plugins/external/llmguard/examples/config-input-output-filter.yaml +++ b/plugins/external/llmguard/examples/config-input-output-filter.yaml @@ -42,7 +42,7 @@ plugins: output: filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml b/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml index 1ce1222fd..785a34a08 100644 --- a/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml +++ b/plugins/external/llmguard/examples/config-separate-plugins-filters-sanitizers.yaml @@ -87,7 +87,7 @@ plugins: output: filters: Toxicity: - threshold: 0.5 + threshold: 0.5 policy: Toxicity policy_message: I'm sorry, I cannot allow this output. diff --git a/plugins/external/llmguard/llmguardplugin/cache.py b/plugins/external/llmguard/llmguardplugin/cache.py index 454b6ee65..68dfa9ff4 100644 --- a/plugins/external/llmguard/llmguardplugin/cache.py +++ b/plugins/external/llmguard/llmguardplugin/cache.py @@ -40,7 +40,7 @@ def __init__(self, ttl: int = 0) -> None: """init block for cache. This initializes a redit client. Args: - ttl: Time to live in seconds for cache + ttl: Time to live in seconds for cache """ self.cache_ttl = ttl self.cache = redis.Redis(host=redis_host, port=redis_port) @@ -53,6 +53,9 @@ def update_cache(self, key: int = None, value: tuple = None) -> tuple[bool]: Args: key: The id of vault in string value: The tuples in the vault + + Returns: + tuple[bool]: A tuple containing (success_set, success_expiry) booleans. """ serialized_obj = pickle.dumps(value) logger.info(f"Update cache in cache: {key} {serialized_obj}") @@ -73,10 +76,9 @@ def retrieve_cache(self, key: int = None) -> tuple: Args: key: The id of vault in string - value: The tuples in the vault Returns: - retrieved_obj: Return the retrieved object from cache + tuple: The retrieved object from cache or None if not found. """ value = self.cache.get(key) if value: @@ -87,14 +89,10 @@ def retrieve_cache(self, key: int = None) -> tuple: logger.error(f"Cache retrieval unsuccessful for id: {key}") def delete_cache(self, key: int = None) -> None: - """Retrieves cache for a key value + """Deletes cache for a key value Args: key: The id of vault in string - value: The tuples in the vault - - Returns: - retrieved_obj: Return the retrieved object from cache """ logger.info(f"Deleting cache for key : {key}") deleted_count = self.cache.delete(key) diff --git a/plugins/external/llmguard/llmguardplugin/llmguard.py b/plugins/external/llmguard/llmguardplugin/llmguard.py index 9612b3abe..8219ff199 100644 --- a/plugins/external/llmguard/llmguardplugin/llmguard.py +++ b/plugins/external/llmguard/llmguardplugin/llmguard.py @@ -36,7 +36,11 @@ class LLMGuardBase: """ def __init__(self, config: Optional[dict[str, Any]]) -> None: - """Initialize the instance.""" + """Initialize the instance. + + Args: + config: Configuration for guardrails. + """ self.lgconfig = LLMGuardConfig.model_validate(config) self.scanners = {"input": {"sanitizers": [], "filters": []}, "output": {"sanitizers": [], "filters": []}} @@ -65,7 +69,11 @@ def _create_new_vault_on_expiry(self, vault) -> bool: return False def _create_vault(self) -> Vault: - """This function creates a new vault and sets it's creation time as it's attribute""" + """This function creates a new vault and sets it's creation time as it's attribute + + Returns: + Vault: A new vault object with creation time set. + """ logger.info("Vault creation") vault = Vault() vault.creation_time = datetime.now() @@ -76,7 +84,11 @@ def _retreive_vault(self, sanitizer_names: list = ["Anonymize"]) -> tuple[Vault, """This function is responsible for retrieving vault for given sanitizer names Args: - sanitizer_names: list of names for sanitizers""" + sanitizer_names: list of names for sanitizers + + Returns: + tuple[Vault, int, tuple]: A tuple containing the vault object, vault ID, and vault tuples. + """ vault_id = None vault_tuples = None length = len(self.scanners["input"]["sanitizers"]) @@ -112,7 +124,9 @@ def _update_output_sanitizers(self, config, sanitizer_names: list = ["Deanonymiz """This function is responsible for updating vault for given sanitizer names in output Args: - sanitizer_names: list of names for sanitizers""" + config: Configuration containing sanitizer settings. + sanitizer_names: list of names for sanitizers + """ length = len(self.scanners["output"]["sanitizers"]) for i in range(length): scanner_name = type(self.scanners["output"]["sanitizers"][i]).__name__ @@ -131,7 +145,7 @@ def _load_policy_scanners(self, config: dict = None) -> list: config: configuration for scanner Returns: - policy_filters: Either None or a list of scanners defined in the policy + list: Either None or a list of scanners defined in the policy. """ config_keys = get_policy_filters(config) if "policy" in config: @@ -259,9 +273,10 @@ def _apply_output_filters(self, original_input, model_response) -> dict[str, dic Args: original_input: The original input prompt for which model produced a response + model_response: The model's response to apply filters on Returns: - result: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, + dict[str, dict[str, Any]]: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, "is_valid" which is boolean that says if the prompt is valid or not based on a scanner applied and "risk_score" which gives the risk score assigned by the scanner to the prompt. """ result = {} @@ -280,10 +295,11 @@ def _apply_output_sanitizers(self, input_prompt, model_response) -> dict[str, di """Takes in model_response and applies sanitizers on it Args: - original_input: The original input prompt for which model produced a response + input_prompt: The original input prompt for which model produced a response + model_response: The model's response to apply sanitizers on Returns: - result: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, + dict[str, dict[str, Any]]: A dictionary with key as scanner_name which is the name of the scanner applied to the output and value as a dictionary with keys "sanitized_prompt" which is the actual prompt, "is_valid" which is boolean that says if the prompt is valid or not based on a scanner applied and "risk_score" which gives the risk score assigned by the scanner to the prompt. """ result = scan_output(self.scanners["output"]["sanitizers"], input_prompt, model_response) diff --git a/plugins/external/llmguard/llmguardplugin/plugin.py b/plugins/external/llmguard/llmguardplugin/plugin.py index afaa5a484..4e52fd90a 100644 --- a/plugins/external/llmguard/llmguardplugin/plugin.py +++ b/plugins/external/llmguard/llmguardplugin/plugin.py @@ -15,12 +15,12 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginError, PluginErrorModel, PluginViolation, - Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -50,7 +50,10 @@ def __init__(self, config: PluginConfig) -> None: """Entry init block for plugin. Validates the configuration of plugin and initializes an instance of LLMGuardBase with the config Args: - config: the skill configuration + config: the skill configuration + + Raises: + PluginError: If the configuration is invalid for plugin initialization. """ super().__init__(config) self.lgconfig = LLMGuardConfig.model_validate(self._config.config) @@ -62,14 +65,28 @@ def __init__(self, config: PluginConfig) -> None: raise PluginError(error=PluginErrorModel(message="Invalid configuration for plugin initilialization", plugin_name=self.name)) def __verify_lgconfig(self): - """Checks if the configuration provided for plugin is valid or not. It should either have input or output key atleast""" + """Checks if the configuration provided for plugin is valid or not. It should either have input or output key atleast + + Returns: + bool: True if configuration is valid (has input or output), False otherwise. + """ return self.lgconfig.input or self.lgconfig.output def __update_context(self, context, key, value) -> dict: - """Update Context implementation.""" + """Update Context implementation. + + Args: + context: The plugin context to update. + key: The key to set in context. + value: The value to set for the key. + """ def update_context(context): - """Update Context implementation.""" + """Update Context implementation. + + Args: + context: The plugin context to update. + """ plugin_name = self.__class__.__name__ if plugin_name not in context.state[self.guardrails_context_key]: diff --git a/plugins/external/llmguard/llmguardplugin/policy.py b/plugins/external/llmguard/llmguardplugin/policy.py index db0c1fdbe..047dbee72 100644 --- a/plugins/external/llmguard/llmguardplugin/policy.py +++ b/plugins/external/llmguard/llmguardplugin/policy.py @@ -34,7 +34,10 @@ def evaluate(self, policy: str, scan_result: dict) -> Union[bool, str]: scan_result: The result of scanners applied Returns: - A union of bool (if true or false). However, if the policy expression is invalid returns string with invalid expression + Union[bool, str]: A union of bool (if true or false). However, if the policy expression is invalid returns string with invalid expression + + Raises: + ValueError: If the policy expression contains invalid operations. """ policy_variables = {key: value["is_valid"] for key, value in scan_result.items()} try: @@ -97,10 +100,9 @@ def get_policy_filters(policy_expression) -> Union[list, None]: Args: policy_expression: The expression of policy - sentence2: The second sentence Returns: - None if no policy expression is defined, else a comma separated list of filters defined in the policy + Union[list, None]: None if no policy expression is defined, else a comma separated list of filters defined in the policy """ if isinstance(policy_expression, str): pattern = r"\b(and|or|not)\b|[()]" diff --git a/plugins/external/llmguard/resources/plugins/config.yaml b/plugins/external/llmguard/resources/plugins/config.yaml index bbb7b4d64..d583c9cf7 100644 --- a/plugins/external/llmguard/resources/plugins/config.yaml +++ b/plugins/external/llmguard/resources/plugins/config.yaml @@ -5,7 +5,7 @@ plugins: description: "A plugin for running input through llmguard scanners " version: "0.1.0" author: "Shriti Priya" - hooks: ["prompt_pre_fetch","prompt_post_fetch"] + hooks: ["prompt_pre_fetch", "prompt_post_fetch"] tags: ["plugin", "guardrails", "llmguard", "pre-post"] mode: "enforce" # enforce | permissive | disabled priority: 10 diff --git a/plugins/external/opa/opapluginfilter/plugin.py b/plugins/external/opa/opapluginfilter/plugin.py index 4557d865a..408153bed 100644 --- a/plugins/external/opa/opapluginfilter/plugin.py +++ b/plugins/external/opa/opapluginfilter/plugin.py @@ -19,10 +19,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -70,8 +70,7 @@ def __init__(self, config: PluginConfig): """Entry init block for plugin. Args: - logger: logger that the skill can make use of - config: the skill configuration + config: the skill configuration """ super().__init__(config) self.opa_config = OPAConfig.model_validate(self._config.config) @@ -105,16 +104,25 @@ def _evaluate_opa_policy(self, url: str, input: OPAInput, policy_input_data_map: Args: url: The url to call opa server input: Contains the payload of input to be sent to opa server for policy evaluation. + policy_input_data_map: Mapping of policy input data keys. Returns: - True, json_response if the opa policy is allowed else false. The json response is the actual response returned by OPA server. + tuple[bool, Any]: True, json_response if the opa policy is allowed else false. The json response is the actual response returned by OPA server. If OPA server encountered any error, the return would be True (to gracefully exit) and None would be the json_response, marking an issue with the OPA server running. """ def _key(k: str, m: str) -> str: - """Key implementation.""" + """Key implementation. + + Args: + k: The key string. + m: The mapping string. + + Returns: + str: Combined key string. + """ return f"{k}.{m}" if k.split(".")[0] == "context" else k @@ -222,10 +230,6 @@ def _extract_payload_key(self, content: Any = None, key: str = None, result: dic content: The content of post hook results. key: The key for which value needs to be extracted for. result: A list of all the values for a key. - - Returns: - None - """ if isinstance(content, list): for element in content: diff --git a/plugins/external/opa/resources/plugins/config.yaml b/plugins/external/opa/resources/plugins/config.yaml index a4499b13c..c033e8498 100644 --- a/plugins/external/opa/resources/plugins/config.yaml +++ b/plugins/external/opa/resources/plugins/config.yaml @@ -4,7 +4,7 @@ plugins: description: "An OPA plugin that enforces rego policies on requests and allows/denies requests as per policies" version: "0.1.0" author: "Shriti Priya" - hooks: ["tool_pre_invoke","tool_post_invoke", "prompt_pre_fetch", "prompt_post_fetch", "resource_pre_fetch", "resource_post_fetch"] + hooks: ["tool_pre_invoke", "tool_post_invoke", "prompt_pre_fetch", "prompt_post_fetch", "resource_pre_fetch", "resource_post_fetch"] tags: ["plugin"] mode: "permissive" # enforce | permissive | disabled priority: 30 diff --git a/plugins/external/opa/tests/test_all.py b/plugins/external/opa/tests/test_all.py index 3e2d872bd..b6ec72500 100644 --- a/plugins/external/opa/tests/test_all.py +++ b/plugins/external/opa/tests/test_all.py @@ -24,7 +24,11 @@ @pytest.fixture(scope="module", autouse=True) def plugin_manager(): - """Initialize plugin manager.""" + """Initialize plugin manager. + + Yields: + PluginManager: An initialized plugin manager instance. + """ plugin_manager = PluginManager("./resources/plugins/config.yaml") asyncio.run(plugin_manager.initialize()) yield plugin_manager @@ -33,7 +37,11 @@ def plugin_manager(): @pytest.mark.asyncio async def test_prompt_pre_hook(plugin_manager: PluginManager): - """Test prompt pre hook across all registered plugins.""" + """Test prompt pre hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = PromptPrehookPayload(name="test_prompt", args={"arg0": "This is an argument"}) global_context = GlobalContext(request_id="1") @@ -44,7 +52,11 @@ async def test_prompt_pre_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_prompt_post_hook(plugin_manager: PluginManager): - """Test prompt post hook across all registered plugins.""" + """Test prompt post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing message = Message(content=TextContent(type="text", text="prompt"), role=Role.USER) prompt_result = PromptResult(messages=[message]) @@ -57,7 +69,11 @@ async def test_prompt_post_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_tool_pre_hook(plugin_manager: PluginManager): - """Test tool pre hook across all registered plugins.""" + """Test tool pre hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = ToolPreInvokePayload(name="test_prompt", args={"arg0": "This is an argument"}) global_context = GlobalContext(request_id="1") @@ -68,7 +84,11 @@ async def test_tool_pre_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_tool_post_hook(plugin_manager: PluginManager): - """Test tool post hook across all registered plugins.""" + """Test tool post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = ToolPostInvokePayload(name="test_tool", result={"output0": "output value"}) global_context = GlobalContext(request_id="1") @@ -79,7 +99,11 @@ async def test_tool_post_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_resource_pre_hook(plugin_manager: PluginManager): - """Test tool post hook across all registered plugins.""" + """Test tool post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing payload = ResourcePreFetchPayload(uri="https://test_resource.com", metadata={}) global_context = GlobalContext(request_id="1", server_id="2") @@ -90,7 +114,11 @@ async def test_resource_pre_hook(plugin_manager: PluginManager): @pytest.mark.asyncio async def test_resource_post_hook(plugin_manager: PluginManager): - """Test tool post hook across all registered plugins.""" + """Test tool post hook across all registered plugins. + + Args: + plugin_manager: The plugin manager instance. + """ # Customize payload for testing content = ResourceContent( type="resource", diff --git a/plugins/file_type_allowlist/file_type_allowlist.py b/plugins/file_type_allowlist/file_type_allowlist.py index 9b2b62ab4..aa7b20143 100644 --- a/plugins/file_type_allowlist/file_type_allowlist.py +++ b/plugins/file_type_allowlist/file_type_allowlist.py @@ -22,10 +22,10 @@ # First-Party from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, diff --git a/plugins/harmful_content_detector/harmful_content_detector.py b/plugins/harmful_content_detector/harmful_content_detector.py index 3f9d0a48e..7468cb0d1 100644 --- a/plugins/harmful_content_detector/harmful_content_detector.py +++ b/plugins/harmful_content_detector/harmful_content_detector.py @@ -23,10 +23,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPostInvokePayload, diff --git a/plugins/header_injector/header_injector.py b/plugins/header_injector/header_injector.py index daa642155..59173bdc3 100644 --- a/plugins/header_injector/header_injector.py +++ b/plugins/header_injector/header_injector.py @@ -22,9 +22,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ResourcePreFetchPayload, ResourcePreFetchResult, ) diff --git a/plugins/html_to_markdown/html_to_markdown.py b/plugins/html_to_markdown/html_to_markdown.py index 025a62ce4..9a87dfe23 100644 --- a/plugins/html_to_markdown/html_to_markdown.py +++ b/plugins/html_to_markdown/html_to_markdown.py @@ -20,9 +20,9 @@ # First-Party from mcpgateway.common.models import ResourceContent from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ) diff --git a/plugins/json_repair/json_repair.py b/plugins/json_repair/json_repair.py index f246faa1c..470209cc4 100644 --- a/plugins/json_repair/json_repair.py +++ b/plugins/json_repair/json_repair.py @@ -18,9 +18,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) diff --git a/plugins/license_header_injector/license_header_injector.py b/plugins/license_header_injector/license_header_injector.py index 5fc1e55b3..563cbee56 100644 --- a/plugins/license_header_injector/license_header_injector.py +++ b/plugins/license_header_injector/license_header_injector.py @@ -22,9 +22,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, diff --git a/plugins/markdown_cleaner/markdown_cleaner.py b/plugins/markdown_cleaner/markdown_cleaner.py index a247e6a05..16d48c5f9 100644 --- a/plugins/markdown_cleaner/markdown_cleaner.py +++ b/plugins/markdown_cleaner/markdown_cleaner.py @@ -17,12 +17,11 @@ from typing import Any # First-Party -from mcpgateway.common.models import Message, PromptResult, TextContent -from mcpgateway.common.models import ResourceContent +from mcpgateway.common.models import Message, PromptResult, ResourceContent, TextContent from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, diff --git a/plugins/output_length_guard/output_length_guard.py b/plugins/output_length_guard/output_length_guard.py index 4d2884d57..7c494987d 100644 --- a/plugins/output_length_guard/output_length_guard.py +++ b/plugins/output_length_guard/output_length_guard.py @@ -34,10 +34,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ) diff --git a/plugins/pii_filter/pii_filter.py b/plugins/pii_filter/pii_filter.py index 4672deca8..69609c06e 100644 --- a/plugins/pii_filter/pii_filter.py +++ b/plugins/pii_filter/pii_filter.py @@ -19,10 +19,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -43,7 +43,10 @@ _RustPIIDetector = None try: - from .pii_filter_rust import RustPIIDetector as _RustPIIDetector, RUST_AVAILABLE as _RUST_AVAILABLE + # Local + from .pii_filter_rust import RUST_AVAILABLE as _RUST_AVAILABLE + from .pii_filter_rust import RustPIIDetector as _RustPIIDetector + if _RUST_AVAILABLE: logger.info("πŸ¦€ Rust PII filter available - using high-performance implementation (5-100x speedup)") else: @@ -805,6 +808,9 @@ def _apply_pii_masking_to_parsed_json(self, data: Any, base_path: str, all_detec data: The parsed JSON data structure base_path: The base path for this JSON data all_detections: Dictionary containing all PII detections + + Returns: + None: Modifies data in place. """ if isinstance(data, str): # Check if this path has detections diff --git a/plugins/pii_filter/pii_filter_rust.py b/plugins/pii_filter/pii_filter_rust.py index c0d9a34e2..180f3a5c7 100644 --- a/plugins/pii_filter/pii_filter_rust.py +++ b/plugins/pii_filter/pii_filter_rust.py @@ -9,11 +9,13 @@ Thin Python wrapper around the Rust implementation for seamless integration. """ -from typing import Dict, List, Any, TYPE_CHECKING +# Standard import logging +from typing import Any, Dict, List, TYPE_CHECKING # Use TYPE_CHECKING to avoid circular import at runtime if TYPE_CHECKING: + # Local from .pii_filter import PIIFilterConfig logger = logging.getLogger(__name__) @@ -21,20 +23,23 @@ # Try to import Rust implementation # Fix sys.path to prioritize site-packages over source directory try: - import sys + # Standard import os + import sys # Temporarily remove current directory from path if it contains plugins_rust source original_path = sys.path.copy() project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - plugins_rust_src = os.path.join(project_root, 'plugins_rust') + plugins_rust_src = os.path.join(project_root, "plugins_rust") # Remove source directory from path temporarily filtered_path = [p for p in sys.path if not p.startswith(plugins_rust_src)] sys.path = filtered_path try: + # First-Party from plugins_rust import PIIDetectorRust as _RustDetector + RUST_AVAILABLE = True logger.info("πŸ¦€ Rust PII filter module imported successfully") finally: @@ -69,16 +74,15 @@ def __init__(self, config: "PIIFilterConfig"): Raises: ImportError: If Rust implementation is not available + TypeError: If configuration type is invalid ValueError: If configuration is invalid """ # Import here to avoid circular dependency + # Local from .pii_filter import PIIFilterConfig # pylint: disable=import-outside-toplevel if not RUST_AVAILABLE: - raise ImportError( - "Rust implementation not available. " - "Install with: pip install mcpgateway[rust]" - ) + raise ImportError("Rust implementation not available. " "Install with: pip install mcpgateway[rust]") # Validate config type if not isinstance(config, PIIFilterConfig): @@ -114,6 +118,9 @@ def detect(self, text: str) -> Dict[str, List[Dict]]: ] } + Raises: + RuntimeError: If PII detection fails. + Example: >>> detector.detect("SSN: 123-45-6789") {'ssn': [{'value': '123-45-6789', 'start': 5, 'end': 16, 'mask_strategy': 'partial'}]} @@ -132,7 +139,10 @@ def mask(self, text: str, detections: Dict[str, List[Dict]]) -> str: detections: Detection results from detect() Returns: - Masked text with PII replaced according to strategies + str: Masked text with PII replaced according to strategies + + Raises: + RuntimeError: If PII masking fails. Example: >>> text = "SSN: 123-45-6789" @@ -157,11 +167,14 @@ def process_nested(self, data: Any, path: str = "") -> tuple[bool, Any, Dict]: path: Current path in the structure (for logging) Returns: - Tuple of (modified, new_data, detections) where: + tuple[bool, Any, Dict]: Tuple of (modified, new_data, detections) where: - modified: True if any PII was found and masked - new_data: The data structure with masked PII - detections: Dictionary of all detections found + Raises: + RuntimeError: If nested processing fails. + Example: >>> data = {"user": {"ssn": "123-45-6789", "name": "John"}} >>> modified, new_data, detections = detector.process_nested(data) @@ -176,4 +189,4 @@ def process_nested(self, data: Any, path: str = "") -> tuple[bool, Any, Dict]: # Export module-level availability flag -__all__ = ['RustPIIDetector', 'RUST_AVAILABLE'] +__all__ = ["RustPIIDetector", "RUST_AVAILABLE"] diff --git a/plugins/privacy_notice_injector/privacy_notice_injector.py b/plugins/privacy_notice_injector/privacy_notice_injector.py index cd45058c3..f619dbaad 100644 --- a/plugins/privacy_notice_injector/privacy_notice_injector.py +++ b/plugins/privacy_notice_injector/privacy_notice_injector.py @@ -21,9 +21,9 @@ # First-Party from mcpgateway.common.models import Message, Role, TextContent from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, PromptPosthookPayload, PromptPosthookResult, ) diff --git a/plugins/rate_limiter/rate_limiter.py b/plugins/rate_limiter/rate_limiter.py index 74ba09a9e..78eccafa4 100644 --- a/plugins/rate_limiter/rate_limiter.py +++ b/plugins/rate_limiter/rate_limiter.py @@ -22,10 +22,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, diff --git a/plugins/regex_filter/search_replace.py b/plugins/regex_filter/search_replace.py index ef6c59707..79e4fc54f 100644 --- a/plugins/regex_filter/search_replace.py +++ b/plugins/regex_filter/search_replace.py @@ -16,9 +16,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, diff --git a/plugins/resource_filter/resource_filter.py b/plugins/resource_filter/resource_filter.py index 8a25aea4f..98121db25 100644 --- a/plugins/resource_filter/resource_filter.py +++ b/plugins/resource_filter/resource_filter.py @@ -19,11 +19,11 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginMode, PluginViolation, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, diff --git a/plugins/response_cache_by_prompt/response_cache_by_prompt.py b/plugins/response_cache_by_prompt/response_cache_by_prompt.py index f84ff4d6c..fa7821817 100644 --- a/plugins/response_cache_by_prompt/response_cache_by_prompt.py +++ b/plugins/response_cache_by_prompt/response_cache_by_prompt.py @@ -28,9 +28,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, diff --git a/plugins/retry_with_backoff/retry_with_backoff.py b/plugins/retry_with_backoff/retry_with_backoff.py index 305da62a4..ef63ee87f 100644 --- a/plugins/retry_with_backoff/retry_with_backoff.py +++ b/plugins/retry_with_backoff/retry_with_backoff.py @@ -17,9 +17,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, diff --git a/plugins/robots_license_guard/robots_license_guard.py b/plugins/robots_license_guard/robots_license_guard.py index 5b7fe3a02..3643688bf 100644 --- a/plugins/robots_license_guard/robots_license_guard.py +++ b/plugins/robots_license_guard/robots_license_guard.py @@ -23,10 +23,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ResourcePreFetchPayload, diff --git a/plugins/safe_html_sanitizer/safe_html_sanitizer.py b/plugins/safe_html_sanitizer/safe_html_sanitizer.py index a6d68cca4..1d4364f0f 100644 --- a/plugins/safe_html_sanitizer/safe_html_sanitizer.py +++ b/plugins/safe_html_sanitizer/safe_html_sanitizer.py @@ -30,9 +30,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ) diff --git a/plugins/schema_guard/schema_guard.py b/plugins/schema_guard/schema_guard.py index 132d21bbf..e8962b970 100644 --- a/plugins/schema_guard/schema_guard.py +++ b/plugins/schema_guard/schema_guard.py @@ -20,10 +20,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, diff --git a/plugins/secrets_detection/secrets_detection.py b/plugins/secrets_detection/secrets_detection.py index fb76c8411..1d2198a6a 100644 --- a/plugins/secrets_detection/secrets_detection.py +++ b/plugins/secrets_detection/secrets_detection.py @@ -23,10 +23,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPrehookPayload, PromptPrehookResult, ResourcePostFetchPayload, diff --git a/plugins/sql_sanitizer/sql_sanitizer.py b/plugins/sql_sanitizer/sql_sanitizer.py index 95d39f094..5ad84de02 100644 --- a/plugins/sql_sanitizer/sql_sanitizer.py +++ b/plugins/sql_sanitizer/sql_sanitizer.py @@ -26,10 +26,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPrehookPayload, PromptPrehookResult, ToolPreInvokePayload, diff --git a/plugins/summarizer/summarizer.py b/plugins/summarizer/summarizer.py index 8f4a7990b..9ba229a54 100644 --- a/plugins/summarizer/summarizer.py +++ b/plugins/summarizer/summarizer.py @@ -23,9 +23,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ResourcePostFetchPayload, ResourcePostFetchResult, ToolPostInvokePayload, diff --git a/plugins/timezone_translator/timezone_translator.py b/plugins/timezone_translator/timezone_translator.py index ce1547db3..af644ca7d 100644 --- a/plugins/timezone_translator/timezone_translator.py +++ b/plugins/timezone_translator/timezone_translator.py @@ -25,9 +25,9 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, diff --git a/plugins/url_reputation/url_reputation.py b/plugins/url_reputation/url_reputation.py index 4ea78b4b0..35bc2e82d 100644 --- a/plugins/url_reputation/url_reputation.py +++ b/plugins/url_reputation/url_reputation.py @@ -20,10 +20,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ResourcePreFetchPayload, ResourcePreFetchResult, ) diff --git a/plugins/vault/vault_plugin.py b/plugins/vault/vault_plugin.py index 4f23dd83a..115168419 100644 --- a/plugins/vault/vault_plugin.py +++ b/plugins/vault/vault_plugin.py @@ -22,10 +22,10 @@ # First-Party from mcpgateway.db import get_db from mcpgateway.plugins.framework import ( + HttpHeaderPayload, + Plugin, PluginConfig, PluginContext, - Plugin, - HttpHeaderPayload, ToolPreInvokePayload, ToolPreInvokeResult, ) diff --git a/plugins/virus_total_checker/virus_total_checker.py b/plugins/virus_total_checker/virus_total_checker.py index 5f4c2ba32..1754a86c6 100644 --- a/plugins/virus_total_checker/virus_total_checker.py +++ b/plugins/virus_total_checker/virus_total_checker.py @@ -31,10 +31,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPosthookPayload, PromptPosthookResult, ResourcePostFetchPayload, @@ -313,7 +313,13 @@ def _ip_in_cidrs(ip: str, cidrs: list[str]) -> bool: def _apply_overrides(url: str, host: str | None, cfg: VirusTotalConfig) -> str | None: """Return 'deny', 'allow', or None based on local overrides and precedence. - Precedence order is controlled by cfg.override_precedence. + Args: + url: The URL to check for overrides. + host: The host to check for overrides (optional). + cfg: The VirusTotal configuration. + + Returns: + str | None: 'deny', 'allow', or None based on overrides. Precedence order is controlled by cfg.override_precedence. """ host_l = (host or "").lower() allow = _url_matches(url, cfg.allow_url_patterns) or (host_l and _domain_matches(host_l, cfg.allow_domains)) or (host_l and _ip_in_cidrs(host_l, cfg.allow_ip_cidrs)) diff --git a/plugins/watchdog/watchdog.py b/plugins/watchdog/watchdog.py index d399e5e94..e61711f4d 100644 --- a/plugins/watchdog/watchdog.py +++ b/plugins/watchdog/watchdog.py @@ -23,10 +23,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, ToolPostInvokePayload, ToolPostInvokeResult, ToolPreInvokePayload, diff --git a/plugins/webhook_notification/webhook_notification.py b/plugins/webhook_notification/webhook_notification.py index ae888bf2d..37eeb61da 100644 --- a/plugins/webhook_notification/webhook_notification.py +++ b/plugins/webhook_notification/webhook_notification.py @@ -27,10 +27,10 @@ # First-Party from mcpgateway.plugins.framework import ( + Plugin, PluginConfig, PluginContext, PluginViolation, - Plugin, PromptPosthookPayload, PromptPosthookResult, PromptPrehookPayload, @@ -131,7 +131,15 @@ def __init__(self, config: PluginConfig) -> None: self._client = httpx.AsyncClient() async def _render_template(self, template: str, context: Dict[str, Any]) -> str: - """Render a Jinja2-style template with the given context.""" + """Render a Jinja2-style template with the given context. + + Args: + template: The template string to render. + context: The context dictionary for template rendering. + + Returns: + str: The rendered template string. + """ # Simple template substitution for now - could be enhanced with Jinja2 result = template for key, value in context.items(): @@ -145,7 +153,16 @@ async def _render_template(self, template: str, context: Dict[str, Any]) -> str: return result def _create_hmac_signature(self, payload: str, secret: str, algorithm: str) -> str: - """Create HMAC signature for the payload.""" + """Create HMAC signature for the payload. + + Args: + payload: The payload to sign. + secret: The secret key for HMAC. + algorithm: The hash algorithm to use. + + Returns: + str: The HMAC signature string. + """ hash_func = getattr(hashlib, algorithm, hashlib.sha256) signature = hmac.new(secret.encode("utf-8"), payload.encode("utf-8"), hash_func).hexdigest() return f"{algorithm}={signature}" @@ -159,7 +176,16 @@ async def _send_webhook( metadata: Optional[Dict[str, Any]] = None, payload_data: Optional[Dict[str, Any]] = None, ) -> None: - """Send a webhook notification with retry logic.""" + """Send a webhook notification with retry logic. + + Args: + webhook: The webhook configuration. + event: The event type to notify. + context: The plugin context. + violation: Optional violation details. + metadata: Optional metadata dictionary. + payload_data: Optional payload data dictionary. + """ if not webhook.enabled or event not in webhook.events: return @@ -229,7 +255,15 @@ async def _send_webhook( async def _notify_webhooks( self, event: EventType, context: PluginContext, violation: Optional[PluginViolation] = None, metadata: Optional[Dict[str, Any]] = None, payload_data: Optional[Dict[str, Any]] = None ) -> None: - """Send notifications to all configured webhooks.""" + """Send notifications to all configured webhooks. + + Args: + event: The event type to notify. + context: The plugin context. + violation: Optional violation details. + metadata: Optional metadata dictionary. + payload_data: Optional payload data dictionary. + """ if not self._cfg.webhooks: return @@ -240,7 +274,14 @@ async def _notify_webhooks( await asyncio.gather(*tasks, return_exceptions=True) def _determine_event_type(self, violation: Optional[PluginViolation]) -> EventType: - """Determine event type based on violation details.""" + """Determine event type based on violation details. + + Args: + violation: Optional violation details. + + Returns: + EventType: The determined event type. + """ if not violation: return EventType.TOOL_SUCCESS @@ -254,20 +295,52 @@ def _determine_event_type(self, violation: Optional[PluginViolation]) -> EventTy return EventType.VIOLATION async def prompt_pre_fetch(self, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: - """Hook for prompt pre-fetch events.""" + """Hook for prompt pre-fetch events. + + Args: + payload: The prompt pre-hook payload. + context: The plugin context. + + Returns: + PromptPrehookResult: The result of the pre-fetch hook. + """ return PromptPrehookResult() async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: PluginContext) -> PromptPosthookResult: - """Hook for prompt post-fetch events.""" + """Hook for prompt post-fetch events. + + Args: + payload: The prompt post-hook payload. + context: The plugin context. + + Returns: + PromptPosthookResult: The result of the post-fetch hook. + """ await self._notify_webhooks(EventType.PROMPT_SUCCESS, context, metadata={"prompt_id": payload.prompt_id}) return PromptPosthookResult() async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult: - """Hook for tool pre-invoke events.""" + """Hook for tool pre-invoke events. + + Args: + payload: The tool pre-invoke payload. + context: The plugin context. + + Returns: + ToolPreInvokeResult: The result of the pre-invoke hook. + """ return ToolPreInvokeResult() async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: - """Hook for tool post-invoke events.""" + """Hook for tool post-invoke events. + + Args: + payload: The tool post-invoke payload. + context: The plugin context. + + Returns: + ToolPostInvokeResult: The result of the post-invoke hook. + """ # Check if there was an error in the result event = EventType.TOOL_SUCCESS metadata = {"tool_name": payload.name} @@ -284,16 +357,36 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin return ToolPostInvokeResult() async def resource_pre_fetch(self, payload: ResourcePreFetchPayload, context: PluginContext) -> ResourcePreFetchResult: - """Hook for resource pre-fetch events.""" + """Hook for resource pre-fetch events. + + Args: + payload: The resource pre-fetch payload. + context: The plugin context. + + Returns: + ResourcePreFetchResult: The result of the pre-fetch hook. + """ return ResourcePreFetchResult() async def resource_post_fetch(self, payload: ResourcePostFetchPayload, context: PluginContext) -> ResourcePostFetchResult: - """Hook for resource post-fetch events.""" + """Hook for resource post-fetch events. + + Args: + payload: The resource post-fetch payload. + context: The plugin context. + + Returns: + ResourcePostFetchResult: The result of the post-fetch hook. + """ await self._notify_webhooks(EventType.RESOURCE_SUCCESS, context, metadata={"resource_uri": payload.uri}) return ResourcePostFetchResult() async def __aenter__(self): - """Async context manager entry.""" + """Async context manager entry. + + Returns: + WebhookNotificationPlugin: The plugin instance. + """ return self async def __aexit__(self, _exc_type, _exc_val, _exc_tb): diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml index 31793520a..3a6ec9c16 100644 --- a/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/agent_passthrough.yaml @@ -12,7 +12,7 @@ plugins: - agent mode: enforce priority: 50 - + # Plugin directories to scan plugin_dirs: - "plugins/native" # Built-in plugins From c3e7466a4de40cb912148f4e69a02081ebc5c868 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Thu, 6 Nov 2025 00:57:32 -0500 Subject: [PATCH 15/20] chore: fix pylint issues Signed-off-by: Frederico Araujo --- plugins/external/opa/tests/test_all.py | 4 ++-- plugins/external/opa/tests/test_opapluginfilter.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/external/opa/tests/test_all.py b/plugins/external/opa/tests/test_all.py index b6ec72500..71cbca5c8 100644 --- a/plugins/external/opa/tests/test_all.py +++ b/plugins/external/opa/tests/test_all.py @@ -12,9 +12,9 @@ from mcpgateway.plugins.framework import ( GlobalContext, PluginManager, + PluginResult, PromptPosthookPayload, PromptPrehookPayload, - PromptResult, ResourcePostFetchPayload, ResourcePreFetchPayload, ToolPostInvokePayload, @@ -59,7 +59,7 @@ async def test_prompt_post_hook(plugin_manager: PluginManager): """ # Customize payload for testing message = Message(content=TextContent(type="text", text="prompt"), role=Role.USER) - prompt_result = PromptResult(messages=[message]) + prompt_result = PluginResult(messages=[message]) payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) global_context = GlobalContext(request_id="1") result, _ = await plugin_manager.prompt_post_fetch(payload, global_context) diff --git a/plugins/external/opa/tests/test_opapluginfilter.py b/plugins/external/opa/tests/test_opapluginfilter.py index 9ba896c9b..075d9e54f 100644 --- a/plugins/external/opa/tests/test_opapluginfilter.py +++ b/plugins/external/opa/tests/test_opapluginfilter.py @@ -21,9 +21,9 @@ GlobalContext, PluginConfig, PluginContext, + PluginResult, PromptPosthookPayload, PromptPrehookPayload, - PromptResult, ResourcePostFetchPayload, ResourcePreFetchPayload, ToolPostInvokePayload, @@ -160,7 +160,7 @@ async def test_post_prompt_fetch_opapluginfilter(): # Benign payload (allowed by OPA (rego) policy) message = Message(content=TextContent(type="text", text="abc"), role=Role.USER) - prompt_result = PromptResult(messages=[message]) + prompt_result = PluginResult(messages=[message]) payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_post_fetch(payload, context) @@ -168,7 +168,7 @@ async def test_post_prompt_fetch_opapluginfilter(): # Malign payload (denied by OPA (rego) policy) message = Message(content=TextContent(type="text", text="abc@example.com"), role=Role.USER) - prompt_result = PromptResult(messages=[message]) + prompt_result = PluginResult(messages=[message]) payload = PromptPosthookPayload(name="test_prompt", result=prompt_result) context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) result = await plugin.prompt_post_fetch(payload, context) From c02949929ac78db4b277f14a54471e1cae3f9852 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Thu, 6 Nov 2025 01:37:30 -0500 Subject: [PATCH 16/20] fix: pylint issues Signed-off-by: Frederico Araujo --- mcpgateway/plugins/framework/base.py | 6 +- mcpgateway/validators.py | 1176 +------------------------- 2 files changed, 13 insertions(+), 1169 deletions(-) diff --git a/mcpgateway/plugins/framework/base.py b/mcpgateway/plugins/framework/base.py index 60c64cc18..1d3e221b9 100644 --- a/mcpgateway/plugins/framework/base.py +++ b/mcpgateway/plugins/framework/base.py @@ -25,6 +25,8 @@ PluginResult, ) +# pylint: disable=import-outside-toplevel + class Plugin(ABC): """Base plugin object for pre/post processing of inputs and outputs at various locations throughout the server. @@ -181,7 +183,7 @@ def json_to_payload(self, hook: str, payload: Union[str | dict]) -> PluginPayloa # Fall back to global registry if not hook_payload_type: # First-Party - from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry registry = get_hook_registry() hook_payload_type = registry.get_payload_type(hook) @@ -216,7 +218,7 @@ def json_to_result(self, hook: str, result: Union[str | dict]) -> PluginResult: # Fall back to global registry if not hook_result_type: # First-Party - from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + from mcpgateway.plugins.framework.hooks.registry import get_hook_registry registry = get_hook_registry() hook_result_type = registry.get_result_type(hook) diff --git a/mcpgateway/validators.py b/mcpgateway/validators.py index 743cf489f..45b0259dc 100644 --- a/mcpgateway/validators.py +++ b/mcpgateway/validators.py @@ -5,37 +5,14 @@ Authors: Mihai Criveti, Madhav Kandukuri SecurityValidator for MCP Gateway -This module defines the `SecurityValidator` class, which provides centralized, configurable -validation logic for user-generated content in MCP-based applications. +This module re-exports the SecurityValidator class from mcpgateway.common.validators +for backward compatibility. -The validator enforces strict security and structural rules across common input types such as: -- Display text (e.g., names, descriptions) -- Identifiers and tool names -- URIs and URLs -- JSON object depth -- Templates (including limited HTML/Jinja2) -- MIME types - -Key Features: -- Pattern-based validation using settings-defined regex for HTML/script safety -- Configurable max lengths and depth limits -- Whitelist-based URL scheme and MIME type validation -- Safe escaping of user-visible text fields -- Reusable static/class methods for field-level and form-level validation - -Intended to be used with Pydantic or similar schema-driven systems to validate and sanitize -user input in a consistent, centralized way. - -Dependencies: -- Standard Library: re, html, logging, urllib.parse -- First-party: `settings` from `mcpgateway.config` +The canonical location for SecurityValidator is mcpgateway.common.validators. +This module exists to maintain backward compatibility with code that imports from +mcpgateway.validators. Example usage: - SecurityValidator.validate_name("my_tool", field_name="Tool Name") - SecurityValidator.validate_url("https://example.com") - SecurityValidator.validate_json_depth({...}) - -Examples: >>> from mcpgateway.validators import SecurityValidator >>> SecurityValidator.sanitize_display_text('Test', 'test') '<b>Test</b>' @@ -47,1144 +24,9 @@ >>> SecurityValidator.validate_json_depth({'a': 1}) """ -# Standard -import html -import logging -import re -from urllib.parse import urlparse -import uuid - # First-Party -from mcpgateway.config import settings - -logger = logging.getLogger(__name__) - - -class SecurityValidator: - """Configurable validation with MCP-compliant limits""" - - # Configurable patterns (from settings) - DANGEROUS_HTML_PATTERN = ( - settings.validation_dangerous_html_pattern - ) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|' - DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script - ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"] - - # Character type patterns - NAME_PATTERN = settings.validation_name_pattern # Default: ^[a-zA-Z0-9_\-\s]+$ - IDENTIFIER_PATTERN = settings.validation_identifier_pattern # Default: ^[a-zA-Z0-9_\-\.]+$ - VALIDATION_SAFE_URI_PATTERN = settings.validation_safe_uri_pattern # Default: ^[a-zA-Z0-9_\-.:/?=&%]+$ - VALIDATION_UNSAFE_URI_PATTERN = settings.validation_unsafe_uri_pattern # Default: [<>"\'\\] - TOOL_NAME_PATTERN = settings.validation_tool_name_pattern # Default: ^[a-zA-Z][a-zA-Z0-9_-]*$ - - # MCP-compliant limits (configurable) - MAX_NAME_LENGTH = settings.validation_max_name_length # Default: 255 - MAX_DESCRIPTION_LENGTH = settings.validation_max_description_length # Default: 8192 (8KB) - MAX_TEMPLATE_LENGTH = settings.validation_max_template_length # Default: 65536 - MAX_CONTENT_LENGTH = settings.validation_max_content_length # Default: 1048576 (1MB) - MAX_JSON_DEPTH = settings.validation_max_json_depth # Default: 10 - MAX_URL_LENGTH = settings.validation_max_url_length # Default: 2048 - - @classmethod - def sanitize_display_text(cls, value: str, field_name: str) -> str: - """Ensure text is safe for display in UI by escaping special characters - - Args: - value (str): Value to validate - field_name (str): Name of field being validated - - Returns: - str: Value if acceptable - - Raises: - ValueError: When input is not acceptable - - Examples: - Basic HTML escaping: - - >>> SecurityValidator.sanitize_display_text('Hello World', 'test') - 'Hello World' - >>> SecurityValidator.sanitize_display_text('Hello World', 'test') - 'Hello <b>World</b>' - - Empty/None handling: - - >>> SecurityValidator.sanitize_display_text('', 'test') - '' - >>> SecurityValidator.sanitize_display_text(None, 'test') #doctest: +SKIP - - Dangerous script patterns: - - >>> SecurityValidator.sanitize_display_text('alert();', 'test') - 'alert();' - >>> SecurityValidator.sanitize_display_text('javascript:alert(1)', 'test') - Traceback (most recent call last): - ... - ValueError: test contains script patterns that may cause display issues - - Polyglot attack patterns: - - >>> SecurityValidator.sanitize_display_text('"; alert()', 'test') - Traceback (most recent call last): - ... - ValueError: test contains potentially dangerous character sequences - >>> SecurityValidator.sanitize_display_text('-->test', 'test') - '-->test' - >>> SecurityValidator.sanitize_display_text('-->') - Traceback (most recent call last): - ... - ValueError: Template contains HTML tags that may interfere with proper display - >>> SecurityValidator.validate_template('Test ') - Traceback (most recent call last): - ... - ValueError: Template contains HTML tags that may interfere with proper display - >>> SecurityValidator.validate_template('') - Traceback (most recent call last): - ... - ValueError: Template contains HTML tags that may interfere with proper display - - Event handlers blocked: - - >>> SecurityValidator.validate_template('
Test
') - Traceback (most recent call last): - ... - ValueError: Template contains event handlers that may cause display issues - >>> SecurityValidator.validate_template('onload = "alert(1)"') - Traceback (most recent call last): - ... - ValueError: Template contains event handlers that may cause display issues - - SSTI prevention patterns: - - >>> SecurityValidator.validate_template('{{ __import__ }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ config }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{% import os %}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 7*7 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 10/2 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 5+5 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('{{ 10-5 }}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - - Other template injection patterns: - - >>> SecurityValidator.validate_template('${evil}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('#{evil}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - >>> SecurityValidator.validate_template('%{evil}') - Traceback (most recent call last): - ... - ValueError: Template contains potentially dangerous expressions - - Length limit testing: - - >>> long_template = 'a' * 65537 - >>> SecurityValidator.validate_template(long_template) - Traceback (most recent call last): - ... - ValueError: Template exceeds maximum length of 65536 - """ - if not value: - return value - - if len(value) > cls.MAX_TEMPLATE_LENGTH: - raise ValueError(f"Template exceeds maximum length of {cls.MAX_TEMPLATE_LENGTH}") - - # Block dangerous tags but allow Jinja2 syntax {{ }} and {% %} - dangerous_tags = r"<(script|iframe|object|embed|link|meta|base|form)\b" - if re.search(dangerous_tags, value, re.IGNORECASE): - raise ValueError("Template contains HTML tags that may interfere with proper display") - - # Check for event handlers that could cause issues - if re.search(r"on\w+\s*=", value, re.IGNORECASE): - raise ValueError("Template contains event handlers that may cause display issues") - - # SSTI Prevention - block dangerous template expressions - ssti_patterns = [ - r"\{\{.*(__|\.|config|self|request|application|globals|builtins|import).*\}\}", # Jinja2 dangerous patterns - r"\{%.*(__|\.|config|self|request|application|globals|builtins|import).*%\}", # Jinja2 tags - r"\$\{.*\}", # ${} expressions - r"#\{.*\}", # #{} expressions - r"%\{.*\}", # %{} expressions - r"\{\{.*\*.*\}\}", # Math operations in templates (like {{7*7}}) - r"\{\{.*\/.*\}\}", # Division operations - r"\{\{.*\+.*\}\}", # Addition operations - r"\{\{.*\-.*\}\}", # Subtraction operations - ] - - for pattern in ssti_patterns: - if re.search(pattern, value, re.IGNORECASE): - raise ValueError("Template contains potentially dangerous expressions") - - return value - - @classmethod - def validate_url(cls, value: str, field_name: str = "URL") -> str: - """Validate URLs for allowed schemes and safe display - - Args: - value (str): Value to validate - field_name (str): Name of field being validated - - Returns: - str: Value if acceptable - - Raises: - ValueError: When input is not acceptable - - Examples: - Valid URLs: - - >>> SecurityValidator.validate_url('https://example.com') - 'https://example.com' - >>> SecurityValidator.validate_url('http://example.com') - 'http://example.com' - >>> SecurityValidator.validate_url('ws://example.com') - 'ws://example.com' - >>> SecurityValidator.validate_url('wss://example.com') - 'wss://example.com' - >>> SecurityValidator.validate_url('https://example.com:8080/path') - 'https://example.com:8080/path' - >>> SecurityValidator.validate_url('https://example.com/path?query=value') - 'https://example.com/path?query=value' - - Empty URL handling: - - >>> SecurityValidator.validate_url('') - Traceback (most recent call last): - ... - ValueError: URL cannot be empty - - Length validation: - - >>> long_url = 'https://example.com/' + 'a' * 2100 - >>> SecurityValidator.validate_url(long_url) - Traceback (most recent call last): - ... - ValueError: URL exceeds maximum length of 2048 - - Scheme validation: - - >>> SecurityValidator.validate_url('ftp://example.com') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('file:///etc/passwd') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('javascript:alert(1)') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('data:text/plain,hello') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('vbscript:alert(1)') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('about:blank') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('chrome://settings') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - >>> SecurityValidator.validate_url('mailto:test@example.com') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - - IPv6 URL blocking: - - >>> SecurityValidator.validate_url('https://[::1]:8080/') - Traceback (most recent call last): - ... - ValueError: URL contains IPv6 address which is not supported - >>> SecurityValidator.validate_url('https://[2001:db8::1]/') - Traceback (most recent call last): - ... - ValueError: URL contains IPv6 address which is not supported - - Protocol-relative URL blocking: - - >>> SecurityValidator.validate_url('//example.com/path') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - - Line break injection: - - >>> SecurityValidator.validate_url('https://example.com\\rHost: evil.com') - Traceback (most recent call last): - ... - ValueError: URL contains line breaks which are not allowed - >>> SecurityValidator.validate_url('https://example.com\\nHost: evil.com') - Traceback (most recent call last): - ... - ValueError: URL contains line breaks which are not allowed - - Space validation: - - >>> SecurityValidator.validate_url('https://exam ple.com') - Traceback (most recent call last): - ... - ValueError: URL contains spaces which are not allowed in URLs - >>> SecurityValidator.validate_url('https://example.com/path?query=hello world') - 'https://example.com/path?query=hello world' - - Malformed URLs: - - >>> SecurityValidator.validate_url('https://') - Traceback (most recent call last): - ... - ValueError: URL is not a valid URL - >>> SecurityValidator.validate_url('not-a-url') - Traceback (most recent call last): - ... - ValueError: URL must start with one of: http://, https://, ws://, wss:// - - Restricted IP addresses: - - >>> SecurityValidator.validate_url('https://0.0.0.0/') - Traceback (most recent call last): - ... - ValueError: URL contains invalid IP address (0.0.0.0) - >>> SecurityValidator.validate_url('https://169.254.169.254/') - Traceback (most recent call last): - ... - ValueError: URL contains restricted IP address - - Invalid port numbers: - - >>> SecurityValidator.validate_url('https://example.com:0/') - Traceback (most recent call last): - ... - ValueError: URL contains invalid port number - >>> try: - ... SecurityValidator.validate_url('https://example.com:65536/') - ... except ValueError as e: - ... 'Port out of range' in str(e) or 'invalid port' in str(e) - True - - Credentials in URL: - - >>> SecurityValidator.validate_url('https://user:pass@example.com/') - Traceback (most recent call last): - ... - ValueError: URL contains credentials which are not allowed - >>> SecurityValidator.validate_url('https://user@example.com/') - Traceback (most recent call last): - ... - ValueError: URL contains credentials which are not allowed - - XSS patterns in URLs: - - >>> SecurityValidator.validate_url('https://example.com/', 'test_field') - Traceback (most recent call last): - ... - ValueError: test_field contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'content') - Traceback (most recent call last): - ... - ValueError: content contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'data') - Traceback (most recent call last): - ... - ValueError: data contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'embed') - Traceback (most recent call last): - ... - ValueError: embed contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'style') - Traceback (most recent call last): - ... - ValueError: style contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'meta') - Traceback (most recent call last): - ... - ValueError: meta contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'base') - Traceback (most recent call last): - ... - ValueError: base contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('
', 'form') - Traceback (most recent call last): - ... - ValueError: form contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'image') - Traceback (most recent call last): - ... - ValueError: image contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'svg') - Traceback (most recent call last): - ... - ValueError: svg contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'video') - Traceback (most recent call last): - ... - ValueError: video contains HTML tags that may cause security issues - >>> SecurityValidator.validate_no_xss('', 'audio') - Traceback (most recent call last): - ... - ValueError: audio contains HTML tags that may cause security issues - """ - if not value: - return # Empty values are considered safe - # Check for dangerous HTML tags - if re.search(cls.DANGEROUS_HTML_PATTERN, value, re.IGNORECASE): - raise ValueError(f"{field_name} contains HTML tags that may cause security issues") - - @classmethod - def validate_json_depth( - cls, - obj: object, - max_depth: int | None = None, - current_depth: int = 0, - ) -> None: - """Validate that a JSON‑like structure does not exceed a depth limit. - - A *depth* is counted **only** when we enter a container (`dict` or - `list`). Primitive values (`str`, `int`, `bool`, `None`, etc.) do not - increase the depth, but an *empty* container still counts as one level. - - Args: - obj: Any Python object to inspect recursively. - max_depth: Maximum allowed depth (defaults to - :pyattr:`SecurityValidator.MAX_JSON_DEPTH`). - current_depth: Internal recursion counter. **Do not** set this - from user code. - - Raises: - ValueError: If the nesting level exceeds *max_depth*. - - Examples: - Simple flat dictionary – depth 1: :: - - >>> SecurityValidator.validate_json_depth({'name': 'Alice'}) - - Nested dict – depth 2: :: - - >>> SecurityValidator.validate_json_depth( - ... {'user': {'name': 'Alice'}} - ... ) - - Mixed dict/list – depth 3: :: - - >>> SecurityValidator.validate_json_depth( - ... {'users': [{'name': 'Alice', 'meta': {'age': 30}}]} - ... ) - - Exactly at the default limit (10) – allowed: :: - - >>> deep_10 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': - ... {'9': {'10': 'end'}}}}}}}}}} - >>> SecurityValidator.validate_json_depth(deep_10) - - One level deeper – rejected: :: - - >>> deep_11 = {'1': {'2': {'3': {'4': {'5': {'6': {'7': {'8': - ... {'9': {'10': {'11': 'end'}}}}}}}}}}} - >>> SecurityValidator.validate_json_depth(deep_11) - Traceback (most recent call last): - ... - ValueError: JSON structure exceeds maximum depth of 10 - """ - if max_depth is None: - max_depth = cls.MAX_JSON_DEPTH - - # Only containers count toward depth; primitives are ignored - if not isinstance(obj, (dict, list)): - return - - next_depth = current_depth + 1 - if next_depth > max_depth: - raise ValueError(f"JSON structure exceeds maximum depth of {max_depth}") - - if isinstance(obj, dict): - for value in obj.values(): - cls.validate_json_depth(value, max_depth, next_depth) - else: # obj is a list - for item in obj: - cls.validate_json_depth(item, max_depth, next_depth) - - @classmethod - def validate_mime_type(cls, value: str) -> str: - """Validate MIME type format - - Args: - value (str): Value to validate - - Returns: - str: Value if acceptable - - Raises: - ValueError: When input is not acceptable - - Examples: - Empty/None handling: - - >>> SecurityValidator.validate_mime_type('') - '' - >>> SecurityValidator.validate_mime_type(None) #doctest: +SKIP - - Valid standard MIME types: - - >>> SecurityValidator.validate_mime_type('text/plain') - 'text/plain' - >>> SecurityValidator.validate_mime_type('application/json') - 'application/json' - >>> SecurityValidator.validate_mime_type('image/jpeg') - 'image/jpeg' - >>> SecurityValidator.validate_mime_type('text/html') - 'text/html' - >>> SecurityValidator.validate_mime_type('application/pdf') - 'application/pdf' - - Valid vendor-specific MIME types: - - >>> SecurityValidator.validate_mime_type('application/x-custom') - 'application/x-custom' - >>> SecurityValidator.validate_mime_type('text/x-log') - 'text/x-log' - - Valid MIME types with suffixes: - - >>> SecurityValidator.validate_mime_type('application/vnd.api+json') - 'application/vnd.api+json' - >>> SecurityValidator.validate_mime_type('image/svg+xml') - 'image/svg+xml' - - Invalid MIME type formats: - - >>> SecurityValidator.validate_mime_type('invalid') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text/') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('/plain') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text//plain') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text/plain/extra') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('text plain') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - >>> SecurityValidator.validate_mime_type('') - Traceback (most recent call last): - ... - ValueError: Invalid MIME type format - - Disallowed MIME types (not in whitelist - line 620): - - >>> try: - ... SecurityValidator.validate_mime_type('application/evil') - ... except ValueError as e: - ... 'not in the allowed list' in str(e) - True - >>> try: - ... SecurityValidator.validate_mime_type('text/evil') - ... except ValueError as e: - ... 'not in the allowed list' in str(e) - True - - Test MIME type with parameters (line 618): - - >>> try: - ... SecurityValidator.validate_mime_type('application/evil; charset=utf-8') - ... except ValueError as e: - ... 'Invalid MIME type format' in str(e) - True - """ - if not value: - return value - - # Basic MIME type pattern - mime_pattern = r"^[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*\/[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*$" - if not re.match(mime_pattern, value): - raise ValueError("Invalid MIME type format") - - # Common safe MIME types - safe_mime_types = settings.validation_allowed_mime_types - if value not in safe_mime_types: - # Allow x- vendor types and + suffixes - base_type = value.split(";")[0].strip() - if not (base_type.startswith("application/x-") or base_type.startswith("text/x-") or "+" in base_type): - raise ValueError(f"MIME type '{value}' is not in the allowed list") +# Re-export SecurityValidator from canonical location +# pylint: disable=unused-import +from mcpgateway.common.validators import SecurityValidator # noqa: F401 - return value +__all__ = ["SecurityValidator"] From 2eb8938ec317762e82f389bb646b90be63c47fea Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Thu, 6 Nov 2025 10:06:26 -0500 Subject: [PATCH 17/20] fix: common validator tests Signed-off-by: Frederico Araujo --- tests/unit/mcpgateway/validation/test_validators.py | 2 +- tests/unit/mcpgateway/validation/test_validators_advanced.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/mcpgateway/validation/test_validators.py b/tests/unit/mcpgateway/validation/test_validators.py index 8e81fd39a..e2f930026 100644 --- a/tests/unit/mcpgateway/validation/test_validators.py +++ b/tests/unit/mcpgateway/validation/test_validators.py @@ -48,7 +48,7 @@ def logfn(*args, **kwargs): return logfn - monkeypatch.setattr("mcpgateway.validators.logger", DummyLogger()) + monkeypatch.setattr("mcpgateway.common.validators.logger", DummyLogger()) yield logs diff --git a/tests/unit/mcpgateway/validation/test_validators_advanced.py b/tests/unit/mcpgateway/validation/test_validators_advanced.py index 6645f522d..e29830c45 100644 --- a/tests/unit/mcpgateway/validation/test_validators_advanced.py +++ b/tests/unit/mcpgateway/validation/test_validators_advanced.py @@ -84,7 +84,7 @@ def logfn(*args, **kwargs): return logfn - monkeypatch.setattr("mcpgateway.validators.logger", DummyLogger()) + monkeypatch.setattr("mcpgateway.common.validators.logger", DummyLogger()) yield logs From ce6aee6aae2e383e9e95b864c8bc3561a21add68 Mon Sep 17 00:00:00 2001 From: Frederico Araujo Date: Thu, 6 Nov 2025 12:06:51 -0500 Subject: [PATCH 18/20] chore: fix flake8 issues Signed-off-by: Frederico Araujo --- .../plugins/framework/external/mcp/client.py | 3 +++ .../framework/external/mcp/server/runtime.py | 15 +++++++++------ mcpgateway/plugins/framework/manager.py | 3 ++- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mcpgateway/plugins/framework/external/mcp/client.py b/mcpgateway/plugins/framework/external/mcp/client.py index b334f0521..465a0b81a 100644 --- a/mcpgateway/plugins/framework/external/mcp/client.py +++ b/mcpgateway/plugins/framework/external/mcp/client.py @@ -325,6 +325,9 @@ def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-i Args: hook: name of the hook point. plugin_ref: The reference to the plugin to hook. + + Raises: + PluginError: If the plugin is not an external plugin. """ self._plugin_ref = plugin_ref self._hook = hook diff --git a/mcpgateway/plugins/framework/external/mcp/server/runtime.py b/mcpgateway/plugins/framework/external/mcp/server/runtime.py index b4e57a39e..fcf1e6507 100755 --- a/mcpgateway/plugins/framework/external/mcp/server/runtime.py +++ b/mcpgateway/plugins/framework/external/mcp/server/runtime.py @@ -51,6 +51,9 @@ async def get_plugin_configs() -> list[dict]: Returns: JSON string containing list of plugin configuration dictionaries. + + Raises: + RuntimeError: If plugin server not initialized. """ if not SERVER: raise RuntimeError("Plugin server not initialized") @@ -65,6 +68,9 @@ async def get_plugin_config(name: str) -> dict: Returns: JSON string containing plugin configuration dictionary. + + Raises: + RuntimeError: If plugin server not initialized. """ if not SERVER: raise RuntimeError("Plugin server not initialized") @@ -85,6 +91,9 @@ async def invoke_hook(hook_type: str, plugin_name: str, payload: Dict[str, Any], Returns: Result dictionary with payload, context and any error information. + + Raises: + RuntimeError: If plugin server not initialized. """ if not SERVER: raise RuntimeError("Plugin server not initialized") @@ -165,9 +174,6 @@ async def _start_health_check_server(self, health_port: int) -> None: async def health_check(_request: Request): """Health check endpoint for container orchestration. - Args: - request: the http request from which the health check occurs. - Returns: JSON response with health status. """ @@ -199,9 +205,6 @@ async def run_streamable_http_async(self) -> None: async def health_check(_request: Request): """Health check endpoint for container orchestration. - Args: - request: the http request from which the health check occurs. - Returns: JSON response with health status. """ diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index d8eddb3fe..546ad9838 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -305,7 +305,6 @@ async def _execute_with_timeout(self, hook_ref: HookRef, payload: PluginPayload, Args: hook_ref: Reference to the hook and plugin to execute. - plugin_run: Function to execute the plugin. payload: Payload to process. context: Plugin execution context. @@ -529,6 +528,7 @@ async def invoke_hook( """Invoke a set of plugins configured for the hook point in priority order. Args: + hook_type: The type of hook to execute. payload: The plugin payload for which the plugins will analyze and modify. global_context: Shared context for all plugins with request metadata. local_contexts: Optional existing contexts from previous hook executions. @@ -586,6 +586,7 @@ async def invoke_hook_for_plugin( Raises: PluginError: If the plugin or hook type cannot be found in the registry. + ValueError: If payload type does not match payload_as_json setting. Examples: >>> manager = PluginManager("plugins/config.yaml") From 941a88253ec58fb5153c76692597693a3e36124d Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Sat, 8 Nov 2025 20:23:30 +0000 Subject: [PATCH 19/20] fix: correct imports for plugin framework hooks - Fix import path from mcpgateway.plugins.mcp.entities to correct location - Use ToolHookType from mcpgateway.plugins.framework.hooks.tools - Import HttpHeaderPayload from mcpgateway.plugins.framework.hooks.http - Update HookType references to ToolHookType --- .env.example | 29 + .gitignore | 1 + MANIFEST.in | 1 + README.md | 75 + __init__ | 0 charts/mcp-stack/values.schema.json | 2 +- charts/mcp-stack/values.yaml | 19 +- docs/docs/manage/configuration.md | 35 +- docs/docs/manage/observability.md | 28 +- docs/docs/manage/observability/.pages | 1 + .../observability/internal-observability.md | 823 ++++++++ mcpgateway/admin.py | 1708 ++++++++++++++++- .../a23a08d61eb0_add_observability_tables.py | 147 ++ ...8_add_observability_performance_indexes.py | 69 + ...6f7g8h9_add_observability_saved_queries.py | 54 + mcpgateway/config.py | 28 + mcpgateway/db.py | 320 ++- mcpgateway/instrumentation/__init__.py | 19 + mcpgateway/instrumentation/sqlalchemy.py | 317 +++ mcpgateway/main.py | 30 + mcpgateway/middleware/auth_middleware.py | 105 + .../middleware/observability_middleware.py | 209 ++ mcpgateway/plugins/framework/manager.py | 55 +- mcpgateway/plugins/framework/models.py | 3 +- mcpgateway/routers/observability.py | 526 +++++ mcpgateway/schemas.py | 174 ++ mcpgateway/services/observability_service.py | 1396 ++++++++++++++ mcpgateway/services/prompt_service.py | 43 +- mcpgateway/services/resource_service.py | 44 +- mcpgateway/services/tool_service.py | 521 ++--- mcpgateway/static/flame-graph.css | 213 ++ mcpgateway/static/flame-graph.js | 340 ++++ mcpgateway/static/gantt-chart.css | 273 +++ mcpgateway/static/gantt-chart.js | 388 ++++ mcpgateway/templates/admin.html | 31 + .../templates/observability_metrics.html | 612 ++++++ .../templates/observability_partial.html | 1101 +++++++++++ .../templates/observability_prompts.html | 502 +++++ .../templates/observability_resources.html | 502 +++++ mcpgateway/templates/observability_stats.html | 19 + mcpgateway/templates/observability_tools.html | 591 ++++++ .../templates/observability_trace_detail.html | 207 ++ .../templates/observability_traces_list.html | 39 + scripts/cleanup-dev.sh | 27 + .../db/test_observability_migrations.py | 374 ++++ 45 files changed, 11727 insertions(+), 274 deletions(-) delete mode 100644 __init__ create mode 100644 docs/docs/manage/observability/internal-observability.md create mode 100644 mcpgateway/alembic/versions/a23a08d61eb0_add_observability_tables.py create mode 100644 mcpgateway/alembic/versions/i3c4d5e6f7g8_add_observability_performance_indexes.py create mode 100644 mcpgateway/alembic/versions/j4d5e6f7g8h9_add_observability_saved_queries.py create mode 100644 mcpgateway/instrumentation/__init__.py create mode 100644 mcpgateway/instrumentation/sqlalchemy.py create mode 100644 mcpgateway/middleware/auth_middleware.py create mode 100644 mcpgateway/middleware/observability_middleware.py create mode 100644 mcpgateway/routers/observability.py create mode 100644 mcpgateway/services/observability_service.py create mode 100644 mcpgateway/static/flame-graph.css create mode 100644 mcpgateway/static/flame-graph.js create mode 100644 mcpgateway/static/gantt-chart.css create mode 100644 mcpgateway/static/gantt-chart.js create mode 100644 mcpgateway/templates/observability_metrics.html create mode 100644 mcpgateway/templates/observability_partial.html create mode 100644 mcpgateway/templates/observability_prompts.html create mode 100644 mcpgateway/templates/observability_resources.html create mode 100644 mcpgateway/templates/observability_stats.html create mode 100644 mcpgateway/templates/observability_tools.html create mode 100644 mcpgateway/templates/observability_trace_detail.html create mode 100644 mcpgateway/templates/observability_traces_list.html create mode 100755 scripts/cleanup-dev.sh create mode 100644 tests/unit/mcpgateway/db/test_observability_migrations.py diff --git a/.env.example b/.env.example index a3985edcb..52e79b3c8 100644 --- a/.env.example +++ b/.env.example @@ -1150,3 +1150,32 @@ PAGINATION_INCLUDE_LINKS=true # Enable TLS for gRPC connections by default # MCPGATEWAY_GRPC_TLS_ENABLED=false + +##################################### +# Observability Settings +##################################### + +# Enable observability tracing and metrics collection +# When enabled, all HTTP requests will be traced with detailed timing, status codes, and context +# OBSERVABILITY_ENABLED=false + +# Automatically trace HTTP requests +# OBSERVABILITY_TRACE_HTTP_REQUESTS=true + +# Number of days to retain trace data +# OBSERVABILITY_TRACE_RETENTION_DAYS=7 + +# Maximum number of traces to retain (prevents unbounded growth) +# OBSERVABILITY_MAX_TRACES=100000 + +# Trace sampling rate (0.0-1.0) - 1.0 means trace everything, 0.1 means trace 10% +# OBSERVABILITY_SAMPLE_RATE=1.0 + +# Paths to exclude from tracing (comma-separated regex patterns) +# OBSERVABILITY_EXCLUDE_PATHS=/health,/healthz,/ready,/metrics,/static/.* + +# Enable metrics collection +# OBSERVABILITY_METRICS_ENABLED=true + +# Enable event logging within spans +# OBSERVABILITY_EVENTS_ENABLED=true diff --git a/.gitignore b/.gitignore index 19649b92f..8b4825775 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*FEATURES.md spec/ stats/ .env.bak diff --git a/MANIFEST.in b/MANIFEST.in index b42cef368..c56fa211f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -73,6 +73,7 @@ include mcpgateway/services/import_service.py # πŸ“¦ Migration scripts (v0.7.0 multitenancy migration tools) recursive-include scripts *.py +recursive-include scripts *.sh # πŸ§ͺ Testing documentation and plans recursive-include tests/manual *.py *.md diff --git a/README.md b/README.md index de5a02e9c..f4af57839 100644 --- a/README.md +++ b/README.md @@ -1670,6 +1670,81 @@ mcpgateway > > πŸ“Š **View Traces**: Phoenix UI at `http://localhost:6006`, Jaeger at `http://localhost:16686`, or your configured backend +### Internal Observability & Tracing + +The gateway includes built-in observability features for tracking HTTP requests, spans, and traces independent of OpenTelemetry. This provides database-backed trace storage and analysis directly in the Admin UI. + +| Setting | Description | Default | Options | +| ------------------------------------ | ----------------------------------------------------- | ---------------------------------------------------- | ---------------- | +| `OBSERVABILITY_ENABLED` | Enable internal observability tracing and metrics | `false` | bool | +| `OBSERVABILITY_TRACE_HTTP_REQUESTS` | Automatically trace HTTP requests | `true` | bool | +| `OBSERVABILITY_TRACE_RETENTION_DAYS` | Number of days to retain trace data | `7` | int (β‰₯ 1) | +| `OBSERVABILITY_MAX_TRACES` | Maximum number of traces to retain | `100000` | int (β‰₯ 1000) | +| `OBSERVABILITY_SAMPLE_RATE` | Trace sampling rate (0.0-1.0) | `1.0` | float (0.0-1.0) | +| `OBSERVABILITY_EXCLUDE_PATHS` | Paths to exclude from tracing (regex patterns) | `/health,/healthz,/ready,/metrics,/static/.*` | comma-separated | +| `OBSERVABILITY_METRICS_ENABLED` | Enable metrics collection | `true` | bool | +| `OBSERVABILITY_EVENTS_ENABLED` | Enable event logging within spans | `true` | bool | + +**Key Features:** +- πŸ“Š **Database-backed storage**: Traces stored in SQLite/PostgreSQL for persistence +- πŸ” **Admin UI integration**: View traces, spans, and metrics in the diagnostics tab +- 🎯 **Sampling control**: Configure sampling rate to reduce overhead in high-traffic scenarios +- πŸ• **Automatic cleanup**: Old traces automatically purged based on retention settings +- 🚫 **Path filtering**: Exclude health checks and static resources from tracing + +**Configuration Effects:** +- `OBSERVABILITY_ENABLED=false`: Completely disables internal observability (no database writes, zero overhead) +- `OBSERVABILITY_SAMPLE_RATE=0.1`: Traces 10% of requests (useful for high-volume production) +- `OBSERVABILITY_EXCLUDE_PATHS=/health,/metrics`: Prevents noisy endpoints from creating traces + +> πŸ“ **Note**: This is separate from OpenTelemetry. You can use both systems simultaneously - internal observability for Admin UI visibility and OpenTelemetry for external systems like Phoenix/Jaeger. +> +> πŸŽ›οΈ **Admin UI Access**: When enabled, traces appear in **Admin β†’ Diagnostics β†’ Observability** tab with filtering, search, and export capabilities + +### Prometheus Metrics + +The gateway exposes Prometheus-compatible metrics at `/metrics/prometheus` for monitoring and alerting. + +| Setting | Description | Default | Options | +| ---------------------------- | -------------------------------------------------------- | --------- | ---------------- | +| `ENABLE_METRICS` | Enable Prometheus metrics instrumentation | `true` | bool | +| `METRICS_EXCLUDED_HANDLERS` | Regex patterns for paths to exclude from metrics | (empty) | comma-separated | +| `METRICS_NAMESPACE` | Prometheus metrics namespace (prefix) | `default` | string | +| `METRICS_SUBSYSTEM` | Prometheus metrics subsystem (secondary prefix) | (empty) | string | +| `METRICS_CUSTOM_LABELS` | Static custom labels for app_info gauge | (empty) | `key=value,...` | + +**Key Features:** +- πŸ“Š **Standard metrics**: HTTP request duration, response codes, active requests +- 🏷️ **Custom labels**: Add static labels (environment, region, team) for filtering in Prometheus/Grafana +- 🚫 **Path exclusions**: Prevent high-cardinality issues by excluding dynamic paths +- πŸ“ˆ **Namespace isolation**: Group metrics by application or organization + +**Configuration Examples:** + +```bash +# Production deployment with custom labels +ENABLE_METRICS=true +METRICS_NAMESPACE=mycompany +METRICS_SUBSYSTEM=gateway +METRICS_CUSTOM_LABELS=environment=production,region=us-east-1,team=platform + +# Exclude high-volume endpoints from metrics +METRICS_EXCLUDED_HANDLERS=/servers/.*/sse,/static/.*,.*health.* + +# Disable metrics for development +ENABLE_METRICS=false +``` + +**Metric Names:** +- With namespace + subsystem: `mycompany_gateway_http_requests_total` +- Default (no namespace/subsystem): `default_http_requests_total` + +> ⚠️ **High-Cardinality Warning**: Never use high-cardinality values (user IDs, request IDs, timestamps) in `METRICS_CUSTOM_LABELS`. Only use low-cardinality static values (environment, region, cluster). +> +> πŸ“Š **Prometheus Endpoint**: Access metrics at `GET /metrics/prometheus` (requires authentication if `AUTH_REQUIRED=true`) +> +> 🎯 **Grafana Integration**: Import metrics into Grafana dashboards using the configured namespace as a filter + ### Transport | Setting | Description | Default | Options | diff --git a/__init__ b/__init__ deleted file mode 100644 index e69de29bb..000000000 diff --git a/charts/mcp-stack/values.schema.json b/charts/mcp-stack/values.schema.json index d986cd186..297c0a1b9 100644 --- a/charts/mcp-stack/values.schema.json +++ b/charts/mcp-stack/values.schema.json @@ -251,7 +251,7 @@ }, "additionalProperties": false }, - + "pluginConfig": { "type": "object", "description": "Plugin configuration via ConfigMap", diff --git a/charts/mcp-stack/values.yaml b/charts/mcp-stack/values.yaml index 958c7507e..c3eea4f9e 100644 --- a/charts/mcp-stack/values.yaml +++ b/charts/mcp-stack/values.yaml @@ -16,7 +16,7 @@ mcpContextForge: enabled: false plugins: | # plugin file - + replicaCount: 2 # horizontal scaling for the gateway # --- HORIZONTAL POD AUTOSCALER -------------------------------------- @@ -317,6 +317,23 @@ mcpContextForge: OTEL_BSP_MAX_EXPORT_BATCH_SIZE: "512" # max export batch size OTEL_BSP_SCHEDULE_DELAY: "5000" # schedule delay in milliseconds + # ─ Internal Observability & Tracing ─ + OBSERVABILITY_ENABLED: "false" # enable internal observability tracing and metrics + OBSERVABILITY_TRACE_HTTP_REQUESTS: "true" # automatically trace HTTP requests + OBSERVABILITY_TRACE_RETENTION_DAYS: "7" # number of days to retain trace data + OBSERVABILITY_MAX_TRACES: "100000" # maximum number of traces to retain + OBSERVABILITY_SAMPLE_RATE: "1.0" # trace sampling rate (0.0-1.0, 1.0 = trace everything) + OBSERVABILITY_EXCLUDE_PATHS: "/health,/healthz,/ready,/metrics,/static/.*" # paths to exclude from tracing + OBSERVABILITY_METRICS_ENABLED: "true" # enable metrics collection + OBSERVABILITY_EVENTS_ENABLED: "true" # enable event logging within spans + + # ─ Prometheus Metrics ─ + ENABLE_METRICS: "true" # enable Prometheus metrics instrumentation + METRICS_EXCLUDED_HANDLERS: "" # regex patterns for paths to exclude from metrics (comma-separated) + METRICS_NAMESPACE: "default" # Prometheus metrics namespace (prefix for all metric names) + METRICS_SUBSYSTEM: "" # Prometheus metrics subsystem (secondary prefix) + METRICS_CUSTOM_LABELS: "" # static custom labels for app_info gauge (key=value,key2=value2) + # ─ Header Passthrough (Security Warning) ─ ENABLE_HEADER_PASSTHROUGH: "false" # enable HTTP header passthrough (security implications) ENABLE_OVERWRITE_BASE_HEADERS: "false" # enable overwriting of base headers (advanced usage) diff --git a/docs/docs/manage/configuration.md b/docs/docs/manage/configuration.md index 326ecc645..d966b6897 100644 --- a/docs/docs/manage/configuration.md +++ b/docs/docs/manage/configuration.md @@ -563,7 +563,7 @@ REQUIRE_TOKEN_EXPIRATION=true TOKEN_EXPIRY=60 ``` -### Observability Integration +### OpenTelemetry Observability ```bash # OpenTelemetry (Phoenix, Jaeger, etc.) @@ -574,6 +574,39 @@ OTEL_EXPORTER_OTLP_PROTOCOL=grpc OTEL_SERVICE_NAME=mcp-gateway ``` +### Internal Observability System + +MCP Gateway includes a built-in observability system that stores traces and metrics in the database, providing performance analytics and error tracking through the Admin UI. + +```bash +# Enable internal observability (database-backed tracing) +OBSERVABILITY_ENABLED=false + +# Automatically trace HTTP requests +OBSERVABILITY_TRACE_HTTP_REQUESTS=true + +# Trace retention (days) +OBSERVABILITY_TRACE_RETENTION_DAYS=7 + +# Maximum traces to retain (prevents unbounded growth) +OBSERVABILITY_MAX_TRACES=100000 + +# Trace sampling rate (0.0-1.0) +# 1.0 = trace everything, 0.1 = trace 10% of requests +OBSERVABILITY_SAMPLE_RATE=1.0 + +# Paths to exclude from tracing (comma-separated regex patterns) +OBSERVABILITY_EXCLUDE_PATHS=/health,/healthz,/ready,/metrics,/static/.* + +# Enable metrics collection +OBSERVABILITY_METRICS_ENABLED=true + +# Enable event logging within spans +OBSERVABILITY_EVENTS_ENABLED=true +``` + +See the [Internal Observability Guide](observability/internal-observability.md) for detailed usage instructions including Admin UI dashboards, performance metrics, and trace analysis. + --- ## πŸ“š Related Documentation diff --git a/docs/docs/manage/observability.md b/docs/docs/manage/observability.md index c069eb84f..d59e33e89 100644 --- a/docs/docs/manage/observability.md +++ b/docs/docs/manage/observability.md @@ -1,16 +1,34 @@ ## Observability -MCP Gateway includes production-grade OpenTelemetry instrumentation for distributed tracing and Prometheus-compatible metrics exposure. +MCP Gateway provides comprehensive observability through two complementary systems: + +1. **Internal Observability** - Built-in database-backed tracing with Admin UI dashboards +2. **OpenTelemetry** - Standard distributed tracing to external backends (Phoenix, Jaeger, Tempo) ## Documentation -- **[Observability Overview](observability/observability.md)** - Complete guide to configuring and using observability +- **[OpenTelemetry Overview](observability/observability.md)** - External observability with OTLP backends +- **[Internal Observability](observability/internal-observability.md)** - Built-in tracing, metrics, and Admin UI dashboards - **[Phoenix Integration](observability/phoenix.md)** - AI/LLM-focused observability with Arize Phoenix ## Quick Start +### Internal Observability (Built-in) + +```bash +# Enable internal observability +export OBSERVABILITY_ENABLED=true + +# Run MCP Gateway +mcpgateway + +# View dashboards at http://localhost:4444/admin/observability +``` + +### OpenTelemetry (External) + ```bash -# Enable observability (enabled by default) +# Enable OpenTelemetry (enabled by default) export OTEL_ENABLE_OBSERVABILITY=true export OTEL_TRACES_EXPORTER=otlp export OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4317 @@ -20,9 +38,9 @@ docker run -p 6006:6006 -p 4317:4317 arizephoenix/phoenix:latest # Run MCP Gateway mcpgateway -``` -View traces at http://localhost:6006 +# View traces at http://localhost:6006 +``` ## Prometheus metrics (important) diff --git a/docs/docs/manage/observability/.pages b/docs/docs/manage/observability/.pages index 1de9009d1..48660d012 100644 --- a/docs/docs/manage/observability/.pages +++ b/docs/docs/manage/observability/.pages @@ -1,3 +1,4 @@ nav: - Overview: observability.md + - Internal Observability: internal-observability.md - Phoenix Integration: phoenix.md diff --git a/docs/docs/manage/observability/internal-observability.md b/docs/docs/manage/observability/internal-observability.md new file mode 100644 index 000000000..474faf025 --- /dev/null +++ b/docs/docs/manage/observability/internal-observability.md @@ -0,0 +1,823 @@ +# Internal Observability System + +MCP Gateway includes a built-in observability system that provides comprehensive performance monitoring, error tracking, and analytics without requiring external observability platforms. All trace data is stored in your database (SQLite/PostgreSQL/MariaDB) and visualized through the Admin UI. + +## Overview + +The internal observability system captures detailed performance metrics and traces for: + +- **Tools** - Invocation frequency, performance metrics, error rates +- **Prompts** - Rendering frequency, latency percentiles, error tracking +- **Resources** - Fetch frequency, performance metrics, error tracking +- **HTTP Requests** - Complete request/response tracing with timing + +Unlike OpenTelemetry (which sends traces to external systems like Phoenix or Jaeger), the internal observability system is self-contained, making it ideal for: + +- Development and testing environments +- Organizations that prefer self-hosted solutions +- Scenarios where external observability platforms are not available +- Quick performance analysis without additional infrastructure + +## Key Features + +### Performance Analytics + +- **Latency Percentiles**: p50, p90, p95, p99 metrics for detailed performance analysis +- **Duration Tracking**: Millisecond-precision timing for all operations +- **Throughput Metrics**: Request counts and rates over time +- **Comparative Analysis**: Side-by-side comparison of multiple resources + +### Error Tracking + +- **Error Rate Monitoring**: Percentage of failed operations with health indicators +- **Error-Prone Analysis**: Identify resources with highest failure rates +- **Status Code Tracking**: HTTP response codes and error messages +- **Root Cause Analysis**: Detailed traces with full context + +### Interactive Dashboards + +- **Summary Cards**: At-a-glance health status, most used, slowest, and most error-prone resources +- **Performance Charts**: Interactive visualizations using Chart.js +- **Time-Based Filtering**: Analyze performance over custom time ranges +- **Auto-Refresh**: Dashboards update every 60 seconds automatically + +### Trace Visualization + +- **Gantt Chart Timeline**: Visual representation of span execution order and timing +- **Flame Graphs**: Hierarchical view of nested operations +- **Trace Details**: Complete trace metadata, attributes, and context +- **Span Explorer**: Drill down into individual operations + +## Quick Start + +### 1. Enable Observability + +Add to your `.env` file: + +```bash +# Enable internal observability +OBSERVABILITY_ENABLED=true + +# Automatically trace HTTP requests +OBSERVABILITY_TRACE_HTTP_REQUESTS=true + +# Retention and limits +OBSERVABILITY_TRACE_RETENTION_DAYS=7 +OBSERVABILITY_MAX_TRACES=100000 + +# Trace sampling (1.0 = 100%, 0.1 = 10%) +OBSERVABILITY_SAMPLE_RATE=1.0 + +# Exclude paths (regex patterns) +OBSERVABILITY_EXCLUDE_PATHS=/health,/healthz,/ready,/metrics,/static/.* + +# Enable metrics and events +OBSERVABILITY_METRICS_ENABLED=true +OBSERVABILITY_EVENTS_ENABLED=true +``` + +### 2. Start MCP Gateway + +```bash +# With environment variables +export OBSERVABILITY_ENABLED=true +mcpgateway + +# Or start development server +make dev +``` + +### 3. Access Admin UI + +Navigate to the Observability section in the Admin UI: + +``` +http://localhost:4444/admin/observability +``` + +## Configuration Reference + +### Core Settings + +| Variable | Description | Default | Options | +|----------|-------------|---------|---------| +| `OBSERVABILITY_ENABLED` | Master switch for internal observability | `false` | `true`, `false` | +| `OBSERVABILITY_TRACE_HTTP_REQUESTS` | Auto-trace HTTP requests | `true` | `true`, `false` | + +### Retention & Limits + +| Variable | Description | Default | Range | +|----------|-------------|---------|-------| +| `OBSERVABILITY_TRACE_RETENTION_DAYS` | Days to retain trace data | `7` | 1-365 | +| `OBSERVABILITY_MAX_TRACES` | Maximum traces to store | `100000` | 1000+ | + +### Sampling & Filtering + +| Variable | Description | Default | Range | +|----------|-------------|---------|-------| +| `OBSERVABILITY_SAMPLE_RATE` | Trace sampling rate | `1.0` | 0.0-1.0 | +| `OBSERVABILITY_EXCLUDE_PATHS` | Regex patterns to exclude | `/health,/healthz,/ready,/metrics,/static/.*` | Comma-separated | + +### Feature Flags + +| Variable | Description | Default | Options | +|----------|-------------|---------|---------| +| `OBSERVABILITY_METRICS_ENABLED` | Enable metrics collection | `true` | `true`, `false` | +| `OBSERVABILITY_EVENTS_ENABLED` | Enable event logging | `true` | `true`, `false` | + +## Admin UI Dashboards + +### Tools Dashboard + +**Path**: `/admin/observability/tools` + +Provides comprehensive analytics for MCP tool invocations: + +#### Summary Cards + +- **Overall Health**: Success rate with color-coded status + - Green: <5% errors (healthy) + - Yellow: 5-20% errors (degraded) + - Red: >20% errors (unhealthy) +- **Most Used**: Top tool by invocation count +- **Slowest**: Tool with highest p99 latency +- **Most Error-Prone**: Tool with highest error rate + +#### Performance Charts + +1. **Tool Usage Chart**: Bar chart showing invocation counts +2. **Average Latency Chart**: Bar chart with millisecond precision +3. **Error Rate Chart**: Percentage visualization with color coding +4. **Top N Error-Prone Tools**: Focused view of problematic tools + +#### Detailed Metrics Table + +For each tool: + +- **Invocation Count**: Total number of calls +- **Latency Percentiles**: p50, p90, p95, p99 in milliseconds +- **Error Rate**: Percentage with color-coded status +- **Last Used**: Timestamp of most recent invocation + +#### Filtering Options + +- **Time Range**: Last 1 hour, 24 hours, 7 days, 30 days +- **Result Limit**: Top 10, 20, 50, or 100 tools +- **Auto-Refresh**: 60-second automatic updates + +### Prompts Dashboard + +**Path**: `/admin/observability/prompts` + +Analyzes MCP prompt rendering performance: + +#### Summary Cards + +- **Overall Health**: Rendering success rate +- **Most Used**: Most frequently rendered prompt +- **Slowest**: Prompt with highest p99 latency +- **Most Error-Prone**: Prompt with highest failure rate + +#### Performance Charts + +1. **Prompt Render Frequency**: Usage distribution +2. **Average Latency**: Rendering performance +3. **Error Rate**: Failure rate analysis +4. **Top N Error-Prone Prompts**: Problem identification + +#### Detailed Metrics + +- **Render Count**: Total rendering operations +- **Latency Percentiles**: p50, p90, p95, p99 metrics +- **Error Rate**: Failure percentage with status +- **Last Rendered**: Most recent usage timestamp + +### Resources Dashboard + +**Path**: `/admin/observability/resources` + +Monitors MCP resource fetch operations: + +#### Summary Cards + +- **Overall Health**: Fetch success rate +- **Most Used**: Most accessed resource +- **Slowest**: Resource with highest p99 latency +- **Most Error-Prone**: Resource with highest error rate + +#### Performance Charts + +1. **Resource Fetch Frequency**: Access patterns +2. **Average Latency**: Fetch performance +3. **Error Rate**: Failure analysis +4. **Top N Error-Prone Resources**: Issue detection + +#### Detailed Metrics + +- **Fetch Count**: Total access operations +- **Latency Percentiles**: p50, p90, p95, p99 metrics +- **Error Rate**: Failure rate with health status +- **Last Fetched**: Recent access timestamp + +## Trace Visualization + +### Trace List + +**Path**: `/admin/observability/traces` + +Browse all captured traces with: + +- **Trace ID**: Unique identifier +- **Operation Name**: Human-readable description +- **Start Time**: When the trace began +- **Duration**: Total execution time +- **Status**: Success/error indicator +- **HTTP Details**: Method, URL, status code + +### Trace Detail View + +**Path**: `/admin/observability/traces/{trace_id}` + +Comprehensive trace analysis: + +#### Trace Metadata + +- **Trace ID**: Unique identifier (W3C format) +- **Name**: Operation description +- **Status**: Overall outcome +- **Duration**: Total execution time +- **HTTP Context**: Method, URL, status code, user agent +- **User Context**: Email, IP address +- **Timestamps**: Start, end, created times + +#### Gantt Chart Timeline + +Visual representation showing: + +- **Span Execution Order**: Chronological flow +- **Nested Operations**: Parent-child relationships +- **Duration Bars**: Relative timing visualization +- **Overlap Detection**: Concurrent operations + +#### Flame Graph + +Hierarchical view displaying: + +- **Call Stack**: Nested span relationships +- **Time Distribution**: Width represents duration +- **Critical Path**: Longest execution chains +- **Bottleneck Identification**: Performance hotspots + +#### Spans Table + +Detailed span information: + +- **Span ID**: Unique identifier +- **Name**: Operation description +- **Kind**: Span type (internal, server, client) +- **Start/End Time**: Execution window +- **Duration**: Millisecond precision +- **Status**: Success/error indicator +- **Resource Info**: Type, name, ID +- **Attributes**: Custom metadata + +## Performance Metrics + +### Latency Percentiles + +The system calculates accurate percentiles using database aggregation: + +- **p50 (Median)**: 50% of requests complete faster +- **p90**: 90% of requests complete faster +- **p95**: 95% of requests complete faster +- **p99**: 99% of requests complete faster + +These metrics help identify performance outliers and establish SLAs. + +### Health Status Indicators + +Color-coded status based on error rates: + +``` +Green (<5% errors) - Healthy +Yellow (5-20% errors) - Degraded +Red (>20% errors) - Unhealthy +``` + +### Metrics Calculation + +All metrics are calculated dynamically based on your selected time range: + +- Real-time aggregation from trace database +- No pre-computation or caching delays +- Accurate percentile calculations using SQLite/PostgreSQL functions +- Efficient indexing for fast queries + +## Data Retention + +### Automatic Cleanup + +Traces older than `OBSERVABILITY_TRACE_RETENTION_DAYS` are automatically deleted: + +```bash +# Retain traces for 7 days (default) +OBSERVABILITY_TRACE_RETENTION_DAYS=7 + +# Extend retention to 30 days +OBSERVABILITY_TRACE_RETENTION_DAYS=30 +``` + +### Size Limits + +Prevent unbounded growth with `OBSERVABILITY_MAX_TRACES`: + +```bash +# Store up to 100,000 traces (default) +OBSERVABILITY_MAX_TRACES=100000 + +# Increase for high-volume environments +OBSERVABILITY_MAX_TRACES=1000000 +``` + +When the limit is reached, oldest traces are deleted first. + +### Manual Cleanup + +Use the CLI for manual trace management: + +```bash +# Delete traces older than 7 days +mcpgateway observability cleanup --days 7 + +# Delete specific trace by ID +mcpgateway observability delete-trace + +# Clear all traces (use with caution!) +mcpgateway observability clear-all +``` + +## Sampling Strategies + +### Full Sampling (Development) + +Capture all requests for complete visibility: + +```bash +OBSERVABILITY_SAMPLE_RATE=1.0 # 100% sampling +``` + +### Partial Sampling (Production) + +Reduce overhead while maintaining visibility: + +```bash +# Sample 10% of requests +OBSERVABILITY_SAMPLE_RATE=0.1 + +# Sample 1% of requests (high volume) +OBSERVABILITY_SAMPLE_RATE=0.01 +``` + +### Path Exclusion + +Exclude noisy or irrelevant paths: + +```bash +# Default exclusions +OBSERVABILITY_EXCLUDE_PATHS=/health,/healthz,/ready,/metrics,/static/.* + +# Custom exclusions (regex patterns) +OBSERVABILITY_EXCLUDE_PATHS=/health.*,/metrics.*,/static/.*,/admin/assets/.* +``` + +## Database Schema + +### ObservabilityTrace Table + +Stores complete request traces: + +```sql +CREATE TABLE observability_traces ( + trace_id VARCHAR(36) PRIMARY KEY, + name VARCHAR(255) NOT NULL, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP, + duration_ms FLOAT, + status VARCHAR(20) DEFAULT 'unset', + status_message TEXT, + http_method VARCHAR(10), + http_url VARCHAR(767), + http_status_code INTEGER, + user_email VARCHAR(255), + user_agent TEXT, + ip_address VARCHAR(45), + attributes JSON, + resource_attributes JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); +``` + +### ObservabilitySpan Table + +Stores individual operations within traces: + +```sql +CREATE TABLE observability_spans ( + span_id VARCHAR(36) PRIMARY KEY, + trace_id VARCHAR(36) NOT NULL, + parent_span_id VARCHAR(36), + name VARCHAR(255) NOT NULL, + kind VARCHAR(20) DEFAULT 'internal', + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP, + duration_ms FLOAT, + status VARCHAR(20) DEFAULT 'unset', + status_message TEXT, + attributes JSON, + resource_name VARCHAR(255), + resource_type VARCHAR(50), + resource_id VARCHAR(36), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (trace_id) REFERENCES observability_traces(trace_id), + FOREIGN KEY (parent_span_id) REFERENCES observability_spans(span_id) +); +``` + +### Performance Indexes + +Optimized for fast queries: + +```sql +-- Trace indexes +CREATE INDEX idx_observability_traces_start_time ON observability_traces(start_time); +CREATE INDEX idx_observability_traces_user_email ON observability_traces(user_email); +CREATE INDEX idx_observability_traces_status ON observability_traces(status); +CREATE INDEX idx_observability_traces_http_status_code ON observability_traces(http_status_code); + +-- Span indexes +CREATE INDEX idx_observability_spans_trace_id ON observability_spans(trace_id); +CREATE INDEX idx_observability_spans_parent_span_id ON observability_spans(parent_span_id); +CREATE INDEX idx_observability_spans_start_time ON observability_spans(start_time); +CREATE INDEX idx_observability_spans_resource_type ON observability_spans(resource_type); +CREATE INDEX idx_observability_spans_resource_name ON observability_spans(resource_name); +``` + +## REST API + +### List Traces + +```bash +GET /observability/traces +``` + +Query parameters: + +- `start_time`: Filter traces after this timestamp (ISO 8601) +- `end_time`: Filter traces before this timestamp +- `min_duration_ms`: Minimum duration in milliseconds +- `max_duration_ms`: Maximum duration in milliseconds +- `status`: Filter by status (`ok`, `error`) +- `http_status_code`: Filter by HTTP status code +- `http_method`: Filter by HTTP method +- `user_email`: Filter by user email +- `limit`: Maximum results (default: 100) +- `offset`: Result offset (default: 0) + +Example: + +```bash +curl -H "Authorization: Bearer $TOKEN" \ + "http://localhost:4444/observability/traces?limit=10&status=error" +``` + +### Get Trace Details + +```bash +GET /observability/traces/{trace_id} +``` + +Returns complete trace with all spans, events, and metrics. + +Example: + +```bash +curl -H "Authorization: Bearer $TOKEN" \ + "http://localhost:4444/observability/traces/550e8400-e29b-41d4-a716-446655440000" +``` + +### Query Tool Metrics + +```bash +GET /observability/tools/metrics +``` + +Query parameters: + +- `time_range`: Time window (`1h`, `24h`, `7d`, `30d`) +- `limit`: Number of tools to return + +Example: + +```bash +curl -H "Authorization: Bearer $TOKEN" \ + "http://localhost:4444/observability/tools/metrics?time_range=24h&limit=20" +``` + +### Query Prompt Metrics + +```bash +GET /observability/prompts/metrics +``` + +Same parameters as tool metrics. + +### Query Resource Metrics + +```bash +GET /observability/resources/metrics +``` + +Same parameters as tool metrics. + +## Comparison: Internal vs OpenTelemetry + +| Feature | Internal Observability | OpenTelemetry | +|---------|----------------------|---------------| +| **Storage** | Database (SQLite/PostgreSQL/MariaDB) | External backends (Phoenix, Jaeger, Tempo) | +| **Setup** | Built-in, zero configuration | Requires external services | +| **Cost** | Free, self-hosted | Depends on backend (free OSS or paid SaaS) | +| **Retention** | Configurable in-database | Backend-dependent | +| **UI** | Admin UI dashboards | Backend-specific UIs | +| **Performance Impact** | Minimal (database writes) | Minimal (async exports) | +| **Use Cases** | Development, testing, small deployments | Production, microservices, distributed systems | +| **Standards** | Custom implementation | OpenTelemetry standard | +| **Integration** | Self-contained | Integrates with APM ecosystem | + +### When to Use Each + +**Use Internal Observability when:** + +- You want zero external dependencies +- Database storage is acceptable +- Admin UI visualization is sufficient +- Deployment simplicity is a priority +- You're in development/testing mode + +**Use OpenTelemetry when:** + +- You need distributed tracing across multiple services +- You want vendor-agnostic standard +- You have existing observability infrastructure +- You need advanced APM features +- You're in production with high scale + +**Use Both when:** + +- You want local debugging with external production monitoring +- You need different retention policies +- You want redundancy in observability data + +## Production Considerations + +### Performance Impact + +The internal observability system is designed for minimal overhead: + +- **Database Writes**: Async, batched when possible +- **Indexing**: Optimized indexes for fast queries +- **Sampling**: Reduce load with configurable sample rates +- **Cleanup**: Automatic retention management + +### Scaling Recommendations + +For high-volume deployments: + +```bash +# Reduce sampling rate +OBSERVABILITY_SAMPLE_RATE=0.1 # 10% sampling + +# Aggressive retention +OBSERVABILITY_TRACE_RETENTION_DAYS=3 + +# Exclude high-frequency paths +OBSERVABILITY_EXCLUDE_PATHS=/health.*,/metrics.*,/static/.* + +# Disable HTTP request tracing (manual traces only) +OBSERVABILITY_TRACE_HTTP_REQUESTS=false +``` + +### Database Considerations + +#### SQLite + +Suitable for: + +- Development and testing +- Single-instance deployments +- Low to medium traffic + +Limitations: + +- Write concurrency limits +- File-based storage + +#### PostgreSQL + +Recommended for: + +- Production deployments +- High-volume environments +- Multi-instance setups + +Benefits: + +- Superior write concurrency +- Advanced indexing +- Better query performance + +#### MariaDB/MySQL + +Alternative production option: + +- Good write performance +- Wide deployment support +- Compatible with PostgreSQL features + +### Monitoring the Monitor + +Track observability system health: + +```bash +# Check trace count +SELECT COUNT(*) FROM observability_traces; + +# Check database size +SELECT pg_size_pretty(pg_total_relation_size('observability_traces')); +SELECT pg_size_pretty(pg_total_relation_size('observability_spans')); + +# Check oldest trace +SELECT MIN(start_time) FROM observability_traces; + +# Check cleanup effectiveness +SELECT COUNT(*) FROM observability_traces +WHERE start_time < NOW() - INTERVAL '7 days'; +``` + +## Troubleshooting + +### No Traces Appearing + +1. **Verify observability is enabled**: + + ```bash + echo $OBSERVABILITY_ENABLED # Should be "true" + ``` + +2. **Check sampling rate**: + + ```bash + echo $OBSERVABILITY_SAMPLE_RATE # Should be > 0.0 + ``` + +3. **Review excluded paths**: + + ```bash + echo $OBSERVABILITY_EXCLUDE_PATHS + # Ensure your test path is not excluded + ``` + +4. **Check database connection**: + + ```bash + # Verify database is accessible + mcpgateway db-check + ``` + +5. **Enable debug logging**: + + ```bash + export LOG_LEVEL=DEBUG + mcpgateway + # Look for observability-related log messages + ``` + +### High Database Size + +1. **Reduce retention period**: + + ```bash + OBSERVABILITY_TRACE_RETENTION_DAYS=3 + ``` + +2. **Lower maximum traces**: + + ```bash + OBSERVABILITY_MAX_TRACES=10000 + ``` + +3. **Increase sampling threshold**: + + ```bash + OBSERVABILITY_SAMPLE_RATE=0.1 + ``` + +4. **Manually cleanup**: + + ```bash + mcpgateway observability cleanup --days 1 + ``` + +### Slow Dashboard Loading + +1. **Reduce query time range**: + + - Use shorter time windows (1 hour instead of 30 days) + +2. **Limit result count**: + + - Query top 10 instead of top 100 + +3. **Add database indexes** (if custom deployment): + + ```sql + CREATE INDEX idx_custom ON observability_spans(resource_type, start_time); + ``` + +4. **Optimize database**: + + ```bash + # PostgreSQL + VACUUM ANALYZE observability_traces; + VACUUM ANALYZE observability_spans; + + # SQLite + VACUUM; + ``` + +### Missing Spans or Metrics + +1. **Check span creation**: + + - Verify tool/prompt/resource operations are completing + - Look for errors in application logs + +2. **Verify metrics enabled**: + + ```bash + echo $OBSERVABILITY_METRICS_ENABLED # Should be "true" + ``` + +3. **Check events enabled**: + + ```bash + echo $OBSERVABILITY_EVENTS_ENABLED # Should be "true" + ``` + +## Best Practices + +### Development + +```bash +# Full tracing, short retention +OBSERVABILITY_ENABLED=true +OBSERVABILITY_SAMPLE_RATE=1.0 +OBSERVABILITY_TRACE_RETENTION_DAYS=1 +OBSERVABILITY_MAX_TRACES=10000 +``` + +### Staging + +```bash +# Partial tracing, moderate retention +OBSERVABILITY_ENABLED=true +OBSERVABILITY_SAMPLE_RATE=0.5 +OBSERVABILITY_TRACE_RETENTION_DAYS=7 +OBSERVABILITY_MAX_TRACES=100000 +``` + +### Production + +```bash +# Sampled tracing, longer retention +OBSERVABILITY_ENABLED=true +OBSERVABILITY_SAMPLE_RATE=0.1 +OBSERVABILITY_TRACE_RETENTION_DAYS=14 +OBSERVABILITY_MAX_TRACES=1000000 +OBSERVABILITY_EXCLUDE_PATHS=/health.*,/metrics.* +``` + +## Next Steps + +- Review [Configuration Reference](../configuration.md) for all observability settings +- Explore [OpenTelemetry Integration](observability.md) for external monitoring +- Set up [Phoenix Integration](phoenix.md) for AI-specific observability +- Configure [Prometheus Metrics](../observability.md#prometheus-metrics-important) for time-series monitoring +- Implement [Custom Dashboards](#admin-ui-dashboards) based on your metrics + +## Related Documentation + +- [Configuration Reference](../configuration.md) - Environment variable configuration +- [OpenTelemetry Observability](observability.md) - External tracing backends +- [Phoenix Integration](phoenix.md) - AI/LLM observability +- [Admin UI Documentation](../ui-customization.md) - Customizing the Admin UI +- [Database Configuration](../configuration.md#database-configuration) - Database setup and tuning diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index b24bc8bac..d02fd920c 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -31,28 +31,31 @@ from pathlib import Path import tempfile import time -from typing import Any, cast, Dict, List, Optional, Union +from typing import Any +from typing import cast as typing_cast +from typing import Dict, List, Optional, Union import urllib.parse import uuid # Third-Party -from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response +from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Response from fastapi.encoders import jsonable_encoder from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse, StreamingResponse import httpx from pydantic import SecretStr, ValidationError from pydantic_core import ValidationError as CoreValidationError -from sqlalchemy import and_, case, desc, func, or_, select +from sqlalchemy import and_, case, cast, desc, func, or_, select, String from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session +from sqlalchemy.orm import joinedload, Session from sqlalchemy.sql.functions import coalesce from starlette.datastructures import UploadFile as StarletteUploadFile # First-Party from mcpgateway.common.models import LogLevel from mcpgateway.config import settings -from mcpgateway.db import get_db, GlobalConfig +from mcpgateway.db import get_db, GlobalConfig, ObservabilitySavedQuery, ObservabilitySpan, ObservabilityTrace from mcpgateway.db import Tool as DbTool +from mcpgateway.db import utc_now from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission from mcpgateway.schemas import ( A2AAgentCreate, @@ -1061,7 +1064,7 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user creation_metadata = MetadataCapture.extract_creation_metadata(request, user) # Ensure default visibility is private and assign to personal team when available - team_id_cast = cast(Optional[str], team_id) + team_id_cast = typing_cast(Optional[str], team_id) await server_service.register_server( db, server, @@ -2456,6 +2459,7 @@ def _to_dict_and_filter(raw_list): "grpc_enabled": GRPC_AVAILABLE and settings.mcpgateway_grpc_enabled, "catalog_enabled": settings.mcpgateway_catalog_enabled, "llmchat_enabled": getattr(settings, "llmchat_enabled", False), + "observability_enabled": getattr(settings, "observability_enabled", False), "current_user": get_user_email(user), "email_auth_enabled": getattr(settings, "email_auth_enabled", False), "is_admin": bool(user.get("is_admin") if isinstance(user, dict) else False), @@ -6307,7 +6311,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use # Extract creation metadata metadata = MetadataCapture.extract_creation_metadata(request, user) - team_id_cast = cast(Optional[str], team_id) + team_id_cast = typing_cast(Optional[str], team_id) await gateway_service.register_gateway( db, gateway, @@ -6858,6 +6862,7 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us ... ("name", "Test Resource"), ... ("description", "A test resource"), ... ("mimeType", "text/plain"), + ... ("template", ""), ... ("content", "Sample content"), ... ]) >>> mock_request = MagicMock(spec=Request) @@ -6889,12 +6894,16 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us team_id = await team_service.verify_team_for_user(user_email, team_id) try: + # Handle template field: convert empty string to None for optional field + template_value = form.get("template") + template = template_value if template_value else None + resource = ResourceCreate( uri=str(form["uri"]), name=str(form["name"]), description=str(form.get("description", "")), mime_type=str(form.get("mimeType", "")), - template=cast(str | None, form.get("template")), + template=template, content=str(form["content"]), tags=tags, visibility=visibility, @@ -8549,7 +8558,7 @@ async def admin_import_tools( }, } - rd = cast(Dict[str, Any], response_data) + rd = typing_cast(Dict[str, Any], response_data) if len(errors) == 0: rd["message"] = f"Successfully imported all {len(created)} tools" else: @@ -8610,7 +8619,7 @@ async def admin_get_logs( HTTPException: If validation fails or service unavailable """ # Get log storage from logging service - storage = cast(Any, logging_service).get_storage() + storage = typing_cast(Any, logging_service).get_storage() if not storage: return {"logs": [], "total": 0, "stats": {}} @@ -8688,7 +8697,7 @@ async def admin_stream_logs( HTTPException: If log level is invalid or service unavailable """ # Get log storage from logging service - storage = cast(Any, logging_service).get_storage() + storage = typing_cast(Any, logging_service).get_storage() if not storage: raise HTTPException(503, "Log storage not available") @@ -8907,7 +8916,7 @@ async def admin_export_logs( raise HTTPException(400, f"Invalid format: {export_format}. Use 'json' or 'csv'") # Get log storage from logging service - storage = cast(Any, logging_service).get_storage() + storage = typing_cast(Any, logging_service).get_storage() if not storage: raise HTTPException(503, "Log storage not available") @@ -11253,3 +11262,1678 @@ async def admin_generate_support_bundle( except Exception as e: LOGGER.error(f"Support bundle generation failed for user {user}: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Failed to generate support bundle: {str(e)}") + + +# ============================================================================ +# Observability Routes +# ============================================================================ + + +@admin_router.get("/observability/partial", response_class=HTMLResponse) +async def get_observability_partial(request: Request, _user=Depends(get_current_user_with_permissions)): + """Render the observability dashboard partial. + + Args: + request: FastAPI request object + _user: Authenticated user with admin permissions (required by dependency) + + Returns: + HTMLResponse: Rendered observability dashboard template + """ + root_path = request.scope.get("root_path", "") + return request.app.state.templates.TemplateResponse("observability_partial.html", {"request": request, "root_path": root_path}) + + +@admin_router.get("/observability/metrics/partial", response_class=HTMLResponse) +async def get_observability_metrics_partial(request: Request, _user=Depends(get_current_user_with_permissions)): + """Render the advanced metrics dashboard partial. + + Args: + request: FastAPI request object + _user: Authenticated user with admin permissions (required by dependency) + + Returns: + HTMLResponse: Rendered metrics dashboard template + """ + root_path = request.scope.get("root_path", "") + return request.app.state.templates.TemplateResponse("observability_metrics.html", {"request": request, "root_path": root_path}) + + +@admin_router.get("/observability/stats", response_class=HTMLResponse) +async def get_observability_stats(request: Request, hours: int = Query(24, ge=1, le=168), _user=Depends(get_current_user_with_permissions)): + """Get observability statistics for the dashboard. + + Args: + request: FastAPI request object + hours: Number of hours to look back for statistics (1-168) + _user: Authenticated user with admin permissions (required by dependency) + + Returns: + HTMLResponse: Rendered statistics template with trace counts and averages + """ + db = next(get_db()) + try: + cutoff_time = datetime.now() - timedelta(hours=hours) + + # pylint: disable=not-callable + total_traces = db.query(func.count(ObservabilityTrace.trace_id)).filter(ObservabilityTrace.start_time >= cutoff_time).scalar() or 0 + + success_count = db.query(func.count(ObservabilityTrace.trace_id)).filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.status == "ok").scalar() or 0 + + error_count = db.query(func.count(ObservabilityTrace.trace_id)).filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.status == "error").scalar() or 0 + # pylint: enable=not-callable + + avg_duration = db.query(func.avg(ObservabilityTrace.duration_ms)).filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.duration_ms.isnot(None)).scalar() or 0 + + stats = { + "total_traces": total_traces, + "success_count": success_count, + "error_count": error_count, + "avg_duration_ms": avg_duration, + } + + return request.app.state.templates.TemplateResponse("observability_stats.html", {"request": request, "stats": stats}) + finally: + db.close() + + +@admin_router.get("/observability/traces", response_class=HTMLResponse) +async def get_observability_traces( + request: Request, + time_range: str = Query("24h"), + status_filter: str = Query("all"), + limit: int = Query(50), + min_duration: Optional[float] = Query(None), + max_duration: Optional[float] = Query(None), + http_method: Optional[str] = Query(None), + user_email: Optional[str] = Query(None), + name_search: Optional[str] = Query(None), + attribute_search: Optional[str] = Query(None), + tool_name: Optional[str] = Query(None), + _user=Depends(get_current_user_with_permissions), +): + """Get list of traces for the dashboard. + + Args: + request: FastAPI request object + time_range: Time range filter (1h, 6h, 24h, 7d) + status_filter: Status filter (all, ok, error) + limit: Maximum number of traces to return + min_duration: Minimum duration in ms + max_duration: Maximum duration in ms + http_method: HTTP method filter + user_email: User email filter + name_search: Trace name search + attribute_search: Full-text attribute search + tool_name: Filter by tool name (shows traces that invoked this tool) + _user: Authenticated user with admin permissions (required by dependency) + + Returns: + HTMLResponse: Rendered traces list template + """ + db = next(get_db()) + try: + # Parse time range + time_map = {"1h": 1, "6h": 6, "24h": 24, "7d": 168} + hours = time_map.get(time_range, 24) + cutoff_time = datetime.now() - timedelta(hours=hours) + + query = db.query(ObservabilityTrace).filter(ObservabilityTrace.start_time >= cutoff_time) + + # Apply status filter + if status_filter != "all": + query = query.filter(ObservabilityTrace.status == status_filter) + + # Apply duration filters + if min_duration is not None: + query = query.filter(ObservabilityTrace.duration_ms >= min_duration) + if max_duration is not None: + query = query.filter(ObservabilityTrace.duration_ms <= max_duration) + + # Apply HTTP method filter + if http_method: + query = query.filter(ObservabilityTrace.http_method == http_method) + + # Apply user email filter + if user_email: + query = query.filter(ObservabilityTrace.user_email.ilike(f"%{user_email}%")) + + # Apply name search + if name_search: + query = query.filter(ObservabilityTrace.name.ilike(f"%{name_search}%")) + + # Apply attribute search + if attribute_search: + # Escape special characters for SQL LIKE + safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_") + query = query.filter(cast(ObservabilityTrace.attributes, String).ilike(f"%{safe_search}%")) + + # Apply tool name filter (join with spans to find traces that invoked a specific tool) + if tool_name: + # Subquery to find trace_ids that have tool invocations matching the tool name + tool_trace_ids = ( + db.query(ObservabilitySpan.trace_id) + .filter( + ObservabilitySpan.name == "tool.invoke", + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').ilike(f"%{tool_name}%"), # pylint: disable=not-callable + ) + .distinct() + .subquery() + ) + query = query.filter(ObservabilityTrace.trace_id.in_(select(tool_trace_ids.c.trace_id))) + + # Get traces ordered by most recent + traces = query.order_by(ObservabilityTrace.start_time.desc()).limit(limit).all() + + root_path = request.scope.get("root_path", "") + return request.app.state.templates.TemplateResponse("observability_traces_list.html", {"request": request, "traces": traces, "root_path": root_path}) + finally: + db.close() + + +@admin_router.get("/observability/trace/{trace_id}", response_class=HTMLResponse) +async def get_observability_trace_detail(request: Request, trace_id: str, _user=Depends(get_current_user_with_permissions)): + """Get detailed trace information with spans. + + Args: + request: FastAPI request object + trace_id: UUID of the trace to retrieve + _user: Authenticated user with admin permissions (required by dependency) + + Returns: + HTMLResponse: Rendered trace detail template with waterfall view + + Raises: + HTTPException: 404 if trace not found + """ + db = next(get_db()) + try: + trace = db.query(ObservabilityTrace).filter_by(trace_id=trace_id).options(joinedload(ObservabilityTrace.spans).joinedload(ObservabilitySpan.events)).first() + + if not trace: + raise HTTPException(status_code=404, detail="Trace not found") + + root_path = request.scope.get("root_path", "") + return request.app.state.templates.TemplateResponse("observability_trace_detail.html", {"request": request, "trace": trace, "root_path": root_path}) + finally: + db.close() + + +@admin_router.post("/observability/queries", response_model=dict) +async def save_observability_query( + request: Request, # pylint: disable=unused-argument + name: str = Body(..., description="Name for the saved query"), + description: Optional[str] = Body(None, description="Optional description"), + filter_config: dict = Body(..., description="Filter configuration as JSON"), + is_shared: bool = Body(False, description="Whether query is shared with team"), + user=Depends(get_current_user_with_permissions), +): + """Save a new observability query filter configuration. + + Args: + request: FastAPI request object + name: User-given name for the query + description: Optional description + filter_config: Dictionary containing all filter values + is_shared: Whether this query is visible to other users + user: Authenticated user (required by dependency) + + Returns: + dict: Created query details with id + + Raises: + HTTPException: 400 if validation fails + """ + db = next(get_db()) + try: + # Get user email from authenticated user + user_email = user.email if hasattr(user, "email") else "unknown" + + # Create new saved query + query = ObservabilitySavedQuery(name=name, description=description, user_email=user_email, filter_config=filter_config, is_shared=is_shared) + + db.add(query) + db.commit() + db.refresh(query) + + return {"id": query.id, "name": query.name, "description": query.description, "filter_config": query.filter_config, "is_shared": query.is_shared, "created_at": query.created_at.isoformat()} + except Exception as e: + db.rollback() + LOGGER.error(f"Failed to save query: {e}") + raise HTTPException(status_code=400, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/queries", response_model=list) +async def list_observability_queries(request: Request, user=Depends(get_current_user_with_permissions)): # pylint: disable=unused-argument + """List saved observability queries for the current user. + + Returns user's own queries plus any shared queries. + + Args: + request: FastAPI request object + user: Authenticated user (required by dependency) + + Returns: + list: List of saved query dictionaries + """ + db = next(get_db()) + try: + user_email = user.email if hasattr(user, "email") else "unknown" + + # Get user's own queries + shared queries + queries = ( + db.query(ObservabilitySavedQuery) + .filter(or_(ObservabilitySavedQuery.user_email == user_email, ObservabilitySavedQuery.is_shared is True)) + .order_by(desc(ObservabilitySavedQuery.created_at)) + .all() + ) + + return [ + { + "id": q.id, + "name": q.name, + "description": q.description, + "filter_config": q.filter_config, + "is_shared": q.is_shared, + "user_email": q.user_email, + "created_at": q.created_at.isoformat(), + "last_used_at": q.last_used_at.isoformat() if q.last_used_at else None, + "use_count": q.use_count, + } + for q in queries + ] + finally: + db.close() + + +@admin_router.get("/observability/queries/{query_id}", response_model=dict) +async def get_observability_query(request: Request, query_id: int, user=Depends(get_current_user_with_permissions)): # pylint: disable=unused-argument + """Get a specific saved query by ID. + + Args: + request: FastAPI request object + query_id: ID of the saved query + user: Authenticated user (required by dependency) + + Returns: + dict: Query details + + Raises: + HTTPException: 404 if query not found or unauthorized + """ + db = next(get_db()) + try: + user_email = user.email if hasattr(user, "email") else "unknown" + + # Can only access own queries or shared queries + query = ( + db.query(ObservabilitySavedQuery).filter(ObservabilitySavedQuery.id == query_id, or_(ObservabilitySavedQuery.user_email == user_email, ObservabilitySavedQuery.is_shared is True)).first() + ) + + if not query: + raise HTTPException(status_code=404, detail="Query not found or unauthorized") + + return { + "id": query.id, + "name": query.name, + "description": query.description, + "filter_config": query.filter_config, + "is_shared": query.is_shared, + "user_email": query.user_email, + "created_at": query.created_at.isoformat(), + "last_used_at": query.last_used_at.isoformat() if query.last_used_at else None, + "use_count": query.use_count, + } + finally: + db.close() + + +@admin_router.put("/observability/queries/{query_id}", response_model=dict) +async def update_observability_query( + request: Request, # pylint: disable=unused-argument + query_id: int, + name: Optional[str] = Body(None), + description: Optional[str] = Body(None), + filter_config: Optional[dict] = Body(None), + is_shared: Optional[bool] = Body(None), + user=Depends(get_current_user_with_permissions), +): + """Update an existing saved query. + + Args: + request: FastAPI request object + query_id: ID of the query to update + name: New name (optional) + description: New description (optional) + filter_config: New filter configuration (optional) + is_shared: New sharing status (optional) + user: Authenticated user (required by dependency) + + Returns: + dict: Updated query details + + Raises: + HTTPException: 404 if query not found, 403 if unauthorized + """ + db = next(get_db()) + try: + user_email = user.email if hasattr(user, "email") else "unknown" + + # Can only update own queries + query = db.query(ObservabilitySavedQuery).filter(ObservabilitySavedQuery.id == query_id, ObservabilitySavedQuery.user_email == user_email).first() + + if not query: + raise HTTPException(status_code=404, detail="Query not found or unauthorized") + + # Update fields if provided + if name is not None: + query.name = name + if description is not None: + query.description = description + if filter_config is not None: + query.filter_config = filter_config + if is_shared is not None: + query.is_shared = is_shared + + db.commit() + db.refresh(query) + + return { + "id": query.id, + "name": query.name, + "description": query.description, + "filter_config": query.filter_config, + "is_shared": query.is_shared, + "updated_at": query.updated_at.isoformat(), + } + except HTTPException: + raise + except Exception as e: + db.rollback() + LOGGER.error(f"Failed to update query: {e}") + raise HTTPException(status_code=400, detail=str(e)) + finally: + db.close() + + +@admin_router.delete("/observability/queries/{query_id}", status_code=204) +async def delete_observability_query(request: Request, query_id: int, user=Depends(get_current_user_with_permissions)): # pylint: disable=unused-argument + """Delete a saved query. + + Args: + request: FastAPI request object + query_id: ID of the query to delete + user: Authenticated user (required by dependency) + + Raises: + HTTPException: 404 if query not found, 403 if unauthorized + """ + db = next(get_db()) + try: + user_email = user.email if hasattr(user, "email") else "unknown" + + # Can only delete own queries + query = db.query(ObservabilitySavedQuery).filter(ObservabilitySavedQuery.id == query_id, ObservabilitySavedQuery.user_email == user_email).first() + + if not query: + raise HTTPException(status_code=404, detail="Query not found or unauthorized") + + db.delete(query) + db.commit() + finally: + db.close() + + +@admin_router.post("/observability/queries/{query_id}/use", response_model=dict) +async def track_query_usage(request: Request, query_id: int, user=Depends(get_current_user_with_permissions)): # pylint: disable=unused-argument + """Track usage of a saved query (increments use count and updates last_used_at). + + Args: + request: FastAPI request object + query_id: ID of the query being used + user: Authenticated user (required by dependency) + + Returns: + dict: Updated query usage stats + + Raises: + HTTPException: 404 if query not found or unauthorized + """ + db = next(get_db()) + try: + user_email = user.email if hasattr(user, "email") else "unknown" + + # Can track usage for own queries or shared queries + query = ( + db.query(ObservabilitySavedQuery).filter(ObservabilitySavedQuery.id == query_id, or_(ObservabilitySavedQuery.user_email == user_email, ObservabilitySavedQuery.is_shared is True)).first() + ) + + if not query: + raise HTTPException(status_code=404, detail="Query not found or unauthorized") + + # Update usage tracking + query.use_count += 1 + query.last_used_at = utc_now() + + db.commit() + db.refresh(query) + + return {"use_count": query.use_count, "last_used_at": query.last_used_at.isoformat()} + except HTTPException: + raise + except Exception as e: + db.rollback() + LOGGER.error(f"Failed to track query usage: {e}") + raise HTTPException(status_code=400, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/metrics/percentiles", response_model=dict) +async def get_latency_percentiles( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + interval_minutes: int = Query(60, ge=5, le=1440, description="Aggregation interval in minutes"), + _user=Depends(get_current_user_with_permissions), +): + """Get latency percentiles (p50, p90, p95, p99) over time. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + interval_minutes: Aggregation interval in minutes (5-1440) + _user: Authenticated user (required by dependency) + + Returns: + dict: Time-series data with percentiles + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + + # Query all traces with duration in time range + traces = ( + db.query(ObservabilityTrace.start_time, ObservabilityTrace.duration_ms) + .filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.duration_ms.isnot(None)) + .order_by(ObservabilityTrace.start_time) + .all() + ) + + if not traces: + return {"timestamps": [], "p50": [], "p90": [], "p95": [], "p99": []} + + # Group traces into time buckets + buckets: Dict[datetime, List[float]] = defaultdict(list) + for trace in traces: + # Round down to nearest interval + bucket_time = trace.start_time.replace(second=0, microsecond=0) + bucket_time = bucket_time - timedelta(minutes=bucket_time.minute % interval_minutes, seconds=bucket_time.second, microseconds=bucket_time.microsecond) + buckets[bucket_time].append(trace.duration_ms) + + # Calculate percentiles for each bucket + timestamps = [] + p50_values = [] + p90_values = [] + p95_values = [] + p99_values = [] + + for bucket_time in sorted(buckets.keys()): + durations = sorted(buckets[bucket_time]) + n = len(durations) + + if n > 0: + # Calculate percentile indices + p50_idx = max(0, int(n * 0.50) - 1) + p90_idx = max(0, int(n * 0.90) - 1) + p95_idx = max(0, int(n * 0.95) - 1) + p99_idx = max(0, int(n * 0.99) - 1) + + timestamps.append(bucket_time.isoformat()) + p50_values.append(round(durations[p50_idx], 2)) + p90_values.append(round(durations[p90_idx], 2)) + p95_values.append(round(durations[p95_idx], 2)) + p99_values.append(round(durations[p99_idx], 2)) + + return {"timestamps": timestamps, "p50": p50_values, "p90": p90_values, "p95": p95_values, "p99": p99_values} + except Exception as e: + LOGGER.error(f"Failed to calculate latency percentiles: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/metrics/timeseries", response_model=dict) +async def get_timeseries_metrics( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + interval_minutes: int = Query(60, ge=5, le=1440, description="Aggregation interval in minutes"), + _user=Depends(get_current_user_with_permissions), +): + """Get time-series metrics (request rate, error rate, throughput). + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + interval_minutes: Aggregation interval in minutes (5-1440) + _user: Authenticated user (required by dependency) + + Returns: + dict: Time-series data with request counts, error rates, and throughput + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + + # Query traces grouped by time bucket + traces = db.query(ObservabilityTrace.start_time, ObservabilityTrace.status).filter(ObservabilityTrace.start_time >= cutoff_time).order_by(ObservabilityTrace.start_time).all() + + if not traces: + return {"timestamps": [], "request_count": [], "success_count": [], "error_count": [], "error_rate": []} + + # Group traces into time buckets + buckets: Dict[datetime, Dict[str, int]] = defaultdict(lambda: {"total": 0, "success": 0, "error": 0}) + for trace in traces: + # Round down to nearest interval + bucket_time = trace.start_time.replace(second=0, microsecond=0) + bucket_time = bucket_time - timedelta(minutes=bucket_time.minute % interval_minutes, seconds=bucket_time.second, microseconds=bucket_time.microsecond) + + buckets[bucket_time]["total"] += 1 + if trace.status == "ok": + buckets[bucket_time]["success"] += 1 + elif trace.status == "error": + buckets[bucket_time]["error"] += 1 + + # Build time-series arrays + timestamps = [] + request_counts = [] + success_counts = [] + error_counts = [] + error_rates = [] + + for bucket_time in sorted(buckets.keys()): + bucket = buckets[bucket_time] + error_rate = (bucket["error"] / bucket["total"] * 100) if bucket["total"] > 0 else 0 + + timestamps.append(bucket_time.isoformat()) + request_counts.append(bucket["total"]) + success_counts.append(bucket["success"]) + error_counts.append(bucket["error"]) + error_rates.append(round(error_rate, 2)) + + return { + "timestamps": timestamps, + "request_count": request_counts, + "success_count": success_counts, + "error_count": error_counts, + "error_rate": error_rates, + } + except Exception as e: + LOGGER.error(f"Failed to calculate timeseries metrics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/metrics/top-slow", response_model=dict) +async def get_top_slow_endpoints( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(10, ge=1, le=100, description="Number of results"), + _user=Depends(get_current_user_with_permissions), +): + """Get top N slowest endpoints by average duration. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Number of results to return (1-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: List of slowest endpoints with stats + + Raises: + HTTPException: 500 if query fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + + # Group by endpoint and calculate average duration + results = ( + db.query( + ObservabilityTrace.http_url, + ObservabilityTrace.http_method, + func.count(ObservabilityTrace.trace_id).label("count"), # pylint: disable=not-callable + func.avg(ObservabilityTrace.duration_ms).label("avg_duration"), + func.max(ObservabilityTrace.duration_ms).label("max_duration"), + ) + .filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.duration_ms.isnot(None)) + .group_by(ObservabilityTrace.http_url, ObservabilityTrace.http_method) + .order_by(desc("avg_duration")) + .limit(limit) + .all() + ) + + endpoints = [] + for row in results: + endpoints.append( + { + "endpoint": f"{row.http_method} {row.http_url}", + "method": row.http_method, + "url": row.http_url, + "count": row.count, + "avg_duration_ms": round(row.avg_duration, 2), + "max_duration_ms": round(row.max_duration, 2), + } + ) + + return {"endpoints": endpoints} + except Exception as e: + LOGGER.error(f"Failed to get top slow endpoints: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/metrics/top-volume", response_model=dict) +async def get_top_volume_endpoints( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(10, ge=1, le=100, description="Number of results"), + _user=Depends(get_current_user_with_permissions), +): + """Get top N highest volume endpoints by request count. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Number of results to return (1-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: List of highest volume endpoints with stats + + Raises: + HTTPException: 500 if query fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + + # Group by endpoint and count requests + results = ( + db.query( + ObservabilityTrace.http_url, + ObservabilityTrace.http_method, + func.count(ObservabilityTrace.trace_id).label("count"), # pylint: disable=not-callable + func.avg(ObservabilityTrace.duration_ms).label("avg_duration"), + ) + .filter(ObservabilityTrace.start_time >= cutoff_time) + .group_by(ObservabilityTrace.http_url, ObservabilityTrace.http_method) + .order_by(desc("count")) + .limit(limit) + .all() + ) + + endpoints = [] + for row in results: + endpoints.append( + { + "endpoint": f"{row.http_method} {row.http_url}", + "method": row.http_method, + "url": row.http_url, + "count": row.count, + "avg_duration_ms": round(row.avg_duration, 2) if row.avg_duration else 0, + } + ) + + return {"endpoints": endpoints} + except Exception as e: + LOGGER.error(f"Failed to get top volume endpoints: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/metrics/top-errors", response_model=dict) +async def get_top_error_endpoints( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(10, ge=1, le=100, description="Number of results"), + _user=Depends(get_current_user_with_permissions), +): + """Get top N error-prone endpoints by error count and rate. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Number of results to return (1-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: List of error-prone endpoints with stats + + Raises: + HTTPException: 500 if query fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + + # Group by endpoint and count errors + results = ( + db.query( + ObservabilityTrace.http_url, + ObservabilityTrace.http_method, + func.count(ObservabilityTrace.trace_id).label("total_count"), # pylint: disable=not-callable + func.sum(case((ObservabilityTrace.status == "error", 1), else_=0)).label("error_count"), + ) + .filter(ObservabilityTrace.start_time >= cutoff_time) + .group_by(ObservabilityTrace.http_url, ObservabilityTrace.http_method) + .having(func.sum(case((ObservabilityTrace.status == "error", 1), else_=0)) > 0) + .order_by(desc("error_count")) + .limit(limit) + .all() + ) + + endpoints = [] + for row in results: + error_rate = (row.error_count / row.total_count * 100) if row.total_count > 0 else 0 + endpoints.append( + { + "endpoint": f"{row.http_method} {row.http_url}", + "method": row.http_method, + "url": row.http_url, + "total_count": row.total_count, + "error_count": row.error_count, + "error_rate": round(error_rate, 2), + } + ) + + return {"endpoints": endpoints} + except Exception as e: + LOGGER.error(f"Failed to get top error endpoints: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/metrics/heatmap", response_model=dict) +async def get_latency_heatmap( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + time_buckets: int = Query(24, ge=10, le=100, description="Number of time buckets"), + latency_buckets: int = Query(20, ge=5, le=50, description="Number of latency buckets"), + _user=Depends(get_current_user_with_permissions), +): + """Get latency distribution heatmap data. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + time_buckets: Number of time buckets (10-100) + latency_buckets: Number of latency buckets (5-50) + _user: Authenticated user (required by dependency) + + Returns: + dict: Heatmap data with time and latency dimensions + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + # Remove timezone info for SQLite compatibility + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Query all traces with duration + traces = ( + db.query(ObservabilityTrace.start_time, ObservabilityTrace.duration_ms) + .filter(ObservabilityTrace.start_time >= cutoff_time_naive, ObservabilityTrace.duration_ms.isnot(None)) + .order_by(ObservabilityTrace.start_time) + .all() + ) + + if not traces: + return {"time_labels": [], "latency_labels": [], "data": []} + + # Calculate time bucket size + time_range = hours * 60 # minutes + time_bucket_minutes = time_range / time_buckets + + # Find latency range and create buckets + durations = [t.duration_ms for t in traces] + min_duration = min(durations) + max_duration = max(durations) + latency_range = max_duration - min_duration + latency_bucket_size = latency_range / latency_buckets if latency_range > 0 else 1 + + # Initialize heatmap matrix + heatmap = [[0 for _ in range(time_buckets)] for _ in range(latency_buckets)] + + # Populate heatmap + for trace in traces: + # Calculate time bucket index + time_diff = (trace.start_time - cutoff_time_naive).total_seconds() / 60 # minutes + time_idx = min(int(time_diff / time_bucket_minutes), time_buckets - 1) + + # Calculate latency bucket index + latency_idx = min(int((trace.duration_ms - min_duration) / latency_bucket_size), latency_buckets - 1) + + heatmap[latency_idx][time_idx] += 1 + + # Generate labels + time_labels = [] + for i in range(time_buckets): + bucket_time = cutoff_time_naive + timedelta(minutes=i * time_bucket_minutes) + time_labels.append(bucket_time.strftime("%H:%M")) + + latency_labels = [] + for i in range(latency_buckets): + bucket_min = min_duration + i * latency_bucket_size + bucket_max = bucket_min + latency_bucket_size + latency_labels.append(f"{bucket_min:.0f}-{bucket_max:.0f}ms") + + return {"time_labels": time_labels, "latency_labels": latency_labels, "data": heatmap} + except Exception as e: + LOGGER.error(f"Failed to generate latency heatmap: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/tools/usage", response_model=dict) +async def get_tool_usage( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of tools to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get tool usage frequency statistics. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of tools to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Tool usage statistics with counts and percentages + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Query tool invocations from spans + # Note: Using $."tool.name" because the JSON key contains a dot + tool_usage = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').label("tool_name"), # pylint: disable=not-callable + func.count(ObservabilitySpan.span_id).label("count"), # pylint: disable=not-callable + ) + .filter( + ObservabilitySpan.name == "tool.invoke", + ObservabilitySpan.start_time >= cutoff_time_naive, + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').isnot(None), # pylint: disable=not-callable + ) + .group_by(func.json_extract(ObservabilitySpan.attributes, '$."tool.name"')) # pylint: disable=not-callable + .order_by(func.count(ObservabilitySpan.span_id).desc()) # pylint: disable=not-callable + .limit(limit) + .all() + ) + + total_invocations = sum(row.count for row in tool_usage) + + tools = [ + { + "tool_name": row.tool_name, + "count": row.count, + "percentage": round((row.count / total_invocations * 100) if total_invocations > 0 else 0, 2), + } + for row in tool_usage + ] + + return {"tools": tools, "total_invocations": total_invocations, "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get tool usage statistics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/tools/performance", response_model=dict) +async def get_tool_performance( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of tools to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get tool performance metrics (avg, min, max duration). + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of tools to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Tool performance metrics + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # First, get all tool invocations with durations + tool_spans = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').label("tool_name"), # pylint: disable=not-callable + ObservabilitySpan.duration_ms, + ) + .filter( + ObservabilitySpan.name == "tool.invoke", + ObservabilitySpan.start_time >= cutoff_time_naive, + ObservabilitySpan.duration_ms.isnot(None), + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').isnot(None), # pylint: disable=not-callable + ) + .all() + ) + + # Group by tool name and calculate percentiles + tool_durations = defaultdict(list) + for span in tool_spans: + tool_durations[span.tool_name].append(span.duration_ms) + + # Calculate metrics for each tool + tools_data = [] + for tool_name, durations in tool_durations.items(): + durations_sorted = sorted(durations) + n = len(durations_sorted) + + if n == 0: + continue + + # Calculate percentiles + def percentile(data, p): + if not data: + return 0 + k = (len(data) - 1) * p + f = int(k) + c = min(f + 1, len(data) - 1) + if f == c: + return data[f] + return data[f] * (c - k) + data[c] * (k - f) + + tools_data.append( + { + "tool_name": tool_name, + "count": n, + "avg_duration_ms": round(sum(durations) / n, 2), + "min_duration_ms": round(min(durations), 2), + "max_duration_ms": round(max(durations), 2), + "p50": round(percentile(durations_sorted, 0.50), 2), + "p90": round(percentile(durations_sorted, 0.90), 2), + "p95": round(percentile(durations_sorted, 0.95), 2), + "p99": round(percentile(durations_sorted, 0.99), 2), + } + ) + + # Sort by average duration descending and limit + tools_data.sort(key=lambda x: x["avg_duration_ms"], reverse=True) + tools = tools_data[:limit] + + return {"tools": tools, "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get tool performance metrics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/tools/errors", response_model=dict) +async def get_tool_errors( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of tools to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get tool error rates and statistics. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of tools to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Tool error statistics + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Query tool error rates + tool_errors = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').label("tool_name"), # pylint: disable=not-callable + func.count(ObservabilitySpan.span_id).label("total_count"), # pylint: disable=not-callable + func.sum(case((ObservabilitySpan.status == "error", 1), else_=0)).label("error_count"), # pylint: disable=not-callable + ) + .filter( + ObservabilitySpan.name == "tool.invoke", + ObservabilitySpan.start_time >= cutoff_time_naive, + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').isnot(None), # pylint: disable=not-callable + ) + .group_by(func.json_extract(ObservabilitySpan.attributes, '$."tool.name"')) # pylint: disable=not-callable + .order_by(func.sum(case((ObservabilitySpan.status == "error", 1), else_=0)).desc()) # pylint: disable=not-callable + .limit(limit) + .all() + ) + + tools = [ + { + "tool_name": row.tool_name, + "total_count": row.total_count, + "error_count": row.error_count or 0, + "error_rate": round((row.error_count / row.total_count * 100) if row.total_count > 0 and row.error_count else 0, 2), + } + for row in tool_errors + ] + + return {"tools": tools, "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get tool error statistics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/tools/chains", response_model=dict) +async def get_tool_chains( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of chains to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get tool chain analysis (which tools are invoked together in the same trace). + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of chains to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Tool chain statistics showing common tool sequences + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Get all tool invocations grouped by trace_id + tool_spans = ( + db.query( + ObservabilitySpan.trace_id, + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').label("tool_name"), # pylint: disable=not-callable + ObservabilitySpan.start_time, + ) + .filter( + ObservabilitySpan.name == "tool.invoke", + ObservabilitySpan.start_time >= cutoff_time_naive, + func.json_extract(ObservabilitySpan.attributes, '$."tool.name"').isnot(None), # pylint: disable=not-callable + ) + .order_by(ObservabilitySpan.trace_id, ObservabilitySpan.start_time) + .all() + ) + + # Group tools by trace and create chains + trace_tools = {} + for span in tool_spans: + if span.trace_id not in trace_tools: + trace_tools[span.trace_id] = [] + trace_tools[span.trace_id].append(span.tool_name) + + # Count tool chain frequencies + chain_counts = {} + for tools in trace_tools.values(): + if len(tools) > 1: + # Create a chain string (sorted to treat [A,B] and [B,A] as same chain) + chain = " -> ".join(tools) + chain_counts[chain] = chain_counts.get(chain, 0) + 1 + + # Sort by frequency and take top N + sorted_chains = sorted(chain_counts.items(), key=lambda x: x[1], reverse=True)[:limit] + + chains = [{"chain": chain, "count": count} for chain, count in sorted_chains] + + return {"chains": chains, "total_traces_with_tools": len(trace_tools), "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get tool chain statistics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/tools/partial", response_class=HTMLResponse) +async def get_tools_partial( + request: Request, + _user=Depends(get_current_user_with_permissions), +): + """Render the tool invocation metrics dashboard HTML partial. + + Args: + request: FastAPI request object + _user: Authenticated user (required by dependency) + + Returns: + HTMLResponse: Rendered tool metrics dashboard partial + """ + root_path = request.scope.get("root_path", "") + return request.app.state.templates.TemplateResponse( + "observability_tools.html", + { + "request": request, + "root_path": root_path, + }, + ) + + +# ============================================================================== +# Prompts Observability Endpoints +# ============================================================================== + + +@admin_router.get("/observability/prompts/usage", response_model=dict) +async def get_prompt_usage( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of prompts to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get prompt rendering frequency statistics. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of prompts to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Prompt usage statistics with counts and percentages + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Query prompt renders from spans (looking for prompts/get calls) + # The prompt id should be in attributes as "prompt.id" + prompt_usage = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"').label("prompt_id"), # pylint: disable=not-callable + func.count(ObservabilitySpan.span_id).label("count"), # pylint: disable=not-callable + ) + .filter( + ObservabilitySpan.name.in_(["prompt.get", "prompts.get", "prompt.render"]), + ObservabilitySpan.start_time >= cutoff_time_naive, + func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"').isnot(None), # pylint: disable=not-callable + ) + .group_by(func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"')) # pylint: disable=not-callable + .order_by(func.count(ObservabilitySpan.span_id).desc()) # pylint: disable=not-callable + .limit(limit) + .all() + ) + + total_renders = sum(row.count for row in prompt_usage) + + prompts = [ + { + "prompt_id": row.prompt_id, + "count": row.count, + "percentage": round((row.count / total_renders * 100) if total_renders > 0 else 0, 2), + } + for row in prompt_usage + ] + + return {"prompts": prompts, "total_renders": total_renders, "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get prompt usage statistics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/prompts/performance", response_model=dict) +async def get_prompt_performance( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of prompts to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get prompt performance metrics (avg, min, max duration). + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of prompts to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Prompt performance metrics + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # First, get all prompt renders with durations + prompt_spans = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"').label("prompt_id"), # pylint: disable=not-callable + ObservabilitySpan.duration_ms, + ) + .filter( + ObservabilitySpan.name.in_(["prompt.get", "prompts.get", "prompt.render"]), + ObservabilitySpan.start_time >= cutoff_time_naive, + ObservabilitySpan.duration_ms.isnot(None), + func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"').isnot(None), # pylint: disable=not-callable + ) + .all() + ) + + # Group by prompt id and calculate percentiles + prompt_durations = defaultdict(list) + for span in prompt_spans: + prompt_durations[span.prompt_id].append(span.duration_ms) + + # Calculate metrics for each prompt + prompts_data = [] + for prompt_id, durations in prompt_durations.items(): + durations_sorted = sorted(durations) + n = len(durations_sorted) + + if n == 0: + continue + + # Calculate percentiles + def percentile(data, p): + if not data: + return 0 + k = (len(data) - 1) * p + f = int(k) + c = min(f + 1, len(data) - 1) + if f == c: + return data[f] + return data[f] * (c - k) + data[c] * (k - f) + + prompts_data.append( + { + "prompt_id": prompt_id, + "count": n, + "avg_duration_ms": round(sum(durations) / n, 2), + "min_duration_ms": round(min(durations), 2), + "max_duration_ms": round(max(durations), 2), + "p50": round(percentile(durations_sorted, 0.50), 2), + "p90": round(percentile(durations_sorted, 0.90), 2), + "p95": round(percentile(durations_sorted, 0.95), 2), + "p99": round(percentile(durations_sorted, 0.99), 2), + } + ) + + # Sort by average duration descending and limit + prompts_data.sort(key=lambda x: x["avg_duration_ms"], reverse=True) + prompts = prompts_data[:limit] + + return {"prompts": prompts, "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get prompt performance metrics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/prompts/errors", response_model=dict) +async def get_prompts_errors( + hours: int = Query(24, description="Time range in hours"), + limit: int = Query(20, description="Maximum number of results"), + _user=Depends(get_current_user_with_permissions), +): + """Get prompt error rates. + + Args: + hours: Time range in hours to analyze + limit: Maximum number of prompts to return + _user: Authenticated user (required by dependency) + + Returns: + dict: Prompt error statistics + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Get all prompt spans with their status + prompt_stats = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"').label("prompt_id"), + func.count().label("total_count"), # pylint: disable=not-callable + func.sum(case((ObservabilitySpan.status == "error", 1), else_=0)).label("error_count"), + ) + .filter( + ObservabilitySpan.name == "prompt.render", + ObservabilitySpan.start_time >= cutoff_time_naive, + func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"').isnot(None), + ) + .group_by(func.json_extract(ObservabilitySpan.attributes, '$."prompt.id"')) + .all() + ) + + prompts_data = [] + for stat in prompt_stats: + total = stat.total_count + errors = stat.error_count or 0 + error_rate = round((errors / total * 100), 2) if total > 0 else 0 + + prompts_data.append({"prompt_id": stat.prompt_id, "total_count": total, "error_count": errors, "error_rate": error_rate}) + + # Sort by error rate descending + prompts_data.sort(key=lambda x: x["error_rate"], reverse=True) + prompts_data = prompts_data[:limit] + + return {"prompts": prompts_data, "time_range_hours": hours} + finally: + db.close() + + +@admin_router.get("/observability/prompts/partial", response_class=HTMLResponse) +async def get_prompts_partial( + request: Request, + _user=Depends(get_current_user_with_permissions), +): + """Render the prompt rendering metrics dashboard HTML partial. + + Args: + request: FastAPI request object + _user: Authenticated user (required by dependency) + + Returns: + HTMLResponse: Rendered prompt metrics dashboard partial + """ + root_path = request.scope.get("root_path", "") + return request.app.state.templates.TemplateResponse( + "observability_prompts.html", + { + "request": request, + "root_path": root_path, + }, + ) + + +# ============================================================================== +# Resources Observability Endpoints +# ============================================================================== + + +@admin_router.get("/observability/resources/usage", response_model=dict) +async def get_resource_usage( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of resources to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get resource fetch frequency statistics. + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of resources to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Resource usage statistics with counts and percentages + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Query resource reads from spans (looking for resources/read calls) + # The resource URI should be in attributes + resource_usage = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"').label("resource_uri"), # pylint: disable=not-callable + func.count(ObservabilitySpan.span_id).label("count"), # pylint: disable=not-callable + ) + .filter( + ObservabilitySpan.name.in_(["resource.read", "resources.read", "resource.fetch"]), + ObservabilitySpan.start_time >= cutoff_time_naive, + func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"').isnot(None), # pylint: disable=not-callable + ) + .group_by(func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"')) # pylint: disable=not-callable + .order_by(func.count(ObservabilitySpan.span_id).desc()) # pylint: disable=not-callable + .limit(limit) + .all() + ) + + total_fetches = sum(row.count for row in resource_usage) + + resources = [ + { + "resource_uri": row.resource_uri, + "count": row.count, + "percentage": round((row.count / total_fetches * 100) if total_fetches > 0 else 0, 2), + } + for row in resource_usage + ] + + return {"resources": resources, "total_fetches": total_fetches, "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get resource usage statistics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/resources/performance", response_model=dict) +async def get_resource_performance( + request: Request, # pylint: disable=unused-argument + hours: int = Query(24, ge=1, le=168, description="Time range in hours"), + limit: int = Query(20, ge=5, le=100, description="Number of resources to return"), + _user=Depends(get_current_user_with_permissions), +): + """Get resource performance metrics (avg, min, max duration). + + Args: + request: FastAPI request object + hours: Number of hours to look back (1-168) + limit: Maximum number of resources to return (5-100) + _user: Authenticated user (required by dependency) + + Returns: + dict: Resource performance metrics + + Raises: + HTTPException: 500 if calculation fails + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # First, get all resource reads with durations + resource_spans = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"').label("resource_uri"), # pylint: disable=not-callable + ObservabilitySpan.duration_ms, + ) + .filter( + ObservabilitySpan.name.in_(["resource.read", "resources.read", "resource.fetch"]), + ObservabilitySpan.start_time >= cutoff_time_naive, + ObservabilitySpan.duration_ms.isnot(None), + func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"').isnot(None), # pylint: disable=not-callable + ) + .all() + ) + + # Group by resource URI and calculate percentiles + resource_durations = defaultdict(list) + for span in resource_spans: + resource_durations[span.resource_uri].append(span.duration_ms) + + # Calculate metrics for each resource + resources_data = [] + for resource_uri, durations in resource_durations.items(): + durations_sorted = sorted(durations) + n = len(durations_sorted) + + if n == 0: + continue + + # Calculate percentiles + def percentile(data, p): + if not data: + return 0 + k = (len(data) - 1) * p + f = int(k) + c = min(f + 1, len(data) - 1) + if f == c: + return data[f] + return data[f] * (c - k) + data[c] * (k - f) + + resources_data.append( + { + "resource_uri": resource_uri, + "count": n, + "avg_duration_ms": round(sum(durations) / n, 2), + "min_duration_ms": round(min(durations), 2), + "max_duration_ms": round(max(durations), 2), + "p50": round(percentile(durations_sorted, 0.50), 2), + "p90": round(percentile(durations_sorted, 0.90), 2), + "p95": round(percentile(durations_sorted, 0.95), 2), + "p99": round(percentile(durations_sorted, 0.99), 2), + } + ) + + # Sort by average duration descending and limit + resources_data.sort(key=lambda x: x["avg_duration_ms"], reverse=True) + resources = resources_data[:limit] + + return {"resources": resources, "time_range_hours": hours} + except Exception as e: + LOGGER.error(f"Failed to get resource performance metrics: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + +@admin_router.get("/observability/resources/errors", response_model=dict) +async def get_resources_errors( + hours: int = Query(24, description="Time range in hours"), + limit: int = Query(20, description="Maximum number of results"), + _user=Depends(get_current_user_with_permissions), +): + """Get resource error rates. + + Args: + hours: Time range in hours to analyze + limit: Maximum number of resources to return + _user: Authenticated user (required by dependency) + + Returns: + dict: Resource error statistics + """ + db = next(get_db()) + try: + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) + cutoff_time_naive = cutoff_time.replace(tzinfo=None) + + # Get all resource spans with their status + resource_stats = ( + db.query( + func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"').label("resource_uri"), + func.count().label("total_count"), # pylint: disable=not-callable + func.sum(case((ObservabilitySpan.status == "error", 1), else_=0)).label("error_count"), + ) + .filter( + ObservabilitySpan.name.in_(["resource.read", "resources.read", "resource.fetch"]), + ObservabilitySpan.start_time >= cutoff_time_naive, + func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"').isnot(None), + ) + .group_by(func.json_extract(ObservabilitySpan.attributes, '$."resource.uri"')) + .all() + ) + + resources_data = [] + for stat in resource_stats: + total = stat.total_count + errors = stat.error_count or 0 + error_rate = round((errors / total * 100), 2) if total > 0 else 0 + + resources_data.append({"resource_uri": stat.resource_uri, "total_count": total, "error_count": errors, "error_rate": error_rate}) + + # Sort by error rate descending + resources_data.sort(key=lambda x: x["error_rate"], reverse=True) + resources_data = resources_data[:limit] + + return {"resources": resources_data, "time_range_hours": hours} + finally: + db.close() + + +@admin_router.get("/observability/resources/partial", response_class=HTMLResponse) +async def get_resources_partial( + request: Request, + _user=Depends(get_current_user_with_permissions), +): + """Render the resource fetch metrics dashboard HTML partial. + + Args: + request: FastAPI request object + _user: Authenticated user (required by dependency) + + Returns: + HTMLResponse: Rendered resource metrics dashboard partial + """ + root_path = request.scope.get("root_path", "") + return request.app.state.templates.TemplateResponse( + "observability_resources.html", + { + "request": request, + "root_path": root_path, + }, + ) diff --git a/mcpgateway/alembic/versions/a23a08d61eb0_add_observability_tables.py b/mcpgateway/alembic/versions/a23a08d61eb0_add_observability_tables.py new file mode 100644 index 000000000..cbf13076d --- /dev/null +++ b/mcpgateway/alembic/versions/a23a08d61eb0_add_observability_tables.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +"""add_observability_tables + +Revision ID: a23a08d61eb0 +Revises: a706a3320c56 +Create Date: 2025-11-05 02:37:14.539024 + +""" + +# Standard +from typing import Sequence, Union + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "a23a08d61eb0" +down_revision: Union[str, Sequence[str], None] = "a706a3320c56" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema - add observability tables.""" + + # Create observability_traces table + op.create_table( + "observability_traces", + sa.Column("trace_id", sa.String(length=36), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=True), + sa.Column("duration_ms", sa.Float(), nullable=True), + sa.Column("status", sa.String(length=20), nullable=False, server_default="unset"), + sa.Column("status_message", sa.Text(), nullable=True), + sa.Column("http_method", sa.String(length=10), nullable=True), + sa.Column("http_url", sa.String(length=767), nullable=True), + sa.Column("http_status_code", sa.Integer(), nullable=True), + sa.Column("user_email", sa.String(length=255), nullable=True), + sa.Column("user_agent", sa.Text(), nullable=True), + sa.Column("ip_address", sa.String(length=45), nullable=True), + sa.Column("attributes", sa.JSON(), nullable=True), + sa.Column("resource_attributes", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("trace_id"), + ) + op.create_index("idx_observability_traces_start_time", "observability_traces", ["start_time"]) + op.create_index("idx_observability_traces_user_email", "observability_traces", ["user_email"]) + op.create_index("idx_observability_traces_status", "observability_traces", ["status"]) + op.create_index("idx_observability_traces_http_status_code", "observability_traces", ["http_status_code"]) + + # Create observability_spans table + op.create_table( + "observability_spans", + sa.Column("span_id", sa.String(length=36), nullable=False), + sa.Column("trace_id", sa.String(length=36), nullable=False), + sa.Column("parent_span_id", sa.String(length=36), nullable=True), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("kind", sa.String(length=20), nullable=False, server_default="internal"), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=True), + sa.Column("duration_ms", sa.Float(), nullable=True), + sa.Column("status", sa.String(length=20), nullable=False, server_default="unset"), + sa.Column("status_message", sa.Text(), nullable=True), + sa.Column("attributes", sa.JSON(), nullable=True), + sa.Column("resource_name", sa.String(length=255), nullable=True), + sa.Column("resource_type", sa.String(length=50), nullable=True), + sa.Column("resource_id", sa.String(length=36), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["trace_id"], ["observability_traces.trace_id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["parent_span_id"], ["observability_spans.span_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("span_id"), + ) + op.create_index("idx_observability_spans_trace_id", "observability_spans", ["trace_id"]) + op.create_index("idx_observability_spans_parent_span_id", "observability_spans", ["parent_span_id"]) + op.create_index("idx_observability_spans_start_time", "observability_spans", ["start_time"]) + op.create_index("idx_observability_spans_resource_type", "observability_spans", ["resource_type"]) + op.create_index("idx_observability_spans_resource_name", "observability_spans", ["resource_name"]) + + # Create observability_events table + op.create_table( + "observability_events", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("span_id", sa.String(length=36), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("attributes", sa.JSON(), nullable=True), + sa.Column("severity", sa.String(length=20), nullable=True), + sa.Column("message", sa.Text(), nullable=True), + sa.Column("exception_type", sa.String(length=255), nullable=True), + sa.Column("exception_message", sa.Text(), nullable=True), + sa.Column("exception_stacktrace", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["span_id"], ["observability_spans.span_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("idx_observability_events_span_id", "observability_events", ["span_id"]) + op.create_index("idx_observability_events_timestamp", "observability_events", ["timestamp"]) + op.create_index("idx_observability_events_severity", "observability_events", ["severity"]) + + # Create observability_metrics table + op.create_table( + "observability_metrics", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("metric_type", sa.String(length=20), nullable=False), + sa.Column("value", sa.Float(), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("unit", sa.String(length=20), nullable=True), + sa.Column("attributes", sa.JSON(), nullable=True), + sa.Column("resource_type", sa.String(length=50), nullable=True), + sa.Column("resource_id", sa.String(length=36), nullable=True), + sa.Column("trace_id", sa.String(length=36), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["trace_id"], ["observability_traces.trace_id"], ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("idx_observability_metrics_name_timestamp", "observability_metrics", ["name", "timestamp"]) + op.create_index("idx_observability_metrics_resource_type", "observability_metrics", ["resource_type"]) + op.create_index("idx_observability_metrics_trace_id", "observability_metrics", ["trace_id"]) + + +def downgrade() -> None: + """Downgrade schema - remove observability tables.""" + op.drop_index("idx_observability_metrics_trace_id", table_name="observability_metrics") + op.drop_index("idx_observability_metrics_resource_type", table_name="observability_metrics") + op.drop_index("idx_observability_metrics_name_timestamp", table_name="observability_metrics") + op.drop_table("observability_metrics") + + op.drop_index("idx_observability_events_severity", table_name="observability_events") + op.drop_index("idx_observability_events_timestamp", table_name="observability_events") + op.drop_index("idx_observability_events_span_id", table_name="observability_events") + op.drop_table("observability_events") + + op.drop_index("idx_observability_spans_resource_name", table_name="observability_spans") + op.drop_index("idx_observability_spans_resource_type", table_name="observability_spans") + op.drop_index("idx_observability_spans_start_time", table_name="observability_spans") + op.drop_index("idx_observability_spans_parent_span_id", table_name="observability_spans") + op.drop_index("idx_observability_spans_trace_id", table_name="observability_spans") + op.drop_table("observability_spans") + + op.drop_index("idx_observability_traces_http_status_code", table_name="observability_traces") + op.drop_index("idx_observability_traces_status", table_name="observability_traces") + op.drop_index("idx_observability_traces_user_email", table_name="observability_traces") + op.drop_index("idx_observability_traces_start_time", table_name="observability_traces") + op.drop_table("observability_traces") diff --git a/mcpgateway/alembic/versions/i3c4d5e6f7g8_add_observability_performance_indexes.py b/mcpgateway/alembic/versions/i3c4d5e6f7g8_add_observability_performance_indexes.py new file mode 100644 index 000000000..2d5c9df05 --- /dev/null +++ b/mcpgateway/alembic/versions/i3c4d5e6f7g8_add_observability_performance_indexes.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +"""add observability performance indexes + +Revision ID: i3c4d5e6f7g8 +Revises: a23a08d61eb0 +Create Date: 2025-01-05 12:00:00.000000 + +""" + +# Standard +from typing import Sequence, Union + +# Third-Party +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "i3c4d5e6f7g8" +down_revision: Union[str, Sequence[str], None] = "a23a08d61eb0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add performance indexes for observability tables. + + These composite indexes optimize common query patterns: + - Filtering by time range WITH status (composite) + - Filtering by duration (new) + - Filtering by HTTP method WITH time (composite) + - Filtering by resource type WITH time (composite) + - Filtering by span kind WITH status (composite) + + Note: Basic indexes (status, start_time, resource_type, etc.) already exist + from the initial migration. This migration adds COMPOSITE indexes only. + """ + # ObservabilityTrace composite indexes + op.create_index("ix_observability_traces_status_start_time", "observability_traces", ["status", "start_time"]) + op.create_index("ix_observability_traces_duration_ms", "observability_traces", ["duration_ms"]) + op.create_index("ix_observability_traces_http_method_start_time", "observability_traces", ["http_method", "start_time"]) + op.create_index("ix_observability_traces_name", "observability_traces", ["name"]) + + # ObservabilitySpan composite indexes + op.create_index("ix_observability_spans_trace_id_start_time", "observability_spans", ["trace_id", "start_time"]) + op.create_index("ix_observability_spans_resource_type_start_time", "observability_spans", ["resource_type", "start_time"]) + op.create_index("ix_observability_spans_kind_status", "observability_spans", ["kind", "status"]) + op.create_index("ix_observability_spans_duration_ms", "observability_spans", ["duration_ms"]) + op.create_index("ix_observability_spans_name", "observability_spans", ["name"]) + + # ObservabilityEvent composite index + op.create_index("ix_observability_events_span_id_timestamp", "observability_events", ["span_id", "timestamp"]) + + +def downgrade() -> None: + """Remove observability performance indexes.""" + # Drop ObservabilityEvent composite index + op.drop_index("ix_observability_events_span_id_timestamp", table_name="observability_events") + + # Drop ObservabilitySpan composite indexes + op.drop_index("ix_observability_spans_name", table_name="observability_spans") + op.drop_index("ix_observability_spans_duration_ms", table_name="observability_spans") + op.drop_index("ix_observability_spans_kind_status", table_name="observability_spans") + op.drop_index("ix_observability_spans_resource_type_start_time", table_name="observability_spans") + op.drop_index("ix_observability_spans_trace_id_start_time", table_name="observability_spans") + + # Drop ObservabilityTrace composite indexes + op.drop_index("ix_observability_traces_name", table_name="observability_traces") + op.drop_index("ix_observability_traces_http_method_start_time", table_name="observability_traces") + op.drop_index("ix_observability_traces_duration_ms", table_name="observability_traces") + op.drop_index("ix_observability_traces_status_start_time", table_name="observability_traces") diff --git a/mcpgateway/alembic/versions/j4d5e6f7g8h9_add_observability_saved_queries.py b/mcpgateway/alembic/versions/j4d5e6f7g8h9_add_observability_saved_queries.py new file mode 100644 index 000000000..bf68ae867 --- /dev/null +++ b/mcpgateway/alembic/versions/j4d5e6f7g8h9_add_observability_saved_queries.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +"""add observability saved queries + +Revision ID: j4d5e6f7g8h9 +Revises: i3c4d5e6f7g8 +Create Date: 2025-01-06 12:00:00.000000 + +""" + +# Standard +from typing import Sequence, Union + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "j4d5e6f7g8h9" +down_revision: Union[str, Sequence[str], None] = "i3c4d5e6f7g8" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add observability_saved_queries table for storing filter presets.""" + op.create_table( + "observability_saved_queries", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("user_email", sa.String(length=255), nullable=False), + sa.Column("filter_config", sa.JSON(), nullable=False), + sa.Column("is_shared", sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("use_count", sa.Integer(), nullable=False, server_default=sa.text("0")), + sa.PrimaryKeyConstraint("id"), + ) + + # Create indexes for performance + op.create_index("idx_observability_saved_queries_user_email", "observability_saved_queries", ["user_email"]) + op.create_index("idx_observability_saved_queries_is_shared", "observability_saved_queries", ["is_shared"]) + op.create_index("idx_observability_saved_queries_created_at", "observability_saved_queries", ["created_at"]) + op.create_index("ix_observability_saved_queries_name", "observability_saved_queries", ["name"]) + + +def downgrade() -> None: + """Remove observability_saved_queries table.""" + op.drop_index("ix_observability_saved_queries_name", table_name="observability_saved_queries") + op.drop_index("idx_observability_saved_queries_created_at", table_name="observability_saved_queries") + op.drop_index("idx_observability_saved_queries_is_shared", table_name="observability_saved_queries") + op.drop_index("idx_observability_saved_queries_user_email", table_name="observability_saved_queries") + op.drop_table("observability_saved_queries") diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 24d1148b5..9856358d4 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -716,6 +716,34 @@ def _parse_allowed_origins(cls, v: Any) -> Set[str]: # Log Buffer (for in-memory storage in admin UI) log_buffer_size_mb: float = 1.0 # Size of in-memory log buffer in MB + # =================================== + # Observability Configuration + # =================================== + + # Enable observability features (traces, spans, metrics) + observability_enabled: bool = Field(default=False, description="Enable observability tracing and metrics collection") + + # Automatic HTTP request tracing + observability_trace_http_requests: bool = Field(default=True, description="Automatically trace HTTP requests") + + # Trace retention period (days) + observability_trace_retention_days: int = Field(default=7, ge=1, description="Number of days to retain trace data") + + # Maximum traces to store (prevents unbounded growth) + observability_max_traces: int = Field(default=100000, ge=1000, description="Maximum number of traces to retain") + + # Sample rate (0.0 to 1.0) - 1.0 means trace everything + observability_sample_rate: float = Field(default=1.0, ge=0.0, le=1.0, description="Trace sampling rate (0.0-1.0)") + + # Exclude paths from tracing (regex patterns) + observability_exclude_paths: List[str] = Field(default_factory=lambda: ["/health", "/healthz", "/ready", "/metrics", "/static/.*"], description="Paths to exclude from tracing (regex)") + + # Enable performance metrics + observability_metrics_enabled: bool = Field(default=True, description="Enable metrics collection") + + # Enable span events + observability_events_enabled: bool = Field(default=True, description="Enable event logging within spans") + @field_validator("log_level", mode="before") @classmethod def validate_log_level(cls, v: str) -> str: diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 087b3936e..a6450a394 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -122,6 +122,17 @@ connect_args=connect_args, ) +# Initialize SQLAlchemy instrumentation for observability +if settings.observability_enabled: + try: + # First-Party + from mcpgateway.instrumentation import instrument_sqlalchemy + + instrument_sqlalchemy(engine) + logger.info("SQLAlchemy instrumentation enabled for observability") + except ImportError: + logger.warning("Failed to import SQLAlchemy instrumentation") + # --------------------------------------------------------------------------- # 6. Function to return UTC timestamp @@ -164,10 +175,12 @@ def set_sqlite_pragma(dbapi_conn, _connection_record): cursor = dbapi_conn.cursor() # Enable WAL mode for better concurrency cursor.execute("PRAGMA journal_mode=WAL") - # Set busy timeout to 10 seconds (10000 ms) to handle lock contention - cursor.execute("PRAGMA busy_timeout=10000") + # Set busy timeout to 30 seconds (30000 ms) to handle lock contention from observability + cursor.execute("PRAGMA busy_timeout=30000") # Synchronous=NORMAL is safe with WAL mode and improves performance cursor.execute("PRAGMA synchronous=NORMAL") + # Increase cache size for better performance (negative value = KB) + cursor.execute("PRAGMA cache_size=-64000") # 64MB cache cursor.close() @@ -1528,6 +1541,309 @@ class A2AAgentMetric(Base): a2a_agent: Mapped["A2AAgent"] = relationship("A2AAgent", back_populates="metrics") +# =================================== +# Observability Models (OpenTelemetry-style traces, spans, events) +# =================================== + + +class ObservabilityTrace(Base): + """ + ORM model for observability traces (similar to OpenTelemetry traces). + + A trace represents a complete request flow through the system. It contains + one or more spans representing individual operations. + + Attributes: + trace_id (str): Unique trace identifier (UUID or OpenTelemetry trace ID format). + name (str): Human-readable name for the trace (e.g., "POST /tools/invoke"). + start_time (datetime): When the trace started. + end_time (datetime): When the trace ended (optional, set when completed). + duration_ms (float): Total duration in milliseconds. + status (str): Trace status (success, error, timeout). + status_message (str): Optional status message or error description. + http_method (str): HTTP method for the request (GET, POST, etc.). + http_url (str): Full URL of the request. + http_status_code (int): HTTP response status code. + user_email (str): User who initiated the request (if authenticated). + user_agent (str): Client user agent string. + ip_address (str): Client IP address. + attributes (dict): Additional trace attributes (JSON). + resource_attributes (dict): Resource attributes (service name, version, etc.). + created_at (datetime): Trace creation timestamp. + """ + + __tablename__ = "observability_traces" + + # Primary key + trace_id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + + # Trace metadata + name: Mapped[str] = mapped_column(String(255), nullable=False) + start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) + end_time: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + duration_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + status: Mapped[str] = mapped_column(String(20), nullable=False, default="unset") # unset, ok, error + status_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # HTTP request context + http_method: Mapped[Optional[str]] = mapped_column(String(10), nullable=True) + http_url: Mapped[Optional[str]] = mapped_column(String(767), nullable=True) + http_status_code: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + + # User context + user_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True) + user_agent: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + ip_address: Mapped[Optional[str]] = mapped_column(String(45), nullable=True) + + # Attributes (flexible key-value storage) + attributes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=dict) + resource_attributes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=dict) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, nullable=False) + + # Relationships + spans: Mapped[List["ObservabilitySpan"]] = relationship("ObservabilitySpan", back_populates="trace", cascade="all, delete-orphan") + + # Indexes for performance + __table_args__ = ( + Index("idx_observability_traces_start_time", "start_time"), + Index("idx_observability_traces_user_email", "user_email"), + Index("idx_observability_traces_status", "status"), + Index("idx_observability_traces_http_status_code", "http_status_code"), + ) + + +class ObservabilitySpan(Base): + """ + ORM model for observability spans (similar to OpenTelemetry spans). + + A span represents a single operation within a trace. Spans can be nested + to represent hierarchical operations. + + Attributes: + span_id (str): Unique span identifier. + trace_id (str): Parent trace ID. + parent_span_id (str): Parent span ID (for nested spans). + name (str): Span name (e.g., "database_query", "tool_invocation"). + kind (str): Span kind (internal, server, client, producer, consumer). + start_time (datetime): When the span started. + end_time (datetime): When the span ended. + duration_ms (float): Span duration in milliseconds. + status (str): Span status (success, error). + status_message (str): Optional status message. + attributes (dict): Span attributes (JSON). + resource_name (str): Name of the resource being operated on. + resource_type (str): Type of resource (tool, resource, prompt, gateway, etc.). + resource_id (str): ID of the specific resource. + created_at (datetime): Span creation timestamp. + """ + + __tablename__ = "observability_spans" + + # Primary key + span_id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + + # Trace relationship + trace_id: Mapped[str] = mapped_column(String(36), ForeignKey("observability_traces.trace_id", ondelete="CASCADE"), nullable=False, index=True) + parent_span_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("observability_spans.span_id", ondelete="CASCADE"), nullable=True, index=True) + + # Span metadata + name: Mapped[str] = mapped_column(String(255), nullable=False) + kind: Mapped[str] = mapped_column(String(20), nullable=False, default="internal") # internal, server, client, producer, consumer + start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) + end_time: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + duration_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + status: Mapped[str] = mapped_column(String(20), nullable=False, default="unset") + status_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Attributes + attributes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=dict) + + # Resource context + resource_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True) + resource_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, index=True) # tool, resource, prompt, gateway, a2a_agent + resource_id: Mapped[Optional[str]] = mapped_column(String(36), nullable=True, index=True) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, nullable=False) + + # Relationships + trace: Mapped["ObservabilityTrace"] = relationship("ObservabilityTrace", back_populates="spans") + parent_span: Mapped[Optional["ObservabilitySpan"]] = relationship("ObservabilitySpan", remote_side=[span_id], backref="child_spans") + events: Mapped[List["ObservabilityEvent"]] = relationship("ObservabilityEvent", back_populates="span", cascade="all, delete-orphan") + + # Indexes for performance + __table_args__ = ( + Index("idx_observability_spans_trace_id", "trace_id"), + Index("idx_observability_spans_parent_span_id", "parent_span_id"), + Index("idx_observability_spans_start_time", "start_time"), + Index("idx_observability_spans_resource_type", "resource_type"), + Index("idx_observability_spans_resource_name", "resource_name"), + ) + + +class ObservabilityEvent(Base): + """ + ORM model for observability events (logs within spans). + + Events represent discrete occurrences within a span, such as log messages, + exceptions, or state changes. + + Attributes: + id (int): Auto-incrementing primary key. + span_id (str): Parent span ID. + name (str): Event name (e.g., "exception", "log", "checkpoint"). + timestamp (datetime): When the event occurred. + attributes (dict): Event attributes (JSON). + severity (str): Log severity level (debug, info, warning, error, critical). + message (str): Event message. + exception_type (str): Exception class name (if event is an exception). + exception_message (str): Exception message. + exception_stacktrace (str): Exception stacktrace. + created_at (datetime): Event creation timestamp. + """ + + __tablename__ = "observability_events" + + # Primary key + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + # Span relationship + span_id: Mapped[str] = mapped_column(String(36), ForeignKey("observability_spans.span_id", ondelete="CASCADE"), nullable=False, index=True) + + # Event metadata + name: Mapped[str] = mapped_column(String(255), nullable=False) + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=utc_now, index=True) + attributes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=dict) + + # Log fields + severity: Mapped[Optional[str]] = mapped_column(String(20), nullable=True, index=True) # debug, info, warning, error, critical + message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Exception fields + exception_type: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + exception_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + exception_stacktrace: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, nullable=False) + + # Relationships + span: Mapped["ObservabilitySpan"] = relationship("ObservabilitySpan", back_populates="events") + + # Indexes for performance + __table_args__ = ( + Index("idx_observability_events_span_id", "span_id"), + Index("idx_observability_events_timestamp", "timestamp"), + Index("idx_observability_events_severity", "severity"), + ) + + +class ObservabilityMetric(Base): + """ + ORM model for observability metrics (time-series numerical data). + + Metrics represent numerical measurements over time, such as request rates, + error rates, latencies, and custom business metrics. + + Attributes: + id (int): Auto-incrementing primary key. + name (str): Metric name (e.g., "http.request.duration", "tool.invocation.count"). + metric_type (str): Metric type (counter, gauge, histogram). + value (float): Metric value. + timestamp (datetime): When the metric was recorded. + unit (str): Metric unit (ms, count, bytes, etc.). + attributes (dict): Metric attributes/labels (JSON). + resource_type (str): Type of resource (tool, resource, prompt, etc.). + resource_id (str): ID of the specific resource. + trace_id (str): Associated trace ID (optional). + created_at (datetime): Metric creation timestamp. + """ + + __tablename__ = "observability_metrics" + + # Primary key + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + # Metric metadata + name: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + metric_type: Mapped[str] = mapped_column(String(20), nullable=False) # counter, gauge, histogram + value: Mapped[float] = mapped_column(Float, nullable=False) + timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=utc_now, index=True) + unit: Mapped[Optional[str]] = mapped_column(String(20), nullable=True) + + # Attributes/labels + attributes: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=dict) + + # Resource context + resource_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, index=True) + resource_id: Mapped[Optional[str]] = mapped_column(String(36), nullable=True, index=True) + + # Trace association (optional) + trace_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("observability_traces.trace_id", ondelete="SET NULL"), nullable=True, index=True) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, nullable=False) + + # Indexes for performance + __table_args__ = ( + Index("idx_observability_metrics_name_timestamp", "name", "timestamp"), + Index("idx_observability_metrics_resource_type", "resource_type"), + Index("idx_observability_metrics_trace_id", "trace_id"), + ) + + +class ObservabilitySavedQuery(Base): + """ + ORM model for saved observability queries (filter presets). + + Allows users to save their filter configurations for quick access and + historical query tracking. Queries can be personal or shared with the team. + + Attributes: + id (int): Auto-incrementing primary key. + name (str): User-given name for the saved query. + description (str): Optional description of what this query finds. + user_email (str): Email of the user who created this query. + filter_config (dict): JSON containing all filter values (time_range, status_filter, etc.). + is_shared (bool): Whether this query is visible to other users. + created_at (datetime): When the query was created. + updated_at (datetime): When the query was last modified. + last_used_at (datetime): When the query was last executed. + use_count (int): How many times this query has been used. + """ + + __tablename__ = "observability_saved_queries" + + # Primary key + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + # Query metadata + name: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + user_email: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + + # Filter configuration (stored as JSON) + filter_config: Mapped[Dict[str, Any]] = mapped_column(JSON, nullable=False) + + # Sharing settings + is_shared: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + # Timestamps and usage tracking + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utc_now, onupdate=utc_now, nullable=False) + last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + use_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + + # Indexes for performance + __table_args__ = ( + Index("idx_observability_saved_queries_user_email", "user_email"), + Index("idx_observability_saved_queries_is_shared", "is_shared"), + Index("idx_observability_saved_queries_created_at", "created_at"), + ) + + class Tool(Base): """ ORM model for a registered Tool. diff --git a/mcpgateway/instrumentation/__init__.py b/mcpgateway/instrumentation/__init__.py new file mode 100644 index 000000000..f67ce406c --- /dev/null +++ b/mcpgateway/instrumentation/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/instrumentation/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Automatic instrumentation for observability. + +This module provides automatic instrumentation for common libraries: +- SQLAlchemy database queries +- HTTP clients (future) +- Redis operations (future) +""" + +# pylint: disable=cyclic-import +# Cyclic import is intentional and broken by lazy imports in sqlalchemy.py +from mcpgateway.instrumentation.sqlalchemy import instrument_sqlalchemy + +__all__ = ["instrument_sqlalchemy"] diff --git a/mcpgateway/instrumentation/sqlalchemy.py b/mcpgateway/instrumentation/sqlalchemy.py new file mode 100644 index 000000000..c66c102ce --- /dev/null +++ b/mcpgateway/instrumentation/sqlalchemy.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/instrumentation/sqlalchemy.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Automatic instrumentation for SQLAlchemy database queries. + +This module instruments SQLAlchemy to automatically capture database +queries as observability spans, providing visibility into database +performance. + +Examples: + >>> from mcpgateway.instrumentation import instrument_sqlalchemy # doctest: +SKIP + >>> instrument_sqlalchemy(engine) # doctest: +SKIP +""" + +# Standard +import logging +import queue +import threading +import time +from typing import Any, Optional + +# Third-Party +from sqlalchemy import event +from sqlalchemy.engine import Connection, Engine + +logger = logging.getLogger(__name__) + +# Thread-local storage for tracking queries in progress +_query_tracking = {} + +# Thread-local flag to prevent recursive instrumentation +_instrumentation_context = threading.local() + +# Background queue for deferred span writes to avoid database locks +_span_queue: queue.Queue = queue.Queue(maxsize=1000) +_span_writer_thread: Optional[threading.Thread] = None +_shutdown_event = threading.Event() + + +def _write_span_to_db(span_data: dict) -> None: + """Write a single span to the database. + + Args: + span_data: Dictionary containing span information + """ + try: + # Import here to avoid circular imports + # First-Party + # pylint: disable=import-outside-toplevel + from mcpgateway.db import ObservabilitySpan, SessionLocal + from mcpgateway.services.observability_service import ObservabilityService + + # pylint: enable=import-outside-toplevel + + service = ObservabilityService() + db = SessionLocal() + try: + span_id = service.start_span( + db=db, + trace_id=span_data["trace_id"], + name=span_data["name"], + kind=span_data["kind"], + resource_type=span_data["resource_type"], + resource_name=span_data["resource_name"], + attributes=span_data["start_attributes"], + ) + + # End span with measured duration in attributes + service.end_span( + db=db, + span_id=span_id, + status=span_data["status"], + attributes=span_data["end_attributes"], + ) + + # Update the span duration to match what we actually measured + span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first() + if span: + span.duration_ms = span_data["duration_ms"] + db.commit() + + logger.debug(f"Created span for {span_data['resource_name']} query: " f"{span_data['duration_ms']:.2f}ms, {span_data.get('row_count')} rows") + + finally: + db.close() + + except Exception as e: # pylint: disable=broad-except + # Don't fail if span creation fails + logger.warning(f"Failed to write query span: {e}") + + +def _span_writer_worker() -> None: + """Background worker thread that writes spans to the database. + + This runs in a separate thread to avoid blocking the main request thread + and to prevent database lock contention. + """ + logger.info("Span writer worker started") + + while not _shutdown_event.is_set(): + try: + # Wait for span data with timeout to allow checking shutdown + try: + span_data = _span_queue.get(timeout=1.0) + except queue.Empty: + continue + + # Write the span to the database + _write_span_to_db(span_data) + _span_queue.task_done() + + except Exception as e: # pylint: disable=broad-except + logger.error(f"Error in span writer worker: {e}") + # Continue processing even if one span fails + + logger.info("Span writer worker stopped") + + +def instrument_sqlalchemy(engine: Engine) -> None: + """Instrument a SQLAlchemy engine to capture query spans. + + Args: + engine: SQLAlchemy engine to instrument + + Examples: + >>> from sqlalchemy import create_engine # doctest: +SKIP + >>> engine = create_engine("sqlite:///./mcp.db") # doctest: +SKIP + >>> instrument_sqlalchemy(engine) # doctest: +SKIP + """ + global _span_writer_thread # pylint: disable=global-statement + + # Register event listeners + event.listen(engine, "before_cursor_execute", _before_cursor_execute) + event.listen(engine, "after_cursor_execute", _after_cursor_execute) + + # Start background span writer thread if not already running + if _span_writer_thread is None or not _span_writer_thread.is_alive(): + _span_writer_thread = threading.Thread(target=_span_writer_worker, name="SpanWriterThread", daemon=True) + _span_writer_thread.start() + logger.info("Started background span writer thread") + + logger.info("SQLAlchemy instrumentation enabled") + + +def _before_cursor_execute( + conn: Connection, + _cursor: Any, + statement: str, + parameters: Any, + _context: Any, + executemany: bool, +) -> None: + """Event handler called before SQL query execution. + + Args: + conn: Database connection + _cursor: Database cursor (required by SQLAlchemy event API) + statement: SQL statement + parameters: Query parameters + _context: Execution context (required by SQLAlchemy event API) + executemany: Whether this is a bulk execution + """ + # Store start time for this query + conn_id = id(conn) + _query_tracking[conn_id] = { + "start_time": time.time(), + "statement": statement, + "parameters": parameters, + "executemany": executemany, + } + + +def _after_cursor_execute( + conn: Connection, + cursor: Any, + statement: str, + _parameters: Any, + _context: Any, + executemany: bool, +) -> None: + """Event handler called after SQL query execution. + + Args: + conn: Database connection + cursor: Database cursor + statement: SQL statement + _parameters: Query parameters (required by SQLAlchemy event API) + _context: Execution context (required by SQLAlchemy event API) + executemany: Whether this is a bulk execution + """ + conn_id = id(conn) + tracking = _query_tracking.pop(conn_id, None) + + if not tracking: + return + + # Skip instrumentation if we're already inside span creation (prevent recursion) + if getattr(_instrumentation_context, "inside_span_creation", False): + return + + # Skip instrumentation for observability tables to prevent recursion and lock issues + statement_upper = statement.upper() + if any(table in statement_upper for table in ["OBSERVABILITY_TRACES", "OBSERVABILITY_SPANS", "OBSERVABILITY_EVENTS", "OBSERVABILITY_METRICS"]): + logger.debug(f"Skipping instrumentation for observability table query: {statement[:100]}...") + return + + # Calculate query duration + duration_ms = (time.time() - tracking["start_time"]) * 1000 + + # Get row count if available + row_count = None + try: + if hasattr(cursor, "rowcount") and cursor.rowcount >= 0: + row_count = cursor.rowcount + except Exception: # pylint: disable=broad-except # nosec B110 - row_count is optional metadata + pass + + # Try to get trace context from connection info + trace_id = None + if hasattr(conn, "info") and "trace_id" in conn.info: + trace_id = conn.info["trace_id"] + + # If we have a trace_id, create a span + if trace_id: + _create_query_span( + trace_id=trace_id, + statement=statement, + duration_ms=duration_ms, + row_count=row_count, + executemany=executemany, + ) + else: + # Log for debugging but don't fail + logger.debug(f"Query executed without trace context: {statement[:100]}... ({duration_ms:.2f}ms)") + + +def _create_query_span( + trace_id: str, + statement: str, + duration_ms: float, + row_count: Optional[int], + executemany: bool, +) -> None: + """Create an observability span for a database query. + + This function enqueues span data to be written by a background thread, + avoiding database lock contention. + + Args: + trace_id: Parent trace ID + statement: SQL statement + duration_ms: Query duration in milliseconds + row_count: Number of rows affected/returned + executemany: Whether this is a bulk execution + """ + try: + # Extract query type (SELECT, INSERT, UPDATE, DELETE, etc.) + query_type = statement.strip().split()[0].upper() if statement else "UNKNOWN" + + # Truncate long queries for span name + span_name = f"db.query.{query_type.lower()}" + + # Prepare span data + span_data = { + "trace_id": trace_id, + "name": span_name, + "kind": "client", + "resource_type": "database", + "resource_name": query_type, + "duration_ms": duration_ms, + "status": "ok", + "start_attributes": { + "db.statement": statement[:500], # Truncate long queries + "db.operation": query_type, + "db.executemany": executemany, + "db.duration_measured_ms": duration_ms, # Store actual measured duration + }, + "end_attributes": { + "db.row_count": row_count, + }, + "row_count": row_count, + } + + # Enqueue for background processing (non-blocking) + try: + _span_queue.put_nowait(span_data) + logger.debug(f"Enqueued span for {query_type} query: {duration_ms:.2f}ms") + except queue.Full: + logger.warning("Span queue is full, dropping span data") + + except Exception as e: # pylint: disable=broad-except + # Don't fail the query if span creation fails + logger.debug(f"Failed to enqueue query span: {e}") + + +def attach_trace_to_session(session: Any, trace_id: str) -> None: + """Attach a trace ID to a database session. + + This allows the instrumentation to correlate queries with traces. + + Args: + session: SQLAlchemy session + trace_id: Trace ID to attach + + Examples: + >>> from mcpgateway.db import SessionLocal # doctest: +SKIP + >>> db = SessionLocal() # doctest: +SKIP + >>> attach_trace_to_session(db, trace_id) # doctest: +SKIP + """ + if hasattr(session, "bind") and session.bind: + # Get a connection and attach trace_id to its info dict + connection = session.connection() + if hasattr(connection, "info"): + connection.info["trace_id"] = trace_id diff --git a/mcpgateway/main.py b/mcpgateway/main.py index a84fd66df..559b6ad0a 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -1071,6 +1071,26 @@ async def _call_streamable_http(self, scope, receive, send): if settings.log_requests: app.add_middleware(RequestLoggingMiddleware, log_requests=settings.log_requests, log_level=settings.log_level, max_body_size=settings.log_max_size_mb * 1024 * 1024) # Convert MB to bytes +# Add observability middleware if enabled +# Note: Middleware runs in REVERSE order (last added runs first) +# We add ObservabilityMiddleware first so it wraps AuthContextMiddleware +# Execution order will be: AuthContext -> Observability -> Request Handler +if settings.observability_enabled: + # First-Party + from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware + + app.add_middleware(ObservabilityMiddleware, enabled=True) + logger.info("πŸ” Observability middleware enabled - tracing all HTTP requests") + + # Add authentication context middleware (runs BEFORE observability in execution) + # First-Party + from mcpgateway.middleware.auth_middleware import AuthContextMiddleware + + app.add_middleware(AuthContextMiddleware) + logger.info("πŸ” Authentication context middleware enabled - extracting user info for observability") +else: + logger.info("πŸ” Observability middleware disabled") + # Set up Jinja2 templates and store in app state for later use templates = Jinja2Templates(directory=str(settings.templates_dir)) app.state.templates = templates @@ -4669,6 +4689,16 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr app.include_router(tag_router) app.include_router(export_import_router) +# Conditionally include observability router if enabled +if settings.observability_enabled: + # First-Party + from mcpgateway.routers.observability import router as observability_router + + app.include_router(observability_router) + logger.info("Observability router included - observability API endpoints enabled") +else: + logger.info("Observability router not included - observability disabled") + # Conditionally include A2A router if A2A features are enabled if settings.mcpgateway_a2a_enabled: app.include_router(a2a_router) diff --git a/mcpgateway/middleware/auth_middleware.py b/mcpgateway/middleware/auth_middleware.py new file mode 100644 index 000000000..a8868ccbe --- /dev/null +++ b/mcpgateway/middleware/auth_middleware.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/middleware/auth_middleware.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Authentication Middleware for early user context extraction. + +This middleware extracts user information from JWT tokens early in the request +lifecycle and stores it in request.state.user for use by other middleware +(like ObservabilityMiddleware) and route handlers. + +Examples: + >>> from mcpgateway.middleware.auth_middleware import AuthContextMiddleware # doctest: +SKIP + >>> app.add_middleware(AuthContextMiddleware) # doctest: +SKIP +""" + +# Standard +import logging +from typing import Callable + +# Third-Party +from fastapi.security import HTTPAuthorizationCredentials +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +# First-Party +from mcpgateway.auth import get_current_user +from mcpgateway.db import SessionLocal + +logger = logging.getLogger(__name__) + + +class AuthContextMiddleware(BaseHTTPMiddleware): + """Middleware for extracting user authentication context early in request lifecycle. + + This middleware attempts to authenticate requests using JWT tokens from cookies + or Authorization headers, and stores the user information in request.state.user + for downstream middleware and handlers to use. + + Unlike route-level authentication dependencies, this runs for ALL requests, + allowing middleware like ObservabilityMiddleware to access user context. + + Note: + Authentication failures are silent - requests continue as unauthenticated. + Route-level dependencies should still enforce authentication requirements. + """ + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process request and populate user context if authenticated. + + Args: + request: Incoming HTTP request + call_next: Next middleware/handler in chain + + Returns: + HTTP response + """ + # Skip for health checks and static files + if request.url.path in ["/health", "/healthz", "/ready", "/metrics"] or request.url.path.startswith("/static/"): + return await call_next(request) + + # Try to extract token from multiple sources + token = None + + # 1. Try manual cookie reading + if request.cookies: + token = request.cookies.get("jwt_token") or request.cookies.get("access_token") + + # 2. Try Authorization header + if not token: + auth_header = request.headers.get("authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.replace("Bearer ", "") + + # If no token found, continue without user context + if not token: + return await call_next(request) + + # Try to authenticate and populate user context + db = None + try: + db = SessionLocal() + credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) + user = await get_current_user(credentials, db) + + # Store user in request state for downstream use + request.state.user = user + logger.info(f"βœ“ Authenticated user for observability: {user.email}") + + except Exception as e: + # Silently fail - let route handlers enforce auth if needed + logger.info(f"βœ— Auth context extraction failed (continuing as anonymous): {e}") + + finally: + # Always close database session + if db: + try: + db.close() + except Exception as close_error: + logger.debug(f"Failed to close database session: {close_error}") + + # Continue with request + return await call_next(request) diff --git a/mcpgateway/middleware/observability_middleware.py b/mcpgateway/middleware/observability_middleware.py new file mode 100644 index 000000000..ef51723a5 --- /dev/null +++ b/mcpgateway/middleware/observability_middleware.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/middleware/observability_middleware.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Observability Middleware for automatic request/response tracing. + +This middleware automatically captures HTTP requests and responses as observability traces, +providing comprehensive visibility into all gateway operations. + +Examples: + >>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP + >>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP +""" + +# Standard +import logging +import time +import traceback +from typing import Callable + +# Third-Party +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +# First-Party +from mcpgateway.config import settings +from mcpgateway.db import SessionLocal +from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session +from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent + +logger = logging.getLogger(__name__) + + +class ObservabilityMiddleware(BaseHTTPMiddleware): + """Middleware for automatic HTTP request/response tracing. + + Captures every HTTP request as a trace with timing, status codes, + and user context. Automatically creates spans for the request lifecycle. + + This middleware is disabled by default and can be enabled via the + MCPGATEWAY_OBSERVABILITY_ENABLED environment variable. + """ + + def __init__(self, app, enabled: bool = None): + """Initialize the observability middleware. + + Args: + app: ASGI application + enabled: Whether observability is enabled (defaults to settings) + """ + super().__init__(app) + self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False) + self.service = ObservabilityService() + logger.info(f"Observability middleware initialized (enabled={self.enabled})") + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process request and create observability trace. + + Args: + request: Incoming HTTP request + call_next: Next middleware/handler in chain + + Returns: + HTTP response + + Raises: + Exception: Re-raises any exception from request processing after logging + """ + # Skip if observability is disabled + if not self.enabled: + return await call_next(request) + + # Skip health checks and static files to reduce noise + if request.url.path in ["/health", "/healthz", "/ready", "/metrics"] or request.url.path.startswith("/static/"): + return await call_next(request) + + # Extract request context + http_method = request.method + http_url = str(request.url) + user_email = None + ip_address = request.client.host if request.client else None + user_agent = request.headers.get("user-agent") + + # Try to extract user from request state (set by auth middleware) + if hasattr(request.state, "user") and hasattr(request.state.user, "email"): + user_email = request.state.user.email + + # Extract W3C Trace Context from headers (for distributed tracing) + external_trace_id = None + external_parent_span_id = None + traceparent_header = request.headers.get("traceparent") + if traceparent_header: + parsed = parse_traceparent(traceparent_header) + if parsed: + external_trace_id, external_parent_span_id, _flags = parsed + logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}") + + db = None + trace_id = None + span_id = None + start_time = time.time() + + try: + # Create database session + db = SessionLocal() + + # Start trace (use external trace_id if provided for distributed tracing) + trace_id = self.service.start_trace( + db=db, + name=f"{http_method} {request.url.path}", + trace_id=external_trace_id, # Use external trace ID if provided + parent_span_id=external_parent_span_id, # Track parent span from upstream + http_method=http_method, + http_url=http_url, + user_email=user_email, + user_agent=user_agent, + ip_address=ip_address, + attributes={ + "http.route": request.url.path, + "http.query": str(request.url.query) if request.url.query else None, + }, + resource_attributes={ + "service.name": "mcp-gateway", + "service.version": getattr(settings, "version", "unknown"), + }, + ) + + # Store trace_id in request state for use in route handlers + request.state.trace_id = trace_id + + # Set trace_id in context variable for access throughout async call stack + current_trace_id.set(trace_id) + + # Attach trace_id to database session for SQL query instrumentation + attach_trace_to_session(db, trace_id) + + # Start request span + span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url}) + + except Exception as e: + # If trace setup failed, log and continue without tracing + logger.warning(f"Failed to setup observability trace: {e}") + # Close db if it was created + if db: + try: + db.close() + except Exception as close_error: + logger.debug(f"Failed to close database session during cleanup: {close_error}") + # Continue without tracing + return await call_next(request) + + # Process request (trace is set up at this point) + try: + response = await call_next(request) + status_code = response.status_code + + # End span successfully + if span_id: + self.service.end_span( + db, span_id, status="ok" if status_code < 400 else "error", attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")} + ) + + # End trace + if trace_id: + duration_ms = (time.time() - start_time) * 1000 + self.service.end_trace(db, trace_id, status="ok" if status_code < 400 else "error", http_status_code=status_code, attributes={"response_time_ms": duration_ms}) + + return response + + except Exception as e: + # Log exception in span + if span_id: + try: + self.service.end_span(db, span_id, status="error", status_message=str(e), attributes={"exception.type": type(e).__name__, "exception.message": str(e)}) + + # Add exception event + self.service.add_event( + db, + span_id, + name="exception", + severity="error", + message=str(e), + exception_type=type(e).__name__, + exception_message=str(e), + exception_stacktrace=traceback.format_exc(), + ) + except Exception as log_error: + logger.warning(f"Failed to log exception in span: {log_error}") + + # End trace with error + if trace_id: + try: + self.service.end_trace(db, trace_id, status="error", status_message=str(e), http_status_code=500) + except Exception as trace_error: + logger.warning(f"Failed to end trace: {trace_error}") + + # Re-raise the original exception + raise + + finally: + # Always close database session + if db: + try: + db.close() + except Exception as close_error: + logger.warning(f"Failed to close database session: {close_error}") diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 546ad9838..09d05dcdc 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -314,7 +314,60 @@ async def _execute_with_timeout(self, hook_ref: HookRef, payload: PluginPayload, Raises: asyncio.TimeoutError: If plugin exceeds timeout. """ - return await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) + # Add observability tracing for plugin execution + try: + # First-Party + # pylint: disable=import-outside-toplevel + from mcpgateway.db import SessionLocal + from mcpgateway.services.observability_service import current_trace_id, ObservabilityService + + # pylint: enable=import-outside-toplevel + + trace_id = current_trace_id.get() + if trace_id: + db = SessionLocal() + try: + service = ObservabilityService() + span_id = service.start_span( + db=db, + trace_id=trace_id, + name=f"plugin.execute.{hook_ref.plugin_ref.name}", + kind="internal", + resource_type="plugin", + resource_name=hook_ref.plugin_ref.name, + attributes={ + "plugin.name": hook_ref.plugin_ref.name, + "plugin.uuid": hook_ref.plugin_ref.uuid, + "plugin.mode": hook_ref.plugin_ref.mode.value if hasattr(hook_ref.plugin_ref.mode, "value") else str(hook_ref.plugin_ref.mode), + "plugin.priority": hook_ref.plugin_ref.priority, + "plugin.timeout": self.timeout, + }, + ) + + # Execute plugin + result = await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) + + # End span with success + service.end_span( + db=db, + span_id=span_id, + status="ok", + attributes={ + "plugin.had_violation": result.violation is not None, + "plugin.modified_payload": result.modified_payload is not None, + }, + ) + return result + finally: + db.close() + else: + # No active trace, execute without instrumentation + return await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) + + except Exception as e: + # If observability setup fails, continue without instrumentation + logger.debug(f"Plugin observability setup failed: {e}") + return await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) def _validate_payload_size(self, payload: Any) -> None: """Validate that payload doesn't exceed size limits. diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index 10ad2343f..d6644abc3 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -406,7 +406,7 @@ class MCPServerConfig(BaseModel): tls (Optional[MCPServerTLSConfig]): Server-side TLS configuration. """ - host: str = Field(default="0.0.0.0", description="Server host to bind to") + host: str = Field(default="127.0.0.1", description="Server host to bind to") port: int = Field(default=8000, description="Server port to bind to") tls: Optional[MCPServerTLSConfig] = Field(default=None, description="Server-side TLS configuration") @@ -659,7 +659,6 @@ class PluginErrorModel(BaseModel): plugin_name: str code: Optional[str] = "" details: Optional[dict[str, Any]] = Field(default_factory=dict) - plugin_name: str mcp_error_code: int = -32603 diff --git a/mcpgateway/routers/observability.py b/mcpgateway/routers/observability.py new file mode 100644 index 000000000..6e26a58f5 --- /dev/null +++ b/mcpgateway/routers/observability.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/routers/observability.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Observability API Router. +Provides REST endpoints for querying traces, spans, events, and metrics. +""" + +# Standard +from datetime import datetime, timedelta +from typing import List, Optional + +# Third-Party +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.db import SessionLocal +from mcpgateway.schemas import ( + ObservabilitySpanRead, + ObservabilityTraceRead, + ObservabilityTraceWithSpans, +) +from mcpgateway.services.observability_service import ObservabilityService + +router = APIRouter(prefix="/observability", tags=["Observability"]) + + +def get_db(): + """Database session dependency. + + Yields: + Session: SQLAlchemy database session + """ + db = SessionLocal() + try: + yield db + finally: + db.close() + + +@router.get("/traces", response_model=List[ObservabilityTraceRead]) +def list_traces( + start_time: Optional[datetime] = Query(None, description="Filter traces after this time"), + end_time: Optional[datetime] = Query(None, description="Filter traces before this time"), + min_duration_ms: Optional[float] = Query(None, ge=0, description="Minimum duration in milliseconds"), + max_duration_ms: Optional[float] = Query(None, ge=0, description="Maximum duration in milliseconds"), + status: Optional[str] = Query(None, description="Filter by status (ok, error)"), + http_status_code: Optional[int] = Query(None, description="Filter by HTTP status code"), + http_method: Optional[str] = Query(None, description="Filter by HTTP method (GET, POST, etc.)"), + user_email: Optional[str] = Query(None, description="Filter by user email"), + attribute_search: Optional[str] = Query(None, description="Free-text search within trace attributes"), + limit: int = Query(100, ge=1, le=1000, description="Maximum results"), + offset: int = Query(0, ge=0, description="Result offset"), + db: Session = Depends(get_db), +): + """List traces with optional filtering. + + Query traces with various filters including time range, duration, status, HTTP method, + HTTP status code, user email, and attribute search. Results are paginated. + + Note: For structured attribute filtering (key-value pairs with AND logic), + use a JSON request body via POST endpoint or the Python SDK. + + Args: + start_time: Filter traces after this time + end_time: Filter traces before this time + min_duration_ms: Minimum duration in milliseconds + max_duration_ms: Maximum duration in milliseconds + status: Filter by status (ok, error) + http_status_code: Filter by HTTP status code + http_method: Filter by HTTP method (GET, POST, etc.) + user_email: Filter by user email + attribute_search: Free-text search across all trace attributes + limit: Maximum results + offset: Result offset + db: Database session + + Returns: + List[ObservabilityTraceRead]: List of traces matching filters + """ + service = ObservabilityService() + traces = service.query_traces( + db=db, + start_time=start_time, + end_time=end_time, + min_duration_ms=min_duration_ms, + max_duration_ms=max_duration_ms, + status=status, + http_status_code=http_status_code, + http_method=http_method, + user_email=user_email, + attribute_search=attribute_search, + limit=limit, + offset=offset, + ) + return traces + + +@router.post("/traces/query", response_model=List[ObservabilityTraceRead]) +def query_traces_advanced( + # Third-Party + request_body: dict, + db: Session = Depends(get_db), +): + """Advanced trace querying with attribute filtering. + + POST endpoint that accepts a JSON body with complex filtering criteria, + including structured attribute filters with AND logic. + + Request Body: + { + "start_time": "2025-01-01T00:00:00Z", # Optional datetime + "end_time": "2025-01-02T00:00:00Z", # Optional datetime + "min_duration_ms": 100.0, # Optional float + "max_duration_ms": 5000.0, # Optional float + "status": "error", # Optional string + "http_status_code": 500, # Optional int + "http_method": "POST", # Optional string + "user_email": "user@example.com", # Optional string + "attribute_filters": { # Optional dict (AND logic) + "http.route": "/api/tools", + "service.name": "mcp-gateway" + }, + "attribute_search": "error", # Optional string (OR logic) + "limit": 100, # Optional int + "offset": 0 # Optional int + } + + Args: + request_body: JSON request body with filter criteria + db: Database session + + Returns: + List[ObservabilityTraceRead]: List of traces matching filters + + Raises: + HTTPException: 400 error if request body is invalid + """ + # Third-Party + from pydantic import ValidationError + + try: + # Extract filters from request body + service = ObservabilityService() + + # Parse datetime strings if provided + start_time = request_body.get("start_time") + if start_time and isinstance(start_time, str): + start_time = datetime.fromisoformat(start_time.replace("Z", "+00:00")) + + end_time = request_body.get("end_time") + if end_time and isinstance(end_time, str): + end_time = datetime.fromisoformat(end_time.replace("Z", "+00:00")) + + traces = service.query_traces( + db=db, + start_time=start_time, + end_time=end_time, + min_duration_ms=request_body.get("min_duration_ms"), + max_duration_ms=request_body.get("max_duration_ms"), + status=request_body.get("status"), + status_in=request_body.get("status_in"), + status_not_in=request_body.get("status_not_in"), + http_status_code=request_body.get("http_status_code"), + http_status_code_in=request_body.get("http_status_code_in"), + http_method=request_body.get("http_method"), + http_method_in=request_body.get("http_method_in"), + user_email=request_body.get("user_email"), + user_email_in=request_body.get("user_email_in"), + attribute_filters=request_body.get("attribute_filters"), + attribute_filters_or=request_body.get("attribute_filters_or"), + attribute_search=request_body.get("attribute_search"), + name_contains=request_body.get("name_contains"), + order_by=request_body.get("order_by", "start_time_desc"), + limit=request_body.get("limit", 100), + offset=request_body.get("offset", 0), + ) + return traces + except (ValidationError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"Invalid request body: {e}") + + +@router.get("/traces/{trace_id}", response_model=ObservabilityTraceWithSpans) +def get_trace(trace_id: str, db: Session = Depends(get_db)): + """Get a trace by ID with all its spans and events. + + Returns a complete trace with all nested spans and their events, + providing a full view of the request flow. + + Args: + trace_id: UUID of the trace to retrieve + db: Database session + + Returns: + ObservabilityTraceWithSpans: Complete trace with all spans and events + + Raises: + HTTPException: 404 if trace not found + """ + service = ObservabilityService() + trace = service.get_trace_with_spans(db, trace_id) + if not trace: + raise HTTPException(status_code=404, detail="Trace not found") + return trace + + +@router.get("/spans", response_model=List[ObservabilitySpanRead]) +def list_spans( + trace_id: Optional[str] = Query(None, description="Filter by trace ID"), + resource_type: Optional[str] = Query(None, description="Filter by resource type"), + resource_name: Optional[str] = Query(None, description="Filter by resource name"), + start_time: Optional[datetime] = Query(None, description="Filter spans after this time"), + end_time: Optional[datetime] = Query(None, description="Filter spans before this time"), + limit: int = Query(100, ge=1, le=1000, description="Maximum results"), + offset: int = Query(0, ge=0, description="Result offset"), + db: Session = Depends(get_db), +): + """List spans with optional filtering. + + Query spans by trace ID, resource type, resource name, or time range. + Useful for analyzing specific operations or resource performance. + + Args: + trace_id: Filter by trace ID + resource_type: Filter by resource type + resource_name: Filter by resource name + start_time: Filter spans after this time + end_time: Filter spans before this time + limit: Maximum results + offset: Result offset + db: Database session + + Returns: + List[ObservabilitySpanRead]: List of spans matching filters + """ + service = ObservabilityService() + spans = service.query_spans( + db=db, + trace_id=trace_id, + resource_type=resource_type, + resource_name=resource_name, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset, + ) + return spans + + +@router.delete("/traces/cleanup") +def cleanup_old_traces( + days: int = Query(7, ge=1, description="Delete traces older than this many days"), + db: Session = Depends(get_db), +): + """Delete traces older than a specified number of days. + + Cleans up old trace data to manage storage. Cascading deletes will + also remove associated spans, events, and metrics. + + Args: + days: Delete traces older than this many days + db: Database session + + Returns: + dict: Number of deleted traces and cutoff time + """ + service = ObservabilityService() + cutoff_time = datetime.now() - timedelta(days=days) + deleted = service.delete_old_traces(db, cutoff_time) + return {"deleted": deleted, "cutoff_time": cutoff_time} + + +@router.get("/stats") +def get_stats( + hours: int = Query(24, ge=1, le=168, description="Time window in hours"), + db: Session = Depends(get_db), +): + """Get observability statistics. + + Returns summary statistics including: + - Total traces in time window + - Success/error counts + - Average response time + - Top slowest endpoints + + Args: + hours: Time window in hours + db: Database session + + Returns: + dict: Statistics including counts, error rate, and slowest endpoints + """ + # Third-Party + from sqlalchemy import func + + # First-Party + from mcpgateway.db import ObservabilityTrace + + ObservabilityService() + cutoff_time = datetime.now() - timedelta(hours=hours) + + # Get basic counts + total_traces = db.query(func.count(ObservabilityTrace.trace_id)).filter(ObservabilityTrace.start_time >= cutoff_time).scalar() + + success_count = db.query(func.count(ObservabilityTrace.trace_id)).filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.status == "ok").scalar() + + error_count = db.query(func.count(ObservabilityTrace.trace_id)).filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.status == "error").scalar() + + avg_duration = db.query(func.avg(ObservabilityTrace.duration_ms)).filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.duration_ms.isnot(None)).scalar() or 0 + + # Get slowest endpoints + slowest = ( + db.query(ObservabilityTrace.name, func.avg(ObservabilityTrace.duration_ms).label("avg_duration"), func.count(ObservabilityTrace.trace_id).label("count")) + .filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.duration_ms.isnot(None)) + .group_by(ObservabilityTrace.name) + .order_by(func.avg(ObservabilityTrace.duration_ms).desc()) + .limit(10) + .all() + ) + + return { + "time_window_hours": hours, + "total_traces": total_traces, + "success_count": success_count, + "error_count": error_count, + "error_rate": (error_count / total_traces * 100) if total_traces > 0 else 0, + "avg_duration_ms": round(avg_duration, 2), + "slowest_endpoints": [{"name": row[0], "avg_duration_ms": round(row[1], 2), "count": row[2]} for row in slowest], + } + + +@router.post("/traces/export") +def export_traces( + request_body: dict, + format: str = Query("json", description="Export format (json, csv, ndjson)"), + db: Session = Depends(get_db), +): + """Export traces in various formats. + + POST endpoint that accepts filter criteria (same as /traces/query) and exports + matching traces in the specified format. + + Supported formats: + - json: Standard JSON array + - csv: Comma-separated values + - ndjson: Newline-delimited JSON (streaming) + + Args: + request_body: JSON request body with filter criteria (same as /traces/query) + format: Export format (json, csv, ndjson) + db: Database session + + Returns: + StreamingResponse or JSONResponse with exported data + + Raises: + HTTPException: 400 error if format is invalid or export fails + """ + # Standard + import csv + import io + + # Third-Party + from starlette.responses import Response, StreamingResponse + + # Validate format + if format not in ["json", "csv", "ndjson"]: + raise HTTPException(status_code=400, detail="format must be one of: json, csv, ndjson") + + try: + service = ObservabilityService() + + # Parse datetime strings + start_time = request_body.get("start_time") + if start_time and isinstance(start_time, str): + start_time = datetime.fromisoformat(start_time.replace("Z", "+00:00")) + + end_time = request_body.get("end_time") + if end_time and isinstance(end_time, str): + end_time = datetime.fromisoformat(end_time.replace("Z", "+00:00")) + + # Query traces + traces = service.query_traces( + db=db, + start_time=start_time, + end_time=end_time, + min_duration_ms=request_body.get("min_duration_ms"), + max_duration_ms=request_body.get("max_duration_ms"), + status=request_body.get("status"), + status_in=request_body.get("status_in"), + http_status_code=request_body.get("http_status_code"), + http_method=request_body.get("http_method"), + user_email=request_body.get("user_email"), + order_by=request_body.get("order_by", "start_time_desc"), + limit=request_body.get("limit", 1000), # Higher limit for export + offset=request_body.get("offset", 0), + ) + + if format == "json": + # Standard JSON response + return [ + { + "trace_id": t.trace_id, + "name": t.name, + "start_time": t.start_time.isoformat() if t.start_time else None, + "end_time": t.end_time.isoformat() if t.end_time else None, + "duration_ms": t.duration_ms, + "status": t.status, + "http_method": t.http_method, + "http_url": t.http_url, + "http_status_code": t.http_status_code, + "user_email": t.user_email, + } + for t in traces + ] + + elif format == "csv": + # CSV export + output = io.StringIO() + writer = csv.writer(output) + + # Write header + writer.writerow(["trace_id", "name", "start_time", "duration_ms", "status", "http_method", "http_status_code", "user_email"]) + + # Write data + for t in traces: + writer.writerow( + [t.trace_id, t.name, t.start_time.isoformat() if t.start_time else "", t.duration_ms or "", t.status, t.http_method or "", t.http_status_code or "", t.user_email or ""] + ) + + output.seek(0) + return Response(content=output.getvalue(), media_type="text/csv", headers={"Content-Disposition": "attachment; filename=traces.csv"}) + + elif format == "ndjson": + # Newline-delimited JSON (streaming) + def generate(): + for t in traces: + # Standard + import json + + yield json.dumps( + { + "trace_id": t.trace_id, + "name": t.name, + "start_time": t.start_time.isoformat() if t.start_time else None, + "duration_ms": t.duration_ms, + "status": t.status, + "http_method": t.http_method, + "http_status_code": t.http_status_code, + "user_email": t.user_email, + } + ) + "\n" + + return StreamingResponse(generate(), media_type="application/x-ndjson", headers={"Content-Disposition": "attachment; filename=traces.ndjson"}) + + except (ValueError, Exception) as e: + raise HTTPException(status_code=400, detail=f"Export failed: {e}") + + +@router.get("/analytics/query-performance") +def get_query_performance(hours: int = Query(24, ge=1, le=168, description="Time window in hours"), db: Session = Depends(get_db)): + """Get query performance analytics. + + Returns performance metrics about trace queries including: + - Average, min, max, p50, p95, p99 durations + - Query volume over time + - Error rate trends + + Args: + hours: Time window in hours + db: Database session + + Returns: + dict: Performance analytics + """ + # Third-Party + + # First-Party + from mcpgateway.db import ObservabilityTrace + + ObservabilityService() + cutoff_time = datetime.now() - timedelta(hours=hours) + + # Get duration percentiles using SQL + traces_with_duration = db.query(ObservabilityTrace.duration_ms).filter(ObservabilityTrace.start_time >= cutoff_time, ObservabilityTrace.duration_ms.isnot(None)).all() + + durations = sorted([t[0] for t in traces_with_duration if t[0] is not None]) + + if not durations: + return { + "time_window_hours": hours, + "total_traces": 0, + "percentiles": {}, + "avg_duration_ms": 0, + "min_duration_ms": 0, + "max_duration_ms": 0, + } + + def percentile(data, p): + n = len(data) + if n == 0: + return 0 + k = (n - 1) * p + f = int(k) + c = k - f + if f + 1 < n: + return data[f] + (c * (data[f + 1] - data[f])) + return data[f] + + return { + "time_window_hours": hours, + "total_traces": len(durations), + "percentiles": { + "p50": round(percentile(durations, 0.50), 2), + "p75": round(percentile(durations, 0.75), 2), + "p90": round(percentile(durations, 0.90), 2), + "p95": round(percentile(durations, 0.95), 2), + "p99": round(percentile(durations, 0.99), 2), + }, + "avg_duration_ms": round(sum(durations) / len(durations), 2), + "min_duration_ms": round(durations[0], 2), + "max_duration_ms": round(durations[-1], 2), + } diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 792e6891c..bc219511e 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -6497,3 +6497,177 @@ class PaginationParams(BaseModel): cursor: Optional[str] = Field(None, description="Cursor for cursor-based pagination") sort_by: Optional[str] = Field("created_at", description="Sort field") sort_order: Optional[str] = Field("desc", pattern="^(asc|desc)$", description="Sort order") + + +# ============================================================================ +# Observability Schemas (OpenTelemetry-style traces, spans, events, metrics) +# ============================================================================ + + +class ObservabilityTraceBase(BaseModel): + """Base schema for observability traces.""" + + name: str = Field(..., description="Trace name (e.g., 'POST /tools/invoke')") + start_time: datetime = Field(..., description="Trace start timestamp") + end_time: Optional[datetime] = Field(None, description="Trace end timestamp") + duration_ms: Optional[float] = Field(None, description="Total duration in milliseconds") + status: str = Field("unset", description="Trace status (unset, ok, error)") + status_message: Optional[str] = Field(None, description="Status message or error description") + http_method: Optional[str] = Field(None, description="HTTP method") + http_url: Optional[str] = Field(None, description="HTTP URL") + http_status_code: Optional[int] = Field(None, description="HTTP status code") + user_email: Optional[str] = Field(None, description="User email") + user_agent: Optional[str] = Field(None, description="User agent string") + ip_address: Optional[str] = Field(None, description="Client IP address") + attributes: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional trace attributes") + resource_attributes: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Resource attributes") + + +class ObservabilityTraceCreate(ObservabilityTraceBase): + """Schema for creating an observability trace.""" + + trace_id: Optional[str] = Field(None, description="Trace ID (generated if not provided)") + + +class ObservabilityTraceUpdate(BaseModel): + """Schema for updating an observability trace.""" + + end_time: Optional[datetime] = None + duration_ms: Optional[float] = None + status: Optional[str] = None + status_message: Optional[str] = None + http_status_code: Optional[int] = None + attributes: Optional[Dict[str, Any]] = None + + +class ObservabilityTraceRead(ObservabilityTraceBase): + """Schema for reading an observability trace.""" + + trace_id: str = Field(..., description="Trace ID") + created_at: datetime = Field(..., description="Creation timestamp") + + model_config = {"from_attributes": True} + + +class ObservabilitySpanBase(BaseModel): + """Base schema for observability spans.""" + + trace_id: str = Field(..., description="Parent trace ID") + parent_span_id: Optional[str] = Field(None, description="Parent span ID (for nested spans)") + name: str = Field(..., description="Span name (e.g., 'database_query', 'tool_invocation')") + kind: str = Field("internal", description="Span kind (internal, server, client, producer, consumer)") + start_time: datetime = Field(..., description="Span start timestamp") + end_time: Optional[datetime] = Field(None, description="Span end timestamp") + duration_ms: Optional[float] = Field(None, description="Span duration in milliseconds") + status: str = Field("unset", description="Span status (unset, ok, error)") + status_message: Optional[str] = Field(None, description="Status message") + attributes: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Span attributes") + resource_name: Optional[str] = Field(None, description="Resource name") + resource_type: Optional[str] = Field(None, description="Resource type (tool, resource, prompt, gateway, a2a_agent)") + resource_id: Optional[str] = Field(None, description="Resource ID") + + +class ObservabilitySpanCreate(ObservabilitySpanBase): + """Schema for creating an observability span.""" + + span_id: Optional[str] = Field(None, description="Span ID (generated if not provided)") + + +class ObservabilitySpanUpdate(BaseModel): + """Schema for updating an observability span.""" + + end_time: Optional[datetime] = None + duration_ms: Optional[float] = None + status: Optional[str] = None + status_message: Optional[str] = None + attributes: Optional[Dict[str, Any]] = None + + +class ObservabilitySpanRead(ObservabilitySpanBase): + """Schema for reading an observability span.""" + + span_id: str = Field(..., description="Span ID") + created_at: datetime = Field(..., description="Creation timestamp") + + model_config = {"from_attributes": True} + + +class ObservabilityEventBase(BaseModel): + """Base schema for observability events.""" + + span_id: str = Field(..., description="Parent span ID") + name: str = Field(..., description="Event name (e.g., 'exception', 'log', 'checkpoint')") + timestamp: datetime = Field(..., description="Event timestamp") + attributes: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Event attributes") + severity: Optional[str] = Field(None, description="Log severity (debug, info, warning, error, critical)") + message: Optional[str] = Field(None, description="Event message") + exception_type: Optional[str] = Field(None, description="Exception class name") + exception_message: Optional[str] = Field(None, description="Exception message") + exception_stacktrace: Optional[str] = Field(None, description="Exception stacktrace") + + +class ObservabilityEventCreate(ObservabilityEventBase): + """Schema for creating an observability event.""" + + +class ObservabilityEventRead(ObservabilityEventBase): + """Schema for reading an observability event.""" + + id: int = Field(..., description="Event ID") + created_at: datetime = Field(..., description="Creation timestamp") + + model_config = {"from_attributes": True} + + +class ObservabilityMetricBase(BaseModel): + """Base schema for observability metrics.""" + + name: str = Field(..., description="Metric name (e.g., 'http.request.duration', 'tool.invocation.count')") + metric_type: str = Field(..., description="Metric type (counter, gauge, histogram)") + value: float = Field(..., description="Metric value") + timestamp: datetime = Field(..., description="Metric timestamp") + unit: Optional[str] = Field(None, description="Metric unit (ms, count, bytes, etc.)") + attributes: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Metric attributes/labels") + resource_type: Optional[str] = Field(None, description="Resource type") + resource_id: Optional[str] = Field(None, description="Resource ID") + trace_id: Optional[str] = Field(None, description="Associated trace ID") + + +class ObservabilityMetricCreate(ObservabilityMetricBase): + """Schema for creating an observability metric.""" + + +class ObservabilityMetricRead(ObservabilityMetricBase): + """Schema for reading an observability metric.""" + + id: int = Field(..., description="Metric ID") + created_at: datetime = Field(..., description="Creation timestamp") + + model_config = {"from_attributes": True} + + +class ObservabilityTraceWithSpans(ObservabilityTraceRead): + """Schema for reading a trace with its spans.""" + + spans: List[ObservabilitySpanRead] = Field(default_factory=list, description="List of spans in this trace") + + +class ObservabilitySpanWithEvents(ObservabilitySpanRead): + """Schema for reading a span with its events.""" + + events: List[ObservabilityEventRead] = Field(default_factory=list, description="List of events in this span") + + +class ObservabilityQueryParams(BaseModel): + """Query parameters for filtering observability data.""" + + start_time: Optional[datetime] = Field(None, description="Filter traces/spans/metrics after this time") + end_time: Optional[datetime] = Field(None, description="Filter traces/spans/metrics before this time") + status: Optional[str] = Field(None, description="Filter by status (ok, error, unset)") + http_status_code: Optional[int] = Field(None, description="Filter by HTTP status code") + user_email: Optional[str] = Field(None, description="Filter by user email") + resource_type: Optional[str] = Field(None, description="Filter by resource type") + resource_name: Optional[str] = Field(None, description="Filter by resource name") + trace_id: Optional[str] = Field(None, description="Filter by trace ID") + limit: int = Field(default=100, ge=1, le=1000, description="Maximum number of results") + offset: int = Field(default=0, ge=0, description="Result offset for pagination") diff --git a/mcpgateway/services/observability_service.py b/mcpgateway/services/observability_service.py new file mode 100644 index 000000000..e8dd9b96c --- /dev/null +++ b/mcpgateway/services/observability_service.py @@ -0,0 +1,1396 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/observability_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Observability Service Implementation. +This module provides OpenTelemetry-style observability for MCP Gateway, +capturing traces, spans, events, and metrics for all operations. + +It includes: +- Trace creation and management +- Span tracking with hierarchical nesting +- Event logging within spans +- Metrics collection and storage +- Query and filtering capabilities +- Integration with FastAPI middleware + +Examples: + >>> from mcpgateway.services.observability_service import ObservabilityService # doctest: +SKIP + >>> service = ObservabilityService() # doctest: +SKIP + >>> trace_id = service.start_trace(db, "GET /tools", http_method="GET", http_url="/tools") # doctest: +SKIP + >>> span_id = service.start_span(db, trace_id, "database_query", resource_type="database") # doctest: +SKIP + >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP + >>> service.end_trace(db, trace_id, status="ok", http_status_code=200) # doctest: +SKIP +""" + +# Standard +from contextlib import contextmanager +from contextvars import ContextVar +from datetime import datetime, timezone +import logging +import re +import traceback +from typing import Any, Dict, List, Optional, Tuple +import uuid + +# Third-Party +from sqlalchemy import desc +from sqlalchemy.orm import joinedload, Session + +# First-Party +from mcpgateway.db import ObservabilityEvent, ObservabilityMetric, ObservabilitySpan, ObservabilityTrace + +logger = logging.getLogger(__name__) + +# Context variable for tracking the current trace_id across async calls +current_trace_id: ContextVar[Optional[str]] = ContextVar("current_trace_id", default=None) + + +def utc_now() -> datetime: + """Return current UTC time with timezone. + + Returns: + datetime: Current time in UTC with timezone info + """ + return datetime.now(timezone.utc) + + +def ensure_timezone_aware(dt: datetime) -> datetime: + """Ensure datetime is timezone-aware (UTC). + + SQLite returns naive datetimes even when stored with timezone info. + This helper ensures consistency for datetime arithmetic. + + Args: + dt: Datetime that may be naive or aware + + Returns: + Timezone-aware datetime in UTC + """ + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt + + +def parse_traceparent(traceparent: str) -> Optional[Tuple[str, str, str]]: + """Parse W3C Trace Context traceparent header. + + Format: version-trace_id-parent_id-trace_flags + Example: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01 + + Args: + traceparent: W3C traceparent header value + + Returns: + Tuple of (trace_id, parent_span_id, trace_flags) or None if invalid + + Examples: + >>> parse_traceparent("00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01") # doctest: +SKIP + ('0af7651916cd43dd8448eb211c80319c', 'b7ad6b7169203331', '01') + """ + # W3C Trace Context format: 00-trace_id(32hex)-parent_id(16hex)-flags(2hex) + pattern = r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$" + match = re.match(pattern, traceparent.lower()) + + if not match: + logger.warning(f"Invalid traceparent format: {traceparent}") + return None + + version, trace_id, parent_id, flags = match.groups() + + # Only support version 00 for now + if version != "00": + logger.warning(f"Unsupported traceparent version: {version}") + return None + + # Validate trace_id and parent_id are not all zeros + if trace_id == "0" * 32 or parent_id == "0" * 16: + logger.warning("Invalid traceparent with zero trace_id or parent_id") + return None + + return (trace_id, parent_id, flags) + + +def generate_w3c_trace_id() -> str: + """Generate a W3C compliant trace ID (32 hex characters). + + Returns: + 32-character lowercase hex string + + Examples: + >>> trace_id = generate_w3c_trace_id() # doctest: +SKIP + >>> len(trace_id) # doctest: +SKIP + 32 + """ + return uuid.uuid4().hex + uuid.uuid4().hex[:16] + + +def generate_w3c_span_id() -> str: + """Generate a W3C compliant span ID (16 hex characters). + + Returns: + 16-character lowercase hex string + + Examples: + >>> span_id = generate_w3c_span_id() # doctest: +SKIP + >>> len(span_id) # doctest: +SKIP + 16 + """ + return uuid.uuid4().hex[:16] + + +def format_traceparent(trace_id: str, span_id: str, sampled: bool = True) -> str: + """Format a W3C traceparent header value. + + Args: + trace_id: 32-character hex trace ID + span_id: 16-character hex span ID + sampled: Whether the trace is sampled (affects trace-flags) + + Returns: + W3C traceparent header value + + Examples: + >>> format_traceparent("0af7651916cd43dd8448eb211c80319c", "b7ad6b7169203331") # doctest: +SKIP + '00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01' + """ + flags = "01" if sampled else "00" + return f"00-{trace_id}-{span_id}-{flags}" + + +class ObservabilityService: + """Service for managing observability traces, spans, events, and metrics. + + This service provides comprehensive observability capabilities similar to + OpenTelemetry, allowing tracking of request flows through the system. + + Examples: + >>> service = ObservabilityService() # doctest: +SKIP + >>> trace_id = service.start_trace(db, "POST /tools/invoke") # doctest: +SKIP + >>> span_id = service.start_span(db, trace_id, "tool_execution") # doctest: +SKIP + >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP + >>> service.end_trace(db, trace_id, status="ok") # doctest: +SKIP + """ + + # ============================== + # Trace Management + # ============================== + + def start_trace( + self, + db: Session, + name: str, + trace_id: Optional[str] = None, + parent_span_id: Optional[str] = None, + http_method: Optional[str] = None, + http_url: Optional[str] = None, + user_email: Optional[str] = None, + user_agent: Optional[str] = None, + ip_address: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None, + resource_attributes: Optional[Dict[str, Any]] = None, + ) -> str: + """Start a new trace. + + Args: + db: Database session + name: Trace name (e.g., "POST /tools/invoke") + trace_id: External trace ID (for distributed tracing, W3C format) + parent_span_id: Parent span ID from upstream service + http_method: HTTP method (GET, POST, etc.) + http_url: Full request URL + user_email: Authenticated user email + user_agent: Client user agent string + ip_address: Client IP address + attributes: Additional trace attributes + resource_attributes: Resource attributes (service name, version, etc.) + + Returns: + Trace ID (UUID string or W3C format) + + Examples: + >>> trace_id = service.start_trace( # doctest: +SKIP + ... db, + ... "POST /tools/invoke", + ... http_method="POST", + ... http_url="https://api.example.com/tools/invoke", + ... user_email="user@example.com" + ... ) + """ + # Use provided trace_id or generate new UUID + if not trace_id: + trace_id = str(uuid.uuid4()) + + # Add parent context to attributes if provided + attrs = attributes or {} + if parent_span_id: + attrs["parent_span_id"] = parent_span_id + + trace = ObservabilityTrace( + trace_id=trace_id, + name=name, + start_time=utc_now(), + status="unset", + http_method=http_method, + http_url=http_url, + user_email=user_email, + user_agent=user_agent, + ip_address=ip_address, + attributes=attrs, + resource_attributes=resource_attributes or {}, + created_at=utc_now(), + ) + db.add(trace) + db.commit() + logger.debug(f"Started trace {trace_id}: {name}") + return trace_id + + def end_trace( + self, + db: Session, + trace_id: str, + status: str = "ok", + status_message: Optional[str] = None, + http_status_code: Optional[int] = None, + attributes: Optional[Dict[str, Any]] = None, + ) -> None: + """End a trace. + + Args: + db: Database session + trace_id: Trace ID to end + status: Trace status (ok, error) + status_message: Optional status message + http_status_code: HTTP response status code + attributes: Additional attributes to merge + + Examples: + >>> service.end_trace( # doctest: +SKIP + ... db, + ... trace_id, + ... status="ok", + ... http_status_code=200 + ... ) + """ + trace = db.query(ObservabilityTrace).filter_by(trace_id=trace_id).first() + if not trace: + logger.warning(f"Trace {trace_id} not found") + return + + end_time = utc_now() + duration_ms = (end_time - ensure_timezone_aware(trace.start_time)).total_seconds() * 1000 + + trace.end_time = end_time + trace.duration_ms = duration_ms + trace.status = status + trace.status_message = status_message + if http_status_code is not None: + trace.http_status_code = http_status_code + if attributes: + trace.attributes = {**(trace.attributes or {}), **attributes} + + db.commit() + logger.debug(f"Ended trace {trace_id}: {status} ({duration_ms:.2f}ms)") + + def get_trace(self, db: Session, trace_id: str, include_spans: bool = False) -> Optional[ObservabilityTrace]: + """Get a trace by ID. + + Args: + db: Database session + trace_id: Trace ID + include_spans: Whether to load spans eagerly + + Returns: + Trace object or None if not found + + Examples: + >>> trace = service.get_trace(db, trace_id, include_spans=True) # doctest: +SKIP + >>> if trace: # doctest: +SKIP + ... print(f"Trace: {trace.name}, Spans: {len(trace.spans)}") # doctest: +SKIP + """ + query = db.query(ObservabilityTrace).filter_by(trace_id=trace_id) + if include_spans: + query = query.options(joinedload(ObservabilityTrace.spans)) + return query.first() + + # ============================== + # Span Management + # ============================== + + def start_span( + self, + db: Session, + trace_id: str, + name: str, + parent_span_id: Optional[str] = None, + kind: str = "internal", + resource_name: Optional[str] = None, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None, + ) -> str: + """Start a new span within a trace. + + Args: + db: Database session + trace_id: Parent trace ID + name: Span name (e.g., "database_query", "tool_invocation") + parent_span_id: Parent span ID (for nested spans) + kind: Span kind (internal, server, client, producer, consumer) + resource_name: Resource name being operated on + resource_type: Resource type (tool, resource, prompt, etc.) + resource_id: Resource ID + attributes: Additional span attributes + + Returns: + Span ID (UUID string) + + Examples: + >>> span_id = service.start_span( # doctest: +SKIP + ... db, + ... trace_id, + ... "tool_invocation", + ... resource_type="tool", + ... resource_name="get_weather" + ... ) + """ + span_id = str(uuid.uuid4()) + span = ObservabilitySpan( + span_id=span_id, + trace_id=trace_id, + parent_span_id=parent_span_id, + name=name, + kind=kind, + start_time=utc_now(), + status="unset", + resource_name=resource_name, + resource_type=resource_type, + resource_id=resource_id, + attributes=attributes or {}, + created_at=utc_now(), + ) + db.add(span) + db.commit() + logger.debug(f"Started span {span_id}: {name} (trace={trace_id})") + return span_id + + def end_span( + self, + db: Session, + span_id: str, + status: str = "ok", + status_message: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None, + ) -> None: + """End a span. + + Args: + db: Database session + span_id: Span ID to end + status: Span status (ok, error) + status_message: Optional status message + attributes: Additional attributes to merge + + Examples: + >>> service.end_span(db, span_id, status="ok") # doctest: +SKIP + """ + span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first() + if not span: + logger.warning(f"Span {span_id} not found") + return + + end_time = utc_now() + duration_ms = (end_time - ensure_timezone_aware(span.start_time)).total_seconds() * 1000 + + span.end_time = end_time + span.duration_ms = duration_ms + span.status = status + span.status_message = status_message + if attributes: + span.attributes = {**(span.attributes or {}), **attributes} + + db.commit() + logger.debug(f"Ended span {span_id}: {status} ({duration_ms:.2f}ms)") + + @contextmanager + def trace_span( + self, + db: Session, + trace_id: str, + name: str, + parent_span_id: Optional[str] = None, + resource_type: Optional[str] = None, + resource_name: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None, + ): + """Context manager for automatic span lifecycle management. + + Args: + db: Database session + trace_id: Parent trace ID + name: Span name + parent_span_id: Parent span ID (optional) + resource_type: Resource type + resource_name: Resource name + attributes: Additional attributes + + Yields: + Span ID + + Raises: + Exception: Re-raises any exception after logging it in the span + + Examples: + >>> with service.trace_span(db, trace_id, "database_query") as span_id: # doctest: +SKIP + ... results = db.query(Tool).all() # doctest: +SKIP + """ + span_id = self.start_span(db, trace_id, name, parent_span_id, resource_type=resource_type, resource_name=resource_name, attributes=attributes) + try: + yield span_id + self.end_span(db, span_id, status="ok") + except Exception as e: + self.end_span(db, span_id, status="error", status_message=str(e)) + self.add_event(db, span_id, "exception", severity="error", message=str(e), exception_type=type(e).__name__, exception_message=str(e), exception_stacktrace=traceback.format_exc()) + raise + + @contextmanager + def trace_tool_invocation( + self, + db: Session, + tool_name: str, + arguments: Dict[str, Any], + integration_type: Optional[str] = None, + ): + """Context manager for tracing MCP tool invocations. + + This automatically creates a span for tool execution, capturing timing, + arguments, results, and errors. + + Args: + db: Database session + tool_name: Name of the tool being invoked + arguments: Tool arguments (will be sanitized) + integration_type: Integration type (MCP, REST, A2A, etc.) + + Yields: + Tuple of (span_id, result_dict) - update result_dict with tool results + + Raises: + Exception: Re-raises any exception from tool invocation after logging + + Examples: + >>> with service.trace_tool_invocation(db, "weather", {"city": "NYC"}) as (span_id, result): # doctest: +SKIP + ... response = await http_client.post(...) # doctest: +SKIP + ... result["status_code"] = response.status_code # doctest: +SKIP + ... result["response_size"] = len(response.content) # doctest: +SKIP + """ + trace_id = current_trace_id.get() + if not trace_id: + # No active trace, yield a no-op + result_dict: Dict[str, Any] = {} + yield (None, result_dict) + return + + # Sanitize arguments (remove sensitive data) + safe_args = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret"]) else v) for k, v in arguments.items()} + + # Start tool invocation span + span_id = self.start_span( + db=db, + trace_id=trace_id, + name=f"tool.invoke.{tool_name}", + kind="client", + resource_type="tool", + resource_name=tool_name, + attributes={ + "tool.name": tool_name, + "tool.integration_type": integration_type, + "tool.argument_count": len(arguments), + "tool.arguments": safe_args, + }, + ) + + result_dict = {} + try: + yield (span_id, result_dict) + + # End span with results + self.end_span( + db=db, + span_id=span_id, + status="ok", + attributes={ + "tool.result": result_dict, + }, + ) + except Exception as e: + # Log error in span + self.end_span(db=db, span_id=span_id, status="error", status_message=str(e)) + + self.add_event( + db=db, + span_id=span_id, + name="tool.error", + severity="error", + message=str(e), + exception_type=type(e).__name__, + exception_message=str(e), + exception_stacktrace=traceback.format_exc(), + ) + raise + + # ============================== + # Event Management + # ============================== + + def add_event( + self, + db: Session, + span_id: str, + name: str, + severity: Optional[str] = None, + message: Optional[str] = None, + exception_type: Optional[str] = None, + exception_message: Optional[str] = None, + exception_stacktrace: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None, + ) -> int: + """Add an event to a span. + + Args: + db: Database session + span_id: Parent span ID + name: Event name + severity: Log severity (debug, info, warning, error, critical) + message: Event message + exception_type: Exception class name + exception_message: Exception message + exception_stacktrace: Exception stacktrace + attributes: Additional event attributes + + Returns: + Event ID + + Examples: + >>> event_id = service.add_event( # doctest: +SKIP + ... db, # doctest: +SKIP + ... span_id, # doctest: +SKIP + ... "database_connection_error", # doctest: +SKIP + ... severity="error", # doctest: +SKIP + ... message="Failed to connect to database" # doctest: +SKIP + ... ) # doctest: +SKIP + """ + event = ObservabilityEvent( + span_id=span_id, + name=name, + timestamp=utc_now(), + severity=severity, + message=message, + exception_type=exception_type, + exception_message=exception_message, + exception_stacktrace=exception_stacktrace, + attributes=attributes or {}, + created_at=utc_now(), + ) + db.add(event) + db.commit() + db.refresh(event) + logger.debug(f"Added event to span {span_id}: {name}") + return event.id + + # ============================== + # Token Usage Tracking + # ============================== + + def record_token_usage( + self, + db: Session, + span_id: Optional[str] = None, + trace_id: Optional[str] = None, + model: Optional[str] = None, + input_tokens: int = 0, + output_tokens: int = 0, + total_tokens: Optional[int] = None, + estimated_cost_usd: Optional[float] = None, + provider: Optional[str] = None, + ) -> None: + """Record token usage for LLM calls. + + Args: + db: Database session + span_id: Span ID to attach token usage to + trace_id: Trace ID (will use current context if not provided) + model: Model name (e.g., "gpt-4", "claude-3-opus") + input_tokens: Number of input/prompt tokens + output_tokens: Number of output/completion tokens + total_tokens: Total tokens (calculated if not provided) + estimated_cost_usd: Estimated cost in USD + provider: LLM provider (openai, anthropic, etc.) + + Examples: + >>> service.record_token_usage( # doctest: +SKIP + ... db, span_id="abc123", + ... model="gpt-4", + ... input_tokens=100, + ... output_tokens=50, + ... estimated_cost_usd=0.015 + ... ) + """ + if not trace_id: + trace_id = current_trace_id.get() + + if not trace_id: + logger.warning("Cannot record token usage: no active trace") + return + + # Calculate total if not provided + if total_tokens is None: + total_tokens = input_tokens + output_tokens + + # Estimate cost if not provided and we have model info + if estimated_cost_usd is None and model: + estimated_cost_usd = self._estimate_token_cost(model, input_tokens, output_tokens) + + # Store in span attributes if span_id provided + if span_id: + span = db.query(ObservabilitySpan).filter_by(span_id=span_id).first() + if span: + attrs = span.attributes or {} + attrs.update( + { + "llm.model": model, + "llm.provider": provider, + "llm.input_tokens": input_tokens, + "llm.output_tokens": output_tokens, + "llm.total_tokens": total_tokens, + "llm.estimated_cost_usd": estimated_cost_usd, + } + ) + span.attributes = attrs + db.commit() + + # Also record as metrics for aggregation + if input_tokens > 0: + self.record_metric( + db=db, + name="llm.tokens.input", + value=float(input_tokens), + metric_type="counter", + unit="tokens", + trace_id=trace_id, + attributes={"model": model, "provider": provider}, + ) + + if output_tokens > 0: + self.record_metric( + db=db, + name="llm.tokens.output", + value=float(output_tokens), + metric_type="counter", + unit="tokens", + trace_id=trace_id, + attributes={"model": model, "provider": provider}, + ) + + if estimated_cost_usd: + self.record_metric( + db=db, + name="llm.cost", + value=estimated_cost_usd, + metric_type="counter", + unit="usd", + trace_id=trace_id, + attributes={"model": model, "provider": provider}, + ) + + logger.debug(f"Recorded token usage: {input_tokens} in, {output_tokens} out, ${estimated_cost_usd:.6f}") + + def _estimate_token_cost(self, model: str, input_tokens: int, output_tokens: int) -> float: + """Estimate cost based on model and token counts. + + Pricing as of January 2025 (prices may change). + + Args: + model: Model name + input_tokens: Input token count + output_tokens: Output token count + + Returns: + Estimated cost in USD + """ + # Pricing per 1M tokens (input, output) + pricing = { + # OpenAI + "gpt-4": (30.0, 60.0), + "gpt-4-turbo": (10.0, 30.0), + "gpt-4o": (2.5, 10.0), + "gpt-4o-mini": (0.15, 0.60), + "gpt-3.5-turbo": (0.50, 1.50), + # Anthropic + "claude-3-opus": (15.0, 75.0), + "claude-3-sonnet": (3.0, 15.0), + "claude-3-haiku": (0.25, 1.25), + "claude-3.5-sonnet": (3.0, 15.0), + "claude-3.5-haiku": (0.80, 4.0), + # Fallback for unknown models + "default": (1.0, 3.0), + } + + # Find matching pricing (case-insensitive, partial match) + model_lower = model.lower() + input_price, output_price = pricing.get("default") + + for model_key, prices in pricing.items(): + if model_key in model_lower: + input_price, output_price = prices + break + + # Calculate cost (pricing is per 1M tokens) + input_cost = (input_tokens / 1_000_000) * input_price + output_cost = (output_tokens / 1_000_000) * output_price + + return input_cost + output_cost + + # ============================== + # Agent-to-Agent (A2A) Tracing + # ============================== + + @contextmanager + def trace_a2a_request( + self, + db: Session, + agent_id: str, + agent_name: Optional[str] = None, + operation: Optional[str] = None, + request_data: Optional[Dict[str, Any]] = None, + ): + """Context manager for tracing Agent-to-Agent requests. + + This automatically creates a span for A2A communication, capturing timing, + request/response data, and errors. + + Args: + db: Database session + agent_id: Target agent ID + agent_name: Human-readable agent name + operation: Operation being performed (e.g., "query", "execute", "status") + request_data: Request payload (will be sanitized) + + Yields: + Tuple of (span_id, result_dict) - update result_dict with A2A results + + Raises: + Exception: Re-raises any exception from A2A call after logging + + Examples: + >>> with service.trace_a2a_request(db, "agent-123", "WeatherAgent", "query") as (span_id, result): # doctest: +SKIP + ... response = await http_client.post(...) # doctest: +SKIP + ... result["status_code"] = response.status_code # doctest: +SKIP + ... result["response_time_ms"] = 45.2 # doctest: +SKIP + """ + trace_id = current_trace_id.get() + if not trace_id: + # No active trace, yield a no-op + result_dict: Dict[str, Any] = {} + yield (None, result_dict) + return + + # Sanitize request data + safe_data = {} + if request_data: + safe_data = {k: ("***REDACTED***" if any(sensitive in k.lower() for sensitive in ["password", "token", "key", "secret", "auth"]) else v) for k, v in request_data.items()} + + # Start A2A span + span_id = self.start_span( + db=db, + trace_id=trace_id, + name=f"a2a.call.{agent_name or agent_id}", + kind="client", + resource_type="agent", + resource_name=agent_name or agent_id, + attributes={ + "a2a.agent_id": agent_id, + "a2a.agent_name": agent_name, + "a2a.operation": operation, + "a2a.request_data": safe_data, + }, + ) + + result_dict = {} + try: + yield (span_id, result_dict) + + # End span with results + self.end_span( + db=db, + span_id=span_id, + status="ok", + attributes={ + "a2a.result": result_dict, + }, + ) + except Exception as e: + # Log error in span + self.end_span(db=db, span_id=span_id, status="error", status_message=str(e)) + + self.add_event( + db=db, + span_id=span_id, + name="a2a.error", + severity="error", + message=str(e), + exception_type=type(e).__name__, + exception_message=str(e), + exception_stacktrace=traceback.format_exc(), + ) + raise + + # ============================== + # Transport Metrics + # ============================== + + def record_transport_activity( + self, + db: Session, + transport_type: str, + operation: str, + message_count: int = 1, + bytes_sent: Optional[int] = None, + bytes_received: Optional[int] = None, + connection_id: Optional[str] = None, + error: Optional[str] = None, + ) -> None: + """Record transport-specific activity metrics. + + Args: + db: Database session + transport_type: Transport type (sse, websocket, stdio, http) + operation: Operation type (connect, disconnect, send, receive, error) + message_count: Number of messages processed + bytes_sent: Bytes sent (if applicable) + bytes_received: Bytes received (if applicable) + connection_id: Connection/session identifier + error: Error message if operation failed + + Examples: + >>> service.record_transport_activity( # doctest: +SKIP + ... db, transport_type="sse", + ... operation="send", + ... message_count=1, + ... bytes_sent=1024 + ... ) + """ + trace_id = current_trace_id.get() + + # Record message count + if message_count > 0: + self.record_metric( + db=db, + name=f"transport.{transport_type}.messages", + value=float(message_count), + metric_type="counter", + unit="messages", + trace_id=trace_id, + attributes={ + "transport": transport_type, + "operation": operation, + "connection_id": connection_id, + }, + ) + + # Record bytes sent + if bytes_sent: + self.record_metric( + db=db, + name=f"transport.{transport_type}.bytes_sent", + value=float(bytes_sent), + metric_type="counter", + unit="bytes", + trace_id=trace_id, + attributes={ + "transport": transport_type, + "operation": operation, + "connection_id": connection_id, + }, + ) + + # Record bytes received + if bytes_received: + self.record_metric( + db=db, + name=f"transport.{transport_type}.bytes_received", + value=float(bytes_received), + metric_type="counter", + unit="bytes", + trace_id=trace_id, + attributes={ + "transport": transport_type, + "operation": operation, + "connection_id": connection_id, + }, + ) + + # Record errors + if error: + self.record_metric( + db=db, + name=f"transport.{transport_type}.errors", + value=1.0, + metric_type="counter", + unit="errors", + trace_id=trace_id, + attributes={ + "transport": transport_type, + "operation": operation, + "connection_id": connection_id, + "error": error, + }, + ) + + logger.debug(f"Recorded {transport_type} transport activity: {operation} ({message_count} messages)") + + # ============================== + # Metric Management + # ============================== + + def record_metric( + self, + db: Session, + name: str, + value: float, + metric_type: str = "gauge", + unit: Optional[str] = None, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + trace_id: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None, + ) -> int: + """Record a metric. + + Args: + db: Database session + name: Metric name (e.g., "http.request.duration") + value: Metric value + metric_type: Metric type (counter, gauge, histogram) + unit: Metric unit (ms, count, bytes, etc.) + resource_type: Resource type + resource_id: Resource ID + trace_id: Associated trace ID + attributes: Additional metric attributes/labels + + Returns: + Metric ID + + Examples: + >>> metric_id = service.record_metric( # doctest: +SKIP + ... db, # doctest: +SKIP + ... "http.request.duration", # doctest: +SKIP + ... 123.45, # doctest: +SKIP + ... metric_type="histogram", # doctest: +SKIP + ... unit="ms", # doctest: +SKIP + ... trace_id=trace_id # doctest: +SKIP + ... ) # doctest: +SKIP + """ + metric = ObservabilityMetric( + name=name, + value=value, + metric_type=metric_type, + timestamp=utc_now(), + unit=unit, + resource_type=resource_type, + resource_id=resource_id, + trace_id=trace_id, + attributes=attributes or {}, + created_at=utc_now(), + ) + db.add(metric) + db.commit() + db.refresh(metric) + logger.debug(f"Recorded metric: {name} = {value} {unit or ''}") + return metric.id + + # ============================== + # Query Methods + # ============================== + + # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals + def query_traces( + self, + db: Session, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + min_duration_ms: Optional[float] = None, + max_duration_ms: Optional[float] = None, + status: Optional[str] = None, + status_in: Optional[List[str]] = None, + status_not_in: Optional[List[str]] = None, + http_status_code: Optional[int] = None, + http_status_code_in: Optional[List[int]] = None, + http_method: Optional[str] = None, + http_method_in: Optional[List[str]] = None, + user_email: Optional[str] = None, + user_email_in: Optional[List[str]] = None, + attribute_filters: Optional[Dict[str, Any]] = None, + attribute_filters_or: Optional[Dict[str, Any]] = None, + attribute_search: Optional[str] = None, + name_contains: Optional[str] = None, + order_by: str = "start_time_desc", + limit: int = 100, + offset: int = 0, + ) -> List[ObservabilityTrace]: + """Query traces with advanced filters. + + Supports both simple filters (single value) and list filters (multiple values with OR logic). + All top-level filters are combined with AND logic unless using _or suffix. + + Args: + db: Database session + start_time: Filter traces after this time + end_time: Filter traces before this time + min_duration_ms: Filter traces with duration >= this value (milliseconds) + max_duration_ms: Filter traces with duration <= this value (milliseconds) + status: Filter by single status (ok, error) + status_in: Filter by multiple statuses (OR logic) + status_not_in: Exclude these statuses (NOT logic) + http_status_code: Filter by single HTTP status code + http_status_code_in: Filter by multiple HTTP status codes (OR logic) + http_method: Filter by single HTTP method (GET, POST, etc.) + http_method_in: Filter by multiple HTTP methods (OR logic) + user_email: Filter by single user email + user_email_in: Filter by multiple user emails (OR logic) + attribute_filters: JSON attribute filters (AND logic - all must match) + attribute_filters_or: JSON attribute filters (OR logic - any must match) + attribute_search: Free-text search within JSON attributes (partial match) + name_contains: Filter traces where name contains this substring + order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc) + limit: Maximum results (1-1000) + offset: Result offset + + Returns: + List of traces + + Raises: + ValueError: If invalid parameters are provided + + Examples: + >>> # Find slow errors from multiple endpoints + >>> traces = service.query_traces( # doctest: +SKIP + ... db, + ... status="error", + ... min_duration_ms=100.0, + ... http_method_in=["POST", "PUT"], + ... attribute_filters={"http.route": "/api/tools"}, + ... limit=50 + ... ) + >>> # Exclude health checks and find slow requests + >>> traces = service.query_traces( # doctest: +SKIP + ... db, + ... min_duration_ms=1000.0, + ... name_contains="api", + ... status_not_in=["ok"], + ... order_by="duration_desc" + ... ) + """ + # Third-Party + # pylint: disable=import-outside-toplevel + from sqlalchemy import cast, or_, String + + # pylint: enable=import-outside-toplevel + # Validate limit + if limit < 1 or limit > 1000: + raise ValueError("limit must be between 1 and 1000") + + # Validate order_by + valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"] + if order_by not in valid_orders: + raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}") + + query = db.query(ObservabilityTrace) + + # Time range filters + if start_time: + query = query.filter(ObservabilityTrace.start_time >= start_time) + if end_time: + query = query.filter(ObservabilityTrace.start_time <= end_time) + + # Duration filters + if min_duration_ms is not None: + query = query.filter(ObservabilityTrace.duration_ms >= min_duration_ms) + if max_duration_ms is not None: + query = query.filter(ObservabilityTrace.duration_ms <= max_duration_ms) + + # Status filters (with OR and NOT support) + if status: + query = query.filter(ObservabilityTrace.status == status) + if status_in: + query = query.filter(ObservabilityTrace.status.in_(status_in)) + if status_not_in: + query = query.filter(~ObservabilityTrace.status.in_(status_not_in)) + + # HTTP status code filters (with OR support) + if http_status_code: + query = query.filter(ObservabilityTrace.http_status_code == http_status_code) + if http_status_code_in: + query = query.filter(ObservabilityTrace.http_status_code.in_(http_status_code_in)) + + # HTTP method filters (with OR support) + if http_method: + query = query.filter(ObservabilityTrace.http_method == http_method) + if http_method_in: + query = query.filter(ObservabilityTrace.http_method.in_(http_method_in)) + + # User email filters (with OR support) + if user_email: + query = query.filter(ObservabilityTrace.user_email == user_email) + if user_email_in: + query = query.filter(ObservabilityTrace.user_email.in_(user_email_in)) + + # Name substring filter + if name_contains: + query = query.filter(ObservabilityTrace.name.ilike(f"%{name_contains}%")) + + # Attribute-based filtering with AND logic (all filters must match) + if attribute_filters: + for key, value in attribute_filters.items(): + # Use JSON path access for filtering + # Supports both SQLite (via json_extract) and PostgreSQL (via ->>) + query = query.filter(ObservabilityTrace.attributes[key].astext == str(value)) + + # Attribute-based filtering with OR logic (any filter must match) + if attribute_filters_or: + or_conditions = [] + for key, value in attribute_filters_or.items(): + or_conditions.append(ObservabilityTrace.attributes[key].astext == str(value)) + if or_conditions: + query = query.filter(or_(*or_conditions)) + + # Free-text search across all attribute values + if attribute_search: + # Cast JSON attributes to text and search for substring + # Works with both SQLite and PostgreSQL + # Escape special characters to prevent SQL injection + safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_") + query = query.filter(cast(ObservabilityTrace.attributes, String).ilike(f"%{safe_search}%")) + + # Apply ordering + if order_by == "start_time_desc": + query = query.order_by(desc(ObservabilityTrace.start_time)) + elif order_by == "start_time_asc": + query = query.order_by(ObservabilityTrace.start_time) + elif order_by == "duration_desc": + query = query.order_by(desc(ObservabilityTrace.duration_ms)) + elif order_by == "duration_asc": + query = query.order_by(ObservabilityTrace.duration_ms) + + # Apply pagination + query = query.limit(limit).offset(offset) + + return query.all() + + # pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-locals + def query_spans( + self, + db: Session, + trace_id: Optional[str] = None, + trace_id_in: Optional[List[str]] = None, + resource_type: Optional[str] = None, + resource_type_in: Optional[List[str]] = None, + resource_name: Optional[str] = None, + resource_name_in: Optional[List[str]] = None, + name_contains: Optional[str] = None, + kind: Optional[str] = None, + kind_in: Optional[List[str]] = None, + status: Optional[str] = None, + status_in: Optional[List[str]] = None, + status_not_in: Optional[List[str]] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + min_duration_ms: Optional[float] = None, + max_duration_ms: Optional[float] = None, + attribute_filters: Optional[Dict[str, Any]] = None, + attribute_search: Optional[str] = None, + order_by: str = "start_time_desc", + limit: int = 100, + offset: int = 0, + ) -> List[ObservabilitySpan]: + """Query spans with advanced filters. + + Supports filtering by trace, resource, kind, status, duration, and attributes. + All top-level filters are combined with AND logic. List filters use OR logic. + + Args: + db: Database session + trace_id: Filter by single trace ID + trace_id_in: Filter by multiple trace IDs (OR logic) + resource_type: Filter by single resource type (tool, database, plugin, etc.) + resource_type_in: Filter by multiple resource types (OR logic) + resource_name: Filter by single resource name + resource_name_in: Filter by multiple resource names (OR logic) + name_contains: Filter spans where name contains this substring + kind: Filter by span kind (client, server, internal) + kind_in: Filter by multiple kinds (OR logic) + status: Filter by single status (ok, error) + status_in: Filter by multiple statuses (OR logic) + status_not_in: Exclude these statuses (NOT logic) + start_time: Filter spans after this time + end_time: Filter spans before this time + min_duration_ms: Filter spans with duration >= this value (milliseconds) + max_duration_ms: Filter spans with duration <= this value (milliseconds) + attribute_filters: JSON attribute filters (AND logic) + attribute_search: Free-text search within JSON attributes + order_by: Sort order (start_time_desc, start_time_asc, duration_desc, duration_asc) + limit: Maximum results (1-1000) + offset: Result offset + + Returns: + List of spans + + Raises: + ValueError: If invalid parameters are provided + + Examples: + >>> # Find slow database queries + >>> spans = service.query_spans( # doctest: +SKIP + ... db, + ... resource_type="database", + ... min_duration_ms=100.0, + ... order_by="duration_desc", + ... limit=50 + ... ) + >>> # Find tool invocation errors + >>> spans = service.query_spans( # doctest: +SKIP + ... db, + ... resource_type="tool", + ... status="error", + ... name_contains="invoke" + ... ) + """ + # Third-Party + # pylint: disable=import-outside-toplevel + from sqlalchemy import cast, String + + # pylint: enable=import-outside-toplevel + # Validate limit + if limit < 1 or limit > 1000: + raise ValueError("limit must be between 1 and 1000") + + # Validate order_by + valid_orders = ["start_time_desc", "start_time_asc", "duration_desc", "duration_asc"] + if order_by not in valid_orders: + raise ValueError(f"order_by must be one of: {', '.join(valid_orders)}") + + query = db.query(ObservabilitySpan) + + # Trace ID filters (with OR support) + if trace_id: + query = query.filter(ObservabilitySpan.trace_id == trace_id) + if trace_id_in: + query = query.filter(ObservabilitySpan.trace_id.in_(trace_id_in)) + + # Resource type filters (with OR support) + if resource_type: + query = query.filter(ObservabilitySpan.resource_type == resource_type) + if resource_type_in: + query = query.filter(ObservabilitySpan.resource_type.in_(resource_type_in)) + + # Resource name filters (with OR support) + if resource_name: + query = query.filter(ObservabilitySpan.resource_name == resource_name) + if resource_name_in: + query = query.filter(ObservabilitySpan.resource_name.in_(resource_name_in)) + + # Name substring filter + if name_contains: + query = query.filter(ObservabilitySpan.name.ilike(f"%{name_contains}%")) + + # Kind filters (with OR support) + if kind: + query = query.filter(ObservabilitySpan.kind == kind) + if kind_in: + query = query.filter(ObservabilitySpan.kind.in_(kind_in)) + + # Status filters (with OR and NOT support) + if status: + query = query.filter(ObservabilitySpan.status == status) + if status_in: + query = query.filter(ObservabilitySpan.status.in_(status_in)) + if status_not_in: + query = query.filter(~ObservabilitySpan.status.in_(status_not_in)) + + # Time range filters + if start_time: + query = query.filter(ObservabilitySpan.start_time >= start_time) + if end_time: + query = query.filter(ObservabilitySpan.start_time <= end_time) + + # Duration filters + if min_duration_ms is not None: + query = query.filter(ObservabilitySpan.duration_ms >= min_duration_ms) + if max_duration_ms is not None: + query = query.filter(ObservabilitySpan.duration_ms <= max_duration_ms) + + # Attribute-based filtering with AND logic + if attribute_filters: + for key, value in attribute_filters.items(): + query = query.filter(ObservabilitySpan.attributes[key].astext == str(value)) + + # Free-text search across all attribute values + if attribute_search: + safe_search = attribute_search.replace("%", "\\%").replace("_", "\\_") + query = query.filter(cast(ObservabilitySpan.attributes, String).ilike(f"%{safe_search}%")) + + # Apply ordering + if order_by == "start_time_desc": + query = query.order_by(desc(ObservabilitySpan.start_time)) + elif order_by == "start_time_asc": + query = query.order_by(ObservabilitySpan.start_time) + elif order_by == "duration_desc": + query = query.order_by(desc(ObservabilitySpan.duration_ms)) + elif order_by == "duration_asc": + query = query.order_by(ObservabilitySpan.duration_ms) + + # Apply pagination + query = query.limit(limit).offset(offset) + + return query.all() + + def get_trace_with_spans(self, db: Session, trace_id: str) -> Optional[ObservabilityTrace]: + """Get a complete trace with all spans and events. + + Args: + db: Database session + trace_id: Trace ID + + Returns: + Trace with spans and events loaded + + Examples: + >>> trace = service.get_trace_with_spans(db, trace_id) # doctest: +SKIP + >>> if trace: # doctest: +SKIP + ... for span in trace.spans: # doctest: +SKIP + ... print(f"Span: {span.name}, Events: {len(span.events)}") # doctest: +SKIP + """ + return db.query(ObservabilityTrace).filter_by(trace_id=trace_id).options(joinedload(ObservabilityTrace.spans).joinedload(ObservabilitySpan.events)).first() + + def delete_old_traces(self, db: Session, before_time: datetime) -> int: + """Delete traces older than a given time. + + Args: + db: Database session + before_time: Delete traces before this time + + Returns: + Number of traces deleted + + Examples: + >>> from datetime import timedelta # doctest: +SKIP + >>> cutoff = utc_now() - timedelta(days=30) # doctest: +SKIP + >>> deleted = service.delete_old_traces(db, cutoff) # doctest: +SKIP + >>> print(f"Deleted {deleted} old traces") # doctest: +SKIP + """ + deleted = db.query(ObservabilityTrace).filter(ObservabilityTrace.start_time < before_time).delete() + db.commit() + logger.info(f"Deleted {deleted} traces older than {before_time}") + return deleted diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 492cafa47..8cb936b05 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -39,6 +39,7 @@ from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.observability_service import current_trace_id, ObservabilityService from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -690,7 +691,33 @@ async def get_prompt( error_message = None prompt = None - # Create a trace span for prompt rendering + # Create database span for observability dashboard + trace_id = current_trace_id.get() + db_span_id = None + db_span_ended = False + observability_service = ObservabilityService() if trace_id else None + + if trace_id and observability_service: + try: + db_span_id = observability_service.start_span( + db=db, + trace_id=trace_id, + name="prompt.render", + attributes={ + "prompt.id": str(prompt_id), + "arguments_count": len(arguments) if arguments else 0, + "user": user or "anonymous", + "server_id": server_id, + "tenant_id": tenant_id, + "request_id": request_id or "none", + }, + ) + logger.debug(f"βœ“ Created prompt.render span: {db_span_id} for prompt: {prompt_id}") + except Exception as e: + logger.warning(f"Failed to start observability span for prompt rendering: {e}") + db_span_id = None + + # Create a trace span for OpenTelemetry export (Jaeger, Zipkin, etc.) with create_span( "prompt.render", { @@ -824,6 +851,20 @@ async def get_prompt( except Exception as metrics_error: logger.warning(f"Failed to record prompt metric: {metrics_error}") + # End database span for observability dashboard + if db_span_id and observability_service and not db_span_ended: + try: + observability_service.end_span( + db=db, + span_id=db_span_id, + status="ok" if success else "error", + status_message=error_message if error_message else None, + ) + db_span_ended = True + logger.debug(f"βœ“ Ended prompt.render span: {db_span_id}") + except Exception as e: + logger.warning(f"Failed to end observability span for prompt rendering: {e}") + async def update_prompt( self, db: Session, diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 9e9baa274..97ea6e250 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -51,6 +51,7 @@ from mcpgateway.observability import create_span from mcpgateway.schemas import ResourceCreate, ResourceMetrics, ResourceRead, ResourceSubscription, ResourceUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.observability_service import current_trace_id, ObservabilityService from mcpgateway.utils.metrics_common import build_top_performers from mcpgateway.utils.pagination import decode_cursor, encode_cursor from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr @@ -722,7 +723,34 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request resource = None resource_db = db.get(DbResource, resource_id) uri = resource_db.uri if resource_db else None - # Create trace span for resource reading + + # Create database span for observability dashboard + trace_id = current_trace_id.get() + db_span_id = None + db_span_ended = False + observability_service = ObservabilityService() if trace_id else None + + if trace_id and observability_service: + try: + db_span_id = observability_service.start_span( + db=db, + trace_id=trace_id, + name="resource.read", + attributes={ + "resource.uri": str(uri) if uri else "unknown", + "user": user or "anonymous", + "server_id": server_id, + "request_id": request_id, + "http.url": uri if uri is not None and uri.startswith("http") else None, + "resource.type": "template" if (uri is not None and "{" in uri and "}" in uri) else "static", + }, + ) + logger.debug(f"βœ“ Created resource.read span: {db_span_id} for resource: {uri}") + except Exception as e: + logger.warning(f"Failed to start observability span for resource reading: {e}") + db_span_id = None + + # Create trace span for OpenTelemetry export (Jaeger, Zipkin, etc.) with create_span( "resource.read", { @@ -847,6 +875,20 @@ async def read_resource(self, db: Session, resource_id: Union[int, str], request except Exception as metrics_error: logger.warning(f"Failed to record resource metric: {metrics_error}") + # End database span for observability dashboard + if db_span_id and observability_service and not db_span_ended: + try: + observability_service.end_span( + db=db, + span_id=db_span_id, + status="ok" if success else "error", + status_message=error_message if error_message else None, + ) + db_span_ended = True + logger.debug(f"βœ“ Ended resource.read span: {db_span_id}") + except Exception as e: + logger.warning(f"Failed to end observability span for resource reading: {e}") + async def toggle_resource_status(self, db: Session, resource_id: int, activate: bool, user_email: Optional[str] = None) -> ResourceRead: """ Toggle the activation status of a resource. diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 96f98d50a..8501409bd 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -41,7 +41,6 @@ from mcpgateway.common.models import Gateway as PydanticGateway from mcpgateway.common.models import TextContent from mcpgateway.common.models import Tool as PydanticTool -from mcpgateway.common.models import ToolResult from mcpgateway.config import settings from mcpgateway.db import A2AAgent as DbA2AAgent from mcpgateway.db import EmailTeam @@ -49,12 +48,14 @@ from mcpgateway.db import server_tool_association from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric -from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginError, PluginManager, PluginViolationError, ToolHookType, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import GlobalContext, PluginError, PluginManager, PluginViolationError from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA -from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer +from mcpgateway.plugins.framework.hooks.http import HttpHeaderPayload +from mcpgateway.plugins.framework.hooks.tools import ToolHookType, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.schemas import ToolCreate, ToolRead, ToolResult, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager +from mcpgateway.services.observability_service import current_trace_id, ObservabilityService from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.utils.create_slug import slugify from mcpgateway.utils.display_name import generate_display_name @@ -1139,266 +1140,292 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r success = False error_message = None - # Create a trace span for the tool invocation - with create_span( - "tool.invoke", - { - "tool.name": name, - "tool.id": str(tool.id) if tool else "unknown", - "tool.integration_type": tool.integration_type if tool else "unknown", - "tool.gateway_id": str(tool.gateway_id) if tool and tool.gateway_id else None, - "arguments_count": len(arguments) if arguments else 0, - "has_headers": bool(request_headers), - }, - ) as span: - try: - # Get combined headers for the tool including base headers, auth, and passthrough headers - # headers = self._get_combined_headers(db, tool, tool.headers or {}, request_headers) - headers = tool.headers or {} - if tool.integration_type == "REST": - # Handle OAuth authentication for REST tools - if tool.auth_type == "oauth" and hasattr(tool, "oauth_config") and tool.oauth_config: - try: - access_token = await self.oauth_manager.get_access_token(tool.oauth_config) - headers["Authorization"] = f"Bearer {access_token}" - except Exception as e: - logger.error(f"Failed to obtain OAuth access token for tool {tool.name}: {e}") - raise ToolInvocationError(f"OAuth authentication failed: {str(e)}") - else: - credentials = decode_auth(tool.auth_value) - # Filter out empty header names/values to avoid "Illegal header name" errors - filtered_credentials = {k: v for k, v in credentials.items() if k and v} - headers.update(filtered_credentials) - - # Only call get_passthrough_headers if we actually have request headers to pass through - if request_headers: - headers = get_passthrough_headers(request_headers, headers, db) - - if self._plugin_manager: - tool_metadata = PydanticTool.model_validate(tool) - global_context.metadata[TOOL_METADATA] = tool_metadata - pre_result, context_table = await self._plugin_manager.invoke_hook( - ToolHookType.TOOL_PRE_INVOKE, - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), - global_context=global_context, - local_contexts=None, - violations_as_exceptions=True, - ) - if pre_result.modified_payload: - payload = pre_result.modified_payload - name = payload.name - arguments = payload.args - if payload.headers is not None: - headers = payload.headers.model_dump() - - # Build the payload based on integration type - payload = arguments.copy() - - # Handle URL path parameter substitution - final_url = tool.url - if "{" in tool.url and "}" in tool.url: - # Extract path parameters from URL template and arguments - url_params = re.findall(r"\{(\w+)\}", tool.url) - url_substitutions = {} - - for param in url_params: - if param in payload: - url_substitutions[param] = payload.pop(param) # Remove from payload - final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param])) - else: - raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments") + # Create a trace span for the tool invocation using ObservabilityService + trace_id = current_trace_id.get() + span_id = None + span_ended = False + observability_service = ObservabilityService() if trace_id else None - # --- Extract query params from URL --- - parsed = urlparse(final_url) - final_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}" + logger.debug(f"Tool invocation trace_id: {trace_id}, tool: {name}") - query_params = {k: v[0] for k, v in parse_qs(parsed.query).items()} + if trace_id and observability_service: + try: + span_id = observability_service.start_span( + db=db, + trace_id=trace_id, + name="tool.invoke", + attributes={ + "tool.name": name, + "tool.id": str(tool.id) if tool else "unknown", + "tool.integration_type": tool.integration_type if tool else "unknown", + "tool.gateway_id": str(tool.gateway_id) if tool and tool.gateway_id else None, + "arguments_count": len(arguments) if arguments else 0, + "has_headers": bool(request_headers), + }, + ) + logger.info(f"βœ“ Created tool.invoke span: {span_id} for tool: {name}") + except Exception as e: + logger.warning(f"Failed to start observability span for tool invocation: {e}") + span_id = None + else: + logger.debug(f"Skipping span creation - trace_id: {trace_id}, observability_service: {observability_service}") - # Merge leftover payload + query params - payload.update(query_params) + try: + # Get combined headers for the tool including base headers, auth, and passthrough headers + # headers = self._get_combined_headers(db, tool, tool.headers or {}, request_headers) + headers = tool.headers or {} + if tool.integration_type == "REST": + # Handle OAuth authentication for REST tools + if tool.auth_type == "oauth" and hasattr(tool, "oauth_config") and tool.oauth_config: + try: + access_token = await self.oauth_manager.get_access_token(tool.oauth_config) + headers["Authorization"] = f"Bearer {access_token}" + except Exception as e: + logger.error(f"Failed to obtain OAuth access token for tool {tool.name}: {e}") + raise ToolInvocationError(f"OAuth authentication failed: {str(e)}") + else: + credentials = decode_auth(tool.auth_value) + # Filter out empty header names/values to avoid "Illegal header name" errors + filtered_credentials = {k: v for k, v in credentials.items() if k and v} + headers.update(filtered_credentials) - # Use the tool's request_type rather than defaulting to POST. - method = tool.request_type.upper() - if method == "GET": - response = await self._http_client.get(final_url, params=payload, headers=headers) - else: - response = await self._http_client.request(method, final_url, json=payload, headers=headers) - response.raise_for_status() - - # Handle 204 No Content responses that have no body - if response.status_code == 204: - tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")]) - success = True - elif response.status_code not in [200, 201, 202, 206]: - result = response.json() - tool_result = ToolResult( - content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")], - is_error=True, - ) - # Don't mark as successful for error responses - success remains False - else: - result = response.json() - filtered_response = extract_using_jq(result, tool.jsonpath_filter) - tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))]) - success = True - - # If output schema is present, validate and attach structured content - if getattr(tool, "output_schema", None): - valid = self._extract_and_validate_structured_content(tool, tool_result, candidate=filtered_response) - success = bool(valid) - - elif tool.integration_type == "MCP": - transport = tool.request_type.lower() - gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none() - - # Handle OAuth authentication for the gateway - if gateway and gateway.auth_type == "oauth" and gateway.oauth_config: - grant_type = gateway.oauth_config.get("grant_type", "client_credentials") - - if grant_type == "authorization_code": - # For Authorization Code flow, try to get stored tokens - try: - # First-Party - from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel - - token_storage = TokenStorageService(db) - - # Get user-specific OAuth token - if not app_user_email: - raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway.name}'. Please ensure you are authenticated.") - - access_token = await token_storage.get_user_token(gateway.id, app_user_email) - - if access_token: - headers = {"Authorization": f"Bearer {access_token}"} - else: - # User hasn't authorized this gateway yet - raise ToolInvocationError(f"Please authorize {gateway.name} first. Visit /oauth/authorize/{gateway.id} to complete OAuth flow.") - except Exception as e: - logger.error(f"Failed to obtain stored OAuth token for gateway {gateway.name}: {e}") - raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}") + # Only call get_passthrough_headers if we actually have request headers to pass through + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db) + + if self._plugin_manager: + tool_metadata = PydanticTool.model_validate(tool) + global_context.metadata[TOOL_METADATA] = tool_metadata + pre_result, context_table = await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), + global_context=global_context, + local_contexts=None, + violations_as_exceptions=True, + ) + if pre_result.modified_payload: + payload = pre_result.modified_payload + name = payload.name + arguments = payload.args + if payload.headers is not None: + headers = payload.headers.model_dump() + + # Build the payload based on integration type + payload = arguments.copy() + + # Handle URL path parameter substitution + final_url = tool.url + if "{" in tool.url and "}" in tool.url: + # Extract path parameters from URL template and arguments + url_params = re.findall(r"\{(\w+)\}", tool.url) + url_substitutions = {} + + for param in url_params: + if param in payload: + url_substitutions[param] = payload.pop(param) # Remove from payload + final_url = final_url.replace(f"{{{param}}}", str(url_substitutions[param])) else: - # For Client Credentials flow, get token directly - try: - access_token = await self.oauth_manager.get_access_token(gateway.oauth_config) - headers = {"Authorization": f"Bearer {access_token}"} - except Exception as e: - logger.error(f"Failed to obtain OAuth access token for gateway {gateway.name}: {e}") - raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") - else: - headers = decode_auth(gateway.auth_value if gateway else None) - - # Get combined headers including gateway auth and passthrough - if request_headers: - headers = get_passthrough_headers(request_headers, headers, db, gateway) - - async def connect_to_sse_server(server_url: str, headers: dict = headers): - """Connect to an MCP server running with SSE transport. - - Args: - server_url: MCP Server SSE URL - headers: HTTP headers to include in the request - - Returns: - ToolResult: Result of tool call - """ - async with sse_client(url=server_url, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result - - async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers): - """Connect to an MCP server running with Streamable HTTP transport. - - Args: - server_url: MCP Server URL - headers: HTTP headers to include in the request - - Returns: - ToolResult: Result of tool call - """ - async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - tool_call_result = await session.call_tool(tool.original_name, arguments) - return tool_call_result - - tool_gateway_id = tool.gateway_id - tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none() - - if self._plugin_manager: - tool_metadata = PydanticTool.model_validate(tool) - global_context.metadata[TOOL_METADATA] = tool_metadata - if tool_gateway: - gateway_metadata = PydanticGateway.model_validate(tool_gateway) - global_context.metadata[GATEWAY_METADATA] = gateway_metadata - pre_result, context_table = await self._plugin_manager.invoke_hook( - ToolHookType.TOOL_PRE_INVOKE, - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), - global_context=global_context, - local_contexts=None, - violations_as_exceptions=True, - ) - if pre_result.modified_payload: - payload = pre_result.modified_payload - name = payload.name - arguments = payload.args - if payload.headers is not None: - headers = payload.headers.model_dump() - - tool_call_result = ToolResult(content=[TextContent(text="", type="text")]) - if transport == "sse": - tool_call_result = await connect_to_sse_server(tool_gateway.url, headers=headers) - elif transport == "streamablehttp": - tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url, headers=headers) - content = tool_call_result.model_dump(by_alias=True).get("content", []) - - filtered_response = extract_using_jq(content, tool.jsonpath_filter) - tool_result = ToolResult(content=filtered_response) + raise ToolInvocationError(f"Required URL parameter '{param}' not found in arguments") + + # --- Extract query params from URL --- + parsed = urlparse(final_url) + final_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}" + + query_params = {k: v[0] for k, v in parse_qs(parsed.query).items()} + + # Merge leftover payload + query params + payload.update(query_params) + + # Use the tool's request_type rather than defaulting to POST. + method = tool.request_type.upper() + if method == "GET": + response = await self._http_client.get(final_url, params=payload, headers=headers) + else: + response = await self._http_client.request(method, final_url, json=payload, headers=headers) + response.raise_for_status() + + # Handle 204 No Content responses that have no body + if response.status_code == 204: + tool_result = ToolResult(content=[TextContent(type="text", text="Request completed successfully (No Content)")]) + success = True + elif response.status_code not in [200, 201, 202, 206]: + result = response.json() + tool_result = ToolResult( + content=[TextContent(type="text", text=str(result["error"]) if "error" in result else "Tool error encountered")], + is_error=True, + ) + # Don't mark as successful for error responses - success remains False + else: + result = response.json() + filtered_response = extract_using_jq(result, tool.jsonpath_filter) + tool_result = ToolResult(content=[TextContent(type="text", text=json.dumps(filtered_response, indent=2))]) success = True + # If output schema is present, validate and attach structured content if getattr(tool, "output_schema", None): valid = self._extract_and_validate_structured_content(tool, tool_result, candidate=filtered_response) success = bool(valid) + + elif tool.integration_type == "MCP": + transport = tool.request_type.lower() + gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.enabled)).scalar_one_or_none() + + # Handle OAuth authentication for the gateway + if gateway and gateway.auth_type == "oauth" and gateway.oauth_config: + grant_type = gateway.oauth_config.get("grant_type", "client_credentials") + + if grant_type == "authorization_code": + # For Authorization Code flow, try to get stored tokens + try: + # First-Party + from mcpgateway.services.token_storage_service import TokenStorageService # pylint: disable=import-outside-toplevel + + token_storage = TokenStorageService(db) + + # Get user-specific OAuth token + if not app_user_email: + raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway.name}'. Please ensure you are authenticated.") + + access_token = await token_storage.get_user_token(gateway.id, app_user_email) + + if access_token: + headers = {"Authorization": f"Bearer {access_token}"} + else: + # User hasn't authorized this gateway yet + raise ToolInvocationError(f"Please authorize {gateway.name} first. Visit /oauth/authorize/{gateway.id} to complete OAuth flow.") + except Exception as e: + logger.error(f"Failed to obtain stored OAuth token for gateway {gateway.name}: {e}") + raise ToolInvocationError(f"OAuth token retrieval failed for gateway: {str(e)}") + else: + # For Client Credentials flow, get token directly + try: + access_token = await self.oauth_manager.get_access_token(gateway.oauth_config) + headers = {"Authorization": f"Bearer {access_token}"} + except Exception as e: + logger.error(f"Failed to obtain OAuth access token for gateway {gateway.name}: {e}") + raise ToolInvocationError(f"OAuth authentication failed for gateway: {str(e)}") else: - tool_result = ToolResult(content=[TextContent(type="text", text="Invalid tool type")]) + headers = decode_auth(gateway.auth_value if gateway else None) + + # Get combined headers including gateway auth and passthrough + if request_headers: + headers = get_passthrough_headers(request_headers, headers, db, gateway) + + async def connect_to_sse_server(server_url: str, headers: dict = headers): + """Connect to an MCP server running with SSE transport. + + Args: + server_url: MCP Server SSE URL + headers: HTTP headers to include in the request + + Returns: + ToolResult: Result of tool call + """ + async with sse_client(url=server_url, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + return tool_call_result + + async def connect_to_streamablehttp_server(server_url: str, headers: dict = headers): + """Connect to an MCP server running with Streamable HTTP transport. + + Args: + server_url: MCP Server URL + headers: HTTP headers to include in the request + + Returns: + ToolResult: Result of tool call + """ + async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, _get_session_id): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tool_call_result = await session.call_tool(tool.original_name, arguments) + return tool_call_result + + tool_gateway_id = tool.gateway_id + tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.enabled)).scalar_one_or_none() - # Plugin hook: tool post-invoke if self._plugin_manager: - post_result, _ = await self._plugin_manager.invoke_hook( - ToolHookType.TOOL_POST_INVOKE, - payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), + tool_metadata = PydanticTool.model_validate(tool) + global_context.metadata[TOOL_METADATA] = tool_metadata + if tool_gateway: + gateway_metadata = PydanticGateway.model_validate(tool_gateway) + global_context.metadata[GATEWAY_METADATA] = gateway_metadata + pre_result, context_table = await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), global_context=global_context, - local_contexts=context_table, + local_contexts=None, violations_as_exceptions=True, ) - # Use modified payload if provided - if post_result.modified_payload: - # Reconstruct ToolResult from modified result - modified_result = post_result.modified_payload.result - if isinstance(modified_result, dict) and "content" in modified_result: - tool_result = ToolResult(content=modified_result["content"]) - else: - # If result is not in expected format, convert it to text content - tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))]) + if pre_result.modified_payload: + payload = pre_result.modified_payload + name = payload.name + arguments = payload.args + if payload.headers is not None: + headers = payload.headers.model_dump() + + tool_call_result = ToolResult(content=[TextContent(text="", type="text")]) + if transport == "sse": + tool_call_result = await connect_to_sse_server(tool_gateway.url, headers=headers) + elif transport == "streamablehttp": + tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url, headers=headers) + content = tool_call_result.model_dump(by_alias=True).get("content", []) + + filtered_response = extract_using_jq(content, tool.jsonpath_filter) + tool_result = ToolResult(content=filtered_response) + success = True + # If output schema is present, validate and attach structured content + if getattr(tool, "output_schema", None): + valid = self._extract_and_validate_structured_content(tool, tool_result, candidate=filtered_response) + success = bool(valid) + else: + tool_result = ToolResult(content=[TextContent(type="text", text="Invalid tool type")]) + + # Plugin hook: tool post-invoke + if self._plugin_manager: + post_result, _ = await self._plugin_manager.tool_post_invoke( + payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=True, + ) + # Use modified payload if provided + if post_result.modified_payload: + # Reconstruct ToolResult from modified result + modified_result = post_result.modified_payload.result + if isinstance(modified_result, dict) and "content" in modified_result: + tool_result = ToolResult(content=modified_result["content"]) + else: + # If result is not in expected format, convert it to text content + tool_result = ToolResult(content=[TextContent(type="text", text=str(modified_result))]) - return tool_result - except (PluginError, PluginViolationError): - raise - except Exception as e: - error_message = str(e) - # Set span error status - if span: - span.set_attribute("error", True) - span.set_attribute("error.message", str(e)) - raise ToolInvocationError(f"Tool invocation failed: {error_message}") - finally: - # Add final span attributes - if span: - span.set_attribute("success", success) - span.set_attribute("duration.ms", (time.monotonic() - start_time) * 1000) - await self._record_tool_metric(db, tool, start_time, success, error_message) + return tool_result + except (PluginError, PluginViolationError): + raise + except Exception as e: + error_message = str(e) + raise ToolInvocationError(f"Tool invocation failed: {error_message}") + finally: + # End span with appropriate status + if span_id and observability_service and not span_ended: + try: + duration_ms = (time.monotonic() - start_time) * 1000 + observability_service.end_span( + db=db, + span_id=span_id, + status="ok" if success else "error", + attributes={ + "success": success, + "duration.ms": duration_ms, + "error.message": error_message if error_message else None, + }, + ) + span_ended = True + except Exception as span_error: + logger.warning(f"Failed to end observability span: {span_error}") + await self._record_tool_metric(db, tool, start_time, success, error_message) async def update_tool( self, diff --git a/mcpgateway/static/flame-graph.css b/mcpgateway/static/flame-graph.css new file mode 100644 index 000000000..74275a155 --- /dev/null +++ b/mcpgateway/static/flame-graph.css @@ -0,0 +1,213 @@ +/** + * Flame Graph Visualization Styles + * + * Provides styling for interactive flame graph traces showing + * execution hierarchy and performance hotspots. + */ + +/* Container */ +.flame-graph-container { + width: 100%; + background: #fff; + border: 1px solid #e5e7eb; + border-radius: 0.5rem; + overflow: hidden; +} + +/* Toolbar */ +.flame-toolbar { + display: flex; + justify-content: space-between; + align-items: center; + padding: 1rem; + background: #f9fafb; + border-bottom: 1px solid #e5e7eb; +} + +.flame-info { + display: flex; + align-items: center; + gap: 0.5rem; +} + +.flame-controls { + display: flex; + align-items: center; + gap: 0.5rem; +} + +/* Search Input */ +.flame-search { + padding: 0.5rem 0.75rem; + border: 1px solid #d1d5db; + border-radius: 0.375rem; + font-size: 0.875rem; + min-width: 200px; + transition: border-color 0.15s ease-in-out; +} + +.flame-search:focus { + outline: none; + border-color: #3b82f6; + box-shadow: 0 0 0 3px rgb(59 130 246 / 10%); +} + +/* Buttons */ +.flame-btn { + padding: 0.5rem 0.75rem; + background: #fff; + border: 1px solid #d1d5db; + border-radius: 0.375rem; + font-size: 0.875rem; + cursor: pointer; + transition: all 0.15s ease-in-out; + display: flex; + align-items: center; + gap: 0.25rem; +} + +.flame-btn:hover { + background: #f3f4f6; + border-color: #9ca3af; +} + +.flame-btn:active { + background: #e5e7eb; +} + +/* SVG Canvas */ +.flame-svg { + display: block; + width: 100%; + background: #fff; + cursor: default; +} + +/* Flame Graph Nodes */ +.flame-node { + transition: opacity 0.2s ease-in-out; +} + +.flame-node:hover { + opacity: 0.9; +} + +.flame-node rect { + transition: stroke-width 0.2s ease-in-out; +} + +.flame-node:hover rect { + stroke-width: 2px; + stroke: #000; +} + +/* Search Match Highlighting */ +.flame-node.search-match rect { + stroke: #000; + stroke-width: 2px; + filter: brightness(1.1); +} + +/* Legend */ +.flame-legend { + display: flex; + align-items: center; + gap: 1.5rem; + padding: 1rem; + background: #f9fafb; + border-top: 1px solid #e5e7eb; + flex-wrap: wrap; +} + +.legend-item { + display: flex; + align-items: center; + gap: 0.5rem; + font-size: 0.875rem; + color: #4b5563; +} + +.legend-color { + width: 1rem; + height: 1rem; + border-radius: 0.25rem; + border: 1px solid #d1d5db; +} + +/* Responsive Design */ +@media (width <= 768px) { + .flame-toolbar { + flex-direction: column; + align-items: flex-start; + gap: 1rem; + } + + .flame-controls { + width: 100%; + flex-direction: column; + } + + .flame-search { + width: 100%; + } + + .flame-legend { + gap: 0.75rem; + padding: 0.75rem; + } + + .legend-item { + font-size: 0.75rem; + } +} + +/* Print Styles */ +@media print { + .flame-toolbar, + .flame-legend { + background: #fff; + } + + .flame-btn { + display: none; + } + + .flame-search { + display: none; + } +} + +/* Accessibility */ +.flame-node:focus { + outline: 2px solid #3b82f6; + outline-offset: 2px; +} + +/* Loading State */ +.flame-graph-container.loading { + position: relative; + min-height: 400px; +} + +.flame-graph-container.loading::after { + content: "Loading flame graph..."; + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + color: #6b7280; + font-size: 0.875rem; +} + +/* Empty State */ +.flame-graph-container.empty { + padding: 2rem; + text-align: center; + color: #6b7280; +} + +/* Tooltip Styles (for SVG title fallback) */ +.flame-node title { + font-family: system-ui, -apple-system, sans-serif; + font-size: 12px; +} diff --git a/mcpgateway/static/flame-graph.js b/mcpgateway/static/flame-graph.js new file mode 100644 index 000000000..4620ffe67 --- /dev/null +++ b/mcpgateway/static/flame-graph.js @@ -0,0 +1,340 @@ +/** + * Interactive Flame Graph for Trace Visualization + * + * Features: + * - Stack-based visualization showing execution hierarchy + * - Click to zoom into specific spans + * - Search and highlight functionality + * - Hover tooltips with span details + * - Color-coding by span type + */ + +/* eslint-disable no-unused-vars */ +class FlameGraph { + constructor(containerId, traceData) { + this.container = document.getElementById(containerId); + this.trace = traceData; + this.spans = this.buildSpanTree(traceData.spans); + this.width = 0; + this.height = 0; + this.cellHeight = 20; + this.textPadding = 5; + this.rootNode = null; + this.currentRoot = null; + this.searchTerm = ""; + + this.init(); + } + + /** + * Build hierarchical span tree from flat span list + */ + buildSpanTree(spans) { + const spanMap = new Map(); + let rootSpan = null; + + // Create map of all spans + spans.forEach((span) => { + spanMap.set(span.span_id, { + ...span, + children: [], + }); + }); + + // Build tree structure + spans.forEach((span) => { + const node = spanMap.get(span.span_id); + if (span.parent_span_id && spanMap.has(span.parent_span_id)) { + const parent = spanMap.get(span.parent_span_id); + parent.children.push(node); + } else { + // This is a root node + if (!rootSpan || node.start_time < rootSpan.start_time) { + rootSpan = node; + } + } + }); + + // Calculate total duration for each node (self + children) + const calculateTotalDuration = (node) => { + const total = node.duration_ms || 0; + node.children.forEach((child) => { + calculateTotalDuration(child); + }); + node.totalDuration = total; + return total; + }; + + if (rootSpan) { + calculateTotalDuration(rootSpan); + } + + return rootSpan; + } + + /** + * Initialize the flame graph + */ + init() { + this.rootNode = this.spans; + this.currentRoot = this.rootNode; + this.render(); + } + + /** + * Render the complete flame graph + */ + render() { + if (!this.rootNode) { + this.container.innerHTML = ` +
+ No span data available for flame graph +
+ `; + return; + } + + // Calculate dimensions + this.width = this.container.clientWidth || 800; + const depth = this.calculateDepth(this.currentRoot); + this.height = depth * this.cellHeight + 100; + + const html = ` +
+ +
+
+ Flame Graph + + ${this.currentRoot.name} - ${this.currentRoot.duration_ms?.toFixed(2) || 0} ms + +
+
+ + +
+
+ + + + ${this.renderNode(this.currentRoot, 0, 0, this.width)} + + + +
+
+ + Client +
+
+ + Server +
+
+ + Internal +
+
+ + Error +
+
+ πŸ’‘ Click on any span to zoom in +
+
+
+ `; + + this.container.innerHTML = html; + } + + /** + * Calculate the maximum depth of the tree + */ + calculateDepth(node, currentDepth = 0) { + if (!node || !node.children || node.children.length === 0) { + return currentDepth + 1; + } + + let maxDepth = currentDepth + 1; + node.children.forEach((child) => { + const childDepth = this.calculateDepth(child, currentDepth + 1); + maxDepth = Math.max(maxDepth, childDepth); + }); + + return maxDepth; + } + + /** + * Render a single node and its children recursively + */ + renderNode(node, x, y, width, parentDuration = null) { + if (!node) { + return ""; + } + + const duration = node.duration_ms || 0; + const totalParentDuration = parentDuration || duration; + + // Calculate width based on duration percentage + const nodeWidth = (duration / totalParentDuration) * width; + + if (nodeWidth < 0.5) { + return ""; // Too small to render + } + + // Determine color based on span kind and search + const isSearchMatch = + this.searchTerm && + node.name.toLowerCase().includes(this.searchTerm.toLowerCase()); + let color = this.getSpanColor(node); + + if (isSearchMatch) { + color = "#f59e0b"; // Highlight search matches in orange + } + + // Truncate text if too long for the box + const availableWidth = nodeWidth - this.textPadding * 2; + const charWidth = 7; // Approximate character width + const maxChars = Math.floor(availableWidth / charWidth); + let displayText = node.name; + + if (displayText.length > maxChars && maxChars > 3) { + displayText = displayText.substring(0, maxChars - 3) + "..."; + } + + // Generate SVG rectangle with text + let svg = ` + + + ${node.name}\nDuration: ${duration.toFixed(2)}ms\nKind: ${node.kind}\nStatus: ${node.status} + ${ + nodeWidth > 30 + ? `${displayText} (${duration.toFixed(1)}ms)` + : "" + } + + `; + + // Render children below this node + if (node.children && node.children.length > 0) { + let childX = x; + + node.children.forEach((child) => { + const childDuration = child.duration_ms || 0; + const childWidth = (childDuration / duration) * nodeWidth; + + svg += this.renderNode( + child, + childX, + y + this.cellHeight, + childWidth, + duration, + ); + + childX += childWidth; + }); + } + + return svg; + } + + /** + * Get color for a span based on its kind and status + */ + getSpanColor(span) { + if (span.status === "error") { + return "#ef4444"; // red + } + + switch (span.kind) { + case "client": + return "#3b82f6"; // blue + case "server": + return "#10b981"; // green + case "internal": + return "#8b5cf6"; // purple + default: + return "#6b7280"; // gray + } + } + + /** + * Find node by span_id + */ + findNode(node, spanId) { + if (node.span_id === spanId) { + return node; + } + + if (node.children) { + for (const child of node.children) { + const found = this.findNode(child, spanId); + if (found) { + return found; + } + } + } + + return null; + } + + /** + * Zoom to a specific node + */ + zoomTo(spanId) { + const node = this.findNode(this.rootNode, spanId); + if (node) { + this.currentRoot = node; + this.render(); + } + } + + /** + * Reset zoom to root + */ + reset() { + this.currentRoot = this.rootNode; + this.searchTerm = ""; + this.render(); + } + + /** + * Search for spans by name + */ + search(term) { + this.searchTerm = term; + this.render(); + } +} + +// Global instance (will be initialized from template) +// eslint-disable-next-line prefer-const +let flameGraph = null; diff --git a/mcpgateway/static/gantt-chart.css b/mcpgateway/static/gantt-chart.css new file mode 100644 index 000000000..5c0b58093 --- /dev/null +++ b/mcpgateway/static/gantt-chart.css @@ -0,0 +1,273 @@ +/** + * Gantt Chart Styles + */ + +.gantt-container { + background: white; + border-radius: 0.5rem; + box-shadow: 0 1px 3px rgb(0 0 0 / 10%); + overflow: hidden; + margin: 1rem 0; +} + +/* Toolbar */ +.gantt-toolbar { + display: flex; + justify-content: space-between; + align-items: center; + padding: 1rem; + background: #f9fafb; + border-bottom: 1px solid #e5e7eb; +} + +.gantt-info { + font-size: 0.875rem; + color: #374151; +} + +.gantt-controls { + display: flex; + gap: 0.5rem; +} + +.gantt-btn { + padding: 0.5rem 0.75rem; + background: white; + border: 1px solid #d1d5db; + border-radius: 0.375rem; + cursor: pointer; + font-size: 0.875rem; + transition: all 0.2s; +} + +.gantt-btn:hover { + background: #f3f4f6; + border-color: #9ca3af; +} + +.gantt-btn:active { + transform: scale(0.95); +} + +/* Time Scale */ +.gantt-timescale { + position: relative; + height: 40px; + background: #fafafa; + border-bottom: 2px solid #e5e7eb; + margin-left: 250px; /* Offset for span names */ + margin-right: 100px; /* Offset for duration column */ +} + +.time-marker { + position: absolute; + top: 0; + height: 100%; +} + +.time-tick { + width: 1px; + height: 8px; + background: #9ca3af; + margin-left: -0.5px; +} + +.time-label { + font-size: 0.75rem; + color: #6b7280; + margin-top: 0.25rem; + margin-left: -15px; + font-family: monospace; +} + +/* Spans Container */ +.gantt-spans { + max-height: 600px; + overflow-y: auto; +} + +.span-row { + display: flex; + align-items: center; + padding: 0.5rem 0; + border-bottom: 1px solid #f3f4f6; + transition: background 0.2s; +} + +.span-row:hover { + background: #f9fafb; +} + +.span-row.critical-path-row { + background: #fef2f2; +} + +.span-row.critical-path-row:hover { + background: #fee2e2; +} + +/* Span Name Column */ +.span-name { + width: 250px; + font-size: 0.875rem; + color: #374151; + display: flex; + align-items: center; + gap: 0.25rem; + padding-right: 0.5rem; + flex-shrink: 0; +} + +.span-toggle { + background: none; + border: none; + cursor: pointer; + font-size: 0.75rem; + color: #6b7280; + padding: 0.125rem 0.25rem; + width: 20px; + text-align: center; +} + +.span-toggle:hover { + color: #374151; + background: #f3f4f6; + border-radius: 0.25rem; +} + +.span-spacer { + display: inline-block; + width: 20px; +} + +.span-label { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +/* Timeline Column */ +.span-timeline { + flex: 1; + height: 32px; + position: relative; + background: #f9fafb; + border-radius: 0.25rem; + margin: 0 0.5rem; +} + +.span-bar { + position: absolute; + height: 100%; + border-radius: 0.25rem; + display: flex; + align-items: center; + justify-content: center; + cursor: pointer; + transition: all 0.2s; + box-shadow: 0 1px 2px rgb(0 0 0 / 10%); +} + +.span-bar:hover { + transform: scaleY(1.1); + box-shadow: 0 2px 4px rgb(0 0 0 / 20%); + z-index: 10; +} + +.span-bar.critical-path-bar { + border: 2px solid #dc2626; + box-shadow: 0 0 0 2px rgb(220 38 38 / 20%); +} + +.span-bar-label { + font-size: 0.75rem; + color: white; + font-weight: 500; + text-shadow: 0 1px 2px rgb(0 0 0 / 20%); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + padding: 0 0.5rem; +} + +/* Duration Column */ +.span-duration { + width: 100px; + text-align: right; + font-size: 0.875rem; + color: #6b7280; + font-family: monospace; + padding-right: 1rem; + flex-shrink: 0; +} + +/* Legend */ +.gantt-legend { + display: flex; + gap: 1.5rem; + padding: 1rem; + background: #f9fafb; + border-top: 1px solid #e5e7eb; + font-size: 0.875rem; + justify-content: center; +} + +.legend-item { + display: flex; + align-items: center; + gap: 0.5rem; +} + +.legend-color { + width: 20px; + height: 12px; + border-radius: 0.25rem; +} + +.legend-color.critical-path { + background: white; + border: 2px solid #dc2626; +} + +/* Responsive */ +@media (width <= 768px) { + .gantt-toolbar { + flex-direction: column; + gap: 1rem; + align-items: stretch; + } + + .gantt-controls { + justify-content: center; + } + + .span-name { + width: 150px; + } + + .gantt-timescale { + margin-left: 150px; + } + + .span-duration { + width: 80px; + } +} + +/* Scrollbar styling for Gantt spans */ +.gantt-spans::-webkit-scrollbar { + width: 8px; +} + +.gantt-spans::-webkit-scrollbar-track { + background: #f1f5f9; +} + +.gantt-spans::-webkit-scrollbar-thumb { + background: #cbd5e1; + border-radius: 4px; +} + +.gantt-spans::-webkit-scrollbar-thumb:hover { + background: #94a3b8; +} diff --git a/mcpgateway/static/gantt-chart.js b/mcpgateway/static/gantt-chart.js new file mode 100644 index 000000000..02b8212b2 --- /dev/null +++ b/mcpgateway/static/gantt-chart.js @@ -0,0 +1,388 @@ +/** + * Interactive Gantt Chart for Trace Visualization + * + * Features: + * - Hierarchical span tree with expand/collapse + * - Interactive zoom and pan + * - Time scale with markers + * - Critical path highlighting + * - Keyboard shortcuts + * - Hover tooltips + */ + +/* eslint-disable no-unused-vars */ +class GanttChart { + constructor(containerId, traceData) { + this.container = document.getElementById(containerId); + this.trace = traceData; + this.zoomLevel = 1; + this.panOffset = 0; + this.collapsedSpans = new Set(); + this.spans = this.buildSpanTree(traceData.spans); + this.criticalPath = this.calculateCriticalPath(); + + this.init(); + } + + /** + * Build hierarchical span tree from flat span list + */ + buildSpanTree(spans) { + const spanMap = new Map(); + const roots = []; + + // Create map of all spans + spans.forEach((span) => { + spanMap.set(span.span_id, { + ...span, + children: [], + depth: 0, + }); + }); + + // Build tree structure + spans.forEach((span) => { + const node = spanMap.get(span.span_id); + if (span.parent_span_id && spanMap.has(span.parent_span_id)) { + const parent = spanMap.get(span.parent_span_id); + parent.children.push(node); + node.depth = parent.depth + 1; + } else { + roots.push(node); + } + }); + + // Flatten tree for rendering (depth-first) + const flatten = (node) => { + const result = [node]; + if (!this.collapsedSpans.has(node.span_id)) { + node.children.forEach((child) => { + result.push(...flatten(child)); + }); + } + return result; + }; + + return roots.flatMap(flatten); + } + + /** + * Calculate critical path (slowest sequential chain) + */ + calculateCriticalPath() { + const criticalSpans = new Set(); + + const findCriticalPath = (spans) => { + if (spans.length === 0) { + return []; + } + + // Find span with longest duration + children duration + let maxPath = []; + let maxDuration = 0; + + spans.forEach((span) => { + const childPath = findCriticalPath(span.children); + const totalDuration = + (span.duration_ms || 0) + + childPath.reduce((sum, s) => sum + (s.duration_ms || 0), 0); + + if (totalDuration > maxDuration) { + maxDuration = totalDuration; + maxPath = [span, ...childPath]; + } + }); + + return maxPath; + }; + + const roots = this.spans.filter((s) => !s.parent_span_id); + const path = findCriticalPath(roots); + path.forEach((span) => criticalSpans.add(span.span_id)); + + return criticalSpans; + } + + /** + * Initialize the chart + */ + init() { + this.render(); + this.attachEventListeners(); + } + + /** + * Render the complete chart + */ + render() { + const totalDuration = this.trace.duration_ms || 1; + const traceStart = new Date(this.trace.start_time); + + const html = ` +
+ +
+
+ Total Duration: ${totalDuration.toFixed(2)} ms + + ${this.spans.length} spans + +
+
+ + + + + +
+
+ + +
+ ${this.renderTimeScale(totalDuration)} +
+ + +
+ ${this.spans.map((span) => this.renderSpan(span, totalDuration, traceStart)).join("")} +
+ + +
+
+ + Client +
+
+ + Server +
+
+ + Internal +
+
+ + Error +
+
+ + Critical Path +
+
+
+ `; + + this.container.innerHTML = html; + } + + /** + * Render time scale markers + */ + renderTimeScale(totalDuration) { + const markers = []; + const step = this.calculateTimeStep(totalDuration); + + for (let t = 0; t <= totalDuration; t += step) { + const percent = (t / totalDuration) * 100; + markers.push(` +
+
+
${t.toFixed(0)}ms
+
+ `); + } + + return markers.join(""); + } + + /** + * Calculate appropriate time step for markers + */ + calculateTimeStep(totalDuration) { + if (totalDuration < 10) { + return 1; + } + if (totalDuration < 50) { + return 5; + } + if (totalDuration < 100) { + return 10; + } + if (totalDuration < 500) { + return 50; + } + if (totalDuration < 1000) { + return 100; + } + if (totalDuration < 5000) { + return 500; + } + return 1000; + } + + /** + * Render individual span row + */ + renderSpan(span, totalDuration, traceStart) { + const duration = span.duration_ms || 0; + const startMs = new Date(span.start_time) - traceStart; + const leftPercent = + (startMs / totalDuration) * 100 * this.zoomLevel + this.panOffset; + const widthPercent = (duration / totalDuration) * 100 * this.zoomLevel; + + const hasChildren = span.children && span.children.length > 0; + const isCollapsed = this.collapsedSpans.has(span.span_id); + const isCritical = this.criticalPath.has(span.span_id); + + // Determine color based on span kind + let color = "#3b82f6"; // client (blue) + if (span.kind === "server") { + color = "#10b981"; // green + } + if (span.kind === "internal") { + color = "#8b5cf6"; // purple + } + if (span.status === "error") { + color = "#ef4444"; // red + } + + const indentPx = span.depth * 20; + + return ` +
+
+ ${ + hasChildren + ? ` + + ` + : '' + } + + ${span.name} + +
+
+
+ ${widthPercent > 5 ? `${duration.toFixed(1)}ms` : ""} +
+
+
${duration.toFixed(2)} ms
+
+ `; + } + + /** + * Toggle span expand/collapse + */ + toggleSpan(spanId) { + if (this.collapsedSpans.has(spanId)) { + this.collapsedSpans.delete(spanId); + } else { + this.collapsedSpans.add(spanId); + } + this.spans = this.buildSpanTree(this.trace.spans); + this.render(); + } + + /** + * Expand all spans + */ + expandAll() { + this.collapsedSpans.clear(); + this.spans = this.buildSpanTree(this.trace.spans); + this.render(); + } + + /** + * Collapse all spans to top level + */ + collapseAll() { + this.trace.spans.forEach((span) => { + if (span.parent_span_id) { + this.collapsedSpans.add(span.parent_span_id); + } + }); + this.spans = this.buildSpanTree(this.trace.spans); + this.render(); + } + + /** + * Zoom in + */ + zoomIn() { + this.zoomLevel = Math.min(this.zoomLevel * 1.5, 10); + this.render(); + } + + /** + * Zoom out + */ + zoomOut() { + this.zoomLevel = Math.max(this.zoomLevel / 1.5, 0.1); + this.render(); + } + + /** + * Reset zoom and pan + */ + resetZoom() { + this.zoomLevel = 1; + this.panOffset = 0; + this.render(); + } + + /** + * Show detailed span information + */ + showSpanDetails(spanId) { + const span = this.trace.spans.find((s) => s.span_id === spanId); + if (!span) { + return; + } + + alert( + `Span Details:\n\nName: ${span.name}\nDuration: ${span.duration_ms}ms\nKind: ${span.kind}\nStatus: ${span.status}\n\nAttributes:\n${JSON.stringify(span.attributes, null, 2)}`, + ); + } + + /** + * Attach keyboard and mouse event listeners + */ + attachEventListeners() { + document.addEventListener("keydown", (e) => { + if (!this.container.isConnected) { + return; + } + + switch (e.key) { + case "=": + case "+": + this.zoomIn(); + e.preventDefault(); + break; + case "-": + case "_": + this.zoomOut(); + e.preventDefault(); + break; + case "0": + this.resetZoom(); + e.preventDefault(); + break; + } + }); + } +} + +// Global instance (will be initialized from template) +// eslint-disable-next-line prefer-const +let ganttChart = null; diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 91b53d1d2..daf59a761 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -54,6 +54,8 @@ > + + + + + + +
+ +
+

+ πŸ” + Observability Dashboard +

+
+ + + + + + +
+
+ + +
+ + + + + +
+
+ + +
+ +
+

+ πŸ”§ Advanced Filters +

+
+ +
+ + +
+
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+
+ + +
+ +
+
+ + +
+
+
+
Loading statistics...
+
+
+
+ + +
+ + + + + + + + + + + + + + + + + +
TimestampMethodEndpointStatusDurationUserActions
+ Loading traces... +
+
+ + +
+
+ +
+
+ + +
+
+
+

πŸ’Ύ Save Query

+ +
+ +
+ +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+

Current Filters

+
+
Time:
+
Status:
+
Min Duration:
+
Max Duration:
+
HTTP Method:
+
User:
+
Name:
+
Attributes:
+
+
+ + +
+ + +
+
+
+
+
+ + +
+
+ Loading metrics dashboard... +
+
+ + +
+
+ Loading tool metrics dashboard... +
+
+ + +
+
+ Loading prompt metrics dashboard... +
+
+ + +
+
+ Loading resource metrics dashboard... +
+
+ diff --git a/mcpgateway/templates/observability_prompts.html b/mcpgateway/templates/observability_prompts.html new file mode 100644 index 000000000..fc4443b8b --- /dev/null +++ b/mcpgateway/templates/observability_prompts.html @@ -0,0 +1,502 @@ + + + +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+ Loading prompt metrics... +
+
+
+ + +
+ +
+
+
+

Overall Health

+

+
+
+
+ + +
+
+
+

Most Rendered

+

+

+
+
πŸ’¬
+
+
+ + +
+
+
+

Slowest Prompt

+

+

+
+
🐌
+
+
+ + +
+
+
+

Most Error-Prone

+

+

+
+
⚠️
+
+
+
+ + +
+
+ +
+
+ + +
+
+ +
+
+ + +
+
+ +
+
+ + +
+

Prompt Performance Metrics

+
+ + + + + + + + + + + + + + + + +
#Prompt IDCountAvgMinp50p90p95p99Max
+
+
+
diff --git a/mcpgateway/templates/observability_resources.html b/mcpgateway/templates/observability_resources.html new file mode 100644 index 000000000..75472a208 --- /dev/null +++ b/mcpgateway/templates/observability_resources.html @@ -0,0 +1,502 @@ + + + +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+ Loading resource metrics... +
+
+
+ + +
+ +
+
+
+

Overall Health

+

+
+
+
+ + +
+
+
+

Most Fetched

+

+

+
+
πŸ“¦
+
+
+ + +
+
+
+

Slowest Resource

+

+

+
+
🐌
+
+
+ + +
+
+
+

Most Error-Prone

+

+

+
+
⚠️
+
+
+
+ + +
+
+ +
+
+ + +
+
+ +
+
+ + +
+
+ +
+
+ + +
+

Resource Performance Metrics

+
+ + + + + + + + + + + + + + + + +
#Resource URICountAvgMinp50p90p95p99Max
+
+
+
diff --git a/mcpgateway/templates/observability_stats.html b/mcpgateway/templates/observability_stats.html new file mode 100644 index 000000000..ade553110 --- /dev/null +++ b/mcpgateway/templates/observability_stats.html @@ -0,0 +1,19 @@ + +
+
+
Total Requests
+
{{ stats.total_traces }}
+
+
+
Success Rate
+
{{ "%.1f"|format(stats.success_count / stats.total_traces * 100 if stats.total_traces > 0 else 0) }}%
+
+
+
Error Count
+
{{ stats.error_count }}
+
+
+
Avg Response Time
+
{{ "%.0f"|format(stats.avg_duration_ms) }}ms
+
+
diff --git a/mcpgateway/templates/observability_tools.html b/mcpgateway/templates/observability_tools.html new file mode 100644 index 000000000..7b937bd37 --- /dev/null +++ b/mcpgateway/templates/observability_tools.html @@ -0,0 +1,591 @@ + + + +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+ Loading tool metrics... +
+
+
+ + +
+ +
+
+
+

Overall Health

+

+
+
+
+ + +
+
+
+

Most Used Tool

+

+

+
+
πŸ“Š
+
+
+ + +
+
+
+

Slowest Tool

+

+

+
+
🐌
+
+
+ + +
+
+
+

Most Error-Prone

+

+

+
+
⚠️
+
+
+
+ + +
+
+ +
+
+ + +
+
+ +
+
+ + +
+
+ +
+
+ + +
+

Tool Performance Metrics

+
+ + + + + + + + + + + + + + + + +
#Tool NameCountAvgMinp50p90p95p99Max
+
+
+ + +
+

Tool Error Rates

+
+ + + + + + + + + + + +
#Tool NameTotal CountError CountError Rate
+
+
+ + +
+

Common Tool Chains

+

Tools frequently invoked together in the same trace

+
+ + + + + + + + + +
#Tool ChainFrequency
+
+
+
diff --git a/mcpgateway/templates/observability_trace_detail.html b/mcpgateway/templates/observability_trace_detail.html new file mode 100644 index 000000000..a60a4e4f8 --- /dev/null +++ b/mcpgateway/templates/observability_trace_detail.html @@ -0,0 +1,207 @@ + +
+
+

Trace Details

+

+ {{ trace.trace_id }} +

+
+ +
+ +
+ +
+
+

REQUEST

+

Method: {{ trace.http_method }}

+

URL: {{ trace.http_url }}

+

Status: + + {% if trace.status == 'ok' %}βœ“{% else %}βœ—{% endif %} {{ trace.http_status_code }} + +

+
+
+

TIMING

+

Start: {{ trace.start_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] }}

+

End: {{ trace.end_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] if trace.end_time else 'In Progress' }}

+

Duration: + + {{ "%.2f"|format(trace.duration_ms) }} ms + +

+
+
+

USER CONTEXT

+

User: {{ trace.user_email or 'anonymous' }}

+

IP: {{ trace.ip_address or 'N/A' }}

+

User Agent: {{ (trace.user_agent[:50] + '...') if trace.user_agent and trace.user_agent|length > 50 else (trace.user_agent or 'N/A') }}

+
+ {% if trace.status_message %} +
+

STATUS MESSAGE

+

{{ trace.status_message }}

+
+ {% endif %} +
+ + + {% if trace.spans %} +
+
+

Execution Timeline

+
+ + + +
+
+ + +
+ + +
+ + +
+ {% set total_duration = trace.duration_ms %} + {% set trace_start = trace.start_time %} + + {% for span in trace.spans|sort(attribute='start_time') %} +
+
+ {{ ' ' * (0 if not span.parent_span_id else 1) }}{{ span.name }} +
+
+ {% set span_duration_ms = span.duration_ms or 0 %} + {% set span_offset_ms = ((span.start_time - trace_start).total_seconds() * 1000) %} + {% set left_percent = (span_offset_ms / total_duration * 100) if total_duration > 0 else 0 %} + {% set width_percent = (span_duration_ms / total_duration * 100) if total_duration > 0 else 0 %} + +
+ {% if width_percent > 10 %}{{ span.name }}{% endif %} +
+
+
{{ "%.2f"|format(span_duration_ms) }} ms
+
+ + + {% if span.events %} + {% for event in span.events %} +
+ + {{ event.name }} + {% if event.message %}: {{ event.message }}{% endif %} + +
+ {% endfor %} + {% endif %} + {% endfor %} +
+
+ + + + + + + {% else %} +

No spans recorded for this trace

+ {% endif %} + + + {% if trace.attributes %} +

Additional Attributes

+
+
{{ trace.attributes|tojson(indent=2) }}
+
+ {% endif %} +
diff --git a/mcpgateway/templates/observability_traces_list.html b/mcpgateway/templates/observability_traces_list.html new file mode 100644 index 000000000..afc9d7a31 --- /dev/null +++ b/mcpgateway/templates/observability_traces_list.html @@ -0,0 +1,39 @@ + +{% if traces %} + {% for trace in traces %} + + {{ trace.start_time.strftime('%Y-%m-%d %H:%M:%S') }} + {{ trace.http_method or 'N/A' }} + {{ trace.name }} + + + {% if trace.status == 'ok' %}βœ“ {{ trace.http_status_code or 'OK' }}{% else %}βœ— {{ trace.http_status_code or 'ERROR' }}{% endif %} + + + + + {{ "%.2f"|format(trace.duration_ms) if trace.duration_ms else 'N/A' }} ms + + + {{ trace.user_email or 'anonymous' }} + + + + + {% endfor %} +{% else %} + + +
+
πŸ“Š
+

No traces found

+

Make some API requests to see observability data

+
+ + +{% endif %} diff --git a/scripts/cleanup-dev.sh b/scripts/cleanup-dev.sh new file mode 100755 index 000000000..d7522d2dd --- /dev/null +++ b/scripts/cleanup-dev.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Cleanup script for development environment +# Kills all running servers and cleans up database locks + +set -e + +echo "🧹 Cleaning up development environment..." + +# Kill all running server processes +echo " Stopping all server processes..." +pkill -9 -f "uvicorn" 2>/dev/null || true +pkill -9 -f "python.*mcpgateway" 2>/dev/null || true +pkill -9 -f "make dev" 2>/dev/null || true + +# Kill processes on port 8000 +echo " Freeing port 8000..." +lsof -ti:8000 | xargs kill -9 2>/dev/null || true + +sleep 2 + +# Clean up SQLite WAL files +echo " Removing SQLite lock files..." +rm -f mcp.db-shm mcp.db-wal + +echo "βœ“ Cleanup complete!" +echo "" +echo "You can now run: make dev" diff --git a/tests/unit/mcpgateway/db/test_observability_migrations.py b/tests/unit/mcpgateway/db/test_observability_migrations.py new file mode 100644 index 000000000..4e887d199 --- /dev/null +++ b/tests/unit/mcpgateway/db/test_observability_migrations.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/db/test_observability_migrations.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Unit tests for observability Alembic migrations. + +Tests verify: +- Migration modules can be imported +- Upgrade and downgrade functions exist +- Migration revision IDs are correct +- Dependencies are properly defined +- No syntax errors in migration code +- Cross-database SQL compatibility +""" + +# Standard +import importlib +import inspect as pyinspect +import re + +# Third-Party +import pytest + + +# Migration module information +OBSERVABILITY_MIGRATIONS = [ + { + "module": "mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables", + "revision": "a23a08d61eb0", + "down_revision": "a706a3320c56", + "description": "add_observability_tables", + }, + { + "module": "mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes", + "revision": "i3c4d5e6f7g8", + "down_revision": "a23a08d61eb0", + "description": "add observability performance indexes", + }, + { + "module": "mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries", + "revision": "j4d5e6f7g8h9", + "down_revision": "i3c4d5e6f7g8", + "description": "add observability saved queries", + }, +] + + +class TestObservabilityMigrationModules: + """Test that all observability migration modules are valid.""" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_module_imports(self, migration_info): + """Test that migration module can be imported.""" + module_name = migration_info["module"] + + try: + module = importlib.import_module(module_name) + assert module is not None, f"Module {module_name} imported as None" + except ImportError as e: + pytest.fail(f"Failed to import {module_name}: {e}") + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_has_upgrade_function(self, migration_info): + """Test that migration has an upgrade() function.""" + module_name = migration_info["module"] + module = importlib.import_module(module_name) + + assert hasattr(module, "upgrade"), f"{module_name} missing upgrade() function" + assert callable(module.upgrade), f"{module_name}.upgrade is not callable" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_has_downgrade_function(self, migration_info): + """Test that migration has a downgrade() function.""" + module_name = migration_info["module"] + module = importlib.import_module(module_name) + + assert hasattr(module, "downgrade"), f"{module_name} missing downgrade() function" + assert callable(module.downgrade), f"{module_name}.downgrade is not callable" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_revision_id_correct(self, migration_info): + """Test that migration has correct revision ID.""" + module_name = migration_info["module"] + expected_revision = migration_info["revision"] + + module = importlib.import_module(module_name) + + assert hasattr(module, "revision"), f"{module_name} missing revision variable" + assert module.revision == expected_revision, f"{module_name} has incorrect revision: {module.revision} != {expected_revision}" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_down_revision_correct(self, migration_info): + """Test that migration has correct down_revision.""" + module_name = migration_info["module"] + expected_down_revision = migration_info["down_revision"] + + module = importlib.import_module(module_name) + + assert hasattr(module, "down_revision"), f"{module_name} missing down_revision variable" + assert module.down_revision == expected_down_revision, f"{module_name} has incorrect down_revision: {module.down_revision} != {expected_down_revision}" + + @pytest.mark.parametrize("migration_info", OBSERVABILITY_MIGRATIONS) + def test_migration_functions_have_no_parameters(self, migration_info): + """Test that upgrade() and downgrade() accept no parameters.""" + module_name = migration_info["module"] + module = importlib.import_module(module_name) + + # Check upgrade function signature + upgrade_sig = pyinspect.signature(module.upgrade) + assert len(upgrade_sig.parameters) == 0, f"{module_name}.upgrade() should have no parameters" + + # Check downgrade function signature + downgrade_sig = pyinspect.signature(module.downgrade) + assert len(downgrade_sig.parameters) == 0, f"{module_name}.downgrade() should have no parameters" + + +class TestObservabilityTablesMigration: + """Test migration a23a08d61eb0 (add observability tables).""" + + def test_creates_four_tables(self): + """Test that migration creates 4 observability tables.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + # Get source code + source = pyinspect.getsource(module.upgrade) + + # Count create_table calls + create_table_count = source.count("op.create_table") + assert create_table_count == 4, f"Expected 4 create_table calls, found {create_table_count}" + + # Verify table names + assert "observability_traces" in source + assert "observability_spans" in source + assert "observability_events" in source + assert "observability_metrics" in source + + def test_downgrade_drops_four_tables(self): + """Test that downgrade drops all 4 tables.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.downgrade) + + drop_table_count = source.count("op.drop_table") + assert drop_table_count == 4, f"Expected 4 drop_table calls, found {drop_table_count}" + + def test_uses_datetime_with_timezone(self): + """Test that DateTime columns use timezone=True.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # Should use DateTime(timezone=True) + assert "DateTime(timezone=True)" in source, "Missing DateTime(timezone=True)" + + def test_uses_json_column_type(self): + """Test that JSON columns are used for attributes.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.JSON() + assert "sa.JSON()" in source, "Missing sa.JSON() column type" + + def test_foreign_keys_have_cascade_delete(self): + """Test that foreign keys have CASCADE delete.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # Should have ondelete="CASCADE" + assert 'ondelete="CASCADE"' in source, "Missing CASCADE delete on foreign keys" + + +class TestObservabilityPerformanceIndexes: + """Test migration i3c4d5e6f7g8 (add performance indexes).""" + + def test_uses_op_create_index_not_raw_sql(self): + """Test that migration uses op.create_index() instead of raw SQL.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.upgrade) + + # Should use op.create_index + assert "op.create_index" in source, "Missing op.create_index calls" + + # Should NOT use raw SQL with IF NOT EXISTS + assert "CREATE INDEX IF NOT EXISTS" not in source, "Should not use raw SQL with IF NOT EXISTS" + assert "op.execute" not in source, "Should not use op.execute for index creation" + + def test_uses_op_drop_index_not_raw_sql(self): + """Test that downgrade uses op.drop_index() instead of raw SQL.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.downgrade) + + # Should use op.drop_index + assert "op.drop_index" in source, "Missing op.drop_index calls" + + # Should NOT use raw SQL with IF EXISTS + assert "DROP INDEX IF EXISTS" not in source, "Should not use raw SQL with IF EXISTS" + assert "op.execute" not in source, "Should not use op.execute for index dropping" + + def test_creates_composite_indexes(self): + """Test that migration creates composite indexes.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.upgrade) + + # Check for multi-column indexes + assert '["status", "start_time"]' in source or "['status', 'start_time']" in source, "Missing composite index on status+start_time" + assert '["trace_id", "start_time"]' in source or "['trace_id', 'start_time']" in source, "Missing composite index on trace_id+start_time" + + def test_downgrade_drops_indexes_in_reverse_order(self): + """Test that downgrade drops indexes (reverse order is good practice).""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.downgrade) + + # Count drop_index calls + drop_count = source.count("op.drop_index") + create_source = pyinspect.getsource(module.upgrade) + create_count = create_source.count("op.create_index") + + assert drop_count == create_count, f"Downgrade should drop {create_count} indexes, but drops {drop_count}" + + def test_specifies_table_name_in_drop_index(self): + """Test that op.drop_index includes table_name parameter.""" + module = importlib.import_module("mcpgateway.alembic.versions.i3c4d5e6f7g8_add_observability_performance_indexes") + + source = pyinspect.getsource(module.downgrade) + + # Should specify table_name for cross-database compatibility + assert "table_name=" in source, "op.drop_index should specify table_name parameter" + + +class TestObservabilitySavedQueries: + """Test migration j4d5e6f7g8h9 (add saved queries table).""" + + def test_boolean_uses_sa_false_not_string(self): + """Test that Boolean server_default uses sa.false() not string '0'.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.false() for Boolean + assert "sa.false()" in source, "Boolean server_default should use sa.false()" + + # Should NOT use string "0" for Boolean + assert 'sa.Boolean(), nullable=False, server_default="0"' not in source, "Should not use string '0' for Boolean server_default" + + def test_integer_uses_sa_text_for_default(self): + """Test that Integer server_default uses sa.text('0').""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.text("0") for Integer + assert 'sa.text("0")' in source, "Integer server_default should use sa.text('0')" + + def test_no_duplicate_user_email_index(self): + """Test that there's only ONE index on user_email column.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Count how many times we create an index on user_email + user_email_index_count = 0 + + # Look for index creation lines containing user_email + for line in source.split("\n"): + if "op.create_index" in line and "user_email" in line: + user_email_index_count += 1 + + assert user_email_index_count == 1, f"Expected 1 user_email index, found {user_email_index_count}" + + def test_downgrade_drops_correct_number_of_indexes(self): + """Test that downgrade drops the same number of indexes as upgrade creates.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + upgrade_source = pyinspect.getsource(module.upgrade) + downgrade_source = pyinspect.getsource(module.downgrade) + + create_count = upgrade_source.count("op.create_index") + drop_count = downgrade_source.count("op.drop_index") + + assert drop_count == create_count, f"Downgrade should drop {create_count} indexes, but drops {drop_count}" + + def test_uses_current_timestamp_for_datetime_defaults(self): + """Test that DateTime columns use CURRENT_TIMESTAMP for server defaults.""" + module = importlib.import_module("mcpgateway.alembic.versions.j4d5e6f7g8h9_add_observability_saved_queries") + + source = pyinspect.getsource(module.upgrade) + + # Should use sa.text("CURRENT_TIMESTAMP") for DateTime + assert 'sa.text("CURRENT_TIMESTAMP")' in source, "DateTime columns should use sa.text('CURRENT_TIMESTAMP')" + + +class TestCrossDatabaseCompatibility: + """Test cross-database compatibility concerns.""" + + def test_no_mysql_specific_if_not_exists(self): + """Test that migrations don't use MySQL < 8.0.13 incompatible IF NOT EXISTS.""" + for migration_info in OBSERVABILITY_MIGRATIONS: + module = importlib.import_module(migration_info["module"]) + upgrade_source = pyinspect.getsource(module.upgrade) + downgrade_source = pyinspect.getsource(module.downgrade) + + # Should not use raw SQL with IF NOT EXISTS / IF EXISTS + assert "IF NOT EXISTS" not in upgrade_source, f"{migration_info['module']} uses IF NOT EXISTS (MySQL < 8.0.13 incompatible)" + assert "IF EXISTS" not in downgrade_source, f"{migration_info['module']} uses IF EXISTS (MySQL < 8.0.13 incompatible)" + + def test_uses_sqlalchemy_types_not_raw_sql_types(self): + """Test that migrations use SQLAlchemy types (sa.*) not raw SQL types.""" + for migration_info in OBSERVABILITY_MIGRATIONS: + module = importlib.import_module(migration_info["module"]) + source = pyinspect.getsource(module.upgrade) + + # Should use sa.String, sa.Integer, etc. + if "create_table" in source: + assert "sa.String" in source or "sa.Text" in source or "sa.Integer" in source, f"{migration_info['module']} should use SQLAlchemy types" + + def test_datetime_columns_use_timezone_parameter(self): + """Test that DateTime columns specify timezone parameter.""" + module = importlib.import_module("mcpgateway.alembic.versions.a23a08d61eb0_add_observability_tables") + + source = pyinspect.getsource(module.upgrade) + + # All DateTime columns should specify timezone=True + datetime_matches = re.findall(r"sa\.DateTime\([^)]*\)", source) + + for match in datetime_matches: + assert "timezone=True" in match, f"DateTime column missing timezone parameter: {match}" + + +class TestMigrationChain: + """Test that migrations form a proper chain.""" + + def test_migrations_form_continuous_chain(self): + """Test that down_revision of each migration matches previous revision.""" + # Check that chain is continuous + revisions = {m["revision"]: m["down_revision"] for m in OBSERVABILITY_MIGRATIONS} + + # i3c4d5e6f7g8 should depend on a23a08d61eb0 + assert revisions["i3c4d5e6f7g8"] == "a23a08d61eb0" + + # j4d5e6f7g8h9 should depend on i3c4d5e6f7g8 + assert revisions["j4d5e6f7g8h9"] == "i3c4d5e6f7g8" + + def test_no_circular_dependencies(self): + """Test that there are no circular dependencies in migration chain.""" + revisions = {m["revision"]: m["down_revision"] for m in OBSERVABILITY_MIGRATIONS} + + # Build dependency graph and check for cycles + visited = set() + + for revision in revisions: + path = [] + current = revision + + while current and current not in visited: + if current in path: + pytest.fail(f"Circular dependency detected: {' -> '.join(path + [current])}") + path.append(current) + current = revisions.get(current) + + visited.update(path) + + def test_all_migrations_have_unique_revisions(self): + """Test that all migration revisions are unique.""" + revisions = [m["revision"] for m in OBSERVABILITY_MIGRATIONS] + + assert len(revisions) == len(set(revisions)), "Duplicate revision IDs found" From 9cf6c2de1e8c78c061bd0cde4d13c082dca030a2 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Sat, 8 Nov 2025 21:04:59 +0000 Subject: [PATCH 20/20] Fixes Signed-off-by: Mihai Criveti --- .../plugins/framework/external/mcp/server/runtime.py | 0 mcpgateway/services/tool_service.py | 7 ++++--- 2 files changed, 4 insertions(+), 3 deletions(-) mode change 100755 => 100644 mcpgateway/plugins/framework/external/mcp/server/runtime.py diff --git a/mcpgateway/plugins/framework/external/mcp/server/runtime.py b/mcpgateway/plugins/framework/external/mcp/server/runtime.py old mode 100755 new mode 100644 diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 8501409bd..fb9b1c1a0 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -1198,7 +1198,7 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r global_context.metadata[TOOL_METADATA] = tool_metadata pre_result, context_table = await self._plugin_manager.invoke_hook( ToolHookType.TOOL_PRE_INVOKE, - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), global_context=global_context, local_contexts=None, violations_as_exceptions=True, @@ -1354,7 +1354,7 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head global_context.metadata[GATEWAY_METADATA] = gateway_metadata pre_result, context_table = await self._plugin_manager.invoke_hook( ToolHookType.TOOL_PRE_INVOKE, - payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(headers)), + payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), global_context=global_context, local_contexts=None, violations_as_exceptions=True, @@ -1385,7 +1385,8 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head # Plugin hook: tool post-invoke if self._plugin_manager: - post_result, _ = await self._plugin_manager.tool_post_invoke( + post_result, _ = await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, payload=ToolPostInvokePayload(name=name, result=tool_result.model_dump(by_alias=True)), global_context=global_context, local_contexts=context_table,