diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 97faddd..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c60ebc3..4b978cb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,6 +5,7 @@ on: branches: - main - dev-jintao + - smart-annotation workflow_dispatch: concurrency: @@ -43,6 +44,7 @@ jobs: - name: Install requirements run: | python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install -r requirements.txt - name: Cleanup before PyInstaller @@ -56,26 +58,30 @@ jobs: shell: pwsh run: > python -m PyInstaller --noconfirm --clean --windowed --onefile - --name "SoccerNetProAnalyzer" + --name "VideoAnnotationTool" --add-data "style;style" --add-data "ui;ui" --add-data "controllers;controllers" --add-data "image;image" + --add-data "config.yaml;." + --collect-all "soccernetpro" + --collect-all "wandb" + --collect-all "torch_geometric" "main.py" - name: Zip Windows binary (manual runs only) if: github.event_name == 'workflow_dispatch' shell: pwsh run: | - Move-Item -Force dist\SoccerNetProAnalyzer.exe dist\SoccerNetProAnalyzer-win.exe - Compress-Archive -Path dist\SoccerNetProAnalyzer-win.exe -DestinationPath dist\SoccerNetProAnalyzer-win.zip -Force + Move-Item -Force dist\VideoAnnotationTool.exe dist\VideoAnnotationTool-win.exe + Compress-Archive -Path dist\VideoAnnotationTool-win.exe -DestinationPath dist\VideoAnnotationTool-win.zip -Force - name: Upload artifact (manual runs only) if: github.event_name == 'workflow_dispatch' uses: actions/upload-artifact@v4 with: - name: SoccerNetProAnalyzer-Windows - path: annotation_tool/dist/SoccerNetProAnalyzer-win.zip + name: VideoAnnotationTool-Windows + path: annotation_tool/dist/VideoAnnotationTool-win.zip retention-days: 3 build-macos: @@ -116,25 +122,29 @@ jobs: shell: bash run: > python -m PyInstaller --noconfirm --clean --windowed - --name "SoccerNetProAnalyzer" + --name "VideoAnnotationTool" --add-data "style:style" --add-data "ui:ui" --add-data "controllers:controllers" --add-data "image:image" + --add-data "config.yaml:." + --collect-all "soccernetpro" + --collect-all "wandb" + --collect-all "torch_geometric" "main.py" - name: Zip macOS app (manual runs only) if: github.event_name == 'workflow_dispatch' shell: bash run: | - ditto -c -k --sequesterRsrc --keepParent "dist/SoccerNetProAnalyzer.app" "dist/SoccerNetProAnalyzer-mac.zip" + ditto -c -k --sequesterRsrc --keepParent "dist/VideoAnnotationTool.app" "dist/VideoAnnotationTool-mac.zip" - name: Upload artifact (manual runs only) if: github.event_name == 'workflow_dispatch' uses: actions/upload-artifact@v4 with: - name: SoccerNetProAnalyzer-macOS - path: annotation_tool/dist/SoccerNetProAnalyzer-mac.zip + name: VideoAnnotationTool-macOS + path: annotation_tool/dist/VideoAnnotationTool-mac.zip retention-days: 3 build-linux: @@ -168,6 +178,7 @@ jobs: - name: Install requirements run: | python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install -r requirements.txt - name: Cleanup before PyInstaller @@ -180,26 +191,30 @@ jobs: shell: bash run: > python -m PyInstaller --noconfirm --clean --windowed --onefile - --name "SoccerNetProAnalyzer" + --name "VideoAnnotationTool" --add-data "style:style" --add-data "ui:ui" --add-data "controllers:controllers" --add-data "image:image" + --add-data "config.yaml:." + --collect-all "soccernetpro" + --collect-all "wandb" + --collect-all "torch_geometric" "main.py" - name: Zip Linux binary (manual runs only) if: github.event_name == 'workflow_dispatch' shell: bash run: | - mv -f dist/SoccerNetProAnalyzer dist/SoccerNetProAnalyzer-linux + mv -f dist/VideoAnnotationTool dist/VideoAnnotationTool-linux cd dist - zip -r SoccerNetProAnalyzer-linux.zip SoccerNetProAnalyzer-linux + zip -r VideoAnnotationTool-linux.zip VideoAnnotationTool-linux cd .. - name: Upload artifact (manual runs only) if: github.event_name == 'workflow_dispatch' uses: actions/upload-artifact@v4 with: - name: SoccerNetProAnalyzer-Linux - path: annotation_tool/dist/SoccerNetProAnalyzer-linux.zip + name: VideoAnnotationTool-Linux + path: annotation_tool/dist/VideoAnnotationTool-linux.zip retention-days: 3 diff --git a/.github/workflows/deploy_docs.yml b/.github/workflows/deploy_docs.yml index 1a0a119..c37aaba 100644 --- a/.github/workflows/deploy_docs.yml +++ b/.github/workflows/deploy_docs.yml @@ -4,7 +4,7 @@ on: push: branches: - main - - dev-jintao + - smart-annotation workflow_dispatch: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 850b511..80942cd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,8 +9,13 @@ on: permissions: contents: write +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: generate-release-notes: + name: Generate release notes runs-on: ubuntu-latest outputs: changelog: ${{ steps.notes.outputs.changelog }} @@ -31,10 +36,11 @@ jobs: } >> "$GITHUB_OUTPUT" create-release: + name: Create GitHub Release needs: generate-release-notes runs-on: ubuntu-latest steps: - - name: Create/Update GitHub Release (body only) + - name: Create/Update GitHub Release uses: softprops/action-gh-release@v2 with: tag_name: ${{ github.ref_name }} @@ -42,6 +48,7 @@ jobs: body: ${{ needs.generate-release-notes.outputs.changelog }} build-windows: + name: Build on Windows needs: create-release runs-on: windows-latest defaults: @@ -54,9 +61,22 @@ jobs: with: python-version: "3.11" + - name: Cache pip + uses: actions/cache@v4 + with: + path: | + ~\AppData\Local\pip\Cache + ~\AppData\Local\pip\cache + ~\AppData\Roaming\pip\Cache + key: ${{ runner.os }}-pip-${{ hashFiles('annotation_tool/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install requirements + shell: pwsh run: | python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install -r requirements.txt - name: Cleanup before PyInstaller @@ -68,32 +88,37 @@ jobs: - name: Build exe shell: pwsh - run: | - python -m PyInstaller --noconfirm --clean --windowed --onefile ` - --name "SoccerNetProAnalyzer" ` - --add-data "style;style" ` - --add-data "ui;ui" ` - --add-data "controllers;controllers" ` - --add-data "image;image" ` - "main.py" + run: > + python -m PyInstaller --noconfirm --clean --windowed --onefile + --name "VideoAnnotationTool" + --add-data "style;style" + --add-data "ui;ui" + --add-data "controllers;controllers" + --add-data "image;image" + --add-data "config.yaml;." + --collect-all "soccernetpro" + --collect-all "wandb" + --collect-all "torch_geometric" + "main.py" - name: Rename binary shell: pwsh run: | - Move-Item -Force dist\SoccerNetProAnalyzer.exe dist\SoccerNetProAnalyzer-win.exe + Move-Item -Force dist\VideoAnnotationTool.exe dist\VideoAnnotationTool-win.exe - name: Zip Windows binary shell: pwsh run: | - Compress-Archive -Path dist\SoccerNetProAnalyzer-win.exe -DestinationPath dist\SoccerNetProAnalyzer-win.zip -Force + Compress-Archive -Path dist\VideoAnnotationTool-win.exe -DestinationPath dist\VideoAnnotationTool-win.zip -Force - name: Upload Release Asset (Windows) uses: softprops/action-gh-release@v2 with: - files: annotation_tool/dist/SoccerNetProAnalyzer-win.zip tag_name: ${{ github.ref_name }} + files: annotation_tool/dist/VideoAnnotationTool-win.zip build-macos: + name: Build on macOS needs: create-release runs-on: macos-latest defaults: @@ -106,7 +131,18 @@ jobs: with: python-version: "3.11" + - name: Cache pip + uses: actions/cache@v4 + with: + path: | + ~/Library/Caches/pip + ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('annotation_tool/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install requirements + shell: bash run: | python -m pip install --upgrade pip pip install -r requirements.txt @@ -121,25 +157,30 @@ jobs: shell: bash run: > python -m PyInstaller --noconfirm --clean --windowed - --name "SoccerNetProAnalyzer" + --name "VideoAnnotationTool" --add-data "style:style" --add-data "ui:ui" --add-data "controllers:controllers" --add-data "image:image" + --add-data "config.yaml:." + --collect-all "soccernetpro" + --collect-all "wandb" + --collect-all "torch_geometric" "main.py" - name: Zip macOS app shell: bash run: | - ditto -c -k --sequesterRsrc --keepParent "dist/SoccerNetProAnalyzer.app" "dist/SoccerNetProAnalyzer-mac.zip" + ditto -c -k --sequesterRsrc --keepParent "dist/VideoAnnotationTool.app" "dist/VideoAnnotationTool-mac.zip" - name: Upload Release Asset (macOS) uses: softprops/action-gh-release@v2 with: - files: annotation_tool/dist/SoccerNetProAnalyzer-mac.zip tag_name: ${{ github.ref_name }} + files: annotation_tool/dist/VideoAnnotationTool-mac.zip build-linux: + name: Build on Linux needs: create-release runs-on: ubuntu-latest defaults: @@ -158,9 +199,20 @@ jobs: sudo apt-get update sudo apt-get install -y libgl1 libglib2.0-0 libxcb-cursor0 + - name: Cache pip + uses: actions/cache@v4 + with: + path: | + ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('annotation_tool/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install requirements + shell: bash run: | python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install -r requirements.txt - name: Cleanup before PyInstaller @@ -173,27 +225,31 @@ jobs: shell: bash run: > python -m PyInstaller --noconfirm --clean --windowed --onefile - --name "SoccerNetProAnalyzer" + --name "VideoAnnotationTool" --add-data "style:style" --add-data "ui:ui" --add-data "controllers:controllers" --add-data "image:image" + --add-data "config.yaml:." + --collect-all "soccernetpro" + --collect-all "wandb" + --collect-all "torch_geometric" "main.py" - name: Rename binary shell: bash run: | - mv -f dist/SoccerNetProAnalyzer dist/SoccerNetProAnalyzer-linux + mv -f dist/VideoAnnotationTool dist/VideoAnnotationTool-linux - name: Zip Linux binary shell: bash run: | cd dist - zip -r SoccerNetProAnalyzer-linux.zip SoccerNetProAnalyzer-linux + zip -r VideoAnnotationTool-linux.zip VideoAnnotationTool-linux cd .. - name: Upload Release Asset (Linux) uses: softprops/action-gh-release@v2 with: - files: annotation_tool/dist/SoccerNetProAnalyzer-linux.zip tag_name: ${{ github.ref_name }} + files: annotation_tool/dist/VideoAnnotationTool-linux.zip diff --git a/README.md b/README.md index 4c8b24a..c8ec687 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,13 @@ -# SoccerNetPro Analyzer (UI) +# Video Annotation Tool (UI) -[![Documentation Status](https://img.shields.io/badge/docs-online-brightgreen)](https://opensportslab.github.io/soccernetpro-ui/) +[![Documentation Status](https://img.shields.io/badge/docs-online-brightgreen)](https://opensportslab.github.io/VideoAnnotationTool/) -A **PyQt6-based GUI** for analyzing and annotating **SoccerNetPro / action spotting** datasets (OpenSportsLab). +A **PyQt6-based GUI** for analyzing and annotating **[OSL format](https://opensportslab.github.io/VideoAnnotationTool/OSL/)** datasets (OpenSportsLab). ---- ## Features -- Open and visualize SoccerNetPro-style data and annotations. +- Open and visualize OSL-style data and annotations. - Annotate and edit events/actions with a user-friendly GUI. - Manage labels/categories and export results for downstream tasks. - Easy to extend with additional viewers, overlays, and tools. @@ -24,16 +23,16 @@ We recommend using [Anaconda](https://www.anaconda.com/) or [Miniconda](https:// ### Step 0 โ€“ Clone the repository ```bash -git clone https://github.com/OpenSportsLab/soccernetpro-ui.git -cd soccernetpro-ui +git clone https://github.com/OpenSportsLab/VideoAnnotationTool.git +cd VideoAnnotationTool ``` ### Step 1 โ€“ Create a new Conda environment ```bash -conda create -n soccernetpro-ui python=3.9 -y -conda activate soccernetpro-ui +conda create -n VideoAnnotationTool python=3.9 -y +conda activate VideoAnnotationTool ``` @@ -63,14 +62,14 @@ This project provides **test datasets** for multiple tasks, including: - **Description (Video Captioning)** - **Dense Description (Dense Video Captioning)** -More details are available at: [`/test_data`](https://github.com/OpenSportsLab/soccernetpro-ui/tree/main/test_data) +More details are available at: [`/test_data`](https://github.com/OpenSportsLab/VideoAnnotationTool/tree/main/test_data) > โš ๏ธ **Important** > For all tasks, the corresponding **JSON annotation file must be placed in the same directory** > as the referenced data folders (e.g., `test/`, `germany_bundesliga/`, etc.). > Otherwise, the GUI may not load the data correctly due to relative path mismatches. -Some Hugging Face datasets (including SoccerNetPro datasets) are **restricted / gated**. Therefore you must: +Some Hugging Face datasets (including OSL datasets) are **restricted / gated**. Therefore you must: 1. Have access to the dataset on Hugging Face 2. Be authenticated locally using your Hugging Face account (`hf auth login`) @@ -151,6 +150,7 @@ python test_data/download_osl_hf.py \ **Data location (HuggingFace):** - [Localization Dataset (Soccer)](https://huggingface.co/datasets/OpenSportsLab/soccernetpro-localization-snas) - [Localization Dataset (Tennis)](https://huggingface.co/datasets/OpenSportsLab/soccernetpro-localization-tennis) +- [Localization Dataset (gymnastics)](https://huggingface.co/datasets/OpenSportsLab/soccernetpro-localization-gymnastics) Each folder (e.g., `england efl/`) contains video clips for localization testing. @@ -199,7 +199,7 @@ Test_Data/Description/XFoul/ --- -## ๐ŸŸง Dense Description (Dense Video Captioning) โ€“ SoccerNetPro SNDVC +## ๐ŸŸง Dense Description (Dense Video Captioning) **Dataset (Hugging Face):** [Denseโ€”Description Dataset](https://huggingface.co/datasets/OpenSportsLab/soccernetpro-densedescription-sndvc) @@ -266,7 +266,7 @@ The commands below assume you run them **from the repository root**. cd annotation_tool python -m PyInstaller --noconfirm --clean --windowed \ - --name "SoccerNetProAnalyzer" \ + --name "VideoAnnotationTool" \ --add-data "style:style" \ --add-data "ui:ui" \ --add-data "controllers:controllers" \ @@ -275,7 +275,7 @@ python -m PyInstaller --noconfirm --clean --windowed \ Output: -* `annotation_tool/dist/SoccerNetProAnalyzer.app` +* `annotation_tool/dist/VideoAnnotationTool.app` --- @@ -287,7 +287,7 @@ Output: cd annotation_tool python -m PyInstaller --noconfirm --clean --windowed --onefile \ - --name "SoccerNetProAnalyzer" \ + --name "VideoAnnotationTool" \ --add-data "style:style" \ --add-data "ui:ui" \ --add-data "controllers:controllers" \ @@ -296,7 +296,7 @@ python -m PyInstaller --noconfirm --clean --windowed --onefile \ Output: -* `annotation_tool/dist/SoccerNetProAnalyzer` +* `annotation_tool/dist/VideoAnnotationTool` #### Windows (PowerShell) @@ -307,7 +307,7 @@ On Windows, the `--add-data` separator is **`;`** (not `:`). cd annotation_tool python -m PyInstaller --noconfirm --clean --windowed --onefile ` - --name "SoccerNetProAnalyzer" ` + --name "VideoAnnotationTool" ` --add-data "style;style" ` --add-data "ui;ui" ` --add-data "controllers;controllers" ` @@ -316,7 +316,7 @@ python -m PyInstaller --noconfirm --clean --windowed --onefile ` Output: -* `annotation_tool\dist\SoccerNetProAnalyzer.exe` +* `annotation_tool\dist\VideoAnnotationTool.exe` --- @@ -356,10 +356,10 @@ There is also a standalone build workflow that can be triggered manually: ## ๐Ÿ“œ License -This Soccernet Pro project offers two licensing options to suit different needs: +This Video Annotation Tool project offers two licensing options to suit different needs: -* **AGPL-3.0 License**: This open-source license is ideal for students, researchers, and the community. It supports open collaboration and sharing. See the [`LICENSE.txt`](https://github.com/OpenSportsLab/soccernetpro-ui/blob/main/LICENSE.txt) file for full details. -* **Commercial License**: Designed for [`commercial use`](https://github.com/OpenSportsLab/soccernetpro-ui/blob/main/COMMERCIAL_LICENSE.md +* **AGPL-3.0 License**: This open-source license is ideal for students, researchers, and the community. It supports open collaboration and sharing. See the [`LICENSE.txt`](https://github.com/OpenSportsLab/VideoAnnotationTool/blob/main/LICENSE.txt) file for full details. +* **Commercial License**: Designed for [`commercial use`](https://github.com/OpenSportsLab/VideoAnnotationTool/blob/main/COMMERCIAL_LICENSE.md ), this option allows you to integrate this software into proprietary products and services without the open-source obligations of GPL-3.0. If your use case involves commercial deployment, please contact the maintainers to obtain a commercial license. **Contact:** OpenSportsLab / project maintainers. diff --git a/annotation_tool/.DS_Store b/annotation_tool/.DS_Store deleted file mode 100644 index 096b0f9..0000000 Binary files a/annotation_tool/.DS_Store and /dev/null differ diff --git a/annotation_tool/README.md b/annotation_tool/README.md index c577db3..871fa9c 100644 --- a/annotation_tool/README.md +++ b/annotation_tool/README.md @@ -1,6 +1,8 @@ -# SoccerNet Pro Annotation Tool +# Video Annotation Tool -This project is a professional video annotation desktop application built with **PyQt6**. It features a comprehensive **quad-mode** architecture supporting **Whole-Video Classification**, **Action Spotting (Localization)**, **Video Captioning (Description)**, and the newly integrated **Dense Video Captioning (Dense Description)**. +This project is a professional video annotation desktop application built with **PyQt6**. It features a comprehensive **quad-mode** architecture supporting **Whole-Video Classification**, **Action Spotting (Localization)**, **Video Captioning (Description)**, and the newly integrated **Dense Video Captioning (Dense Description)**. + +With the latest update, the Classification mode now features **AI-Powered Smart Annotation**, allowing users to leverage state-of-the-art `soccernetpro` models (e.g., MViT) to automatically infer actions via single or batch processing. The project follows a modular **MVC (Model-View-Controller)** design pattern to ensure strict separation of concerns. It leverages **Qt's Model/View architecture** for resource management and a unified **Media Controller** to ensure stable, high-performance video playback across all modalities. @@ -13,6 +15,7 @@ annotation_tool/ โ”œโ”€โ”€ main.py # Application entry point โ”œโ”€โ”€ viewer.py # Main Window controller (Orchestrator) โ”œโ”€โ”€ utils.py # Helper functions and constants +โ”œโ”€โ”€ config.yaml # [NEW] Inference configuration for soccernetpro models โ”œโ”€โ”€ __init__.py # Package initialization โ”‚ โ”œโ”€โ”€ models/ # [Model Layer] Data Structures & State @@ -21,14 +24,18 @@ annotation_tool/ โ”‚ โ”œโ”€โ”€ controllers/ # [Controller Layer] Business Logic โ”‚ โ”œโ”€โ”€ router.py # Mode detection & Project lifecycle management -โ”‚ โ”œโ”€โ”€ history_manager.py # Universal Undo/Redo system +โ”‚ โ”œโ”€โ”€ history_manager.py # Universal Undo/Redo system (Supports Batch Annotations) โ”‚ โ”œโ”€โ”€ media_controller.py # Unified playback logic (Anti-freeze/Visual clearing) โ”‚ โ”œโ”€โ”€ classification/ # Logic for Classification mode +โ”‚ โ”‚ โ”œโ”€โ”€ class_annotation_manager.py # Manual label state management +โ”‚ โ”‚ โ”œโ”€โ”€ class_file_manager.py # JSON I/O for Classification tasks +โ”‚ โ”‚ โ”œโ”€โ”€ class_navigation_manager.py # Action tree navigation +โ”‚ โ”‚ โ””โ”€โ”€ inference_manager.py # [NEW] AI Smart Annotation (Single/Batch Inference) โ”‚ โ”œโ”€โ”€ localization/ # Logic for Action Spotting (Localization) mode โ”‚ โ”œโ”€โ”€ description/ # Logic for Global Captioning (Description) mode -โ”‚ โ””โ”€โ”€ dense_description/ # [NEW] Logic for Dense Captioning (Text-at-Timestamp) +โ”‚ โ””โ”€โ”€ dense_description/ # Logic for Dense Captioning (Text-at-Timestamp) โ”‚ โ”œโ”€โ”€ dense_manager.py # Core logic for dense annotations & UI sync -โ”‚ โ””โ”€โ”€ dense_file_manager.py # JSON I/O specifically for Dense tasks +โ”‚ โ””โ”€โ”€ dense_file_manager.py # JSON I/O specifically for Dense tasks โ”‚ โ”œโ”€โ”€ ui/ # [View Layer] Interface Definitions โ”‚ โ”œโ”€โ”€ common/ # Shared widgets (Main Window, Sidebar, Video Surface) @@ -37,9 +44,13 @@ annotation_tool/ โ”‚ โ”‚ โ”œโ”€โ”€ workspace.py # Unified 3-column skeleton โ”‚ โ”‚ โ””โ”€โ”€ dialogs.py # Project wizards and mode selectors โ”‚ โ”œโ”€โ”€ classification/ # UI specific to Classification +โ”‚ โ”‚ โ””โ”€โ”€ event_editor/ # Dynamic Schema Editor & [NEW] Smart Annotation UI +โ”‚ โ”‚ โ”œโ”€โ”€ dynamic_widgets.py # Single/Multi label dynamic radio & checkbox groups +โ”‚ โ”‚ โ”œโ”€โ”€ editor.py # Includes NativeDonutChart & Batch Progress UI +โ”‚ โ”‚ โ””โ”€โ”€ controls.py # Playback control bar โ”‚ โ”œโ”€โ”€ localization/ # UI specific to Localization (Timeline + Tabbed Spotting) โ”‚ โ”œโ”€โ”€ description/ # UI specific to Global Captioning (Full-video text) -โ”‚ โ””โ”€โ”€ dense_description/ # [NEW] UI specific to Dense Description +โ”‚ โ””โ”€โ”€ dense_description/ # UI specific to Dense Description โ”‚ โ””โ”€โ”€ event_editor/ โ”‚ โ”œโ”€โ”€ __init__.py # Right panel assembler for Dense mode โ”‚ โ”œโ”€โ”€ desc_input_widget.py # Text input & timestamp submission @@ -47,9 +58,7 @@ annotation_tool/ โ”‚ โ””โ”€โ”€ style/ # Visual theme assets โ””โ”€โ”€ style.qss # Centralized Dark mode stylesheet - ``` - --- ## ๐Ÿ“ Detailed Module Descriptions diff --git a/annotation_tool/config.yaml b/annotation_tool/config.yaml new file mode 100644 index 0000000..d3e7fe2 --- /dev/null +++ b/annotation_tool/config.yaml @@ -0,0 +1,108 @@ +TASK: classification + +DATA: + TASK: classification + dataset_name: mvfouls + data_dir: "" + data_modality: video + view_type: multi + num_classes: 8 + + classes: + - Challenge + - Dive + - Elbowing + - High leg + - Holding + - Pushing + - Standing tackling + - Tackling + + # โฌ…๏ธ Added back the dummy train block + train: + type: annotations_train.json + video_path: "" + path: "" + dataloader: + batch_size: 1 + shuffle: true + num_workers: 0 + pin_memory: false + + # โฌ…๏ธ Added back the dummy valid block + valid: + type: annotations_valid.json + video_path: "" + path: "" + dataloader: + batch_size: 1 + num_workers: 0 + shuffle: false + + test: + type: annotations_test.json + video_path: "" + path: ./temp_workspace/temp_test.json + dataloader: + batch_size: 1 + num_workers: 0 + shuffle: false + + num_frames: 16 + input_fps: 25 + target_fps: 17 + start_frame: 63 + end_frame: 87 + frame_size: [224, 224] + + augmentations: + random_affine: false + random_perspective: false + random_rotation: false + color_jitter: false + random_horizontal_flip: false + random_crop: false +MODEL: + TASK: classification + type: custom + backbone: + type: mvit_v2_s + neck: + type: MV_Aggregate + agr_type: max + head: + type: MV_LinearLayer + pretrained_model: mvit_v2_s + +TRAIN: + monitor: loss + mode: min + enabled: false + use_weighted_sampler: false + use_weighted_loss: false + epochs: 1 + save_dir: ./temp_workspace/checkpoints + criterion: + type: CrossEntropyLoss + optimizer: + type: AdamW + lr: 0.0001 + backbone_lr: 0.00005 + head_lr: 0.001 + betas: [0.9, 0.999] + eps: 0.0000001 + weight_decay: 0.001 + amsgrad: false + scheduler: + type: StepLR + step_size: 3 + gamma: 0.1 + + +SYSTEM: + log_dir: ./logs + use_seed: false + seed: 42 + GPU: 0 + device: cpu + gpu_id: 0 \ No newline at end of file diff --git a/annotation_tool/controllers/.DS_Store b/annotation_tool/controllers/.DS_Store deleted file mode 100644 index 57a9279..0000000 Binary files a/annotation_tool/controllers/.DS_Store and /dev/null differ diff --git a/annotation_tool/controllers/classification/class_annotation_manager.py b/annotation_tool/controllers/classification/class_annotation_manager.py index 7d9822a..db921cc 100644 --- a/annotation_tool/controllers/classification/class_annotation_manager.py +++ b/annotation_tool/controllers/classification/class_annotation_manager.py @@ -9,11 +9,130 @@ def __init__(self, main_window): self.model = main_window.model self.ui = main_window.ui - def save_manual_annotation(self): + def confirm_smart_annotation_as_manual(self): + """ + [MODIFIED] Mark current smart prediction(s) as confirmed. + Added Undo/Redo support for Smart Annotations to fix history bugs. + """ + import copy + from models.app_state import CmdType # Ensure CmdType is available + right_panel = self.ui.classification_ui.right_panel + + # Check if we are confirming a batch or a single inference + if right_panel.is_batch_mode_active: + # --- BATCH CONFIRMATION LOGIC --- + batch_preds = right_panel.pending_batch_results + if not batch_preds: + self.main.show_temp_msg("Notice", "No batch predictions to confirm.") + return + + old_batch_data = {} + new_batch_data = {} + confirmed_count = 0 + + # Loop through all items in the batch + for path, pred_data in batch_preds.items(): + # Store the old state for Undo + old_batch_data[path] = copy.deepcopy(self.model.smart_annotations.get(path)) + + # --- ROBUST DATA FORMATTING --- + if isinstance(pred_data, str): + head = next(iter(self.model.label_definitions.keys()), "action") + formatted_data = {head: {"label": pred_data, "conf_dict": {pred_data: 1.0}}} + elif isinstance(pred_data, dict) and "label" in pred_data: + head = next(iter(self.model.label_definitions.keys()), "action") + formatted_data = {head: copy.deepcopy(pred_data)} + else: + formatted_data = copy.deepcopy(pred_data) + + # [NEW FIX] Ensure 'conf_dict' exists for the Donut Chart rendering! + for h, h_data in formatted_data.items(): + if isinstance(h_data, dict) and "label" in h_data: + if "conf_dict" not in h_data: + # Safely extract 'confidence', fallback to 1.0 if not found + conf = h_data.get("confidence", 1.0) + h_data["conf_dict"] = {h_data["label"]: conf} + # Also calculate the remaining percentage for the pie chart + rem = 1.0 - conf + if rem > 0.001: + h_data["conf_dict"]["Other Uncertainties"] = rem + + # Mark as confirmed safely + formatted_data["_confirmed"] = True + + # Store the new state for Redo + new_batch_data[path] = copy.deepcopy(formatted_data) + + # Save to model memory + self.model.smart_annotations[path] = formatted_data + self.main.update_action_item_status(path) + confirmed_count += 1 + + # [NEW] Push the batch confirmation to the Undo stack + self.model.push_undo(CmdType.BATCH_SMART_ANNOTATION_RUN, old_data=old_batch_data, new_data=new_batch_data) + + self.model.is_data_dirty = True + self.main.show_temp_msg("Saved", f"Batch Smart Annotations confirmed for {confirmed_count} items.", 2000) + + # Reset the batch UI back to normal after confirmation + right_panel.reset_smart_inference() + + else: + # --- SINGLE CONFIRMATION LOGIC --- + path = self.main.get_current_action_path() + if not path: return + + smart_data = self.model.smart_annotations.get(path) + if not smart_data: + self.main.show_temp_msg("Notice", "No smart annotation available to confirm.") + return + + # Store the old state for Undo + old_data = copy.deepcopy(smart_data) + + # Flag it as confirmed internally within the smart memory + self.model.smart_annotations[path]["_confirmed"] = True + self.model.is_data_dirty = True + + # Store the new state for Redo + new_data = copy.deepcopy(self.model.smart_annotations[path]) + + # [NEW] Push the single confirmation to the Undo stack + self.model.push_undo(CmdType.SMART_ANNOTATION_RUN, path=path, old_data=old_data, new_data=new_data) + + self.main.update_action_item_status(path) + self.main.show_temp_msg("Saved", "Smart Annotation confirmed independently.", 1000) + + # --- COMMON UI UPDATES --- + self.main.update_save_export_button_state() + + # Apply filter immediately to reflect the new Smart Labelled status + self.main.nav_manager.apply_action_filter() + + # Auto-advance to the next video clip + tree = self.ui.classification_ui.left_panel.tree + curr_idx = tree.currentIndex() + if curr_idx.isValid(): + nxt_idx = tree.indexBelow(curr_idx) + if nxt_idx.isValid(): + from PyQt6.QtCore import QTimer + QTimer.singleShot(500, lambda: [tree.setCurrentIndex(nxt_idx), tree.scrollTo(nxt_idx)]) + + def save_manual_annotation(self, override_data=None): + """ + [MODIFIED] Added 'override_data' parameter. + If provided (e.g., from Smart Annotation confirm), it uses the provided dict. + Otherwise, it falls back to reading the Hand Annotation UI state. + """ path = self.main.get_current_action_path() if not path: return - raw = self.ui.classification_ui.right_panel.get_annotation() + # Use provided data if available, otherwise read from the UI + if override_data is not None: + raw = override_data + else: + raw = self.ui.classification_ui.right_panel.get_annotation() + cleaned = {k: v for k, v in raw.items() if v} if not cleaned: cleaned = None @@ -29,6 +148,7 @@ def save_manual_annotation(self): self.main.update_action_item_status(path) self.main.update_save_export_button_state() + self.main.nav_manager.apply_action_filter() # [MV Fix] Auto-advance using QTreeView API tree = self.ui.classification_ui.left_panel.tree @@ -52,10 +172,50 @@ def clear_current_manual_annotation(self): self.main.show_temp_msg("Cleared", "Selection cleared.") self.ui.classification_ui.right_panel.clear_selection() + def clear_current_smart_annotation(self): + """[NEW] Clear the smart annotation for the current video, with Undo support.""" + path = self.main.get_current_action_path() + if not path: return + + old_smart = copy.deepcopy(self.model.smart_annotations.get(path)) + if old_smart: + # Push the clearing action to the Undo stack using the SMART_ANNOTATION_RUN cmd + self.model.push_undo( + CmdType.SMART_ANNOTATION_RUN, + path=path, + old_data=old_smart, + new_data=None + ) + + # Remove from model memory + if path in self.model.smart_annotations: + del self.model.smart_annotations[path] + + self.model.is_data_dirty = True + self.main.show_temp_msg("Cleared", "Smart Annotation cleared.", 1000) + self.main.update_save_export_button_state() + + # Visually hide the donut chart and text without affecting the Hand Annotation UI + self.ui.classification_ui.right_panel.chart_widget.setVisible(False) + self.ui.classification_ui.right_panel.batch_result_text.setVisible(False) + def display_manual_annotation(self, path): + # 1. Restore manual annotation (This will reset the UI and hide the chart by default) data = self.model.manual_annotations.get(path, {}) self.ui.classification_ui.right_panel.set_annotation(data) + # 2. [NEW] Re-display the Smart Annotation Donut Chart if data exists + smart_data = self.model.smart_annotations.get(path, {}) + if smart_data: + # We display the chart for the first available head (typically 'action') + for head, s_data in smart_data.items(): + self.ui.classification_ui.right_panel.chart_widget.update_chart( + s_data["label"], + s_data.get("conf_dict", {}) + ) + self.ui.classification_ui.right_panel.chart_widget.setVisible(True) + break + def handle_ui_selection_change(self, head, new_val): if self.main.history_manager._is_undoing_redoing: return @@ -152,4 +312,4 @@ def remove_custom_type(self, head, lbl): group = self.ui.classification_ui.right_panel.label_groups.get(head) if isinstance(group, DynamicSingleLabelGroup): group.update_radios(defn['labels']) else: group.update_checkboxes(defn['labels']) - self.display_manual_annotation(self.main.get_current_action_path()) \ No newline at end of file + self.display_manual_annotation(self.main.get_current_action_path()) diff --git a/annotation_tool/controllers/classification/class_file_manager.py b/annotation_tool/controllers/classification/class_file_manager.py index 2910ce0..7684140 100644 --- a/annotation_tool/controllers/classification/class_file_manager.py +++ b/annotation_tool/controllers/classification/class_file_manager.py @@ -67,6 +67,14 @@ def load_project(self, data, file_path): clean_k = k.strip().replace(' ', '_').lower() self.model.label_definitions[clean_k] = {'type': v['type'], 'labels': sorted(list(set(v.get('labels', []))))} self.main.setup_dynamic_ui() + + # Check if it is multi view + is_multi = False + for item in data.get('data', []): + if len(item.get('inputs', [])) > 1: + is_multi = True + break + self.model.is_multi_view = is_multi # Load Data for item in data.get('data', []): @@ -108,6 +116,22 @@ def load_project(self, data, file_path): if has_l: self.model.manual_annotations[path_key] = manual + # [NEW] Load Smart Annotations from JSON + smart_lbls = item.get('smart_labels', {}) + smart = {} + for h, content in smart_lbls.items(): + ck = h.strip().replace(' ', '_').lower() + if ck in self.model.label_definitions and isinstance(content, dict): + # Reconstruct the prediction and confidence dictionary + smart[ck] = { + "label": content.get("label"), + "conf_dict": content.get("conf_dict", {content.get("label"): content.get("confidence", 1.0)}) + } + if smart: + # [MODIFIED] Mark loaded smart annotations as confirmed so the Filter recognizes them + smart["_confirmed"] = True + self.model.smart_annotations[path_key] = smart + self.model.current_json_path = file_path self.model.json_loaded = True @@ -196,6 +220,23 @@ def _write_json(self, save_path): if entry_labels: data_entry["labels"] = entry_labels + # [NEW] Write smart_labels parallel to manual labels + if path_key in self.model.smart_annotations: + smart_annots = self.model.smart_annotations[path_key] + # [MODIFIED] Only export if they were actually confirmed, and skip the internal flag + if smart_annots.get("_confirmed", False): + entry_smart_labels = {} + for head, data_dict in smart_annots.items(): + if head == "_confirmed": + continue # Skip the internal boolean flag to prevent TypeError + + entry_smart_labels[head] = { + "label": data_dict["label"], + "confidence": data_dict.get("conf_dict", {}).get(data_dict["label"], 1.0), + "conf_dict": data_dict.get("conf_dict", {}) + } + if entry_smart_labels: + data_entry["smart_labels"] = entry_smart_labels out["data"].append(data_entry) try: @@ -214,16 +255,28 @@ def create_new_project(self): """ Creates a blank project immediately, allowing the user to build the schema in the right-hand panel. + Now asks for SV/MV type before proceeding. """ + # Ask Single-View or Multi-View + from ui.common.dialogs import ClassificationTypeDialog + dialog = ClassificationTypeDialog(self.main) + + if not dialog.exec(): + return + # 1. Clear existing data (Full Reset) self._clear_workspace(full_reset=True) # 2. Initialize default "Blank Project" state in the Model self.model.current_task_name = "Untitled Task" self.model.modalities = ["video"] - self.model.label_definitions = {} # Empty Category (Category Editor start blank) + self.model.label_definitions = {} # Empty Category self.model.project_description = "" + # 2. Initialize default "Blank Project" state in the Model + # [MODIFIED] Changed from "Untitled Task" to "action_classification". + self.model.current_task_name = "action_classification" + # 3. Set flags to allow interaction self.model.json_loaded = True self.model.is_data_dirty = True @@ -246,7 +299,20 @@ def _clear_workspace(self, full_reset=False): self.model.reset(full_reset) self.main.update_save_export_button_state() + + # --- UI Resets --- self.ui.classification_ui.right_panel.manual_box.setEnabled(False) self.ui.classification_ui.center_panel.show_single_view(None) + + # [NEW] Explicitly reset the Smart Annotation UI (hide donut chart & batch results) + if hasattr(self.ui.classification_ui.right_panel, 'reset_smart_inference'): + self.ui.classification_ui.right_panel.reset_smart_inference() + + if hasattr(self.ui.classification_ui.right_panel, 'reset_train_ui'): + self.ui.classification_ui.right_panel.reset_train_ui() if full_reset: self.main.setup_dynamic_ui() + + # [NEW] Clear the Smart Annotation dropdowns when workspace is reset + if hasattr(self.main, 'sync_batch_inference_dropdowns'): + self.main.sync_batch_inference_dropdowns() \ No newline at end of file diff --git a/annotation_tool/controllers/classification/class_navigation_manager.py b/annotation_tool/controllers/classification/class_navigation_manager.py index 20021d0..6189cc8 100644 --- a/annotation_tool/controllers/classification/class_navigation_manager.py +++ b/annotation_tool/controllers/classification/class_navigation_manager.py @@ -36,7 +36,12 @@ def __init__(self, main_window): def add_items_via_dialog(self): """ Allows user to add video/image files to the project. + Smartly handles SV vs MV based on the loaded JSON flag. """ + from PyQt6.QtWidgets import QMessageBox, QFileDialog + import os + from collections import defaultdict + if not self.model.json_loaded: QMessageBox.warning(self.main, "Warning", "Please create or load a project first.") return @@ -51,24 +56,54 @@ def add_items_via_dialog(self): self.model.current_working_directory = os.path.dirname(files[0]) added_count = 0 - for file_path in files: - # Duplicate check - if any(d['path'] == file_path for d in self.model.action_item_data): - continue - - name = os.path.basename(file_path) - self.model.action_item_data.append({'name': name, 'path': file_path, 'source_files': [file_path]}) + + is_mv = getattr(self.model, 'is_multi_view', False) + + if is_mv: + grouped_files = defaultdict(list) - # [MV Fix] Add to Model directly - item = self.main.tree_model.add_entry(name, file_path, [file_path]) - self.model.action_item_map[file_path] = item - added_count += 1 + for file_path in files: + dir_name = os.path.dirname(file_path) + grouped_files[dir_name].append(file_path) + + for dir_path, paths in grouped_files.items(): + paths.sort() + + if len(paths) > 1: + name = os.path.basename(dir_path) + else: + name = os.path.basename(paths[0]) + + if any(d['name'] == name for d in self.model.action_item_data): + continue + + main_path = paths[0] + self.model.action_item_data.append({'name': name, 'path': main_path, 'source_files': paths}) + + item = self.main.tree_model.add_entry(name, main_path, paths) + self.model.action_item_map[main_path] = item + added_count += 1 + + else: + for file_path in files: + if any(d['path'] == file_path for d in self.model.action_item_data): + continue + + name = os.path.basename(file_path) + self.model.action_item_data.append({'name': name, 'path': file_path, 'source_files': [file_path]}) + + item = self.main.tree_model.add_entry(name, file_path, [file_path]) + self.model.action_item_map[file_path] = item + added_count += 1 if added_count > 0: self.model.is_data_dirty = True self.apply_action_filter() self.main.show_temp_msg("Added", f"Added {added_count} items.") + # [NEW] Force Smart Annotation dropdowns to update with the new videos + self.main.sync_batch_inference_dropdowns() + def remove_single_action_item(self, index: QModelIndex): """ Removes an item given its QModelIndex. @@ -92,6 +127,9 @@ def remove_single_action_item(self, index: QModelIndex): self.main.show_temp_msg("Removed", "Item removed.") self.main.update_save_export_button_state() + # [NEW] Force Smart Annotation dropdowns to update after deletion + self.main.sync_batch_inference_dropdowns() + def on_item_selected(self, current, previous): """ Called when the user clicks a different item in the left tree. @@ -136,24 +174,55 @@ def show_all_views(self): self.ui.classification_ui.center_panel.show_all_views([p for p in paths if p.lower().endswith(SUPPORTED_EXTENSIONS[:3])]) - def apply_action_filter(self): - """Filters the tree items based on Done/Not Done status using setRowHidden.""" - idx = self.ui.classification_ui.left_panel.filter_combo.currentIndex() - tree_view = self.ui.classification_ui.left_panel.tree + def apply_action_filter(self, index=None): + """ + Filter the tree based on 4 custom states for Classification. + 0: Show All + 1: Hand Labelled (Has manual annotation) + 2: Smart Labelled (Has confirmed smart annotation) + 3: No Labelled (Neither hand nor smart confirmed) + """ + tree = self.ui.classification_ui.left_panel.tree + combo = self.ui.classification_ui.left_panel.filter_combo + + # Use the passed index from the signal, or the current combo box index + filter_idx = combo.currentIndex() if index is None else index + if filter_idx < 0: return + model = self.main.tree_model - root = model.invisibleRootItem() - for i in range(root.rowCount()): - item = root.child(i) - # We access data via the item (QStandardItem) or index + for row in range(model.rowCount()): + idx = model.index(row, 0) + item = model.itemFromIndex(idx) + if not item: continue + path = item.data(ProjectTreeModel.FilePathRole) - is_done = (path in self.model.manual_annotations and bool(self.model.manual_annotations[path])) - should_hide = False - if idx == self.main.FILTER_DONE and not is_done: should_hide = True - elif idx == self.main.FILTER_NOT_DONE and is_done: should_hide = True + # 1. Is it Hand Labelled? (Exists in manual_annotations) + is_hand_labelled = path in self.model.manual_annotations and bool(self.model.manual_annotations[path]) + + # 2. Is it Smart Labelled? (Has _confirmed flag in smart_annotations) + smart_data = self.model.smart_annotations.get(path, {}) + # [MODIFIED] Removed the mutually exclusive condition "and not is_hand_labelled". + # Now an item can be treated as both Hand Labelled and Smart Labelled simultaneously. + is_smart_labelled = smart_data.get("_confirmed", False) - tree_view.setRowHidden(i, QModelIndex(), should_hide) + # 3. No Labelled (Neither hand nor smart confirmed) + is_no_labelled = not is_hand_labelled and not is_smart_labelled + + # 4. Apply hiding logic based on the selected filter index + hidden = False + if filter_idx == 1 and not is_hand_labelled: + # Hide if "Hand Labelled" is selected but the item lacks hand labels + hidden = True + elif filter_idx == 2 and not is_smart_labelled: + # Hide if "Smart Labelled" is selected but the item lacks smart labels + hidden = True + elif filter_idx == 3 and not is_no_labelled: + # Hide if "No Labelled" is selected but the item has ANY label + hidden = True + + tree.setRowHidden(row, QModelIndex(), hidden) def nav_prev_action(self): self._nav_tree(step=-1, level='top') def nav_next_action(self): self._nav_tree(step=1, level='top') @@ -198,4 +267,4 @@ def _nav_tree(self, step, level): new_row = curr.row() + step if 0 <= new_row < model.rowCount(parent): nxt = model.index(new_row, 0, parent) - tree.setCurrentIndex(nxt); tree.scrollTo(nxt) \ No newline at end of file + tree.setCurrentIndex(nxt); tree.scrollTo(nxt) diff --git a/annotation_tool/controllers/classification/inference_manager.py b/annotation_tool/controllers/classification/inference_manager.py new file mode 100644 index 0000000..12eadac --- /dev/null +++ b/annotation_tool/controllers/classification/inference_manager.py @@ -0,0 +1,569 @@ +import os +import sys +import json +import glob +import ssl +import copy +import uuid +import re +import yaml +from models import CmdType +from PyQt6.QtCore import QThread, pyqtSignal, QObject +from PyQt6.QtWidgets import QMessageBox +from utils import natural_sort_key + +os.environ["WANDB_MODE"] = "disabled" +ssl._create_default_https_context = ssl._create_unverified_context + +from soccernetpro import model + + +def _run_soccernet_inference(base_config_path: str, temp_data: dict, prefix: str): + """ + [REFACTORED] A shared helper function to handle the repetitive setup, + execution, and cleanup of the soccernetpro inference process. + Used by both Single Inference and Batch Inference workers. + """ + writable_dir = os.path.join(os.path.expanduser("~"), ".soccernet_workspace") + os.makedirs(writable_dir, exist_ok=True) + + writable_dir_fwd = writable_dir.replace('\\', '/') + logs_dir_fwd = os.path.join(writable_dir, "logs").replace('\\', '/') + + unique_id = uuid.uuid4().hex[:8] + temp_json_path = os.path.join(writable_dir, f"temp_{prefix}_{unique_id}.json") + temp_config_path = os.path.join(writable_dir, f"temp_config_{prefix}_{unique_id}.yaml") + + try: + # 1. Write the temporary JSON data + with open(temp_json_path, 'w', encoding='utf-8') as f: + json.dump(temp_data, f, indent=4) + + # 2. Read and modify the YAML config dynamically + with open(base_config_path, 'r', encoding='utf-8') as f: + config_text = f.read() + + config_text = config_text.replace('./temp_workspace', writable_dir_fwd) + config_text = config_text.replace('./logs', logs_dir_fwd) + + with open(temp_config_path, 'w', encoding='utf-8') as f: + f.write(config_text) + + # 3. Initialize model and run inference + myModel = model.classification(config=temp_config_path) + metrics = myModel.infer( + test_set=temp_json_path, + pretrained="jeetv/snpro-classification-mvit" + ) + + # 4. Search for the generated prediction output + checkpoint_dir = os.path.join(writable_dir, "checkpoints") + search_pattern = os.path.join(checkpoint_dir, "**", "predictions_test_epoch_*.json") + pred_files = glob.glob(search_pattern, recursive=True) + + if not pred_files: + raise FileNotFoundError("Could not find the generated prediction JSON file.") + + latest_pred_file = max(pred_files, key=os.path.getctime) + with open(latest_pred_file, 'r', encoding='utf-8') as pf: + pred_data = json.load(pf) + + return metrics if metrics else {}, pred_data + + finally: + # 5. Guaranteed cleanup of temporary payload files + if os.path.exists(temp_json_path): + try: os.remove(temp_json_path) + except: pass + if os.path.exists(temp_config_path): + try: os.remove(temp_config_path) + except: pass + + +class InferenceWorker(QThread): + finished_signal = pyqtSignal(str, str, dict) + error_signal = pyqtSignal(str) + + def __init__(self, config_path, base_dir, action_id, json_path, video_path, label_map): + super().__init__() + self.config_path = config_path + self.base_dir = base_dir + self.action_id = str(action_id) + self.json_path = json_path + self.video_path = video_path + + # [DYNAMIC] Assigned from config.yaml, no more hardcoding! + self.label_map = label_map + + def run(self): + try: + video_abs_path = self.video_path + if not os.path.isabs(video_abs_path): + if self.json_path and os.path.exists(self.json_path): + video_abs_path = os.path.join(os.path.dirname(self.json_path), self.video_path) + else: + video_abs_path = os.path.abspath(self.video_path) + + video_abs_path = os.path.normpath(video_abs_path).replace('\\', '/') + + if not os.path.exists(video_abs_path): + raise FileNotFoundError(f"Cannot find video file at absolute path:\n{video_abs_path}\nPlease ensure the file exists.") + + original_data = {} + target_item = None + + if self.json_path and os.path.exists(self.json_path): + with open(self.json_path, 'r', encoding='utf-8') as f: + original_data = json.load(f) + + for item in original_data.get("data", []): + if str(item.get("id")) == self.action_id: + target_item = copy.deepcopy(item) + break + + # Dynamic default fallback from the schema instead of hardcoded strings + default_label = list(self.label_map.values())[0] if self.label_map else "Unknown" + + if not target_item: + target_item = { + "id": self.action_id, + "inputs": [{"type": "video", "path": video_abs_path}], + "labels": { + "action": {"label": default_label, "confidence": 1.0} + } + } + else: + for inp in target_item.get("inputs", []): + inp["path"] = video_abs_path + if "type" not in inp: + inp["type"] = "video" + + if "labels" not in target_item: + target_item["labels"] = {} + if "action" not in target_item["labels"]: + target_item["labels"]["action"] = {"label": default_label} + + global_labels = original_data.get("labels", {}) + if not isinstance(global_labels, dict): + global_labels = {} + + if "action" not in global_labels: + global_labels["action"] = { + "type": "single_label", + "labels": list(self.label_map.values()) + } + + temp_data = { + "version": original_data.get("version", "2.0"), + "task": "classification", + "labels": global_labels, + "data": [target_item] + } + + # Use the shared helper function to run inference + metrics, pred_data = _run_soccernet_inference(self.config_path, temp_data, "infer") + + predicted_label_idx = None + confidence = 0.0 + raw_action_data = {} + + pred_items = pred_data.get("data", []) + + if len(pred_items) == 1: + raw_action_data = pred_items[0].get("labels", {}).get("action", {}) + if "label" in raw_action_data: + predicted_label_idx = str(raw_action_data["label"]).strip() + confidence = float(raw_action_data.get("confidence", 0.0)) + else: + clean_action_id = re.sub(r'_view\d+', '', self.action_id) + for item in pred_items: + out_id = str(item.get("id")) + if out_id == self.action_id or out_id == clean_action_id: + raw_action_data = item.get("labels", {}).get("action", {}) + if "label" in raw_action_data: + predicted_label_idx = str(raw_action_data["label"]).strip() + confidence = float(raw_action_data.get("confidence", 0.0)) + break + + if predicted_label_idx is None: + raise ValueError(f"Dataloader dropped the sample or prediction missing for ID '{self.action_id}'.") + + final_label = "Unknown" + valid_class_names = list(self.label_map.values()) + + if predicted_label_idx in valid_class_names: + final_label = predicted_label_idx + elif predicted_label_idx in self.label_map: + final_label = self.label_map[predicted_label_idx] + elif predicted_label_idx.endswith(".0"): + clean_idx = predicted_label_idx.replace(".0", "") + if clean_idx in self.label_map: + final_label = self.label_map[clean_idx] + + conf_dict = {} + if "confidences" in raw_action_data and isinstance(raw_action_data["confidences"], dict): + for k, v in raw_action_data["confidences"].items(): + key_name = self.label_map.get(str(k), str(k)) + conf_dict[key_name] = float(v) + else: + conf_dict[final_label] = confidence + remaining = max(0.0, 1.0 - confidence) + if remaining > 0.001: + conf_dict["Other Uncertainties"] = remaining + + self.finished_signal.emit("action", final_label, conf_dict) + + except Exception as e: + self.error_signal.emit(str(e)) + + +class BatchInferenceWorker(QThread): + finished_signal = pyqtSignal(dict, list) + error_signal = pyqtSignal(str) + + def __init__(self, config_path, base_dir, json_path, target_clips, label_map): + super().__init__() + self.config_path = config_path + self.base_dir = base_dir + self.json_path = json_path + self.target_clips = target_clips + + # [DYNAMIC] Load map from external source + self.label_map = label_map + + def _map_label(self, raw_label): + valid_class_names = list(self.label_map.values()) + if raw_label in valid_class_names: return raw_label + elif raw_label in self.label_map: return self.label_map[raw_label] + elif raw_label.endswith(".0"): + clean_idx = raw_label.replace(".0", "") + if clean_idx in self.label_map: return self.label_map[clean_idx] + return "Unknown" + + def run(self): + try: + data_items = [] + default_label = list(self.label_map.values())[0] if self.label_map else "Unknown" + + for clip in self.target_clips: + inputs = [] + for path in clip['paths']: + video_abs_path = path + if not os.path.isabs(video_abs_path): + if self.json_path and os.path.exists(self.json_path): + video_abs_path = os.path.join(os.path.dirname(self.json_path), video_abs_path) + else: + video_abs_path = os.path.abspath(video_abs_path) + video_abs_path = os.path.normpath(video_abs_path).replace('\\', '/') + inputs.append({"type": "video", "path": video_abs_path}) + + # Fallback to default label instead of hardcoded strings + safe_gt = clip['gt'] if clip['gt'] else default_label + + item = { + "id": clip['id'], + "inputs": inputs, + "labels": {"action": {"label": safe_gt, "confidence": 1.0}} + } + data_items.append(item) + + global_labels = { + "action": { + "type": "single_label", + "labels": list(self.label_map.values()) + } + } + + temp_data = { + "version": "2.0", + "task": "classification", + "labels": global_labels, + "data": data_items + } + + # Use the shared helper function to run inference + metrics, pred_data = _run_soccernet_inference(self.config_path, temp_data, "batch_infer") + + pred_items = pred_data.get("data", []) + out_dict = {} + for item in pred_items: + out_id = str(item.get("id")) + raw_action = item.get("labels", {}).get("action", {}) + raw_label = str(raw_action.get("label", "")).strip() + conf = float(raw_action.get("confidence", 0.0)) + out_dict[out_id] = (self._map_label(raw_label), conf) + + results = [] + for clip in self.target_clips: + aid = clip['id'] + clean_id = os.path.splitext(aid)[0] + + pred_label, conf = out_dict.get(aid, (None, 0.0)) + if pred_label is None: + pred_label, conf = out_dict.get(clean_id, ("Unknown", 0.0)) + + results.append({ + 'id': aid, + 'gt': clip['gt'], + 'pred': pred_label, + 'conf': conf, + 'original_items': clip['original_items'] + }) + + self.finished_signal.emit(metrics, results) + + except Exception as e: + self.error_signal.emit(str(e)) + + +class InferenceManager(QObject): + def __init__(self, main_window): + super().__init__() + self.main = main_window + self.ui = main_window.ui + + if hasattr(sys, '_MEIPASS'): + self.base_dir = sys._MEIPASS + else: + self.base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + + self.config_path = os.path.join(self.base_dir, "config.yaml") + self.worker = None + self.batch_worker = None + + self.ui.classification_ui.right_panel.batch_run_requested.connect(self.start_batch_inference) + self.ui.classification_ui.right_panel.batch_confirm_requested.connect(self.confirm_batch_inference) + + def _get_label_map_from_config(self) -> dict: + """ + [DYNAMIC PARSING] Reads the config.yaml on-the-fly to extract the classes list. + Prevents hardcoding so the framework scales effortlessly to new sports/models. + """ + label_map = {} + try: + with open(self.config_path, 'r', encoding='utf-8') as f: + config_data = yaml.safe_load(f) + + # Extract classes array safely from YAML structure + if config_data and 'DATA' in config_data and 'classes' in config_data['DATA']: + classes_list = config_data['DATA']['classes'] + for i, cls_name in enumerate(classes_list): + label_map[str(i)] = cls_name + except Exception as e: + print(f"Warning: Could not read classes from config.yaml dynamically: {e}") + + # Absolute failsafe if the user forgot to write `classes:` in their yaml + if not label_map: + label_map = { + '0': 'Challenge', '1': 'Dive', '2': 'Elbowing', '3': 'High leg', + '4': 'Holding', '5': 'Pushing', '6': 'Standing tackling', '7': 'Tackling' + } + + return label_map + + def start_inference(self): + if not os.path.exists(self.config_path): + QMessageBox.critical(self.main, "Error", f"config.yaml not found at:\n{self.config_path}") + return + + current_json_path = self.main.model.current_json_path + current_video_path = self.main.get_current_action_path() + if not current_video_path: + QMessageBox.warning(self.main, "Warning", "Please select an action/video from the list first.") + return + + action_id = self.main.model.action_path_to_name.get(current_video_path, os.path.basename(current_video_path)) + + self.ui.classification_ui.right_panel.show_inference_loading(True) + + # 1. Dynamically load labels from config + label_map = self._get_label_map_from_config() + + # 2. Pass labels to worker + self.worker = InferenceWorker(self.config_path, self.base_dir, action_id, current_json_path, current_video_path, label_map) + self.worker.finished_signal.connect(self._on_inference_success) + self.worker.error_signal.connect(self._on_inference_error) + self.worker.start() + + def _on_inference_success(self, target_head, label, conf_dict): + # Auto-create the schema (Category) if it's a completely blank/new project + if target_head not in self.main.model.label_definitions: + if self.worker: + # Use dynamically generated labels + default_labels = list(self.worker.label_map.values()) + self.main.model.label_definitions[target_head] = { + "type": "single_label", + "labels": sorted(default_labels) + } + # Force UI regeneration to display radio buttons + self.main.setup_dynamic_ui() + + # [NEW] Save raw inference result to smart_annotations memory + current_video_path = self.main.get_current_action_path() + # [NEW] Capture old state before overwriting + old_data = self.main.model.smart_annotations.get(current_video_path, {}) + new_data = { + target_head: {"label": label, "conf_dict": conf_dict} + } + + # [NEW] Push to Undo History + import copy + self.main.model.push_undo( + CmdType.SMART_ANNOTATION_RUN, + path=current_video_path, + old_data=copy.deepcopy(old_data), + new_data=copy.deepcopy(new_data) + ) + + # Save new data + if current_video_path not in self.main.model.smart_annotations: + self.main.model.smart_annotations[current_video_path] = {} + self.main.model.smart_annotations[current_video_path] = new_data + + self.main.model.is_data_dirty = True + self.ui.classification_ui.right_panel.display_inference_result(target_head, label, conf_dict) + self.worker = None + + def _on_inference_error(self, error_msg): + self.ui.classification_ui.right_panel.show_inference_loading(False) + QMessageBox.critical(self.main, "Inference Error", f"An error occurred during inference:\n\n{error_msg}") + self.worker = None + + def start_batch_inference(self, start_idx: int, end_idx: int): + if not os.path.exists(self.config_path): + QMessageBox.critical(self.main, "Error", f"config.yaml not found at:\n{self.config_path}") + return + + sorted_items = sorted(self.main.model.action_item_data, key=lambda x: natural_sort_key(x.get('name', ''))) + + action_groups = {} + for item in sorted_items: + base_id = re.sub(r'_view\d+', '', item['name']) + if base_id not in action_groups: + action_groups[base_id] = [] + action_groups[base_id].append(item) + + sorted_base_ids = list(action_groups.keys()) + max_idx = len(sorted_base_ids) - 1 + + if start_idx < 0 or end_idx > max_idx or start_idx > end_idx: + QMessageBox.warning(self.main, "Invalid Range", f"Please enter a valid range between 0 and {max_idx}.") + return + + target_base_ids = sorted_base_ids[start_idx : end_idx + 1] + + target_clips = [] + for base_id in target_base_ids: + items = action_groups[base_id] + paths = [it['path'] for it in items] + + # Extract current ground truth + gt_label = "" + for it in items: + ann = self.main.model.manual_annotations.get(it['path'], {}) + if 'action' in ann: + gt_label = ann['action'] + break + + target_clips.append({'id': base_id, 'paths': paths, 'gt': gt_label, 'original_items': items}) + + self.ui.classification_ui.right_panel.show_inference_loading(True) + + # 1. Dynamically load labels from config + label_map = self._get_label_map_from_config() + + # 2. Pass labels to batch worker + self.batch_worker = BatchInferenceWorker(self.config_path, self.base_dir, self.main.model.current_json_path, target_clips, label_map) + self.batch_worker.finished_signal.connect(self._on_batch_inference_success) + self.batch_worker.error_signal.connect(self._on_batch_inference_error) + self.batch_worker.start() + + def _on_batch_inference_success(self, metrics: dict, results_list: list): + # Auto-create the schema (Category) if it's a completely blank/new project + target_head = "action" + if target_head not in self.main.model.label_definitions: + if self.batch_worker: + default_labels = list(self.batch_worker.label_map.values()) + self.main.model.label_definitions[target_head] = { + "type": "single_label", + "labels": sorted(default_labels) + } + self.main.setup_dynamic_ui() + + # Start building the output text without the accuracy metrics + text = "BATCH INFERENCE PREDICTIONS:\n\n" + batch_predictions = {} + + old_batch_data = {} + new_batch_data = {} + import copy + + for r in results_list: + text += f"Video ID: {r['id']}\nPredicted Class: {r['pred']} (Confidence: {r['conf']*100:.1f}%)\n\n" + + for item in r['original_items']: + path = item['path'] + + # [NEW FIX 2] Store a rich dictionary instead of just a string! + # This ensures the Confidence is passed to the UI for the Donut Chart. + conf_dict = {r['pred']: r['conf']} + if r['conf'] < 1.0: + conf_dict["Other Uncertainties"] = 1.0 - r['conf'] + + batch_predictions[path] = { + "label": r['pred'], + "confidence": r['conf'], + "conf_dict": conf_dict + } + + # Record old data for Undo + if path not in old_batch_data: + old_batch_data[path] = self.main.model.smart_annotations.get(path, {}) + + # Prepare new data for Redo + new_batch_data[path] = { + target_head: {"label": r['pred'], "conf_dict": conf_dict} + } + + # [NEW FIX 1] Push Batch to Undo History using CORRECT keys 'old_data' and 'new_data' + self.main.model.push_undo( + CmdType.BATCH_SMART_ANNOTATION_RUN, + old_data=copy.deepcopy(old_batch_data), + new_data=copy.deepcopy(new_batch_data) + ) + + # Apply new data to model + for path, data in new_batch_data.items(): + self.main.model.smart_annotations[path] = data + + self.main.model.is_data_dirty = True + self.ui.classification_ui.right_panel.display_batch_inference_result(text, batch_predictions) + self.batch_worker = None + + def _on_batch_inference_error(self, error_msg): + self.ui.classification_ui.right_panel.show_inference_loading(False) + QMessageBox.critical(self.main, "Batch Inference Error", f"An error occurred during batch inference:\n\n{error_msg}") + self.batch_worker = None + + def confirm_batch_inference(self, results: dict): + """ + [MODIFIED] Acknowledge batch inference without polluting Hand Annotations. + """ + applied_count = 0 + + # Smart annotations were already pushed to memory and Undo stack + # during _on_batch_inference_success. Here we just mark them as confirmed. + for path, label in results.items(): + if path in self.main.model.smart_annotations: + # [NEW] Set a confirmed flag directly in smart memory + self.main.model.smart_annotations[path]["_confirmed"] = True + self.main.update_action_item_status(path) + applied_count += 1 + + # Update UI global states + if applied_count > 0: + self.main.model.is_data_dirty = True + self.main.update_save_export_button_state() + self.main.show_temp_msg("Batch Annotation", f"Confirmed {applied_count} smart annotations independently.") + else: + self.main.show_temp_msg("Batch Annotation", "No smart annotations to confirm.") diff --git a/annotation_tool/controllers/classification/train_manager.py b/annotation_tool/controllers/classification/train_manager.py new file mode 100644 index 0000000..7b69248 --- /dev/null +++ b/annotation_tool/controllers/classification/train_manager.py @@ -0,0 +1,320 @@ +import os +import sys +import json +import uuid +import yaml +import copy +import re +import io +import contextlib +from PyQt6.QtCore import QThread, pyqtSignal, QObject +from PyQt6.QtWidgets import QMessageBox + +# Assume the model invocation style is the same as inference +from soccernetpro import model + +class TrainWorker(QThread): + """ + Background training thread. + Supports printing checkpoint-save related information to the terminal + every fixed number of steps. + """ + # Signal for appending plain log text to the UI console + log_signal = pyqtSignal(str) + # Signal for updating the progress bar percentage + progress_signal = pyqtSignal(int) + # Signal for updating the short training status label + status_msg_signal = pyqtSignal(str) + # Signal emitted when training ends: (success_flag, message) + finished_signal = pyqtSignal(bool, str) + + def __init__(self, config_path, train_params): + super().__init__() + # Path to the base YAML config template + self.config_path = config_path + # Training parameters collected from the UI + self.params = train_params + # Regex used to capture progress-style outputs such as "12/100 [" + self.progress_re = re.compile(r'(\d+)/(\d+)\s+\[') + + def run(self): + # Create a hidden workspace under the user's home directory + # to store temporary runtime config files + temp_workspace = os.path.join(os.path.expanduser("~"), ".soccernet_workspace") + os.makedirs(temp_workspace, exist_ok=True) + + # Generate a unique temporary config filename for this training run + unique_id = uuid.uuid4().hex[:6] + temp_config_path = os.path.join(temp_workspace, f"temp_train_config_{unique_id}.yaml") + + try: + # Load the original YAML config template + with open(self.config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + # 1. Path setup + # Infer dataset root from the directory of the training annotation JSON + dataset_root = os.path.dirname(self.params['train_json']) + config['DATA']['data_dir'] = str(dataset_root).replace('\\', '/') + + # Create a checkpoint directory inside the dataset root + checkpoint_dir = os.path.join(dataset_root, "checkpoints") + os.makedirs(checkpoint_dir, exist_ok=True) + config['TRAIN']['save_dir'] = str(checkpoint_dir).replace('\\', '/') + + log_dir = os.path.join(dataset_root, "logs") + os.makedirs(log_dir, exist_ok=True) + + if 'SYSTEM' not in config: + config['SYSTEM'] = {} + config['SYSTEM']['log_dir'] = str(log_dir).replace('\\', '/') + + # 2. Structure adjustment + # Inject annotation paths into a custom annotations block + config['DATA']['annotations'] = { + 'train': str(self.params['train_json']), + 'valid': str(self.params['valid_json']) + } + + # 3. Update training hyperparameters from UI inputs + config['TRAIN']['epochs'] = int(self.params['epochs']) + config['TRAIN']['optimizer']['lr'] = float(self.params['lr']) + config['TRAIN']['save_every'] = 1 # Save by epoch + + # [NEW] Try to inject step-based checkpoint save options + # These keys depend on whether soccernetpro supports them internally + config['TRAIN']['save_step'] = 500 + config['TRAIN']['checkpoint_interval'] = 500 # Fallback compatibility key + + # Explicitly enable training mode + config['TRAIN']['enabled'] = True + # Set the device selected from the UI, e.g. "cpu" or "cuda" + config['SYSTEM']['device'] = str(self.params['device']) + + # Ensure the train block exists before modifying dataloader settings + if 'train' not in config['DATA']: + config['DATA']['train'] = {} + + # Ensure the nested dataloader block exists + config['DATA']['train']['dataloader'] = config['DATA'].get('train', {}).get('dataloader', {}) + # Update batch size and number of workers from UI + config['DATA']['train']['dataloader']['batch_size'] = int(self.params['batch']) + config['DATA']['train']['dataloader']['num_workers'] = int(self.params['workers']) + + # Write the runtime YAML config to a temporary file + with open(temp_config_path, 'w', encoding='utf-8') as f: + yaml.dump(config, f) + + # Notify UI that training initialization has started + self.log_signal.emit(f"๐Ÿš€ Initializing Training on {self.params['device']}...") + + # --- 4. Enhanced log interceptor --- + # This stream captures stdout/stderr from training, parses progress, + # forwards readable logs to the UI, and prints special checkpoint info + # directly to the VSCode terminal. + class UILogStream(io.TextIOBase): + def __init__(self, outer_instance, cp_dir): + # Reference back to the outer TrainWorker instance + self.outer = outer_instance + # Directory where checkpoints are stored + self.cp_dir = cp_dir + # Best-effort current epoch string for UI status display + self.current_epoch_str = "Epoch ?/?" + # Internal line buffer for partial writes + self.line_buffer = "" + + def write(self, s): + # Accumulate incoming text because stdout/stderr may write in chunks + self.line_buffer += s + + # Process buffered content whenever a newline or carriage return appears + while '\n' in self.line_buffer or '\r' in self.line_buffer: + idx_n = self.line_buffer.find('\n') + idx_r = self.line_buffer.find('\r') + + # Split on the earliest line boundary + if idx_n != -1 and (idx_r == -1 or idx_n < idx_r): + line, self.line_buffer = self.line_buffer.split('\n', 1) + else: + line, self.line_buffer = self.line_buffer.split('\r', 1) + + # Remove leading/trailing spaces + clean_line = line.strip() + if not clean_line: + continue + + # Detect epoch lines and forward them to the UI log console + if "Epoch" in clean_line: + epoch_match = re.search(r'Epoch\s+\d+/\d+', clean_line) + if epoch_match: + self.current_epoch_str = epoch_match.group(0) + self.outer.log_signal.emit(clean_line) + + # Detect progress lines like "123/500 [" + match = self.outer.progress_re.search(clean_line) + if match: + curr_step = int(match.group(1)) + total_step = int(match.group(2)) + + # Compute percentage for the progress bar + percent = int((curr_step / total_step) * 100) + self.outer.progress_signal.emit(percent) + + # Update short status text in the UI + self.outer.status_msg_signal.emit( + f"{self.current_epoch_str} | Step: {curr_step}/{total_step}" + ) + + # [NEW] Every 500 steps, print an explicit message to the real terminal + # using sys.__stdout__ so it bypasses the redirection + if curr_step > 0 and curr_step % 500 == 0: + msg = ( + f"\n[VSCODE INFO] Iteration {curr_step} reached. " + f"Checkpoint auto-save triggered to: {self.cp_dir}\n" + ) + sys.__stdout__.write(msg) + sys.__stdout__.flush() + else: + # Forward non-progress, non-noisy lines to the UI log panel + if "Training:" not in clean_line and "|" not in clean_line: + self.outer.log_signal.emit(clean_line) + + return len(s) + + # 5. Start training + # Build the classification model using the temporary runtime config + myModel = model.classification(config=temp_config_path) + + # Redirect stdout and stderr from the training process into the custom UI stream + log_stream = UILogStream(self, checkpoint_dir) + with contextlib.ExitStack() as stack: + stack.enter_context(contextlib.redirect_stdout(log_stream)) + stack.enter_context(contextlib.redirect_stderr(log_stream)) + # Launch training + myModel.train() + + # Notify UI that training completed successfully + self.finished_signal.emit(True, f"Training Completed Successfully.\nCheckpoints: {checkpoint_dir}") + + except Exception as e: + # Print traceback to the real console for debugging + import traceback + traceback.print_exc() + # Notify UI that training failed + self.finished_signal.emit(False, str(e)) + finally: + # Always try to remove the temporary config file after training ends + if os.path.exists(temp_config_path): + try: + os.remove(temp_config_path) + except: + pass + +class TrainManager(QObject): + def __init__(self, main_window): + super().__init__() + # Reference to the main window + self.main = main_window + # Shortcut to the classification UI panel on the right side + self.ui = main_window.ui.classification_ui.right_panel + # Background worker thread instance + self.worker = None + + # Resolve base directory differently for bundled app vs source execution + if hasattr(sys, '_MEIPASS'): + self.base_dir = sys._MEIPASS + else: + self.base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + + # Path to the base classification config file + self.config_path = os.path.join(self.base_dir, "config.yaml") + + # Connect UI buttons to start/stop handlers + self.ui.btn_start_train.clicked.connect(self.start_training) + self.ui.btn_stop_train.clicked.connect(self.stop_training) + + def start_training(self): + # Prevent launching a second training job while one is already running + if self.worker and self.worker.isRunning(): + return + + # Get the currently loaded JSON annotation file from the main model + train_json = self.main.model.current_json_path + + # Require that the currently loaded file is the training annotation file + if not train_json or "annotations_train" not in train_json: + QMessageBox.critical(self.main, "Error", "Please load 'annotations_train.json' first.") + return + + # Collect training parameters from the UI controls + params = { + "epochs": self.ui.spin_epochs.currentText(), + "lr": self.ui.edit_lr.text(), + "batch": self.ui.spin_batch.currentText(), + # Extract only the raw device token from combo text, e.g. "cuda (GPU)" -> "cuda" + "device": self.ui.combo_device.currentText().split(" ")[0], + "workers": self.ui.spin_workers.currentText(), + "train_json": train_json, + # Infer the validation annotation path by filename replacement + "valid_json": train_json.replace("annotations_train.json", "annotations_valid.json") + } + + # Prepare UI state for an active training session + self.ui.btn_start_train.setEnabled(False) + self.ui.btn_stop_train.setEnabled(True) + self.ui.train_progress.setVisible(True) + self.ui.train_progress.setValue(0) + self.ui.lbl_train_status.setVisible(True) + self.ui.lbl_train_status.setText("๐Ÿš€ Starting Training Loop...") + self.ui.train_console.clear() + + # Create and wire up the training worker thread + self.worker = TrainWorker(self.config_path, params) + self.worker.log_signal.connect(self._append_log) + self.worker.progress_signal.connect(self.ui.train_progress.setValue) + self.worker.status_msg_signal.connect(self.ui.lbl_train_status.setText) + self.worker.finished_signal.connect(self._on_train_finished) + + # Start background training + self.worker.start() + + def _append_log(self, text): + # Append a line of log text to the training console widget + self.ui.train_console.append(text) + + def stop_training(self): + """Force-stop the training thread.""" + if self.worker and self.worker.isRunning(): + # Ask for user confirmation before aborting training + reply = QMessageBox.question( + self.main, "Confirm Stop", + "Are you sure you want to abort training?\nUnsaved progress in the current epoch will be lost.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No + ) + + if reply == QMessageBox.StandardButton.Yes: + # Update UI to reflect aborting state + self.ui.btn_stop_train.setEnabled(False) + self.ui.lbl_train_status.setText("๐Ÿ›‘ Aborting...") + + # Forcefully terminate the worker thread + self.worker.terminate() + self.worker.wait() + + # Reuse the finish handler with a manual-abort message + self._on_train_finished(False, "Training was manually aborted by user.") + + def _on_train_finished(self, success, message): + # Restore UI controls after training ends + self.ui.btn_start_train.setEnabled(True) + self.ui.btn_stop_train.setEnabled(False) + self.ui.train_progress.setVisible(False) + self.ui.lbl_train_status.setVisible(False) + + # Show result feedback and append a final log line + if success: + QMessageBox.information(self.main, "Success", message) + self._append_log(f"\nโœ… [SUCCESS] {message}") + else: + QMessageBox.critical(self.main, "Train Error", message) + self._append_log(f"\nโŒ [ERROR] {message}") diff --git a/annotation_tool/controllers/description/.DS_Store b/annotation_tool/controllers/description/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/annotation_tool/controllers/description/.DS_Store and /dev/null differ diff --git a/annotation_tool/controllers/history_manager.py b/annotation_tool/controllers/history_manager.py index 2a126d3..ab684d1 100644 --- a/annotation_tool/controllers/history_manager.py +++ b/annotation_tool/controllers/history_manager.py @@ -1,6 +1,8 @@ import copy from models import CmdType from ui.classification.event_editor import DynamicSingleLabelGroup, DynamicMultiLabelGroup +import copy +from models.app_state import CmdType class HistoryManager: """ @@ -80,9 +82,7 @@ def _refresh_active_view(self): def _apply_state_change(self, cmd, is_undo): ctype = cmd['type'] - # ========================================================= # 1. Classification Specific - # ========================================================= if ctype == CmdType.ANNOTATION_CONFIRM: path = cmd['path'] data = cmd['old_data'] if is_undo else cmd['new_data'] @@ -90,6 +90,56 @@ def _apply_state_change(self, cmd, is_undo): if path in self.model.manual_annotations: del self.model.manual_annotations[path] else: self.model.manual_annotations[path] = copy.deepcopy(data) self.main.refresh_ui_after_undo_redo(path) + + # [NEW] Handle batch annotation confirm + elif ctype == CmdType.BATCH_ANNOTATION_CONFIRM: + batch_changes = cmd['batch_changes'] # Retrieve the packed dictionary + + # Loop through every video that was modified in this batch + for path, changes in batch_changes.items(): + data = changes['old_data'] if is_undo else changes['new_data'] + + # Apply the data + if data: + self.model.manual_annotations[path] = copy.deepcopy(data) + else: + if path in self.model.manual_annotations: + del self.model.manual_annotations[path] + + # Update the checkmark status in the Tree UI for this video + self.main.update_action_item_status(path) + + # Refresh the right panel if the currently selected item was affected + self._refresh_active_view() + + # [NEW] Handle single smart annotation run (Donut Chart) + elif ctype == CmdType.SMART_ANNOTATION_RUN: + path = cmd['path'] + data = cmd['old_data'] if is_undo else cmd['new_data'] + + if data: + self.model.smart_annotations[path] = copy.deepcopy(data) + else: + if path in self.model.smart_annotations: + del self.model.smart_annotations[path] + + # Refresh the UI to immediately show or hide the Donut Chart + self._refresh_active_view() + + # [NEW] Handle batch smart annotation run + elif ctype == CmdType.BATCH_SMART_ANNOTATION_RUN: + batch_data = cmd['old_data'] if is_undo else cmd['new_data'] + + for path, data in batch_data.items(): + if data: + self.model.smart_annotations[path] = copy.deepcopy(data) + else: + if path in self.model.smart_annotations: + del self.model.smart_annotations[path] + + # Refresh the UI to reflect batch smart annotations + self._refresh_active_view() + elif ctype == CmdType.UI_CHANGE: path = cmd['path'] @@ -100,6 +150,8 @@ def _apply_state_change(self, cmd, is_undo): if isinstance(grp, DynamicSingleLabelGroup): grp.set_checked_label(val) else: grp.set_checked_labels(val) + + # ========================================================= # 2. Localization Specific (Events) # ========================================================= @@ -378,4 +430,4 @@ def _apply_state_change(self, cmd, is_undo): if evt.get('head') == head and evt.get('label') == src: evt['label'] = dst - self._refresh_active_view() \ No newline at end of file + self._refresh_active_view() diff --git a/annotation_tool/controllers/localization/loc_file_manager.py b/annotation_tool/controllers/localization/loc_file_manager.py index 952c71f..ff28441 100644 --- a/annotation_tool/controllers/localization/loc_file_manager.py +++ b/annotation_tool/controllers/localization/loc_file_manager.py @@ -299,6 +299,30 @@ def _write_json(self, path): except Exception as e: QMessageBox.critical(self.main, "Error", f"Save failed: {e}") return False + + for video_path in sorted(self.model.localization_events.keys()): + # ่Žทๅ–่ฏฅ่ง†้ข‘ๆ‰€ๅฑž็š„ๅŽŸๅง‹ item ๅฎšไน‰๏ผˆๅŒ…ๅซ inputs ่ง†้ข‘ๆบไฟกๆฏ๏ผ‰ + base_item = next((item for item in self.model.action_item_data if item["path"] == video_path), None) + if not base_item: continue + + # 1. ่Žทๅ–ๆ‰‹ๅทฅ๏ผˆๆˆ–ๅทฒ็กฎ่ฎค็š„๏ผ‰ๆ ‡ๆณจ + manual_events = self.model.localization_events.get(video_path, []) + + # 2. ่Žทๅ–ๆœช็กฎ่ฎค็š„ๆ™บ่ƒฝๆ ‡ๆณจ + smart_events = self.model.smart_localization_events.get(video_path, []) + + # ๆž„ๅปบ็ฌฆๅˆ OSL ๆ ‡ๅ‡†่ง„่Œƒ็š„ๅ•ๆกๆ•ฐๆฎ็ป“ๆž„ + out_item = { + "id": base_item.get("id", ""), + "inputs": [{"path": f, "type": "video"} for f in base_item.get("source_files", [video_path])], + "events": manual_events + } + + # ้ตๅพชๅŽŸๅง‹็ป“ๆž„ๆทปๅŠ  smart_events ๅญ—ๆฎต๏ผˆๅฆ‚ๆžœๆœ‰็š„่ฏ๏ผ‰ + if smart_events: + out_item["smart_events"] = smart_events + + items.append(out_item) def _clear_workspace(self, full_reset=False): """ diff --git a/annotation_tool/controllers/localization/loc_inference.py b/annotation_tool/controllers/localization/loc_inference.py new file mode 100644 index 0000000..6ef4a1f --- /dev/null +++ b/annotation_tool/controllers/localization/loc_inference.py @@ -0,0 +1,160 @@ +import os +import json +import tempfile +import yaml +import glob +from PyQt6.QtCore import QObject, QThread, pyqtSignal + +class LocInferenceWorker(QThread): + """ + Background worker for running OpenSportsLib Localization inference. + Dynamically patches config for CPU usage (Mac M1/M2 compatibility). + """ + finished_signal = pyqtSignal(list) + error_signal = pyqtSignal(str) + + def __init__(self, video_path, start_ms, end_ms, config_path): + super().__init__() + self.video_path = os.path.abspath(video_path) + self.start_ms = start_ms + self.end_ms = end_ms + self.config_path = config_path + + def run(self): + try: + # Import library inside thread to avoid blocking main thread at startup + from opensportslib import model + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_input_json = os.path.join(tmp_dir, "temp_test.json") + tmp_config_yaml = os.path.join(tmp_dir, "temp_config.yaml") + tmp_output_json = os.path.join(tmp_dir, "predictions.json") + + # --- 1. Load and dynamically patch the YAML config --- + with open(self.config_path, 'r', encoding='utf-8') as f: + config_dict = yaml.safe_load(f) + + classes = config_dict.get("DATA", {}).get("classes", []) + + # ๐Ÿš€ [MAC CPU ADAPTATION & PATH FIXES] ๐Ÿš€ + # Force CPU mode and disable Multi-GPU dynamically + if "SYSTEM" not in config_dict: config_dict["SYSTEM"] = {} + config_dict["SYSTEM"]["work_dir"] = tmp_dir + config_dict["SYSTEM"]["device"] = "cpu" + config_dict["SYSTEM"]["GPU"] = 0 + config_dict["SYSTEM"]["gpu_id"] = 0 + + if "MODEL" not in config_dict: config_dict["MODEL"] = {} + config_dict["MODEL"]["multi_gpu"] = False + + # Override dataloader paths for test + if "DATA" in config_dict and "test" in config_dict["DATA"]: + config_dict["DATA"]["test"]["video_path"] = os.path.dirname(self.video_path) + config_dict["DATA"]["test"]["path"] = tmp_input_json + config_dict["DATA"]["test"]["results"] = "predictions" + + with open(tmp_config_yaml, 'w', encoding='utf-8') as f: + yaml.dump(config_dict, f) + + # --- 2. Create temporary JSON for the single video --- + test_data = { + "version": "2.0", + "task": "action_spotting", + "labels": {"ball_action": {"type": "single_label", "labels": classes}}, + "data": [{ + "id": "inf_vid", + "inputs": [{"path": self.video_path, "type": "video", "fps": 25.0}], + # ๅฟ…้กปๆ”พไธ€ไธช Dummy event ้ช—่ฟ‡ DataLoader + "events": [{"head": "ball_action", "label": classes[0] if classes else "Unknown", "position_ms": 0}] + }] + } + with open(tmp_input_json, 'w', encoding='utf-8') as f: + json.dump(test_data, f) + + # --- 3. Execute model inference --- + loc_model = model.localization(config=tmp_config_yaml) + + try: + # ่ฟ่กŒๆŽจ็†ใ€‚่ฟ™้‡Œไธ€ๅฎšไผšๆŠ›ๅ‡บ FileNotFoundError๏ผŒๅ› ไธบๆก†ๆžถๅบ•ๅฑ‚็š„่ฏ„ไผฐๅ™จๆ‰พไธๅˆฐๆ–‡ไปถ + loc_model.infer( + test_set=tmp_input_json, + pretrained="jeetv/snpro-snbas-2024" + ) + except FileNotFoundError: + # [ๅ…ณ้”ฎไฟฎๅค 4]๏ผš้œธๆฐ”ๅฟฝ็•ฅ๏ผ + # ๅ› ไธบๆŠฅ้”™ๅ‘็”ŸๅœจๆŽจ็†ๅฎŒๆˆไน‹ๅŽ็š„โ€œ่ฏ„ไผฐ้˜ถๆฎตโ€๏ผŒๆ‰€ไปฅๆˆ‘ไปฌ็›ดๆŽฅ catch ๆމ่ฟ™ไธช้”™่ฏฏ๏ผŒ + # ๅ‡่ฃ…ๆ— ไบ‹ๅ‘็”Ÿ๏ผŒ็›ดๆŽฅ่ฟ›ๅ…ฅไธ‹ไธ€ๆญฅๅŽปๆทฑๅฑ‚ๆ–‡ไปถๅคน้‡Œๆž็”Ÿๆˆ็š„ JSONใ€‚ + pass + + # --- 4. Parse result JSON --- + # ้€’ๅฝ’ๆœ็ดขไธดๆ—ถๆ–‡ไปถๅคนไธ‹็š„ๆ‰€ๆœ‰ .json ๆ–‡ไปถ๏ผˆๅฎŒ็พŽ็ฉฟ้€ checkpoints/xxx ๅตŒๅฅ—ๆ–‡ไปถๅคน๏ผ‰ + search_pattern = os.path.join(tmp_dir, "**", "*.json") + all_jsons = glob.glob(search_pattern, recursive=True) + + valid_preds = [] + for f in all_jsons: + filename = os.path.basename(f) + # ๆŽ’้™คๆމๆˆ‘ไปฌ่‡ชๅทฑ็”Ÿๆˆ็š„่พ“ๅ…ฅๆ•ฐๆฎๅ’Œ้…็ฝฎๆ–‡ไปถ + if "temp_test" not in filename and "temp_config" not in filename: + valid_preds.append(f) + + if valid_preds: + # ๆ‰พๅˆฐๆœ€ๆ–ฐ็”Ÿๆˆ็š„้‚ฃไธ€ไธช๏ผˆ้˜ฒๆญขๆœ‰ๅคšไธชๆ—งๆ–‡ไปถๅนฒๆ‰ฐ๏ผ‰ + actual_output_json = max(valid_preds, key=os.path.getctime) + else: + raise FileNotFoundError(f"Could not find any generated prediction JSON in {tmp_dir}/checkpoints/") + + predicted_events = [] + if os.path.exists(actual_output_json): + with open(actual_output_json, 'r', encoding='utf-8') as f: + output_data = json.load(f) + + raw_evts = output_data.get("data", [{}])[0].get("events", []) + for evt in raw_evts: + p_ms = int(evt.get("position_ms", 0)) + + if p_ms == 0 and evt.get("label") == (classes[0] if classes else "Unknown"): + continue + + if p_ms >= self.start_ms and (self.end_ms == 0 or p_ms <= self.end_ms): + predicted_events.append({ + "head": evt.get("head", "ball_action"), + "label": evt.get("label", "Unknown"), + "position_ms": p_ms + }) + + self.finished_signal.emit(predicted_events) + + except Exception as e: + import traceback + traceback.print_exc() + self.error_signal.emit(str(e)) + + +class LocalizationInferenceManager(QObject): + """ + High-level controller that manages the inference thread lifecycle. + """ + inference_finished = pyqtSignal(list) + inference_error = pyqtSignal(str) + + def __init__(self, main_window): + super().__init__(main_window) + self.main = main_window + self.worker = None + + def start_inference(self, video_path: str, start_ms: int, end_ms: int): + if self.worker and self.worker.isRunning(): return + config_path = os.path.join(os.getcwd(), "loc_config.yaml") + self.worker = LocInferenceWorker(video_path, start_ms, end_ms, config_path) + self.worker.finished_signal.connect(self._on_finished) + self.worker.error_signal.connect(self._on_error) + self.worker.start() + + def _on_finished(self, events): + self.inference_finished.emit(events) + self.worker = None + + def _on_error(self, err_msg): + self.inference_error.emit(err_msg) + self.worker = None \ No newline at end of file diff --git a/annotation_tool/controllers/localization/localization_manager.py b/annotation_tool/controllers/localization/localization_manager.py index 9ef7461..ef15725 100644 --- a/annotation_tool/controllers/localization/localization_manager.py +++ b/annotation_tool/controllers/localization/localization_manager.py @@ -9,6 +9,7 @@ from models import CmdType # [NEW] Import the unified MediaController from controllers.media_controller import MediaController +from .loc_inference import LocalizationInferenceManager class LocalizationManager: """ @@ -24,6 +25,10 @@ def __init__(self, main_window): self.left_panel = self.ui_root.left_panel self.center_panel = self.ui_root.center_panel self.right_panel = self.ui_root.right_panel + + self.inference_manager = LocalizationInferenceManager(self.main) + self.inference_manager.inference_finished.connect(self._on_inference_success) + self.inference_manager.inference_error.connect(self._on_inference_error) # [NEW] Initialize Media Controller # We access the underlying QMediaPlayer from the UI wrapper @@ -62,7 +67,7 @@ def setup_connections(self): media.durationChanged.connect(timeline.set_duration) timeline.seekRequested.connect(media.set_position) - # [CHANGED] Use MediaController for playback control + # Use MediaController for playback control pb.stopRequested.connect(self.media_controller.stop) pb.playPauseRequested.connect(self.media_controller.toggle_play_pause) @@ -73,6 +78,18 @@ def setup_connections(self): pb.nextPrevAnnotRequested.connect(self._navigate_annotation) # --- Right Panel --- + #Smart Annotation UI + if hasattr(self.right_panel, 'smart_widget'): + smart_ui = self.right_panel.smart_widget + smart_ui.setTimeRequested.connect(self._on_smart_set_time) + smart_ui.runInferenceRequested.connect(self._run_localization_inference) + smart_ui.confirmSmartRequested.connect(self._confirm_smart_events) + smart_ui.clearSmartRequested.connect(self._clear_smart_events) + + # Tab switch to toggle timeline markers + self.right_panel.tabs.currentChanged.connect(self._on_tab_switched) + + tabs = self.right_panel.annot_mgmt.tabs table = self.right_panel.table @@ -90,11 +107,32 @@ def setup_connections(self): table.annotationDeleted.connect(self._on_delete_single_annotation) table.annotationModified.connect(self._on_annotation_modified) + table.updateTimeForSelectedRequested.connect(self._on_update_time_for_selected) + def _on_media_position_changed(self, ms): self.center_panel.timeline.set_position(ms) time_str = self._fmt_ms_full(ms) self.right_panel.annot_mgmt.tabs.update_current_time(time_str) + def _on_update_time_for_selected(self, old_event): + """ + Handles the logic when the user clicks the + 'Set to Current Video Time' button. + """ + if not self.current_video_path: + return + + # 1. Get the current playback position in milliseconds + current_ms = self.center_panel.media_preview.player.position() + + # 2. Copy the old event and update its timestamp + new_event = old_event.copy() + new_event['position_ms'] = current_ms + + # 3. Reuse the existing modification logic + self._on_annotation_modified(old_event, new_event) + + # --- Video Loading Logic (Strict Classification Style via Controller) --- def on_clip_selected(self, current_idx, previous_idx): if not current_idx.isValid(): @@ -291,6 +329,8 @@ def _on_spotting_triggered(self, head, label): self.main.show_temp_msg("Event Created", f"{head}: {label}") self.main.update_save_export_button_state() + self._reselect_event(new_event) + # --- Table Modification --- def _on_annotation_modified(self, old_event, new_event): events = self.model.localization_events.get(self.current_video_path, []) @@ -318,6 +358,8 @@ def _on_annotation_modified(self, old_event, new_event): self.main.show_temp_msg("Event Updated", "Modified") self.main.update_save_export_button_state() + self._reselect_event(new_event) + def _on_delete_single_annotation(self, item_data): events = self.model.localization_events.get(self.current_video_path, []) if item_data not in events: return @@ -482,8 +524,144 @@ def _select_row_by_time(self, time_ms): self.right_panel.table.table.scrollTo(idx) break + def _reselect_event(self, target_event): + model = self.right_panel.table.model + table_view = self.right_panel.table.table + + table_view.selectionModel().blockSignals(True) + + for row in range(model.rowCount()): + item = model.get_annotation_at(row) + if not item: continue + + if (item.get('position_ms') == target_event.get('position_ms') and + item.get('head') == target_event.get('head') and + item.get('label') == target_event.get('label')): + + idx = model.index(row, 0) + + table_view.selectRow(row) + table_view.scrollTo(idx) + + if hasattr(self.right_panel.table, 'btn_set_time'): + self.right_panel.table.btn_set_time.setEnabled(True) + + break + + table_view.selectionModel().blockSignals(False) + def _fmt_ms_full(self, ms): s = ms // 1000 m = s // 60 h = m // 60 return f"{h:02}:{m%60:02}:{s%60:02}.{ms%1000:03}" + + + # ========================================== + # --- Smart Annotation Control Logic --- + # ========================================== + + def _on_smart_set_time(self, target: str): + """ + Triggered when 'Set to Current' is clicked in Smart Spotting UI. + Gets current player position and updates the smart UI. + """ + player = self.center_panel.media_preview.player + current_ms = player.position() + time_str = self._fmt_ms_full(current_ms) + + # Update the UI display and internal state in the Smart Widget + self.right_panel.smart_widget.update_time_display(target, time_str, current_ms) + + + def _run_localization_inference(self, start_ms: int, end_ms: int): + if not self.current_video_path: + return + if start_ms >= end_ms and end_ms != 0: + from PyQt6.QtWidgets import QMessageBox + QMessageBox.warning(self.main, "Invalid Range", "End time must be greater than Start time.") + return + + self.main.show_temp_msg("Smart Inference", "Running OpenSportsLib Localization Model...") + self.right_panel.smart_widget.btn_run_infer.setEnabled(False) + self.inference_manager.start_inference(self.current_video_path, start_ms, end_ms) + + + def _on_inference_success(self, predicted_events: list): + self.right_panel.smart_widget.btn_run_infer.setEnabled(True) + if not self.current_video_path: + return + + self.model.smart_localization_events[self.current_video_path] = predicted_events + self.main.show_temp_msg("Smart Inference", f"Success: Found {len(predicted_events)} events.") + + if self.right_panel.tabs.currentIndex() == 1: + self._display_smart_events(self.current_video_path) + + def _on_inference_error(self, error_msg: str): + self.right_panel.smart_widget.btn_run_infer.setEnabled(True) + from PyQt6.QtWidgets import QMessageBox + QMessageBox.critical(self.main, "Inference Error", f"Failed to run model:\n{error_msg}") + + def _confirm_smart_events(self): + """ๅฐ†ๆ™บ่ƒฝ้ข„ๆต‹ๅˆๅนถๅˆฐๆ‰‹ๅทฅๆ ‡ๆณจไธญ""" + if not self.current_video_path: + return + + smart_events = self.model.smart_localization_events.get(self.current_video_path, []) + if not smart_events: + return + + # ๅˆๅง‹ๅŒ–ๅฝ“ๅ‰่ง†้ข‘็š„ๆ‰‹ๅทฅไบ‹ไปถๅˆ—่กจ๏ผˆๅฆ‚ๆžœๆฒกๆœ‰๏ผ‰ + if self.current_video_path not in self.model.localization_events: + self.model.localization_events[self.current_video_path] = [] + + # ๅˆๅนถไบ‹ไปถ (ๆญคๅค„ๆš‚ไธๅค„็† undo/redo) + self.model.localization_events[self.current_video_path].extend(smart_events) + + # ๆŒ‰็…งๆ—ถ้—ดๆŽ’ๅบ + self.model.localization_events[self.current_video_path].sort(key=lambda x: x.get('position_ms', 0)) + + # ๆธ…็ฉบๅฝ“ๅ‰็š„ Smart Events + self.model.smart_localization_events[self.current_video_path] = [] + self._display_smart_events(self.current_video_path) # ๅˆทๆ–ฐไธบ็ฉบ่กจ + + # ๆ็คบ็”จๆˆท + self.main.show_temp_msg("Smart Spotting", "Predictions confirmed and merged into Hand Annotations.") + self.model.is_data_dirty = True + self.main.update_save_export_button_state() + + def _clear_smart_events(self): + if not self.current_video_path: + return + self.model.smart_localization_events[self.current_video_path] = [] + self._display_smart_events(self.current_video_path) + self.main.show_temp_msg("Smart Spotting", "Cleared smart predictions.") + + def _display_smart_events(self, video_path: str): + """Dedicated method to display ONLY smart events in the smart table and timeline.""" + events = self.model.smart_localization_events.get(video_path, []) + # ๆ›ดๆ–ฐ Smart Table + self.right_panel.smart_widget.smart_table.set_data(events) + # ๆ›ดๆ–ฐ Timeline + markers = [] + for evt in events: + # Smart events ไนŸๅฏไปฅไฝฟ็”จไธๅŒ็š„้ขœ่‰ฒ๏ผŒๆฏ”ๅฆ‚่“่‰ฒ๏ผŒ็”จๆฅๅ’Œๆ‰‹ๅทฅๆ ‡ๆณจ๏ผˆ็บข่‰ฒ๏ผ‰ๅŒบๅˆ† + from PyQt6.QtGui import QColor + markers.append({ + 'start_ms': evt.get('position_ms', 0), + 'color': QColor('deepskyblue') + }) + self.center_panel.timeline.set_markers(markers) + + def _on_tab_switched(self, index: int): + """ๅˆ‡ๆข Tab ๆ—ถ้š”็ฆป่ง†่ง‰็Šถๆ€""" + if not self.current_video_path: + return + + if index == 0: + # ๅ›žๅˆฐๆ‰‹ๅทฅๆ ‡ๆณจ๏ผŒๅŠ ่ฝฝๅŽŸๅง‹็š„ๆ‰‹ๅทฅไบ‹ไปถ + self._display_events_for_item(self.current_video_path) + elif index == 1: + # ๅŽปๅˆฐๆ™บ่ƒฝๆ ‡ๆณจ๏ผŒๅŠ ่ฝฝๆ™บ่ƒฝไบ‹ไปถ + self._display_smart_events(self.current_video_path) diff --git a/annotation_tool/controllers/media_controller.py b/annotation_tool/controllers/media_controller.py index f3a3907..131775a 100644 --- a/annotation_tool/controllers/media_controller.py +++ b/annotation_tool/controllers/media_controller.py @@ -5,73 +5,122 @@ class MediaController(QObject): """ A unified controller for managing video playback logic across all modes. - Now handles: - 1. Robust Playback State (Stop -> Clear -> Load -> Delay -> Play) - 2. Timer Cancellation (Prevents race conditions on rapid switching) - 3. Visual Clearing (Forces VideoWidget to refresh) + Now includes a 'Watchdog' mechanism to catch silent hardware decoder failures + (e.g., AV1 video fails, but Audio keeps playing causing a zombie black screen). """ def __init__(self, player: QMediaPlayer, video_widget: QWidget = None): super().__init__() self.player = player self.video_widget = video_widget - # [CRITICAL FIX] Use an instance timer so we can cancel it! - # This prevents the "Ghost Timer" bug where a video starts playing - # *after* the user has closed the project or switched modes. + # 1. Connect standard error signals + self.player.errorOccurred.connect(self._handle_media_error) + + # 2. Setup Play Timer self.play_timer = QTimer() self.play_timer.setSingleShot(True) - self.play_timer.setInterval(150) # 150ms delay + self.play_timer.setInterval(150) self.play_timer.timeout.connect(self._execute_play) - def load_and_play(self, file_path: str, auto_play: bool = True): + # 3. [NEW] Setup Watchdog Timer to catch silent Black Screens + self.watchdog_timer = QTimer() + self.watchdog_timer.setSingleShot(True) + self.watchdog_timer.setInterval(1500) # Check 1.5 seconds after play starts + self.watchdog_timer.timeout.connect(self._check_for_black_screen) + + # 4. [NEW] Monitor actual frames being drawn to the screen + self._frame_received = False + if self.video_widget and hasattr(self.video_widget, 'videoSink'): + sink = self.video_widget.videoSink() + if sink: + # Every time a pixel frame is actually rendered, this triggers + sink.videoFrameChanged.connect(self._on_frame_rendered) + + def _on_frame_rendered(self, *args): + """Marks that the GPU successfully decoded and drew at least one frame.""" + self._frame_received = True + + def _trigger_error_dialog(self, error_details: str): + """Stops playback immediately and blocks the UI with an error dialog.""" + self.stop() # Force kill playback + + try: + from ui.common.dialogs import MediaErrorDialog + error_dialog = MediaErrorDialog(error_details, parent=self.video_widget) + error_dialog.exec() # Block UI thread + except ImportError as e: + print(f"Failed to import MediaErrorDialog: {e}") + + def _check_for_black_screen(self): + """ + The Ultimate Catch: Watchdog timer triggered. + """ + is_playing = self.player.playbackState() == QMediaPlayer.PlaybackState.PlayingState + is_loaded = self.player.mediaStatus() in [QMediaPlayer.MediaStatus.LoadedMedia, QMediaPlayer.MediaStatus.BufferedMedia] + + if is_playing and is_loaded and self.player.hasVideo() and not self._frame_received: + # Pass a concise technical reason instead of a long paragraph + self._trigger_error_dialog("Watchdog Timeout: The hardware video decoder crashed silently and failed to render any frames within 1.5 seconds.") + + def _handle_media_status(self, status: QMediaPlayer.MediaStatus): + """ + Catches silent failures from MediaStatus. + """ + if status == QMediaPlayer.MediaStatus.InvalidMedia: + self._trigger_error_dialog("Status Error: Invalid Media or completely unsupported file format.") + + elif status == QMediaPlayer.MediaStatus.LoadedMedia: + if not self.player.hasVideo(): + self._trigger_error_dialog("Status Error: The file has no decodable video stream (e.g., missing AV1 hardware decoder).") + + def _handle_media_error(self, error: QMediaPlayer.Error, error_string: str): """ - Standardized sequence to load and play a video. + Catches standard MediaFoundation/AVFoundation load errors. """ - # 1. Force Stop & Cancel any pending play requests + if error != QMediaPlayer.Error.NoError: + print(f"[Media Error] Code: {error}, Message: {error_string}") + self._trigger_error_dialog(f"Player Error Code {error}: {error_string}") + + def load_and_play(self, file_path: str, auto_play: bool = True): self.stop() if not file_path: return - # 2. Load Source self.player.setSource(QUrl.fromLocalFile(file_path)) - # 3. Auto-play with safety delay if auto_play: self.play_timer.start() def _execute_play(self): - """Actual slot called by timer to start playback.""" + """Starts playback and launches the Watchdog.""" + self._frame_received = False # Reset the frame flag self.player.play() + self.watchdog_timer.start() # Unleash the watchdog def toggle_play_pause(self): - """Toggle between Play and Pause.""" if self.player.playbackState() == QMediaPlayer.PlaybackState.PlayingState: self.player.pause() else: + self._frame_received = False self.player.play() + self.watchdog_timer.start() def stop(self): - """ - Stops playback, clears source, cancels timers, and forces UI refresh. - """ - # A. Cancel pending auto-play if user clicked away quickly + """Stops playback and cancels all timers.""" if self.play_timer.isActive(): self.play_timer.stop() + if self.watchdog_timer.isActive(): + self.watchdog_timer.stop() - # B. Stop Player logic self.player.stop() self.player.setSource(QUrl()) - # C. [Visual Fix] Force the video widget to repaint/update - # This helps clear the "stuck frame" from the GPU buffer if self.video_widget: self.video_widget.update() self.video_widget.repaint() - def set_looping(self, enable: bool): - """Helper to set looping.""" if enable: self.player.setLoops(QMediaPlayer.Loops.Infinite) else: diff --git a/annotation_tool/main.py b/annotation_tool/main.py index cdabad4..4d7a0e9 100644 --- a/annotation_tool/main.py +++ b/annotation_tool/main.py @@ -1,8 +1,15 @@ +import os import sys +import multiprocessing + +os.environ["PYTORCH_JIT"] = "0" + from PyQt6.QtWidgets import QApplication from viewer import ActionClassifierApp if __name__ == '__main__': + multiprocessing.freeze_support() + app = QApplication(sys.argv) window = ActionClassifierApp() window.show() diff --git a/annotation_tool/models/app_state.py b/annotation_tool/models/app_state.py index 284e79a..4ad396e 100644 --- a/annotation_tool/models/app_state.py +++ b/annotation_tool/models/app_state.py @@ -8,8 +8,13 @@ class CmdType(Enum): # --- Classification commands --- ANNOTATION_CONFIRM = auto() # Persist a user-confirmed annotation to the model + BATCH_ANNOTATION_CONFIRM = auto() # [NEW] Persist a batch of annotations as a single action UI_CHANGE = auto() # Fine-grained UI toggle (radio/checkbox changes) + # [NEW] Smart Annotation commands for Undo/Redo + SMART_ANNOTATION_RUN = auto() + BATCH_SMART_ANNOTATION_RUN = auto() + # --- Shared schema commands (used by both modes) --- SCHEMA_ADD_CAT = auto() # Add a category/head SCHEMA_DEL_CAT = auto() # Delete a category/head @@ -49,6 +54,8 @@ def __init__(self): self.current_task_name = "Untitled Task" self.modalities = ["video"] + self.is_multi_view = False + # --- Schema / labels --- # Format: { head_name: { "type": "single|multi", "labels": [..] } } self.label_definitions = {} @@ -57,6 +64,10 @@ def __init__(self): # Format: { video_path: { "Head": "Label", "Head2": ["L1", "L2"] } } self.manual_annotations = {} + # [NEW] Store AI inference results to persist the Donut Chart state + # Format: { video_path: { "action": { "label": "Dive", "conf_dict": {...} } } } + self.smart_annotations = {} + # Classification import metadata (kept for backward compatibility) self.imported_input_metadata = {} # key: (action_id, filename) self.imported_action_metadata = {} # key: action_id @@ -65,6 +76,10 @@ def __init__(self): # Format: { video_path: [ { "head": ..., "label": ..., "position_ms": ... }, ... ] } self.localization_events = {} + + # localization-smart annotation + self.smart_localization_events = {} + # --- Common clip list --- # Each item: { "name": "...", "path": "...", "source_files": [...] } # This is the shared source of truth for the Project Tree @@ -86,8 +101,13 @@ def reset(self, full_reset: bool = False): self.json_loaded = False self.is_data_dirty = False + self.is_multi_view = False + self.manual_annotations = {} + # [NEW] Clear smart annotations on reset + self.smart_annotations = {} self.localization_events = {} + self.smart_localization_events = {} self.imported_input_metadata = {} self.imported_action_metadata = {} @@ -792,4 +812,4 @@ def _fmt(title, lst): if warn_duplicates: warnings.append(_fmt("Duplicate dense captions found", warn_duplicates)) - return True, "", "\n\n".join(warnings) \ No newline at end of file + return True, "", "\n\n".join(warnings) diff --git a/annotation_tool/requirements.txt b/annotation_tool/requirements.txt index 69352a3..dd51e6e 100644 --- a/annotation_tool/requirements.txt +++ b/annotation_tool/requirements.txt @@ -1,2 +1,5 @@ PyQt6 -pyinstaller \ No newline at end of file +pyinstaller +torch-geometric==2.7.0 +soccernetpro==0.0.1.dev11 +wandb diff --git a/annotation_tool/style/style.qss b/annotation_tool/style/style.qss index 7eb18dd..955854c 100644 --- a/annotation_tool/style/style.qss +++ b/annotation_tool/style/style.qss @@ -184,7 +184,7 @@ QSlider::sub-page:horizontal { QPushButton[class="project_control_btn"] { border-radius: 6px; padding: 5px; - background-color: #444; + background-color: #5a5a5a; color: #EEE; border: 1px solid #555; font-weight: bold; @@ -270,7 +270,7 @@ QPushButton[class="welcome_secondary_btn"] { font-size: 14px; font-weight: bold; background-color: transparent; - border: 1px solid #84ff00; + border: 1px solid #84ff00; color: #84ff00; border-radius: 6px; } @@ -348,24 +348,40 @@ QPushButton[class="editor_save_btn"]:hover { /* Target: ui/classification/event_editor/dynamic_widgets.py */ /* Common Header Style */ -QLabel[class="group_head_lbl"] { +QLabel.group_head_lbl { font-weight: bold; - font-size: 13px; + font-size: 20px; + padding-top: 4px; + padding-bottom: 2px; + border-bottom: 1px solid #444; + margin-bottom: 4px; +} + +QLabel.group_head_single { + color: #00BFFF; } -/* Specific Colors for Single vs Multi Label Headers */ -QLabel[class="group_head_single"] { - color: #00BFFF; /* Cyan */ +QLabel.group_head_multi { + color: #00BFFF; +} + +/* --- 2. Label Items (Challenge, Dive, etc.) --- */ +QRadioButton.label_item, +QCheckBox.label_item { + font-size: 14px; + color: #DDD; + padding: 2px; } -QLabel[class="group_head_multi"] { - color: #32CD32; /* Lime Green */ +QRadioButton.label_item:hover, +QCheckBox.label_item:hover { + color: #FFF; } /* Small 'X' Remove Buttons (Replaces utils.get_square_remove_btn_style) */ QPushButton[class="icon_remove_btn"] { background-color: transparent; - color: #888; + color: #adadad; border: none; font-weight: bold; font-size: 16px; @@ -388,6 +404,28 @@ QLabel[class="player_time_lbl"] { } +/* --- Smart Annotation Tabs Styling --- */ +QTabBar::tab { + background: #2c2c2d; + color: #969696; + padding: 8px 15px; + border: 1px solid #1E1E1E; + border-bottom: none; + border-top-left-radius: 4px; + border-top-right-radius: 4px; +} + +QTabBar::tab:selected { + background: #585757; + color: #00BFFF; + border-bottom: 2px solid #00BFFF; + font-weight: bold; +} + +QTabBar::tab:hover:!selected { + background: #3E3E3E; + color: #DCDCDC; +} /* ======================================================= Localization Mode Styles diff --git a/annotation_tool/ui/.DS_Store b/annotation_tool/ui/.DS_Store deleted file mode 100644 index 738e20d..0000000 Binary files a/annotation_tool/ui/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/classification/.DS_Store b/annotation_tool/ui/classification/.DS_Store deleted file mode 100644 index ffe21a5..0000000 Binary files a/annotation_tool/ui/classification/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/classification/event_editor/dynamic_widgets.py b/annotation_tool/ui/classification/event_editor/dynamic_widgets.py index 32a8561..cd4db1a 100644 --- a/annotation_tool/ui/classification/event_editor/dynamic_widgets.py +++ b/annotation_tool/ui/classification/event_editor/dynamic_widgets.py @@ -15,7 +15,8 @@ def __init__(self, head_name, definition, parent=None): self.definition = definition self.layout = QVBoxLayout(self) - self.layout.setContentsMargins(0, 5, 0, 15) + self.layout.setContentsMargins(0, 0, 0, 2) + self.layout.setSpacing(2) # Header header_layout = QHBoxLayout() @@ -37,7 +38,7 @@ def __init__(self, head_name, definition, parent=None): self.radio_group.setExclusive(True) self.radio_container = QWidget() self.radio_layout = QVBoxLayout(self.radio_container) - self.radio_layout.setContentsMargins(10, 0, 0, 0) + self.radio_layout.setContentsMargins(5, 0, 0, 0) self.layout.addWidget(self.radio_container) # Input for new label @@ -68,10 +69,11 @@ def update_radios(self, labels): for i, lbl_text in enumerate(labels): row_widget = QWidget() row_layout = QHBoxLayout(row_widget) - row_layout.setContentsMargins(0, 2, 0, 2) + row_layout.setContentsMargins(0, 0, 0, 0) rb = QRadioButton(lbl_text) self.radio_group.addButton(rb, i) + rb.setProperty("class", "label_item") del_label_btn = QPushButton("ร—") del_label_btn.setCursor(Qt.CursorShape.PointingHandCursor) @@ -113,11 +115,11 @@ def __init__(self, head_name, definition, parent=None): self.definition = definition self.layout = QVBoxLayout(self) - self.layout.setContentsMargins(0, 5, 0, 15) + self.layout.setContentsMargins(0, 2, 0, 5) # Header header_layout = QHBoxLayout() - self.lbl_head = QLabel(head_name + " (Multi)") + self.lbl_head = QLabel(head_name) self.lbl_head.setProperty("class", "group_head_lbl group_head_multi") self.btn_del_cat = QPushButton("ร—") @@ -131,7 +133,7 @@ def __init__(self, head_name, definition, parent=None): self.checkbox_container = QWidget() self.checkbox_layout = QVBoxLayout(self.checkbox_container) - self.checkbox_layout.setContentsMargins(10, 0, 0, 0) + self.checkbox_layout.setContentsMargins(5, 0, 0, 0) self.layout.addWidget(self.checkbox_container) # Input @@ -158,11 +160,12 @@ def update_checkboxes(self, new_types): for type_name in sorted(list(set(new_types))): row_widget = QWidget() row_layout = QHBoxLayout(row_widget) - row_layout.setContentsMargins(0, 2, 0, 2) + row_layout.setContentsMargins(0, 0, 0, 0) cb = QCheckBox(type_name) cb.clicked.connect(self._on_box_clicked) self.checkboxes[type_name] = cb + cb.setProperty("class", "label_item") del_label_btn = QPushButton("ร—") del_label_btn.setCursor(Qt.CursorShape.PointingHandCursor) @@ -185,4 +188,4 @@ def set_checked_labels(self, label_list): if not label_list: label_list = [] for text, cb in self.checkboxes.items(): cb.setChecked(text in label_list) - self.blockSignals(False) + self.blockSignals(False) \ No newline at end of file diff --git a/annotation_tool/ui/classification/event_editor/editor.py b/annotation_tool/ui/classification/event_editor/editor.py index 02c3c5d..08350fe 100644 --- a/annotation_tool/ui/classification/event_editor/editor.py +++ b/annotation_tool/ui/classification/event_editor/editor.py @@ -1,26 +1,157 @@ +import math from PyQt6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, - QGroupBox, QLineEdit, QScrollArea, QFrame + QGroupBox, QLineEdit, QScrollArea, QFrame, QProgressBar, QToolTip, QTextEdit, QTabWidget, QComboBox ) -from PyQt6.QtCore import Qt, pyqtSignal +from PyQt6.QtCore import Qt, pyqtSignal, QRectF, QPointF +from PyQt6.QtGui import QPainter, QColor, QPen, QFont, QCursor +import sys from .dynamic_widgets import DynamicSingleLabelGroup, DynamicMultiLabelGroup +class NativeDonutChart(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + self.setMinimumSize(160, 160) + self.setMouseTracking(True) + + self.data_dict = {} + self.top_label = "" + self.slices_info = [] + self.setVisible(False) + + def update_chart(self, top_label, conf_dict): + self.top_label = top_label + + sorted_data = {top_label: conf_dict.get(top_label, 0.0)} + for k, v in conf_dict.items(): + if k != top_label: + sorted_data[k] = v + + self.data_dict = sorted_data + self.repaint() + self.setVisible(True) + + def paintEvent(self, event): + if not self.data_dict: + return + + painter = QPainter(self) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + + margin = 30 + rect = QRectF(margin, margin, self.width() - margin * 2, self.height() - margin * 2) + pen_width = 35 + + start_angle_qt = 90 * 16 + self.slices_info.clear() + + color_top = QColor("#4CAF50") + colors_other = [QColor("#607D8B"), QColor("#78909C"), QColor("#546E7A"), QColor("#455A64")] + color_idx = 0 + + current_angle_deg = 0.0 + + for label, prob in self.data_dict.items(): + span_deg = prob * 360 + span_angle_qt = int(round(-span_deg * 16)) + + if span_angle_qt == 0: + continue + + color = color_top if label == self.top_label else colors_other[color_idx % len(colors_other)] + if label != self.top_label: + color_idx += 1 + + pen = QPen(color) + pen.setWidth(pen_width) + pen.setCapStyle(Qt.PenCapStyle.FlatCap) + painter.setPen(pen) + + painter.drawArc(rect, start_angle_qt, span_angle_qt) + + self.slices_info.append({ + "label": label, + "prob": prob, + "start_deg": current_angle_deg, + "end_deg": current_angle_deg + span_deg + }) + + start_angle_qt += span_angle_qt + current_angle_deg += span_deg + + painter.setPen(QColor("white")) + font = QFont("Arial", 12, QFont.Weight.Bold) + painter.setFont(font) + top_prob = self.data_dict.get(self.top_label, 0.0) + + text_rect = QRectF(0, 0, self.width(), self.height()) + painter.drawText(text_rect, Qt.AlignmentFlag.AlignCenter, f"{self.top_label}\n{top_prob*100:.1f}%") + + def mouseMoveEvent(self, event): + if not self.data_dict: + return + + pos = event.position() + center_x = self.width() / 2 + center_y = self.height() / 2 + dx = pos.x() - center_x + dy = pos.y() - center_y + + distance = math.sqrt(dx**2 + dy**2) + radius = (self.width() - 60) / 2 + pen_width = 35 + + if distance < (radius - pen_width/2) or distance > (radius + pen_width/2): + QToolTip.hideText() + self.setCursor(Qt.CursorShape.ArrowCursor) + return + + angle_rad = math.atan2(dy, dx) + angle_deg = math.degrees(angle_rad) + 90 + if angle_deg < 0: + angle_deg += 360 + + hovered_text = None + for slice_info in self.slices_info: + if slice_info["start_deg"] <= angle_deg <= slice_info["end_deg"]: + hovered_text = f"{slice_info['label']}: {slice_info['prob']*100:.1f}%" + break + + if hovered_text: + self.setCursor(Qt.CursorShape.PointingHandCursor) + QToolTip.showText(event.globalPosition().toPoint(), hovered_text, self) + else: + self.setCursor(Qt.CursorShape.ArrowCursor) + QToolTip.hideText() + + class ClassificationEventEditor(QWidget): - """ - Right Panel for Classification Mode. - Renamed from ClassRightPanel to ClassificationEventEditor for consistency with folder name. - """ - add_head_clicked = pyqtSignal(str) remove_head_clicked = pyqtSignal(str) style_mode_changed = pyqtSignal(str) + + smart_infer_requested = pyqtSignal() + confirm_infer_requested = pyqtSignal(dict) + + batch_confirm_requested = pyqtSignal(dict) + + annotation_saved = pyqtSignal(dict) + smart_confirm_requested = pyqtSignal() # [NEW] Signal emitted when confirming from the Smart Tab + batch_run_requested = pyqtSignal(int, int) + + # [NEW] Signals for tab-aware clearing + hand_clear_requested = pyqtSignal() + smart_clear_requested = pyqtSignal() def __init__(self, parent=None): super().__init__(parent) - self.setFixedWidth(350) + self.setFixedWidth(320) layout = QVBoxLayout(self) + self.is_batch_mode_active = False + self.pending_batch_results = {} + # 1. Undo/Redo Controls h_undo = QHBoxLayout() self.undo_btn = QPushButton("Undo") @@ -50,8 +181,39 @@ def __init__(self, parent=None): schema_layout.addWidget(self.add_head_btn) layout.addWidget(schema_box) - # 4. Dynamic Annotation Area - self.manual_box = QGroupBox("Annotations") + # [NEW] Create QTabWidget to hold both annotation modes + self.tabs = QTabWidget() + + # 1. ใ€ๆ ธๅฟƒใ€‘ๅฝปๅบ•็ฆ็”จ็œ็•ฅๆจกๅผ๏ผŒ้˜ฒๆญขๆ–‡ๅญ—ๅ˜ๆˆ "..." + self.tabs.setElideMode(Qt.TextElideMode.ElideNone) + + # 2. ๅผบๅˆถๆ ‡็ญพๆ ไธ่‡ชๅŠจๆ‰ฉๅฑ•๏ผŒไฝฟๅ…ถไป…ๅ ๆฎๆ–‡ๅญ—ๆ‰€้œ€็š„็ฉบ้—ด + self.tabs.tabBar().setExpanding(False) + + # 3. ไผ˜ๅŒ–ๆ ทๅผ่กจ๏ผš็งป้™ค min-width ้™ๅˆถ๏ผŒๅนถ่ฎพ็ฝฎๆž็ช„ Padding + self.tabs.setStyleSheet(""" + QTabBar::tab { + /* ่ฎพ็ฝฎ่พƒๅฐ็š„ๅทฆๅณ่พน่ท๏ผŒ็กฎไฟๆ–‡ๅญ—็ดงๅ‡‘ไธ”ๅฏ่ง */ + padding-left: 3px; + padding-right: 3px; + padding-top: 5px; + padding-bottom: 5px; + + /* ไฟๆŒๅญ—ไฝ“ๅคงๅฐ้€‚ไธญ */ + font-size: 13px; + + /* ็กฎไฟๆฒกๆœ‰ๆœ€ๅฐๅฎฝๅบฆๅ’Œๆœ€ๅคงๅฎฝๅบฆ็š„็กฌๆ€ง้™ๅˆถ */ + min-width: 0px; + max-width: 1000px; + } + """) + + self.tabs.setObjectName("annotation_tabs") + layout.addWidget(self.tabs, 1) # Add tabs to main layout with stretch factor 1 + + # --- 4. Hand Annotation Tab --- + # Changed from QGroupBox to QWidget to fit seamlessly inside the Tab + self.manual_box = QWidget() self.manual_box.setEnabled(False) manual_layout = QVBoxLayout(self.manual_box) @@ -66,20 +228,345 @@ def __init__(self, parent=None): scroll.setWidget(self.label_container) manual_layout.addWidget(scroll) + # Add the manual widget as the first tab + self.tabs.addTab(self.manual_box, "Hand Annotation") + + # --- 5. Smart Annotation Tab --- + # Changed from QGroupBox to QWidget to fit seamlessly inside the Tab + self.smart_box = QWidget() + smart_layout = QVBoxLayout(self.smart_box) + + # [NEW] Force all items in the smart tab to align to the top + # This prevents the inference buttons from jumping around + smart_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + + # Two Buttons for Inference + btn_h_layout = QHBoxLayout() + self.btn_smart_infer = QPushButton("Single Inference") + self.btn_smart_infer.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_smart_infer.clicked.connect(self.smart_infer_requested.emit) + + self.btn_batch_infer = QPushButton("Batch Inference") + self.btn_batch_infer.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_batch_infer.clicked.connect(lambda: self.batch_input_widget.setVisible(not self.batch_input_widget.isVisible())) + + btn_h_layout.addWidget(self.btn_smart_infer) + btn_h_layout.addWidget(self.btn_batch_infer) + smart_layout.addLayout(btn_h_layout) + + # Input Box for Batch Inference + self.batch_input_widget = QWidget() + h_batch = QHBoxLayout(self.batch_input_widget) + h_batch.setContentsMargins(0, 5, 0, 5) + # [NEW] Add descriptive labels for the Start and End comboboxes + self.lbl_start = QLabel("Start:") + self.spin_start = QComboBox() + + self.lbl_end = QLabel("End:") + self.spin_end = QComboBox() + + self.btn_run_batch = QPushButton("Run") + self.btn_run_batch.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_run_batch.clicked.connect(self._on_run_batch_clicked) + + # [MODIFIED] Add the labels and comboboxes to the horizontal layout in order + h_batch.addWidget(self.lbl_start) + h_batch.addWidget(self.spin_start) + h_batch.addWidget(self.lbl_end) + h_batch.addWidget(self.spin_end) + h_batch.addWidget(self.btn_run_batch) + + self.batch_input_widget.setVisible(False) + + + # [NEW] Connect validation signals to enforce i <= j rule + self.spin_start.currentIndexChanged.connect(self._validate_batch_range) + #self.spin_end.currentIndexChanged.connect(self._validate_batch_range) + + self.infer_progress = QProgressBar() + self.infer_progress.setRange(0, 0) + self.infer_progress.setVisible(False) + + self.chart_widget = NativeDonutChart() + + self.batch_result_text = QTextEdit() + self.batch_result_text.setReadOnly(True) + self.batch_result_text.setVisible(False) + self.batch_result_text.setMinimumHeight(120) + + smart_layout.addWidget(self.batch_input_widget) + smart_layout.addWidget(self.infer_progress) + smart_layout.addWidget(self.chart_widget, alignment=Qt.AlignmentFlag.AlignCenter) + smart_layout.addWidget(self.batch_result_text) + + # Add the smart widget as the second tab + self.tabs.addTab(self.smart_box, "Smart Annotation") + + # --- 7. Train Tab [RE-DESIGNED] --- + self.train_box = QWidget() + train_main_layout = QVBoxLayout(self.train_box) + train_main_layout.setContentsMargins(5, 5, 5, 5) + train_main_layout.setSpacing(10) + + # ไฝฟ็”จๆปšๅŠจๅŒบๅŸŸ๏ผŒ้˜ฒๆญขๅ‚ๆ•ฐ่ฟ‡ๅคšๆ—ถๆ˜พ็คบไธๅ…จ + train_scroll = QScrollArea() + train_scroll.setWidgetResizable(True) + train_scroll.setFrameShape(QFrame.Shape.NoFrame) + train_scroll_content = QWidget() + train_layout = QVBoxLayout(train_scroll_content) + train_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + + # A. ่ฎญ็ปƒ่ถ…ๅ‚ๆ•ฐ็ป„ (Hyperparameters) + hyper_group = QGroupBox("Hyperparameters") + hyper_form = QVBoxLayout(hyper_group) # ไฝฟ็”จๅž‚็›ดๅธƒๅฑ€ๅŒ…่ฃ…่กจๅ•่กŒ + + # ๅฐ่ฃ…ไธ€ไธช็ฎ€ๅ•็š„่กจๅ•่กŒๅ‡ฝๆ•ฐ + def add_form_row(label_text, widget): + row = QHBoxLayout() + lbl = QLabel(label_text) + lbl.setFixedWidth(80) + row.addWidget(lbl) + row.addWidget(widget) + return row + + self.spin_epochs = QComboBox() + self.spin_epochs.addItems(["1", "5", "10", "20", "50", "100"]) + self.spin_epochs.setEditable(True) + hyper_form.addLayout(add_form_row("Epochs:", self.spin_epochs)) + + self.edit_lr = QLineEdit("0.0001") + hyper_form.addLayout(add_form_row("LR:", self.edit_lr)) + + self.spin_batch = QComboBox() + self.spin_batch.addItems(["1", "2", "4", "8", "16"]) + self.spin_batch.setEditable(True) + hyper_form.addLayout(add_form_row("Batch:", self.spin_batch)) + + train_layout.addWidget(hyper_group) + + # B. ็กฌไปถ่ฎพ็ฝฎ็ป„ (Hardware - ้’ˆๅฏน Mac M1 ไผ˜ๅŒ–) + device_group = QGroupBox("Execution") + device_form = QVBoxLayout(device_group) + + self.combo_device = QComboBox() + # ้’ˆๅฏน M1 ๅขžๅŠ  mps ้€‰้กน + self.combo_device.addItems(["cpu", "mps (Metal)", "cuda"]) + if sys.platform == "darwin": + self.combo_device.setCurrentText("mps (Metal)") + device_form.addLayout(add_form_row("Device:", self.combo_device)) + + self.spin_workers = QComboBox() + self.spin_workers.addItems(["0", "2", "4"]) + device_form.addLayout(add_form_row("Workers:", self.spin_workers)) + + train_layout.addWidget(device_group) + + # C. ่ฎญ็ปƒๆ“ไฝœไธŽ็›‘ๆŽง (Action & Monitor) + h_train_btns = QHBoxLayout() # ๅˆ›ๅปบๆจชๅ‘ๅธƒๅฑ€ + + # 1. Start Training ๆŒ‰้’ฎ + self.btn_start_train = QPushButton("Start Training") + self.btn_start_train.setMinimumHeight(40) + self.btn_start_train.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_start_train.setStyleSheet(""" + QPushButton { + background-color: #007bff; + color: white; + font-weight: bold; + border-radius: 4px; + } + QPushButton:hover { background-color: #0069d9; } + QPushButton:disabled { background-color: #cccccc; color: #666666; } + """) + + # 2. Stop Training ๆŒ‰้’ฎ [NEW] + self.btn_stop_train = QPushButton("Stop Training") + self.btn_stop_train.setMinimumHeight(40) + self.btn_stop_train.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_stop_train.setEnabled(False) # ๅˆๅง‹ไธๅฏ็‚นๅ‡ป + # ๆ ทๅผไธŽ Clear Selection ไธ€่‡ด๏ผˆๆ ‡ๅ‡†ๆŒ‰้’ฎๆ ทๅผ๏ผ‰ + self.btn_stop_train.setProperty("class", "editor_control_btn") + + h_train_btns.addWidget(self.btn_start_train, 2) # Start ๅ ๆ›ดๅคš็ฉบ้—ด + h_train_btns.addWidget(self.btn_stop_train, 1) + + # ๅŽ้ข่ทŸ็€็Šถๆ€ๆ ‡็ญพๅ’Œ่ฟ›ๅบฆๆก + self.lbl_train_status = QLabel("Ready to train") + + + self.lbl_train_status = QLabel("Ready to train") + self.lbl_train_status.setStyleSheet("color: #4A90E2; font-weight: bold; margin-top: 5px;") + self.lbl_train_status.setVisible(False) + + + self.train_progress = QProgressBar() + self.train_progress.setRange(0, 100) + self.train_progress.setValue(0) + self.train_progress.setVisible(False) + + self.train_console = QTextEdit() + self.train_console.setReadOnly(True) + self.train_console.setPlaceholderText("Training logs will appear here...") + self.train_console.setMinimumHeight(150) + self.train_console.setStyleSheet("background-color: #1e1e1e; color: #d4d4d4; font-family: 'Courier New'; font-size: 11px;") + + train_layout.addLayout(h_train_btns) + train_layout.addWidget(self.lbl_train_status) + train_layout.addWidget(self.train_progress) + train_layout.addWidget(self.train_console) + + train_scroll.setWidget(train_scroll_content) + train_main_layout.addWidget(train_scroll) + + self.tabs.addTab(self.train_box, "Train") + + # --- 6. Bottom Confirm Buttons (Fixed Outside Tabs) --- btn_row = QHBoxLayout() - self.confirm_btn = QPushButton("Save Annotation") + self.confirm_btn = QPushButton("Confirm Annotation") self.clear_sel_btn = QPushButton("Clear Selection") self.confirm_btn.setProperty("class", "editor_save_btn") self.confirm_btn.setCursor(Qt.CursorShape.PointingHandCursor) self.clear_sel_btn.setCursor(Qt.CursorShape.PointingHandCursor) + self.confirm_btn.clicked.connect(self.on_confirm_clicked) + # [NEW] Route the clear button internally + self.clear_sel_btn.clicked.connect(self.on_clear_clicked) + btn_row.addWidget(self.confirm_btn) btn_row.addWidget(self.clear_sel_btn) - manual_layout.addLayout(btn_row) + layout.addLayout(btn_row) # Add strictly to the main vertical layout, remaining at the bottom + + self.label_groups = {} + + def _on_run_batch_clicked(self): + try: + start_idx = int(self.spin_start.text().strip()) + end_idx = int(self.spin_end.text().strip()) + self.batch_run_requested.emit(start_idx, end_idx) + except ValueError: + pass + + + def on_confirm_clicked(self): + """[MODIFIED] Route confirm action based on the active tab.""" + active_tab_idx = self.tabs.currentIndex() + + if active_tab_idx == 0: + # --- Hand Annotation Confirmation --- + data = {} + for head, group in self.label_groups.items(): + if hasattr(group, 'get_checked_label'): + val = group.get_checked_label() + if val: data[head] = val + elif hasattr(group, 'get_checked_labels'): + val = group.get_checked_labels() + if val: data[head] = val + self.annotation_saved.emit(data) + + elif active_tab_idx == 1: + # --- Smart Annotation Confirmation --- + self.smart_confirm_requested.emit() + + def on_clear_clicked(self): + """[NEW] Route clear action based on the active tab.""" + active_tab_idx = self.tabs.currentIndex() + if active_tab_idx == 0: + self.hand_clear_requested.emit() + elif active_tab_idx == 1: + self.smart_clear_requested.emit() + + # [MODIFIED] Hide the batch input box upon confirmation or action switch + def reset_smart_inference(self): + self.is_batch_mode_active = False + self.chart_widget.setVisible(False) + self.batch_result_text.setVisible(False) + self.btn_smart_infer.setEnabled(True) + self.btn_batch_infer.setEnabled(True) + self.infer_progress.setVisible(False) + + # Ensures Run Batch dropdowns disappear after Confirm or switching videos + self.batch_input_widget.setVisible(False) + + def reset_train_ui(self): + self.train_progress.setValue(0) + self.train_progress.setVisible(False) + + self.lbl_train_status.setText("Ready to train") + self.lbl_train_status.setVisible(False) + + self.train_console.clear() + + self.btn_start_train.setEnabled(True) + + # [MODIFIED] Save the full list and initialize the dropdowns + def update_action_list(self, action_names: list): + self.full_action_names = action_names - layout.addWidget(self.manual_box, 1) + self.spin_start.blockSignals(True) + self.spin_end.blockSignals(True) - self.label_groups = {} + self.spin_start.clear() + self.spin_end.clear() + + self.spin_start.addItems(self.full_action_names) + self.spin_end.addItems(self.full_action_names) + + self.spin_start.blockSignals(False) + self.spin_end.blockSignals(False) + + # [MODIFIED] Dynamically update the second dropdown to only show items from index i onwards + def _validate_batch_range(self): + start_idx = self.spin_start.currentIndex() + if start_idx < 0: return + + current_end_text = self.spin_end.currentText() + + self.spin_end.blockSignals(True) + self.spin_end.clear() + + # Only add items starting from the selected 'start_idx' + valid_end_items = self.full_action_names[start_idx:] + self.spin_end.addItems(valid_end_items) + + # Attempt to restore the previous selection if it's still in the valid range + if current_end_text in valid_end_items: + self.spin_end.setCurrentText(current_end_text) + else: + self.spin_end.setCurrentIndex(0) + + self.spin_end.blockSignals(False) + + # [MODIFIED] Calculate absolute end index based on dynamic relative index + def _on_run_batch_clicked(self): + start_idx = self.spin_start.currentIndex() + + # Since spin_end only contains items from start_idx onwards, + # its absolute index is its relative index + start_idx + end_idx = start_idx + self.spin_end.currentIndex() + + if start_idx >= 0 and end_idx >= start_idx: + self.batch_run_requested.emit(start_idx, end_idx) + + def show_inference_loading(self, is_loading: bool): + self.btn_smart_infer.setEnabled(not is_loading) + self.btn_batch_infer.setEnabled(not is_loading) + self.infer_progress.setVisible(is_loading) + if is_loading: + self.chart_widget.setVisible(False) + self.batch_result_text.setVisible(False) + + def display_inference_result(self, target_head: str, predicted_label: str, conf_dict: dict): + self.show_inference_loading(False) + self.is_batch_mode_active = False + self.chart_widget.update_chart(predicted_label, conf_dict) + + def display_batch_inference_result(self, result_text: str, batch_predictions: dict): + self.show_inference_loading(False) + self.is_batch_mode_active = True + self.pending_batch_results = batch_predictions + self.chart_widget.setVisible(False) + self.batch_result_text.setText(result_text) + self.batch_result_text.setVisible(True) def setup_dynamic_labels(self, label_definitions): while self.label_container_layout.count(): @@ -101,6 +588,8 @@ def setup_dynamic_labels(self, label_definitions): self.label_container_layout.addStretch() def set_annotation(self, data): + self.reset_smart_inference() + if not data: data = {} for head, group in self.label_groups.items(): val = data.get(head) @@ -121,6 +610,8 @@ def get_annotation(self): return result def clear_selection(self): + # [MODIFIED] Keep the Donut Chart visible even if the user clears hand annotations. + # self.reset_smart_inference() for group in self.label_groups.values(): if hasattr(group, 'set_checked_label'): group.set_checked_label(None) diff --git a/annotation_tool/ui/common/.DS_Store b/annotation_tool/ui/common/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/annotation_tool/ui/common/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/common/clip_explorer.py b/annotation_tool/ui/common/clip_explorer.py index e57db6d..bcdf664 100644 --- a/annotation_tool/ui/common/clip_explorer.py +++ b/annotation_tool/ui/common/clip_explorer.py @@ -25,7 +25,7 @@ def __init__(self, enable_context_menu=True, parent=None): super().__init__(parent) - self.setFixedWidth(300) + self.setFixedWidth(250) # Main Layout layout = QVBoxLayout(self) diff --git a/annotation_tool/ui/common/dialogs.py b/annotation_tool/ui/common/dialogs.py index dbe70f6..b91964b 100644 --- a/annotation_tool/ui/common/dialogs.py +++ b/annotation_tool/ui/common/dialogs.py @@ -75,6 +75,48 @@ def finalize_selection(self, mode: str): self.selected_mode = mode self.accept() +class ClassificationTypeDialog(QDialog): + """ + [NEW] Dialog to ask the user if the new Classification project + is Single-View or Multi-View. + """ + def __init__(self, parent=None) -> None: + super().__init__(parent) + self.setWindowTitle("Classification Project Type") + self.resize(450, 180) + self.is_multi_view = False # Default to Single-View + + layout = QVBoxLayout(self) + layout.setSpacing(15) + layout.setContentsMargins(30, 30, 30, 30) + + lbl = QLabel("Is this a Single-View or Multi-View project?") + lbl.setProperty("class", "dialog_instruction_lbl") + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + layout.addWidget(lbl) + + btn_layout = QHBoxLayout() + btn_layout.setSpacing(20) + + self.btn_sv = QPushButton("Single-View\n(Individual Videos)") + self.btn_sv.setMinimumSize(QSize(0, 70)) + self.btn_sv.setCursor(Qt.CursorShape.PointingHandCursor) + + self.btn_mv = QPushButton("Multi-View\n(Grouped by Folder)") + self.btn_mv.setMinimumSize(QSize(0, 70)) + self.btn_mv.setCursor(Qt.CursorShape.PointingHandCursor) + + btn_layout.addWidget(self.btn_sv) + btn_layout.addWidget(self.btn_mv) + layout.addLayout(btn_layout) + + # Connect signals + self.btn_sv.clicked.connect(lambda: self.finalize_selection(False)) + self.btn_mv.clicked.connect(lambda: self.finalize_selection(True)) + + def finalize_selection(self, is_multi: bool): + self.is_multi_view = is_multi + self.accept() class FolderPickerDialog(QDialog): """ @@ -120,4 +162,35 @@ def get_selected_folders(self) -> list[str]: """Returns a list of absolute paths for the selected folders.""" indexes = self.tree.selectionModel().selectedRows() paths = [self.model.filePath(idx) for idx in indexes] - return paths \ No newline at end of file + return paths + +class MediaErrorDialog(QMessageBox): + """ + [NEW] A standardized error dialog for media playback failures. + Provides a concise explanation and an FFmpeg command to fix the codec issue. + Technical logs are hidden in the details section to keep the UI clean. + """ + def __init__(self, error_string: str, parent=None) -> None: + super().__init__(parent) + + self.setIcon(QMessageBox.Icon.Critical) + + # Main short title + self.setWindowTitle("Video Decoding Error") + self.setText("Unsupported Video Codec Detected") + + # Concise explanation with the FFmpeg terminal command + info_text = ( + "Your system cannot decode this video's format (e.g., AV1, DivX, or Xvid). " + "The audio might play, but the video hardware decoder has failed.\n\n" + "To fix this, please transcode your file to a standard H.264 MP4 format. " + "Run the following command in your terminal:\n\n" + "ffmpeg -i input.mp4 -vcodec libx264 -acodec aac output.mp4" + ) + self.setInformativeText(info_text) + + # Hide the long, ugly technical error logs inside a collapsible "Show Details..." button + if error_string: + self.setDetailedText(f"System Diagnostic Logs:\n{error_string}") + + self.setStandardButtons(QMessageBox.StandardButton.Ok) \ No newline at end of file diff --git a/annotation_tool/ui/common/video_surface.py b/annotation_tool/ui/common/video_surface.py index 4c160c7..9235cc4 100644 --- a/annotation_tool/ui/common/video_surface.py +++ b/annotation_tool/ui/common/video_surface.py @@ -35,6 +35,8 @@ def __init__(self, parent=None): # 3. Add video widget to layout self.layout.addWidget(self.video_widget) + + def load_source(self, path): """ Loads the video source. diff --git a/annotation_tool/ui/common/welcome_widget.py b/annotation_tool/ui/common/welcome_widget.py index 93eab20..5805ebf 100644 --- a/annotation_tool/ui/common/welcome_widget.py +++ b/annotation_tool/ui/common/welcome_widget.py @@ -24,7 +24,7 @@ def __init__(self, parent=None): title_layout.setAlignment(Qt.AlignmentFlag.AlignHCenter) title_layout.setSpacing(15) - title = QLabel("SoccerNetPro Annotation Tool") + title = QLabel("Video Annotation Tool") title.setObjectName("welcome_title_lbl") self.logo_lbl = QLabel() @@ -70,7 +70,7 @@ def __init__(self, parent=None): self.tutorial_btn.setFixedSize(160, 40) self.tutorial_btn.setProperty("class", "welcome_secondary_btn") self.tutorial_btn.setCursor(Qt.CursorShape.PointingHandCursor) - self.tutorial_btn.clicked.connect(lambda: QDesktopServices.openUrl(QUrl("https://drive.google.com/file/d/1EgQXGMQya06vNMuX_7-OlAUjF_Je-ye_/view?usp=sharing"))) + self.tutorial_btn.clicked.connect(lambda: QDesktopServices.openUrl(QUrl("https://www.youtube.com/"))) self.github_btn = QPushButton("๐Ÿ™ GitHub Repo") self.github_btn.setFixedSize(160, 40) @@ -81,4 +81,4 @@ def __init__(self, parent=None): links_layout.addWidget(self.tutorial_btn) links_layout.addWidget(self.github_btn) - layout.addLayout(links_layout) + layout.addLayout(links_layout) \ No newline at end of file diff --git a/annotation_tool/ui/common/workspace.py b/annotation_tool/ui/common/workspace.py index 8e58fea..386d8f1 100644 --- a/annotation_tool/ui/common/workspace.py +++ b/annotation_tool/ui/common/workspace.py @@ -32,7 +32,7 @@ def __init__(self, # 1. Setup Layout layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(5) + layout.setSpacing(2) # 2. Instantiate Left Panel (Common) # Default to Localization-style naming if not provided diff --git a/annotation_tool/ui/description/.DS_Store b/annotation_tool/ui/description/.DS_Store deleted file mode 100644 index 4dcda3f..0000000 Binary files a/annotation_tool/ui/description/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/description/event_editor/.DS_Store b/annotation_tool/ui/description/event_editor/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/annotation_tool/ui/description/event_editor/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/description/media_player/.DS_Store b/annotation_tool/ui/description/media_player/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/annotation_tool/ui/description/media_player/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/localization/.DS_Store b/annotation_tool/ui/localization/.DS_Store deleted file mode 100644 index 8a5b1d5..0000000 Binary files a/annotation_tool/ui/localization/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/localization/event_editor/.DS_Store b/annotation_tool/ui/localization/event_editor/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/annotation_tool/ui/localization/event_editor/.DS_Store and /dev/null differ diff --git a/annotation_tool/ui/localization/event_editor/__init__.py b/annotation_tool/ui/localization/event_editor/__init__.py index d19042f..4185890 100644 --- a/annotation_tool/ui/localization/event_editor/__init__.py +++ b/annotation_tool/ui/localization/event_editor/__init__.py @@ -1,24 +1,30 @@ +# __init__.py (in ui/localization/) + from PyQt6.QtWidgets import ( - QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel + QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QTabWidget ) -from PyQt6.QtCore import Qt +from PyQt6.QtCore import Qt, pyqtSignal -# Import the separated components from the same package from .spotting_controls import AnnotationManagementWidget from .annotation_table import AnnotationTableWidget +from .smart_spotting import SmartSpottingWidget -# --- [Assembled] Localization Right Panel --- class LocRightPanel(QWidget): """ Right Panel for Localization Mode. - Contains: Undo/Redo Buttons, Annotation Tabs (Top), and Events Table (Bottom). + Contains: Undo/Redo Buttons (Global), and a TabWidget separating + Hand Annotation and Smart Annotation interfaces. """ + # Signal emitted when the user switches between Hand and Smart tabs + # The Controller should catch this to swap the Timeline markers + tabSwitched = pyqtSignal(int) + def __init__(self, parent=None): super().__init__(parent) self.setFixedWidth(400) layout = QVBoxLayout(self) - # --- Undo/Redo Button Header --- + # --- 1. Global Undo/Redo Button Header --- header_layout = QHBoxLayout() header_layout.setContentsMargins(0, 0, 0, 5) @@ -28,7 +34,6 @@ def __init__(self, parent=None): self.undo_btn = QPushButton("Undo") self.redo_btn = QPushButton("Redo") - # Button Styling btn_style = """ QPushButton { background-color: #444; color: #DDD; @@ -49,15 +54,32 @@ def __init__(self, parent=None): header_layout.addStretch() header_layout.addWidget(self.undo_btn) header_layout.addWidget(self.redo_btn) - layout.addLayout(header_layout) - # ----------------------------------- - # 1. Top: Multi Head Management (Tabs) - self.annot_mgmt = AnnotationManagementWidget() + # --- 2. Main Tabs --- + self.tabs = QTabWidget() + self.tabs.setObjectName("localization_tabs") + layout.addWidget(self.tabs) + + # ========== TAB 0: Hand Annotation ========== + self.hand_widget = QWidget() + hand_layout = QVBoxLayout(self.hand_widget) + hand_layout.setContentsMargins(0, 5, 0, 0) - # 2. Bottom: Labelled Event List (Table) + # Top: Multi Head Management (Tabs for categories) + self.annot_mgmt = AnnotationManagementWidget() + # Bottom: Labelled Event List (Table for hand annotations) self.table = AnnotationTableWidget() - layout.addWidget(self.annot_mgmt, 3) - layout.addWidget(self.table, 2) \ No newline at end of file + hand_layout.addWidget(self.annot_mgmt, 2) + hand_layout.addWidget(self.table, 3) + + self.tabs.addTab(self.hand_widget, "Hand Annotation") + + # ========== TAB 1: Smart Annotation ========== + # Loads the newly created SmartSpottingWidget + self.smart_widget = SmartSpottingWidget() + self.tabs.addTab(self.smart_widget, "Smart Annotation") + + # Connect tab change signal + self.tabs.currentChanged.connect(self.tabSwitched.emit) \ No newline at end of file diff --git a/annotation_tool/ui/localization/event_editor/annotation_table.py b/annotation_tool/ui/localization/event_editor/annotation_table.py index 3f1e471..df345ab 100644 --- a/annotation_tool/ui/localization/event_editor/annotation_table.py +++ b/annotation_tool/ui/localization/event_editor/annotation_table.py @@ -1,6 +1,6 @@ from PyQt6.QtWidgets import ( QWidget, QVBoxLayout, QLabel, QTableView, QHeaderView, QMenu, - QAbstractItemView + QAbstractItemView, QPushButton ) from PyQt6.QtCore import pyqtSignal, Qt, QAbstractTableModel @@ -131,14 +131,27 @@ class AnnotationTableWidget(QWidget): annotationSelected = pyqtSignal(int) annotationModified = pyqtSignal(dict, dict) # old_event, new_event annotationDeleted = pyqtSignal(dict) + updateTimeForSelectedRequested = pyqtSignal(dict) def __init__(self, parent=None): super().__init__(parent) layout = QVBoxLayout(self) + + # [NEW] 1. Edit Annotation + self.edit_lbl = QLabel("Edit Annotation") + self.edit_lbl.setProperty("class", "panel_header_lbl") + layout.addWidget(self.edit_lbl) - lbl = QLabel("Events List") - lbl.setProperty("class", "panel_header_lbl") - layout.addWidget(lbl) + self.btn_set_time = QPushButton("Set to Current Video Time") + self.btn_set_time.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_set_time.setEnabled(False) + self.btn_set_time.clicked.connect(self._on_set_time_clicked) + layout.addWidget(self.btn_set_time) + + # 2. Events List + self.list_lbl = QLabel("Events List") + self.list_lbl.setProperty("class", "panel_header_lbl") + layout.addWidget(self.list_lbl) self.table = QTableView() self.table.setProperty("class", "annotation_table") @@ -166,13 +179,25 @@ def set_data(self, annotations): def set_schema(self, schema): self.current_schema = schema + def _on_selection_changed(self, selected, deselected): indexes = selected.indexes() if indexes: + self.btn_set_time.setEnabled(True) row = indexes[0].row() item = self.model.get_annotation_at(row) if item: self.annotationSelected.emit(item.get('position_ms', 0)) + else: + self.btn_set_time.setEnabled(False) + + def _on_set_time_clicked(self): + indexes = self.table.selectionModel().selectedRows() + if indexes: + row = indexes[0].row() + item = self.model.get_annotation_at(row) + if item: + self.updateTimeForSelectedRequested.emit(item) def _show_context_menu(self, pos): index = self.table.indexAt(pos) @@ -187,4 +212,4 @@ def _show_context_menu(self, pos): selected_action = menu.exec(self.table.mapToGlobal(pos)) if selected_action == act_delete: - self.annotationDeleted.emit(item) \ No newline at end of file + self.annotationDeleted.emit(item) diff --git a/annotation_tool/ui/localization/event_editor/smart_spotting.py b/annotation_tool/ui/localization/event_editor/smart_spotting.py new file mode 100644 index 0000000..e461f84 --- /dev/null +++ b/annotation_tool/ui/localization/event_editor/smart_spotting.py @@ -0,0 +1,217 @@ +# smart_spotting.py + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, + QGroupBox, QLineEdit +) +from PyQt6.QtCore import pyqtSignal, Qt + +# Reuse the existing table widget for displaying smart predictions +from .annotation_table import AnnotationTableWidget + +class TimeLineEdit(QLineEdit): + """ + Custom QLineEdit tailored for time input in the format MM:SS.mmm. + Supports free typing and using Up/Down arrow keys to increment/decrement time. + """ + timeChanged = pyqtSignal(int) # Emits the new time in milliseconds + + def __init__(self, parent=None): + super().__init__(parent) + self._ms = 0 + self.setText("00:00.000") + self.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.setStyleSheet("font-family: monospace; font-weight: bold; font-size: 13px; padding: 2px;") + self.setFixedWidth(100) + + # When user finishes typing and loses focus or hits Enter + self.editingFinished.connect(self._on_edit_finished) + + def set_time_ms(self, ms: int): + """Programmatically set the time in milliseconds.""" + self._ms = max(0, ms) + self.setText(self._fmt_ms(self._ms)) + self.timeChanged.emit(self._ms) + + def get_time_ms(self) -> int: + """Get the current time in milliseconds.""" + return self._ms + + def _fmt_ms(self, ms: int) -> str: + """Format milliseconds to MM:SS.mmm""" + s = ms // 1000 + m = s // 60 + return f"{m:02}:{s%60:02}.{ms%1000:03}" + + def _parse_time(self, text: str) -> int: + """Parse MM:SS.mmm string back to milliseconds.""" + try: + parts = text.split(':') + if len(parts) >= 2: + m = int(parts[0]) + s_parts = parts[1].split('.') + s = int(s_parts[0]) + ms = int(s_parts[1]) if len(s_parts) > 1 else 0 + return (m * 60 + s) * 1000 + ms + except Exception: + pass + return self._ms # Return the last valid time if parsing fails + + def _on_edit_finished(self): + """Validate and apply manually typed time.""" + parsed_ms = self._parse_time(self.text()) + self.set_time_ms(parsed_ms) + + def keyPressEvent(self, event): + """Intercept Up/Down arrows to adjust time dynamically.""" + if event.key() == Qt.Key.Key_Up: + self._adjust_time(1) + event.accept() + elif event.key() == Qt.Key.Key_Down: + self._adjust_time(-1) + event.accept() + else: + super().keyPressEvent(event) + + def _adjust_time(self, direction: int): + """Adjust time based on cursor position.""" + cursor = self.cursorPosition() + ms = self._ms + + # Cursor positions for MM:SS.mmm: + # <= 2: Minutes + # 3 to 5: Seconds + # >= 6: Milliseconds + if cursor <= 2: + ms += direction * 60000 # +/- 1 minute + elif cursor <= 5: + ms += direction * 1000 # +/- 1 second + else: + ms += direction * 100 # +/- 100 milliseconds for smoother scrolling + + self.set_time_ms(max(0, ms)) + self.setCursorPosition(cursor) # Restore cursor position so user can keep pressing + + +class SmartSpottingWidget(QWidget): + """ + UI for Smart Annotation in Localization mode. + Allows users to select a time range, run inference, and review predicted events + in a separate table before confirming them. + """ + # Signals to be connected to the LocalizationManager + setTimeRequested = pyqtSignal(str) # 'start' or 'end' + runInferenceRequested = pyqtSignal(int, int) # start_ms, end_ms + confirmSmartRequested = pyqtSignal() # Merge smart events to hand events + clearSmartRequested = pyqtSignal() # Clear current smart predictions + + def __init__(self, parent=None): + super().__init__(parent) + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + # Internal state + self.start_ms = 0 + self.end_ms = 0 + + # --- 1. Time Range Selection Box --- + self.time_box = QGroupBox("Smart Inference Range") + self.time_box.setProperty("class", "smart_inference_box") + time_layout = QVBoxLayout(self.time_box) + + # Start Time Row + start_row = QHBoxLayout() + self.lbl_start = QLabel("Start Time:") + self.val_start = TimeLineEdit() + self.btn_set_start = QPushButton("Set to Current") + self.btn_set_start.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_set_start.clicked.connect(lambda: self.setTimeRequested.emit("start")) + self.val_start.timeChanged.connect(self._on_start_changed) + + start_row.addWidget(self.lbl_start) + start_row.addWidget(self.val_start) + start_row.addStretch() + start_row.addWidget(self.btn_set_start) + + # End Time Row + end_row = QHBoxLayout() + self.lbl_end = QLabel("End Time:") + self.val_end = TimeLineEdit() + self.btn_set_end = QPushButton("Set to Current") + self.btn_set_end.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_set_end.clicked.connect(lambda: self.setTimeRequested.emit("end")) + self.val_end.timeChanged.connect(self._on_end_changed) + + end_row.addWidget(self.lbl_end) + end_row.addWidget(self.val_end) + end_row.addStretch() + end_row.addWidget(self.btn_set_end) + + time_layout.addLayout(start_row) + time_layout.addLayout(end_row) + + # Run Button + self.btn_run_infer = QPushButton("Run Smart Inference") + self.btn_run_infer.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_run_infer.setProperty("class", "run_inference_btn") + self.btn_run_infer.clicked.connect(self._on_run_clicked) + + time_layout.addWidget(self.btn_run_infer) + layout.addWidget(self.time_box, 0) # 0 stretch means it stays at top + + # --- 2. Smart Events List (Separated from Hand Annotations) --- + self.smart_table = AnnotationTableWidget() + self.smart_table.edit_lbl.hide() + self.smart_table.btn_set_time.hide() + self.smart_table.list_lbl.setText("Predicted Events List") + + layout.addWidget(self.smart_table, 1) # 1 stretch means it fills remaining space + + # --- 3. Bottom Controls --- + bottom_row = QHBoxLayout() + self.btn_confirm = QPushButton("Confirm Predictions") + self.btn_confirm.setProperty("class", "editor_save_btn") + self.btn_confirm.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_confirm.clicked.connect(self.confirmSmartRequested.emit) + + self.btn_clear = QPushButton("Clear Predictions") + self.btn_clear.setCursor(Qt.CursorShape.PointingHandCursor) + self.btn_clear.clicked.connect(self.clearSmartRequested.emit) + + bottom_row.addWidget(self.btn_confirm) + bottom_row.addWidget(self.btn_clear) + layout.addLayout(bottom_row) + + # ==================== Logic & Validation ==================== + + def _on_start_changed(self, ms: int): + """Ensure Start Time does not exceed End Time""" + self.start_ms = ms + # If End Time is set (not 0) and Start > End, push End Time forward + if self.end_ms > 0 and self.start_ms > self.end_ms: + self.val_end.blockSignals(True) + self.val_end.set_time_ms(self.start_ms) + self.end_ms = self.start_ms + self.val_end.blockSignals(False) + + def _on_end_changed(self, ms: int): + """Ensure End Time does not drop below Start Time""" + self.end_ms = ms + # If End Time drops below Start Time, push Start Time backward + if self.end_ms > 0 and self.end_ms < self.start_ms: + self.val_start.blockSignals(True) + self.val_start.set_time_ms(self.end_ms) + self.start_ms = self.end_ms + self.val_start.blockSignals(False) + + def update_time_display(self, target: str, time_str: str, time_ms: int): + """Called by controller to update the UI with the player's current time""" + # We ignore the time_str since TimeLineEdit formats it internally + if target == "start": + self.val_start.set_time_ms(time_ms) + elif target == "end": + self.val_end.set_time_ms(time_ms) + + def _on_run_clicked(self): + """Emit the run signal with validated boundaries""" + self.runInferenceRequested.emit(self.start_ms, self.end_ms) \ No newline at end of file diff --git a/annotation_tool/ui/localization/event_editor/spotting_controls.py b/annotation_tool/ui/localization/event_editor/spotting_controls.py index 380ab99..be1a761 100644 --- a/annotation_tool/ui/localization/event_editor/spotting_controls.py +++ b/annotation_tool/ui/localization/event_editor/spotting_controls.py @@ -5,7 +5,6 @@ from PyQt6.QtCore import pyqtSignal, Qt # ==================== Custom Widgets ==================== - class LabelButton(QPushButton): """ Custom Label Button that supports Right-Click signal. @@ -16,9 +15,10 @@ class LabelButton(QPushButton): def __init__(self, text, parent=None): super().__init__(text, parent) - self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + self.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Fixed) self.setCursor(Qt.CursorShape.PointingHandCursor) - self.setMinimumHeight(40) + self.setMinimumHeight(28) + self.setStyleSheet("padding: 2px 10px;") self.setProperty("class", "spotting_label_btn") def mousePressEvent(self, event): @@ -50,29 +50,22 @@ def __init__(self, head_name, labels, parent=None): self.labels = labels layout = QVBoxLayout(self) - layout.setContentsMargins(5, 5, 5, 5) - layout.setSpacing(10) + layout.setContentsMargins(2, 2, 2, 2) + layout.setSpacing(5) # Time display self.time_label = QLabel("Current Time: 00:00.000") self.time_label.setProperty("class", "spotting_time_lbl") self.time_label.setAlignment(Qt.AlignmentFlag.AlignCenter) layout.addWidget(self.time_label) - - # Scroll area for buttons - scroll = QScrollArea() - scroll.setWidgetResizable(True) - scroll.setFrameShape(QScrollArea.Shape.NoFrame) - scroll.setProperty("class", "spotting_scroll_area") - self.grid_container = QWidget() - self.grid_layout = QGridLayout(self.grid_container) - self.grid_layout.setSpacing(8) - self.grid_layout.setContentsMargins(0,0,0,0) - self.grid_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + # Scroll area for buttons + self.scroll = QScrollArea() + self.scroll.setWidgetResizable(True) + self.scroll.setFrameShape(QScrollArea.Shape.NoFrame) + self.scroll.setProperty("class", "spotting_scroll_area") - scroll.setWidget(self.grid_container) - layout.addWidget(scroll) + layout.addWidget(self.scroll) self._populate_grid() @@ -84,39 +77,96 @@ def refresh_labels(self, new_labels): self._populate_grid() def _populate_grid(self): - # Clear existing items - while self.grid_layout.count(): - item = self.grid_layout.takeAt(0) - if item.widget(): - item.widget().deleteLater() - - cols = 2 - row, col = 0, 0 + old_widget = self.scroll.takeWidget() + if old_widget: + old_widget.deleteLater() + + self.grid_container = QWidget() + self.grid_layout = QVBoxLayout(self.grid_container) + self.grid_layout.setSpacing(6) + self.grid_layout.setContentsMargins(0, 0, 0, 0) + self.grid_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + + max_width = 360 + + #(Bin Packing) - # Add label buttons + buttons_info = [] for lbl in self.labels: display_text = lbl.replace('_', ' ') btn = LabelButton(display_text) btn.clicked.connect(lambda _, l=lbl: self.labelClicked.emit(l)) btn.rightClicked.connect(lambda l=lbl: self._show_context_menu(l)) btn.doubleClicked.connect(lambda l=lbl: self.renameLabelRequested.emit(l)) - self.grid_layout.addWidget(btn, row, col) - col += 1 - if col >= cols: - col = 0 - row += 1 + + btn.adjustSize() + btn_w = btn.sizeHint().width() + buttons_info.append((btn, btn_w)) + + buttons_info.sort(key=lambda x: x[1], reverse=True) + + rows = [] + + for btn, btn_w in buttons_info: + placed = False + for row in rows: + if row['width'] + btn_w + 6 <= max_width: + row['layout'].addWidget(btn) + row['width'] += btn_w + 6 + placed = True + break + + if not placed: + new_layout = QHBoxLayout() + new_layout.setSpacing(6) + new_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) + self.grid_layout.addLayout(new_layout) + + new_layout.addWidget(btn) + rows.append({'layout': new_layout, 'width': btn_w}) - # Add "Add Label" button at the bottom - add_btn = QPushButton("Add new label at current time") + add_btn = QPushButton("+ Add Label to Current Time") add_btn.setCursor(Qt.CursorShape.PointingHandCursor) - add_btn.setMinimumHeight(45) - add_btn.setProperty("class", "spotting_add_btn") + add_btn.setMinimumHeight(28) + add_btn.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Fixed) + add_btn.setStyleSheet(""" + QPushButton { + padding: 2px 10px; + color: #FFFFFF; + font-weight: bold; + background-color: #007BFF; + border: none; + border-radius: 4px; + } + QPushButton:hover { + background-color: #0056b3; + } + QPushButton:pressed { + background-color: #004085; + } + """) + add_btn.clicked.connect(self.addLabelRequested.emit) + add_btn.adjustSize() + add_btn_w = add_btn.sizeHint().width() - if col != 0: - row += 1 - self.grid_layout.addWidget(add_btn, row, 0, 1, 2) + placed_add = False + for row in rows: + if row['width'] + add_btn_w + 6 <= max_width: + row['layout'].addWidget(add_btn) + row['width'] += add_btn_w + 6 + placed_add = True + break + + if not placed_add: + new_layout = QHBoxLayout() + new_layout.setSpacing(6) + new_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) + self.grid_layout.addLayout(new_layout) + new_layout.addWidget(add_btn) + + self.scroll.setWidget(self.grid_container) def _show_context_menu(self, label): display_label = label.replace('_', ' ') diff --git a/annotation_tool/ui/localization/media_player/.DS_Store b/annotation_tool/ui/localization/media_player/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/annotation_tool/ui/localization/media_player/.DS_Store and /dev/null differ diff --git a/annotation_tool/viewer.py b/annotation_tool/viewer.py index 087f497..68b2b4e 100644 --- a/annotation_tool/viewer.py +++ b/annotation_tool/viewer.py @@ -7,6 +7,8 @@ from controllers.classification.class_annotation_manager import AnnotationManager from controllers.classification.class_navigation_manager import NavigationManager +from controllers.classification.inference_manager import InferenceManager +from controllers.classification.train_manager import TrainManager # [NEW] Import TrainManager from controllers.history_manager import HistoryManager from controllers.localization.localization_manager import LocalizationManager # Import Description Managers @@ -23,6 +25,7 @@ from utils import create_checkmark_icon, natural_sort_key, resource_path + class ActionClassifierApp(QMainWindow): """Main application window for annotation + localization + description + dense workflows.""" @@ -33,7 +36,7 @@ class ActionClassifierApp(QMainWindow): def __init__(self) -> None: super().__init__() - self.setWindowTitle("SoccerNet Pro Analysis Tool") + self.setWindowTitle("Video Annotation Tool") self.setGeometry(100, 100, 600, 400) # --- MVC wiring --- @@ -64,6 +67,8 @@ def __init__(self) -> None: # [NEW] Dense Description Controller self.dense_manager = DenseManager(self) + self.inference_manager = InferenceManager(self) + self.train_manager = TrainManager(self) # --- Local UI state (icons, etc.) --- bright_blue = QColor("#00BFFF") @@ -85,6 +90,7 @@ def __init__(self) -> None: self.ui.show_welcome_view() self._adjust_window_size(0) + # --------------------------------------------------------------------- # Global Media Control to Prevent Freezing/Ghost Frames # --------------------------------------------------------------------- @@ -134,9 +140,11 @@ def _adjust_window_size(self, index: int) -> None: self.resize(600, 400) else: - self.setMinimumSize(1000, 700) + #self.setMinimumSize(1000, 700) - self.resize(1400, 900) + #self.resize(1400, 900) + self.setMinimumSize(600, 400) + self.resize(1200, 800) def _safe_import_annotations(self): """Wrapper to ensure players are stopped before loading a new project.""" @@ -160,6 +168,12 @@ def connect_signals(self) -> None: # --- Classification - Left panel --- cls_left = self.ui.classification_ui.left_panel + # [NEW] Customize the filter combo box exclusively for Classification mode + # Blocking signals prevents triggering filter logic before UI is fully built + cls_left.filter_combo.blockSignals(True) + cls_left.filter_combo.clear() + cls_left.filter_combo.addItems(["Show All", "Hand Labelled", "Smart Labelled", "No Labelled"]) + cls_left.filter_combo.blockSignals(False) cls_controls = cls_left.project_controls cls_controls.createRequested.connect(self._safe_create_project) @@ -185,11 +199,24 @@ def connect_signals(self) -> None: # --- Classification - Right panel --- cls_right = self.ui.classification_ui.right_panel - cls_right.confirm_btn.clicked.connect(self.annot_manager.save_manual_annotation) - cls_right.clear_sel_btn.clicked.connect(self.annot_manager.clear_current_manual_annotation) + + # [MODIFIED] Disconnect the direct button click and use our new Tab-aware signals + # cls_right.confirm_btn.clicked.connect(self.annot_manager.save_manual_annotation) # <-- ๅˆ ้™คๆˆ–ๆณจ้‡Šๆމ่ฟ™่กŒๆ—งไปฃ็  + + # [NEW] Connect the tab-aware confirm signals to their respective manager functions + cls_right.annotation_saved.connect(lambda data: self.annot_manager.save_manual_annotation()) + cls_right.smart_confirm_requested.connect(self.annot_manager.confirm_smart_annotation_as_manual) + + # [MODIFIED] Connect tab-aware clear signals + cls_right.hand_clear_requested.connect(self.annot_manager.clear_current_manual_annotation) + cls_right.smart_clear_requested.connect(self.annot_manager.clear_current_smart_annotation) + cls_right.add_head_clicked.connect(self.annot_manager.handle_add_label_head) cls_right.remove_head_clicked.connect(self.annot_manager.handle_remove_label_head) + cls_right.smart_infer_requested.connect(self.inference_manager.start_inference) + cls_right.confirm_infer_requested.connect(lambda result_dict: self.annot_manager.save_manual_annotation()) + # Undo/redo for Class/Loc cls_right.undo_btn.clicked.connect(self.history_manager.perform_undo) cls_right.redo_btn.clicked.connect(self.history_manager.perform_redo) @@ -469,7 +496,12 @@ def closeEvent(self, event) -> None: # [NEW] Check dense data has_data = bool(self.model.dense_description_events) else: - has_data = bool(self.model.manual_annotations) + has_manual = bool(self.model.manual_annotations) + has_smart_confirmed = any( + data.get("_confirmed", False) + for data in self.model.smart_annotations.values() + ) + has_data = has_manual or has_smart_confirmed can_export = self.model.json_loaded and has_data @@ -528,7 +560,13 @@ def update_save_export_button_state(self) -> None: # [NEW] has_data = bool(self.model.dense_description_events) else: - has_data = bool(self.model.manual_annotations) + # [FIXED] Check the hand and smart annotation + has_manual = bool(self.model.manual_annotations) + has_smart_confirmed = any( + data.get("_confirmed", False) + for data in self.model.smart_annotations.values() + ) + has_data = has_manual or has_smart_confirmed can_export = self.model.json_loaded and has_data can_save = can_export and (self.model.current_json_path is not None) and self.model.is_data_dirty @@ -576,12 +614,31 @@ def get_current_action_path(self): return idx.parent().data(ProjectTreeModel.FilePathRole) return idx.data(ProjectTreeModel.FilePathRole) + + def sync_batch_inference_dropdowns(self) -> None: + """[NEW] Sync the Action List names from the model to the Batch Inference dropdowns.""" + right_panel = self.ui.classification_ui.right_panel + # Ensure the UI component exists and supports updating + if not hasattr(right_panel, 'update_action_list'): + return + + # Sort the data using natural sort to exactly match the left tree + sorted_list = sorted(self.model.action_item_data, key=lambda d: natural_sort_key(d.get("name", ""))) + action_names = [d["name"] for d in sorted_list] + + # Push the updated list to the dropdowns + right_panel.update_action_list(action_names) + def populate_action_tree(self) -> None: """Rebuild the action tree from model data using the new ProjectTreeModel.""" self.tree_model.clear() self.model.action_item_map.clear() sorted_list = sorted(self.model.action_item_data, key=lambda d: natural_sort_key(d.get("name", ""))) + + # [NEW] Extract sorted names and sync them to the Batch Inference dropdowns + action_names = [d["name"] for d in sorted_list] + self.ui.classification_ui.right_panel.update_action_list(action_names) for data in sorted_list: item = self.tree_model.add_entry( @@ -590,9 +647,10 @@ def populate_action_tree(self) -> None: source_files=data.get("source_files") ) self.model.action_item_map[data["path"]] = item + self.update_action_item_status(data["path"]) - for path in self.model.action_item_map.keys(): - self.update_action_item_status(path) + # [MODIFIED] Use the centralized sync method to update Smart Annotation dropdowns + self.sync_batch_inference_dropdowns() # Decide which manager handles the navigation logic if self._is_loc_mode(): @@ -638,8 +696,11 @@ def update_action_item_status(self, action_path: str) -> None: elif self._is_dense_mode(): is_done = action_path in self.model.dense_description_events and bool(self.model.dense_description_events[action_path]) else: - # Classification mode logic - is_done = action_path in self.model.manual_annotations and bool(self.model.manual_annotations[action_path]) + #is_done = action_path in self.model.manual_annotations and bool(self.model.manual_annotations[action_path]) + # [MODIFIED] Classification mode logic: Done if manually annotated OR smart confirmed + is_manual_done = action_path in self.model.manual_annotations and bool(self.model.manual_annotations[action_path]) + is_smart_done = self.model.smart_annotations.get(action_path, {}).get("_confirmed", False) + is_done = is_manual_done or is_smart_done item.setIcon(self.done_icon if is_done else self.empty_icon) @@ -665,12 +726,22 @@ def refresh_ui_after_undo_redo(self, action_path: str) -> None: Refreshes the UI after an Undo/Redo operation. Updates the tree icon, selection, and the active editor content. """ + # [MODIFIED] Batch operations might pass action_path as None. + # We must still refresh the filter and button states even if path is None. if not action_path: + if not self._is_loc_mode() and not self._is_desc_mode() and not self._is_dense_mode(): + self.nav_manager.apply_action_filter() + self.update_save_export_button_state() return # 1. Update the tree icon status self.update_action_item_status(action_path) + # [NEW] 1.5 Refresh the tree filter to immediately show/hide items! + # This fixes the bug where Undo/Redo doesn't visually update the list. + if not self._is_loc_mode() and not self._is_desc_mode() and not self._is_dense_mode(): + self.nav_manager.apply_action_filter() + # 2. Ensure the item is selected in the active tree active_tree = None if self._is_loc_mode(): @@ -694,9 +765,8 @@ def refresh_ui_after_undo_redo(self, action_path: str) -> None: elif self._is_desc_mode(): self.desc_nav_manager.on_item_selected(item.index(), None) elif self._is_dense_mode(): - # [NEW] Refresh Dense events display self.dense_manager._display_events_for_item(action_path) else: self.annot_manager.display_manual_annotation(action_path) - self.update_save_export_button_state() + self.update_save_export_button_state() \ No newline at end of file diff --git a/docs/OSL.md b/docs/OSL.md new file mode 100644 index 0000000..6cb951d --- /dev/null +++ b/docs/OSL.md @@ -0,0 +1,215 @@ +# OSL JSON Format + +The OSL JSON format is a unified, extensible data structure designed to handle multi-task video understanding datasets (e.g., action classification, action spotting, and various forms of video captioning) within a single file. + +By unifying dataset annotations, the OSL format makes it easy to load complex, multi-modal, and multi-task datasets without writing custom parsers for every new task. + +Below is a detailed breakdown of the format, followed by a comprehensive example. + +--- + +## 1. Top-Level Structure + +The root of the OSL JSON document contains metadata about the dataset, the shared taxonomy for labels, and the actual data items. + +| Field | Type | Description | Required | +| :--- | :--- | :--- | :---: | +| `version` | String | The version of the OSL format used (e.g., `"1.0"`). | Yes | +| `date` | String | The ISO-8601 formatted date when this split/file was produced (e.g., `"2025-10-20"`). | Yes | +| `dataset_name` | String | The name of the dataset and the specific split (e.g., `"OSL-Football-UNIFIED (train)"`). | Yes | +| `metadata` | Object | Global, file-level metadata (e.g., `source`, `license`, `created_by`, `notes`). | No | +| `tasks` | Array[String] | An advisory list of task families included in this file (e.g., `["action_classification", "action_spotting", ...]`). | No | +| `labels` | Object | The shared global taxonomy defining the available classes and their properties. | Yes* | +| `data` | Array[Object] | The list of data items (video clips) and their associated annotations. | Yes | + +*\* Required if the dataset involves classification or spotting tasks.* + +--- + +## 2. Shared Taxonomy (`labels`) + +The top-level `labels` object defines the taxonomy used across all data items for tasks like action classification and action spotting. It supports multi-head outputs (e.g., predicting an "action" and "attributes" simultaneously). + +Each key in the `labels` object represents a specific "head" and defines: +* `type`: Either `"single_label"` (exactly one class per item/event) or `"multi_label"` (zero or more classes). +* `labels`: An array of strings representing the valid class names. + +**Example:** +```json +"labels": { + "action": { + "type": "single_label", + "labels": ["Pass", "Shot", "Header", "Foul"] + }, + "attributes": { + "type": "multi_label", + "labels": ["Aerial", "SetPiece"] + } +} + +``` + +--- + +## 3. Data Items (`data`) + +The `data` array contains individual objects, each representing a specific data instance (typically a video clip) and all its multi-task annotations. + +### Item Properties + +| Field | Type | Description | Required | +| --- | --- | --- | --- | +| `id` | String | A unique identifier for this data item. All task targets below apply to this ID. | Yes | +| `metadata` | Object | Item-level metadata (e.g., `competition`, `stage`, `home_team`). | No | +| `inputs` | Array[Object] | A list of typed inputs associated with this item (e.g., raw video, extracted features, poses). | Yes | + +### 3.1 Inputs + +The `inputs` array defines the multi-modal data sources for the item. Different input types require different fields. Time references in annotations (like spotting or dense captioning) are relative to the start of the primary video file specified here. + +**Common Input Types:** + +* **Video:** `{ "type": "video", "path": "path/to/vid.mp4", "fps": 25 }` +* **Features:** `{ "type": "features", "name": "I3D", "path": "...", "dim": 1024, "hop_ms": 160 }` +* **Poses:** `{ "type": "poses", "format": "COCO", "path": "..." }` + +*(Note: If referencing an untrimmed video, you can specify `start_ms` and `end_ms` within the video input object to define a specific segment.)* + +### 3.2 Task Annotations + +An item can contain annotations for multiple tasks simultaneously. Only the fields relevant to the tasks present in the dataset need to be included. + +#### Action Classification (`labels`) + +Assigns classes to the entire video clip based on the shared taxonomy defined at the top level. + +* For `"single_label"` heads, use the `"label"` key (String). +* For `"multi_label"` heads, use the `"labels"` key (Array of Strings). + +```json +"labels": { + "action": { "label": "Header" }, + "attributes": { "labels": ["Aerial"] } +} + +``` + +#### Action Spotting (`events`) + +Defines instantaneous events occurring at specific timestamps within the clip. + +* `head`: The taxonomy head to use (from top-level `labels`). +* `label`: The class name. +* `position_ms`: The timestamp of the event in milliseconds (relative to the start of the clip). + +```json +"events": [ + { "head": "action", "label": "Header", "position_ms": 2100 } +] + +``` + +#### Video Captioning (`captions`) + +Provides text descriptions for the entire video clip. Multiple languages are supported. + +* `lang`: Language code (e.g., `"en"`, `"fr"`). +* `text`: The caption string. + +```json +"captions": [ + { "lang": "en", "text": "A precise cross finds the striker..." } +] + +``` + +#### Dense Video Captioning (`dense_captions`) + +Provides text descriptions for specific temporal segments within the video clip. + +* `start_ms`: Start time of the segment in milliseconds. +* `end_ms`: End time of the segment in milliseconds. +* `lang`: Language code. +* `text`: The caption string for that segment. + +```json +"dense_captions": [ + { "start_ms": 1200, "end_ms": 2500, "lang": "en", "text": "The winger accelerates..." } +] + +``` + +--- + +## 4. Full Example + +Below is a complete example of an OSL JSON file demonstrating a single data item with multiple inputs and multi-task annotations. + +```json +{ + "version": "1.0", + "date": "2025-10-20", + "dataset_name": "OSL-Football-UNIFIED (train)", + + "metadata": { + "source": "World Cup Finals", + "license": "CC-BY-NC-4.0", + "created_by": "OSL", + "notes": "Single item demonstrates multi-task targets on the same ID." + }, + + "tasks": ["action_classification", "action_spotting", "video_captioning", "dense_video_captioning"], + + "labels": { + "action": { + "type": "single_label", + "labels": ["Pass", "Shot", "Header", "Foul"] + }, + "attributes": { + "type": "multi_label", + "labels": ["Aerial", "SetPiece"] + } + }, + + "data": [ + { + "id": "M64_multi_000", + + "metadata": { + "competition": "FIFA WC", + "stage": "Final", + "home_team": "Germany", + "away_team": "Argentina" + }, + + "inputs": [ + { "type": "video", "path": "FWC2014/224p/M64_multi_000.mp4", "fps": 25 }, + { "type": "features", "name": "I3D", "path": "features/I3D/M64_multi_000.npy", "dim": 1024, "hop_ms": 160 }, + { "type": "poses", "format": "COCO", "path": "poses/M64_multi_000.json" }, + { "type": "gamestate", "path": "gamestate/M64_multi_000.json" } + ], + + "labels": { + "action": { "label": "Header" }, + "attributes": { "labels": ["Aerial"] } + }, + + "events": [ + { "head": "action", "label": "Header", "position_ms": 2100 }, + { "head": "action", "label": "Pass", "position_ms": 3850 } + ], + + "captions": [ + { "lang": "en", "text": "A precise cross finds the striker, who directs a powerful header on target." }, + { "lang": "fr", "text": "Un centre prรฉcis trouve lโ€™attaquant, qui place une tรชte puissante cadrรฉe." } + ], + + "dense_captions": [ + { "start_ms": 1200, "end_ms": 2500, "lang": "en", "text": "The winger accelerates down the flank and delivers a looping cross." }, + { "start_ms": 2600, "end_ms": 4200, "lang": "en", "text": "The striker rises above the defense and heads the ball toward goal." } + ] + } + ] +} + +``` diff --git a/docs/about.md b/docs/about.md index 502be00..ecf7237 100644 --- a/docs/about.md +++ b/docs/about.md @@ -1,10 +1,10 @@ # About -The Soccernet Pro Tool is developed by OpenSportsLab to help researchers and practitioners efficiently annotate sports video datasets. +The Video Annotation Tool is developed by OpenSportsLab to help researchers and practitioners efficiently annotate sports video datasets. - **Project Lead:** Silvio Giancola - **Front End Developer:** Jintao Ma -- **GitHub:** [OpenSportsLab/soccernetpro-ui](https://github.com/OpenSportsLab/soccernetpro-ui) +- **GitHub:** [OpenSportsLab/soccernetpro-ui](https://github.com/OpenSportsLab/VideoAnnotationTool) - **License:** Dual-licensed (GPL-3.0 / Commercial) We welcome feedback and contributions from the community. diff --git a/docs/index.md b/docs/index.md index 83a52af..d73741e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ -# Soccernet Pro Tool +# Video AnnotationTool -Welcome to the Soccernet Pro Annotation Tool documentation! +Welcome to the Video Annotation Tool Annotation Tool documentation! This tool helps you annotate action spotting datasets in sports video. Use the navigation to find installation instructions, user guides, and more. @@ -11,7 +11,7 @@ This tool helps you annotate action spotting datasets in sports video. Use the n - Intuitive graphical interface for annotating actions in sports videos - Fast video navigation and frame-accurate annotation - Easily edit timestamps and action labels -- Supports OSL JSON annotation format for seamless integration with [OSL-ActionSpotting](https://github.com/OpenSportsLab/OSL-ActionSpotting) +- Supports OSL JSON annotation format for seamless integration with [OSL-ActionSpotting](https://github.com/VideoAnnotationTool/OSL-ActionSpotting) - Save and load annotation files - Keyboard shortcuts for power users @@ -31,6 +31,7 @@ This tool helps you annotate action spotting datasets in sports video. Use the n - [Installation](installation.md) - [User Guide](gui_overview.md) - [FAQ](faq.md) +- [OSL JSON format](OSL.md) --- @@ -40,7 +41,7 @@ This project offers two licensing options to suit different needs: - **GPL-3.0 License**: This open-source license is intended for students, researchers, and the community. It supports open collaboration and sharing under the terms of the GNU General Public License v3.0. - See the [`LICENSE.txt`](https://github.com/OpenSportsLab/soccernetpro-ui/blob/main/LICENSE.txt) file for full details. + See the [`LICENSE.txt`](https://github.com/OpenSportsLab/VideoAnnotationTool/blob/main/LICENSE.txt) file for full details. - **Commercial License**: Designed for commercial use, this option allows integration of the software into proprietary products and services without the open-source obligations of GPL-3.0. diff --git a/mkdocs.yml b/mkdocs.yml index afc2e32..08775ac 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,8 +1,8 @@ -site_name: SoccerNetPro Analyzer -site_description: A PyQt6 GUI tool for analyzing and annotating SoccerNetPro datasets (OpenSportsLab) +site_name: Video Annotation Tool +site_description: A PyQt6 GUI tool for analyzing and annotating OSL datasets (OpenSportsLab) site_author: OpenSportsLab -repo_url: https://github.com/OpenSportsLab/soccernetpro-ui -repo_name: OpenSportsLab/soccernetpro-ui +repo_url: https://github.com/OpenSportsLab/VideoAnnotationTool +repo_name: OpenSportsLab/VideoAnnotationTool theme: name: material