diff --git a/src/PatternKit.Generators/Messaging/ClaimCheckGenerator.cs b/src/PatternKit.Generators/Messaging/ClaimCheckGenerator.cs index 2ce7091d..09891e02 100644 --- a/src/PatternKit.Generators/Messaging/ClaimCheckGenerator.cs +++ b/src/PatternKit.Generators/Messaging/ClaimCheckGenerator.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; @@ -112,6 +113,56 @@ private static string GenerateSource( sb.AppendLine(); } + var containingTypes = GetContainingTypes(type); + var indentLevel = 0; + foreach (var containingType in containingTypes) + { + AppendTypeDeclaration(sb, containingType, indentLevel); + sb.AppendLine(); + sb.AppendLine(new string(' ', indentLevel * 4) + "{"); + indentLevel++; + } + + AppendTypeDeclaration(sb, type, indentLevel); + sb.AppendLine(); + var indent = new string(' ', indentLevel * 4); + sb.AppendLine(indent + "{"); + var memberIndent = indent + " "; + var bodyIndent = memberIndent + " "; + sb.Append(memberIndent).Append("public static global::PatternKit.Messaging.Transformation.ClaimCheck<") + .Append(payloadName).Append("> ").Append(factoryName).AppendLine("()"); + sb.Append(bodyIndent).Append("=> global::PatternKit.Messaging.Transformation.ClaimCheck<") + .Append(payloadName).Append(">.Create(\"").Append(Escape(claimCheckName)).AppendLine("\")"); + sb.Append(bodyIndent).Append(" .InStore(\"").Append(Escape(storeName)).AppendLine("\")"); + sb.Append(bodyIndent).Append(" .UseStore(").Append(storeFactory).AppendLine("())"); + sb.Append(bodyIndent).Append(" .UseClaimIds(static (message, _) => \"") + .Append(Escape(claimIdPrefix)) + .Append(":\" + (message.Headers.MessageId ?? global::System.Guid.NewGuid().ToString(\"N\")))") + .AppendLine(); + sb.Append(bodyIndent).AppendLine(" .Build();"); + sb.AppendLine(indent + "}"); + for (var i = containingTypes.Length - 1; i >= 0; i--) + { + sb.AppendLine(new string(' ', i * 4) + "}"); + } + + return sb.ToString(); + } + + private static INamedTypeSymbol[] GetContainingTypes(INamedTypeSymbol type) + { + var containingTypes = new Stack(); + for (var current = type.ContainingType; current is not null; current = current.ContainingType) + { + containingTypes.Push(current); + } + + return containingTypes.ToArray(); + } + + private static void AppendTypeDeclaration(StringBuilder sb, INamedTypeSymbol type, int indentLevel) + { + sb.Append(new string(' ', indentLevel * 4)); sb.Append(GetAccessibility(type.DeclaredAccessibility)).Append(' '); if (type.IsStatic) sb.Append("static "); @@ -119,21 +170,7 @@ private static string GenerateSource( sb.Append("abstract "); else if (type.IsSealed && type.TypeKind == TypeKind.Class) sb.Append("sealed "); - sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name).AppendLine(); - sb.AppendLine("{"); - sb.Append(" public static global::PatternKit.Messaging.Transformation.ClaimCheck<") - .Append(payloadName).Append("> ").Append(factoryName).AppendLine("()"); - sb.Append(" => global::PatternKit.Messaging.Transformation.ClaimCheck<") - .Append(payloadName).Append(">.Create(\"").Append(Escape(claimCheckName)).AppendLine("\")"); - sb.Append(" .InStore(\"").Append(Escape(storeName)).AppendLine("\")"); - sb.Append(" .UseStore(").Append(storeFactory).AppendLine("())"); - sb.Append(" .UseClaimIds(static (message, _) => \"") - .Append(Escape(claimIdPrefix)) - .Append(":\" + (message.Headers.MessageId ?? global::System.Guid.NewGuid().ToString(\"N\")))") - .AppendLine(); - sb.AppendLine(" .Build();"); - sb.AppendLine("}"); - return sb.ToString(); + sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name); } private static bool IsStoreFactory(IMethodSymbol method, ITypeSymbol payloadType) diff --git a/src/PatternKit.Generators/TransactionScript/TransactionScriptGenerator.cs b/src/PatternKit.Generators/TransactionScript/TransactionScriptGenerator.cs index 5f8f39b9..17a5434c 100644 --- a/src/PatternKit.Generators/TransactionScript/TransactionScriptGenerator.cs +++ b/src/PatternKit.Generators/TransactionScript/TransactionScriptGenerator.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; @@ -119,6 +120,53 @@ private static string GenerateSource( sb.AppendLine(); } + var containingTypes = GetContainingTypes(type); + var indentLevel = 0; + foreach (var containingType in containingTypes) + { + AppendTypeDeclaration(sb, containingType, indentLevel); + sb.AppendLine(); + sb.AppendLine(new string(' ', indentLevel * 4) + "{"); + indentLevel++; + } + + AppendTypeDeclaration(sb, type, indentLevel); + sb.AppendLine(); + var indent = new string(' ', indentLevel * 4); + sb.AppendLine(indent + "{"); + var memberIndent = indent + " "; + var bodyIndent = memberIndent + " "; + sb.Append(memberIndent).Append("public static global::PatternKit.Application.TransactionScript.TransactionScript<") + .Append(requestName).Append(", ").Append(responseName).Append("> ").Append(factoryName).AppendLine("()"); + sb.Append(bodyIndent).Append("=> global::PatternKit.Application.TransactionScript.TransactionScript<") + .Append(requestName).Append(", ").Append(responseName).Append(">.Create(\"").Append(Escape(scriptName)).Append("\")"); + if (validatorName is not null) + sb.AppendLine().Append(bodyIndent).Append(" .Validate(").Append(validatorName).Append(')'); + sb.AppendLine().Append(bodyIndent).Append(" .Execute(").Append(handlerName).AppendLine(")"); + sb.Append(bodyIndent).AppendLine(" .Build();"); + sb.AppendLine(indent + "}"); + for (var i = containingTypes.Length - 1; i >= 0; i--) + { + sb.AppendLine(new string(' ', i * 4) + "}"); + } + + return sb.ToString(); + } + + private static INamedTypeSymbol[] GetContainingTypes(INamedTypeSymbol type) + { + var containingTypes = new Stack(); + for (var current = type.ContainingType; current is not null; current = current.ContainingType) + { + containingTypes.Push(current); + } + + return containingTypes.ToArray(); + } + + private static void AppendTypeDeclaration(StringBuilder sb, INamedTypeSymbol type, int indentLevel) + { + sb.Append(new string(' ', indentLevel * 4)); sb.Append(GetAccessibility(type.DeclaredAccessibility)).Append(' '); if (type.IsStatic) sb.Append("static "); @@ -126,18 +174,7 @@ private static string GenerateSource( sb.Append("abstract "); else if (type.IsSealed && type.TypeKind == TypeKind.Class) sb.Append("sealed "); - sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name).AppendLine(); - sb.AppendLine("{"); - sb.Append(" public static global::PatternKit.Application.TransactionScript.TransactionScript<") - .Append(requestName).Append(", ").Append(responseName).Append("> ").Append(factoryName).AppendLine("()"); - sb.Append(" => global::PatternKit.Application.TransactionScript.TransactionScript<") - .Append(requestName).Append(", ").Append(responseName).Append(">.Create(\"").Append(Escape(scriptName)).Append("\")"); - if (validatorName is not null) - sb.AppendLine().Append(" .Validate(").Append(validatorName).Append(')'); - sb.AppendLine().Append(" .Execute(").Append(handlerName).AppendLine(")"); - sb.AppendLine(" .Build();"); - sb.AppendLine("}"); - return sb.ToString(); + sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name); } private static bool IsHandler(IMethodSymbol method, INamedTypeSymbol requestType, INamedTypeSymbol responseType) diff --git a/test/PatternKit.Generators.Tests/ClaimCheckGeneratorTests.cs b/test/PatternKit.Generators.Tests/ClaimCheckGeneratorTests.cs index 1a3a7e11..1ab8a999 100644 --- a/test/PatternKit.Generators.Tests/ClaimCheckGeneratorTests.cs +++ b/test/PatternKit.Generators.Tests/ClaimCheckGeneratorTests.cs @@ -3,16 +3,18 @@ using PatternKit.Generators.Messaging; using PatternKit.Messaging.Transformation; using TinyBDD; +using TinyBDD.Xunit; +using Xunit.Abstractions; namespace PatternKit.Generators.Tests; -public sealed class ClaimCheckGeneratorTests +[Feature("Claim Check generator")] +public sealed partial class ClaimCheckGeneratorTests(ITestOutputHelper output) : TinyBddXunitBase(output) { [Scenario("Generates claim check factory")] [Fact] - public void GeneratesClaimCheckFactory() - { - var source = """ + public Task Generates_Claim_Check_Factory() + => Given("a valid claim check declaration", () => Compile(""" using PatternKit.Generators.Messaging; using PatternKit.Messaging.Transformation; @@ -26,80 +28,180 @@ public static partial class DocumentClaimCheck [ClaimCheckStoreFactory] private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); } - """; - - var comp = CreateCompilation(source, nameof(GeneratesClaimCheckFactory)); - var gen = new ClaimCheckGenerator(); - _ = RoslynTestHelpers.Run(comp, gen, out var run, out var updated); - - ScenarioExpect.All(run.Results, result => ScenarioExpect.Empty(result.Diagnostics)); - var generated = ScenarioExpect.Single(run.Results.SelectMany(result => result.GeneratedSources)); - var text = generated.SourceText.ToString(); - ScenarioExpect.Equal("DocumentClaimCheck.ClaimCheck.g.cs", generated.HintName); - ScenarioExpect.Contains("Build()", text); - ScenarioExpect.Contains("ClaimCheck.Create(\"documents\")", text); - ScenarioExpect.Contains(".InStore(\"blob-store\")", text); - ScenarioExpect.Contains(".UseStore(CreateStore())", text); - ScenarioExpect.Contains("\"doc:\" + (message.Headers.MessageId", text); - ScenarioExpect.True(updated.Emit(Stream.Null).Success); - } - - [Scenario("Reports diagnostic for non-partial claim check host")] - [Fact] - public void ReportsDiagnosticForNonPartialClaimCheckHost() - { - var source = """ + """)) + .Then("the generated source creates the configured claim check", result => + { + ScenarioExpect.Empty(result.Diagnostics); + var generated = ScenarioExpect.Single(result.GeneratedSources); + ScenarioExpect.Equal("DocumentClaimCheck.ClaimCheck.g.cs", generated.HintName); + ScenarioExpect.Contains("Build()", generated.Source); + ScenarioExpect.Contains("ClaimCheck.Create(\"documents\")", generated.Source); + ScenarioExpect.Contains(".InStore(\"blob-store\")", generated.Source); + ScenarioExpect.Contains(".UseStore(CreateStore())", generated.Source); + ScenarioExpect.Contains("\"doc:\" + (message.Headers.MessageId", generated.Source); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); + }) + .AssertPassed(); + + [Scenario("Reports diagnostics for invalid claim check declarations")] + [Theory] + [InlineData("public static class Host { [ClaimCheckStoreFactory] private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); }", "PKCC001")] + [InlineData("public static partial class Host;", "PKCC002")] + [InlineData("public static partial class Host { [ClaimCheckStoreFactory] private static IClaimCheckStore One() => new InMemoryClaimCheckStore(); [ClaimCheckStoreFactory] private static IClaimCheckStore Two() => new InMemoryClaimCheckStore(); }", "PKCC002")] + [InlineData("public partial class Host { [ClaimCheckStoreFactory] private IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); }", "PKCC003")] + [InlineData("public static partial class Host { [ClaimCheckStoreFactory] private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); }", "PKCC003")] + [InlineData("public static partial class Host { [ClaimCheckStoreFactory] private static IClaimCheckStore CreateStore(string name) => new InMemoryClaimCheckStore(); }", "PKCC003")] + [InlineData("public static partial class Host { [ClaimCheckStoreFactory] private static string CreateStore() => string.Empty; }", "PKCC003")] + [InlineData("public static partial class Host { [ClaimCheckStoreFactory] private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); }", "PKCC003")] + public Task Reports_Diagnostics_For_Invalid_Claim_Check_Declarations(string declaration, string diagnosticId) + => Given("an invalid claim check declaration", () => Compile($$""" using PatternKit.Generators.Messaging; + using PatternKit.Messaging.Transformation; namespace Demo; [GenerateClaimCheck(typeof(string))] - public static class Host; - """; + {{declaration}} + """)) + .Then("the expected diagnostic is reported", result => + ScenarioExpect.Contains(result.Diagnostics, diagnostic => diagnostic.Id == diagnosticId)) + .AssertPassed(); - var diagnostic = RunAndGetSingleDiagnostic(source, nameof(ReportsDiagnosticForNonPartialClaimCheckHost)); - - ScenarioExpect.Equal("PKCC001", diagnostic.Id); - } - - [Scenario("Reports diagnostic for missing claim check store factory")] + [Scenario("Generates claim check defaults and host shapes")] [Fact] - public void ReportsDiagnosticForMissingClaimCheckStoreFactory() - { - var source = """ + public Task Generates_Claim_Check_Defaults_And_Host_Shapes() + => Given("claim check declarations with default names and different host shapes", () => Compile(""" using PatternKit.Generators.Messaging; + using PatternKit.Messaging.Transformation; namespace Demo; - [GenerateClaimCheck(typeof(string))] - public static partial class Host; - """; + public sealed record LargeDocument(string Id, string Content); - var diagnostic = RunAndGetSingleDiagnostic(source, nameof(ReportsDiagnosticForMissingClaimCheckStoreFactory)); + [GenerateClaimCheck(typeof(LargeDocument))] + internal abstract partial class AbstractClaimCheck + { + [ClaimCheckStoreFactory] + private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); + } - ScenarioExpect.Equal("PKCC002", diagnostic.Id); - } + [GenerateClaimCheck(typeof(LargeDocument), ClaimCheckName = "tenant\\\"claim", StoreName = "tenant\\\"store", ClaimIdPrefix = "tenant\\\"prefix")] + public sealed partial class SealedClaimCheck + { + [ClaimCheckStoreFactory] + private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); + } - [Scenario("Reports diagnostic for invalid claim check store factory")] + [GenerateClaimCheck(typeof(LargeDocument))] + internal partial struct StructClaimCheck + { + [ClaimCheckStoreFactory] + private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); + } + """)) + .Then("generated sources preserve host shape and configured names", result => + { + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources.Select(static source => source.Source)); + ScenarioExpect.Contains("internal abstract partial class AbstractClaimCheck", combined); + ScenarioExpect.Contains("public sealed partial class SealedClaimCheck", combined); + ScenarioExpect.Contains("internal partial struct StructClaimCheck", combined); + ScenarioExpect.Contains("Create(\"claim-check\")", combined); + ScenarioExpect.Contains(".InStore(\"claim-store\")", combined); + ScenarioExpect.Contains("\"claim:\" + (message.Headers.MessageId", combined); + ScenarioExpect.Contains("Create(\"tenant\\\\\\\"claim\")", combined); + ScenarioExpect.Contains(".InStore(\"tenant\\\\\\\"store\")", combined); + ScenarioExpect.Contains("\"tenant\\\\\\\"prefix:\" + (message.Headers.MessageId", combined); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); + }) + .AssertPassed(); + + [Scenario("Generates nested claim check host wrappers")] [Fact] - public void ReportsDiagnosticForInvalidClaimCheckStoreFactory() - { - var source = """ + public Task Generates_Nested_Claim_Check_Host_Wrappers() + => Given("nested claim check declarations", () => Compile(""" using PatternKit.Generators.Messaging; + using PatternKit.Messaging.Transformation; namespace Demo; - [GenerateClaimCheck(typeof(string))] - public static partial class Host + public sealed record LargeDocument(string Id, string Content); + + public partial class ClaimCheckContainer { - [ClaimCheckStoreFactory] - private static string CreateStore() => ""; + private partial class PrivateHost + { + [GenerateClaimCheck(typeof(LargeDocument))] + protected partial class ProtectedClaimCheck + { + [ClaimCheckStoreFactory] + private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); + } + + [GenerateClaimCheck(typeof(LargeDocument))] + private protected partial class PrivateProtectedClaimCheck + { + [ClaimCheckStoreFactory] + private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); + } + + [GenerateClaimCheck(typeof(LargeDocument))] + protected internal partial class ProtectedInternalClaimCheck + { + [ClaimCheckStoreFactory] + private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); + } + } } - """; + """)) + .Then("generated sources preserve containing partial type wrappers", result => + { + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources.Select(static source => source.Source)); + ScenarioExpect.Contains("public partial class ClaimCheckContainer", combined); + ScenarioExpect.Contains("private partial class PrivateHost", combined); + ScenarioExpect.Contains("protected partial class ProtectedClaimCheck", combined); + ScenarioExpect.Contains("private protected partial class PrivateProtectedClaimCheck", combined); + ScenarioExpect.Contains("protected internal partial class ProtectedInternalClaimCheck", combined); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); + }) + .AssertPassed(); + + [Scenario("Skips malformed claim check type arguments")] + [Fact] + public Task Skips_Malformed_Claim_Check_Type_Arguments() + => Given("a claim check declaration with a null type argument", () => Compile(""" + using PatternKit.Generators.Messaging; + using PatternKit.Messaging.Transformation; - var diagnostic = RunAndGetSingleDiagnostic(source, nameof(ReportsDiagnosticForInvalidClaimCheckStoreFactory)); + [GenerateClaimCheck(null!)] + public static partial class DocumentClaimCheck + { + [ClaimCheckStoreFactory] + private static IClaimCheckStore CreateStore() => new InMemoryClaimCheckStore(); + } + """)) + .Then("no source is generated", result => + ScenarioExpect.Empty(result.GeneratedSources)) + .AssertPassed(); - ScenarioExpect.Equal("PKCC003", diagnostic.Id); + private static GeneratorResult Compile(string source) + { + var compilation = CreateCompilation(source, "ClaimCheckGeneratorTests"); + _ = RoslynTestHelpers.Run(compilation, new ClaimCheckGenerator(), out var run, out var updated); + var result = run.Results.Single(); + var emit = updated.Emit(Stream.Null); + return new GeneratorResult( + result.Diagnostics.ToArray(), + result.GeneratedSources + .Select(static source => new GeneratedSource(source.HintName, source.SourceText.ToString())) + .ToArray(), + emit.Success, + emit.Diagnostics.Select(static diagnostic => diagnostic.ToString()).ToArray()); } private static CSharpCompilation CreateCompilation(string source, string assemblyName) @@ -117,11 +219,11 @@ private static string GetAbstractionsAssemblyPath() Path.GetDirectoryName(typeof(ClaimCheckGenerator).Assembly.Location)!, "PatternKit.Generators.Abstractions.dll"); - private static Diagnostic RunAndGetSingleDiagnostic(string source, string assemblyName) - { - var comp = CreateCompilation(source, assemblyName); - var gen = new ClaimCheckGenerator(); - _ = RoslynTestHelpers.Run(comp, gen, out var run, out _); - return ScenarioExpect.Single(run.Results.SelectMany(result => result.Diagnostics)); - } + private sealed record GeneratorResult( + IReadOnlyList Diagnostics, + IReadOnlyList GeneratedSources, + bool EmitSuccess, + IReadOnlyList EmitDiagnostics); + + private sealed record GeneratedSource(string HintName, string Source); } diff --git a/test/PatternKit.Generators.Tests/TransactionScriptGeneratorTests.cs b/test/PatternKit.Generators.Tests/TransactionScriptGeneratorTests.cs index 4506dbe1..e4037f98 100644 --- a/test/PatternKit.Generators.Tests/TransactionScriptGeneratorTests.cs +++ b/test/PatternKit.Generators.Tests/TransactionScriptGeneratorTests.cs @@ -39,6 +39,7 @@ public static partial class SubmitOrderScript ScenarioExpect.Contains("Create(\"submit-order\")", source); ScenarioExpect.Contains(".Validate(Validate)", source); ScenarioExpect.Contains(".Execute(Handle)", source); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); }) .AssertPassed(); @@ -47,7 +48,15 @@ public static partial class SubmitOrderScript [InlineData("public static class SubmitOrderScript { [TransactionScriptHandler] private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS001")] [InlineData("public static partial class SubmitOrderScript;", "PKTS002")] [InlineData("public static partial class SubmitOrderScript { [TransactionScriptHandler] private static ValueTask One(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); [TransactionScriptHandler] private static ValueTask Two(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS002")] - [InlineData("public static partial class SubmitOrderScript { [TransactionScriptHandler] private static OrderReceipt Handle(SubmitOrder request) => new(request.OrderId); }", "PKTS003")] + [InlineData("public partial class SubmitOrderScript { [TransactionScriptHandler] private ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS003")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptHandler] private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS003")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptHandler] private static OrderReceipt Handle(SubmitOrder request, CancellationToken cancellationToken) => new(request.OrderId); }", "PKTS003")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptHandler] private static ValueTask Handle() => new(new OrderReceipt(string.Empty)); }", "PKTS003")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptHandler] private static ValueTask Handle(string request, CancellationToken cancellationToken) => new(new OrderReceipt(request)); }", "PKTS003")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptHandler] private static ValueTask Handle(SubmitOrder request, string cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS003")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptValidator] private static IEnumerable One(SubmitOrder request) => []; [TransactionScriptValidator] private static IEnumerable Two(SubmitOrder request) => []; [TransactionScriptHandler] private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS004")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptValidator] private static IEnumerable Validate(SubmitOrder request) => []; [TransactionScriptHandler] private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS004")] + [InlineData("public static partial class SubmitOrderScript { [TransactionScriptValidator] private static IEnumerable Validate(string request) => []; [TransactionScriptHandler] private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS004")] [InlineData("public static partial class SubmitOrderScript { [TransactionScriptValidator] private static string Validate(SubmitOrder request) => \"invalid\"; [TransactionScriptHandler] private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); }", "PKTS004")] public Task Generator_Reports_Invalid_Transaction_Script_Declarations(string declaration, string diagnosticId) => Given("an invalid transaction script declaration", () => Compile($$""" @@ -65,18 +74,147 @@ public sealed record OrderReceipt(string OrderId); ScenarioExpect.Contains(result.Diagnostics, diagnostic => diagnostic.Id == diagnosticId)) .AssertPassed(); + [Scenario("Generator emits transaction script defaults and host shapes")] + [Fact] + public Task Generator_Emits_Transaction_Script_Defaults_And_Host_Shapes() + => Given("transaction script declarations with default names and different host shapes", () => Compile(""" + using System.Threading; + using System.Threading.Tasks; + using PatternKit.Generators.TransactionScript; + namespace Demo; + public sealed record SubmitOrder(string OrderId); + public sealed record OrderReceipt(string OrderId); + + [GenerateTransactionScript(typeof(SubmitOrder), typeof(OrderReceipt))] + internal abstract partial class AbstractScript + { + [TransactionScriptHandler] + private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); + } + + [GenerateTransactionScript(typeof(SubmitOrder), typeof(OrderReceipt), ScriptName = "tenant\\\"script")] + public sealed partial class SealedScript + { + [TransactionScriptHandler] + private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); + } + + [GenerateTransactionScript(typeof(SubmitOrder), typeof(OrderReceipt))] + internal partial struct StructScript + { + [TransactionScriptHandler] + private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); + } + """)) + .Then("generated sources preserve host shape and configured names", result => + { + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources); + ScenarioExpect.Contains("internal abstract partial class AbstractScript", combined); + ScenarioExpect.Contains("public sealed partial class SealedScript", combined); + ScenarioExpect.Contains("internal partial struct StructScript", combined); + ScenarioExpect.Contains("Create(\"AbstractScript\")", combined); + ScenarioExpect.Contains("Create(\"tenant\\\\\\\"script\")", combined); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); + }) + .AssertPassed(); + + [Scenario("Generator emits nested transaction script host wrappers")] + [Fact] + public Task Generator_Emits_Nested_Transaction_Script_Host_Wrappers() + => Given("nested transaction script declarations", () => Compile(""" + using System.Threading; + using System.Threading.Tasks; + using PatternKit.Generators.TransactionScript; + namespace Demo; + public sealed record SubmitOrder(string OrderId); + public sealed record OrderReceipt(string OrderId); + + public partial class ScriptContainer + { + private partial class PrivateHost + { + [GenerateTransactionScript(typeof(SubmitOrder), typeof(OrderReceipt))] + protected partial class ProtectedScript + { + [TransactionScriptHandler] + private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); + } + + [GenerateTransactionScript(typeof(SubmitOrder), typeof(OrderReceipt))] + private protected partial class PrivateProtectedScript + { + [TransactionScriptHandler] + private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); + } + + [GenerateTransactionScript(typeof(SubmitOrder), typeof(OrderReceipt))] + protected internal partial class ProtectedInternalScript + { + [TransactionScriptHandler] + private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); + } + } + } + """)) + .Then("generated sources preserve containing partial type wrappers", result => + { + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources); + ScenarioExpect.Contains("public partial class ScriptContainer", combined); + ScenarioExpect.Contains("private partial class PrivateHost", combined); + ScenarioExpect.Contains("protected partial class ProtectedScript", combined); + ScenarioExpect.Contains("private protected partial class PrivateProtectedScript", combined); + ScenarioExpect.Contains("protected internal partial class ProtectedInternalScript", combined); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); + }) + .AssertPassed(); + + [Scenario("Generator skips malformed transaction script type arguments")] + [Theory] + [InlineData("null!", "typeof(OrderReceipt)")] + [InlineData("typeof(SubmitOrder)", "null!")] + public Task Generator_Skips_Malformed_Transaction_Script_Type_Arguments(string requestType, string responseType) + => Given("a transaction script declaration with a null type argument", () => Compile($$""" + using System.Threading; + using System.Threading.Tasks; + using PatternKit.Generators.TransactionScript; + public sealed record SubmitOrder(string OrderId); + public sealed record OrderReceipt(string OrderId); + [GenerateTransactionScript({{requestType}}, {{responseType}})] + public static partial class SubmitOrderScript + { + [TransactionScriptHandler] + private static ValueTask Handle(SubmitOrder request, CancellationToken cancellationToken) => new(new OrderReceipt(request.OrderId)); + } + """)) + .Then("no source is generated", result => + ScenarioExpect.Empty(result.GeneratedSources)) + .AssertPassed(); + private static GeneratorResult Compile(string source) { var compilation = RoslynTestHelpers.CreateCompilation( source, "TransactionScriptGeneratorTests", extra: MetadataReference.CreateFromFile(typeof(TransactionScript<,>).Assembly.Location)); - _ = RoslynTestHelpers.Run(compilation, new TransactionScriptGenerator(), out var run, out _); + _ = RoslynTestHelpers.Run(compilation, new TransactionScriptGenerator(), out var run, out var updated); var result = run.Results.Single(); + var emit = updated.Emit(Stream.Null); return new GeneratorResult( result.Diagnostics.ToArray(), - result.GeneratedSources.Select(static source => source.SourceText.ToString()).ToArray()); + result.GeneratedSources.Select(static source => source.SourceText.ToString()).ToArray(), + emit.Success, + emit.Diagnostics.Select(static diagnostic => diagnostic.ToString()).ToArray()); } - private sealed record GeneratorResult(IReadOnlyList Diagnostics, IReadOnlyList GeneratedSources); + private sealed record GeneratorResult( + IReadOnlyList Diagnostics, + IReadOnlyList GeneratedSources, + bool EmitSuccess, + IReadOnlyList EmitDiagnostics); }