Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dfm/src/automodel/_diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .auto_diffusion_pipeline import NeMoAutoDiffusionPipeline
from .auto_diffusion_pipeline import NeMoAutoDiffusionPipeline, PipelineSpec


__all__ = [
"NeMoAutoDiffusionPipeline",
"PipelineSpec",
]
367 changes: 289 additions & 78 deletions dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions dfm/src/automodel/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
build_mock_dataloader,
mock_collate_fn,
)
from dfm.src.automodel.datasets.multiresolutionDataloader import (
build_video_multiresolution_dataloader,
)


__all__ = [
Expand All @@ -35,4 +38,6 @@
"MockWanDataset",
"build_mock_dataloader",
"mock_collate_fn",
# Multiresolution video dataloader
"build_video_multiresolution_dataloader",
]
21 changes: 21 additions & 0 deletions dfm/src/automodel/datasets/multiresolutionDataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_dataset import BaseMultiresolutionDataset
from .dataloader import (
SequentialBucketSampler,
build_multiresolution_dataloader,
collate_fn_production,
)
from .flux_collate import (
build_flux_multiresolution_dataloader,
collate_fn_flux,
)
from .multi_tier_bucketing import MultiTierBucketCalculator
from .text_to_image_dataset import TextToImageDataset
from .text_to_video_dataset import TextToVideoDataset
from .video_collate import (
build_video_multiresolution_dataloader,
collate_fn_video,
)


__all__ = [
# Base class
"BaseMultiresolutionDataset",
# Dataset classes
"TextToImageDataset",
"TextToVideoDataset",
# Utilities
"MultiTierBucketCalculator",
"SequentialBucketSampler",
"build_multiresolution_dataloader",
"collate_fn_production",
# Flux-specific
"build_flux_multiresolution_dataloader",
"collate_fn_flux",
# Video-specific
"build_video_multiresolution_dataloader",
"collate_fn_video",
]
162 changes: 162 additions & 0 deletions dfm/src/automodel/datasets/multiresolutionDataloader/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Abstract base class for multiresolution datasets.

Provides shared functionality for TextToImageDataset and TextToVideoDataset:
- Metadata loading from sharded JSON files
- Bucket grouping by aspect ratio and resolution
- Bucket info utilities
"""

import json
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List

import torch
from torch.utils.data import Dataset

from dfm.src.automodel.datasets.multiresolutionDataloader.multi_tier_bucketing import MultiTierBucketCalculator


logger = logging.getLogger(__name__)


class BaseMultiresolutionDataset(Dataset, ABC):
"""
Abstract base class for multiresolution datasets.

Provides shared functionality for loading preprocessed cache data
and organizing samples into buckets by aspect ratio and resolution.

Subclasses must implement:
- __getitem__: Load a single sample with media-type-specific fields
"""

def __init__(self, cache_dir: str, quantization: int = 64):
"""
Initialize the dataset.

Args:
cache_dir: Directory containing preprocessed cache (metadata.json and cache files)
quantization: Bucket calculator quantization (64 for images, 8 for videos)
"""
self.cache_dir = Path(cache_dir)

# Load metadata
self.metadata = self._load_metadata()

# Group by bucket
self._group_by_bucket()

# Initialize bucket calculator for dynamic batch sizes
self.calculator = MultiTierBucketCalculator(quantization=quantization)

def _load_metadata(self) -> List[Dict]:
"""
Load metadata from cache directory.

Expects metadata.json with "shards" key referencing shard files.
"""
metadata_file = self.cache_dir / "metadata.json"

if not metadata_file.exists():
raise FileNotFoundError(f"No metadata.json found in {self.cache_dir}")

with open(metadata_file, "r") as f:
data = json.load(f)

if not isinstance(data, dict) or "shards" not in data:
raise ValueError(f"Invalid metadata format in {metadata_file}. Expected dict with 'shards' key.")

# Load all shard files
metadata = []
for shard_name in data["shards"]:
shard_path = self.cache_dir / shard_name
with open(shard_path, "r") as f:
shard_data = json.load(f)
metadata.extend(shard_data)

return metadata

def _aspect_ratio_to_name(self, aspect_ratio: float) -> str:
"""Convert aspect ratio to a descriptive name."""
if aspect_ratio < 0.85:
return "tall"
elif aspect_ratio > 1.18:
return "wide"
else:
return "square"

def _group_by_bucket(self):
"""Group samples by bucket (aspect_ratio + resolution)."""
self.bucket_groups = {}

for idx, item in enumerate(self.metadata):
# Bucket key: aspect_name/resolution
# Support both old "crop_resolution" and new "bucket_resolution" keys for backward compatibility
aspect_ratio = item.get("aspect_ratio", 1.0)
aspect_name = self._aspect_ratio_to_name(aspect_ratio)
resolution = tuple(item.get("bucket_resolution", item.get("crop_resolution")))
bucket_key = (aspect_name, resolution)

if bucket_key not in self.bucket_groups:
self.bucket_groups[bucket_key] = {
"indices": [],
"aspect_name": aspect_name,
"aspect_ratio": aspect_ratio,
"resolution": resolution,
"pixels": resolution[0] * resolution[1],
}

self.bucket_groups[bucket_key]["indices"].append(idx)

# Sort buckets by resolution (low to high for optimal memory usage)
self.sorted_bucket_keys = sorted(self.bucket_groups.keys(), key=lambda k: self.bucket_groups[k]["pixels"])

logger.info(f"\nDataset organized into {len(self.bucket_groups)} buckets:")
for key in self.sorted_bucket_keys:
bucket = self.bucket_groups[key]
aspect_name, resolution = key
logger.info(
f" {aspect_name:6s} {resolution[0]:4d}x{resolution[1]:4d}: {len(bucket['indices']):5d} samples"
)

def get_bucket_info(self) -> Dict:
"""Get bucket organization information."""
return {
"total_buckets": len(self.bucket_groups),
"buckets": {f"{k[0]}/{k[1][0]}x{k[1][1]}": len(v["indices"]) for k, v in self.bucket_groups.items()},
}

def __len__(self) -> int:
return len(self.metadata)

@abstractmethod
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Load a single sample.

Subclasses must implement this method to load media-type-specific data.

Args:
idx: Sample index

Returns:
Dict containing sample data with keys specific to the media type
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(
logger.info(
f" Base batch size: {base_batch_size}" + (f" @ {base_resolution}" if dynamic_batch_size else " (fixed)")
)
logger.info(f" DDP: rank {self.rank} of {self.num_replicas}")

def _get_batch_size(self, resolution: Tuple[int, int]) -> int:
"""Get batch size for resolution (dynamic or fixed based on setting)."""
Expand Down Expand Up @@ -229,12 +228,12 @@ def get_batch_info(self, batch_idx: int) -> Dict:
def collate_fn_production(batch: List[Dict]) -> Dict:
"""Production collate function with verification."""
# Verify all samples have same resolution
resolutions = [tuple(item["crop_resolution"].tolist()) for item in batch]
resolutions = [tuple(item["bucket_resolution"].tolist()) for item in batch]
assert len(set(resolutions)) == 1, f"Mixed resolutions in batch: {set(resolutions)}"

# Stack tensors
latents = torch.stack([item["latent"] for item in batch])
crop_resolutions = torch.stack([item["crop_resolution"] for item in batch])
bucket_resolutions = torch.stack([item["bucket_resolution"] for item in batch])
original_resolutions = torch.stack([item["original_resolution"] for item in batch])
crop_offsets = torch.stack([item["crop_offset"] for item in batch])

Expand All @@ -246,7 +245,7 @@ def collate_fn_production(batch: List[Dict]) -> Dict:

output = {
"latent": latents,
"crop_resolution": crop_resolutions,
"bucket_resolution": bucket_resolutions,
"original_resolution": original_resolutions,
"crop_offset": crop_offsets,
"prompt": prompts,
Expand All @@ -258,8 +257,8 @@ def collate_fn_production(batch: List[Dict]) -> Dict:
# Handle text encodings
if "clip_hidden" in batch[0]:
output["clip_hidden"] = torch.stack([item["clip_hidden"] for item in batch])
output["clip_pooled"] = torch.stack([item["clip_pooled"] for item in batch])
output["t5_hidden"] = torch.stack([item["t5_hidden"] for item in batch])
output["pooled_prompt_embeds"] = torch.stack([item["pooled_prompt_embeds"] for item in batch])
output["prompt_embeds"] = torch.stack([item["prompt_embeds"] for item in batch])
else:
output["clip_tokens"] = torch.stack([item["clip_tokens"] for item in batch])
output["t5_tokens"] = torch.stack([item["t5_tokens"] for item in batch])
Expand Down
Loading