Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ protected override ScmMethodProvider[] BuildMethods()
Return(
Static(typeof(Volatile)).Invoke(nameof(Volatile.Read), cachedClientFieldVar)
.NullCoalesce(interlockedCompareExchange.NullCoalesce(subClient._clientCachingField))),
this);
this,
ScmMethodKind.Convenience);
methods.Add(factoryMethod);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ protected override ScmMethodProvider[] BuildMethods()
private ScmMethodProvider BuildCreateRequestMethod(InputServiceMethod serviceMethod, bool isNextLinkRequest = false)
{
var options = ScmKnownParameters.RequestOptions;
var parameters = GetMethodParameters(serviceMethod, MethodType.CreateRequest);
var parameters = GetMethodParameters(serviceMethod, ScmMethodKind.CreateRequest);
if (isNextLinkRequest)
{
parameters = [ScmKnownParameters.NextPage, .. parameters];
Expand All @@ -123,6 +123,7 @@ private ScmMethodProvider BuildCreateRequestMethod(InputServiceMethod serviceMet
signature,
messageStatements,
this,
ScmMethodKind.CreateRequest,
xmlDocProvider: XmlDocProvider.Empty,
serviceMethod: serviceMethod);
}
Expand Down Expand Up @@ -757,7 +758,7 @@ public MethodProvider GetCreateNextLinkRequestMethod(InputOperation operation)
return NextMethodCache[operation];
}

internal static List<ParameterProvider> GetMethodParameters(InputServiceMethod serviceMethod, MethodType methodType)
internal static List<ParameterProvider> GetMethodParameters(InputServiceMethod serviceMethod, ScmMethodKind methodType)
{
SortedList<int, ParameterProvider> sortedParams = [];
int path = 0;
Expand All @@ -768,10 +769,10 @@ internal static List<ParameterProvider> GetMethodParameters(InputServiceMethod s

var operation = serviceMethod.Operation;
// For convenience methods, use the service method parameters
var inputParameters = methodType is MethodType.Convenience ? serviceMethod.Parameters : operation.Parameters;
var inputParameters = methodType is ScmMethodKind.Convenience ? serviceMethod.Parameters : operation.Parameters;

ModelProvider? spreadSource = null;
if (methodType == MethodType.Convenience)
if (methodType == ScmMethodKind.Convenience)
{
InputParameter? inputOperationSpreadParameter = operation.Parameters.FirstOrDefault(p => p.Scope.HasFlag(InputParameterScope.Spread));
spreadSource = inputOperationSpreadParameter != null
Expand Down Expand Up @@ -816,11 +817,11 @@ internal static List<ParameterProvider> GetMethodParameters(InputServiceMethod s
continue;
}

if (methodType is MethodType.Protocol or MethodType.CreateRequest)
if (methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest)
{
if (inputParam is InputBodyParameter)
{
if (methodType == MethodType.CreateRequest)
if (methodType == ScmMethodKind.CreateRequest)
{
parameter = ScmKnownParameters.RequestContent;
}
Expand All @@ -836,7 +837,7 @@ internal static List<ParameterProvider> GetMethodParameters(InputServiceMethod s
parameter.Type = parameter.Type.IsEnum ? parameter.Type.UnderlyingEnumType : parameter.Type;
}
}
else if (methodType is MethodType.Convenience &&
else if (methodType is ScmMethodKind.Convenience &&
spreadSource != null
&& inputParam is InputMethodParameter inputMethodParameter
&& inputMethodParameter.Location == InputRequestLocation.Body)
Expand Down Expand Up @@ -875,7 +876,7 @@ internal static List<ParameterProvider> GetMethodParameters(InputServiceMethod s
sortedParams.Add(bodyRequired++, ScmKnownParameters.ContentType);
}

if (methodType == MethodType.CreateRequest)
if (methodType == ScmMethodKind.CreateRequest)
{
// All the parameters should be required for the CreateRequest method
foreach (var parameter in sortedParams.Values)
Expand All @@ -897,13 +898,6 @@ internal static InputModelType GetSpreadParameterModel(InputParameter inputParam
throw new InvalidOperationException($"inputParam `{inputParam.Name}` is `Spread` but not a model type");
}

internal enum MethodType
{
CreateRequest,
Protocol,
Convenience
}

private class StatusCodesComparer : IEqualityComparer<List<int>>
{
bool IEqualityComparer<List<int>>.Equals(List<int>? x, List<int>? y)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

namespace Microsoft.TypeSpec.Generator.ClientModel.Providers
{
/// <summary>
/// Defines the different types of methods that can be generated by the SCM (Service Client Model) generator.
/// </summary>
public enum ScmMethodKind
{
/// <summary>
/// Internal method that creates HTTP request messages for protocol methods.
/// These methods are typically used internally by protocol methods to construct HTTP requests.
/// </summary>
CreateRequest,

/// <summary>
/// Protocol method that handles raw HTTP requests and responses.
/// These methods provide low-level access to HTTP operations with minimal abstraction.
/// </summary>
Protocol,

/// <summary>
/// Convenience method with strongly-typed parameters and return values.
/// These methods provide a high-level, developer-friendly API with strong typing.
/// </summary>
Convenience
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,33 @@
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Statements;
using Microsoft.TypeSpec.Generator.ClientModel.Primitives;

namespace Microsoft.TypeSpec.Generator.ClientModel.Providers
{
public class ScmMethodProvider : MethodProvider
{
public InputServiceMethod? ServiceMethod { get; }
public TypeProvider? CollectionDefinition { get; }
public bool IsProtocolMethod { get; }

/// <summary>
/// Gets the kind of method (CreateRequest, Protocol, or Convenience).
/// </summary>
public ScmMethodKind Kind { get; }

public ScmMethodProvider(
MethodSignature signature,
MethodBodyStatement bodyStatements,
TypeProvider enclosingType,
ScmMethodKind methodKind,
XmlDocProvider? xmlDocProvider = default,
TypeProvider? collectionDefinition = default,
InputServiceMethod? serviceMethod = default,
bool isProtocolMethod = false)
InputServiceMethod? serviceMethod = default)
: base(signature, bodyStatements, enclosingType, xmlDocProvider)
{
CollectionDefinition = collectionDefinition;
IsProtocolMethod = isProtocolMethod;
ServiceMethod = serviceMethod;
Kind = methodKind;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ public class ScmMethodProviderCollection : IReadOnlyList<ScmMethodProvider>
private readonly MethodProvider _createRequestMethod;
private static readonly ClientPipelineExtensionsDefinition _clientPipelineExtensionsDefinition = new();
private static readonly CancellationTokenExtensionsDefinition _cancellationTokenExtensionsDefinition = new();
private IList<ParameterProvider> ProtocolMethodParameters => _protocolMethodParameters ??= RestClientProvider.GetMethodParameters(ServiceMethod, RestClientProvider.MethodType.Protocol);
private IList<ParameterProvider> ProtocolMethodParameters => _protocolMethodParameters ??= RestClientProvider.GetMethodParameters(ServiceMethod, ScmMethodKind.Protocol);
private IList<ParameterProvider>? _protocolMethodParameters;

private IReadOnlyList<ParameterProvider> ConvenienceMethodParameters => _convenienceMethodParameters ??= RestClientProvider.GetMethodParameters(ServiceMethod, RestClientProvider.MethodType.Convenience);
private IReadOnlyList<ParameterProvider> ConvenienceMethodParameters => _convenienceMethodParameters ??= RestClientProvider.GetMethodParameters(ServiceMethod, ScmMethodKind.Convenience);
private IReadOnlyList<ParameterProvider>? _convenienceMethodParameters;
private readonly InputPagingServiceMethod? _pagingServiceMethod;
private IReadOnlyList<ScmMethodProvider>? _methods;
Expand Down Expand Up @@ -160,7 +160,7 @@ .. GetStackVariablesForReturnValueConversion(result, responseBodyType, isAsync,
];
}

var convenienceMethod = new ScmMethodProvider(methodSignature, methodBody, EnclosingType, collectionDefinition: collection, serviceMethod: ServiceMethod);
var convenienceMethod = new ScmMethodProvider(methodSignature, methodBody, EnclosingType, ScmMethodKind.Convenience, collectionDefinition: collection, serviceMethod: ServiceMethod);

if (convenienceMethod.XmlDocs != null)
{
Expand Down Expand Up @@ -653,7 +653,7 @@ private ScmMethodProvider BuildProtocolMethod(MethodProvider createRequestMethod
}

var protocolMethod =
new ScmMethodProvider(methodSignature, methodBody, EnclosingType, collectionDefinition: collection, serviceMethod: ServiceMethod, isProtocolMethod: true);
new ScmMethodProvider(methodSignature, methodBody, EnclosingType, ScmMethodKind.Protocol, collectionDefinition: collection, serviceMethod: ServiceMethod);

if (protocolMethod.XmlDocs != null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void ValidateProperties()
[TestCaseSource(nameof(GetMethodParametersTestCases))]
public void TestGetMethodParameters(InputServiceMethod inputServiceMethod)
{
var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, RestClientProvider.MethodType.Convenience);
var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, ScmMethodKind.Convenience);

Assert.IsTrue(methodParameters.Count > 0);

Expand Down Expand Up @@ -139,7 +139,7 @@ public void TestGetMethodParameters(InputServiceMethod inputServiceMethod)
[TestCase]
public void TestGetMethodParameters_ProperOrdering()
{
var methodParameters = RestClientProvider.GetMethodParameters(ServiceMethodWithMixedParamOrdering, RestClientProvider.MethodType.Convenience);
var methodParameters = RestClientProvider.GetMethodParameters(ServiceMethodWithMixedParamOrdering, ScmMethodKind.Convenience);

Assert.AreEqual(ServiceMethodWithMixedParamOrdering.Parameters.Count, methodParameters.Count);

Expand All @@ -152,7 +152,7 @@ public void TestGetMethodParameters_ProperOrdering()
Assert.AreEqual("optionalHeader", methodParameters[5].Name);
Assert.AreEqual("optionalContentType", methodParameters[6].Name);

var orderedPathParams = RestClientProvider.GetMethodParameters(ServiceMethodWithOnlyPathParams, RestClientProvider.MethodType.Convenience);
var orderedPathParams = RestClientProvider.GetMethodParameters(ServiceMethodWithOnlyPathParams, ScmMethodKind.Convenience);
Assert.AreEqual(ServiceMethodWithOnlyPathParams.Parameters.Count, orderedPathParams.Count);
Assert.AreEqual("c", orderedPathParams[0].Name);
Assert.AreEqual("a", orderedPathParams[1].Name);
Expand Down Expand Up @@ -185,7 +185,7 @@ public void HeaderParameterOptionality(bool isRequired, bool isValueType)
"TestClient",
methods: [testServiceMethod]);
var clientProvider = new ClientProvider(client);
var parameters = RestClientProvider.GetMethodParameters(testServiceMethod, RestClientProvider.MethodType.Convenience);
var parameters = RestClientProvider.GetMethodParameters(testServiceMethod, ScmMethodKind.Convenience);
Assert.IsNotNull(parameters);

if (isRequired)
Expand Down Expand Up @@ -705,7 +705,7 @@ public void TestReadOnlyParameters_FilteredFromProtocolMethod()
InputFactory.MethodParameter("normalBody", InputPrimitiveType.Boolean, isRequired: false, location: InputRequestLocation.Body)
]);

var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, RestClientProvider.MethodType.Protocol);
var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, ScmMethodKind.Protocol);

// Verify read-only parameters are filtered out
Assert.IsFalse(methodParameters.Any(p => p.Name == "readOnlyPath"));
Expand Down Expand Up @@ -737,7 +737,7 @@ public void TestReadOnlyParameters_FilteredFromConvenienceMethod()
InputFactory.MethodParameter("normalHeader", InputPrimitiveType.Boolean, isRequired: false, location: InputRequestLocation.Header)
]);

var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, RestClientProvider.MethodType.Convenience);
var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, ScmMethodKind.Convenience);

// Verify read-only parameters are filtered out
Assert.IsFalse(methodParameters.Any(p => p.Name == "readOnlyQuery"));
Expand Down Expand Up @@ -773,7 +773,7 @@ public void TestReadOnlyParameters_WithMixedParameterTypes()
InputFactory.MethodParameter("normalBody", InputPrimitiveType.Boolean, isRequired: false, location: InputRequestLocation.Body)
]);

var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, RestClientProvider.MethodType.Convenience);
var methodParameters = RestClientProvider.GetMethodParameters(inputServiceMethod, ScmMethodKind.Convenience);

Assert.AreEqual(4, methodParameters.Count); // Only non-readonly parameters
Assert.IsTrue(methodParameters.Any(p => p.Name == "normalPath"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,5 +1073,29 @@ public static IEnumerable<TestCaseData> DefaultCSharpMethodCollectionTestCases
]));
}
}

[Test]
public void TestMethodTypeIdentification()
{
MockHelpers.LoadMockGenerator();

var inputOperation = InputFactory.Operation("TestOperation");
var inputServiceMethod = InputFactory.BasicServiceMethod("Test", inputOperation);
var inputClient = InputFactory.Client("TestClient", methods: [inputServiceMethod]);
var client = ScmCodeModelGenerator.Instance.TypeFactory.CreateClient(inputClient);
var methodCollection = new ScmMethodProviderCollection(inputServiceMethod, client!);

// Verify protocol methods
var protocolMethods = methodCollection.Where(m => ((ScmMethodProvider)m).Kind == ScmMethodKind.Protocol).ToList();
Assert.AreEqual(2, protocolMethods.Count); // sync + async

// Verify convenience methods
var convenienceMethods = methodCollection.Where(m => ((ScmMethodProvider)m).Kind == ScmMethodKind.Convenience).ToList();
Assert.AreEqual(2, convenienceMethods.Count); // sync + async

// Verify CreateRequest method
var createRequestMethod = (ScmMethodProvider)client!.RestClient.GetCreateRequestMethod(inputOperation);
Assert.AreEqual(ScmMethodKind.CreateRequest, createRequestMethod.Kind);
}
}
}
Loading