diff --git a/modules/apps/29-fee/ibc_middleware_test.go b/modules/apps/29-fee/ibc_middleware_test.go index d1a857a0ca7..1515b1e3327 100644 --- a/modules/apps/29-fee/ibc_middleware_test.go +++ b/modules/apps/29-fee/ibc_middleware_test.go @@ -25,7 +25,6 @@ var ( defaultRecvFee = sdk.Coins{sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdkmath.NewInt(100)}} defaultAckFee = sdk.Coins{sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdkmath.NewInt(200)}} defaultTimeoutFee = sdk.Coins{sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdkmath.NewInt(300)}} - smallAmount = sdk.Coins{sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdkmath.NewInt(50)}} ) // Tests OnChanOpenInit on ChainA @@ -605,6 +604,8 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { packetFee types.PacketFee refundAddr sdk.AccAddress relayerAddr sdk.AccAddress + escrowAmount sdk.Coins + initialRefundAccBal sdk.Coins expRefundAccBalance sdk.Coins expPayeeAccBalance sdk.Coins ) @@ -621,10 +622,31 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { // retrieve the relayer acc balance and add the expected recv and ack fees relayerAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom)) expPayeeAccBalance = relayerAccBalance.Add(packetFee.Fee.RecvFee...).Add(packetFee.Fee.AckFee...) + }, + true, + func() { + // assert that the packet fees have been distributed + found := suite.chainA.GetSimApp().IBCFeeKeeper.HasFeesInEscrow(suite.chainA.GetContext(), packetID) + suite.Require().False(found) - // retrieve the refund acc balance and add the expected timeout fees - refundAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom)) - expRefundAccBalance = refundAccBalance.Add(packetFee.Fee.TimeoutFee...) + relayerAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom) + suite.Require().Equal(expPayeeAccBalance, sdk.NewCoins(relayerAccBalance)) + + refundAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom) + suite.Require().Equal(initialRefundAccBal, sdk.NewCoins(refundAccBalance)) + }, + }, + { + "success: some refunds", + func() { + // set timeout_fee > recv_fee + ack_fee + packetFee.Fee.TimeoutFee = packetFee.Fee.Total().Add(sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100)))...) + + escrowAmount = packetFee.Fee.Total() + + // retrieve the relayer acc balance and add the expected recv and ack fees + relayerAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom)) + expPayeeAccBalance = relayerAccBalance.Add(packetFee.Fee.RecvFee...).Add(packetFee.Fee.AckFee...) }, true, func() { @@ -635,6 +657,9 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { relayerAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom) suite.Require().Equal(expPayeeAccBalance, sdk.NewCoins(relayerAccBalance)) + // expect the correct refunds + refundCoins := packetFee.Fee.Total().Sub(packetFee.Fee.RecvFee...).Sub(packetFee.Fee.AckFee...) + expRefundAccBalance = initialRefundAccBal.Add(refundCoins...) refundAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom) suite.Require().Equal(expRefundAccBalance, sdk.NewCoins(refundAccBalance)) }, @@ -656,10 +681,6 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { // retrieve the payee acc balance and add the expected recv and ack fees payeeAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), payeeAddr, sdk.DefaultBondDenom)) expPayeeAccBalance = payeeAccBalance.Add(packetFee.Fee.RecvFee...).Add(packetFee.Fee.AckFee...) - - // retrieve the refund acc balance and add the expected timeout fees - refundAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom)) - expRefundAccBalance = refundAccBalance.Add(packetFee.Fee.TimeoutFee...) }, true, func() { @@ -671,8 +692,9 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { payeeAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), payeeAddr, sdk.DefaultBondDenom) suite.Require().Equal(expPayeeAccBalance, sdk.NewCoins(payeeAccBalance)) + // expect zero refunds refundAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom) - suite.Require().Equal(expRefundAccBalance, sdk.NewCoins(refundAccBalance)) + suite.Require().Equal(initialRefundAccBal, sdk.NewCoins(refundAccBalance)) }, }, { @@ -721,10 +743,6 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { // retrieve the relayer acc balance and add the expected ack fees relayerAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom)) expPayeeAccBalance = relayerAccBalance.Add(packetFee.Fee.AckFee...) - - // retrieve the refund acc balance and add the expected recv fees and timeout fees - refundAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom)) - expRefundAccBalance = refundAccBalance.Add(packetFee.Fee.RecvFee...).Add(packetFee.Fee.TimeoutFee...) }, true, func() { @@ -735,6 +753,8 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { relayerAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom) suite.Require().Equal(expPayeeAccBalance, sdk.NewCoins(relayerAccBalance)) + // expect only recv fee to be refunded + expRefundAccBalance = initialRefundAccBal.Add(packetFee.Fee.RecvFee...) refundAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom) suite.Require().Equal(expRefundAccBalance, sdk.NewCoins(refundAccBalance)) }, @@ -742,8 +762,7 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { { "fail: fee distribution fails and fee module is locked when escrow account does not have sufficient funds", func() { - err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromModuleToAccount(suite.chainA.GetContext(), types.ModuleName, suite.chainA.SenderAccount.GetAddress(), smallAmount) - suite.Require().NoError(err) + escrowAmount = sdk.NewCoins() }, true, func() { @@ -796,15 +815,18 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { packet := suite.CreateMockPacket() packetID = channeltypes.NewPacketID(packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) packetFee = types.NewPacketFee(types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee), refundAddr.String(), nil) + escrowAmount = packetFee.Fee.Total() + + ack = types.NewIncentivizedAcknowledgement(relayerAddr.String(), ibcmock.MockAcknowledgement.Acknowledgement(), true).Acknowledgement() + + tc.malleate() // malleate mutates test data suite.chainA.GetSimApp().IBCFeeKeeper.SetFeesInEscrow(suite.chainA.GetContext(), packetID, types.NewPacketFees([]types.PacketFee{packetFee})) - err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromAccountToModule(suite.chainA.GetContext(), refundAddr, types.ModuleName, packetFee.Fee.Total()) + err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromAccountToModule(suite.chainA.GetContext(), refundAddr, types.ModuleName, escrowAmount) suite.Require().NoError(err) - ack = types.NewIncentivizedAcknowledgement(relayerAddr.String(), ibcmock.MockAcknowledgement.Acknowledgement(), true).Acknowledgement() - - tc.malleate() // malleate mutates test data + initialRefundAccBal = sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom)) // retrieve module callbacks module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort) @@ -828,12 +850,14 @@ func (suite *FeeTestSuite) TestOnAcknowledgementPacket() { func (suite *FeeTestSuite) TestOnTimeoutPacket() { var ( - packetID channeltypes.PacketId - packetFee types.PacketFee - refundAddr sdk.AccAddress - relayerAddr sdk.AccAddress - expRefundAccBalance sdk.Coins - expPayeeAccBalance sdk.Coins + packetID channeltypes.PacketId + packetFee types.PacketFee + refundAddr sdk.AccAddress + relayerAddr sdk.AccAddress + escrowAmount sdk.Coins + initialRelayerAccBal sdk.Coins + expRefundAccBalance sdk.Coins + expPayeeAccBalance sdk.Coins ) testCases := []struct { @@ -843,15 +867,38 @@ func (suite *FeeTestSuite) TestOnTimeoutPacket() { expResult func() }{ { - "success", + "success: no refund", func() { - // retrieve the relayer acc balance and add the expected timeout fees - relayerAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom)) - expPayeeAccBalance = relayerAccBalance.Add(packetFee.Fee.TimeoutFee...) + // expect zero refunds + refundAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom)) + expRefundAccBalance = refundAccBalance + }, + true, + func() { + // assert that the packet fees have been distributed + found := suite.chainA.GetSimApp().IBCFeeKeeper.HasFeesInEscrow(suite.chainA.GetContext(), packetID) + suite.Require().False(found) + + expPayeeAccBalance = initialRelayerAccBal.Add(packetFee.Fee.TimeoutFee...) + relayerAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom) + suite.Require().Equal(expPayeeAccBalance, sdk.NewCoins(relayerAccBalance)) + + refundAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom) + suite.Require().Equal(expRefundAccBalance, sdk.NewCoins(refundAccBalance)) + }, + }, + { + "success: refund (recv_fee + ack_fee) - timeout_fee", + func() { + // set recv_fee + ack_fee > timeout_fee + packetFee.Fee.RecvFee = packetFee.Fee.Total().Add(sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100)))...) + + escrowAmount = packetFee.Fee.Total() // retrieve the refund acc balance and add the expected recv and ack fees + refundCoins := packetFee.Fee.Total().Sub(packetFee.Fee.TimeoutFee...) refundAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom)) - expRefundAccBalance = refundAccBalance.Add(packetFee.Fee.RecvFee...).Add(packetFee.Fee.AckFee...) + expRefundAccBalance = refundAccBalance.Add(refundCoins...) }, true, func() { @@ -859,6 +906,7 @@ func (suite *FeeTestSuite) TestOnTimeoutPacket() { found := suite.chainA.GetSimApp().IBCFeeKeeper.HasFeesInEscrow(suite.chainA.GetContext(), packetID) suite.Require().False(found) + expPayeeAccBalance = initialRelayerAccBal.Add(packetFee.Fee.TimeoutFee...) relayerAccBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom) suite.Require().Equal(expPayeeAccBalance, sdk.NewCoins(relayerAccBalance)) @@ -881,9 +929,9 @@ func (suite *FeeTestSuite) TestOnTimeoutPacket() { payeeAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), payeeAddr, sdk.DefaultBondDenom)) expPayeeAccBalance = payeeAccBalance.Add(packetFee.Fee.TimeoutFee...) - // retrieve the refund acc balance and add the expected recv and ack fees + // expect zero refunds refundAccBalance := sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAddr, sdk.DefaultBondDenom)) - expRefundAccBalance = refundAccBalance.Add(packetFee.Fee.RecvFee...).Add(packetFee.Fee.AckFee...) + expRefundAccBalance = refundAccBalance }, true, func() { @@ -936,8 +984,7 @@ func (suite *FeeTestSuite) TestOnTimeoutPacket() { { "fee distribution fails and fee module is locked when escrow account does not have sufficient funds", func() { - err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromModuleToAccount(suite.chainA.GetContext(), types.ModuleName, suite.chainA.SenderAccount.GetAddress(), smallAmount) - suite.Require().NoError(err) + escrowAmount = sdk.NewCoins() }, true, func() { @@ -982,12 +1029,15 @@ func (suite *FeeTestSuite) TestOnTimeoutPacket() { packet := suite.CreateMockPacket() packetID = channeltypes.NewPacketID(packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) packetFee = types.NewPacketFee(types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee), refundAddr.String(), nil) + escrowAmount = packetFee.Fee.Total() + + tc.malleate() // malleate mutates test data suite.chainA.GetSimApp().IBCFeeKeeper.SetFeesInEscrow(suite.chainA.GetContext(), packetID, types.NewPacketFees([]types.PacketFee{packetFee})) - err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromAccountToModule(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), types.ModuleName, packetFee.Fee.Total()) + err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromAccountToModule(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), types.ModuleName, escrowAmount) suite.Require().NoError(err) - tc.malleate() // malleate mutates test data + initialRelayerAccBal = sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), relayerAddr, sdk.DefaultBondDenom)) // retrieve module callbacks module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort) diff --git a/modules/apps/29-fee/keeper/escrow.go b/modules/apps/29-fee/keeper/escrow.go index e21dc85fb99..4501788d749 100644 --- a/modules/apps/29-fee/keeper/escrow.go +++ b/modules/apps/29-fee/keeper/escrow.go @@ -97,8 +97,9 @@ func (k Keeper) distributePacketFeeOnAcknowledgement(ctx sdk.Context, refundAddr // distribute fee for reverse relaying k.distributeFee(ctx, reverseRelayer, refundAddr, packetFee.Fee.AckFee) - // refund timeout fee for unused timeout - k.distributeFee(ctx, refundAddr, refundAddr, packetFee.Fee.TimeoutFee) + // refund unused amount from the escrowed fee + refundCoins := packetFee.Fee.Total().Sub(packetFee.Fee.RecvFee...).Sub(packetFee.Fee.AckFee...) + k.distributeFee(ctx, refundAddr, refundAddr, refundCoins) } // DistributePacketsFeesOnTimeout pays all the timeout fees for a given packetID while refunding the acknowledgement & receive fees to the refund account. @@ -137,14 +138,12 @@ func (k Keeper) DistributePacketFeesOnTimeout(ctx sdk.Context, timeoutRelayer sd // distributePacketFeeOnTimeout pays the timeout fee to the timeout relayer and refunds the acknowledgement & receive fee. func (k Keeper) distributePacketFeeOnTimeout(ctx sdk.Context, refundAddr, timeoutRelayer sdk.AccAddress, packetFee types.PacketFee) { - // refund receive fee for unused forward relaying - k.distributeFee(ctx, refundAddr, refundAddr, packetFee.Fee.RecvFee) - - // refund ack fee for unused reverse relaying - k.distributeFee(ctx, refundAddr, refundAddr, packetFee.Fee.AckFee) - // distribute fee for timeout relaying k.distributeFee(ctx, timeoutRelayer, refundAddr, packetFee.Fee.TimeoutFee) + + // refund unused amount from the escrowed fee + refundCoins := packetFee.Fee.Total().Sub(packetFee.Fee.TimeoutFee...) + k.distributeFee(ctx, refundAddr, refundAddr, refundCoins) } // distributeFee will attempt to distribute the escrowed fee to the receiver address. diff --git a/modules/apps/29-fee/keeper/escrow_test.go b/modules/apps/29-fee/keeper/escrow_test.go index 928238163e0..ad218113a92 100644 --- a/modules/apps/29-fee/keeper/escrow_test.go +++ b/modules/apps/29-fee/keeper/escrow_test.go @@ -55,8 +55,46 @@ func (suite *KeeperTestSuite) TestDistributeFee() { balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), forward, sdk.DefaultBondDenom) suite.Require().Equal(expectedForwardAccBal, balance) - // check if the refund acc has been refunded the timeoutFee - expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0].Add(defaultTimeoutFee[0])) + // check if the refund amount is zero + expectedRefundAccBal := refundAccBal + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(expectedRefundAccBal, balance) + + // check the module acc wallet is now empty + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.GetSimApp().IBCFeeKeeper.GetFeeModuleAddress(), sdk.DefaultBondDenom) + suite.Require().Equal(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(0)), balance) + }, + }, + { + "success: refund timeout_fee - (recv_fee + ack_fee)", + func() { + // set the timeout fee to be greater than recv + ack fee so that the refund amount is non-zero + fee.TimeoutFee = fee.Total().Add(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100))) + + packetFee = types.NewPacketFee(fee, refundAcc.String(), []string{}) + packetFees = []types.PacketFee{packetFee, packetFee} + }, + func() { + // check if fees has been deleted + packetID := channeltypes.NewPacketID(suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, 1) + suite.Require().False(suite.chainA.GetSimApp().IBCFeeKeeper.HasFeesInEscrow(suite.chainA.GetContext(), packetID)) + + // check if the reverse relayer is paid + expectedReverseAccBal := reverseRelayerBal.Add(defaultAckFee[0]).Add(defaultAckFee[0]) + balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), reverseRelayer, sdk.DefaultBondDenom) + suite.Require().Equal(expectedReverseAccBal, balance) + + // check if the forward relayer is paid + forward, err := sdk.AccAddressFromBech32(forwardRelayer) + suite.Require().NoError(err) + + expectedForwardAccBal := forwardRelayerBal.Add(defaultRecvFee[0]).Add(defaultRecvFee[0]) + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), forward, sdk.DefaultBondDenom) + suite.Require().Equal(expectedForwardAccBal, balance) + + // check if the refund amount is correct + refundCoins := fee.Total().Sub(defaultRecvFee[0]).Sub(defaultAckFee[0]).MulInt(sdkmath.NewInt(2)) + expectedRefundAccBal := refundAccBal.Add(refundCoins[0]) balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) suite.Require().Equal(expectedRefundAccBal, balance) @@ -68,6 +106,9 @@ func (suite *KeeperTestSuite) TestDistributeFee() { { "success: refund account is module account", func() { + // set the timeout fee to be greater than recv + ack fee so that the refund amount is non-zero + fee.TimeoutFee = fee.Total().Add(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100))) + refundAcc = suite.chainA.GetSimApp().AccountKeeper.GetModuleAddress(mock.ModuleName) packetFee = types.NewPacketFee(fee, refundAcc.String(), []string{}) @@ -78,8 +119,9 @@ func (suite *KeeperTestSuite) TestDistributeFee() { suite.Require().NoError(err) }, func() { - // check if the refund acc has been refunded the timeoutFee - expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0]).Add(defaultTimeoutFee[0]) + // check if the refund acc has been refunded the correct amount + refundCoins := fee.Total().Sub(defaultRecvFee[0]).Sub(defaultAckFee[0]).MulInt(sdkmath.NewInt(2)) + expectedRefundAccBal := refundAccBal.Add(refundCoins[0]) balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) suite.Require().Equal(expectedRefundAccBal, balance) }, @@ -113,8 +155,8 @@ func (suite *KeeperTestSuite) TestDistributeFee() { forwardRelayer = "invalid address" }, func() { - // check if the refund acc has been refunded the timeoutFee & recvFee - expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0]).Add(defaultRecvFee[0]).Add(defaultTimeoutFee[0]).Add(defaultRecvFee[0]) + // check if the refund acc has been refunded the recvFee + expectedRefundAccBal := refundAccBal.Add(defaultRecvFee[0]).Add(defaultRecvFee[0]) balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) suite.Require().Equal(expectedRefundAccBal, balance) }, @@ -129,7 +171,7 @@ func (suite *KeeperTestSuite) TestDistributeFee() { }, func() { // check if the refund acc has been refunded the timeoutFee & recvFee - expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0]).Add(defaultRecvFee[0]).Add(defaultTimeoutFee[0]).Add(defaultRecvFee[0]) + expectedRefundAccBal := refundAccBal.Add(defaultRecvFee[0]).Add(defaultRecvFee[0]) balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) suite.Require().Equal(expectedRefundAccBal, balance) }, @@ -143,15 +185,18 @@ func (suite *KeeperTestSuite) TestDistributeFee() { reverseRelayer = suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress() }, func() { - // check if the refund acc has been refunded the timeoutFee & ackFee - expectedRefundAccBal := refundAccBal.Add(defaultTimeoutFee[0]).Add(defaultAckFee[0]).Add(defaultTimeoutFee[0]).Add(defaultAckFee[0]) + // check if the refund acc has been refunded the ackFee + expectedRefundAccBal := refundAccBal.Add(defaultAckFee[0]).Add(defaultAckFee[0]) balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) suite.Require().Equal(expectedRefundAccBal, balance) }, }, { - "invalid refund address: no-op, timeout fee remains in escrow", + "invalid refund address: no-op, timeout_fee - (recv_fee + ack_fee) remains in escrow", func() { + // set the timeout fee to be greater than recv + ack fee so that the refund amount is non-zero + fee.TimeoutFee = fee.Total().Add(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100))) + packetFee = types.NewPacketFee(fee, refundAcc.String(), []string{}) packetFees = []types.PacketFee{packetFee, packetFee} @@ -160,7 +205,8 @@ func (suite *KeeperTestSuite) TestDistributeFee() { }, func() { // check if the module acc contains the timeoutFee - expectedModuleAccBal := sdk.NewCoin(sdk.DefaultBondDenom, defaultTimeoutFee.Add(defaultTimeoutFee...).AmountOf(sdk.DefaultBondDenom)) + refundCoins := fee.Total().Sub(defaultRecvFee[0]).Sub(defaultAckFee[0]).MulInt(sdkmath.NewInt(2)) + expectedModuleAccBal := sdk.NewCoin(sdk.DefaultBondDenom, refundCoins.AmountOf(sdk.DefaultBondDenom)) balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.GetSimApp().IBCFeeKeeper.GetFeeModuleAddress(), sdk.DefaultBondDenom) suite.Require().Equal(expectedModuleAccBal, balance) }, @@ -207,6 +253,7 @@ func (suite *KeeperTestSuite) TestDistributePacketFeesOnTimeout() { timeoutRelayerBal sdk.Coin refundAcc sdk.AccAddress refundAccBal sdk.Coin + fee types.Fee packetFee types.PacketFee packetFees []types.PacketFee ) @@ -217,7 +264,7 @@ func (suite *KeeperTestSuite) TestDistributePacketFeesOnTimeout() { expResult func() }{ { - "success", + "success: no refund", func() {}, func() { // check if the timeout relayer is paid @@ -225,8 +272,33 @@ func (suite *KeeperTestSuite) TestDistributePacketFeesOnTimeout() { balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), timeoutRelayer, sdk.DefaultBondDenom) suite.Require().Equal(expectedTimeoutAccBal, balance) - // check if the refund acc has been refunded the recv/ack fees - expectedRefundAccBal := refundAccBal.Add(defaultAckFee[0]).Add(defaultAckFee[0]).Add(defaultRecvFee[0]).Add(defaultRecvFee[0]) + // check if the refund amount is zero + expectedRefundAccBal := refundAccBal + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(expectedRefundAccBal, balance) + + // check the module acc wallet is now empty + balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.GetSimApp().IBCFeeKeeper.GetFeeModuleAddress(), sdk.DefaultBondDenom) + suite.Require().Equal(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(0)), balance) + }, + }, + { + "success: refund (recv_fee + ack_fee) - timeout_fee", + func() { + // set the recv + ack fee to be greater than timeout fee so that the refund amount is non-zero + fee.RecvFee = fee.RecvFee.Add(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100))) + packetFee = types.NewPacketFee(fee, refundAcc.String(), []string{}) + packetFees = []types.PacketFee{packetFee, packetFee} + }, + func() { + // check if the timeout relayer is paid + expectedTimeoutAccBal := timeoutRelayerBal.Add(defaultTimeoutFee[0]).Add(defaultTimeoutFee[0]) + balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), timeoutRelayer, sdk.DefaultBondDenom) + suite.Require().Equal(expectedTimeoutAccBal, balance) + + // check if the refund amount is correct + refundCoins := fee.Total().Sub(defaultTimeoutFee[0]).MulInt(sdkmath.NewInt(2)) + expectedRefundAccBal := refundAccBal.Add(refundCoins[0]) balance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) suite.Require().Equal(expectedRefundAccBal, balance) @@ -265,14 +337,21 @@ func (suite *KeeperTestSuite) TestDistributePacketFeesOnTimeout() { }, }, { - "invalid refund address: no-op, recv and ack fees remain in escrow", + "invalid refund address: no-op, (recv_fee + ack_fee) - timeout_fee remain in escrow", func() { + // set the recv + ack fee to be greater than timeout fee so that the refund amount is non-zero + fee.RecvFee = fee.RecvFee.Add(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100))) + packetFee = types.NewPacketFee(fee, refundAcc.String(), []string{}) + packetFees = []types.PacketFee{packetFee, packetFee} + packetFees[0].RefundAddress = suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress().String() packetFees[1].RefundAddress = suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress().String() }, func() { - // check if the module acc contains the timeoutFee - expectedModuleAccBal := sdk.NewCoin(sdk.DefaultBondDenom, defaultRecvFee.Add(defaultRecvFee[0]).Add(defaultAckFee[0]).Add(defaultAckFee[0]).AmountOf(sdk.DefaultBondDenom)) + // check if the module acc contains the correct amount of fees + refundCoins := fee.Total().Sub(defaultTimeoutFee[0]).MulInt(sdkmath.NewInt(2)) + + expectedModuleAccBal := refundCoins[0] balance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.GetSimApp().IBCFeeKeeper.GetFeeModuleAddress(), sdk.DefaultBondDenom) suite.Require().Equal(expectedModuleAccBal, balance) }, @@ -291,18 +370,18 @@ func (suite *KeeperTestSuite) TestDistributePacketFeesOnTimeout() { refundAcc = suite.chainA.SenderAccount.GetAddress() packetID := channeltypes.NewPacketID(suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, 1) - fee := types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + fee = types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) // escrow the packet fees & store the fees in state packetFee = types.NewPacketFee(fee, refundAcc.String(), []string{}) packetFees = []types.PacketFee{packetFee, packetFee} + tc.malleate() + suite.chainA.GetSimApp().IBCFeeKeeper.SetFeesInEscrow(suite.chainA.GetContext(), packetID, types.NewPacketFees(packetFees)) - err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromAccountToModule(suite.chainA.GetContext(), refundAcc, types.ModuleName, packetFee.Fee.Total().Add(packetFee.Fee.Total()...)) + err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromAccountToModule(suite.chainA.GetContext(), refundAcc, types.ModuleName, fee.Total().Add(fee.Total()...)) suite.Require().NoError(err) - tc.malleate() - // fetch the account balances before fee distribution (forward, reverse, refund) timeoutRelayerBal = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), timeoutRelayer, sdk.DefaultBondDenom) refundAccBal = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) diff --git a/modules/apps/29-fee/keeper/events_test.go b/modules/apps/29-fee/keeper/events_test.go index a3b5a87e414..c33315c790b 100644 --- a/modules/apps/29-fee/keeper/events_test.go +++ b/modules/apps/29-fee/keeper/events_test.go @@ -161,7 +161,7 @@ func (suite *KeeperTestSuite) TestDistributeFeeEvent() { sdk.NewEvent( types.EventTypeDistributeFee, sdk.NewAttribute(types.AttributeKeyReceiver, suite.chainA.SenderAccount.GetAddress().String()), - sdk.NewAttribute(types.AttributeKeyFee, defaultTimeoutFee.String()), + sdk.NewAttribute(types.AttributeKeyFee, sdk.NewCoins().String()), ), }.ToABCIEvents() diff --git a/modules/apps/29-fee/keeper/export_test.go b/modules/apps/29-fee/keeper/export_test.go index 382dde7a805..e7fb25fb6c8 100644 --- a/modules/apps/29-fee/keeper/export_test.go +++ b/modules/apps/29-fee/keeper/export_test.go @@ -1,12 +1,18 @@ package keeper -/* - This file is to allow for unexported functions and fields to be accessible to the testing package. -*/ +import ( + sdk "github.com/cosmos/cosmos-sdk/types" -import porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types" + "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/types" + porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types" +) // GetICS4Wrapper is a getter for the keeper's ICS4Wrapper. func (k *Keeper) GetICS4Wrapper() porttypes.ICS4Wrapper { return k.ics4Wrapper } + +// LegacyTotal is a wrapper for the legacyTotal function for testing. +func LegacyTotal(f types.Fee) sdk.Coins { + return legacyTotal(f) +} diff --git a/modules/apps/29-fee/keeper/migrations.go b/modules/apps/29-fee/keeper/migrations.go new file mode 100644 index 00000000000..3eae34f7e1e --- /dev/null +++ b/modules/apps/29-fee/keeper/migrations.go @@ -0,0 +1,52 @@ +package keeper + +import ( + storetypes "cosmossdk.io/store/types" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/types" +) + +// Migrator is a struct for handling in-place store migrations. +type Migrator struct { + keeper Keeper +} + +// NewMigrator returns a new Migrator. +func NewMigrator(keeper Keeper) Migrator { + return Migrator{ + keeper: keeper, + } +} + +// Migrate1to2 migrates ibc-fee module from ConsensusVersion 1 to 2 +// by refunding leftover fees to the refund address. +func (m Migrator) Migrate1to2(ctx sdk.Context) error { + store := ctx.KVStore(m.keeper.storeKey) + iterator := storetypes.KVStorePrefixIterator(store, []byte(types.FeesInEscrowPrefix)) + defer sdk.LogDeferred(ctx.Logger(), func() error { return iterator.Close() }) + + for ; iterator.Valid(); iterator.Next() { + feesInEscrow := m.keeper.MustUnmarshalFees(iterator.Value()) + + for _, packetFee := range feesInEscrow.PacketFees { + refundCoins := legacyTotal(packetFee.Fee).Sub(packetFee.Fee.Total()...) + + refundAddr, err := sdk.AccAddressFromBech32(packetFee.RefundAddress) + if err != nil { + return err + } + + m.keeper.distributeFee(ctx, refundAddr, refundAddr, refundCoins) + } + } + + return nil +} + +// legacyTotal returns the legacy total amount for a given Fee +// The total amount is the RecvFee + AckFee + TimeoutFee +func legacyTotal(f types.Fee) sdk.Coins { + return f.RecvFee.Add(f.AckFee...).Add(f.TimeoutFee...) +} diff --git a/modules/apps/29-fee/keeper/migrations_test.go b/modules/apps/29-fee/keeper/migrations_test.go new file mode 100644 index 00000000000..fc4fb68802b --- /dev/null +++ b/modules/apps/29-fee/keeper/migrations_test.go @@ -0,0 +1,167 @@ +package keeper_test + +import ( + sdkmath "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" + minttypes "github.com/cosmos/cosmos-sdk/x/mint/types" + + "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/keeper" + "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/types" + channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" +) + +func (suite *KeeperTestSuite) TestLegacyTotal() { + fee := types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + expLegacyTotal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(600))) + + suite.Require().Equal(expLegacyTotal, keeper.LegacyTotal(fee)) +} + +func (suite *KeeperTestSuite) TestMigrate1to2() { + var ( + packetID channeltypes.PacketId + moduleAcc sdk.AccAddress + refundAcc sdk.AccAddress + initRefundAccBal sdk.Coins + initModuleAccBal sdk.Coins + packetFees []types.PacketFee + ) + + testCases := []struct { + name string + malleate func() + assert func(error) + }{ + { + "success: no fees in escrow", + func() {}, + func(err error) { + suite.Require().NoError(err) + suite.Require().Empty(suite.chainA.GetSimApp().IBCFeeKeeper.GetAllIdentifiedPacketFees(suite.chainA.GetContext())) + + // refund account balance should not change + refundAccBal := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(initRefundAccBal[0], refundAccBal) + + // module account balance should not change + moduleAccBal := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), moduleAcc, sdk.DefaultBondDenom) + suite.Require().True(moduleAccBal.IsZero()) + }, + }, + { + "success: one fee in escrow", + func() { + fee := types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + packetFee := types.NewPacketFee(fee, refundAcc.String(), []string(nil)) + packetFees = []types.PacketFee{packetFee} + }, + func(err error) { + suite.Require().NoError(err) + + // ensure that the packet fees are unmodified + expPacketFees := []types.IdentifiedPacketFees{ + types.NewIdentifiedPacketFees(packetID, packetFees), + } + suite.Require().Equal(expPacketFees, suite.chainA.GetSimApp().IBCFeeKeeper.GetAllIdentifiedPacketFees(suite.chainA.GetContext())) + + unusedFee := sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(300)) + // refund account balance should increase + refundAccBal := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), refundAcc, sdk.DefaultBondDenom) + suite.Require().Equal(initRefundAccBal.Add(unusedFee)[0], refundAccBal) + + // module account balance should decrease + moduleAccBal := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), moduleAcc, sdk.DefaultBondDenom) + suite.Require().Equal(initModuleAccBal.Sub(unusedFee)[0], moduleAccBal) + }, + }, + { + "success: many fees with multiple denoms in escrow", + func() { + fee1 := types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + packetFee1 := types.NewPacketFee(fee1, refundAcc.String(), []string(nil)) + + // mint some tokens to the refund account + denom2 := "denom" + err := suite.chainA.GetSimApp().MintKeeper.MintCoins(suite.chainA.GetContext(), sdk.NewCoins(sdk.NewCoin(denom2, sdkmath.NewInt(1000)))) + suite.Require().NoError(err) + err = suite.chainA.GetSimApp().BankKeeper.SendCoinsFromModuleToAccount(suite.chainA.GetContext(), minttypes.ModuleName, refundAcc, sdk.NewCoins(sdk.NewCoin(denom2, sdkmath.NewInt(1000)))) + suite.Require().NoError(err) + + defaultFee2 := sdk.NewCoins(sdk.NewCoin(denom2, sdkmath.NewInt(100))) + fee2 := types.NewFee(defaultFee2, defaultFee2, defaultFee2) + packetFee2 := types.NewPacketFee(fee2, refundAcc.String(), []string(nil)) + + packetFees = []types.PacketFee{packetFee1, packetFee2, packetFee1} + }, + func(err error) { + denom2 := "denom" + + suite.Require().NoError(err) + + // ensure that the packet fees are unmodified + expPacketFees := []types.IdentifiedPacketFees{ + types.NewIdentifiedPacketFees(packetID, packetFees), + } + suite.Require().Equal(expPacketFees, suite.chainA.GetSimApp().IBCFeeKeeper.GetAllIdentifiedPacketFees(suite.chainA.GetContext())) + + unusedFee := sdk.NewCoins( + sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(600)), + sdk.NewCoin(denom2, sdkmath.NewInt(100)), + ) + // refund account balance should increase + refundAccBal := suite.chainA.GetSimApp().BankKeeper.GetAllBalances(suite.chainA.GetContext(), refundAcc) + suite.Require().Equal(initRefundAccBal.Add(unusedFee...), refundAccBal) + + // module account balance should decrease + moduleAccBal := suite.chainA.GetSimApp().BankKeeper.GetAllBalances(suite.chainA.GetContext(), moduleAcc) + suite.Require().Equal(initModuleAccBal.Sub(unusedFee...).Sort(), moduleAccBal) + }, + }, + { + "failure: invalid refund address", + func() { + fee := types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + packetFee := types.NewPacketFee(fee, "invalid", []string{}) + packetFees = []types.PacketFee{packetFee} + }, + func(err error) { + suite.Require().Error(err) + }, + }, + } + + for _, tc := range testCases { + tc := tc + + suite.SetupTest() + suite.coordinator.Setup(suite.path) + + refundAcc = suite.chainA.SenderAccount.GetAddress() + moduleAcc = suite.chainA.GetSimApp().AccountKeeper.GetModuleAddress(types.ModuleName) + packetID = channeltypes.NewPacketID(suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, 1) + packetFees = nil + + tc.malleate() + + feesToModule := sdk.NewCoins() + for _, packetFee := range packetFees { + feesToModule = feesToModule.Add(keeper.LegacyTotal(packetFee.Fee)...) + } + + if !feesToModule.IsZero() { + // escrow the packet fees & store the fees in state + suite.chainA.GetSimApp().IBCFeeKeeper.SetFeesInEscrow(suite.chainA.GetContext(), packetID, types.NewPacketFees(packetFees)) + err := suite.chainA.GetSimApp().BankKeeper.SendCoinsFromAccountToModule(suite.chainA.GetContext(), refundAcc, types.ModuleName, feesToModule) + suite.Require().NoError(err) + } + + initRefundAccBal = suite.chainA.GetSimApp().BankKeeper.GetAllBalances(suite.chainA.GetContext(), refundAcc) + initModuleAccBal = suite.chainA.GetSimApp().BankKeeper.GetAllBalances(suite.chainA.GetContext(), moduleAcc) + + migrator := keeper.NewMigrator(suite.chainA.GetSimApp().IBCFeeKeeper) + err := migrator.Migrate1to2(suite.chainA.GetContext()) + + tc.assert(err) + } +} diff --git a/modules/apps/29-fee/module.go b/modules/apps/29-fee/module.go index c627d82ec61..327522f5339 100644 --- a/modules/apps/29-fee/module.go +++ b/modules/apps/29-fee/module.go @@ -108,6 +108,11 @@ func NewAppModule(k keeper.Keeper) AppModule { func (am AppModule) RegisterServices(cfg module.Configurator) { types.RegisterMsgServer(cfg.MsgServer(), am.keeper) types.RegisterQueryServer(cfg.QueryServer(), am.keeper) + + m := keeper.NewMigrator(am.keeper) + if err := cfg.RegisterMigration(types.ModuleName, 1, m.Migrate1to2); err != nil { + panic(fmt.Errorf("failed to migrate ibc-fee module from version 1 to 2 (refund leftover fees): %v", err)) + } } // InitGenesis performs genesis initialization for the ibc-29-fee module. It returns @@ -126,7 +131,7 @@ func (am AppModule) ExportGenesis(ctx sdk.Context, cdc codec.JSONCodec) json.Raw } // ConsensusVersion implements AppModule/ConsensusVersion. -func (AppModule) ConsensusVersion() uint64 { return 1 } +func (AppModule) ConsensusVersion() uint64 { return 2 } // AppModuleSimulation functions diff --git a/modules/apps/29-fee/transfer_test.go b/modules/apps/29-fee/transfer_test.go index d283cdc1713..d5163f68657 100644 --- a/modules/apps/29-fee/transfer_test.go +++ b/modules/apps/29-fee/transfer_test.go @@ -65,7 +65,7 @@ func (suite *FeeTestSuite) TestFeeTransfer() { ) suite.Require().Equal( - fee.AckFee.Add(fee.TimeoutFee...), // ack fee paid, timeout fee refunded + fee.AckFee, // ack fee paid, no refund needed since timeout_fee = recv_fee + ack_fee sdk.NewCoins(suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom)).Sub(originalChainASenderAccountBalance[0])) } diff --git a/modules/apps/29-fee/types/fee.go b/modules/apps/29-fee/types/fee.go index a34715db08e..6d8521bc1e7 100644 --- a/modules/apps/29-fee/types/fee.go +++ b/modules/apps/29-fee/types/fee.go @@ -59,9 +59,12 @@ func NewFee(recvFee, ackFee, timeoutFee sdk.Coins) Fee { } } -// Total returns the total amount for a given Fee +// Total returns the total amount for a given Fee. +// The total amount is the Max(RecvFee + AckFee, TimeoutFee), +// This is because either the packet is received and acknowledged or it timeouts func (f Fee) Total() sdk.Coins { - return f.RecvFee.Add(f.AckFee...).Add(f.TimeoutFee...) + // maximum returns the denomwise maximum of two sets of coins + return f.RecvFee.Add(f.AckFee...).Max(f.TimeoutFee) } // Validate asserts that each Fee is valid and all three Fees are not empty or zero diff --git a/modules/apps/29-fee/types/fee_test.go b/modules/apps/29-fee/types/fee_test.go index fe4e2b24d95..ab44e645d48 100644 --- a/modules/apps/29-fee/types/fee_test.go +++ b/modules/apps/29-fee/types/fee_test.go @@ -34,10 +34,95 @@ var ( const invalidAddress = "invalid-address" func TestFeeTotal(t *testing.T) { - fee := types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + var fee types.Fee - total := fee.Total() - require.Equal(t, sdkmath.NewInt(600), total.AmountOf(sdk.DefaultBondDenom)) + testCases := []struct { + name string + malleate func() + expTotal sdk.Coins + }{ + { + "success", + func() {}, + sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(300))), + }, + { + "success: empty fees", + func() { + fee = types.NewFee(sdk.NewCoins(), sdk.NewCoins(), sdk.NewCoins()) + }, + sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(0))), + }, + { + "success: multiple denoms", + func() { + fee = types.NewFee( + sdk.NewCoins( + defaultRecvFee[0], + sdk.NewCoin("denom", sdkmath.NewInt(300)), + ), + sdk.NewCoins( + defaultAckFee[0], + sdk.NewCoin("denom", sdkmath.NewInt(200)), + ), + sdk.NewCoins( + defaultTimeoutFee[0], + sdk.NewCoin("denom", sdkmath.NewInt(100)), + ), + ) + }, + sdk.NewCoins( + sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(300)), + sdk.NewCoin("denom", sdkmath.NewInt(500)), + ), + }, + { + "success: many denoms", + func() { + fee = types.NewFee( + sdk.NewCoins( + defaultRecvFee[0], + sdk.NewCoin("denom", sdkmath.NewInt(200)), + sdk.NewCoin("denom4", sdkmath.NewInt(100)), + sdk.NewCoin("denom5", sdkmath.NewInt(300)), + ), + sdk.NewCoins( + defaultAckFee[0], + sdk.NewCoin("denom", sdkmath.NewInt(200)), + sdk.NewCoin("denom2", sdkmath.NewInt(100)), + sdk.NewCoin("denom3", sdkmath.NewInt(300)), + sdk.NewCoin("denom4", sdkmath.NewInt(100)), + ), + sdk.NewCoins( + defaultTimeoutFee[0], + sdk.NewCoin("denom", sdkmath.NewInt(100)), + sdk.NewCoin("denom2", sdkmath.NewInt(200)), + sdk.NewCoin("denom5", sdkmath.NewInt(300)), + ), + ) + }, + sdk.NewCoins( + sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(300)), + sdk.NewCoin("denom", sdkmath.NewInt(400)), + sdk.NewCoin("denom2", sdkmath.NewInt(200)), + sdk.NewCoin("denom3", sdkmath.NewInt(300)), + sdk.NewCoin("denom4", sdkmath.NewInt(200)), + sdk.NewCoin("denom5", sdkmath.NewInt(300)), + ), + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + fee = types.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + + tc.malleate() // malleate mutates test data + + require.Equal(t, tc.expTotal, fee.Total()) + }) + } } func TestPacketFeeValidation(t *testing.T) { diff --git a/modules/apps/callbacks/callbacks_test.go b/modules/apps/callbacks/callbacks_test.go index 9796913cef6..7f5d5185aa8 100644 --- a/modules/apps/callbacks/callbacks_test.go +++ b/modules/apps/callbacks/callbacks_test.go @@ -264,8 +264,9 @@ func (s *CallbacksTestSuite) AssertHasExecutedExpectedCallbackWithFee( sdk.NewCoins(GetSimApp(s.chainA).BankKeeper.GetBalance(s.chainA.GetContext(), s.chainB.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom)), ) + refundCoins := fee.Total().Sub(fee.RecvFee...).Sub(fee.AckFee...) s.Require().Equal( - fee.AckFee.Add(fee.TimeoutFee...), // ack fee paid, timeout fee refunded + fee.AckFee.Add(refundCoins...), // ack fee paid, and refund processed sdk.NewCoins( GetSimApp(s.chainA).BankKeeper.GetBalance( s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(),