Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions scripts/colored_noise_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#! /usr/bin/env python

import PIL.Image
import torch
import tqdm
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL

from skrample.pytorch.noise import Colored

with torch.inference_mode():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float32

size: int = 1024
exponents: list[float] = [-(2**3), -(2**2), -(2**1), 0, 2**-1, 2**0, 2**0.5]

size = round(size / 8) * 8

aekl: AutoencoderKL = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae",
).to(device=device, dtype=dtype) # type: ignore # ???

batches = torch.randn(
len(exponents),
4,
size // 8,
size // 8,
device=device,
dtype=dtype,
generator=torch.Generator(device).manual_seed(42),
)

colors = [[Colored.colorize_noise(t, e) for t in batches] for e in exponents]

canvas = PIL.Image.new("RGB", (size * len(exponents), size * len(exponents)))

for y, batches in enumerate(tqdm.tqdm(colors, "colors")):
for x, latent in enumerate(tqdm.tqdm(batches, "batch")):
decoded = aekl.decode(latent.unsqueeze(0) / aekl.config["scaling_factor"]).sample[0] # pyright: ignore [reportAttributeAccessIssue]
im = PIL.Image.fromarray(
((decoded + 1) * (255 / 2)).clamp(0, 255).permute(1, 2, 0).to(device="cpu", dtype=torch.uint8).numpy()
)

canvas.paste(im, (x * size, y * size))

canvas.save("colored_noise_grid.png")
185 changes: 184 additions & 1 deletion skrample/pytorch/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from skrample.common import Step
from skrample.common import Step, divf, rescale_positive


@dataclass(frozen=True)
Expand Down Expand Up @@ -252,6 +252,189 @@ def from_inputs(
return cls(shape=shape, seed=seed, dtype=dtype, props=props)


@dataclass(frozen=True)
class ColoredProps(TensorNoiseProps):
energy: float | None = None
"""Target standard deviation of the output tensor, effectively the scale of noise.
When `None`, noise is normalized back to uncolored variance."""

color_start: float = 1 / 4
"""Power-law exponent at the beginning of the schedule (`step` = None).
Higher values produce redder (lower frequency) noise,
lower values produce bluer (higher frequency) noise."""
color_end: float = -2
"""Power-law exponent at the end of the schedule (`step.time_to` = 1).
Higher values produce redder (lower frequency) noise,
lower values produce bluer (higher frequency) noise."""
color_curve: float = 2
"""Curvature of power-law exponent gradient, similar to FlowShift.
Higher values bias `color_start`, lower values bias `color_end`."""


@dataclass
class Colored(TensorNoiseCommon[ColoredProps]):
"""Power-law colored noise generator with schedule-driven exponent interpolation.

Generates noise whose power spectrum follows `f^{-exponent/2}` by shaping
white noise in the Fourier domain. The `exponent` is interpolated between
`color_start` and `color_end` as a function of the diffusion step,
so the color of the noise evolves over the generation timeline.
"""

@staticmethod
def _radial_freq_grid(shape: torch.Size, device: torch.device) -> torch.Tensor:
"""Build a normalized radial-frequency tensor matching rfftn output shape.

For `rfftn(x)` on a tensor of spatial shape `(D₁, …, Dₙ)`, the complex
output has shape `(D₁, …, Dₙ//2+1)` when the last dim is even (full size
if odd). This function returns a radius grid of *exactly* that trailing-D
shape so it broadcasts naturally over any leading batch/channel dims.

Values are in `[0, 1]` with 0 = DC and 1 = the farthest Nyquist bin.

Parameters
---
shape : torch.Size
The spatial shape of the tensor being transformed.
device : torch.device
Where to allocate frequency tensors.

Returns
---
torch.Tensor
Normalized radial frequency grid.

Notes
---
Implementation by Qwen 3.6 27B
"""
ndim = len(shape)

# Build per-axis frequency coordinates in normalized form.
# rFFT always keeps only the non-redundant half on the last axis (N//2+1 bins).
freqs_per_axis: list[torch.Tensor] = []
for i, dim in enumerate(shape):
if i == ndim - 1:
# Last axis: rFFT output has N//2 + 1 non-redundant frequency bins [0 .. N/2]
n_bins = dim // 2 + 1
idx = torch.arange(n_bins, device=device)
freqs_per_axis.append(idx / dim) # normalized [0, 0.5]
else:
# Other axes: full FFT - use abs(fftfreq) for radial distance symmetry
freqs_per_axis.append(torch.fft.fftfreq(dim, d=1.0, device=device).abs())

# meshgrid → stack → radial norm
grid = torch.stack(torch.meshgrid(*freqs_per_axis, indexing="ij"), dim=-1)
radius = grid.norm(p=2, dim=-1)

# Normalize to [0, 1]
r_max = radius.max()
if r_max > 0:
radius = radius / r_max

return radius # shape exactly matches trailing ndim of rfftn output

@staticmethod
def colorize_noise(white: torch.Tensor, exponent: float = 0.0, energy: float | None = None) -> torch.Tensor:
"""Colors the input white noise according to the Gaussian power-law spectrum `f^{-exponent}`.

Takes an existing white-noise tensor and colors it in the Fourier
domain so that its amplitude falls (or rises) with radial frequency.
The result is normalized back to input deviation (or to `energy`, if provided).

Single element dimensions are excluded from FFT.
Batching is NOT accounted for. Batched tensors must be passed individually.

Examples
---
>>> import torch
>>> white = torch.randn(64, 64)
>>>
>>> # Pink-ish noise - richer low-frequency structure
>>> pink = Colored.colorize_noise(white, exponent=1.0)
>>>
>>> # Blue noise - high-frequency detail emphasized
>>> violet = Colored.colorize_noise(white, exponent=-2.0)

Notes
---
Initial implementation by Qwen 3.6 27B
"""

# Step 1: white noise
wstd = white.std()

# Fast path: t == 0 is plain white noise - no FFT overhead
if exponent == 0.0:
return white if energy is None or wstd < 1e-8 else white * (energy / wstd)

w = white.squeeze()

if w.dtype not in [torch.float32, torch.float64]: # half/bfloat not fully supported
w = w.to(torch.float32)

# Step 2: forward FFT (real → complex)
F = torch.fft.rfftn(w)

# Step 3: normalized radial frequency grid
freq_grid = Colored._radial_freq_grid(w.shape, w.device)

# Step 4: power-law amplitude weights
# PSD ∝ f^{-t} ⇒ amplitude weight ∝ f^{-t/2}.
#
# The weight diverges at DC (f = 0). We clip at half a frequency-bin
# spacing in normalized coordinates, which is standard practice for
# FFT-based colored-noise generation. This gives the correct PSD slope
# away from DC while keeping only one bin per radial direction clamped.
N_eff = sum(w.shape) / len(w.shape) if w.shape else 1.0
eps_clip = 0.5 / max(N_eff, 4.0)

weights = torch.clamp(freq_grid, min=eps_clip) ** (-exponent / 2.0)

# Step 5: multiply in Fourier domain
F_colored = F * weights

# Step 6: inverse FFT to spatial domain
colored = torch.fft.irfftn(F_colored, s=w.shape)

# Step 7: renormalize to input std (variance conservation)
cstd = colored.std()
if cstd > 1e-8:
colored *= wstd / cstd if energy is None else energy / cstd

return colored.view(white.shape).to(dtype=white.dtype)

def generate(self, step: Step | None) -> torch.Tensor:
noise = self._randn()

if step is None:
exponent = self.props.color_start # t=0 equivalent
elif self.props.color_curve == math.inf:
exponent = self.props.color_end # infinite curve makes a flat line
else:
step = step.normal().clamp() # enforce 0..=1
t = step.time_to # t>0
# Negative curve to match FlowShift since step is ascending more like alpha than sigma
shift = rescale_positive(-self.props.color_curve)
t = shift / (shift + (divf(1, t) - 1))
exponent = (1 - t) * self.props.color_start + t * self.props.color_end

# will short-circuit for exponent 0, but still has energy target
noise = self.colorize_noise(noise, exponent=exponent, energy=self.props.energy)

return noise

@classmethod
def from_inputs(
cls,
shape: tuple[int, ...],
seed: torch.Generator,
props: ColoredProps = ColoredProps(),
dtype: torch.dtype = torch.float32,
) -> Self:
return cls(shape=shape, seed=seed, dtype=dtype, props=props)


@dataclass
class BatchTensorNoise[T: TensorNoiseProps | None](SkrampleTensorNoise):
"""Helper class for producing batches of noise while maintaining seeds across individual batch items.
Expand Down
5 changes: 1 addition & 4 deletions skrample/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,10 +555,7 @@ class FlowShift(ScheduleModifier):
"""Amount to shift noise schedule by."""

def _modify(self, t: NPSequence) -> NPSequence:
t = t.copy()
mask = t > 0
t[mask] = self.shift / (self.shift + (1 / t[mask] - 1))
return t
return self.shift / (self.shift + (1 / t - 1))


@dataclass(frozen=True)
Expand Down
Loading