Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/generators/proxy.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ Main attribute for marking interfaces or abstract classes for proxy generation.
|----------|------|---------|-------------|
| `ProxyTypeName` | `string?` | `{ContractName}Proxy` | Name of the generated proxy class |
| `InterceptorMode` | `ProxyInterceptorMode` | `Single` | Interceptor support mode |
| `GenerateAsync` | `bool?` | Auto-detected | Generate async interceptor methods |
| `GenerateAsync` | `bool` | Auto-detected when omitted | Generate async interceptor methods |
| `ForceAsync` | `bool` | `false` | Force async even if no async members detected |
| `Exceptions` | `ProxyExceptionPolicy` | `Rethrow` | Exception handling policy |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public sealed class GenerateProxyAttribute : Attribute
/// If not specified, async support is inferred from the contract
/// (enabled if any member returns Task/ValueTask or has a CancellationToken parameter).
/// </summary>
public bool? GenerateAsync { get; set; }
public bool GenerateAsync { get; set; }

/// <summary>
/// Gets or sets whether to force async interceptor hooks even if no async members are detected.
Expand Down
93 changes: 27 additions & 66 deletions src/PatternKit.Generators/ProxyGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ private void GenerateInterceptedDelegation(StringBuilder sb, MemberInfo member,
var interceptorCheck = config.InterceptorMode == ProxyInterceptorMode.Single
? "_interceptor is null"
: "_interceptors is null || _interceptors.Count == 0";
var useAsync = contractInfo.HasAsyncMembers && member.IsAsync;

sb.AppendLine($" if ({interceptorCheck})");
sb.AppendLine(" {");
Expand All @@ -996,13 +997,13 @@ private void GenerateInterceptedDelegation(StringBuilder sb, MemberInfo member,
else
{
var refModifier = member.ReturnsByRef || member.ReturnsByRefReadonly ? "ref " : "";
if (member.IsAsync && !member.IsGenericAsyncReturnType)
if (useAsync && !member.IsGenericAsyncReturnType)
{
sb.Append($" await _inner.{member.Name}(");
}
else
{
var awaitModifier = member.IsAsync ? "await " : "";
var awaitModifier = useAsync ? "await " : "";
sb.Append($" return {awaitModifier}{refModifier}_inner.{member.Name}(");
}
sb.Append(string.Join(", ", member.Parameters.Select(p =>
Expand All @@ -1017,7 +1018,7 @@ private void GenerateInterceptedDelegation(StringBuilder sb, MemberInfo member,
return $"{refKind}{p.Name}";
})));
sb.AppendLine(");");
if (member.IsAsync && !member.IsGenericAsyncReturnType)
if (useAsync && !member.IsGenericAsyncReturnType)
{
sb.AppendLine(" return;");
}
Expand All @@ -1035,8 +1036,6 @@ private void GenerateInterceptedDelegation(StringBuilder sb, MemberInfo member,
sb.AppendLine();

// Use async or sync based on detection and configuration
bool useAsync = contractInfo.HasAsyncMembers && member.IsAsync;

if (useAsync)
{
GenerateAsyncInterceptedCall(sb, member, config, contextTypeName);
Expand Down Expand Up @@ -1172,70 +1171,32 @@ private void GenerateAsyncInterceptedCall(StringBuilder sb, MemberInfo member, P
sb.AppendLine(" }");
}

// Actual method call
if (member.IsVoid)
{
sb.Append($" _inner.{member.Name}(");
sb.Append(string.Join(", ", member.Parameters.Select(p =>
{
var refKind = p.RefKind switch
{
RefKind.Ref => "ref ",
RefKind.Out => "out ",
RefKind.In => "in ",
_ => ""
};
return $"{refKind}{p.Name}";
})));
sb.AppendLine(");");
}
else if (member.IsAsync)
// For async methods, get the task and await it.
sb.Append($" var __task = _inner.{member.Name}(");
sb.Append(string.Join(", ", member.Parameters.Select(p =>
{
// For async methods, get the task and await it
sb.Append($" var __task = _inner.{member.Name}(");
sb.Append(string.Join(", ", member.Parameters.Select(p =>
var refKind = p.RefKind switch
{
var refKind = p.RefKind switch
{
RefKind.Ref => "ref ",
RefKind.Out => "out ",
RefKind.In => "in ",
_ => ""
};
return $"{refKind}{p.Name}";
})));
sb.AppendLine(");");
sb.AppendLine(" __context.SetResult(__task);");
RefKind.Ref => "ref ",
RefKind.Out => "out ",
RefKind.In => "in ",
_ => ""
};
return $"{refKind}{p.Name}";
})));
sb.AppendLine(");");
sb.AppendLine(" __context.SetResult(__task);");

// Check if the async method returns a value (Task<T> or ValueTask<T> vs Task or ValueTask)
if (member.IsGenericAsyncReturnType)
{
// Await and store result for later return
sb.AppendLine(" var __result = await __task.ConfigureAwait(false);");
}
else
{
// Task or ValueTask with no result - just await
sb.AppendLine(" await __task.ConfigureAwait(false);");
}
// Check if the async method returns a value (Task<T> or ValueTask<T> vs Task or ValueTask)
if (member.IsGenericAsyncReturnType)
{
// Await and store result for later return
sb.AppendLine(" var __result = await __task.ConfigureAwait(false);");
}
else
{
var refModifier = member.ReturnsByRef || member.ReturnsByRefReadonly ? "ref " : "";
sb.Append($" var __result = {refModifier}_inner.{member.Name}(");
sb.Append(string.Join(", ", member.Parameters.Select(p =>
{
var refKind = p.RefKind switch
{
RefKind.Ref => "ref ",
RefKind.Out => "out ",
RefKind.In => "in ",
_ => ""
};
return $"{refKind}{p.Name}";
})));
sb.AppendLine(");");
sb.AppendLine(" __context.SetResult(__result);");
// Task or ValueTask with no result - just await
sb.AppendLine(" await __task.ConfigureAwait(false);");
}
sb.AppendLine();

Expand All @@ -1252,8 +1213,8 @@ private void GenerateAsyncInterceptedCall(StringBuilder sb, MemberInfo member, P
sb.AppendLine(" }");
}

// Return statement (only for non-void and for async methods with generic Task<T>/ValueTask<T>)
if (!member.IsVoid && (!member.IsAsync || member.IsGenericAsyncReturnType))
// Return statement for async methods with generic Task<T>/ValueTask<T>.
if (member.IsGenericAsyncReturnType)
{
sb.AppendLine(" return __result;");
}
Expand Down Expand Up @@ -1282,7 +1243,7 @@ private void GenerateAsyncInterceptedCall(StringBuilder sb, MemberInfo member, P
}
else // Swallow
{
if (!member.IsVoid && (!member.IsAsync || member.IsGenericAsyncReturnType))
if (member.IsGenericAsyncReturnType)
{
sb.AppendLine(" return default!;");
}
Expand Down
157 changes: 157 additions & 0 deletions test/PatternKit.Generators.Tests/ProxyGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1531,4 +1531,161 @@ string Format(
var emit = updated.Emit(Stream.Null);
ScenarioExpect.True(emit.Success, string.Join("\n", emit.Diagnostics));
}

[Scenario("GenerateProxy GenerateAsyncFalse ReportsWarningsAndSuppressesAsyncHooks")]
[Fact]
public void GenerateProxy_GenerateAsyncFalse_ReportsWarningsAndSuppressesAsyncHooks()
{
const string source = """
using PatternKit.Generators.Proxy;
using System.Threading;
using System.Threading.Tasks;

namespace TestNamespace;

[GenerateProxy(GenerateAsync = false)]
public partial interface IExplicitSyncProxy
{
Task SaveAsync(CancellationToken cancellationToken = default);
string Poll(CancellationToken cancellationToken = default);
}
""";

var comp = RoslynTestHelpers.CreateCompilation(source, nameof(GenerateProxy_GenerateAsyncFalse_ReportsWarningsAndSuppressesAsyncHooks));
var gen = new ProxyGenerator();
_ = RoslynTestHelpers.Run(comp, gen, out var result, out var updated);

var diagnostics = result.Results.SelectMany(r => r.Diagnostics).ToArray();
ScenarioExpect.Equal(2, diagnostics.Count(d => d.Id == "PKPRX005"));
ScenarioExpect.Contains(diagnostics, d => d.Id == "PKPRX005" && d.GetMessage().Contains("SaveAsync"));
ScenarioExpect.Contains(diagnostics, d => d.Id == "PKPRX005" && d.GetMessage().Contains("Poll"));

var interceptorSource = result.Results
.SelectMany(r => r.GeneratedSources)
.Single(gs => gs.HintName == "TestNamespace_IExplicitSyncProxy.Proxy.Interceptor.g.cs")
.SourceText.ToString();

ScenarioExpect.DoesNotContain("BeforeAsync", interceptorSource);

var emit = updated.Emit(Stream.Null);
ScenarioExpect.True(emit.Success, string.Join("\n", emit.Diagnostics));
}

[Scenario("GenerateProxy Defaults CoverDynamicNullParamsAndReservedContextNames")]
[Fact]
public void GenerateProxy_Defaults_CoverDynamicNullParamsAndReservedContextNames()
{
const string source = """
using PatternKit.Generators.Proxy;

namespace TestNamespace;

[GenerateProxy]
public partial interface IContextNameProxy
{
string Format(
dynamic payload = null,
string methodName = "calculate",
string arguments = "items",
string result = "ok",
params string[] tags);
}
""";

var comp = RoslynTestHelpers.CreateCompilation(
source,
nameof(GenerateProxy_Defaults_CoverDynamicNullParamsAndReservedContextNames),
extra:
[
MetadataReference.CreateFromFile(typeof(System.Runtime.CompilerServices.DynamicAttribute).Assembly.Location),
MetadataReference.CreateFromFile(typeof(Microsoft.CSharp.RuntimeBinder.CSharpArgumentInfo).Assembly.Location)
]);
var gen = new ProxyGenerator();
_ = RoslynTestHelpers.Run(comp, gen, out var result, out var updated);

ScenarioExpect.All(result.Results, r => ScenarioExpect.Empty(r.Diagnostics));

var proxySource = result.Results
.SelectMany(r => r.GeneratedSources)
.Single(gs => gs.HintName == "TestNamespace_IContextNameProxy.Proxy.g.cs")
.SourceText.ToString();
var interceptorSource = result.Results
.SelectMany(r => r.GeneratedSources)
.Single(gs => gs.HintName == "TestNamespace_IContextNameProxy.Proxy.Interceptor.g.cs")
.SourceText.ToString();

ScenarioExpect.Contains("dynamic payload = null", proxySource);
ScenarioExpect.Contains("params string[] tags", proxySource);
ScenarioExpect.Contains("Arg_MethodName", interceptorSource);
ScenarioExpect.Contains("Arg_Arguments", interceptorSource);
ScenarioExpect.Contains("Arg_Result", interceptorSource);

var emit = updated.Emit(Stream.Null);
ScenarioExpect.True(emit.Success, string.Join("\n", emit.Diagnostics));
}

[Scenario("GenerateProxy NoInterceptorVoidMethodWithPlainParameter Delegates")]
[Fact]
public void GenerateProxy_NoInterceptorVoidMethodWithPlainParameter_Delegates()
{
const string source = """
using PatternKit.Generators.Proxy;

namespace TestNamespace;

[GenerateProxy(InterceptorMode = ProxyInterceptorMode.None)]
public partial interface IVoidDelegateProxy
{
void Track(string message);
}
""";

var comp = RoslynTestHelpers.CreateCompilation(source, nameof(GenerateProxy_NoInterceptorVoidMethodWithPlainParameter_Delegates));
var gen = new ProxyGenerator();
_ = RoslynTestHelpers.Run(comp, gen, out var result, out var updated);

ScenarioExpect.All(result.Results, r => ScenarioExpect.Empty(r.Diagnostics));

var proxySource = result.Results
.SelectMany(r => r.GeneratedSources)
.Single(gs => gs.HintName == "TestNamespace_IVoidDelegateProxy.Proxy.g.cs")
.SourceText.ToString();

ScenarioExpect.Contains("_inner.Track(message);", proxySource);
ScenarioExpect.DoesNotContain("return _inner.Track", proxySource);

var emit = updated.Emit(Stream.Null);
ScenarioExpect.True(emit.Success, string.Join("\n", emit.Diagnostics));
}

[Scenario("GenerateProxy AsyncInterceptedRefOutInArguments AreForwarded")]
[Fact]
public void GenerateProxy_AsyncInterceptedRefOutInArguments_AreForwarded()
{
const string source = """
using PatternKit.Generators.Proxy;
using System.Threading.Tasks;

namespace TestNamespace;

[GenerateProxy]
public partial interface IAsyncRefArgumentProxy
{
Task<int> CountAsync(ref int source, out int destination, in bool enabled);
}
""";

var comp = RoslynTestHelpers.CreateCompilation(source, nameof(GenerateProxy_AsyncInterceptedRefOutInArguments_AreForwarded));
var gen = new ProxyGenerator();
_ = RoslynTestHelpers.Run(comp, gen, out var result, out _);

ScenarioExpect.All(result.Results, r => ScenarioExpect.Empty(r.Diagnostics));

var proxySource = result.Results
.SelectMany(r => r.GeneratedSources)
.Single(gs => gs.HintName == "TestNamespace_IAsyncRefArgumentProxy.Proxy.g.cs")
.SourceText.ToString();

ScenarioExpect.Contains("_inner.CountAsync(ref source, out destination, in enabled);", proxySource);
}
}
Loading