-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_matchcut.py
More file actions
72 lines (58 loc) · 2.61 KB
/
generate_matchcut.py
File metadata and controls
72 lines (58 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
from pathlib import Path
import torch
from diffusers import CogVideoXPipeline
from samplers.samplers_cog import sample_match_cog
from utils import add_args, save_metadata, save_video_mc
import gc # garbage collection
# Parse args
parser = argparse.ArgumentParser()
parser = add_args(parser)
args = parser.parse_args()
save_dir = Path(args.save_dir) / f"{args.name}_{args.model_size}"
save_dir.mkdir(exist_ok=True, parents=True)
print(f"Using model THUDM/CogVideoX-{args.model_size}")
if args.model_size == '5b':
cog_pipeline = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX-5b", torch_dtype=torch.float16
).to("cuda")
elif args.model_size == '2b':
cog_pipeline = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX-2b", torch_dtype=torch.float16
).to("cuda")
prompt_embeds = []
for p in args.prompts:
prompt_embeds.append(cog_pipeline.encode_prompt(p,
device="cuda",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
))
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds) # These are just null embeds
# Save metadata
save_metadata(args, save_dir)
for i in range(args.num_samples):
generator = torch.manual_seed(args.seed + i)
sample_dir = save_dir / f'{args.seed + i:04}'
sample_dir.mkdir(exist_ok=True, parents=True)
videos = sample_match_cog(cog_pipeline,
prompt_embeds,
negative_prompt_embeds,
num_inference_steps=args.num_inference_steps,
num_joint_steps=args.num_joint_steps,
guidance_scale=args.guidance_scale,
generator=generator,
initial_lambda_1_1=args.initial_lambda_1_1,
final_lambda_1_1=args.final_lambda_1_1,
initial_lambda_2_2=args.initial_lambda_2_2,
final_lambda_2_2=args.final_lambda_2_2,
lambda_schedule=args.lambda_schedule,
)
# clear gpu memory to be able to handle large number of samples
# Clear video-related variables from GPU memory
save_video_mc(videos, sample_dir)
del videos
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()