Skip to content

Support for Dreambooth Weights #59

@anotherjesse

Description

@anotherjesse

I've been experimenting with supporting weights from dreambooth in this model:

diff --git a/predict.py b/predict.py
index 5630646..2d23a87 100644
--- a/predict.py
+++ b/predict.py
@@ -10,6 +10,7 @@ from diffusers import (
     StableDiffusionPipeline,
 )
 
+USE_WEIGHTS = os.path.exists("weights")
 MODEL_ID = "stabilityai/stable-diffusion-2-1"
 MODEL_CACHE = "diffusers-cache"
 
@@ -18,11 +19,18 @@ class Predictor(BasePredictor):
     def setup(self):
         """Load the model into memory to make running multiple predictions efficient"""
         print("Loading pipeline...")
-        self.pipe = StableDiffusionPipeline.from_pretrained(
-            MODEL_ID,
-            cache_dir=MODEL_CACHE,
-            local_files_only=True,
-        ).to("cuda")
+        if USE_WEIGHTS:
+            self.pipe = StableDiffusionPipeline.from_pretrained(
+                "weights",
+                safety_checker=None,
+                torch_dtype=torch.float16,
+            ).to("cuda")
+        else:
+            self.pipe = StableDiffusionPipeline.from_pretrained(
+                MODEL_ID,
+                cache_dir=MODEL_CACHE,
+                local_files_only=True,
+            ).to("cuda")
 
     @torch.inference_mode()
     def predict(

The only other major difference between this and dreambooth-template is that it has a hardcoded scheduler:

    scheduler = DDIMScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
    )

The default scheduler seems to work - although I don't know if those "magic numbers" in the DDIMScheduler in dreambooth-template are to maximize the quality from the dreambooth generations?

image

With the above patch all you have to do unzip the weights generated by this api https://replicate.com/replicate/dreambooth into cog-stable-diffusion and cog build

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions