From c4c64b7f782932183bf5c5fd50f331776e20b4fd Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Wed, 10 Sep 2025 15:33:32 +0530 Subject: [PATCH 01/10] Support url safety checking for adapters --- src/unstract/sdk/adapters/base.py | 33 ++- .../azure_open_ai/src/azure_open_ai.py | 8 + .../adapters/embedding/ollama/src/ollama.py | 8 + .../adapters/embedding/open_ai/src/open_ai.py | 8 + .../adapters/llm/any_scale/src/anyscale.py | 9 +- .../llm/azure_open_ai/src/azure_open_ai.py | 13 +- .../sdk/adapters/llm/ollama/src/ollama.py | 12 +- .../sdk/adapters/llm/open_ai/src/open_ai.py | 8 + src/unstract/sdk/adapters/url_validator.py | 172 +++++++++++++++ .../adapters/vectordb/milvus/src/milvus.py | 12 + .../milvus/src/static/json_schema.json | 2 +- .../vectordb/postgres/src/postgres.py | 21 ++ .../adapters/vectordb/qdrant/src/qdrant.py | 12 + .../vectordb/weaviate/src/weaviate.py | 9 + src/unstract/sdk/adapters/x2text/helper.py | 8 + .../x2text/llama_parse/src/llama_parse.py | 8 + .../x2text/llm_whisperer/src/llm_whisperer.py | 8 + .../llm_whisperer_v2/src/llm_whisperer_v2.py | 8 + .../src/unstructured_community.py | 8 + .../src/unstructured_enterprise.py | 5 + tests/test_url_validator.py | 208 ++++++++++++++++++ 21 files changed, 572 insertions(+), 8 deletions(-) create mode 100644 src/unstract/sdk/adapters/url_validator.py create mode 100644 tests/test_url_validator.py diff --git a/src/unstract/sdk/adapters/base.py b/src/unstract/sdk/adapters/base.py index 4b4daf98..719b0eab 100644 --- a/src/unstract/sdk/adapters/base.py +++ b/src/unstract/sdk/adapters/base.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from unstract.sdk.adapters.enums import AdapterTypes +from unstract.sdk.adapters.exceptions import AdapterError +from unstract.sdk.adapters.url_validator import URLValidator logger = logging.getLogger(__name__) @@ -32,7 +34,7 @@ def get_icon() -> str: @classmethod def get_json_schema(cls) -> str: - schema_path = getattr(cls, 'SCHEMA_PATH', None) + schema_path = getattr(cls, "SCHEMA_PATH", None) if schema_path is None: raise ValueError(f"SCHEMA_PATH not defined for {cls.__name__}") with open(schema_path) as f: @@ -43,6 +45,35 @@ def get_json_schema(cls) -> str: def get_adapter_type() -> AdapterTypes: return "" + def get_configured_urls(self) -> list[str]: + """Return all URLs that this adapter will connect to. + + This method should return a list of all URLs that the adapter + uses for external connections. These URLs will be validated + for security before allowing connection attempts. + + Returns: + list[str]: List of URLs that will be accessed by this adapter + """ + return [] + + def _validate_urls(self) -> None: + """Validate all configured URLs against security rules.""" + urls = self.get_configured_urls() + + for url in urls: + if not url: # Skip empty/None URLs + continue + + is_valid, error_message = URLValidator.validate_url(url) + if not is_valid: + # Use class name as fallback when self.name isn't set yet + adapter_name = getattr(self, "name", self.__class__.__name__) + logger.error( + f"URL validation failed for adapter '{adapter_name}': {error_message}" + ) + raise AdapterError(f"URL validation failed: {error_message}") + @abstractmethod def test_connection(self) -> bool: """Override to test connection for a adapter. diff --git a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py index 872cb398..5afa0ca7 100644 --- a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py +++ b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py @@ -26,6 +26,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("AzureOpenAIEmbedding") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -48,6 +51,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/AzureopenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + endpoint = self.config.get("azure_endpoint") + return [endpoint] if endpoint else [] + def get_embedding_instance(self) -> BaseEmbedding: try: embedding_batch_size = EmbeddingHelper.get_embedding_batch_size( diff --git a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py index 68b5b8a0..57035b73 100644 --- a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py @@ -19,6 +19,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("Ollama") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -41,6 +44,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/ollama.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + base_url = self.config.get("base_url") + return [base_url] if base_url else [] + def get_embedding_instance(self) -> BaseEmbedding: try: embedding_batch_size = EmbeddingHelper.get_embedding_batch_size( diff --git a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py index 781e849d..a4a33e5a 100644 --- a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py @@ -25,6 +25,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("OpenAI") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -47,6 +50,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/OpenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + api_base = self.config.get("api_base") + return [api_base] if api_base else [] + def get_embedding_instance(self) -> BaseEmbedding: try: timeout = int(self.config.get(Constants.TIMEOUT, Constants.DEFAULT_TIMEOUT)) diff --git a/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py b/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py index 3c371ddc..c0acb6e6 100644 --- a/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py +++ b/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py @@ -4,7 +4,6 @@ from llama_index.core.constants import DEFAULT_NUM_OUTPUTS from llama_index.core.llms import LLM from llama_index.llms.anyscale import Anyscale - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -24,6 +23,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("AnyScale") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -46,6 +48,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/anyscale.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + api_base = self.config.get(Constants.API_BASE) + return [api_base] if api_base else [] + def get_llm_instance(self) -> LLM: try: max_tokens = int(self.config.get(Constants.MAX_TOKENS, DEFAULT_NUM_OUTPUTS)) diff --git a/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py b/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py index fe1c123c..c706e49a 100644 --- a/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py +++ b/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py @@ -4,7 +4,6 @@ from llama_index.core.llms import LLM from llama_index.llms.azure_openai import AzureOpenAI from llama_index.llms.openai.utils import O1_MODELS - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -30,6 +29,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("AzureOpenAI") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -52,6 +54,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/AzureopenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + endpoint = self.config.get("azure_endpoint") + return [endpoint] if endpoint else [] + def get_llm_instance(self) -> LLM: max_retries = int( self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES) @@ -74,9 +81,7 @@ def get_llm_instance(self) -> LLM: } if enable_reasoning: - llm_kwargs["reasoning_effort"] = self.config.get( - Constants.REASONING_EFFORT - ) + llm_kwargs["reasoning_effort"] = self.config.get(Constants.REASONING_EFFORT) if model not in O1_MODELS: llm_kwargs["max_completion_tokens"] = max_tokens diff --git a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py index 49e6ff13..9ff3c404 100644 --- a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py @@ -6,7 +6,6 @@ from httpx import ConnectError, HTTPStatusError from llama_index.core.llms import LLM from llama_index.llms.ollama import Ollama - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -29,6 +28,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("Ollama") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -51,6 +53,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/ollama.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + base_url = self.config.get(Constants.BASE_URL) + return [base_url] if base_url else [] + def get_llm_instance(self) -> LLM: try: llm: LLM = Ollama( @@ -77,6 +84,9 @@ def get_llm_instance(self) -> LLM: raise AdapterError(str(exc)) def test_connection(self) -> bool: + # Validate URLs first + super().test_connection() + try: llm = self.get_llm_instance() if not llm: diff --git a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py index d1d1b255..95ad917b 100644 --- a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py @@ -29,6 +29,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("OpenAI") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -51,6 +54,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/OpenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + api_base = self.config.get("api_base") + return [api_base] if api_base else [] + def get_llm_instance(self) -> LLM: try: max_tokens = self.config.get(Constants.MAX_TOKENS) diff --git a/src/unstract/sdk/adapters/url_validator.py b/src/unstract/sdk/adapters/url_validator.py new file mode 100644 index 00000000..ac4f72d2 --- /dev/null +++ b/src/unstract/sdk/adapters/url_validator.py @@ -0,0 +1,172 @@ +import ipaddress +import logging +import os +import socket +from dataclasses import dataclass +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +@dataclass +class WhitelistEntry: + """Represents a whitelisted endpoint with IP range and optional port.""" + + ip_network: ipaddress.IPv4Network | ipaddress.IPv6Network + port: int | None = None + + +class URLValidator: + """Validates URLs to prevent SSRF attacks by blocking private IP addresses. + + URLs are validated to block private IP addresses unless explicitly + whitelisted via ALLOWED_ADAPTER_PRIVATE_ENDPOINTS. + """ + + ENV_VAR = "ALLOWED_ADAPTER_PRIVATE_ENDPOINTS" + + # Private IP ranges that are blocked by default (RFC 1918 + others) + BLOCKED_PRIVATE_RANGES = [ + "127.0.0.0/8", # Localhost + "10.0.0.0/8", # Class A private + "172.16.0.0/12", # Class B private + "192.168.0.0/16", # Class C private + "169.254.0.0/16", # Link-local + "0.0.0.0/8", # Current network + "224.0.0.0/4", # Multicast + "240.0.0.0/4", # Reserved + # IPv6 ranges + "::1/128", # IPv6 localhost + "fc00::/7", # IPv6 unique local + "fe80::/10", # IPv6 link-local + ] + + @classmethod + def validate_url(cls, url: str) -> tuple[bool, str]: + """Validates a URL against security rules. + + Args: + url: The URL to validate + + Returns: + Tuple of (is_valid, error_message) + """ + try: + parsed = urlparse(url) + + if not parsed.hostname: + return False, f"Invalid URL: No hostname found in '{url}'" + + # Resolve hostname to IP address + try: + host_ip = socket.gethostbyname(parsed.hostname) + except socket.gaierror as e: + return ( + False, + f"DNS resolution failed for '{parsed.hostname}': {str(e)}", + ) + + # Check if IP is private + ip_obj = ipaddress.ip_address(host_ip) + if cls._is_private_ip(ip_obj): + # Private IP - check whitelist + port = parsed.port + if cls._is_whitelisted(ip_obj, port): + logger.info(f"Private IP {host_ip}:{port} allowed by whitelist") + return True, "" + else: + error_msg = ( + f"URL blocked: Private IP {host_ip}" + f"{':' + str(port) if port else ''} not in whitelist. " + f"Contact platform admin for assistance." + ) + return False, error_msg + + # Public IP - allowed by default + return True, "" + + except Exception as e: + logger.error(f"URL validation error for '{url}': {str(e)}") + return False, f"{str(e)}" + + @classmethod + def _is_private_ip(cls, ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + """Check if IP address is in private ranges.""" + for range_str in cls.BLOCKED_PRIVATE_RANGES: + try: + network = ipaddress.ip_network(range_str) + if ip in network: + return True + except ValueError: + continue + return False + + @classmethod + def _is_whitelisted( + cls, ip: ipaddress.IPv4Address | ipaddress.IPv6Address, port: int | None + ) -> bool: + """Check if IP:port combination is whitelisted.""" + whitelist = cls._parse_whitelist_config() + + for entry in whitelist: + if ip in entry.ip_network: + # IP matches - check port + if entry.port is None or entry.port == port: + return True + + return False + + @classmethod + def _parse_whitelist_config(cls) -> list[WhitelistEntry]: + """Parse whitelist configuration from environment variable.""" + config = os.getenv(cls.ENV_VAR, "").strip() + if not config: + return [] + + entries = [] + for item in config.split(","): + item = item.strip() + if not item: + continue + + try: + entry = cls._parse_whitelist_entry(item) + if entry: + entries.append(entry) + except Exception as e: + logger.warning(f"Invalid whitelist entry '{item}': {str(e)}") + + return entries + + @classmethod + def _parse_whitelist_entry(cls, entry: str) -> WhitelistEntry | None: + """Parse a single whitelist entry in format 'IP:PORT' or + 'IP/CIDR:PORT'.""" + port = None + ip_part = entry + + # Check if entry has port specification + if ":" in entry: + parts = entry.rsplit(":", 1) + if len(parts) == 2 and parts[1].isdigit(): + ip_part = parts[0] + port = int(parts[1]) + + # Parse IP or CIDR + try: + if "/" in ip_part: + # CIDR notation + network = ipaddress.ip_network(ip_part, strict=False) + else: + # Single IP - convert to /32 or /128 network + ip = ipaddress.ip_address(ip_part) + if isinstance(ip, ipaddress.IPv4Address): + network = ipaddress.IPv4Network(f"{ip}/32") + else: + network = ipaddress.IPv6Network(f"{ip}/128") + + return WhitelistEntry(ip_network=network, port=port) + + except ValueError as e: + logger.warning(f"Invalid IP/CIDR in whitelist entry '{ip_part}': {str(e)}") + return None diff --git a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py index 12c1fbc7..930efa3b 100644 --- a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py +++ b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py @@ -21,6 +21,10 @@ def __init__(self, settings: dict[str, Any]): self._config = settings self._client: MilvusClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + # Validate URLs BEFORE any network operations + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Milvus", self._vector_db_instance) @@ -42,6 +46,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/Milvus.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + uri = self._config.get(Constants.URI) + return [uri] if uri else [] + def get_vector_db_instance(self) -> VectorStore: return self._vector_db_instance @@ -68,6 +77,9 @@ def _get_vector_db_instance(self) -> VectorStore: raise AdapterError(str(e)) def test_connection(self) -> bool: + # Validate URLs first + super().test_connection() + vector_db = self.get_vector_db_instance() test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing diff --git a/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json b/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json index 22965dbb..9c493a59 100644 --- a/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json +++ b/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json @@ -16,7 +16,7 @@ "type": "string", "title": "URI", "format": "uri", - "default": "localhost:19530", + "default": "http://localhost:19530", "description": "Provide the URI of the Milvus server. Example: `https://.api.gcp-us-west1.zillizcloud.com`" }, "token": { diff --git a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py index 36676129..27b28e88 100644 --- a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py +++ b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py @@ -28,6 +28,10 @@ def __init__(self, settings: dict[str, Any]): self._client: connection | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._schema_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + # Validate URLs BEFORE any network operations + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Postgres", self._vector_db_instance) @@ -95,6 +99,9 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: raise AdapterError(str(e)) def test_connection(self) -> bool: + # Validate URLs before attempting connection + super().test_connection() + vector_db = self.get_vector_db_instance() test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) @@ -108,6 +115,20 @@ def test_connection(self) -> bool: return test_result + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + host = self._config.get(Constants.HOST) + port = self._config.get(Constants.PORT) + + if host: + # Construct the database URL for validation + if port: + url = f"postgresql://{host}:{port}" + else: + url = f"postgresql://{host}" + return [url] + return [] + def close(self, **kwargs: Any) -> None: if self._client: self._client.close() diff --git a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py index 767a387d..d1a40327 100644 --- a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py +++ b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py @@ -24,6 +24,10 @@ def __init__(self, settings: dict[str, Any]): self._config = settings self._client: QdrantClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + # Validate URLs BEFORE any network operations + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Qdrant", self._vector_db_instance) @@ -45,6 +49,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/qdrant.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self._config.get(Constants.URL) + return [url] if url else [] + def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance @@ -71,6 +80,9 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: raise self.parse_vector_db_err(e) from e def test_connection(self) -> bool: + # Validate URLs first + super().test_connection() + try: vector_db = self.get_vector_db_instance() test_result: bool = VectorDBHelper.test_vector_db_instance( diff --git a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py index 8e2c12aa..34bad6b3 100644 --- a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py +++ b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py @@ -25,6 +25,10 @@ def __init__(self, settings: dict[str, Any]): self._config = settings self._client: weaviate.Client | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + # Validate URLs BEFORE any network operations + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Weaviate", self._vector_db_instance) @@ -46,6 +50,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/Weaviate.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self._config.get(Constants.URL) + return [url] if url else [] + def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance diff --git a/src/unstract/sdk/adapters/x2text/helper.py b/src/unstract/sdk/adapters/x2text/helper.py index 3a94872f..ad225951 100644 --- a/src/unstract/sdk/adapters/x2text/helper.py +++ b/src/unstract/sdk/adapters/x2text/helper.py @@ -5,6 +5,7 @@ from requests import Response from requests.exceptions import ConnectionError, HTTPError, Timeout from unstract.sdk.adapters.exceptions import AdapterError +from unstract.sdk.adapters.url_validator import URLValidator from unstract.sdk.adapters.utils import AdapterUtils from unstract.sdk.adapters.x2text.constants import X2TextConstants from unstract.sdk.constants import MimeType @@ -101,6 +102,13 @@ def make_request( ) -> Response: unstructured_url = unstructured_adapter_config.get(UnstructuredHelper.URL) + # Validate the unstructured URL for security + if unstructured_url: + is_valid, error_message = URLValidator.validate_url(unstructured_url) + if not is_valid: + logger.error(f"Unstructured URL validation failed: {error_message}") + raise AdapterError(f"URL validation failed: {error_message}") + x2text_service_url = unstructured_adapter_config.get(X2TextConstants.X2TEXT_HOST) x2text_service_port = unstructured_adapter_config.get(X2TextConstants.X2TEXT_PORT) platform_service_api_key = unstructured_adapter_config.get( diff --git a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py index 237a5078..d3f7347b 100644 --- a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py +++ b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py @@ -19,6 +19,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("LlamaParse") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -37,6 +40,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/llama-parse.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get("url") + return [url] if url else [] + def _call_parser( self, input_file_path: str, diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py index 71ef4502..6a573614 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py @@ -55,6 +55,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/LLMWhisperer.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get(WhispererConfig.URL) + return [url] if url else [] + def _get_request_headers(self) -> dict[str, Any]: """Obtains the request headers to authenticate with LLMWhisperer. @@ -200,6 +205,9 @@ def _get_whisper_params(self, enable_highlight: bool = False) -> dict[str, Any]: return params def test_connection(self) -> bool: + # Validate URLs first + super().test_connection() + self._make_request( request_method=HTTPMethod.GET, request_endpoint=WhispererEndpoint.TEST_CONNECTION, diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py index 166b9f92..9bec5846 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py @@ -24,6 +24,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("LLMWhispererV2") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -42,6 +45,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/LLMWhispererV2.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get("url") + return [url] if url else [] + def test_connection(self) -> bool: LLMWhispererHelper.test_connection_request( config=self.config, diff --git a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py index b205ce0c..7e9365aa 100644 --- a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py +++ b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py @@ -15,6 +15,9 @@ def __init__(self, settings: dict[str, Any]): super().__init__("UnstructuredIOCommunity") self.config = settings + # Validate URLs BEFORE any network operations + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -33,6 +36,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/UnstructuredIO.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get("url") + return [url] if url else [] + def process( self, input_file_path: str, diff --git a/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py b/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py index 908be6da..a713f409 100644 --- a/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py +++ b/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py @@ -33,6 +33,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/UnstructuredIO.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get("url") + return [url] if url else [] + def process( self, input_file_path: str, diff --git a/tests/test_url_validator.py b/tests/test_url_validator.py new file mode 100644 index 00000000..a5ac5aa1 --- /dev/null +++ b/tests/test_url_validator.py @@ -0,0 +1,208 @@ +import os +import unittest +from unittest.mock import patch + +from unstract.sdk.adapters.url_validator import URLValidator + + +class TestURLValidator(unittest.TestCase): + """Test cases for URL validation functionality.""" + + def setUp(self): + """Set up test environment.""" + # Clear any existing environment variables + if URLValidator.ENV_VAR in os.environ: + del os.environ[URLValidator.ENV_VAR] + + def tearDown(self): + """Clean up after tests.""" + # Clear environment variables + if URLValidator.ENV_VAR in os.environ: + del os.environ[URLValidator.ENV_VAR] + + def test_public_urls_allowed(self): + """Test that public URLs are allowed by default.""" + test_cases = [ + "https://api.openai.com/v1/chat/completions", + "https://google.com", + "http://example.com", + "https://1.1.1.1:8080", # Public IP with port + ] + + for url in test_cases: + with self.subTest(url=url): + is_valid, error = URLValidator.validate_url(url) + self.assertTrue( + is_valid, f"Public URL should be valid: {url}, Error: {error}" + ) + + @patch("socket.gethostbyname") + def test_private_ips_blocked_by_default(self, mock_gethostbyname): + """Test that private IPs are blocked when not whitelisted.""" + test_cases = [ + ("https://192.168.1.100", "192.168.1.100"), # Private class C + ("https://10.0.0.5:8080", "10.0.0.5"), # Private class A with port + ("https://172.16.5.10", "172.16.5.10"), # Private class B + ("https://127.0.0.1", "127.0.0.1"), # Localhost + ("https://169.254.169.254", "169.254.169.254"), # Link-local (AWS metadata) + ] + + for url, ip in test_cases: + with self.subTest(url=url): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(url) + self.assertFalse(is_valid, f"Private IP should be blocked: {url}") + self.assertIn("not in", error) + self.assertIn("whitelist", error) + + @patch("socket.gethostbyname") + def test_whitelisted_private_ips_allowed(self, mock_gethostbyname): + """Test that whitelisted private IPs are allowed.""" + # Set whitelist environment variable + os.environ[URLValidator.ENV_VAR] = "192.168.1.100:8080,10.0.0.0/8" + + test_cases = [ + ("https://192.168.1.100:8080", "192.168.1.100"), # Exact IP:port match + ("https://10.0.0.5:9200", "10.0.0.5"), # CIDR range match + ("https://10.255.255.255", "10.255.255.255"), # CIDR range edge + ] + + for url, ip in test_cases: + with self.subTest(url=url): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(url) + self.assertTrue( + is_valid, f"Whitelisted IP should be allowed: {url}, Error: {error}" + ) + + @patch("socket.gethostbyname") + def test_port_specific_whitelist(self, mock_gethostbyname): + """Test port-specific whitelisting.""" + os.environ[URLValidator.ENV_VAR] = "192.168.1.100:8080" + mock_gethostbyname.return_value = "192.168.1.100" + + # Port match - should be allowed + is_valid, error = URLValidator.validate_url("https://192.168.1.100:8080") + self.assertTrue(is_valid, "Matching port should be allowed") + + # Port mismatch - should be blocked + is_valid, error = URLValidator.validate_url("https://192.168.1.100:9000") + self.assertFalse(is_valid, "Non-matching port should be blocked") + + @patch("socket.gethostbyname") + def test_cidr_range_matching(self, mock_gethostbyname): + """Test CIDR range matching in whitelist.""" + os.environ[URLValidator.ENV_VAR] = "192.168.1.0/24:8080" + + test_cases = [ + ("192.168.1.1", True), # In range + ("192.168.1.255", True), # In range (edge) + ("192.168.2.1", False), # Out of range + ("192.168.0.255", False), # Out of range + ] + + for ip, should_be_valid in test_cases: + with self.subTest(ip=ip): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(f"https://{ip}:8080") + self.assertEqual( + is_valid, should_be_valid, f"CIDR matching failed for {ip}: {error}" + ) + + def test_whitelist_parsing(self): + """Test whitelist configuration parsing.""" + # Test various whitelist formats + os.environ[URLValidator.ENV_VAR] = ( + "192.168.1.100:8080,10.0.0.0/8,172.16.5.100:3000" + ) + + entries = URLValidator._parse_whitelist_config() + + self.assertEqual(len(entries), 3) + + # Check first entry (single IP with port) + self.assertEqual(str(entries[0].ip_network), "192.168.1.100/32") + self.assertEqual(entries[0].port, 8080) + + # Check second entry (CIDR without port) + self.assertEqual(str(entries[1].ip_network), "10.0.0.0/8") + self.assertIsNone(entries[1].port) + + # Check third entry (single IP with port) + self.assertEqual(str(entries[2].ip_network), "172.16.5.100/32") + self.assertEqual(entries[2].port, 3000) + + def test_invalid_whitelist_entries_ignored(self): + """Test that invalid whitelist entries are ignored gracefully.""" + os.environ[URLValidator.ENV_VAR] = ( + "192.168.1.100:8080,invalid-ip,10.0.0.0/8,bad-cidr/35" + ) + + entries = URLValidator._parse_whitelist_config() + + # Only valid entries should be parsed + self.assertEqual(len(entries), 2) + self.assertEqual(str(entries[0].ip_network), "192.168.1.100/32") + self.assertEqual(str(entries[1].ip_network), "10.0.0.0/8") + + def test_empty_whitelist_config(self): + """Test behavior with empty whitelist configuration.""" + os.environ[URLValidator.ENV_VAR] = "" + + entries = URLValidator._parse_whitelist_config() + self.assertEqual(len(entries), 0) + + @patch("socket.gethostbyname") + def test_dns_resolution_failure(self, mock_gethostbyname): + """Test handling of DNS resolution failures.""" + mock_gethostbyname.side_effect = Exception("DNS resolution failed") + + is_valid, error = URLValidator.validate_url("https://nonexistent.example.com") + self.assertFalse(is_valid) + self.assertIn("DNS resolution failed", error) + + def test_invalid_url_handling(self): + """Test handling of invalid URLs.""" + invalid_urls = [ + "not-a-url", + "https://", # No hostname + "", # Empty URL + ] + + for url in invalid_urls: + with self.subTest(url=url): + is_valid, error = URLValidator.validate_url(url) + self.assertFalse(is_valid) + self.assertTrue(len(error) > 0) + + @patch("socket.gethostbyname") + def test_localhost_blocked_by_default(self, mock_gethostbyname): + """Test that localhost is blocked when not explicitly whitelisted.""" + # No whitelist configured - localhost should be blocked + + localhost_ips = ["127.0.0.1", "127.0.0.2", "127.255.255.255"] + + for ip in localhost_ips: + with self.subTest(ip=ip): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(f"https://{ip}") + self.assertFalse( + is_valid, f"Localhost IP should be blocked by default: {ip}" + ) + self.assertIn("not in", error) + + @patch("socket.gethostbyname") + def test_metadata_service_blocked(self, mock_gethostbyname): + """Test that cloud metadata services are blocked.""" + # AWS/Azure metadata service + mock_gethostbyname.return_value = "169.254.169.254" + + is_valid, error = URLValidator.validate_url( + "https://169.254.169.254/latest/meta-data" + ) + self.assertFalse(is_valid) + self.assertIn("not in", error) + + +if __name__ == "__main__": + unittest.main() From 9eb31dacc70efebda0ef072f46af20c9c534a56c Mon Sep 17 00:00:00 2001 From: Gayathri <142381512+gaya3-zipstack@users.noreply.github.com> Date: Thu, 11 Sep 2025 10:44:23 +0530 Subject: [PATCH 02/10] Update src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Gayathri <142381512+gaya3-zipstack@users.noreply.github.com> --- .../adapters/x2text/llama_parse/src/llama_parse.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py index d3f7347b..cee8d857 100644 --- a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py +++ b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py @@ -40,11 +40,13 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/llama-parse.png" - def get_configured_urls(self) -> list[str]: - """Return all URLs this adapter will connect to.""" - url = self.config.get("url") - return [url] if url else [] - + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" +- url = self.config.get("url") + base_url = self.config.get(LlamaParseConfig.BASE_URL) + if isinstance(base_url, str): + base_url = base_url.strip() + return [base_url] if base_url else [] def _call_parser( self, input_file_path: str, From 6c53c7c6ed9bfe85d0000d2240f3ce712992607d Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Thu, 11 Sep 2025 10:50:52 +0530 Subject: [PATCH 03/10] Remove redundant base class test_connection --- .../sdk/adapters/llm/ollama/src/ollama.py | 6 +-- .../adapters/vectordb/milvus/src/milvus.py | 6 +-- .../vectordb/postgres/src/postgres.py | 10 +++-- .../adapters/vectordb/qdrant/src/qdrant.py | 6 +-- .../x2text/llm_whisperer/src/llm_whisperer.py | 41 +++++++++++++------ 5 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py index 9ff3c404..fe8ce124 100644 --- a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py @@ -67,7 +67,9 @@ def get_llm_instance(self) -> LLM: self.config.get(Constants.TIMEOUT, LLMKeys.DEFAULT_TIMEOUT) ), json_mode=False, - context_window=int(self.config.get(Constants.CONTEXT_WINDOW, 3900)), + context_window=int( + self.config.get(Constants.CONTEXT_WINDOW, 3900) + ), temperature=0.01, ) return llm @@ -84,8 +86,6 @@ def get_llm_instance(self) -> LLM: raise AdapterError(str(exc)) def test_connection(self) -> bool: - # Validate URLs first - super().test_connection() try: llm = self.get_llm_instance() diff --git a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py index 930efa3b..e3f02564 100644 --- a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py +++ b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py @@ -77,11 +77,11 @@ def _get_vector_db_instance(self) -> VectorStore: raise AdapterError(str(e)) def test_connection(self) -> bool: - # Validate URLs first - super().test_connection() vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) + test_result: bool = VectorDBHelper.test_vector_db_instance( + vector_store=vector_db + ) # Delete the collection that was created for testing if self._client is not None: self._client.drop_collection(self._collection_name) diff --git a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py index 27b28e88..9722ce59 100644 --- a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py +++ b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py @@ -58,7 +58,9 @@ def get_vector_db_instance(self) -> BasePydanticVectorStore: def _get_vector_db_instance(self) -> BasePydanticVectorStore: try: - encoded_password = quote_plus(str(self._config.get(Constants.PASSWORD))) + encoded_password = quote_plus( + str(self._config.get(Constants.PASSWORD)) + ) dimension = self._config.get( VectorDbConstants.EMBEDDING_DIMENSION, VectorDbConstants.DEFAULT_EMBEDDING_SIZE, @@ -99,11 +101,11 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: raise AdapterError(str(e)) def test_connection(self) -> bool: - # Validate URLs before attempting connection - super().test_connection() vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) + test_result: bool = VectorDBHelper.test_vector_db_instance( + vector_store=vector_db + ) # Delete the collection that was created for testing if self._client is not None: diff --git a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py index d1a40327..042a2678 100644 --- a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py +++ b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py @@ -80,8 +80,6 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: raise self.parse_vector_db_err(e) from e def test_connection(self) -> bool: - # Validate URLs first - super().test_connection() try: vector_db = self.get_vector_db_instance() @@ -118,4 +116,6 @@ def parse_vector_db_err(e: Exception) -> VectorDBError: status_code = 503 elif "timeout" in str(e): status_code = 504 - return VectorDBError(message=str(e), actual_err=e, status_code=status_code) + return VectorDBError( + message=str(e), actual_err=e, status_code=status_code + ) diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py index 6a573614..7bb147b5 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py @@ -68,7 +68,9 @@ def _get_request_headers(self) -> dict[str, Any]: """ return { "accept": MimeType.JSON, - WhispererHeader.UNSTRACT_KEY: self.config.get(WhispererConfig.UNSTRACT_KEY), + WhispererHeader.UNSTRACT_KEY: self.config.get( + WhispererConfig.UNSTRACT_KEY + ), } def _make_request( @@ -114,7 +116,9 @@ def _make_request( data=data, ) else: - raise ExtractorError(f"Unsupported request method: {request_method}") + raise ExtractorError( + f"Unsupported request method: {request_method}" + ) response.raise_for_status() except ConnectionError as e: logger.error(f"Adapter error: {e}") @@ -134,7 +138,9 @@ def _make_request( raise ExtractorError(msg) return response - def _get_whisper_params(self, enable_highlight: bool = False) -> dict[str, Any]: + def _get_whisper_params( + self, enable_highlight: bool = False + ) -> dict[str, Any]: """Gets query params meant for /whisper endpoint. The params is filled based on the configuration passed. @@ -200,13 +206,13 @@ def _get_whisper_params(self, enable_highlight: bool = False) -> dict[str, Any]: if enable_highlight: params.update( - {WhispererConfig.STORE_METADATA_FOR_HIGHLIGHTING: enable_highlight} + { + WhispererConfig.STORE_METADATA_FOR_HIGHLIGHTING: enable_highlight + } ) return params def test_connection(self) -> bool: - # Validate URLs first - super().test_connection() self._make_request( request_method=HTTPMethod.GET, @@ -251,7 +257,9 @@ def _check_status_until_ready( ) if status_response.status_code == 200: status_data = status_response.json() - status = status_data.get(WhisperStatus.STATUS, WhisperStatus.UNKNOWN) + status = status_data.get( + WhisperStatus.STATUS, WhisperStatus.UNKNOWN + ) logger.info(f"Whisper status for {whisper_hash}: {status}") if status in [WhisperStatus.PROCESSED, WhisperStatus.DELIVERED]: break @@ -264,7 +272,8 @@ def _check_status_until_ready( # Exit with error if max poll count is reached if request_count >= MAX_POLLS: raise ExtractorError( - "Unable to extract text after attempting" f" {request_count} times" + "Unable to extract text after attempting" + f" {request_count} times" ) time.sleep(POLL_INTERVAL) @@ -348,7 +357,9 @@ def _extract_text_from_response( raise ExtractorError("Couldn't extract text from file") if output_file_path: self._write_output_to_file( - output_json=output_json, output_file_path=Path(output_file_path), fs=fs + output_json=output_json, + output_file_path=Path(output_file_path), + fs=fs, ) return output_json.get("text", "") @@ -389,9 +400,13 @@ def _write_output_to_file( fs.mkdir(str(metadata_dir), create_parents=True) # Remove the "text" key from the metadata metadata = { - key: value for key, value in output_json.items() if key != "text" + key: value + for key, value in output_json.items() + if key != "text" } - metadata_json = json.dumps(metadata, ensure_ascii=False, indent=4) + metadata_json = json.dumps( + metadata, ensure_ascii=False, indent=4 + ) logger.info(f"Writing metadata to {metadata_file_path}") fs.write( @@ -401,7 +416,9 @@ def _write_output_to_file( data=metadata_json, ) except Exception as e: - logger.error(f"Error while writing metadata to {metadata_file_path}: {e}") + logger.error( + f"Error while writing metadata to {metadata_file_path}: {e}" + ) except Exception as e: logger.error(f"Error while writing {output_file_path}: {e}") From e370097af7995f17578639074573c6f8f1329cbc Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Thu, 11 Sep 2025 11:02:18 +0530 Subject: [PATCH 04/10] Fix indentation --- .../x2text/llama_parse/src/llama_parse.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py index cee8d857..b3958a2b 100644 --- a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py +++ b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py @@ -5,9 +5,12 @@ from httpx import ConnectError from llama_parse import LlamaParse + from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.x2text.dto import TextExtractionResult -from unstract.sdk.adapters.x2text.llama_parse.src.constants import LlamaParseConfig +from unstract.sdk.adapters.x2text.llama_parse.src.constants import ( + LlamaParseConfig, +) from unstract.sdk.adapters.x2text.x2text_adapter import X2TextAdapter from unstract.sdk.file_storage import FileStorage, FileStorageProvider @@ -40,13 +43,13 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/llama-parse.png" - def get_configured_urls(self) -> list[str]: - """Return all URLs this adapter will connect to.""" -- url = self.config.get("url") + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" base_url = self.config.get(LlamaParseConfig.BASE_URL) if isinstance(base_url, str): base_url = base_url.strip() return [base_url] if base_url else [] + def _call_parser( self, input_file_path: str, @@ -91,7 +94,8 @@ def _call_parser( except ConnectError as connec_err: logger.error(f"Invalid Base URL given. : {connec_err}") raise AdapterError( - "Unable to connect to llama-parse`s service, " "please check the Base URL" + "Unable to connect to llama-parse`s service, " + "please check the Base URL" ) except Exception as exe: logger.error( @@ -109,7 +113,9 @@ def process( fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: - response_text = self._call_parser(input_file_path=input_file_path, fs=fs) + response_text = self._call_parser( + input_file_path=input_file_path, fs=fs + ) if output_file_path: fs.write( path=output_file_path, From 81c5f9b9732f5e481ccbe9d3cd079ba68b86a6cc Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Mon, 15 Sep 2025 10:06:15 +0530 Subject: [PATCH 05/10] Roll SDK version --- src/unstract/sdk/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unstract/sdk/__init__.py b/src/unstract/sdk/__init__.py index 3f9c0f15..ffcb125c 100644 --- a/src/unstract/sdk/__init__.py +++ b/src/unstract/sdk/__init__.py @@ -1,4 +1,4 @@ -__version__ = "v0.77.1" +__version__ = "v0.78.0" def get_sdk_version() -> str: From 6c3a630c5db2a3d1b94cc18ed283ef405d4ba1c8 Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Mon, 15 Sep 2025 14:36:02 +0530 Subject: [PATCH 06/10] Add abstractmethod decorator --- src/unstract/sdk/adapters/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/unstract/sdk/adapters/base.py b/src/unstract/sdk/adapters/base.py index 719b0eab..b1827bf4 100644 --- a/src/unstract/sdk/adapters/base.py +++ b/src/unstract/sdk/adapters/base.py @@ -45,6 +45,7 @@ def get_json_schema(cls) -> str: def get_adapter_type() -> AdapterTypes: return "" + @abstractmethod def get_configured_urls(self) -> list[str]: """Return all URLs that this adapter will connect to. From afb0ab52906cbb91b54ac99dd1bea53bf7bb0af9 Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Mon, 15 Sep 2025 21:57:43 +0530 Subject: [PATCH 07/10] Change env name. Pass validate_urls=true explicitly --- .../embedding/azure_open_ai/src/azure_open_ai.py | 5 +++-- .../sdk/adapters/embedding/ollama/src/ollama.py | 5 +++-- .../sdk/adapters/embedding/open_ai/src/open_ai.py | 5 +++-- .../sdk/adapters/llm/any_scale/src/anyscale.py | 5 +++-- .../llm/azure_open_ai/src/azure_open_ai.py | 5 +++-- .../sdk/adapters/llm/ollama/src/ollama.py | 11 ++++------- .../sdk/adapters/llm/open_ai/src/open_ai.py | 6 +++--- src/unstract/sdk/adapters/url_validator.py | 4 ++-- .../sdk/adapters/vectordb/milvus/src/milvus.py | 11 ++++------- .../adapters/vectordb/postgres/src/postgres.py | 15 +++++---------- .../sdk/adapters/vectordb/qdrant/src/qdrant.py | 11 ++++------- .../adapters/vectordb/weaviate/src/weaviate.py | 6 +++--- .../x2text/llama_parse/src/llama_parse.py | 13 +++++-------- .../llm_whisperer_v2/src/llm_whisperer_v2.py | 5 +++-- .../src/unstructured_community.py | 5 +++-- 15 files changed, 51 insertions(+), 61 deletions(-) diff --git a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py index 5afa0ca7..37631a34 100644 --- a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py +++ b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py @@ -22,12 +22,13 @@ class Constants: class AzureOpenAI(EmbeddingAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("AzureOpenAIEmbedding") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" diff --git a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py index 57035b73..f0e921ec 100644 --- a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py @@ -15,12 +15,13 @@ class Constants: class Ollama(EmbeddingAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("Ollama") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" diff --git a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py index a4a33e5a..c51790d0 100644 --- a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py @@ -21,12 +21,13 @@ class Constants: class OpenAI(EmbeddingAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("OpenAI") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" diff --git a/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py b/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py index c0acb6e6..5d53560e 100644 --- a/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py +++ b/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py @@ -19,12 +19,13 @@ class Constants: class AnyScaleLLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("AnyScale") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" diff --git a/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py b/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py index c706e49a..22a62f7c 100644 --- a/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py +++ b/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py @@ -25,12 +25,13 @@ class Constants: class AzureOpenAILLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("AzureOpenAI") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" diff --git a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py index fe8ce124..0c35b67b 100644 --- a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py @@ -24,12 +24,12 @@ class Constants: class OllamaLLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("Ollama") self.config = settings - # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @@ -67,9 +67,7 @@ def get_llm_instance(self) -> LLM: self.config.get(Constants.TIMEOUT, LLMKeys.DEFAULT_TIMEOUT) ), json_mode=False, - context_window=int( - self.config.get(Constants.CONTEXT_WINDOW, 3900) - ), + context_window=int(self.config.get(Constants.CONTEXT_WINDOW, 3900)), temperature=0.01, ) return llm @@ -86,7 +84,6 @@ def get_llm_instance(self) -> LLM: raise AdapterError(str(exc)) def test_connection(self) -> bool: - try: llm = self.get_llm_instance() if not llm: diff --git a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py index 95ad917b..0baa4e18 100644 --- a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py @@ -25,12 +25,12 @@ class Constants: class OpenAILLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("OpenAI") self.config = settings - # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" diff --git a/src/unstract/sdk/adapters/url_validator.py b/src/unstract/sdk/adapters/url_validator.py index ac4f72d2..afdeb131 100644 --- a/src/unstract/sdk/adapters/url_validator.py +++ b/src/unstract/sdk/adapters/url_validator.py @@ -20,10 +20,10 @@ class URLValidator: """Validates URLs to prevent SSRF attacks by blocking private IP addresses. URLs are validated to block private IP addresses unless explicitly - whitelisted via ALLOWED_ADAPTER_PRIVATE_ENDPOINTS. + whitelisted via WHITELISTED_ENDPOINTS. """ - ENV_VAR = "ALLOWED_ADAPTER_PRIVATE_ENDPOINTS" + ENV_VAR = "WHITELISTED_ENDPOINTS" # Private IP ranges that are blocked by default (RFC 1918 + others) BLOCKED_PRIVATE_RANGES = [ diff --git a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py index e3f02564..4b60fffe 100644 --- a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py +++ b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py @@ -17,13 +17,13 @@ class Constants: class Milvus(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: MilvusClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME - # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() self._vector_db_instance = self._get_vector_db_instance() super().__init__("Milvus", self._vector_db_instance) @@ -77,11 +77,8 @@ def _get_vector_db_instance(self) -> VectorStore: raise AdapterError(str(e)) def test_connection(self) -> bool: - vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance( - vector_store=vector_db - ) + test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing if self._client is not None: self._client.drop_collection(self._collection_name) diff --git a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py index 9722ce59..d6e200a9 100644 --- a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py +++ b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py @@ -23,14 +23,14 @@ class Constants: class Postgres(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: connection | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._schema_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME - # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() self._vector_db_instance = self._get_vector_db_instance() super().__init__("Postgres", self._vector_db_instance) @@ -58,9 +58,7 @@ def get_vector_db_instance(self) -> BasePydanticVectorStore: def _get_vector_db_instance(self) -> BasePydanticVectorStore: try: - encoded_password = quote_plus( - str(self._config.get(Constants.PASSWORD)) - ) + encoded_password = quote_plus(str(self._config.get(Constants.PASSWORD))) dimension = self._config.get( VectorDbConstants.EMBEDDING_DIMENSION, VectorDbConstants.DEFAULT_EMBEDDING_SIZE, @@ -101,11 +99,8 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: raise AdapterError(str(e)) def test_connection(self) -> bool: - vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance( - vector_store=vector_db - ) + test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing if self._client is not None: diff --git a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py index 042a2678..36173ee6 100644 --- a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py +++ b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py @@ -20,13 +20,13 @@ class Constants: class Qdrant(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: QdrantClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME - # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() self._vector_db_instance = self._get_vector_db_instance() super().__init__("Qdrant", self._vector_db_instance) @@ -80,7 +80,6 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: raise self.parse_vector_db_err(e) from e def test_connection(self) -> bool: - try: vector_db = self.get_vector_db_instance() test_result: bool = VectorDBHelper.test_vector_db_instance( @@ -116,6 +115,4 @@ def parse_vector_db_err(e: Exception) -> VectorDBError: status_code = 503 elif "timeout" in str(e): status_code = 504 - return VectorDBError( - message=str(e), actual_err=e, status_code=status_code - ) + return VectorDBError(message=str(e), actual_err=e, status_code=status_code) diff --git a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py index 34bad6b3..13cd9860 100644 --- a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py +++ b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py @@ -21,13 +21,13 @@ class Constants: class Weaviate(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: weaviate.Client | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME - # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() self._vector_db_instance = self._get_vector_db_instance() super().__init__("Weaviate", self._vector_db_instance) diff --git a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py index b3958a2b..d871fc2f 100644 --- a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py +++ b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py @@ -5,7 +5,6 @@ from httpx import ConnectError from llama_parse import LlamaParse - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.x2text.dto import TextExtractionResult from unstract.sdk.adapters.x2text.llama_parse.src.constants import ( @@ -18,12 +17,13 @@ class LlamaParseAdapter(X2TextAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("LlamaParse") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @@ -94,8 +94,7 @@ def _call_parser( except ConnectError as connec_err: logger.error(f"Invalid Base URL given. : {connec_err}") raise AdapterError( - "Unable to connect to llama-parse`s service, " - "please check the Base URL" + "Unable to connect to llama-parse`s service, " "please check the Base URL" ) except Exception as exe: logger.error( @@ -113,9 +112,7 @@ def process( fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: - response_text = self._call_parser( - input_file_path=input_file_path, fs=fs - ) + response_text = self._call_parser(input_file_path=input_file_path, fs=fs) if output_file_path: fs.write( path=output_file_path, diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py index 9bec5846..0964894b 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py @@ -20,12 +20,13 @@ class LLMWhispererV2(X2TextAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("LLMWhispererV2") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" diff --git a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py index 7e9365aa..6be3289c 100644 --- a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py +++ b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py @@ -11,12 +11,13 @@ class UnstructuredCommunity(X2TextAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("UnstructuredIOCommunity") self.config = settings # Validate URLs BEFORE any network operations - self._validate_urls() + if validate_urls: + self._validate_urls() SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" From 440d847657a820b473c7a090f10aeb7b08a5513c Mon Sep 17 00:00:00 2001 From: Gayathri <142381512+gaya3-zipstack@users.noreply.github.com> Date: Thu, 18 Sep 2025 12:26:13 +0530 Subject: [PATCH 08/10] Update tests/test_url_validator.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Gayathri <142381512+gaya3-zipstack@users.noreply.github.com> --- tests/test_url_validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_url_validator.py b/tests/test_url_validator.py index a5ac5aa1..0560afeb 100644 --- a/tests/test_url_validator.py +++ b/tests/test_url_validator.py @@ -20,7 +20,8 @@ def tearDown(self): if URLValidator.ENV_VAR in os.environ: del os.environ[URLValidator.ENV_VAR] - def test_public_urls_allowed(self): + @patch("socket.gethostbyname", return_value="1.1.1.1") + def test_public_urls_allowed(self, _): """Test that public URLs are allowed by default.""" test_cases = [ "https://api.openai.com/v1/chat/completions", @@ -35,7 +36,6 @@ def test_public_urls_allowed(self): self.assertTrue( is_valid, f"Public URL should be valid: {url}, Error: {error}" ) - @patch("socket.gethostbyname") def test_private_ips_blocked_by_default(self, mock_gethostbyname): """Test that private IPs are blocked when not whitelisted.""" From c883c9a3be3eab60ebab830cd9bf820e540865d0 Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Thu, 18 Sep 2025 12:35:11 +0530 Subject: [PATCH 09/10] Ceanup whitespace and extra char '/' in url validation --- .../sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py index 7bb147b5..0bce7895 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py @@ -8,6 +8,7 @@ import requests from requests import Response from requests.exceptions import ConnectionError, HTTPError, Timeout + from unstract.sdk.adapters.exceptions import ExtractorError from unstract.sdk.adapters.utils import AdapterUtils from unstract.sdk.adapters.x2text.constants import X2TextConstants @@ -57,8 +58,8 @@ def get_icon() -> str: def get_configured_urls(self) -> list[str]: """Return all URLs this adapter will connect to.""" - url = self.config.get(WhispererConfig.URL) - return [url] if url else [] + url = str(self.config.get(WhispererConfig.URL) or "").strip() + return [url.rstrip("/")] if url else [] def _get_request_headers(self) -> dict[str, Any]: """Obtains the request headers to authenticate with LLMWhisperer. From 77398b2084abe461a030291ba5546ace01376a44 Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Mon, 22 Sep 2025 16:42:03 +0530 Subject: [PATCH 10/10] Better error propagation for status code --- src/unstract/sdk/prompt.py | 32 +++++++++++++++++++------------- src/unstract/sdk/tool/stream.py | 5 +++-- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/unstract/sdk/prompt.py b/src/unstract/sdk/prompt.py index b21c8315..7818656e 100644 --- a/src/unstract/sdk/prompt.py +++ b/src/unstract/sdk/prompt.py @@ -6,11 +6,8 @@ import requests from deprecated import deprecated from requests import ConnectionError, RequestException, Response -from unstract.sdk.constants import ( - MimeType, - RequestHeader, - ToolEnv, -) + +from unstract.sdk.constants import MimeType, RequestHeader, ToolEnv from unstract.sdk.helper import SdkHelper from unstract.sdk.platform import PlatformHelper from unstract.sdk.tool.base import BaseTool @@ -22,7 +19,9 @@ R = TypeVar("R") -def handle_service_exceptions(context: str) -> Callable[[Callable[P, R]], Callable[P, R]]: +def handle_service_exceptions( + context: str, +) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator to handle exceptions in PromptTool service calls. Args: @@ -39,20 +38,23 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: except ConnectionError as e: msg = f"Error while {context}. Unable to connect to prompt service." logger.error(f"{msg}\n{e}") - args[0].tool.stream_error_and_exit(msg, e) + args[0].tool.stream_error_and_exit(msg, e, None) except RequestException as e: error_message = str(e) + status_code = None response = getattr(e, "response", None) if response is not None: + status_code = response.status_code if ( - MimeType.JSON in response.headers.get("Content-Type", "").lower() + MimeType.JSON + in response.headers.get("Content-Type", "").lower() and "error" in response.json() ): error_message = response.json()["error"] elif response.text: error_message = response.text msg = f"Error while {context}. {error_message}" - args[0].tool.stream_error_and_exit(msg, e) + args[0].tool.stream_error_and_exit(msg, e, status_code) return wrapper @@ -79,7 +81,9 @@ def __init__( is_public_call (bool): Whether the call is public. Defaults to False """ self.tool = tool - self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port) + self.base_url = SdkHelper.get_platform_base_url( + prompt_host, prompt_port + ) self.is_public_call = is_public_call self.request_id = request_id if not is_public_call: @@ -168,7 +172,9 @@ def summarize( headers=headers, ) - def _get_headers(self, headers: dict[str, str] | None = None) -> dict[str, str]: + def _get_headers( + self, headers: dict[str, str] | None = None + ) -> dict[str, str]: """Get default headers for requests. Returns: @@ -218,13 +224,13 @@ def _call_service( response = requests.get(url=url, params=params, headers=req_headers) else: raise ValueError(f"Unsupported HTTP method: {method}") - response.raise_for_status() return response.json() @staticmethod @deprecated( - version="v0.71.0", reason="Use `PlatformHelper.get_prompt_studio_tool` instead" + version="v0.71.0", + reason="Use `PlatformHelper.get_prompt_studio_tool` instead", ) def get_exported_tool( tool: BaseTool, prompt_registry_id: str diff --git a/src/unstract/sdk/tool/stream.py b/src/unstract/sdk/tool/stream.py index 9286025a..ca9c84b3 100644 --- a/src/unstract/sdk/tool/stream.py +++ b/src/unstract/sdk/tool/stream.py @@ -115,18 +115,19 @@ def stream_log( } print(json.dumps(record)) - def stream_error_and_exit(self, message: str, err: Exception | None = None) -> None: + def stream_error_and_exit(self, message: str, err: Exception | None = None, status_code: int | None = None) -> None: """Stream error log and exit. Args: message (str): Error message err (Exception): Actual exception that occurred + status_code (int): HTTP status code to preserve """ self.stream_log(message, level=LogLevel.ERROR) if self._exec_by_tool: exit(1) else: - raise SdkError(message, actual_err=err) + raise SdkError(message, status_code=status_code, actual_err=err) def get_env_or_die(self, env_key: str) -> str: """Returns the value of an env variable.