diff --git a/protocol/x/affiliates/keeper/keeper.go b/protocol/x/affiliates/keeper/keeper.go index f942e46418..d08d32fb8d 100644 --- a/protocol/x/affiliates/keeper/keeper.go +++ b/protocol/x/affiliates/keeper/keeper.go @@ -339,34 +339,42 @@ func (k Keeper) GetAffiliateWhitelist(ctx sdk.Context) (types.AffiliateWhitelist } return affiliateWhitelist, nil } + func (k Keeper) AggregateAffiliateReferredVolumeForFills( ctx sdk.Context, ) error { blockStats := k.statsKeeper.GetBlockStats(ctx) referredByCache := make(map[string]string) + for _, fill := range blockStats.Fills { - // Add taker's referred volume to the cache - if _, ok := referredByCache[fill.Taker]; !ok { - referredByAddrTaker, found := k.GetReferredBy(ctx, fill.Taker) - if !found { - continue + // Process taker's referred volume + referredByAddrTaker, cached := referredByCache[fill.Taker] + if !cached { + var found bool + referredByAddrTaker, found = k.GetReferredBy(ctx, fill.Taker) + if found { + referredByCache[fill.Taker] = referredByAddrTaker } - referredByCache[fill.Taker] = referredByAddrTaker } - if err := k.AddReferredVolume(ctx, referredByCache[fill.Taker], lib.BigU(fill.Notional)); err != nil { - return err + if referredByAddrTaker != "" { + if err := k.AddReferredVolume(ctx, referredByAddrTaker, lib.BigU(fill.Notional)); err != nil { + return err + } } - // Add maker's referred volume to the cache - if _, ok := referredByCache[fill.Maker]; !ok { - referredByAddrMaker, found := k.GetReferredBy(ctx, fill.Maker) - if !found { - continue + // Process maker's referred volume + referredByAddrMaker, cached := referredByCache[fill.Maker] + if !cached { + var found bool + referredByAddrMaker, found = k.GetReferredBy(ctx, fill.Maker) + if found { + referredByCache[fill.Maker] = referredByAddrMaker } - referredByCache[fill.Maker] = referredByAddrMaker } - if err := k.AddReferredVolume(ctx, referredByCache[fill.Maker], lib.BigU(fill.Notional)); err != nil { - return err + if referredByAddrMaker != "" { + if err := k.AddReferredVolume(ctx, referredByAddrMaker, lib.BigU(fill.Notional)); err != nil { + return err + } } } return nil diff --git a/protocol/x/affiliates/keeper/keeper_test.go b/protocol/x/affiliates/keeper/keeper_test.go index 96b25b83e5..b6bac29ba6 100644 --- a/protocol/x/affiliates/keeper/keeper_test.go +++ b/protocol/x/affiliates/keeper/keeper_test.go @@ -751,6 +751,29 @@ func TestAggregateAffiliateReferredVolumeForFills(t *testing.T) { }) }, }, + { + name: "2 referrals, takers not referred, maker referred", + referrals: 2, + expectedVolume: big.NewInt(300_000_000_000), + setup: func(t *testing.T, ctx sdk.Context, k *keeper.Keeper, statsKeeper *statskeeper.Keeper) { + err := k.RegisterAffiliate(ctx, maker, affiliate) + require.NoError(t, err) + statsKeeper.SetBlockStats(ctx, &statstypes.BlockStats{ + Fills: []*statstypes.BlockStats_Fill{ + { + Taker: referee1, + Maker: maker, + Notional: 100_000_000_000, + }, + { + Taker: referee2, + Maker: maker, + Notional: 200_000_000_000, + }, + }, + }) + }, + }, } for _, tc := range testCases {