diff --git a/skrample/diffusers.py b/skrample/diffusers.py index aa7ab3c..22408e5 100644 --- a/skrample/diffusers.py +++ b/skrample/diffusers.py @@ -307,7 +307,7 @@ def get_step_noise[T: TensorNoiseProps | None]( # multiply by step index to spread the values and minimize clash # does not work across batch sizes but at least Flux will have something mostly deterministic seeds = [ - torch.Generator().manual_seed( + torch.Generator(torch.get_default_device()).manual_seed( int(b.reshape(b.numel())[b.numel() // 2].item() * 1e4 * (step.position() + 1)) ) for b in sample