Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/plot_skrample.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def callback(x: float, n: int, d: DeltaPoint) -> None:
composed = schedule
label: str = sched_name

if (subschedule := SUBSCHEDULES[sub]) and sub is not None:
if sub is not None and (subschedule := SUBSCHEDULES[sub]):
composed = subschedule[0](composed, **subschedule[1])
label += "_" + subschedule[0].__name__.lower()

Expand Down
29 changes: 18 additions & 11 deletions skrample/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import OrderedDict
from collections.abc import Hashable, Mapping, Sequence
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
Expand Down Expand Up @@ -67,6 +67,8 @@
("prediction_type", "flow"): ("skrample_predictor", FlowModel()),
("prediction_type", "sample"): ("skrample_predictor", DataModel()),
("prediction_type", "v_prediction"): ("skrample_predictor", VelocityModel()),
# backwards order, last values take priority
("use_flow_sigmas", True): ("skrample_subschedule", None),
("use_beta_sigmas", True): ("skrample_subschedule", scheduling.Beta),
("use_exponential_sigmas", True): ("skrample_subschedule", scheduling.Exponential),
("use_karras_sigmas", True): ("skrample_subschedule", scheduling.Karras),
Expand Down Expand Up @@ -113,10 +115,10 @@ def parse_diffusers_config(
if not isinstance(config, dict):
config = dict(config.config)

remapped = {DIFFUSERS_KEY_MAP[k]: v for k, v in config.items() if k in DIFFUSERS_KEY_MAP} | {
DIFFUSERS_VALUE_MAP[(k, v)][0]: DIFFUSERS_VALUE_MAP[(k, v)][1]
for k, v in config.items()
if isinstance(v, Hashable) and (k, v) in DIFFUSERS_VALUE_MAP
remapped = {key_to: config[key_from] for key_from, key_to in DIFFUSERS_KEY_MAP.items() if key_from in config} | {
key_to: value_to
for (key_from, value_from), (key_to, value_to) in DIFFUSERS_VALUE_MAP.items()
if key_from in config and config[key_from] == value_from
}

if "skrample_predictor" in remapped:
Expand Down Expand Up @@ -148,17 +150,22 @@ def parse_diffusers_config(

schedule_modifiers: list[tuple[type[ScheduleModifier], dict[str, Any]]] = []

if isinstance(model, FlowModel):
flow_keys = [f.name for f in dataclasses.fields(scheduling.FlowShift)]
schedule_modifiers.append((scheduling.FlowShift, {k: v for k, v in remapped.items() if k in flow_keys}))

if "skrample_subschedule" in remapped:
subschedule: type[SubSchedule] | None = cast("type[SubSchedule]", remapped.pop("skrample_subschedule"))
modifier_keys = [f.name for f in dataclasses.fields(subschedule)]
subschedule: type[SubSchedule] | None = remapped.pop("skrample_subschedule")
# flow sigmas is typically last prio BUT results just look bad with karras/exp which is apparently a thing now
# https://huggingface.co/nvidia/Cosmos3-Super-Text2Image/blob/main/scheduler/scheduler_config.json
# So just gonna replace that shit here until we decide to implement their weird normalized karras thing
if config.get("use_flow_sigmas", False) is True and subschedule in (scheduling.Karras, scheduling.Exponential):
subschedule = None
modifier_keys = [f.name for f in dataclasses.fields(subschedule)] if subschedule else []
subschedule_props = {k: v for k, v in remapped.items() if k in modifier_keys}
else:
subschedule, subschedule_props = None, {}

if isinstance(model, FlowModel) and not subschedule:
flow_keys = [f.name for f in dataclasses.fields(scheduling.FlowShift)]
schedule_modifiers.append((scheduling.FlowShift, {k: v for k, v in remapped.items() if k in flow_keys}))

# feels cleaner than inspect.signature().parameters
sampler_keys = [f.name for f in dataclasses.fields(sampler)]
schedule_keys = [f.name for f in dataclasses.fields(schedule)]
Expand Down
2 changes: 1 addition & 1 deletion skrample/sampling/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def sample_packed[T: Sample](
previous: Sequence[SKSamples[T]] = (),
) -> SKSamples[T]:
return SKSamples(
**( # ty: ignore # ???
**( # type: ignore # ???
dataclasses.asdict(packed)
| {
"final": self._sample_packed(
Expand Down
77 changes: 69 additions & 8 deletions tests/diffusers_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from skrample.diffusers import SkrampleWrapperScheduler
from skrample.sampling.models import DiffusionModel, FlowModel, NoiseModel, VelocityModel
from skrample.sampling.structured import DPM, Adams, Euler, UniPC
from skrample.scheduling import Beta, Exponential, FlowShift, Karras, Linear, Scaled, SubSchedule
from skrample.scheduling import Beta, Exponential, FlowShift, Karras, Linear, Scaled, ScheduleModifier, SubSchedule

EPSILON = NoiseModel()
FLOW = FlowModel()
Expand All @@ -26,6 +26,8 @@
def assert_wrapper(wrapper: SkrampleWrapperScheduler, scheduler: ConfigMixin) -> None:
a, b = wrapper, SkrampleWrapperScheduler.from_diffusers_config(scheduler)
a.fake_config = b.fake_config
assert a.sampler == b.sampler # individual asserts for complex structs first for easier debugging
assert a.schedule == b.schedule
assert a == b


Expand Down Expand Up @@ -104,13 +106,6 @@ def test_euler_flow() -> None:
)


def test_euler_beta() -> None:
assert_wrapper(
SkrampleWrapperScheduler(Euler(), FlowShift(Beta(Linear())), FLOW),
FlowMatchEulerDiscreteScheduler.from_config(FLOW_CONFIG | {"use_beta_sigmas": True}),
)


def test_ipndm() -> None:
assert_wrapper(
SkrampleWrapperScheduler(Adams(order=4), Scaled()),
Expand Down Expand Up @@ -151,3 +146,69 @@ def test_ddpm() -> None:
SkrampleWrapperScheduler(DPM(order=1, stochasticity=True), Scaled()),
DDPMScheduler.from_config(SCALED_CONFIG),
)


@pytest.mark.parametrize(
("karras", "exp", "beta", "subschedule"),
[
# https://github.com/huggingface/diffusers/blob/2d0110f8182d18834d5039b19232e5761023b5f6/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L441-L468
(True, True, True, Karras),
(False, True, True, Exponential),
(True, False, True, Karras),
(True, True, False, Karras),
(True, False, False, Karras),
(False, True, False, Exponential),
(False, False, True, Beta),
(False, False, False, None),
],
)
def test_subschedule_mro_vp(karras: bool, exp: bool, beta: bool, subschedule: type[SubSchedule] | None) -> None:
scheduler: DPMSolverMultistepScheduler = DPMSolverMultistepScheduler.from_config(SCALED_CONFIG)
scheduler._internal_dict = dict( # bypass validation checks
scheduler.config
| {
"use_karras_sigmas": karras,
"use_exponential_sigmas": exp,
"use_beta_sigmas": beta,
"use_flow_sigmas": False,
"flow_shift": 3,
}
)
assert_wrapper(
SkrampleWrapperScheduler(DPM(), Scaled() if subschedule is None else subschedule(Scaled())),
scheduler,
)


@pytest.mark.parametrize(
("karras", "exp", "beta", "subschedule"),
[
# Different than VP due to manual override in parse_diffusers_config
(True, True, True, FlowShift),
(False, True, True, FlowShift),
(True, False, True, FlowShift),
(True, True, False, FlowShift),
(True, False, False, FlowShift),
(False, True, False, FlowShift),
(False, False, True, Beta),
(False, False, False, FlowShift),
],
)
def test_subschedule_mro_fm(
karras: bool,
exp: bool,
beta: bool,
subschedule: type[SubSchedule | ScheduleModifier],
) -> None:
scheduler: DPMSolverMultistepScheduler = DPMSolverMultistepScheduler.from_config(FLOW_CONFIG)
scheduler._internal_dict = dict( # bypass validation checks
scheduler.config
| {
"use_karras_sigmas": karras,
"use_exponential_sigmas": exp,
"use_beta_sigmas": beta,
"use_flow_sigmas": True,
"flow_shift": 3,
}
)
assert_wrapper(SkrampleWrapperScheduler(DPM(), subschedule(Linear()), FlowModel()), scheduler)
Loading