diff --git a/.chronus/changes/external-type-python-2025-9-20-14-28-41.md b/.chronus/changes/external-type-python-2025-9-20-14-28-41.md new file mode 100644 index 00000000000..2977eb159e2 --- /dev/null +++ b/.chronus/changes/external-type-python-2025-9-20-14-28-41.md @@ -0,0 +1,7 @@ +--- +changeKind: feature +packages: + - "@typespec/http-client-python" +--- + +Support SDK users defined customized serialization/deserialization function for external models \ No newline at end of file diff --git a/packages/http-client-python/emitter/src/types.ts b/packages/http-client-python/emitter/src/types.ts index 26bee78d935..ee7a373ed8f 100644 --- a/packages/http-client-python/emitter/src/types.ts +++ b/packages/http-client-python/emitter/src/types.ts @@ -270,6 +270,12 @@ function emitModel(context: PythonSdkContext, type: SdkModelType): Record[] = []; const newValue = { type: type.kind, diff --git a/packages/http-client-python/generator/pygen/codegen/models/__init__.py b/packages/http-client-python/generator/pygen/codegen/models/__init__.py index 119a1b68efd..a1d9f9a4dbc 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/__init__.py +++ b/packages/http-client-python/generator/pygen/codegen/models/__init__.py @@ -31,6 +31,7 @@ SdkCoreType, DecimalType, MultiPartFileType, + ExternalType, ) from .enum_type import EnumType, EnumValue from .base import BaseType @@ -151,6 +152,7 @@ "credential": StringType, "sdkcore": SdkCoreType, "multipartfile": MultiPartFileType, + "external": ExternalType, } _LOGGER = logging.getLogger(__name__) diff --git a/packages/http-client-python/generator/pygen/codegen/models/code_model.py b/packages/http-client-python/generator/pygen/codegen/models/code_model.py index 6d2796d49e2..86927cac2ea 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/code_model.py +++ b/packages/http-client-python/generator/pygen/codegen/models/code_model.py @@ -10,6 +10,7 @@ from .enum_type import EnumType from .model_type import ModelType, UsageFlags from .combined_type import CombinedType +from .primitive_types import ExternalType from .client import Client from .request_builder import RequestBuilder, OverloadedRequestBuilder from .operation_group import OperationGroup @@ -101,6 +102,7 @@ def __init__( self._operations_folder_name: dict[str, str] = {} self._relative_import_path: dict[str, str] = {} self.metadata: dict[str, Any] = yaml_data.get("metadata", {}) + self.has_external_type = any(isinstance(t, ExternalType) for t in self.types_map.values()) @staticmethod def get_imported_namespace_for_client(imported_namespace: str, async_mode: bool = False) -> str: @@ -488,3 +490,8 @@ def _get_relative_generation_dir(self, root_dir: Path, namespace: str) -> Path: @property def has_operation_named_list(self) -> bool: return any(o.name.lower() == "list" for c in self.clients for og in c.operation_groups for o in og.operations) + + @property + def external_types(self) -> list[ExternalType]: + """All of the external types""" + return [t for t in self.types_map.values() if isinstance(t, ExternalType)] diff --git a/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py b/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py index 00b480c5537..20332fcd4a9 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py +++ b/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py @@ -615,6 +615,39 @@ def serialization_type(self, **kwargs: Any) -> str: return self.name +class ExternalType(PrimitiveType): + def __init__(self, yaml_data: dict[str, Any], code_model: "CodeModel") -> None: + super().__init__(yaml_data=yaml_data, code_model=code_model) + self.external_type_info = yaml_data.get("externalTypeInfo", {}) + self.identity = self.external_type_info.get("identity", "") + self.submodule = ".".join(self.identity.split(".")[:-1]) + self.min_version = self.external_type_info.get("minVersion", "") + self.package_name = self.external_type_info.get("package", "") + + def docstring_type(self, **kwargs: Any) -> str: + return f"~{self.identity}" + + def type_annotation(self, **kwargs: Any) -> str: + return self.identity + + def imports(self, **kwargs: Any) -> FileImport: + file_import = super().imports(**kwargs) + file_import.add_import(self.submodule, ImportType.THIRDPARTY, TypingSection.REGULAR) + return file_import + + @property + def instance_check_template(self) -> str: + return f"isinstance({{}}, {self.identity})" + + def serialization_type(self, **kwargs: Any) -> str: + return self.identity + + @property + def default_template_representation_declaration(self) -> str: + value = f"{self.identity}(...)" + return f'"{value}"' if self.code_model.for_test else value + + class MultiPartFileType(PrimitiveType): def __init__(self, yaml_data: dict[str, Any], code_model: "CodeModel") -> None: super().__init__(yaml_data=yaml_data, code_model=code_model) diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py index cebe5d86166..fad14ccee31 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py @@ -23,7 +23,7 @@ "msrest": "0.7.1", "isodate": "0.6.1", "azure-mgmt-core": "1.6.0", - "azure-core": "1.35.0", + "azure-core": "1.36.0", "typing-extensions": "4.6.0", "corehttp": "1.0.0b6", } @@ -57,7 +57,16 @@ def _extract_min_dependency(self, s): m = re.search(r"[>=]=?([\d.]+(?:[a-z]+\d+)?)", s) return parse_version(m.group(1)) if m else parse_version("0") - def _keep_pyproject_fields(self, file_content: str) -> dict: + def _update_version_map(self, version_map: dict[str, str], dep_name: str, dep: str) -> None: + # For tracked dependencies, check if the version is higher than our default + default_version = parse_version(version_map[dep_name]) + dep_version = self._extract_min_dependency(dep) + # If the version is higher than the default, update VERSION_MAP + # with higher min dependency version + if dep_version > default_version: + version_map[dep_name] = str(dep_version) + + def _keep_pyproject_fields(self, file_content: str, additional_version_map: dict[str, str]) -> dict: # Load the pyproject.toml file if it exists and extract fields to keep. result: dict = {"KEEP_FIELDS": {}} try: @@ -80,15 +89,11 @@ def _keep_pyproject_fields(self, file_content: str) -> dict: for dep in loaded_pyproject_toml["project"]["dependencies"]: dep_name = re.split(r"[<>=\[]", dep)[0].strip() - # Check if dependency is one we track in VERSION_MAP + # Check if dependency is one we track in version map if dep_name in VERSION_MAP: - # For tracked dependencies, check if the version is higher than our default - default_version = parse_version(VERSION_MAP[dep_name]) - dep_version = self._extract_min_dependency(dep) - # If the version is higher than the default, update VERSION_MAP - # with higher min dependency version - if dep_version > default_version: - VERSION_MAP[dep_name] = str(dep_version) + self._update_version_map(VERSION_MAP, dep_name, dep) + elif dep_name in additional_version_map: + self._update_version_map(additional_version_map, dep_name, dep) else: # Keep non-default dependencies kept_deps.append(dep) @@ -107,9 +112,18 @@ def _keep_pyproject_fields(self, file_content: str) -> dict: def serialize_package_file(self, template_name: str, file_content: str, **kwargs: Any) -> str: template = self.env.get_template(template_name) + additional_version_map = {} + if self.code_model.has_external_type: + for item in self.code_model.external_types: + if item.package_name: + if item.min_version: + additional_version_map[item.package_name] = item.min_version + else: + additional_version_map[item.package_name] = "0" + # Add fields to keep from an existing pyproject.toml if template_name == "pyproject.toml.jinja2": - params = self._keep_pyproject_fields(file_content) + params = self._keep_pyproject_fields(file_content, additional_version_map) else: params = {} @@ -126,6 +140,7 @@ def serialize_package_file(self, template_name: str, file_content: str, **kwargs dev_status = "4 - Beta" else: dev_status = "5 - Production/Stable" + params |= { "code_model": self.code_model, "dev_status": dev_status, @@ -136,6 +151,7 @@ def serialize_package_file(self, template_name: str, file_content: str, **kwargs "VERSION_MAP": VERSION_MAP, "MIN_PYTHON_VERSION": MIN_PYTHON_VERSION, "MAX_PYTHON_VERSION": MAX_PYTHON_VERSION, + "ADDITIONAL_DEPENDENCIES": [f"{item[0]}>={item[1]}" for item in additional_version_map.items()], } params |= {"options": self.code_model.options} params |= kwargs diff --git a/packages/http-client-python/generator/pygen/codegen/templates/model_base.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/model_base.py.jinja2 index 3158f8c46e6..8d9fbf2b825 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/model_base.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/model_base.py.jinja2 @@ -25,6 +25,9 @@ from {{ code_model.core_library }}.exceptions import DeserializationError from {{ code_model.core_library }}{{ "" if code_model.is_azure_flavor else ".utils" }} import CaseInsensitiveEnumMeta from {{ code_model.core_library }}.{{ "" if code_model.is_azure_flavor else "runtime." }}pipeline import PipelineResponse from {{ code_model.core_library }}.serialization import _Null +{% if code_model.has_external_type %} +from {{ code_model.core_library }}.serialization import TypeHandlerRegistry +{% endif %} from {{ code_model.core_library }}.rest import HttpResponse _LOGGER = logging.getLogger(__name__) @@ -34,6 +37,10 @@ __all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"] TZ_UTC = timezone.utc _T = typing.TypeVar("_T") +{% if code_model.has_external_type %} +TYPE_HANDLER_REGISTRY = TypeHandlerRegistry() +{% endif %} + def _timedelta_as_isostr(td: timedelta) -> str: """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' @@ -158,6 +165,11 @@ class SdkJSONEncoder(JSONEncoder): except AttributeError: # This will be raised when it hits value.total_seconds in the method above pass + {% if code_model.has_external_type %} + custom_serializer = TYPE_HANDLER_REGISTRY.get_serializer(o) + if custom_serializer: + return custom_serializer(o) + {% endif %} return super(SdkJSONEncoder, self).default(o) @@ -313,7 +325,13 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _deserialize_int_as_str if rf and rf._format: return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) + {% if code_model.has_external_type %} + if _DESERIALIZE_MAPPING.get(annotation): + return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore + return TYPE_HANDLER_REGISTRY.get_deserializer(annotation) # pyright: ignore + {% else %} return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore + {% endif %} def _get_type_alias_type(module_name: str, alias_name: str): @@ -507,6 +525,14 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m except AttributeError: # This will be raised when it hits value.total_seconds in the method above pass + + {% if code_model.has_external_type %} + # Check if there's a custom serializer for the type + custom_serializer = TYPE_HANDLER_REGISTRY.get_serializer(o) + if custom_serializer: + return custom_serializer(o) + {% endif %} + return o diff --git a/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/pyproject.toml.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/pyproject.toml.jinja2 index 0f52b35fd17..4b6d74796c5 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/pyproject.toml.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/pyproject.toml.jinja2 @@ -56,6 +56,9 @@ dependencies = [ "{{ dep }}", {% endfor %} {% endif %} + {% for dep in ADDITIONAL_DEPENDENCIES %} + "{{ dep }}", + {% endfor %} ] dynamic = [ {% if options.get('package-mode') %}"version", {% endif %}"readme" diff --git a/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/setup.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/setup.py.jinja2 index 2590ce72776..396d915e0e9 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/setup.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/packaging_templates/setup.py.jinja2 @@ -108,6 +108,9 @@ setup( "corehttp[requests]>={{ VERSION_MAP["corehttp"] }}", {% endif %} "typing-extensions>={{ VERSION_MAP['typing-extensions'] }}", + {% for dep in ADDITIONAL_DEPENDENCIES %} + {{ dep }}, + {% endfor %} ], {% if options["package-mode"] %} python_requires=">={{ MIN_PYTHON_VERSION }}", diff --git a/packages/http-client-python/generator/test/azure/requirements.txt b/packages/http-client-python/generator/test/azure/requirements.txt index 9bd7d045774..d7f33efcfdf 100644 --- a/packages/http-client-python/generator/test/azure/requirements.txt +++ b/packages/http-client-python/generator/test/azure/requirements.txt @@ -14,6 +14,7 @@ azure-mgmt-core==1.6.0 -e ./generated/azure-client-generator-core-usage -e ./generated/azure-client-generator-core-override -e ./generated/azure-client-generator-core-client-location +-e ./generated/azure-client-generator-core-alternate-type -e ./generated/azure-core-basic -e ./generated/azure-core-scalar -e ./generated/azure-core-lro-rpc