-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_yolo.py
More file actions
1761 lines (1386 loc) · 62.7 KB
/
train_yolo.py
File metadata and controls
1761 lines (1386 loc) · 62.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Implémentation corrigée et complète de YOLOv8 en PyTorch
"""
import math
from typing import List, Tuple, Optional, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from yolov8 import YOLOv8
''''
class Conv(nn.Module):
"""
Couche de convolution standard avec batch normalization et activation SiLU
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 1,
stride: int = 1, padding: Optional[int] = None, groups: int = 1,
activation: bool = True):
super().__init__()
if padding is None:
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size, stride,
padding, groups=groups, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.SiLU() if activation else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.conv(x)))
class Bottleneck(nn.Module):
"""
Bloc bottleneck résiduel
"""
def __init__(self, in_channels: int, out_channels: int, shortcut: bool = True,
groups: int = 1, expansion: float = 0.5):
super().__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = Conv(in_channels, hidden_channels, 1, 1)
self.conv2 = Conv(hidden_channels, out_channels, 3, 1, groups=groups)
self.add = shortcut and in_channels == out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.conv2(self.conv1(x)) if self.add else self.conv2(self.conv1(x))
class C2f(nn.Module):
"""
Bloc CSP (Cross Stage Partial) avec connexions multiples
Version améliorée pour YOLOv8
"""
def __init__(self, in_channels: int, out_channels: int, n: int = 1,
shortcut: bool = False, groups: int = 1, expansion: float = 0.5):
super().__init__()
self.c = int(out_channels * expansion)
self.cv1 = Conv(in_channels, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, out_channels, 1)
self.m = nn.ModuleList(
Bottleneck(self.c, self.c, shortcut, groups, expansion=1.0)
for _ in range(n)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class SPPF(nn.Module):
"""
Spatial Pyramid Pooling Fast
Pooling à plusieurs échelles pour capturer le contexte
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 5):
super().__init__()
c_ = in_channels // 2
self.cv1 = Conv(in_channels, c_, 1, 1)
self.cv2 = Conv(c_ * 4, out_channels, 1, 1)
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1,
padding=kernel_size // 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.cv1(x)
y1 = self.m(x)
y2 = self.m(y1)
y3 = self.m(y2)
return self.cv2(torch.cat([x, y1, y2, y3], 1))
class Backbone(nn.Module):
"""
Backbone CSPDarknet pour YOLOv8 - Version corrigée
"""
def __init__(self, base_channels: int = 64, base_depth: int = 3,
deep_mul: float = 0.33, width_mul: float = 0.25):
super().__init__()
# Facteurs d'échelle selon la version YOLOv8-n
ch = [base_channels, base_channels * 2, base_channels * 4,
base_channels * 8, base_channels * 16]
ch = [int(c * width_mul) for c in ch]
n = [max(round(base_depth * deep_mul), 1) for _ in range(4)]
# Couches initiales
self.stem = Conv(3, ch[0], 3, 2, 1)
# Blocs du backbone avec dimensions corrigées
self.dark2 = nn.Sequential(
Conv(ch[0], ch[1], 3, 2),
C2f(ch[1], ch[1], n[0], True)
)
self.dark3 = nn.Sequential(
Conv(ch[1], ch[2], 3, 2),
C2f(ch[2], ch[2], n[1], True)
)
self.dark4 = nn.Sequential(
Conv(ch[2], ch[3], 3, 2),
C2f(ch[3], ch[3], n[2], True)
)
self.dark5 = nn.Sequential(
Conv(ch[3], ch[4], 3, 2),
C2f(ch[4], ch[4], n[3], True),
SPPF(ch[4], ch[4], 5)
)
# CORRECTION: Stockage des canaux P3, P4, P5
self.channels = [ch[2], ch[3], ch[4]] # [64, 128, 256] pour YOLOv8-n
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
x = self.stem(x)
x = self.dark2(x)
x3 = self.dark3(x) # Sortie à 1/8 de la résolution
x4 = self.dark4(x3) # Sortie à 1/16 de la résolution
x5 = self.dark5(x4) # Sortie à 1/32 de la résolution
return x3, x4, x5
class Neck(nn.Module):
"""
Neck PANet pour fusionner les caractéristiques multi-échelles - Version corrigée
"""
def __init__(self, channels: List[int], depth: float = 0.33,
width: float = 0.25):
super().__init__()
# CORRECTION: On utilise les channels TELS QUELS du backbone
# Pas de multiplication par width ici !
ch = channels # [64, 128, 256] pour YOLOv8-n
n = max(round(depth * 3), 1)
# Stockage des canaux
self.channels = ch
# Convolutions pour réduire les canaux avant la fusion
# P5 (256) -> P4 (128)
self.reduce1 = Conv(ch[2], ch[1], 1, 1) # CORRECTION: 256 -> 128
# P4 (128) -> P3 (64)
self.reduce2 = Conv(ch[1], ch[0], 1, 1) # CORRECTION: 128 -> 64
# UpSampling
self.up = nn.Upsample(scale_factor=2, mode='nearest')
# Connexions top-down (de haut en bas)
# P5 réduit (128) + P4 (128) = 256 -> 128
self.n1 = C2f(ch[1] * 2, ch[1], n, False)
# P4 réduit (64) + P3 (64) = 128 -> 64
self.n2 = C2f(ch[0] * 2, ch[0], n, False)
# Connexions bottom-up (de bas en haut)
self.n3 = Conv(ch[0], ch[0], 3, 2) # Downsample P3 (64 -> 64)
# P3 down (64) + P4 (128) = 192 -> 128
self.n4 = C2f(ch[0] + ch[1], ch[1], n, False)
self.n5 = Conv(ch[1], ch[1], 3, 2) # Downsample P4 (128 -> 128)
# P4 down (128) + P5 (256) = 384 -> 256
self.n6 = C2f(ch[1] + ch[2], ch[2], n, False)
def forward(self, features: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
x3, x4, x5 = features
# Top-down pathway
# Étape 1: P5 -> P4
p5_reduced = self.reduce1(x5) # 256 -> 128
p5_up = self.up(p5_reduced)
p4 = torch.cat([p5_up, x4], 1) # 128 + 128 = 256
p4 = self.n1(p4) # 256 -> 128
# Étape 2: P4 -> P3
p4_reduced = self.reduce2(p4) # 128 -> 64
p4_up = self.up(p4_reduced)
p3 = torch.cat([p4_up, x3], 1) # 64 + 64 = 128
p3 = self.n2(p3) # 128 -> 64
# Bottom-up pathway
# Étape 3: P3 -> P4
p3_down = self.n3(p3) # 64 -> 64 (downsample)
p4_cat = torch.cat([p3_down, p4], 1) # 64 + 128 = 192
p4_out = self.n4(p4_cat) # 192 -> 128
# Étape 4: P4 -> P5
p4_down = self.n5(p4_out) # 128 -> 128 (downsample)
p5_cat = torch.cat([p4_down, x5], 1) # 128 + 256 = 384
p5_out = self.n6(p5_cat) # 384 -> 256
return p3, p4_out, p5_out
class DetectionHead(nn.Module):
"""
Tête de détection anchor-free pour YOLOv8 - Version corrigée
"""
def __init__(self, num_classes: int = 80, channels: List[int] = None,
width: float = 0.25):
super().__init__()
if channels is None:
channels = [256, 512, 1024]
# CORRECTION: On utilise les channels TELS QUELS
ch = channels # [64, 128, 256] pour YOLOv8-n
self.num_classes = num_classes
# Paramètres de configuration
self.reg_max = 16 # Utilisé pour la prédiction DFL
self.nc = num_classes
self.no = num_classes + 4 * self.reg_max + 1 # Correction: 4 * reg_max pour DFL
# Convolutions pour chaque échelle
self.cv2 = nn.ModuleList()
for i in range(3):
self.cv2.append(
nn.Sequential(
Conv(ch[i], ch[i], 3), # CORRECTION: ch[i] au lieu de int(ch[i] * width)
Conv(ch[i], ch[i], 3),
nn.Conv2d(ch[i], self.no, 1)
)
)
# Initialisation des poids
self._initialize_weights()
def _initialize_weights(self):
"""Initialisation des poids pour améliorer la convergence"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, features: Tuple[torch.Tensor, ...]) -> List[torch.Tensor]:
"""
Forward pass de la tête de détection
Args:
features: Tuple de 3 tenseurs (P3, P4, P5)
Returns:
List de tenseurs de prédictions pour chaque échelle
"""
outputs = []
for i, x in enumerate(features):
output = self.cv2[i](x)
outputs.append(output)
return outputs
class YOLOv8(nn.Module):
"""
Modèle complet YOLOv8 - Version corrigée
"""
def __init__(self, num_classes: int = 80, version: str = 'n',
pretrained: bool = False):
super().__init__()
# Configuration selon la version YOLOv8
configs = {
'n': {'depth': 0.33, 'width': 0.25, 'base_channels': 64},
's': {'depth': 0.33, 'width': 0.50, 'base_channels': 64},
'm': {'depth': 0.67, 'width': 0.75, 'base_channels': 64},
'l': {'depth': 1.00, 'width': 1.00, 'base_channels': 64},
'x': {'depth': 1.00, 'width': 1.25, 'base_channels': 64}
}
config = configs.get(version, configs['n'])
# Initialisation des composants avec dimensions cohérentes
self.backbone = Backbone(
base_channels=config['base_channels'],
base_depth=3,
deep_mul=config['depth'],
width_mul=config['width']
)
# Les canaux viennent directement du backbone
channels = self.backbone.channels # [64, 128, 256] pour YOLOv8-n
self.neck = Neck(
channels=channels, # Passage direct des canaux
depth=config['depth'],
width=config['width'] # Toujours passé pour la configuration interne
)
self.head = DetectionHead(
num_classes=num_classes,
channels=channels, # Mêmes canaux
width=config['width']
)
self.num_classes = num_classes
self.version = version
if pretrained:
self._load_pretrained_weights()
def _load_pretrained_weights(self):
"""Chargement des poids pré-entraînés (placeholder)"""
print("Note: Chargement des poids pré-entraînés non implémenté dans cette version")
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Forward pass complet
Args:
x: Image d'entrée [batch_size, 3, height, width]
Returns:
List de prédictions pour chaque échelle
"""
# Backbone
features = self.backbone(x)
# Neck
neck_features = self.neck(features)
# Head
predictions = self.head(neck_features)
return predictions
'''
class YOLOv8Loss(nn.Module):
"""
Fonction de perte complète pour YOLOv8
Combine: Classification, Régression (DFL), Objectness
"""
def __init__(self, num_classes: int = 80, reg_max: int = 16):
super().__init__()
self.num_classes = num_classes
self.reg_max = reg_max
self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.use_dfl = True
# Initialisation des poids pour les différentes pertes
self.box_weight = 7.5
self.cls_weight = 0.5
self.dfl_weight = 1.5
self.obj_weight = 1.0
def bbox_iou(self, box1: torch.Tensor, box2: torch.Tensor,
xywh: bool = True, eps: float = 1e-7) -> torch.Tensor:
"""
Calcul de l'IoU (Intersection over Union)
"""
if xywh:
# Conversion xywh -> xyxy
b1_x1, b1_x2 = box1[..., 0] - box1[..., 2] / 2, box1[..., 0] + box1[..., 2] / 2
b1_y1, b1_y2 = box1[..., 1] - box1[..., 3] / 2, box1[..., 1] + box1[..., 3] / 2
b2_x1, b2_x2 = box2[..., 0] - box2[..., 2] / 2, box2[..., 0] + box2[..., 2] / 2
b2_y1, b2_y2 = box2[..., 1] - box2[..., 3] / 2, box2[..., 1] + box2[..., 3] / 2
else:
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
# Intersection
inter_x1 = torch.max(b1_x1, b2_x1)
inter_y1 = torch.max(b1_y1, b2_y1)
inter_x2 = torch.min(b1_x2, b2_x2)
inter_y2 = torch.min(b1_y2, b2_y2)
inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)
# Union
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
union_area = b1_area + b2_area - inter_area + eps
return inter_area / union_area
def dfl_loss(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Distribution Focal Loss (DFL) pour la régression des boîtes
Version corrigée
"""
# pred_dist: [num_positives, 4 * reg_max]
# target: [num_positives, 4] (valeurs entre 0 et reg_max)
num_positives = pred_dist.shape[0]
# Reshape pred_dist pour avoir [num_positives, 4, reg_max]
pred_dist = pred_dist.view(num_positives, 4, self.reg_max)
# Softmax sur la dimension reg_max
pred_dist = F.softmax(pred_dist, dim=-1)
# Target left et right indices
tl = target.long() # target left
tr = tl + 1 # target right
# Clamping pour éviter les indices hors limites
tl = tl.clamp(0, self.reg_max - 1)
tr = tr.clamp(0, self.reg_max - 1)
# Calcul des poids
wl = tr - target # weight left (1 - (target - tl))
wr = target - tl # weight right (target - tl)
# Récupérer les prédictions pour tl et tr
# Nous avons besoin de gather sur la dernière dimension
tl_expanded = tl.unsqueeze(-1) # [num_positives, 4, 1]
tr_expanded = tr.unsqueeze(-1) # [num_positives, 4, 1]
# Gather les prédictions
pred_tl = torch.gather(pred_dist, -1, tl_expanded).squeeze(-1) # [num_positives, 4]
pred_tr = torch.gather(pred_dist, -1, tr_expanded).squeeze(-1) # [num_positives, 4]
# Calcul de la perte DFL
# DFL = -[wl * log(p_tl) + wr * log(p_tr)]
loss_left = wl * torch.log(pred_tl + 1e-8)
loss_right = wr * torch.log(pred_tr + 1e-8)
dfl_loss = -(loss_left + loss_right)
# Moyenne sur toutes les positions positives
return dfl_loss.mean()
def build_targets(self, predictions: List[torch.Tensor],
targets: torch.Tensor,
img_size: Tuple[int, int] = (640, 640)) -> Dict[str, torch.Tensor]:
"""
Construction des cibles pour l'entraînement - Version corrigée
"""
device = predictions[0].device
batch_size = predictions[0].shape[0]
# Initialisation des cibles
cls_targets = []
box_targets = []
obj_targets = []
dfl_targets = []
# Pour chaque échelle (P3, P4, P5)
for scale_idx, pred in enumerate(predictions):
bs, _, ny, nx = pred.shape
stride = img_size[0] // ny # 8, 16, 32
# Création de la grille
grid_y, grid_x = torch.meshgrid(
torch.arange(ny, device=device),
torch.arange(nx, device=device),
indexing='ij'
)
grid = torch.stack([grid_x, grid_y], dim=-1).float() # [ny, nx, 2]
grid = grid.view(1, ny, nx, 2) # [1, ny, nx, 2]
# Initialisation des cibles pour cette échelle
scale_cls_target = torch.zeros(bs, ny, nx, self.num_classes, device=device)
scale_box_target = torch.zeros(bs, ny, nx, 4, device=device)
scale_obj_target = torch.zeros(bs, ny, nx, 1, device=device)
scale_dfl_target = torch.zeros(bs, ny, nx, 4, device=device)
# Pour chaque image du batch
for batch_idx in range(bs):
# Récupérer les cibles pour cette image
img_targets = targets[targets[:, 0] == batch_idx]
if len(img_targets) == 0:
continue
# Extraire les informations
target_classes = img_targets[:, 1].long() # classes
target_boxes = img_targets[:, 2:6] # x, y, w, h (normalisées)
# Convertir en coordonnées de grille (multiplication par la taille de la grille)
grid_boxes = target_boxes.clone()
grid_boxes[:, :2] = grid_boxes[:, :2] * torch.tensor([nx, ny], device=device) # x, y en cellules
grid_boxes[:, 2:] = grid_boxes[:, 2:] * torch.tensor([nx, ny], device=device) # w, h en cellules
# Pour chaque cible dans l'image
for t_idx in range(len(img_targets)):
cls = target_classes[t_idx]
box = grid_boxes[t_idx]
# Centre de la boîte en coordonnées de grille
gx, gy = box[0], box[1]
gw, gh = box[2], box[3]
# Cellule la plus proche
cell_x = int(gx.item())
cell_y = int(gy.item())
# Vérifier les limites
if not (0 <= cell_x < nx and 0 <= cell_y < ny):
continue
# Offset dans la cellule (entre 0 et 1)
dx = gx - cell_x
dy = gy - cell_y
# Largeur et hauteur normalisées (entre 0 et 1)
nw = gw / nx
nh = gh / ny
# Assigner la cible
scale_cls_target[batch_idx, cell_y, cell_x, cls] = 1.0
scale_box_target[batch_idx, cell_y, cell_x] = torch.tensor([dx, dy, nw, nh], device=device)
scale_obj_target[batch_idx, cell_y, cell_x] = 1.0
# Cible DFL: valeurs entre 0 et reg_max
dfl_values = torch.tensor([dx, dy, nw, nh], device=device) * self.reg_max
scale_dfl_target[batch_idx, cell_y, cell_x] = dfl_values
# Reshape pour correspondre aux prédictions
# De [bs, ny, nx, ...] à [bs, ny*nx, ...]
scale_cls_target = scale_cls_target.view(bs, ny * nx, self.num_classes)
scale_box_target = scale_box_target.view(bs, ny * nx, 4)
scale_obj_target = scale_obj_target.view(bs, ny * nx, 1)
scale_dfl_target = scale_dfl_target.view(bs, ny * nx, 4)
cls_targets.append(scale_cls_target)
box_targets.append(scale_box_target)
obj_targets.append(scale_obj_target)
dfl_targets.append(scale_dfl_target)
# Concaténer toutes les échelles
cls_target = torch.cat(cls_targets, dim=1) # [bs, total_anchors, num_classes]
box_target = torch.cat(box_targets, dim=1) # [bs, total_anchors, 4]
obj_target = torch.cat(obj_targets, dim=1) # [bs, total_anchors, 1]
dfl_target = torch.cat(dfl_targets, dim=1) # [bs, total_anchors, 4]
return {
'cls_target': cls_target,
'box_target': box_target,
'obj_target': obj_target,
'dfl_target': dfl_target
}
def forward(self, predictions: List[torch.Tensor],
targets: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Calcul de la perte totale - Version corrigée
"""
device = predictions[0].device
batch_size = predictions[0].shape[0]
# Construire les cibles
targets_dict = self.build_targets(predictions, targets)
# Initialisation des pertes
total_box_loss = torch.tensor(0.0, device=device)
total_cls_loss = torch.tensor(0.0, device=device)
total_obj_loss = torch.tensor(0.0, device=device)
total_dfl_loss = torch.tensor(0.0, device=device)
# Pour chaque échelle
start_idx = 0
for scale_idx, pred in enumerate(predictions):
bs, _, ny, nx = pred.shape
num_anchors = ny * nx
# Reshape des prédictions
pred = pred.view(bs, self.reg_max * 4 + 1 + self.num_classes, num_anchors)
pred = pred.permute(0, 2, 1).contiguous() # [bs, num_anchors, channels]
# Séparation des prédictions
pred_dfl = pred[..., :self.reg_max * 4] # DFL predictions [bs, num_anchors, 4*reg_max]
pred_obj = pred[..., self.reg_max * 4:self.reg_max * 4 + 1] # Objectness [bs, num_anchors, 1]
pred_cls = pred[..., self.reg_max * 4 + 1:] # Classification [bs, num_anchors, num_classes]
# Récupérer les cibles pour cette échelle
end_idx = start_idx + num_anchors
scale_cls_target = targets_dict['cls_target'][:, start_idx:end_idx] # [bs, num_anchors, num_classes]
scale_box_target = targets_dict['box_target'][:, start_idx:end_idx] # [bs, num_anchors, 4]
scale_obj_target = targets_dict['obj_target'][:, start_idx:end_idx] # [bs, num_anchors, 1]
scale_dfl_target = targets_dict['dfl_target'][:, start_idx:end_idx] # [bs, num_anchors, 4]
# Masque pour les cellules avec objet
obj_mask = scale_obj_target.squeeze(-1) > 0.5 # [bs, num_anchors]
# 1. Perte de classification (seulement pour les cellules avec objet)
if obj_mask.any():
# Flatten pour le calcul de la perte
pred_cls_flat = pred_cls[obj_mask] # [num_positives, num_classes]
cls_target_flat = scale_cls_target[obj_mask] # [num_positives, num_classes]
cls_loss = F.binary_cross_entropy_with_logits(
pred_cls_flat,
cls_target_flat,
reduction='mean'
)
total_cls_loss += cls_loss * self.cls_weight
# 2. Perte d'objectness (toutes les cellules)
obj_loss = F.binary_cross_entropy_with_logits(
pred_obj.squeeze(-1), # [bs, num_anchors]
scale_obj_target.squeeze(-1), # [bs, num_anchors]
reduction='mean'
)
total_obj_loss += obj_loss * self.obj_weight
# 3. Perte de régression et DFL (seulement pour les cellules avec objet)
if obj_mask.any() and self.use_dfl:
# Récupérer les prédictions et cibles pour les positions positives
pred_dfl_pos = pred_dfl[obj_mask] # [num_positives, 4*reg_max]
box_target_pos = scale_box_target[obj_mask] # [num_positives, 4]
dfl_target_pos = scale_dfl_target[obj_mask] # [num_positives, 4]
# Décodage des prédictions DFL pour calculer IoU
pred_dfl_reshaped = pred_dfl_pos.view(-1, 4, self.reg_max) # [num_positives, 4, reg_max]
pred_dfl_softmax = F.softmax(pred_dfl_reshaped, dim=-1)
# Calcul des valeurs prédites par intégration
pred_boxes = torch.zeros(pred_dfl_softmax.shape[0], 4, device=device)
reg_range = torch.arange(self.reg_max, device=device).float()
for i in range(4):
# Intégration: sum(prob * value)
pred_boxes[:, i] = (pred_dfl_softmax[:, i] * reg_range).sum(dim=-1)
# Normalisation: les valeurs sont entre 0 et reg_max, on les ramène à 0-1
pred_boxes = pred_boxes / self.reg_max
# Calcul de la perte IoU
iou = self.bbox_iou(pred_boxes, box_target_pos)
box_loss = (1.0 - iou).mean()
total_box_loss += box_loss * self.box_weight
# 4. Perte DFL
dfl_loss = self.dfl_loss(pred_dfl_pos, dfl_target_pos)
total_dfl_loss += dfl_loss * self.dfl_weight
start_idx = end_idx
# Perte totale
total_loss = total_box_loss + total_cls_loss + total_obj_loss + total_dfl_loss
return {
'box_loss': total_box_loss,
'cls_loss': total_cls_loss,
'obj_loss': total_obj_loss,
'dfl_loss': total_dfl_loss,
'total_loss': total_loss
}
class PostProcessor:
"""
Post-processing des prédictions YOLOv8
"""
def __init__(self, conf_threshold: float = 0.25,
iou_threshold: float = 0.45, max_det: int = 300):
self.conf_threshold = conf_threshold
self.iou_threshold = iou_threshold
self.max_det = max_det
@staticmethod
def decode_predictions(predictions: List[torch.Tensor],
reg_max: int = 16, num_classes: int = 80) -> torch.Tensor:
"""
Décodage simplifié des prédictions YOLOv8
Args:
predictions: Sorties du modèle pour chaque échelle
reg_max: Paramètre pour le décodage DFL
num_classes: Nombre de classes
Returns:
Prédictions décodées [batch, num_preds, 6] (x, y, w, h, conf, class)
"""
decoded_preds = []
for pred in predictions:
bs, _, ny, nx = pred.shape
# Redimensionnement
pred = pred.view(bs, reg_max * 4 + 1 + num_classes, ny * nx)
pred = pred.permute(0, 2, 1).contiguous()
# Création de la grille de positions
grid_y, grid_x = torch.meshgrid(
torch.arange(ny, device=pred.device),
torch.arange(nx, device=pred.device),
indexing='ij'
)
grid = torch.stack((grid_x, grid_y), 2).view(1, ny * nx, 2).float()
# Décodage simplifié (sans DFL pour l'exemple)
# Dans une vraie implémentation, on utiliserait DFL
box_pred = pred[..., :reg_max * 4]
obj_pred = pred[..., reg_max * 4:reg_max * 4 + 1].sigmoid()
cls_pred = pred[..., reg_max * 4 + 1:].sigmoid()
# Décodage des boîtes (simplifié)
# Normalement: box_pred -> distribution -> intégration -> valeurs
decoded_box = torch.zeros(bs, ny * nx, 4, device=pred.device)
# Position relative dans la cellule
decoded_box[..., :2] = (grid + 0.5) / torch.tensor([nx, ny], device=pred.device)
decoded_box[..., 2:] = 0.1 # Dimensions par défaut (simplifié)
# Scores et classes
scores = obj_pred * cls_pred.max(-1, keepdim=True)[0]
class_ids = cls_pred.argmax(-1, keepdim=True)
# Concaténation finale
final_pred = torch.cat([
decoded_box, # x, y, w, h
scores, # confiance
class_ids.float() # classe
], dim=-1)
decoded_preds.append(final_pred.view(bs, -1, 6))
return torch.cat(decoded_preds, dim=1)
def non_max_suppression(self, predictions: torch.Tensor) -> List[torch.Tensor]:
"""
Suppression Non-Maximale (NMS) - Version corrigée et robuste
"""
batch_size = predictions.shape[0]
output = [torch.zeros((0, 6), device=predictions.device)] * batch_size
for i, pred in enumerate(predictions):
# Filtrage par confiance
mask = pred[:, 4] > self.conf_threshold
pred = pred[mask]
if pred.shape[0] == 0:
continue
# Conversion xywh -> xyxy
box = pred[:, :4].clone()
box[:, 0] = box[:, 0] - box[:, 2] / 2 # x1
box[:, 1] = box[:, 1] - box[:, 3] / 2 # y1
box[:, 2] = box[:, 0] + box[:, 2] # x2
box[:, 3] = box[:, 1] + box[:, 3] # y2
# NMS par classe
unique_classes = pred[:, 5].unique()
for cls in unique_classes:
cls_mask = pred[:, 5] == cls
cls_pred = pred[cls_mask]
cls_box = box[cls_mask]
cls_scores = cls_pred[:, 4]
# Tri par score (décroissant)
sorted_indices = cls_scores.argsort(descending=True)
cls_box = cls_box[sorted_indices]
cls_scores = cls_scores[sorted_indices]
cls_pred = cls_pred[sorted_indices]
# NMS itérative - CORRECTION: Garder les indices originaux
keep_indices = []
current_indices = torch.arange(len(cls_box), device=cls_box.device)
while len(cls_box) > 0:
# Garder l'index de la boîte avec le score le plus élevé
keep_indices.append(current_indices[0].item())
if len(cls_box) == 1:
break
# Calcul des IoU avec toutes les autres boîtes
ious = self._bbox_iou(cls_box[0:1], cls_box[1:])
# CORRECTION: ious peut avoir shape [1, n-1] ou [n-1]
if ious.dim() == 2:
ious = ious.squeeze(0)
# Garder seulement les boîtes avec IoU < seuil
mask = ious < self.iou_threshold
# Mettre à jour les boîtes, scores et indices
cls_box = cls_box[1:][mask]
cls_scores = cls_scores[1:][mask]
current_indices = current_indices[1:][mask]
cls_pred = cls_pred[1:][mask]
# Ajouter les boîtes gardées aux résultats
if keep_indices:
# CORRECTION: Reconstruire cls_pred original pour indexation
cls_pred_original = pred[cls_mask][sorted_indices]
kept_preds = cls_pred_original[keep_indices]
output[i] = torch.cat([output[i], kept_preds], dim=0)
# Limitation du nombre de détections
if output[i].shape[0] > self.max_det:
# Trier par score avant de limiter
scores = output[i][:, 4]
sorted_idx = scores.argsort(descending=True)
output[i] = output[i][sorted_idx][:self.max_det]
return output
@staticmethod
def _bbox_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
"""
Calcul de l'IoU entre boîtes au format xyxy - Version corrigée
Args:
box1: [1, 4] ou [n, 4]
box2: [m, 4]
Returns:
IoU: [m] ou [n, m] selon les dimensions d'entrée
"""
# S'assurer que box1 et box2 ont les bonnes dimensions
if box1.dim() == 1:
box1 = box1.unsqueeze(0)
if box2.dim() == 1:
box2 = box2.unsqueeze(0)
# Extraire les coordonnées
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
# Intersection
inter_x1 = torch.max(b1_x1, b2_x1.T) if box1.shape[0] == 1 else torch.max(b1_x1, b2_x1)
inter_y1 = torch.max(b1_y1, b2_y1.T) if box1.shape[0] == 1 else torch.max(b1_y1, b2_y1)
inter_x2 = torch.min(b1_x2, b2_x2.T) if box1.shape[0] == 1 else torch.min(b1_x2, b2_x2)
inter_y2 = torch.min(b1_y2, b2_y2.T) if box1.shape[0] == 1 else torch.min(b1_y2, b2_y2)
inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)
# Union
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
if box1.shape[0] == 1:
# box1 est une seule boîte, box2 multiple
union_area = b1_area + b2_area.T - inter_area + eps
else:
# box1 et box2 ont même nombre de boîtes
union_area = b1_area + b2_area - inter_area + eps
return inter_area / union_area
class YOLOv8Trainer:
"""
Classe pour l'entraînement de YOLOv8
"""
def __init__(self, model: nn.Module, device: torch.device,
optimizer: torch.optim.Optimizer = None,
scheduler: torch.optim.lr_scheduler._LRScheduler = None):
self.model = model.to(device)
self.device = device
self.criterion = YOLOv8Loss(num_classes=model.num_classes)
self.optimizer = optimizer or torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0005)
self.scheduler = scheduler
# Initialisation du GradScaler pour mixed precision
if device.type == 'cuda':
self.scaler = torch.cuda.amp.GradScaler()
else:
self.scaler = None
def train_step(self, images: torch.Tensor,
targets: torch.Tensor) -> Dict[str, float]:
"""
Une étape d'entraînement
"""
self.model.train()
images = images.to(self.device)
targets = targets.to(self.device)
# Forward pass avec mixed precision si disponible
if self.scaler is not None:
with torch.cuda.amp.autocast():
predictions = self.model(images)
loss_dict = self.criterion(predictions, targets)
# Backward pass avec scaling des gradients
self.optimizer.zero_grad()
self.scaler.scale(loss_dict['total_loss']).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Forward pass normal
predictions = self.model(images)
loss_dict = self.criterion(predictions, targets)
# Backward pass normal
self.optimizer.zero_grad()
loss_dict['total_loss'].backward()
self.optimizer.step()
# Mise à jour du scheduler
if self.scheduler is not None:
self.scheduler.step()
# Conversion des tenseurs en valeurs Python
return {k: v.item() for k, v in loss_dict.items()}
def validate(self, dataloader: torch.utils.data.DataLoader) -> Dict[str, float]:
"""
Validation du modèle
"""
self.model.eval()
total_losses = {
'box_loss': 0.0,
'cls_loss': 0.0,
'obj_loss': 0.0,
'dfl_loss': 0.0,
'total_loss': 0.0
}
num_batches = 0
with torch.no_grad():
for images, targets in dataloader:
images = images.to(self.device)
targets = targets.to(self.device)
predictions = self.model(images)
loss_dict = self.criterion(predictions, targets)
# Accumulation des pertes
for key in total_losses:
total_losses[key] += loss_dict[key].item()
num_batches += 1
# Moyenne des pertes
if num_batches > 0:
for key in total_losses:
total_losses[key] /= num_batches
return total_losses
def main():
"""
Exemple complet d'utilisation de YOLOv8 avec fonction de perte réelle
"""
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Utilisation du device: {device}")
# 1. Création du modèle
model = YOLOv8(num_classes=80, version='n')
print(f"\nModèle créé: YOLOv8-{model.version}")
print(f"Nombre de paramètres: {sum(p.numel() for p in model.parameters()):,}")
# Afficher l'architecture
print("\nArchitecture du modèle:")
print("Backbone channels:", model.backbone.channels)
print("Neck channels:", model.neck.channels)
# 2. Exemple de forward pass
batch_size = 2