-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathinference_merge_base_depth.py
More file actions
102 lines (89 loc) · 2.72 KB
/
inference_merge_base_depth.py
File metadata and controls
102 lines (89 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torch
from PIL import Image
import os
import sys
from matplotlib import pyplot as plt
import argparse
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from merge.pipeline.merge_transformer import xTransformerModel, MERGEPixArtTransformer
from merge.pipeline.pipeline_merge import MERGEPixArtPipeline
cmap = plt.get_cmap('Spectral')
def main(args):
weight_dtype = torch.float32
fixed_transformer = xTransformerModel.from_pretrained(
args.pretrained_model_path,
subfolder="transformer", torch_dtype=weight_dtype
)
fixed_transformer.requires_grad_(False)
depth_converters = xTransformerModel.from_pretrained(
args.model_weights,
subfolder="depth_converters",
torch_dtype=weight_dtype
)
depth_converters.requires_grad_(False)
merge_transformer = MERGEPixArtTransformer(fixed_transformer, depth_converters)
del fixed_transformer, depth_converters
merge_model = MERGEPixArtPipeline.from_pretrained(
args.pretrained_model_path,
transformer=merge_transformer,
torch_dtype=weight_dtype,
use_safetensors=True
).to("cuda")
# for depth estimation
image = Image.open(args.image_path)
width, height = image.size
depth_image = merge_model(
image=image,
prompt='',
num_inference_steps=args.denoising_step,
height=height,
width=width,
mode='merge'
).images
depth_image = torch.mean(depth_image, dim=1).squeeze().cpu().numpy()
depth_image = (cmap(depth_image) * 255).astype(np.uint8)
Image.fromarray(depth_image).save("./merge_base_depth_demo.png")
# for text-to-image
image = merge_model(
prompt=args.prompt,
num_inference_steps=args.denoising_step,
guidance_scale=4.5,
mode='t2i',
).images[0]
image.save("./merge_base_t2i_demo.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_path",
type=str,
required=True,
help="Path to pretrained text-to-image model.",
)
parser.add_argument(
"--model_weights",
type=str,
required=True,
help="Path to converter weight.",
)
parser.add_argument(
"--image_path",
type=str,
required=True,
help="Path to input image.",
)
parser.add_argument(
"--prompt",
type=str,
default='a apple',
required=False,
help="Prompt for text-to-image.",
)
parser.add_argument(
"--denoising_step",
type=int,
default=20,
help="Denoising step.",
)
args = parser.parse_args()
main(args)