-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTest.py
More file actions
63 lines (49 loc) · 2.35 KB
/
Test.py
File metadata and controls
63 lines (49 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn.functional as F
import torch.utils.data as data
import os, argparse
from lib.FINet import FINet
from utils.dataloader import test_dataset
import cv2
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', type=int, default=224, help='testing size')
parser.add_argument('--dataset_root', type=str,
default='/opt/data/private/datasets/ultrasound',
help='path to train dataset')
parser.add_argument('--dataset_name', type=str,
default='breast/BUSI1',
help='dataset')
parser.add_argument('--pth_path', type=str,
default='./model_pth/breast/BUSI1/model.pth')
parser.add_argument('--test_save', type=str,
default='./result_map/breast/BUSI1/FINet')
parser.add_argument('--batch_size', type=int,
default=1, help='training batch size')
parser.add_argument('--num_workers', type=int,
default=8, help='num_workers')
parser.add_argument('--num_classes', default=1, type=int,
help='number of classes')
config = vars(parser.parse_args())
# build models
torch.cuda.set_device(0) # set your gpu device
model = FINet().cuda()
model.load_state_dict(torch.load(config['pth_path'], map_location='cpu'))
model.eval()
test_path = os.path.join(config['dataset_root'], config['dataset_name'], 'test')
dataset = test_dataset(test_path, config['image_size'])
test_loader = data.data_loader = data.DataLoader(dataset=dataset,
batch_size=1,
shuffle=False)
for i, pack in enumerate(test_loader):
image, gt, name = pack
image = image.cuda()
gt = gt.cuda()
P, _, _, _, _ = model(image)
res = F.interpolate(P, size=(gt.shape[-2], gt.shape[-1]), mode='bilinear', align_corners=False)
res = res.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
if not os.path.exists(config['test_save']):
os.makedirs(config['test_save'])
cv2.imwrite(os.path.join(config['test_save'], ''.join(name).split('/')[-1]), res * 255)
print('Finish!')