-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimulate.py
More file actions
339 lines (289 loc) · 12 KB
/
simulate.py
File metadata and controls
339 lines (289 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
"""Real-time simulation of the trained highway driving agent using Pygame."""
# Suppress warnings
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', module='gymnasium')
import pygame
import gymnasium as gym
import highway_env
from stable_baselines3 import DQN
import argparse
import os
import torch
import zipfile
from configs.env_config import ENV_CONFIG, TRAIN_CONFIG # Import ENV_CONFIG and TRAIN_CONFIG
# Constants
SCREEN_WIDTH = 800
SCREEN_HEIGHT = 600
FPS = 30
BACKGROUND_COLOR = (220, 220, 220) # Light gray
VEHICLE_COLORS = {
'ego': (46, 204, 113), # Green for the ego agent
'other': (231, 76, 60), # Red for other vehicles
'merged': (155, 89, 182) # Purple for merging vehicles
}
class HighwaySimulator:
def __init__(self, model_path: str):
# Initialize Pygame
pygame.init()
self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Highway Driving Simulation")
self.clock = pygame.time.Clock()
# Create and configure the environment with our config
highway_env.register_highway_envs()
self.env = gym.make('highway-v0', render_mode='rgb_array')
self.env.configure(ENV_CONFIG)
# Load the trained model (for evaluation only)
print(f"\nLoading model from {model_path}")
try:
# Create a dummy DQN with same policy architecture as training
self.model = DQN(
policy="MlpPolicy",
env=self.env,
learning_rate=0.0, # Dummy, not used
buffer_size=1, # Dummy, not used
learning_starts=0, # Dummy, not used
batch_size=1, # Dummy, not used
policy_kwargs=TRAIN_CONFIG['policy_kwargs'], # Must match how it was trained
verbose=0
)
# Inspect the ZIP to find the correct policy filename
with zipfile.ZipFile(model_path, 'r') as model_zip:
names = model_zip.namelist()
# Usually SB3 stores: "dqn_policy.pth" or "policy.pth"
if 'policy.pth' in names:
policy_name = 'policy.pth'
else:
# Fallback: pick the first .pth file containing "policy"
policy_candidates = [n for n in names if n.endswith('.pth') and 'policy' in n]
if len(policy_candidates) == 0:
raise RuntimeError("No policy file found inside the model ZIP.")
policy_name = policy_candidates[0]
# Load the state‐dict directly
with model_zip.open(policy_name) as policy_file:
policy_state_dict = torch.load(policy_file, map_location='cpu')
self.model.policy.load_state_dict(policy_state_dict)
# Set to evaluation mode
self.model.policy.eval()
print("Model loaded successfully!\n")
except Exception as e:
print(f"\nError loading model: {e}")
raise
# Initialize fonts
self.font = pygame.font.Font(None, 36)
# Stats
self.episode_reward = 0.0
self.steps = 0
self.collisions = 0
# Viewport & lane‐drawing parameters
self.viewport_width = SCREEN_WIDTH * 0.8 # 80% of screen for forward motion
self.viewport_offset_x = SCREEN_WIDTH * 0.1 # 10% left margin
self.lane_height = SCREEN_HEIGHT * 0.15 # Height of each lane stripe
self.top_margin = SCREEN_HEIGHT * 0.2 # Top offset to start drawing lanes
def draw_vehicle(
self,
position: tuple,
color: tuple,
size: tuple = (40, 20),
is_merging: bool = False,
velocity: tuple = (0, 0)
):
"""
Draw a vehicle on the screen.
Args:
position: (x_screen, y_screen)
color: (R, G, B)
size: (width, height)
is_merging: whether the vehicle is currently changing lanes
velocity: (vx, vy) in environment‐units (for arrow direction)
"""
x_screen, y_screen = position
w, h = size
rect = pygame.Rect(
x_screen - w // 2,
y_screen - h // 2,
w,
h
)
pygame.draw.rect(self.screen, color, rect)
pygame.draw.rect(self.screen, (0, 0, 0), rect, 1)
if is_merging:
arrow_color = (255, 255, 0) # Yellow arrow
vy = velocity[1]
if vy > 0:
# Arrow pointing downwards
pygame.draw.polygon(
self.screen,
arrow_color,
[
(x_screen, y_screen + h // 2),
(x_screen - w // 4, y_screen),
(x_screen + w // 4, y_screen)
]
)
else:
# Arrow pointing upwards
pygame.draw.polygon(
self.screen,
arrow_color,
[
(x_screen, y_screen - h // 2),
(x_screen - w // 4, y_screen),
(x_screen + w // 4, y_screen)
]
)
def draw_stats(self):
"""Draw the current simulation statistics (steps, reward, collisions)."""
stats = [
f"Steps: {self.steps}",
f"Reward: {self.episode_reward:.1f}",
f"Collisions: {self.collisions}"
]
for i, text in enumerate(stats):
surface = self.font.render(text, True, (0, 0, 0))
self.screen.blit(surface, (10, 10 + i * 30))
def convert_env_to_display(self, env_pos: tuple):
"""
Convert environment (x, y) in [−1..1] (or normalized) to screen‐pixel (x, y).
env_pos = (x, y), where
x ∈ [−1, +1] maps to forward/backward along the highway → screen_x
y ∈ [−1, +1] maps to lane positions → screen_y
We assume that the environment state is normalized to [0..1] by default
(ENV_CONFIG: "normalize": True), so x,y∈[0..1].
We then:
- scale x∈[0..1] → x_screen ∈ [viewport_offset_x .. viewport_offset_x + viewport_width]
- scale y∈[0..1] → y_screen ∈ [top_margin .. top_margin + 4*lane_height]
(since there are 4 lanes+shoulders, that total vertical “band” is 4*lane_height)
Note: Pygame’s y=0 is the top of the window, so if env.y=0 means “top lane,”
and env.y=1 means “bottom lane,” no flipping is needed. If flipping is needed,
invert piped difference.
"""
raw_x, raw_y = env_pos
# Map raw_x ∈ [0..1] to screen_x
screen_x = raw_x * self.viewport_width + self.viewport_offset_x
# Map raw_y ∈ [0..1] to screen_y in a block of (4 * lane_height) starting at top_margin
screen_y = self.top_margin + raw_y * (self.lane_height * 4)
return (int(screen_x), int(screen_y))
def get_vehicle_positions(self, obs):
"""
Extract ego‐vehicle and other vehicles from the raw observation array.
Each row in obs is a 5‐dim vector: [presence, x, y, vx, vy].
- presence: 1 if vehicle exists, 0 otherwise
- x: longitudinal position (normalized to [0..1])
- y: lateral position (normalized to [0..1])
- vx: forward speed
- vy: lateral speed
We return:
- ego_pos = (x, y)
- other_vehicles = [
{
'position': (x, y),
'merging': bool,
'velocity': (vx, vy)
},
...
]
"""
# The first row of obs is always the ego‐vehicle.
ego_vehicle = obs[0]
ego_present, ego_x, ego_y, ego_vx, ego_vy = ego_vehicle
ego_pos = (ego_x, ego_y)
other_vehicles = []
# Each subsequent row (1..4) is a potential “other vehicle”
for vehicle in obs[1:]:
pres, x, y, vx, vy = vehicle
if pres < 0.5:
# If presence < 0.5, that vehicle slot is empty → skip
continue
is_merging = abs(vy) > 0.05
other_vehicles.append({
'position': (x, y),
'merging': is_merging,
'velocity': (vx, vy)
})
return ego_pos, other_vehicles
def run(self):
running = True
# Gymnasium’s reset returns obs, info
obs, info = self.env.reset()
terminated = False
truncated = False
while running:
# 1) Event handling
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
elif event.type == pygame.KEYDOWN and event.key == pygame.K_r:
# Reset entire episode
obs, info = self.env.reset()
terminated = False
truncated = False
self.episode_reward = 0.0
self.steps = 0
self.collisions = 0
# 2) Get action (deterministic evaluation)
action, _ = self.model.predict(obs, deterministic=True)
# 3) Step environment
obs, reward, terminated, truncated, info = self.env.step(action)
# 4) Update stats
self.episode_reward += float(reward)
self.steps += 1
# Gym‐info field for collisions can be "crashed" or "collision"
if info.get('crashed', False) or info.get('collision', False):
self.collisions += 1
# 5) Clear screen
self.screen.fill(BACKGROUND_COLOR)
# 6) Draw horizontal lane lines (4 lanes + shoulders)
for i in range(5): # i=0..4 → draw 5 lines
y_line = self.top_margin + i * self.lane_height
pygame.draw.line(
self.screen,
(128, 128, 128),
(0, int(y_line)),
(SCREEN_WIDTH, int(y_line)),
2
)
# 7) Extract and draw vehicles
ego_pos, other_vehicles = self.get_vehicle_positions(obs)
# Draw ego vehicle
ego_screen = self.convert_env_to_display(ego_pos)
self.draw_vehicle(ego_screen, VEHICLE_COLORS['ego'], is_merging=False)
# Draw other vehicles
for v in other_vehicles:
screen_pos = self.convert_env_to_display(v['position'])
color = VEHICLE_COLORS['merged'] if v['merging'] else VEHICLE_COLORS['other']
self.draw_vehicle(
screen_pos,
color,
is_merging=v['merging'],
velocity=v['velocity']
)
# 8) Draw stats
self.draw_stats()
# 9) Update display
pygame.display.flip()
# 10) Cap framerate
self.clock.tick(FPS)
# 11) If episode ended, reset and continue
if terminated or truncated:
obs, info = self.env.reset()
terminated = False
truncated = False
self.episode_reward = 0.0
self.steps = 0
self.collisions = 0
# Cleanup
pygame.quit()
self.env.close()
def main():
parser = argparse.ArgumentParser(description='Highway Driving Simulation')
parser.add_argument('--model', type=str, required=True, help='Path to trained model zip')
args = parser.parse_args()
if not os.path.exists(args.model):
print(f"Error: Model file not found at {args.model}")
return
simulator = HighwaySimulator(args.model)
simulator.run()
if __name__ == "__main__":
main()