This repository contains an implementation of a Deep Q-Network (DQN) agent capable of solving the classic CartPole control problem.
Built from scratch using PyTorch and Gymnasium, this project demonstrates key Reinforcement Learning concepts including Experience Replay, Target Networks, and Epsilon-Greedy exploration. It also features a custom interactive demo mode that allows to stress-test the trained agent by applying external forces.
The Goal: Balance a pole on a cart by moving the cart left or right. The Challenge: The environment provides no labeled data. The agent must learn solely through trial and error, associating actions with delayed rewards.
- Deep Q-Network (DQN): Replaces the traditional Q-Table with a Neural Network to handle continuous state spaces.
- Experience Replay: Uses a circular buffer to store and sample past transitions, breaking temporal correlations and stabilizing training.
- Target Network: Implements a secondary "frozen" network to calculate stable Q-value targets, preventing oscillation.
- Robust Checkpointing: Automatically saves the best performing model (not just the last one) to avoid catastrophic forgetting.
- Interactive Demo: A
pygame-based inference script that allows humans to "kick" the cart to test the agent's recovery reflexes.
- Clone the repository:
git clone https://github.com/codewithbro95/deep-q-learning-cartpole.git
cd deep-q-learning-cartpole
- Install dependencies: It is recommended to use a virtual environment.
pip install gymnasium[classic_control] torch matplotlib pygame numpy
| File | Description |
|---|---|
main.py |
The main entry point for training. Contains the training loop and performance plotting. |
dqn_agent.py |
Contains the DQNAgent class, ReplayMemory, and the DQN neural network architecture. |
demo.py |
The inference script. Loads the trained model and runs the interactive simulation. |
cartpole_best.pth |
The saved weights of the best-performing model (generated after training). |
To train the model from scratch, run:
python main.py
- What happens: The agent will play 600 episodes.
- Visuals: A live plot will appear showing the duration (score) of each episode.
- Output: The script will save the best model weights to
cartpole_best.pthwhenever a new high score is reached.
Once you have a trained model (or if you are using the provided weights), run:
python demo.py
Controls:
-
LEFT ARROW: Apply a sudden force (kick) to the left.
-
RIGHT ARROW: Apply a sudden force (kick) to the right.
-
Observation: Watch how the agent frantically adjusts the cart position to recover the pole's balance after being shoved.
We use a simple Multi-Layer Perceptron (MLP) to approximate the Q-Value function:
- Input Layer: 4 Neurons (Cart Position, Cart Velocity, Pole Angle, Pole Velocity)
- Hidden Layer 1: 128 Neurons (ReLU activation)
- Hidden Layer 2: 128 Neurons (ReLU activation)
- Output Layer: 2 Neurons (Action Left, Action Right) - Linear activation (Raw Q-Values)
BATCH_SIZE: 128GAMMA(Discount Factor): 0.99EPSILON(Exploration): Decays from 0.9 to 0.05LEARNING_RATE: 1e-4 (AdamW Optimizer)
- "Crash: NoneType Error": This usually happens if the training loop tries to process the "Next State" after the game has ended. The code handles this by masking out terminal states in the
optimize_modelfunction. - Model Performance Degrades: If the model performs well at episode 300 but fails at episode 600, this is "Catastrophic Forgetting." Ensure you are using
cartpole_best.pth(the checkpointed model) and not the final state of the network.
Project created for educational purposes to understand Deep Reinforcement Learning engineering practices under the hood.

