Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions solidity/contracts/ReentrancyAttacker.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;
import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
interface IStakingVault{function withdraw(uint256 amount)external;function stake(uint256 amount)external;}
contract ReentrancyAttacker{
IStakingVault public vault;
uint256 public attackAmount;
uint256 public count;
constructor(address _v){vault=IStakingVault(_v);}
function approveAndStake(address token,uint256 amount)external{IERC20(token).approve(address(vault),amount);vault.stake(amount);}
function attack(uint256 amount)external{attackAmount=amount;vault.withdraw(amount);}
receive()external payable{if(count<3){count++;vault.withdraw(attackAmount);}}
}
43 changes: 14 additions & 29 deletions solidity/contracts/StakingVault.sol
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;

import "@openzeppelin/contracts/token/ERC20/IERC20.sol";

contract StakingVault {
import "@openzeppelin/contracts/utils/ReentrancyGuard.sol";
contract StakingVault is ReentrancyGuard {
IERC20 public stakingToken;
uint256 public rewardRate;
uint256 public totalStaked;

mapping(address => uint256) public balances;
mapping(address => uint256) public rewards;
mapping(address => uint256) public lastStakeTime;

event Staked(address indexed user, uint256 amount);
event Withdrawn(address indexed user, uint256 amount);
event RewardClaimed(address indexed user, uint256 amount);

constructor(address _stakingToken, uint256 _rewardRate) {
stakingToken = IERC20(_stakingToken);
rewardRate = _rewardRate;
}

function stake(uint256 amount) external {
require(amount > 0, "Cannot stake 0");
stakingToken.transferFrom(msg.sender, address(this), amount);
Expand All @@ -30,51 +25,41 @@ contract StakingVault {
lastStakeTime[msg.sender] = block.timestamp;
emit Staked(msg.sender, amount);
}

function _updateReward(address account) internal {
if (balances[account] > 0) {
uint256 timeStaked = block.timestamp - lastStakeTime[account];
rewards[account] += balances[account] * timeStaked * rewardRate / 1e18;
}
lastStakeTime[account] = block.timestamp;
}

// BUG: Reentrancy — state update after external call
function withdraw(uint256 amount) external {
// FIX: CEI pattern — state updates BEFORE external call + nonReentrant guard
function withdraw(uint256 amount) external nonReentrant {
require(balances[msg.sender] >= amount, "Insufficient balance");
_updateReward(msg.sender);

// External call before state update
(bool success, ) = payable(msg.sender).call{value: amount}("");
require(success, "Transfer failed");

// State update after external call — vulnerable to reentrancy
// Effects first
balances[msg.sender] -= amount;
totalStaked -= amount;
// Interaction last
(bool success, ) = payable(msg.sender).call{value: amount}("");
require(success, "Transfer failed");
emit Withdrawn(msg.sender, amount);
}

// BUG: Same reentrancy pattern in claimRewards
function claimRewards() external {
// FIX: Same CEI fix — zero rewards BEFORE external call + nonReentrant
function claimRewards() external nonReentrant {
_updateReward(msg.sender);
uint256 reward = rewards[msg.sender];
require(reward > 0, "No rewards");

// Effects first
rewards[msg.sender] = 0;
// Interaction last
(bool success, ) = payable(msg.sender).call{value: reward}("");
require(success, "Transfer failed");

rewards[msg.sender] = 0;
emit RewardClaimed(msg.sender, reward);
}

function getStakedBalance(address account) external view returns (uint256) {
return balances[account];
}

function getStakedBalance(address account) external view returns (uint256) { return balances[account]; }
function getPendingRewards(address account) external view returns (uint256) {
uint256 timeStaked = block.timestamp - lastStakeTime[account];
return rewards[account] + balances[account] * timeStaked * rewardRate / 1e18;
}

receive() external payable {}
}
36 changes: 36 additions & 0 deletions solidity/test/StakingVault.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
const{expect}=require("chai");
const{ethers}=require("hardhat");
describe("StakingVault - Reentrancy Fix",function(){
let vault,mockToken,owner,attacker;
beforeEach(async()=>{
[owner,attacker]=await ethers.getSigners();
const MockToken=await ethers.getContractFactory("MockERC20");
mockToken=await MockToken.deploy("Mock","MCK",ethers.parseEther("1000000"));
const StakingVault=await ethers.getContractFactory("StakingVault");
vault=await StakingVault.deploy(await mockToken.getAddress(),ethers.parseEther("0.001"));
await owner.sendTransaction({to:await vault.getAddress(),value:ethers.parseEther("10")});
});
it("withdraw: state zeroed before external call (CEI)",async()=>{
await mockToken.transfer(attacker.address,ethers.parseEther("100"));
await mockToken.connect(attacker).approve(await vault.getAddress(),ethers.parseEther("100"));
await vault.connect(attacker).stake(ethers.parseEther("100"));
await vault.connect(attacker).withdraw(ethers.parseEther("100"));
expect(await vault.balances(attacker.address)).to.equal(0n);
});
it("claimRewards: rewards zeroed before external call",async()=>{
await mockToken.transfer(attacker.address,ethers.parseEther("100"));
await mockToken.connect(attacker).approve(await vault.getAddress(),ethers.parseEther("100"));
await vault.connect(attacker).stake(ethers.parseEther("100"));
await ethers.provider.send("evm_increaseTime",[3600]);
await ethers.provider.send("evm_mine",[]);
await vault.connect(attacker).claimRewards();
expect(await vault.rewards(attacker.address)).to.equal(0n);
});
it("nonReentrant blocks recursive withdrawal attack",async()=>{
const Attacker=await ethers.getContractFactory("ReentrancyAttacker");
const atk=await Attacker.deploy(await vault.getAddress());
await mockToken.transfer(await atk.getAddress(),ethers.parseEther("100"));
await atk.approveAndStake(await mockToken.getAddress(),ethers.parseEther("100"));
await expect(atk.attack(ethers.parseEther("10"))).to.be.revertedWithCustomError(vault,"ReentrancyGuardReentrantCall");
});
});
Loading