From e903baf7074f9f047f95481dff0c8a488183d519 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 7 Jun 2026 15:44:57 -0700 Subject: [PATCH 01/18] Initial Colored noise impl --- skrample/pytorch/noise.py | 199 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 43fcdf9..1c6ee47 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -252,6 +252,205 @@ def from_inputs( return cls(shape=shape, seed=seed, dtype=dtype, props=props) +@dataclass(frozen=True) +class ColoredProps(BrownianProps): + independent: bool = False + """When `True`, the initial noise is generated using a Brownian bridge to maintain + temporal consistency across any `Step` distance. Otherwise noise is purely random.""" + + energy: float | None = None + """Target standard deviation of the output tensor. + When `None`, noise is normalized back to unit variance.""" + + color_start: float = 0 + "Power-law exponent at the beginning of the schedule (`step` = None)" + color_end: float = -2 + "Power-law exponent at the end of the schedule (`step.time_to` = 1)." + + +@dataclass +class Colored(TensorNoiseCommon[ColoredProps]): + """Power-law colored noise generator with schedule-driven exponent interpolation. + + Generates noise whose power spectrum follows `f^{-exponent/2}` by shaping + white noise in the Fourier domain. The `exponent` is interpolated between + `color_start` and `color_end` as a function of the diffusion step, + so the color of the noise evolves over the generation timeline. + """ + + def __post_init__(self) -> None: + if self.props.independent: + import torchsde + + self._tree = torchsde.BrownianInterval( + t0=0, + t1=1, + size=self.shape, + entropy=self.seed.initial_seed(), + dtype=self.dtype, + device=self.seed.device, + halfway_tree=True, + tol=1 / (self.props.max_steps * 10), # 1 order of magnitude more than min step size + pool_size=2**6, # tolerance is 99% of the perf hit at this size + cache_size=round(math.log2(self.props.max_steps * 10) * 1.3), # binary for halfway + 30% + ) + else: + self._tree = None + + @staticmethod + def _radial_freq_grid(shape: torch.Size, device: torch.device) -> torch.Tensor: + """Build a normalized radial-frequency tensor matching rfftn output shape. + + For `rfftn(x)` on a tensor of spatial shape `(D₁, …, Dₙ)`, the complex + output has shape `(D₁, …, Dₙ//2+1)` when the last dim is even (full size + if odd). This function returns a radius grid of *exactly* that trailing-D + shape so it broadcasts naturally over any leading batch/channel dims. + + Values are in `[0, 1]` with 0 = DC and 1 = the farthest Nyquist bin. + + Parameters + --- + shape : torch.Size + The spatial shape of the tensor being transformed. + device : torch.device + Where to allocate frequency tensors. + + Returns + --- + torch.Tensor + Normalized radial frequency grid. + + Notes + --- + Implementation by Qwen 3.6 27B + """ + ndim = len(shape) + + # Build per-axis frequency coordinates in normalized form. + # rFFT always keeps only the non-redundant half on the last axis (N//2+1 bins). + freqs_per_axis: list[torch.Tensor] = [] + for i, dim in enumerate(shape): + if i == ndim - 1: + # Last axis: rFFT output has N//2 + 1 non-redundant frequency bins [0 .. N/2] + n_bins = dim // 2 + 1 + idx = torch.arange(n_bins, device=device) + freqs_per_axis.append(idx / dim) # normalized [0, 0.5] + else: + # Other axes: full FFT - use abs(fftfreq) for radial distance symmetry + freqs_per_axis.append(torch.fft.fftfreq(dim, d=1.0, device=device).abs()) + + # meshgrid → stack → radial norm + grid = torch.stack(torch.meshgrid(*freqs_per_axis, indexing="ij"), dim=-1) + radius = grid.norm(p=2, dim=-1) + + # Normalize to [0, 1] + r_max = radius.max() + if r_max > 0: + radius = radius / r_max + + return radius # shape exactly matches trailing ndim of rfftn output + + @staticmethod + def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | None = None) -> torch.Tensor: + """Colors the input white noise according to the Gaussian power-law spectrum `f^{-exponent/2}`. + + Takes an existing white-noise tensor and colours it in the Fourier + domain so that its amplitude falls (or rises) with radial frequency. + The result is normalized back to unit standard deviation (or to `energy`, if provided). + + Examples + --- + >>> import torch + >>> white = torch.randn(4, 64, 64) + >>> + >>> # Pink-ish noise - richer low-frequency structure + >>> pink = Colored.colored_noise(white, exponent=1.0) + >>> + >>> # Blue noise - high-frequency detail emphasized + >>> blue = Colored.colored_noise(white, exponent=-2.0) + + Notes + --- + Initial implementation by Qwen 3.6 27B + """ + + # Step 1: white noise + out_shape = white.shape + wstd = white.std() + w = white.squeeze() + + # Fast path: t == 0 is plain white noise - no FFT overhead + if exponent == 0.0: + return w if energy is None or wstd < 1e-8 else w * (energy / wstd) + + # Step 2: forward FFT (real → complex) + F = torch.fft.rfftn(w, norm="forward") + + # Step 3: normalized radial frequency grid + freq_grid = Colored._radial_freq_grid(w.shape, w.device) + + # Step 4: power-law amplitude weights + # PSD ∝ f^{-t} ⇒ amplitude weight ∝ f^{-t/2}. + # + # The weight diverges at DC (f = 0). We clip at half a frequency-bin + # spacing in normalized coordinates, which is standard practice for + # FFT-based colored-noise generation. This gives the correct PSD slope + # away from DC while keeping only one bin per radial direction clamped. + N_eff = sum(w.shape) / len(w.shape) if w.shape else 1.0 + eps_clip = 0.5 / max(N_eff, 4.0) + + weights = torch.clamp(freq_grid, min=eps_clip) ** (-exponent / 2.0) + + # Step 5: multiply in Fourier domain + F_colored = F * weights + + # Step 6: inverse FFT to spatial domain + colored = torch.fft.irfftn(F_colored, s=w.shape, norm="forward") + + # Step 7: renormalize to unit std (variance conservation) + cstd = colored.std() + if cstd > 1e-8: + colored *= wstd / cstd if energy is None else energy / cstd + + return colored.view(out_shape).to(dtype=w.dtype) + + def white(self, step: Step | None) -> torch.Tensor: + """Raw white-noise tensor before coloring. + Uses BrownianInterval when `props.independent` is `True` + and a valid `step` is provided; otherwise falls back to plain randn.""" + + if self._tree is not None and step: + step = step.normal().clamp() # enforce 0..=1 + return self._tree(*step) / math.sqrt(step.distance()) # pyright: ignore[reportOperatorIssue] + else: + return self._randn() + + def generate(self, step: Step | None) -> torch.Tensor: + noise = self.white(step) + + if step is None: + exponent = self.props.color_start # t=0 equivalent + else: + step = step.normal().clamp() # enforce 0..=1 + t = step.time_to # t>0 + exponent = (1 - t) * self.props.color_start + t * self.props.color_end + + # will short-circuit for exponent 0, but still has energy target + noise = self.colorize_noise(noise, exponent=exponent, energy=self.props.energy) + + return noise + + @classmethod + def from_inputs( + cls, + shape: tuple[int, ...], + seed: torch.Generator, + props: ColoredProps = ColoredProps(), + dtype: torch.dtype = torch.float32, + ) -> Self: + return cls(shape=shape, seed=seed, dtype=dtype, props=props) + + @dataclass class BatchTensorNoise[T: TensorNoiseProps | None](SkrampleTensorNoise): """Helper class for producing batches of noise while maintaining seeds across individual batch items. From 7ff674d69c6e6e9e54043e596cf8231c4a37335d Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 7 Jun 2026 22:42:09 -0700 Subject: [PATCH 02/18] Fix some bugs and docs --- skrample/pytorch/noise.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 1c6ee47..251d52a 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -260,7 +260,7 @@ class ColoredProps(BrownianProps): energy: float | None = None """Target standard deviation of the output tensor. - When `None`, noise is normalized back to unit variance.""" + When `None`, noise is normalized back to uncolored variance.""" color_start: float = 0 "Power-law exponent at the beginning of the schedule (`step` = None)" @@ -352,22 +352,25 @@ def _radial_freq_grid(shape: torch.Size, device: torch.device) -> torch.Tensor: @staticmethod def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | None = None) -> torch.Tensor: - """Colors the input white noise according to the Gaussian power-law spectrum `f^{-exponent/2}`. + """Colors the input white noise according to the Gaussian power-law spectrum `f^{-exponent}`. - Takes an existing white-noise tensor and colours it in the Fourier + Takes an existing white-noise tensor and colors it in the Fourier domain so that its amplitude falls (or rises) with radial frequency. The result is normalized back to unit standard deviation (or to `energy`, if provided). + Single element dimensions are excluded from FFT. + Batching is NOT accounted for. Batched tensors must be passed individually. + Examples --- >>> import torch - >>> white = torch.randn(4, 64, 64) + >>> white = torch.randn(64, 64) >>> >>> # Pink-ish noise - richer low-frequency structure - >>> pink = Colored.colored_noise(white, exponent=1.0) + >>> pink = Colored.colorize_noise(white, exponent=1.0) >>> >>> # Blue noise - high-frequency detail emphasized - >>> blue = Colored.colored_noise(white, exponent=-2.0) + >>> blue = Colored.colorize_noise(white, exponent=-2.0) Notes --- @@ -375,13 +378,16 @@ def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | N """ # Step 1: white noise - out_shape = white.shape wstd = white.std() - w = white.squeeze() # Fast path: t == 0 is plain white noise - no FFT overhead if exponent == 0.0: - return w if energy is None or wstd < 1e-8 else w * (energy / wstd) + return white if energy is None or wstd < 1e-8 else white * (energy / wstd) + + w = white.squeeze() + + if w.dtype not in [torch.float32, torch.float64]: # half/bfloat not fully supported + w = w.to(torch.float32) # Step 2: forward FFT (real → complex) F = torch.fft.rfftn(w, norm="forward") @@ -407,12 +413,12 @@ def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | N # Step 6: inverse FFT to spatial domain colored = torch.fft.irfftn(F_colored, s=w.shape, norm="forward") - # Step 7: renormalize to unit std (variance conservation) + # Step 7: renormalize to input std (variance conservation) cstd = colored.std() if cstd > 1e-8: colored *= wstd / cstd if energy is None else energy / cstd - return colored.view(out_shape).to(dtype=w.dtype) + return colored.view(white.shape).to(dtype=white.dtype) def white(self, step: Step | None) -> torch.Tensor: """Raw white-noise tensor before coloring. From 17d7d13134cb4f2a4859417951735653cc1ed255 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 7 Jun 2026 22:45:26 -0700 Subject: [PATCH 03/18] typo --- skrample/pytorch/noise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 251d52a..743eadd 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -356,7 +356,7 @@ def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | N Takes an existing white-noise tensor and colors it in the Fourier domain so that its amplitude falls (or rises) with radial frequency. - The result is normalized back to unit standard deviation (or to `energy`, if provided). + The result is normalized back to input deviation (or to `energy`, if provided). Single element dimensions are excluded from FFT. Batching is NOT accounted for. Batched tensors must be passed individually. From 799672fdad8fd9c46e714f3e5f77f08e32cb7d1e Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 7 Jun 2026 23:53:46 -0700 Subject: [PATCH 04/18] Simplify FlowShift --- skrample/scheduling.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/skrample/scheduling.py b/skrample/scheduling.py index d958c10..cad4eed 100644 --- a/skrample/scheduling.py +++ b/skrample/scheduling.py @@ -555,10 +555,7 @@ class FlowShift(ScheduleModifier): """Amount to shift noise schedule by.""" def _modify(self, t: NPSequence) -> NPSequence: - t = t.copy() - mask = t > 0 - t[mask] = self.shift / (self.shift + (1 / t[mask] - 1)) - return t + return self.shift / (self.shift + (1 / t - 1)) @dataclass(frozen=True) From d40aefa8ae6bd5a1f8c3799ffd2f2674ed5b7210 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 8 Jun 2026 00:02:20 -0700 Subject: [PATCH 05/18] Add ColorProps.color_curve --- skrample/pytorch/noise.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 743eadd..89a0902 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -5,7 +5,7 @@ import torch -from skrample.common import Step +from skrample.common import Step, rescale_positive @dataclass(frozen=True) @@ -266,6 +266,8 @@ class ColoredProps(BrownianProps): "Power-law exponent at the beginning of the schedule (`step` = None)" color_end: float = -2 "Power-law exponent at the end of the schedule (`step.time_to` = 1)." + color_curve: float = 2 + "Curvature of power-law exponent gradient, similar to FlowShift" @dataclass @@ -439,6 +441,8 @@ def generate(self, step: Step | None) -> torch.Tensor: else: step = step.normal().clamp() # enforce 0..=1 t = step.time_to # t>0 + shift = rescale_positive(self.props.color_curve) + t = shift / (shift + (1 / t - 1)) exponent = (1 - t) * self.props.color_start + t * self.props.color_end # will short-circuit for exponent 0, but still has energy target From e38b377189a5e2f28c3a26fecc22f62f008e7d42 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 8 Jun 2026 00:07:08 -0700 Subject: [PATCH 06/18] fix ci --- scripts/plot_skrample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/plot_skrample.py b/scripts/plot_skrample.py index c61cc6d..795b7aa 100755 --- a/scripts/plot_skrample.py +++ b/scripts/plot_skrample.py @@ -244,7 +244,7 @@ def callback(x: float, n: int, d: DeltaPoint) -> None: composed = schedule label: str = sched_name - if (subschedule := SUBSCHEDULES[sub]) and sub is not None: + if subschedule := SUBSCHEDULES[sub]: composed = subschedule[0](composed, **subschedule[1]) label += "_" + subschedule[0].__name__.lower() From bc1ff315d7e55bc2713d11742f5d01154d0161f3 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 8 Jun 2026 00:09:20 -0700 Subject: [PATCH 07/18] CI weirdness --- skrample/sampling/structured.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrample/sampling/structured.py b/skrample/sampling/structured.py index 4ab3a4d..e9db2c3 100644 --- a/skrample/sampling/structured.py +++ b/skrample/sampling/structured.py @@ -110,8 +110,8 @@ def sample_packed[T: Sample]( schedule: SkrampleSchedule, previous: Sequence[SKSamples[T]] = (), ) -> SKSamples[T]: - return SKSamples( - **( # ty: ignore # ??? + return SKSamples( # ty: ignore # ??? + **( # type: ignore # ??? dataclasses.asdict(packed) | { "final": self._sample_packed( From 831bed2f593293f02e80330f0164debbecb9e468 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 8 Jun 2026 00:17:26 -0700 Subject: [PATCH 08/18] Guard some infs --- skrample/pytorch/noise.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 89a0902..362d1de 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -5,7 +5,7 @@ import torch -from skrample.common import Step, rescale_positive +from skrample.common import Step, divf, rescale_positive @dataclass(frozen=True) @@ -438,11 +438,13 @@ def generate(self, step: Step | None) -> torch.Tensor: if step is None: exponent = self.props.color_start # t=0 equivalent + elif self.props.color_curve == math.inf: + exponent = self.props.color_end # infinite curve makes a flat line else: step = step.normal().clamp() # enforce 0..=1 t = step.time_to # t>0 shift = rescale_positive(self.props.color_curve) - t = shift / (shift + (1 / t - 1)) + t = shift / (shift + (divf(1, t) - 1)) exponent = (1 - t) * self.props.color_start + t * self.props.color_end # will short-circuit for exponent 0, but still has energy target From 167dde1b17e7e03bb44fb9777140bea4085af2e3 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 8 Jun 2026 00:44:32 -0700 Subject: [PATCH 09/18] VFR colored noise with simple script --- scripts/colored_noise_grid.py | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 scripts/colored_noise_grid.py diff --git a/scripts/colored_noise_grid.py b/scripts/colored_noise_grid.py new file mode 100644 index 0000000..3ebcf6f --- /dev/null +++ b/scripts/colored_noise_grid.py @@ -0,0 +1,36 @@ +import PIL.Image +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL + +from skrample.pytorch.noise import Colored + +with torch.inference_mode(): + device = torch.device("cuda") + dtype = torch.float32 + + size: int = 1024 + exponents: list[float] = [-(2**3), -(2**2), -(2**1), 0, 2**-1, 2**0, 2**0.5] + + size = round(size / 8) * 8 + + aekl: AutoencoderKL = AutoencoderKL.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + subfolder="vae", + ).to(device=device, dtype=dtype) # type: ignore # ??? + + batches = torch.randn(len(exponents), 4, size // 8, size // 8, device=device, dtype=dtype) + + colors = [[Colored.colorize_noise(t, e) for t in batches] for e in exponents] + + canvas = PIL.Image.new("RGB", (size * len(exponents), size * len(exponents))) + + for y, batches in enumerate(colors): + for x, latent in enumerate(batches): + decoded = aekl.decode(latent.unsqueeze(0) / aekl.config["scaling_factor"]).sample[0] # pyright: ignore [reportAttributeAccessIssue] + im = PIL.Image.fromarray( + ((decoded + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() + ) + + canvas.paste(im, (x * size, y * size)) + + canvas.save("colored_noise_grid.png") From dcb0f272c017da5166237ec3e294aa16230197f7 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Mon, 8 Jun 2026 01:51:46 -0700 Subject: [PATCH 10/18] Remove ColoredProps.independent --- skrample/pytorch/noise.py | 38 ++------------------------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 362d1de..3d45e35 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -253,11 +253,7 @@ def from_inputs( @dataclass(frozen=True) -class ColoredProps(BrownianProps): - independent: bool = False - """When `True`, the initial noise is generated using a Brownian bridge to maintain - temporal consistency across any `Step` distance. Otherwise noise is purely random.""" - +class ColoredProps(TensorNoiseProps): energy: float | None = None """Target standard deviation of the output tensor. When `None`, noise is normalized back to uncolored variance.""" @@ -280,25 +276,6 @@ class Colored(TensorNoiseCommon[ColoredProps]): so the color of the noise evolves over the generation timeline. """ - def __post_init__(self) -> None: - if self.props.independent: - import torchsde - - self._tree = torchsde.BrownianInterval( - t0=0, - t1=1, - size=self.shape, - entropy=self.seed.initial_seed(), - dtype=self.dtype, - device=self.seed.device, - halfway_tree=True, - tol=1 / (self.props.max_steps * 10), # 1 order of magnitude more than min step size - pool_size=2**6, # tolerance is 99% of the perf hit at this size - cache_size=round(math.log2(self.props.max_steps * 10) * 1.3), # binary for halfway + 30% - ) - else: - self._tree = None - @staticmethod def _radial_freq_grid(shape: torch.Size, device: torch.device) -> torch.Tensor: """Build a normalized radial-frequency tensor matching rfftn output shape. @@ -422,19 +399,8 @@ def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | N return colored.view(white.shape).to(dtype=white.dtype) - def white(self, step: Step | None) -> torch.Tensor: - """Raw white-noise tensor before coloring. - Uses BrownianInterval when `props.independent` is `True` - and a valid `step` is provided; otherwise falls back to plain randn.""" - - if self._tree is not None and step: - step = step.normal().clamp() # enforce 0..=1 - return self._tree(*step) / math.sqrt(step.distance()) # pyright: ignore[reportOperatorIssue] - else: - return self._randn() - def generate(self, step: Step | None) -> torch.Tensor: - noise = self.white(step) + noise = self._randn() if step is None: exponent = self.props.color_start # t=0 equivalent From d704b71b8a6730833258fb07d17f8fc0872c6f33 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Wed, 10 Jun 2026 02:47:51 -0700 Subject: [PATCH 11/18] fix ci --- scripts/colored_noise_grid.py | 2 ++ 1 file changed, 2 insertions(+) mode change 100644 => 100755 scripts/colored_noise_grid.py diff --git a/scripts/colored_noise_grid.py b/scripts/colored_noise_grid.py old mode 100644 new mode 100755 index 3ebcf6f..fbeaf1e --- a/scripts/colored_noise_grid.py +++ b/scripts/colored_noise_grid.py @@ -1,3 +1,5 @@ +#! /usr/bin/env python + import PIL.Image import torch from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL From 4d6a522d903fa83b6afa18ede5f3ee247f45d3d7 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Wed, 10 Jun 2026 03:04:54 -0700 Subject: [PATCH 12/18] Fix ColorProps.color_curve direction, add a tiny amount of pink --- skrample/pytorch/noise.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 3d45e35..0d394c3 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -258,12 +258,13 @@ class ColoredProps(TensorNoiseProps): """Target standard deviation of the output tensor. When `None`, noise is normalized back to uncolored variance.""" - color_start: float = 0 + color_start: float = 1 / 4 "Power-law exponent at the beginning of the schedule (`step` = None)" color_end: float = -2 "Power-law exponent at the end of the schedule (`step.time_to` = 1)." color_curve: float = 2 - "Curvature of power-law exponent gradient, similar to FlowShift" + """Curvature of power-law exponent gradient, similar to FlowShift. + Higher values bias `color_start`, lower values bias `color_end`.""" @dataclass @@ -409,7 +410,8 @@ def generate(self, step: Step | None) -> torch.Tensor: else: step = step.normal().clamp() # enforce 0..=1 t = step.time_to # t>0 - shift = rescale_positive(self.props.color_curve) + # Negative curve to match FlowShift since step is ascending more like alpha than sigma + shift = rescale_positive(-self.props.color_curve) t = shift / (shift + (divf(1, t) - 1)) exponent = (1 - t) * self.props.color_start + t * self.props.color_end From 155a2f78e9386f37b948ac4e96831c1d291b1c8d Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 20 Jun 2026 18:43:36 -0700 Subject: [PATCH 13/18] More helpful docs --- skrample/pytorch/noise.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 0d394c3..646c8db 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -255,13 +255,17 @@ def from_inputs( @dataclass(frozen=True) class ColoredProps(TensorNoiseProps): energy: float | None = None - """Target standard deviation of the output tensor. + """Target standard deviation of the output tensor, effectively the scale of noise. When `None`, noise is normalized back to uncolored variance.""" color_start: float = 1 / 4 - "Power-law exponent at the beginning of the schedule (`step` = None)" + """Power-law exponent at the beginning of the schedule (`step` = None). + Higher values produce redder (lower frequency) noise, + lower values produce bluer (higher frequency) noise.""" color_end: float = -2 - "Power-law exponent at the end of the schedule (`step.time_to` = 1)." + """Power-law exponent at the end of the schedule (`step.time_to` = 1). + Higher values produce redder (lower frequency) noise, + lower values produce bluer (higher frequency) noise.""" color_curve: float = 2 """Curvature of power-law exponent gradient, similar to FlowShift. Higher values bias `color_start`, lower values bias `color_end`.""" From f15ef73255f67703febd5b9908a7893f458cc4f4 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 20 Jun 2026 18:58:36 -0700 Subject: [PATCH 14/18] fix docstring --- skrample/pytorch/noise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 646c8db..47b7596 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -354,7 +354,7 @@ def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | N >>> pink = Colored.colorize_noise(white, exponent=1.0) >>> >>> # Blue noise - high-frequency detail emphasized - >>> blue = Colored.colorize_noise(white, exponent=-2.0) + >>> violet = Colored.colorize_noise(white, exponent=-2.0) Notes --- From 7d2bace78fbe5d177bdd291a49060c2122a91de9 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 20 Jun 2026 19:20:53 -0700 Subject: [PATCH 15/18] more robust colored noise script --- scripts/colored_noise_grid.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/colored_noise_grid.py b/scripts/colored_noise_grid.py index fbeaf1e..509fea5 100755 --- a/scripts/colored_noise_grid.py +++ b/scripts/colored_noise_grid.py @@ -2,12 +2,13 @@ import PIL.Image import torch +import tqdm from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from skrample.pytorch.noise import Colored with torch.inference_mode(): - device = torch.device("cuda") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") dtype = torch.float32 size: int = 1024 @@ -26,8 +27,8 @@ canvas = PIL.Image.new("RGB", (size * len(exponents), size * len(exponents))) - for y, batches in enumerate(colors): - for x, latent in enumerate(batches): + for y, batches in enumerate(tqdm.tqdm(colors, "colors")): + for x, latent in enumerate(tqdm.tqdm(batches, "batch")): decoded = aekl.decode(latent.unsqueeze(0) / aekl.config["scaling_factor"]).sample[0] # pyright: ignore [reportAttributeAccessIssue] im = PIL.Image.fromarray( ((decoded + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy() From 56c4875191742d275fe54544b218e137a96fcf8d Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 20 Jun 2026 19:33:28 -0700 Subject: [PATCH 16/18] ig the norms are redundant? --- skrample/pytorch/noise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 47b7596..1e1b7cc 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -374,7 +374,7 @@ def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | N w = w.to(torch.float32) # Step 2: forward FFT (real → complex) - F = torch.fft.rfftn(w, norm="forward") + F = torch.fft.rfftn(w) # Step 3: normalized radial frequency grid freq_grid = Colored._radial_freq_grid(w.shape, w.device) @@ -395,7 +395,7 @@ def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | N F_colored = F * weights # Step 6: inverse FFT to spatial domain - colored = torch.fft.irfftn(F_colored, s=w.shape, norm="forward") + colored = torch.fft.irfftn(F_colored, s=w.shape) # Step 7: renormalize to input std (variance conservation) cstd = colored.std() From 7ccb4f0fe3aced97889a6afc0c4664a3379bc97f Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sat, 20 Jun 2026 19:42:35 -0700 Subject: [PATCH 17/18] fix seed in script --- scripts/colored_noise_grid.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/scripts/colored_noise_grid.py b/scripts/colored_noise_grid.py index 509fea5..1d67858 100755 --- a/scripts/colored_noise_grid.py +++ b/scripts/colored_noise_grid.py @@ -21,7 +21,15 @@ subfolder="vae", ).to(device=device, dtype=dtype) # type: ignore # ??? - batches = torch.randn(len(exponents), 4, size // 8, size // 8, device=device, dtype=dtype) + batches = torch.randn( + len(exponents), + 4, + size // 8, + size // 8, + device=device, + dtype=dtype, + generator=torch.Generator(device).manual_seed(42), + ) colors = [[Colored.colorize_noise(t, e) for t in batches] for e in exponents] From 73ccf17bd4401f2eb2789debc63b7bdec56c0036 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Sun, 28 Jun 2026 15:00:50 -0700 Subject: [PATCH 18/18] basic colored noise CI --- tests/self_noise.py | 103 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/self_noise.py diff --git a/tests/self_noise.py b/tests/self_noise.py new file mode 100644 index 0000000..a834fc1 --- /dev/null +++ b/tests/self_noise.py @@ -0,0 +1,103 @@ +import itertools + +import numpy as np +import pytest +import scipy.fft as fft +import torch +from scipy.stats import linregress + +from skrample.common import Step +from skrample.pytorch.noise import Colored, ColoredProps + + +def measure_noise_color(data: np.ndarray) -> float: + """ + Measures the spectral exponent (beta) of an n-dimensional noise array. + Implemented by Gemini. Was intentionally given no reference to Colored or skrample. + """ + ndim = data.ndim + shape = data.shape + + # 1. Compute the n-dimensional FFT and shift the DC component to the center + fft_dims = tuple(range(ndim)) + F = fft.fftn(data, axes=fft_dims) + F_shifted = fft.fftshift(F) + psd = np.abs(F_shifted) ** 2 + + # 2. Generate frequency grids for each dimension + freqs = [fft.fftshift(fft.fftfreq(s)) for s in shape] + mesh = np.meshgrid(*freqs, indexing="ij") + + # 3. Calculate the radial frequency (Euclidean distance from center) + radial_freq = np.sqrt(sum(m**2 for m in mesh)) + + # 4. Mask out the DC component (frequency = 0) to avoid log(0) + mask = radial_freq > 0 + radial_freq_flat = radial_freq[mask] + psd_flat = psd[mask] + + # 5. Bin the radial frequencies + num_bins = min(shape) // 2 + bin_edges = np.linspace(radial_freq_flat.min(), radial_freq_flat.max(), num_bins + 1) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + + bin_indices = np.digitize(radial_freq_flat, bin_edges) - 1 + bin_powers = np.zeros(num_bins) + + # Calculate the MEAN power per radial bin + for i in range(num_bins): + bin_powers[i] = np.mean(psd_flat[bin_indices == i]) + + # 6. Filter out empty or unrepresentative bins + valid = (bin_powers > 0) & (bin_centers > 0) + log_f = np.log(bin_centers[valid]) + log_p = np.log(bin_powers[valid]) + + # 7. Linear regression on log-log scale: log(P) = -beta * log(f) + C + slope, _intercept, _r_value, _p_value, _std_err = linregress(log_f, log_p) + + beta = -slope.item() # pyright: ignore + return beta + + +@pytest.mark.parametrize( + ("exponent", "shape"), + (itertools.product([-3, -1.5, 0, 1.5, 3], [(65536,), (1024, 1024), (128, 128, 128)])), +) +def test_noise_color(exponent: float, shape: tuple[int, ...]) -> None: + generator = Colored( + shape, + torch.Generator("cpu"), + torch.float32, + ColoredProps(color_curve=0, color_start=exponent, color_end=-exponent), + ) + n0 = generator.generate(None) + color0 = measure_noise_color(n0.numpy()) + assert abs(exponent - color0) < 0.1, f"{exponent=}, {color0=}" + + n1 = generator.generate(Step(0, 1)) + color1 = measure_noise_color(n1.numpy()) + assert abs(-exponent - color1) < 0.1, f"{-exponent=}, {color1=}" + + +@pytest.mark.parametrize( + ("energy", "shape"), + (itertools.product([None, -3, -1.5, 0, 1.5, 3], [(65536,), (1024, 1024), (128, 128, 128)])), +) +def test_noise_energy(energy: float | None, shape: tuple[int, ...]) -> None: + generator = Colored( + shape, + torch.Generator("cpu"), + torch.float32, + ColoredProps(energy=energy, color_start=torch.randn(1).item(), color_end=torch.randn(1).item()), + ) + + std0 = generator.generate(None).std().item() + std1 = generator.generate(Step(0, 1)).std().item() + + if energy is None: + assert abs(1 - std0) < 1e-2, f"{energy=}, {std0=}" + assert abs(1 - std1) < 1e-2, f"{energy=}, {std1=}" + else: + assert abs(abs(energy) - std0) < 1e-6, f"{energy=}, {std0=}" + assert abs(abs(energy) - std1) < 1e-6, f"{energy=}, {std1=}"