From 98183b66e1c093ab5f7978ac40338e6661b99ad3 Mon Sep 17 00:00:00 2001 From: JerrettDavis Date: Fri, 29 May 2026 20:18:28 -0500 Subject: [PATCH] test(generators): cover feature toggles and gateway aggregation --- .../FeatureToggleSetGenerator.cs | 64 ++++-- .../GatewayAggregationGenerator.cs | 65 ++++-- .../FeatureToggleSetGeneratorTests.cs | 114 +++++++++++ .../GatewayAggregationGeneratorTests.cs | 190 ++++++++++++++---- 4 files changed, 363 insertions(+), 70 deletions(-) diff --git a/src/PatternKit.Generators/FeatureToggles/FeatureToggleSetGenerator.cs b/src/PatternKit.Generators/FeatureToggles/FeatureToggleSetGenerator.cs index cbe55d11..33cf1ffb 100644 --- a/src/PatternKit.Generators/FeatureToggles/FeatureToggleSetGenerator.cs +++ b/src/PatternKit.Generators/FeatureToggles/FeatureToggleSetGenerator.cs @@ -123,6 +123,55 @@ private static string GenerateSource( sb.AppendLine(); } + var containingTypes = GetContainingTypes(type); + var indentLevel = 0; + foreach (var containingType in containingTypes) + { + AppendTypeDeclaration(sb, containingType, indentLevel); + sb.AppendLine(); + sb.AppendLine(new string(' ', indentLevel * 4) + "{"); + indentLevel++; + } + + AppendTypeDeclaration(sb, type, indentLevel); + sb.AppendLine(); + var indent = new string(' ', indentLevel * 4); + sb.AppendLine(indent + "{"); + var memberIndent = indent + " "; + var bodyIndent = memberIndent + " "; + sb.Append(memberIndent).Append("public static global::PatternKit.Application.FeatureToggles.FeatureToggleSet<").Append(contextTypeName).Append("> ").Append(factoryName).AppendLine("()"); + sb.AppendLine(memberIndent + "{"); + sb.Append(bodyIndent).Append("return global::PatternKit.Application.FeatureToggles.FeatureToggleSet<").Append(contextTypeName).Append(">.Create(\"").Append(Escape(setName)).AppendLine("\")"); + foreach (var rule in rules) + { + sb.Append(bodyIndent).Append(" .AddRule(\"").Append(Escape(rule.Name)).Append("\", ").Append(rule.DefaultEnabled ? "true" : "false").Append(", ").Append(rule.Method.Name).AppendLine(")"); + } + + sb.Append(bodyIndent).AppendLine(" .Build();"); + sb.AppendLine(memberIndent + "}"); + sb.AppendLine(indent + "}"); + for (var i = containingTypes.Length - 1; i >= 0; i--) + { + sb.AppendLine(new string(' ', i * 4) + "}"); + } + + return sb.ToString(); + } + + private static INamedTypeSymbol[] GetContainingTypes(INamedTypeSymbol type) + { + var containingTypes = new Stack(); + for (var current = type.ContainingType; current is not null; current = current.ContainingType) + { + containingTypes.Push(current); + } + + return containingTypes.ToArray(); + } + + private static void AppendTypeDeclaration(StringBuilder sb, INamedTypeSymbol type, int indentLevel) + { + sb.Append(new string(' ', indentLevel * 4)); sb.Append(GetAccessibility(type.DeclaredAccessibility)).Append(' '); if (type.IsStatic) sb.Append("static "); @@ -130,20 +179,7 @@ private static string GenerateSource( sb.Append("abstract "); else if (type.IsSealed && type.TypeKind == TypeKind.Class) sb.Append("sealed "); - sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name).AppendLine(); - sb.AppendLine("{"); - sb.Append(" public static global::PatternKit.Application.FeatureToggles.FeatureToggleSet<").Append(contextTypeName).Append("> ").Append(factoryName).AppendLine("()"); - sb.AppendLine(" {"); - sb.Append(" return global::PatternKit.Application.FeatureToggles.FeatureToggleSet<").Append(contextTypeName).Append(">.Create(\"").Append(Escape(setName)).AppendLine("\")"); - foreach (var rule in rules) - { - sb.Append(" .AddRule(\"").Append(Escape(rule.Name)).Append("\", ").Append(rule.DefaultEnabled ? "true" : "false").Append(", ").Append(rule.Method.Name).AppendLine(")"); - } - - sb.AppendLine(" .Build();"); - sb.AppendLine(" }"); - sb.AppendLine("}"); - return sb.ToString(); + sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name); } private static string Escape(string value) => value.Replace("\\", "\\\\").Replace("\"", "\\\""); diff --git a/src/PatternKit.Generators/GatewayAggregation/GatewayAggregationGenerator.cs b/src/PatternKit.Generators/GatewayAggregation/GatewayAggregationGenerator.cs index ac66c6b7..dabea959 100644 --- a/src/PatternKit.Generators/GatewayAggregation/GatewayAggregationGenerator.cs +++ b/src/PatternKit.Generators/GatewayAggregation/GatewayAggregationGenerator.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; @@ -158,6 +159,55 @@ private static string GenerateSource( sb.AppendLine(); } + var containingTypes = GetContainingTypes(type); + var indentLevel = 0; + foreach (var containingType in containingTypes) + { + AppendTypeDeclaration(sb, containingType, indentLevel); + sb.AppendLine(); + sb.AppendLine(new string(' ', indentLevel * 4) + "{"); + indentLevel++; + } + + AppendTypeDeclaration(sb, type, indentLevel); + sb.AppendLine(); + var indent = new string(' ', indentLevel * 4); + sb.AppendLine(indent + "{"); + var memberIndent = indent + " "; + var bodyIndent = memberIndent + " "; + sb.Append(memberIndent).Append("public static global::PatternKit.Cloud.GatewayAggregation.GatewayAggregation<") + .Append(requestTypeName).Append(", ").Append(responseTypeName).Append("> ").Append(factoryMethodName).AppendLine("()"); + sb.AppendLine(memberIndent + "{"); + sb.Append(bodyIndent).Append("return global::PatternKit.Cloud.GatewayAggregation.GatewayAggregation<") + .Append(requestTypeName).Append(", ").Append(responseTypeName).Append(">.Create(\"").Append(Escape(gatewayName)).AppendLine("\")"); + foreach (var fetch in fetches) + sb.Append(bodyIndent).Append(" .Fetch<").Append(fetch.Method.ReturnType.ToDisplayString(TypeFormat)).Append(">(\"").Append(Escape(fetch.Name)).Append("\", ").Append(fetch.Method.Name).AppendLine(")"); + sb.Append(bodyIndent).Append(" .Compose(").Append(composerName).AppendLine(")"); + sb.Append(bodyIndent).AppendLine(" .Build();"); + sb.AppendLine(memberIndent + "}"); + sb.AppendLine(indent + "}"); + for (var i = containingTypes.Length - 1; i >= 0; i--) + { + sb.AppendLine(new string(' ', i * 4) + "}"); + } + + return sb.ToString(); + } + + private static INamedTypeSymbol[] GetContainingTypes(INamedTypeSymbol type) + { + var containingTypes = new Stack(); + for (var current = type.ContainingType; current is not null; current = current.ContainingType) + { + containingTypes.Push(current); + } + + return containingTypes.ToArray(); + } + + private static void AppendTypeDeclaration(StringBuilder sb, INamedTypeSymbol type, int indentLevel) + { + sb.Append(new string(' ', indentLevel * 4)); sb.Append(GetAccessibility(type.DeclaredAccessibility)).Append(' '); if (type.IsStatic) sb.Append("static "); @@ -165,20 +215,7 @@ private static string GenerateSource( sb.Append("abstract "); else if (type.IsSealed && type.TypeKind == TypeKind.Class) sb.Append("sealed "); - sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name).AppendLine(); - sb.AppendLine("{"); - sb.Append(" public static global::PatternKit.Cloud.GatewayAggregation.GatewayAggregation<") - .Append(requestTypeName).Append(", ").Append(responseTypeName).Append("> ").Append(factoryMethodName).AppendLine("()"); - sb.AppendLine(" {"); - sb.Append(" return global::PatternKit.Cloud.GatewayAggregation.GatewayAggregation<") - .Append(requestTypeName).Append(", ").Append(responseTypeName).Append(">.Create(\"").Append(Escape(gatewayName)).AppendLine("\")"); - foreach (var fetch in fetches) - sb.Append(" .Fetch<").Append(fetch.Method.ReturnType.ToDisplayString(TypeFormat)).Append(">(\"").Append(Escape(fetch.Name)).Append("\", ").Append(fetch.Method.Name).AppendLine(")"); - sb.Append(" .Compose(").Append(composerName).AppendLine(")"); - sb.AppendLine(" .Build();"); - sb.AppendLine(" }"); - sb.AppendLine("}"); - return sb.ToString(); + sb.Append("partial ").Append(type.TypeKind == TypeKind.Struct ? "struct" : "class").Append(' ').Append(type.Name); } private static string? GetNamedString(AttributeData attribute, string name) diff --git a/test/PatternKit.Generators.Tests/FeatureToggleSetGeneratorTests.cs b/test/PatternKit.Generators.Tests/FeatureToggleSetGeneratorTests.cs index e5d80ca4..6a446823 100644 --- a/test/PatternKit.Generators.Tests/FeatureToggleSetGeneratorTests.cs +++ b/test/PatternKit.Generators.Tests/FeatureToggleSetGeneratorTests.cs @@ -43,6 +43,8 @@ public static partial class CheckoutToggles [Theory] [InlineData("public static class CheckoutToggles { [FeatureToggleRule(\"x\")] private static bool IsEnabled(CheckoutContext context) => true; }", "PKFT001")] [InlineData("public static partial class CheckoutToggles;", "PKFT002")] + [InlineData("public static partial class CheckoutToggles { [FeatureToggleRule(\"x\")] private static bool IsEnabled() => true; }", "PKFT003")] + [InlineData("public static partial class CheckoutToggles { [FeatureToggleRule(\"x\")] private static bool IsEnabled(string context) => true; }", "PKFT003")] [InlineData("public static partial class CheckoutToggles { [FeatureToggleRule(\"x\")] private static string IsEnabled(CheckoutContext context) => \"yes\"; }", "PKFT003")] [InlineData("public static partial class CheckoutToggles { [FeatureToggleRule(\"x\")] private bool IsEnabled(CheckoutContext context) => true; }", "PKFT003")] public Task Generator_Reports_Invalid_Feature_Toggle_Declarations(string declaration, string diagnosticId) @@ -56,6 +58,118 @@ public sealed record CheckoutContext(string Tenant, decimal Total); ScenarioExpect.Contains(result.Diagnostics, diagnostic => diagnostic.Id == diagnosticId)) .AssertPassed(); + [Scenario("Generator emits feature toggle defaults and host shapes")] + [Fact] + public Task Generator_Emits_Feature_Toggle_Defaults_And_Host_Shapes() + => Given("feature toggle declarations with default names and host shapes", () => Compile(""" + using PatternKit.Generators.FeatureToggles; + namespace Demo; + public sealed record CheckoutContext(string Tenant, decimal Total); + + [GenerateFeatureToggleSet(typeof(CheckoutContext))] + internal abstract partial class AbstractToggles + { + [FeatureToggleRule("enabled")] + private static bool IsEnabled(CheckoutContext context) => true; + } + + [GenerateFeatureToggleSet(typeof(CheckoutContext), SetName = "tenant\\\"toggles")] + public sealed partial class SealedToggles + { + [FeatureToggleRule("beta", DefaultEnabled = true)] + private static bool Beta(CheckoutContext context) => context.Tenant == "beta"; + } + + [GenerateFeatureToggleSet(typeof(CheckoutContext))] + internal partial struct StructToggles + { + [FeatureToggleRule("large-order")] + private static bool LargeOrder(CheckoutContext context) => context.Total >= 500m; + } + """)) + .Then("generated sources preserve host shape and configured defaults", result => + { + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources); + ScenarioExpect.Contains("internal abstract partial class AbstractToggles", combined); + ScenarioExpect.Contains("public sealed partial class SealedToggles", combined); + ScenarioExpect.Contains("internal partial struct StructToggles", combined); + ScenarioExpect.Contains("Create(\"feature-toggles\")", combined); + ScenarioExpect.Contains("Create(\"tenant\\\\\\\"toggles\")", combined); + ScenarioExpect.Contains(".AddRule(\"enabled\", false, IsEnabled)", combined); + ScenarioExpect.Contains(".AddRule(\"beta\", true, Beta)", combined); + ScenarioExpect.True(result.EmitSuccess, result.EmitDiagnostics); + }) + .AssertPassed(); + + [Scenario("Generator emits nested feature toggle host wrappers")] + [Fact] + public Task Generator_Emits_Nested_Feature_Toggle_Host_Wrappers() + => Given("nested feature toggle declarations", () => Compile(""" + using PatternKit.Generators.FeatureToggles; + namespace Demo; + public sealed record CheckoutContext(string Tenant, decimal Total); + + public partial class ToggleContainer + { + private partial class PrivateHost + { + [GenerateFeatureToggleSet(typeof(CheckoutContext))] + protected partial class ProtectedToggles + { + [FeatureToggleRule("protected")] + private static bool Protected(CheckoutContext context) => true; + } + + [GenerateFeatureToggleSet(typeof(CheckoutContext))] + private protected partial class PrivateProtectedToggles + { + [FeatureToggleRule("private-protected")] + private static bool PrivateProtected(CheckoutContext context) => true; + } + + [GenerateFeatureToggleSet(typeof(CheckoutContext))] + protected internal partial class ProtectedInternalToggles + { + [FeatureToggleRule("protected-internal")] + private static bool ProtectedInternal(CheckoutContext context) => true; + } + } + } + """)) + .Then("generated sources preserve containing partial type wrappers", result => + { + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources); + ScenarioExpect.Contains("public partial class ToggleContainer", combined); + ScenarioExpect.Contains("private partial class PrivateHost", combined); + ScenarioExpect.Contains("protected partial class ProtectedToggles", combined); + ScenarioExpect.Contains("private protected partial class PrivateProtectedToggles", combined); + ScenarioExpect.Contains("protected internal partial class ProtectedInternalToggles", combined); + ScenarioExpect.True(result.EmitSuccess, result.EmitDiagnostics); + }) + .AssertPassed(); + + [Scenario("Generator skips malformed feature toggle context type")] + [Fact] + public Task Generator_Skips_Malformed_Feature_Toggle_Context_Type() + => Given("a feature toggle declaration with a null context type", () => Compile(""" + using PatternKit.Generators.FeatureToggles; + [GenerateFeatureToggleSet(null!)] + public static partial class CheckoutToggles + { + [FeatureToggleRule("x")] + private static bool IsEnabled(string context) => true; + } + """)) + .Then("no source is generated", result => + ScenarioExpect.Empty(result.GeneratedSources)) + .AssertPassed(); + private static GeneratorResult Compile(string source) { var compilation = RoslynTestHelpers.CreateCompilation( diff --git a/test/PatternKit.Generators.Tests/GatewayAggregationGeneratorTests.cs b/test/PatternKit.Generators.Tests/GatewayAggregationGeneratorTests.cs index d24bf727..dcaa1252 100644 --- a/test/PatternKit.Generators.Tests/GatewayAggregationGeneratorTests.cs +++ b/test/PatternKit.Generators.Tests/GatewayAggregationGeneratorTests.cs @@ -46,56 +46,162 @@ public static partial class CustomerDashboardGateway .AssertPassed(); [Scenario("Reports diagnostics for invalid gateway aggregation declarations")] + [Theory] + [InlineData("public static class GatewayHost;", "PKGA001")] + [InlineData("public static partial class GatewayHost;", "PKGA002")] + [InlineData("public static partial class GatewayHost { [GatewayAggregationFetch(\"profile\")] private static void Profile(DashboardRequest request) { } [GatewayAggregationComposer] private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); }", "PKGA003")] + [InlineData("public static partial class GatewayHost { [GatewayAggregationFetch(\"profile\")] private DashboardRequest Profile(DashboardRequest request) => request; [GatewayAggregationComposer] private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); }", "PKGA003")] + [InlineData("public static partial class GatewayHost { [GatewayAggregationFetch(\"profile\")] private static DashboardRequest Profile(string request) => new(request); [GatewayAggregationComposer] private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); }", "PKGA003")] + [InlineData("public static partial class GatewayHost { [GatewayAggregationFetch(\"profile\")] private static DashboardRequest Profile(DashboardRequest request) => request; [GatewayAggregationComposer] private static string Compose(GatewayAggregationContext ctx) => ctx.Request.CustomerId; }", "PKGA003")] + [InlineData("public static partial class GatewayHost { [GatewayAggregationFetch(\"profile\")] private static DashboardRequest Profile(DashboardRequest request) => request; [GatewayAggregationComposer] private DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); }", "PKGA003")] + [InlineData("public static partial class GatewayHost { [GatewayAggregationFetch(\"profile\")] private static DashboardRequest Profile(DashboardRequest request) => request; [GatewayAggregationComposer] private static DashboardResponse Compose(string ctx) => new(ctx); }", "PKGA003")] + [InlineData("public static partial class GatewayHost { [GatewayAggregationFetch(\"profile\")] private static DashboardRequest Profile(DashboardRequest request) => request; [GatewayAggregationFetch(\"PROFILE\")] private static DashboardRequest Profile2(DashboardRequest request) => request; [GatewayAggregationComposer] private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); }", "PKGA004")] + public Task Reports_Diagnostics_For_Invalid_Gateway_Aggregation_Declarations(string declaration, string diagnosticId) + => Given("an invalid gateway aggregation declaration", () => Compile($$""" + using PatternKit.Cloud.GatewayAggregation; + using PatternKit.Generators.GatewayAggregation; + public sealed record DashboardRequest(string CustomerId); + public sealed record DashboardResponse(string CustomerId); + [GenerateGatewayAggregation(typeof(DashboardRequest), typeof(DashboardResponse))] + {{declaration}} + """)) + .Then("the expected diagnostic is reported", result => + ScenarioExpect.Contains(result.Diagnostics, diagnostic => diagnostic.Id == diagnosticId)) + .AssertPassed(); + + [Scenario("Generates gateway aggregation defaults and host shapes")] [Fact] - public Task Reports_Diagnostics_For_Invalid_Gateway_Aggregation_Declarations() - => Given("invalid gateway aggregation declarations", () => new[] + public Task Generates_Gateway_Aggregation_Defaults_And_Host_Shapes() + => Given("gateway aggregation declarations with default names and host shapes", () => Compile(""" + using PatternKit.Cloud.GatewayAggregation; + using PatternKit.Generators.GatewayAggregation; + namespace Demo; + public sealed record DashboardRequest(string CustomerId); + public sealed record DashboardResponse(string CustomerId); + + [GenerateGatewayAggregation(typeof(DashboardRequest), typeof(DashboardResponse))] + internal abstract partial class AbstractGateway + { + [GatewayAggregationFetch("profile")] + private static DashboardRequest Profile(DashboardRequest request) => request; + [GatewayAggregationComposer] + private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); + } + + [GenerateGatewayAggregation(typeof(DashboardRequest), typeof(DashboardResponse), GatewayName = "tenant\\\"gateway")] + public sealed partial class SealedGateway + { + [GatewayAggregationFetch("profile")] + private static DashboardRequest Profile(DashboardRequest request) => request; + [GatewayAggregationComposer] + private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); + } + + [GenerateGatewayAggregation(typeof(DashboardRequest), typeof(DashboardResponse))] + internal partial struct StructGateway + { + [GatewayAggregationFetch("profile")] + private static DashboardRequest Profile(DashboardRequest request) => request; + [GatewayAggregationComposer] + private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); + } + """)) + .Then("generated sources preserve host shape and configured defaults", result => { - Compile(""" - using PatternKit.Generators.GatewayAggregation; - [GenerateGatewayAggregation(typeof(string), typeof(int))] - public static class GatewayHost; - """), - Compile(""" - using PatternKit.Generators.GatewayAggregation; - [GenerateGatewayAggregation(typeof(string), typeof(int))] - public static partial class GatewayHost; - """), - Compile(""" - using PatternKit.Cloud.GatewayAggregation; - using PatternKit.Generators.GatewayAggregation; - [GenerateGatewayAggregation(typeof(string), typeof(int))] - public static partial class GatewayHost - { - [GatewayAggregationFetch("profile")] - private static void Profile(string value) { } - [GatewayAggregationComposer] - private static int Compose(GatewayAggregationContext ctx) => 1; - } - """), - Compile(""" - using PatternKit.Cloud.GatewayAggregation; - using PatternKit.Generators.GatewayAggregation; - [GenerateGatewayAggregation(typeof(string), typeof(int))] - public static partial class GatewayHost + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources); + ScenarioExpect.Contains("internal abstract partial class AbstractGateway", combined); + ScenarioExpect.Contains("public sealed partial class SealedGateway", combined); + ScenarioExpect.Contains("internal partial struct StructGateway", combined); + ScenarioExpect.Contains("Create(\"gateway-aggregation\")", combined); + ScenarioExpect.Contains("Create(\"tenant\\\\\\\"gateway\")", combined); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); + }) + .AssertPassed(); + + [Scenario("Generates nested gateway aggregation host wrappers")] + [Fact] + public Task Generates_Nested_Gateway_Aggregation_Host_Wrappers() + => Given("nested gateway aggregation declarations", () => Compile(""" + using PatternKit.Cloud.GatewayAggregation; + using PatternKit.Generators.GatewayAggregation; + namespace Demo; + public sealed record DashboardRequest(string CustomerId); + public sealed record DashboardResponse(string CustomerId); + + public partial class GatewayContainer + { + private partial class PrivateHost { - [GatewayAggregationFetch("profile")] - private static string Profile(string value) => value; - [GatewayAggregationFetch("PROFILE")] - private static string Profile2(string value) => value; - [GatewayAggregationComposer] - private static int Compose(GatewayAggregationContext ctx) => 1; + [GenerateGatewayAggregation(typeof(DashboardRequest), typeof(DashboardResponse))] + protected partial class ProtectedGateway + { + [GatewayAggregationFetch("profile")] + private static DashboardRequest Profile(DashboardRequest request) => request; + [GatewayAggregationComposer] + private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); + } + + [GenerateGatewayAggregation(typeof(DashboardRequest), typeof(DashboardResponse))] + private protected partial class PrivateProtectedGateway + { + [GatewayAggregationFetch("profile")] + private static DashboardRequest Profile(DashboardRequest request) => request; + [GatewayAggregationComposer] + private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); + } + + [GenerateGatewayAggregation(typeof(DashboardRequest), typeof(DashboardResponse))] + protected internal partial class ProtectedInternalGateway + { + [GatewayAggregationFetch("profile")] + private static DashboardRequest Profile(DashboardRequest request) => request; + [GatewayAggregationComposer] + private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); + } } - """) - }) - .Then("diagnostics identify invalid declarations", results => + } + """)) + .Then("generated sources preserve containing partial type wrappers", result => { - ScenarioExpect.Contains(results[0].Diagnostics, diagnostic => diagnostic.Id == "PKGA001"); - ScenarioExpect.Contains(results[1].Diagnostics, diagnostic => diagnostic.Id == "PKGA002"); - ScenarioExpect.Contains(results[2].Diagnostics, diagnostic => diagnostic.Id == "PKGA003"); - ScenarioExpect.Contains(results[3].Diagnostics, diagnostic => diagnostic.Id == "PKGA004"); + ScenarioExpect.Empty(result.Diagnostics); + ScenarioExpect.Equal(3, result.GeneratedSources.Count); + + var combined = string.Join("\n", result.GeneratedSources); + ScenarioExpect.Contains("public partial class GatewayContainer", combined); + ScenarioExpect.Contains("private partial class PrivateHost", combined); + ScenarioExpect.Contains("protected partial class ProtectedGateway", combined); + ScenarioExpect.Contains("private protected partial class PrivateProtectedGateway", combined); + ScenarioExpect.Contains("protected internal partial class ProtectedInternalGateway", combined); + ScenarioExpect.True(result.EmitSuccess, string.Join(Environment.NewLine, result.EmitDiagnostics)); }) .AssertPassed(); + [Scenario("Skips malformed gateway aggregation type arguments")] + [Theory] + [InlineData("null!", "typeof(DashboardResponse)")] + [InlineData("typeof(DashboardRequest)", "null!")] + public Task Skips_Malformed_Gateway_Aggregation_Type_Arguments(string requestType, string responseType) + => Given("a gateway aggregation declaration with a null type argument", () => Compile($$""" + using PatternKit.Cloud.GatewayAggregation; + using PatternKit.Generators.GatewayAggregation; + public sealed record DashboardRequest(string CustomerId); + public sealed record DashboardResponse(string CustomerId); + [GenerateGatewayAggregation({{requestType}}, {{responseType}})] + public static partial class GatewayHost + { + [GatewayAggregationFetch("profile")] + private static DashboardRequest Profile(DashboardRequest request) => request; + [GatewayAggregationComposer] + private static DashboardResponse Compose(GatewayAggregationContext ctx) => new(ctx.Request.CustomerId); + } + """)) + .Then("no source is generated", result => + ScenarioExpect.Empty(result.GeneratedSources)) + .AssertPassed(); + private static GeneratorResult Compile(string source) { var compilation = RoslynTestHelpers.CreateCompilation(