Update jobapi pt example#4112
Conversation
Greptile OverviewGreptile SummaryThis PR updates the PyTorch Job API examples to fix CIFAR10 download issues and migrate from ModelLearner to the Client API approach. Key improvements include:
All previously reported issues have been addressed including FileLock additions, task name alignment (validate task), and accuracy calculation precision. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Server
participant Client1 as Client (site-1)
participant Client2 as Client (site-2)
participant CIFAR10 as CIFAR10 Dataset
Note over Server,Client2: Initialization Phase
Server->>Client1: Initialize with global model
Server->>Client2: Initialize with global model
Client1->>CIFAR10: Download dataset (with FileLock)
Client2->>CIFAR10: Download dataset (with FileLock)
Note over Client1,Client2: FileLock prevents race conditions
loop Training Rounds (num_rounds)
Note over Server,Client2: Training Phase
Server->>Client1: Send global model (train task)
Server->>Client2: Send global model (train task)
Client1->>Client1: Load partitioned data subset
Client2->>Client2: Load partitioned data subset
Client1->>Client1: Train local model (local_epochs)
Client1->>Client1: Evaluate local model
Client1->>Client1: Evaluate global model for selection
Client1->>Server: Send updated weights + accuracy metric
Client2->>Client2: Train local model (local_epochs)
Client2->>Client2: Evaluate local model
Client2->>Client2: Evaluate global model for selection
Client2->>Server: Send updated weights + accuracy metric
Server->>Server: Aggregate weights (FedAvg)
Server->>Server: Select best model (IntimeModelSelector)
opt Cross-Site Validation
Note over Server,Client2: Cross-Site Evaluation Phase
Server->>Client1: Request best local model (submit_model task)
Server->>Client2: Request best local model (submit_model task)
Client1->>Server: Send best local model
Client2->>Server: Send best local model
Server->>Client1: Send Client2's model (validate task)
Server->>Client2: Send Client1's model (validate task)
Client1->>Client1: Evaluate Client2's model
Client2->>Client2: Evaluate Client1's model
Client1->>Server: Send cross-site accuracy
Client2->>Server: Send cross-site accuracy
Server->>Server: Generate validation JSON results
end
end
Note over Server,Client2: Training Complete
|
Additional Comments (1)
|
Additional Comments (3)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Additional Comments (1)
|
Additional Comments (1)
|
In Python 3, / already does floating‑point division. To force float calculation explicitly, the float literal 100.0 * correct / total is the simplest and clearest |
This comment was marked as resolved.
This comment was marked as resolved.
@greptileai, it will not loss precision, test yourself: python -c "correct=3; total=8; print(100correct/total, 100.0correct/total, 100*float(correct)/total)" |
This comment was marked as resolved.
This comment was marked as resolved.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
|
/build |
|
/build |
There was a problem hiding this comment.
Pull request overview
This PR updates the CIFAR-10 PyTorch Job API examples to use a consistent dataset location with safe concurrent downloads, and replaces the older model-learner-based FedAvg cross-site validation example with a Client API/script-runner–based example that supports heterogeneous data partitioning.
Changes:
- Ensure all CIFAR-10 example scripts use
/tmp/nvflare/data/cifar10and addfilelock-based locking around dataset downloads to avoid race conditions across multiple sites. - Refine client-side evaluation logic to use float accuracies with explicit checks for empty test loaders, and improve logging for train/evaluate/submit tasks.
- Replace the
fedavg_model_learner_xsite_val_cifar10.pyexample with a newfedavg_script_runner_xsite_val_cifar10.py+cifar10_fl_partitioned.pypipeline that partitions CIFAR-10 non-iid across sites and integrates Cross-Site Evaluation (CSE); align swarm and CSE script runners, requirements, and README with these changes.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
examples/advanced/job_api/pt/swarm_script_runner_cifar10.py |
Points the swarm job to the new cifar10_fl_train_eval_submit.py training script to keep the swarm example aligned with the updated Client API workflow. |
examples/advanced/job_api/pt/src/cifar10_lightning_fl.py |
Uses /tmp/nvflare/data/cifar10 and a FileLock-guarded prepare_data to prevent concurrent CIFAR-10 downloads across Lightning clients. |
examples/advanced/job_api/pt/src/cifar10_fl_train_eval_submit.py |
Adds locked dataset download, converts accuracy computation to safe float division with a guard for empty test loaders, and improves task-specific logging. |
examples/advanced/job_api/pt/src/cifar10_fl_partitioned.py |
New Client API training script that loads optional per-site index splits, partitions CIFAR-10 accordingly, and supports train/evaluate/submit_model with best-model saving and robust accuracy computation. |
examples/advanced/job_api/pt/src/cifar10_fl.py |
Aligns dataset root with the new CIFAR-10 path, wraps downloads in a FileLock, and updates evaluation to use float accuracies with an explicit empty-loader check. |
examples/advanced/job_api/pt/requirements.txt |
Adds filelock>=3.12.0 to support the new locking behavior in all CIFAR-10 examples that download data. |
examples/advanced/job_api/pt/fedavg_script_runner_xsite_val_cifar10.py |
New FedAvg Job API script that partitions CIFAR-10 with a Dirichlet sampler, creates/save per-site splits, configures Scatter-and-Gather training plus CrossSiteModelEval, and wires ScriptRunner clients with the partitioned training script. |
examples/advanced/job_api/pt/fedavg_model_learner_xsite_val_cifar10.py |
Removes the older model-learner–based FedAvg x-site validation example in favor of the newer Client API/script-runner–based example. |
examples/advanced/job_api/pt/cse_script_runner_cifar10.py |
Switches to the cifar10_fl_train_eval_submit.py script, uses DataKind.WEIGHTS for the aggregator, aligns CSE validation_task_name="evaluate", and normalizes client IDs to site-{i}. |
examples/advanced/job_api/pt/README.md |
Updates the fifth example to reference the new FedAvg script runner with cross-site validation and heterogeneous data partitioning, adjusting the command and description accordingly. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Additional Comments (2)
|
|
/build |
|
/build |
|
/build |
Fixes # . ### Description cherry pick #4112 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Fixes # .
Description
correct cifar10 download issue, update learner to clientapi
Types of changes
./runtest.sh.