Skip to content

Commit

Permalink
refactor(solana-contracts): add Fees trait to unify fee handling for …
Browse files Browse the repository at this point in the history
…propeller-variant ixs
  • Loading branch information
swimricky committed Oct 5, 2022
1 parent f702f2b commit a230619
Show file tree
Hide file tree
Showing 18 changed files with 503 additions and 361 deletions.
12 changes: 12 additions & 0 deletions packages/solana-contracts/programs/propeller/src/fees.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use {crate::TOKEN_COUNT, anchor_lang::prelude::*, two_pool::BorshDecimal};

pub trait Fees {
fn calculate_fees_in_lamports(&self) -> Result<u64>;
fn convert_fees_to_swim_usd_atomic(
&self,
fee_in_lamports: u64,
marginal_prices: [BorshDecimal; TOKEN_COUNT],
max_staleness: i64,
) -> Result<u64>;
fn track_and_transfer_fees(&mut self, fees_in_swim_usd: u64) -> Result<()>;
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ use {
std::convert::TryInto,
two_pool::state::TwoPool,
};
use {
crate::{convert_fees_to_swim_usd_atomic_2, get_lamports_intermediate_token_price, Fees},
two_pool::BorshDecimal,
};

pub const SWIM_USD_TO_TOKEN_NUMBER: u16 = 0;

Expand Down Expand Up @@ -473,7 +477,27 @@ impl<'info> PropellerProcessSwimPayload<'info> {
/// Calculates, transfer and tracks fees
/// returns fees_in_token_bridge_mint
fn handle_fees(&mut self) -> Result<u64> {
let fees_in_swim_usd_atomic = self.calculate_fees()?;
let fees_in_lamports = self.calculate_fees_in_lamports()?;
let marginal_prices = two_pool::cpi::marginal_prices(CpiContext::new(
self.process_swim_payload.two_pool_program.to_account_info(),
two_pool::cpi::accounts::MarginalPrices {
pool: self.marginal_price_pool.to_account_info(),
pool_token_account_0: self.marginal_price_pool_token_0_account.to_account_info(),
pool_token_account_1: self.marginal_price_pool_token_1_account.to_account_info(),
lp_mint: self.marginal_price_pool_lp_mint.to_account_info(),
},
))?;
let fees_in_swim_usd_atomic = convert_fees_to_swim_usd_atomic_2(
fees_in_lamports,
&self.process_swim_payload.propeller,
&self.marginal_price_pool_lp_mint,
marginal_prices.get(),
&self.marginal_price_pool,
&self.aggregator,
i64::MAX,
)?;
// let fees_in_swim_usd_atomic =
// let fees_in_swim_usd_atomic = self.calculate_fees()?;
let propeller = &self.process_swim_payload.propeller;
let token_program = &self.process_swim_payload.token_program;
msg!("fees_in_swim_usd_atomic: {:?}", fees_in_swim_usd_atomic);
Expand All @@ -496,7 +520,7 @@ impl<'info> PropellerProcessSwimPayload<'info> {
Ok(fees_in_swim_usd_atomic)
}

fn calculate_fees(&self) -> Result<u64> {
fn calculate_fees_in_lamports(&self) -> Result<u64> {
//TODO: this is in lamports/SOL. need in swimUSD.
// for (secp + verify) & postVAA, need to implement a fee tracking mechanism since there's no way to
// credit the payer during that step. must be some type of "deferred" fees
Expand All @@ -506,12 +530,6 @@ impl<'info> PropellerProcessSwimPayload<'info> {
let swim_payload_message = &self.process_swim_payload.swim_payload_message;
let propeller_process_swim_payload_fees = propeller.process_swim_payload_fee;

let two_pool_program = &self.process_swim_payload.two_pool_program;
let marginal_price_pool = &self.marginal_price_pool;
let marginal_price_pool_token_0_account = &self.marginal_price_pool_token_0_account;
let marginal_price_pool_token_1_account = &self.marginal_price_pool_token_1_account;
let marginal_price_pool_lp_mint = &self.marginal_price_pool_lp_mint;

let swim_claim_rent_exempt_fees = rent.minimum_balance(8 + SwimClaim::LEN);
let gas_kickstart_amount = if swim_payload_message.gas_kickstart { propeller.gas_kickstart_amount } else { 0 };
let fee_in_lamports = swim_claim_rent_exempt_fees
Expand All @@ -531,23 +549,28 @@ impl<'info> PropellerProcessSwimPayload<'info> {
gas_kickstart_amount,
fee_in_lamports
);
Ok(fee_in_lamports)
}

let cpi_ctx = CpiContext::new(
two_pool_program.to_account_info(),
fn calculate_fees(&self) -> Result<u64> {
let fee_in_lamports = self.calculate_fees_in_lamports()?;

let marginal_prices = two_pool::cpi::marginal_prices(CpiContext::new(
self.process_swim_payload.two_pool_program.to_account_info(),
two_pool::cpi::accounts::MarginalPrices {
pool: marginal_price_pool.to_account_info(),
pool_token_account_0: marginal_price_pool_token_0_account.to_account_info(),
pool_token_account_1: marginal_price_pool_token_1_account.to_account_info(),
lp_mint: marginal_price_pool_lp_mint.to_account_info(),
pool: self.marginal_price_pool.to_account_info(),
pool_token_account_0: self.marginal_price_pool_token_0_account.to_account_info(),
pool_token_account_1: self.marginal_price_pool_token_1_account.to_account_info(),
lp_mint: self.marginal_price_pool_lp_mint.to_account_info(),
},
);
let fees_in_swim_usd_atomic = convert_fees_to_swim_usd_atomic(
))?;

let fees_in_swim_usd_atomic = convert_fees_to_swim_usd_atomic_2(
fee_in_lamports,
&propeller,
&marginal_price_pool_lp_mint,
// ctx.accounts.into_marginal_prices(),
cpi_ctx,
&marginal_price_pool,
&self.process_swim_payload.propeller,
&self.marginal_price_pool_lp_mint,
marginal_prices.get(),
&self.marginal_price_pool,
&self.aggregator,
i64::MAX,
)?;
Expand Down Expand Up @@ -802,17 +825,14 @@ pub struct PropellerProcessSwimPayloadFallback<'info> {
impl<'info> PropellerProcessSwimPayloadFallback<'info> {
pub fn accounts(ctx: &Context<PropellerProcessSwimPayloadFallback>) -> Result<()> {
require_keys_eq!(ctx.accounts.owner.key(), ctx.accounts.swim_payload_message.owner);
let (expected_token_id_map_address, _bump) = Pubkey::find_program_address(
validate_marginal_prices_pool_accounts(
&ctx.accounts.propeller,
&ctx.accounts.marginal_price_pool.key(),
&[
b"propeller".as_ref(),
b"token_id".as_ref(),
ctx.accounts.propeller.key().as_ref(),
ctx.accounts.swim_payload_message.target_token_id.to_le_bytes().as_ref(),
ctx.accounts.marginal_price_pool_token_0_account.mint,
ctx.accounts.marginal_price_pool_token_1_account.mint,
],
ctx.program_id,
);
//Note: the address should at least be valid even though it doesn't exist.
require_keys_eq!(expected_token_id_map_address, ctx.accounts.token_id_map.key());
)?;
msg!("Passed PropellerProcessSwimPayloadFallback::accounts() check");
Ok(())
}
Expand Down Expand Up @@ -1022,6 +1042,116 @@ impl<'info> PropellerProcessSwimPayloadFallback<'info> {
}
}

/*
impl Fees for PropellerProcessSwimPayloadFallback {
fn calculate_fees_in_lamports(&self) -> Result<u64> {
let rent = Rent::get()?;
let propeller = &self.propeller;
let swim_payload_message = &self.swim_payload_message;
let propeller_process_swim_payload_fees = propeller.process_swim_payload_fee;
let two_pool_program = &self.two_pool_program;
let marginal_price_pool = &self.marginal_price_pool;
let marginal_price_pool_token_0_account = &self.marginal_price_pool_token_0_account;
let marginal_price_pool_token_1_account = &self.marginal_price_pool_token_1_account;
let marginal_price_pool_lp_mint = &self.marginal_price_pool_lp_mint;
let swim_claim_rent_exempt_fees = rent.minimum_balance(8 + SwimClaim::LEN);
let gas_kickstart_amount = if swim_payload_message.gas_kickstart { propeller.gas_kickstart_amount } else { 0 };
let fee_in_lamports = swim_claim_rent_exempt_fees
.checked_add(propeller_process_swim_payload_fees)
.and_then(|x| x.checked_add(gas_kickstart_amount))
.ok_or(PropellerError::IntegerOverflow)?;
msg!(
"
{}(swim_claim_rent_exempt_fees) +
{}(propeller_process_swim_payload_fees) +
{}(gas_kickstart_amount)
= {}(fee_in_lamports)
",
swim_claim_rent_exempt_fees,
propeller_process_swim_payload_fees,
gas_kickstart_amount,
fee_in_lamports
);
Ok(fee_in_lamports)
}
fn convert_fees_to_swim_usd_atomic(
&self,
fee_in_lamports: u64,
marginal_prices: [BorshDecimal; TOKEN_COUNT],
max_staleness: i64,
) -> Result<u64> {
msg!("fee_in_lamports: {:?}", fee_in_lamports);
let marginal_price_pool_lp_mint = &self.marginal_price_pool_lp_mint;
let swim_usd_mint_key = self.propeller.swim_usd_mint;
// let marginal_prices = get_marginal_prices(cpi_ctx)?;
let intermediate_token_price_decimal: Decimal = get_marginal_price_decimal(
&self.marginal_price_pool,
&marginal_prices,
&self.propeller,
&marginal_price_pool_lp_mint.key(),
)?;
msg!("intermediate_token_price_decimal: {:?}", intermediate_token_price_decimal);
let fee_in_lamports_decimal = Decimal::from_u64(fee_in_lamports).ok_or(PropellerError::ConversionError)?;
msg!("fee_in_lamports(u64): {:?} fee_in_lamports_decimal: {:?}", fee_in_lamports, fee_in_lamports_decimal);
let mut res = 0u64;
let lamports_intermediate_token_price = get_lamports_intermediate_token_price(&aggregator, max_staleness)?;
let fee_in_swim_usd_decimal = lamports_intermediate_token_price
.checked_mul(fee_in_lamports_decimal)
.and_then(|x| x.checked_div(intermediate_token_price_decimal))
.ok_or(PropellerError::IntegerOverflow)?;
let swim_usd_decimals =
get_swim_usd_mint_decimals(&swim_usd_mint_key, &marginal_price_pool, &marginal_price_pool_lp_mint)?;
msg!("swim_usd_decimals: {:?}", swim_usd_decimals);
let ten_pow_decimals =
Decimal::from_u64(10u64.pow(swim_usd_decimals as u32)).ok_or(PropellerError::IntegerOverflow)?;
let fee_in_swim_usd_atomic = fee_in_swim_usd_decimal
.checked_mul(ten_pow_decimals)
.and_then(|v| v.to_u64())
.ok_or(PropellerError::ConversionError)?;
msg!(
"fee_in_swim_usd_decimal: {:?} fee_in_swim_usd_atomic: {:?}",
fee_in_swim_usd_decimal,
fee_in_swim_usd_atomic
);
res = fee_in_swim_usd_atomic;
Ok(res)
}
fn track_and_transfer_fees(&mut self, fees_in_swim_usd: u64) -> Result<()> {
let fee_tracker = &mut self.fee_tracker;
fee_tracker.fees_owed =
fee_tracker.fees_owed.checked_add(fees_in_swim_usd).ok_or(PropellerError::IntegerOverflow)?;
token::transfer(
CpiContext::new_with_signer(
self.token_program.to_account_info(),
Transfer {
from: self.redeemer_escrow.to_account_info(),
to: self.fee_vault.to_account_info(),
authority: self.redeemer.to_account_info(),
},
&[&[&b"redeemer".as_ref(), &[self.propeller.redeemer_bump]]],
),
fees_in_swim_usd,
)
}
}
*/

pub fn handle_propeller_process_swim_payload_fallback(
ctx: Context<PropellerProcessSwimPayloadFallback>,
) -> Result<u64> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ pub struct TokenIdMap {

impl TokenIdMap {
pub const LEN: usize = 2 + 32 + 1 + 32 + 1 + 1 + 1;

pub fn assert_is_invalid(token_id_map: &AccountInfo) -> Result<()> {
if let Ok(_) = TokenIdMap::try_deserialize(&mut &**token_id_map.try_borrow_mut_data()?) {
return err!(PropellerError::TokenIdMapExists);
}
Ok(())
}
}

#[derive(AnchorSerialize, AnchorDeserialize, Copy, Clone, Debug)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub fn convert_fees_to_swim_usd_atomic<'info>(
fee_in_lamports: u64,
propeller: &Propeller,
marginal_price_pool_lp_mint: &Account<'info, Mint>,
//TODO: just take marginal_prices as input here.
cpi_ctx: CpiContext<'_, '_, '_, 'info, two_pool::cpi::accounts::MarginalPrices<'info>>,
marginal_price_pool: &TwoPool,
aggregator: &AccountLoader<AggregatorAccountData>,
Expand Down Expand Up @@ -68,6 +69,59 @@ pub fn convert_fees_to_swim_usd_atomic<'info>(
Ok(res)
}

pub fn convert_fees_to_swim_usd_atomic_2<'info>(
fee_in_lamports: u64,
propeller: &Propeller,
marginal_price_pool_lp_mint: &Account<'info, Mint>,
marginal_prices: [BorshDecimal; TOKEN_COUNT],
marginal_price_pool: &TwoPool,
aggregator: &AccountLoader<AggregatorAccountData>,
max_staleness: i64,
) -> Result<u64> {
// let propeller = &self.propeller;

msg!("fee_in_lamports: {:?}", fee_in_lamports);
let marginal_price_pool_lp_mint = &marginal_price_pool_lp_mint;

let swim_usd_mint_key = propeller.swim_usd_mint;
// let marginal_prices = get_marginal_prices(cpi_ctx)?;

let intermediate_token_price_decimal: Decimal = get_marginal_price_decimal(
&marginal_price_pool,
&marginal_prices,
&propeller,
&marginal_price_pool_lp_mint.key(),
)?;

msg!("intermediate_token_price_decimal: {:?}", intermediate_token_price_decimal);

let fee_in_lamports_decimal = Decimal::from_u64(fee_in_lamports).ok_or(PropellerError::ConversionError)?;
msg!("fee_in_lamports(u64): {:?} fee_in_lamports_decimal: {:?}", fee_in_lamports, fee_in_lamports_decimal);

let mut res = 0u64;

let lamports_intermediate_token_price = get_lamports_intermediate_token_price(&aggregator, max_staleness)?;
let fee_in_swim_usd_decimal = lamports_intermediate_token_price
.checked_mul(fee_in_lamports_decimal)
.and_then(|x| x.checked_div(intermediate_token_price_decimal))
.ok_or(PropellerError::IntegerOverflow)?;

let swim_usd_decimals =
get_swim_usd_mint_decimals(&swim_usd_mint_key, &marginal_price_pool, &marginal_price_pool_lp_mint)?;
msg!("swim_usd_decimals: {:?}", swim_usd_decimals);

let ten_pow_decimals =
Decimal::from_u64(10u64.pow(swim_usd_decimals as u32)).ok_or(PropellerError::IntegerOverflow)?;
let fee_in_swim_usd_atomic = fee_in_swim_usd_decimal
.checked_mul(ten_pow_decimals)
.and_then(|v| v.to_u64())
.ok_or(PropellerError::ConversionError)?;

msg!("fee_in_swim_usd_decimal: {:?} fee_in_swim_usd_atomic: {:?}", fee_in_swim_usd_decimal, fee_in_swim_usd_atomic);
res = fee_in_swim_usd_atomic;
Ok(res)
}

pub fn get_swim_usd_mint_decimals(
swim_usd_mint: &Pubkey,
marginal_price_pool: &TwoPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use {
},
byteorder::{BigEndian, ReadBytesExt, WriteBytesExt},
num_traits::{FromPrimitive, ToPrimitive},
primitive_types::U256,
// primitive_types::U256,
rust_decimal::Decimal,
solana_program::program::invoke,
switchboard_v2::{AggregatorAccountData, SwitchboardDecimal, SWITCHBOARD_PROGRAM_ID},
Expand Down Expand Up @@ -306,8 +306,6 @@ pub fn handle_complete_native_with_payload(ctx: Context<CompleteNativeWithPayloa
let swim_payload = &transfer_with_payload.payload;
msg!("swim_payload: {:?}", swim_payload);

// ugly. re-doing the same calculation that WH does in `complete_transfer_payload` but
// should not be a huge issue.
let mut transfer_amount = transfer_with_payload.amount.as_u64();
if ctx.accounts.swim_usd_mint.decimals > 8 {
transfer_amount *= 10u64.pow(ctx.accounts.swim_usd_mint.decimals as u32);
Expand All @@ -316,12 +314,6 @@ pub fn handle_complete_native_with_payload(ctx: Context<CompleteNativeWithPayloa
let bump = *ctx.bumps.get("swim_payload_message").unwrap();
ctx.accounts.write_swim_payload_message(bump, &message_data, transfer_amount, swim_payload)?;

// let memo = swim_payload.memo;
// // get target_token_id -> (pool, pool_token_index)
// // need to know when to do remove_exact_burn & when to do swap_exact_input
// let memo_ix = spl_memo::build_memo(memo.as_slice(), &[]);
// invoke(&memo_ix, &[ctx.accounts.memo.to_account_info()])?;

Ok(())
}

Expand All @@ -347,7 +339,7 @@ pub struct PropellerCompleteNativeWithPayload<'info> {
*aggregator.to_account_info().owner == SWITCHBOARD_PROGRAM_ID @ PropellerError::InvalidSwitchboardAccount
)]
pub aggregator: AccountLoader<'info, AggregatorAccountData>,
// pub two_pool_program: Program<'info, two_pool::program::TwoPool>,

#[account(
mut,
seeds = [
Expand All @@ -360,8 +352,17 @@ pub struct PropellerCompleteNativeWithPayload<'info> {
seeds::program = two_pool_program.key()
)]
pub marginal_price_pool: Box<Account<'info, TwoPool>>,
#[account(
address = marginal_price_pool.token_keys[0],
)]
pub marginal_price_pool_token_0_account: Box<Account<'info, TokenAccount>>,
#[account(
address = marginal_price_pool.token_keys[1],
)]
pub marginal_price_pool_token_1_account: Box<Account<'info, TokenAccount>>,
#[account(
address = marginal_price_pool.lp_mint_key,
)]
pub marginal_price_pool_lp_mint: Box<Account<'info, Mint>>,
pub two_pool_program: Program<'info, two_pool::program::TwoPool>,
#[account(executable, address = spl_memo::id())]
Expand Down
Loading

0 comments on commit a230619

Please sign in to comment.