-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmeasure_model.py
More file actions
120 lines (104 loc) · 3.87 KB
/
measure_model.py
File metadata and controls
120 lines (104 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from network.model import get_model
from thop import profile
import torch
from torch.utils.flop_counter import FlopCounterMode
import torch.nn as nn
import click
from zeus.monitor import ZeusMonitor
import json
def get_datum(dataset, B: int= 1):
if dataset in ["fetalAbdominal", "XRayMimic", "XRayMimic200"]:
return torch.randn(B, 1, 64, 64)
elif dataset in ["crc"]:
return torch.randn(B, 3, 96, 96)
else:
raise ValueError(f"Unknown dataset {dataset}")
def get_flops(model: nn.Module, inp, train):
model.train(train)
inp = inp if isinstance(inp, torch.Tensor) else torch.randn(inp)
flop_counter = FlopCounterMode(mods=model, display=False, depth=None)
with flop_counter:
if train:
model(inp).sum().backward()
else:
with torch.no_grad():
model(inp)
total_flops = flop_counter.get_total_flops()
return {"flops": total_flops}
def get_macs_and_params(model, inp):
model.eval()
with torch.no_grad():
macs, params = profile(model, inputs=(inp,), verbose=False)
return {"macs": macs, "params": params}
def get_peak_mem_consumption(model: nn.Module, inp, train):
torch.cuda.reset_peak_memory_stats()
model.train(train)
if train:
optimizer = torch.optim.Adam(model.parameters())
model(inp).sum().backward()
optimizer.step()
optimizer.zero_grad()
else:
with torch.no_grad():
model(inp)
torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated()
return {"mem": mem}
def get_time(model: nn.Module, inp, train):
model.train(train)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if train:
optimizer = torch.optim.Adam(model.parameters())
for _ in range(100):
optimizer.zero_grad()
model(inp).sum().backward()
optimizer.step()
else:
with torch.no_grad():
for _ in range(100):
model(inp)
end.record()
torch.cuda.synchronize()
return {"time": start.elapsed_time(end)}
def get_energy_consumption(model: nn.Module, inp, train):
model.train(train)
monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()])
monitor.begin_window("epoch")
if train:
optimizer = torch.optim.Adam(model.parameters())
for _ in range(100):
optimizer.zero_grad()
model(inp).sum().backward()
optimizer.step()
else:
with torch.no_grad():
for _ in range(100):
model(inp)
mes = monitor.end_window("epoch")
return {"energy": mes.gpu_energy[0]}
@click.command()
@click.option('--command', 'command', help='Command to run', required=True)
@click.option('--data', 'data_type', help='Data type', required=True)
@click.option('--model', 'model_type', default='mednca', help='Model type', required=True)
@click.option('--mode', 'mode', default='train', help='Mode (train/eval)', required=True)
def main(command, data_type, model_type, mode):
assert torch.cuda.memory_allocated(0) == 0, f"Expected no memory allocated on GPU 0, but found {torch.cuda.memory_allocated(0)}"
inp = get_datum(data_type, B=1 if mode=="eval"else 4).cuda()
model = get_model(data_type, model_type).cuda()
if command == "macs":
result = get_macs_and_params(model, inp)
elif command == "flops":
result = get_flops(model, inp, mode == "train")
elif command == "mem":
result = get_peak_mem_consumption(model, inp, mode == "train")
elif command == "time":
result = get_time(model, inp, mode == "train")
elif command == "energy":
result = get_energy_consumption(model, inp, mode == "train")
else:
raise ValueError(f"Unknown command: {command}")
print(json.dumps(result))
if __name__ == "__main__":
main()