diff --git a/src/WrappedMToken.sol b/src/WrappedMToken.sol index e76347b..8702c13 100644 --- a/src/WrappedMToken.sol +++ b/src/WrappedMToken.sol @@ -40,7 +40,7 @@ contract WrappedMToken is IWrappedMToken, Migratable, ERC20Extended { mapping(address account => BalanceInfo balance) internal _balances; modifier onlyWhenEarning() { - if (IMTokenLike(mToken).isEarning(address(this))) revert NotInEarningState(); + if (!IMTokenLike(mToken).isEarning(address(this))) revert NotInEarningState(); _; } @@ -56,7 +56,7 @@ contract WrappedMToken is IWrappedMToken, Migratable, ERC20Extended { /* ============ Interactive Functions ============ */ - function wrap(address recipient_, uint256 amount_) external onlyWhenEarning { + function wrap(address recipient_, uint256 amount_) external { _mint(recipient_, UIntMath.safe240(amount_)); IMTokenLike(mToken).transferFrom(msg.sender, address(this), amount_); @@ -78,7 +78,19 @@ contract WrappedMToken is IWrappedMToken, Migratable, ERC20Extended { IMTokenLike(mToken).transfer(vault, yield_); } - function startEarningFor(address account_) external { + function startEarningM() external { + if (mIndexWhenEarningStopped != 0) revert OnlyEarningOnce(); + + IMTokenLike(mToken).startEarning(); + } + + function stopEarningM() external { + mIndexWhenEarningStopped = currentIndex(); + + IMTokenLike(mToken).stopEarning(); + } + + function startEarningFor(address account_) external onlyWhenEarning { if (!_isApprovedEarner(account_)) revert NotApprovedEarner(); (bool isEarning_, , , uint240 balance_) = _getBalanceInfo(account_); @@ -118,18 +130,6 @@ contract WrappedMToken is IWrappedMToken, Migratable, ERC20Extended { } } - function startEarningM() external { - if (mIndexWhenEarningStopped != 0) revert OnlyEarningOnce(); - - IMTokenLike(mToken).startEarning(); - } - - function stopEarningM() external { - mIndexWhenEarningStopped = currentIndex(); - - IMTokenLike(mToken).stopEarning(); - } - /* ============ View/Pure Functions ============ */ function accruedYieldOf(address account_) external view returns (uint240 yield_) { diff --git a/test/Test.t.sol b/test/Test.t.sol index 166ecf7..a04108c 100644 --- a/test/Test.t.sol +++ b/test/Test.t.sol @@ -72,6 +72,8 @@ contract Tests is Test { _registrar.setListContains(_EARNERS_LIST, _alice, true); _registrar.setListContains(_EARNERS_LIST, _bob, true); + _wrappedMToken.startEarningM(); + _wrappedMToken.startEarningFor(_alice); _wrappedMToken.startEarningFor(_bob); diff --git a/test/WrappedMToken.t.sol b/test/WrappedMToken.t.sol index 56eec8f..7f7a7b7 100644 --- a/test/WrappedMToken.t.sol +++ b/test/WrappedMToken.t.sol @@ -54,6 +54,8 @@ contract WrappedMTokenTests is Test { _wrappedMToken = WrappedMTokenHarness(address(new Proxy(address(_implementation)))); _mToken.setCurrentIndex(_currentIndex = 1_100000068703); + + _wrappedMToken.startEarningM(); } /* ============ constructor ============ */ @@ -430,6 +432,31 @@ contract WrappedMTokenTests is Test { assertEq(_wrappedMToken.totalEarningSupply(), 1); // TODO: Fix? } + /* ============ startEarningM ============ */ + function test_startEarningM_onlyEarningOnce() external { + assertEq(_mToken.isEarning(address(_wrappedMToken)), true); + + _wrappedMToken.stopEarningM(); + + vm.expectRevert(IWrappedMToken.OnlyEarningOnce.selector); + _wrappedMToken.startEarningM(); + } + + /* ============ stopEarningM ============ */ + function test_stopEarningM() external { + assertEq(_mToken.isEarning(address(_wrappedMToken)), true); + assertEq(_wrappedMToken.mIndexWhenEarningStopped(), 0); + + _wrappedMToken.stopEarningM(); + + assertEq(_mToken.isEarning(address(_wrappedMToken)), false); + assertEq(_wrappedMToken.mIndexWhenEarningStopped(), _mToken.currentIndex()); + + _mToken.setCurrentIndex(_currentIndex = _EXP_SCALED_ONE); + + assertEq(_wrappedMToken.currentIndex(), _wrappedMToken.mIndexWhenEarningStopped()); + } + /* ============ balanceOf ============ */ function test_balanceOf_nonEarner() external { _wrappedMToken.setBalanceOf(_alice, 500); diff --git a/test/utils/Mocks.sol b/test/utils/Mocks.sol index 87dda4d..b3c7833 100644 --- a/test/utils/Mocks.sol +++ b/test/utils/Mocks.sol @@ -8,12 +8,13 @@ contract MockM { uint128 public currentIndex; mapping(address account => uint256 balance) public balanceOf; + mapping(address account => bool isEarning) public isEarning; - function transfer(address, uint256) external returns (bool success_) { + function transfer(address, uint256) external pure returns (bool success_) { return true; } - function transferFrom(address, address, uint256) external returns (bool success_) { + function transferFrom(address, address, uint256) external pure returns (bool success_) { return true; } @@ -28,6 +29,14 @@ contract MockM { function setTtgRegistrar(address ttgRegistrar_) external { ttgRegistrar = ttgRegistrar_; } + + function startEarning() external { + isEarning[msg.sender] = true; + } + + function stopEarning() external { + isEarning[msg.sender] = false; + } } contract MockRegistrar {