-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredict.py
More file actions
110 lines (92 loc) · 2.88 KB
/
predict.py
File metadata and controls
110 lines (92 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import logging
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from eval import eval_net
from normal_net import NormalNet
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from utils.dataset import SynthiaDataset
def get_args():
parser = argparse.ArgumentParser(
description="Predict normal maps for test images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model",
"-m",
default="MODEL.pth",
metavar="FILE",
help="Specify the file in which the model is stored",
)
parser.add_argument(
"-d",
"--dataset",
dest="dataset",
type=str,
default="synthia",
help="Dataset to be used",
)
parser.add_argument(
"--save",
"-sv",
dest="save",
action="store_true",
help="Save the results",
default=False,
)
parser.add_argument(
"--scale",
"-s",
type=float,
help="Scale factor for the input images",
default=0.5,
)
parser.add_argument(
"--resnet",
"-res",
dest="resnet",
action="store_true",
help="Use pre-trained resnet encoder.",
default=False,
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
if args.dataset == "synthia":
rgb_dir = "data/synthia/rgb/{}/"
normals_dir = "data/synthia/normals/{}/"
seg_masks_dir = "data/synthia/seg_masks/{}/"
test_dataset = SynthiaDataset(
rgb_dir.format("test"),
normals_dir.format("test"),
seg_masks_dir.format("test"),
args.scale,
)
test_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True,
)
ckpt = args.model[5:]
if args.save:
os.makedirs(f"results/{ckpt}", exist_ok=True) # for saving results
net = NormalNet(discriminator=False, bilinear=True, resnet=args.resnet,)
logging.info("Loading model {}".format(args.model))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
net.to(device=device)
model_dict = net.state_dict()
pretrained_dict = torch.load(args.model, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
net.load_state_dict(pretrained_dict)
l1_criterion = nn.L1Loss()
logging.info("Model loaded. Starting evaluation.")
eval_net(
net, test_loader, device, writer=None, ckpt=ckpt, save=args.save,
)