@@ -653,22 +653,23 @@ def convert_batch_to_graph_nodes(
653653 batch_cell_positions = batch_cell_positions + patch_cell_positions
654654
655655 if self .detection_cell_postprocessor .classifier is not None :
656- batch_cell_tokens_pt = torch .stack (batch_cell_tokens )
657- updated_preds = self .detection_cell_postprocessor .classifier (
658- batch_cell_tokens_pt
659- )
660- updated_preds = F .softmax (updated_preds , dim = 1 )
661- updated_classes = torch .argmax (updated_preds , dim = 1 )
662- updated_class_preds = updated_preds [
663- torch .arange (updated_classes .shape [0 ]), updated_classes
664- ]
665-
666- for f , z in zip (batch_complete , updated_classes ):
667- f ["type" ] = int (z )
668- for f , z in zip (batch_complete , updated_class_preds ):
669- f ["type_prob" ] = int (z )
670- for f , z in zip (batch_detection , updated_classes ):
671- f ["type" ] = int (z )
656+ if len (batch_cell_tokens ) > 0 :
657+ batch_cell_tokens_pt = torch .stack (batch_cell_tokens )
658+ updated_preds = self .detection_cell_postprocessor .classifier (
659+ batch_cell_tokens_pt
660+ )
661+ updated_preds = F .softmax (updated_preds , dim = 1 )
662+ updated_classes = torch .argmax (updated_preds , dim = 1 )
663+ updated_class_preds = updated_preds [
664+ torch .arange (updated_classes .shape [0 ]), updated_classes
665+ ]
666+
667+ for f , z in zip (batch_complete , updated_classes ):
668+ f ["type" ] = int (z )
669+ for f , z in zip (batch_complete , updated_class_preds ):
670+ f ["type_prob" ] = int (z )
671+ for f , z in zip (batch_detection , updated_classes ):
672+ f ["type" ] = int (z )
672673 if self .detection_cell_postprocessor .binary :
673674 for f in batch_complete :
674675 f ["type" ] = 1
0 commit comments