Skip to content

alessioarcara/SoccerAI

Repository files navigation

Overview

In this work, we benchmark several graph-neural-network (GNN) architectures to estimate the probability that a given action will culminate in a shot, thereby quantifying how dangerous each action is. Once the shot-likelihood model is trained, we shift our attention to explainability: identifying and interpreting the key factors that drive the network predictions.

Chain
Figure 1 – An action culminating in a shot

Dataset

2022 FIFA World Cup

  • Training set: 48 group-stage matches
  • Validation set: 16 knockout-stage matches

Two available data streams:

Stream Granularity Contents Usage
Event data Sparse Labeled events (Pass, shot, tackle, foul, ...) Primary source for shot prediction
Tracking data Dense, 30 Hz Player & ball positions 60 frames (≈2 s) before each event to derive momentum—player speed & direction

We fused event data with short bursts of tracking data, capturing not only what happened but also how each player was moving at that moment. We then enriched these positional data with key player statistics scraped from Transfermarkt and FBref.

Representation

  • Graph structure – Each match frame is a graph whose

    • Nodes are the 22 players on the pitch.
    • Edges encode pairwise spatial relationships (e.g., Euclidean distance).
  • Node features combine

    • Positional statistics (location, velocity, etc.).
    • Player-specific statistics (market value, age, and other attributes scraped from Transfermarkt and FBref).
Graph
Figure 2 – Player positions (left) mapped to a bipartite interaction graph (right)

Model architecture:

The architecture is fully modular: you can use different backbones to capture spatial features, choose a temporal neck that works on graph or node embeddings, and fine-tune each component through its own configuration file.

Architecture
Figure 3 – Model architecture

Available backbones:

Available necks:

  • Readout → Temporal over graph embeddings [paper]
  • Temporal over node embeddings → Readout [paper]

Available heads:

  • Graph Classification

End-to-End alternatives:

Tip

Consult the configuration files for additional, specific parameters available for each backbone, neck, and head.

Installation

Click to expand

Before running the code, you need to install PyTorch and its dependencies. You can choose either the GPU or CPU build depending on your setup. The code has been tested with:

  • PyTorch 2.7.1
  • CUDA 12.8
  • Optional PyTorch Geometric libraries

1. Install PyTorch

Build Command
GPU (CUDA 12.8) pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128
CPU-only pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu

Note: Be aware of potential mismatches between CUDA versions when installing.


2. PyTorch Geometric stack

Install PyTorch Geometric companion wheels after PyTorch:

Build Command
GPU (CUDA 12.8) pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.7.1+cu128.html
CPU-only pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.7.1+cpu.html

3. Install project dependencies

pip install .

That’s it—you’re ready to run the code!

Usage

Training a model

  1. Set the model name

    Open config/base.yaml and fill in the run_name field with your chosen model identifier.

  2. Launch training

    python ./scripts/train.py
    • Add --reload only if you have changed any dataset-related entries in the YAML file; this forces the dataset to be rebuilt so the changes take effect.

Evaluating a trained model

python ./scripts/eval.py --name <model_name>
  • The script automatically picks the best checkpoint from ./checkpoints/<model_name>/.
  • If you want to evaluate a specific checkpoint, move or delete any other checkpoints in that directory before running the command.

Repository Structure

configs/                         # Default and per-model configs
scripts/ 
├── preload_video_frames.py      # Pre-downloads video frames needed for labelling to avoid repeated I/O
├── train.py                     # Trains a selected model
└── eval.py                      # Selects the best checkpoint of a model type and computes accuracy/F1/AP
notebooks/
└── data_collection.ipynb        # Used for manually filtering the unwanted chains and to build the Shot-Prediction dataset
soccerai/              
└── data/              
│   ├── converters.py            # Turns tabular data into sparse PyG graphs (bipartite / FC)
│   ├── data.py                  # Loads World Cup 2022 data and exports Parquet.
│   ├── dataset.py               # PyG-style dataset class; handles preprocessing, imputing, normalisation & splits
│   ├── transformers.py          # scikit-learn transformers for feature engineering & normalization
│   ├── visualize.py             # Pitch frame visualizer (players, ball, side video)
│   ├── temporal_dataset.py      # Torch dataset that groups all frames of each chain into a sequence and pads/collates them so multiple chains can be batched together.
│   ├── enrichers/
│   │   ├── player_velocity.py   # Adds direction & velocity from the last 60 tracking data frames
│   │   └── rosters.py           # Scrapes FBref & Transfermarkt to build player-stat CSV for the World Cup
│   └── label.py                 # Builds positive/negative chains and includes a visual function to filter low-quality ones
└── models/                      # Modular architecture that let's you specify a configurable backbone, neck & head 
│   ├── backbones.py 
│   ├── diffpool.py
│   ├── heads.py
│   ├── layers.py
│   ├── models.py
│   ├── necks.py
│   ├── typings.py
│   └── utils.py       
└── training/                    # Modular training loop with callbacks, metrics & augmentations
    ├── callbacks.py
    ├── metrics.py
    ├── trainer.py
    ├── trainer_config.py        # Schema for configs
    ├── transforms.py
    └── utils.py      

Acknowledgments

This project leverages the transfermarkt-api repository by Felipe Almeida (MIT License) — https://github.com/felipeall/transfermarkt-api — to obtain player profile data from Transfermarkt.

About

Benchmarking Multiple GNNs for Predicting Whether a Soccer Action Ends in a Shot

Topics

Resources

License

Stars

Watchers

Forks

Contributors