diff --git a/solidity/contracts/ReentrancyAttacker.sol b/solidity/contracts/ReentrancyAttacker.sol new file mode 100644 index 000000000..e2885ac38 --- /dev/null +++ b/solidity/contracts/ReentrancyAttacker.sol @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +interface IStakingVault{function withdraw(uint256 amount)external;function stake(uint256 amount)external;} +contract ReentrancyAttacker{ + IStakingVault public vault; + uint256 public attackAmount; + uint256 public count; + constructor(address _v){vault=IStakingVault(_v);} + function approveAndStake(address token,uint256 amount)external{IERC20(token).approve(address(vault),amount);vault.stake(amount);} + function attack(uint256 amount)external{attackAmount=amount;vault.withdraw(amount);} + receive()external payable{if(count<3){count++;vault.withdraw(attackAmount);}} +} diff --git a/solidity/contracts/StakingVault.sol b/solidity/contracts/StakingVault.sol index 0518e83f4..0ddf8098d 100644 --- a/solidity/contracts/StakingVault.sol +++ b/solidity/contracts/StakingVault.sol @@ -1,26 +1,21 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.20; - import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; - -contract StakingVault { +import "@openzeppelin/contracts/utils/ReentrancyGuard.sol"; +contract StakingVault is ReentrancyGuard { IERC20 public stakingToken; uint256 public rewardRate; uint256 public totalStaked; - mapping(address => uint256) public balances; mapping(address => uint256) public rewards; mapping(address => uint256) public lastStakeTime; - event Staked(address indexed user, uint256 amount); event Withdrawn(address indexed user, uint256 amount); event RewardClaimed(address indexed user, uint256 amount); - constructor(address _stakingToken, uint256 _rewardRate) { stakingToken = IERC20(_stakingToken); rewardRate = _rewardRate; } - function stake(uint256 amount) external { require(amount > 0, "Cannot stake 0"); stakingToken.transferFrom(msg.sender, address(this), amount); @@ -30,7 +25,6 @@ contract StakingVault { lastStakeTime[msg.sender] = block.timestamp; emit Staked(msg.sender, amount); } - function _updateReward(address account) internal { if (balances[account] > 0) { uint256 timeStaked = block.timestamp - lastStakeTime[account]; @@ -38,43 +32,34 @@ contract StakingVault { } lastStakeTime[account] = block.timestamp; } - - // BUG: Reentrancy — state update after external call - function withdraw(uint256 amount) external { + // FIX: CEI pattern — state updates BEFORE external call + nonReentrant guard + function withdraw(uint256 amount) external nonReentrant { require(balances[msg.sender] >= amount, "Insufficient balance"); _updateReward(msg.sender); - - // External call before state update - (bool success, ) = payable(msg.sender).call{value: amount}(""); - require(success, "Transfer failed"); - - // State update after external call — vulnerable to reentrancy + // Effects first balances[msg.sender] -= amount; totalStaked -= amount; + // Interaction last + (bool success, ) = payable(msg.sender).call{value: amount}(""); + require(success, "Transfer failed"); emit Withdrawn(msg.sender, amount); } - - // BUG: Same reentrancy pattern in claimRewards - function claimRewards() external { + // FIX: Same CEI fix — zero rewards BEFORE external call + nonReentrant + function claimRewards() external nonReentrant { _updateReward(msg.sender); uint256 reward = rewards[msg.sender]; require(reward > 0, "No rewards"); - + // Effects first + rewards[msg.sender] = 0; + // Interaction last (bool success, ) = payable(msg.sender).call{value: reward}(""); require(success, "Transfer failed"); - - rewards[msg.sender] = 0; emit RewardClaimed(msg.sender, reward); } - - function getStakedBalance(address account) external view returns (uint256) { - return balances[account]; - } - + function getStakedBalance(address account) external view returns (uint256) { return balances[account]; } function getPendingRewards(address account) external view returns (uint256) { uint256 timeStaked = block.timestamp - lastStakeTime[account]; return rewards[account] + balances[account] * timeStaked * rewardRate / 1e18; } - receive() external payable {} } diff --git a/solidity/test/StakingVault.test.js b/solidity/test/StakingVault.test.js new file mode 100644 index 000000000..fa076d389 --- /dev/null +++ b/solidity/test/StakingVault.test.js @@ -0,0 +1,36 @@ +const{expect}=require("chai"); +const{ethers}=require("hardhat"); +describe("StakingVault - Reentrancy Fix",function(){ + let vault,mockToken,owner,attacker; + beforeEach(async()=>{ + [owner,attacker]=await ethers.getSigners(); + const MockToken=await ethers.getContractFactory("MockERC20"); + mockToken=await MockToken.deploy("Mock","MCK",ethers.parseEther("1000000")); + const StakingVault=await ethers.getContractFactory("StakingVault"); + vault=await StakingVault.deploy(await mockToken.getAddress(),ethers.parseEther("0.001")); + await owner.sendTransaction({to:await vault.getAddress(),value:ethers.parseEther("10")}); + }); + it("withdraw: state zeroed before external call (CEI)",async()=>{ + await mockToken.transfer(attacker.address,ethers.parseEther("100")); + await mockToken.connect(attacker).approve(await vault.getAddress(),ethers.parseEther("100")); + await vault.connect(attacker).stake(ethers.parseEther("100")); + await vault.connect(attacker).withdraw(ethers.parseEther("100")); + expect(await vault.balances(attacker.address)).to.equal(0n); + }); + it("claimRewards: rewards zeroed before external call",async()=>{ + await mockToken.transfer(attacker.address,ethers.parseEther("100")); + await mockToken.connect(attacker).approve(await vault.getAddress(),ethers.parseEther("100")); + await vault.connect(attacker).stake(ethers.parseEther("100")); + await ethers.provider.send("evm_increaseTime",[3600]); + await ethers.provider.send("evm_mine",[]); + await vault.connect(attacker).claimRewards(); + expect(await vault.rewards(attacker.address)).to.equal(0n); + }); + it("nonReentrant blocks recursive withdrawal attack",async()=>{ + const Attacker=await ethers.getContractFactory("ReentrancyAttacker"); + const atk=await Attacker.deploy(await vault.getAddress()); + await mockToken.transfer(await atk.getAddress(),ethers.parseEther("100")); + await atk.approveAndStake(await mockToken.getAddress(),ethers.parseEther("100")); + await expect(atk.attack(ethers.parseEther("10"))).to.be.revertedWithCustomError(vault,"ReentrancyGuardReentrantCall"); + }); +});