diff --git a/programs/rewards/Cargo.toml b/programs/rewards/Cargo.toml index 4d6ce24..2b07fb0 100644 --- a/programs/rewards/Cargo.toml +++ b/programs/rewards/Cargo.toml @@ -27,3 +27,4 @@ path = "tests/rewards/tests.rs" [features] no-entrypoint = [] +testing = [] diff --git a/programs/rewards/src/entrypoint.rs b/programs/rewards/src/entrypoint.rs index 78a1158..06bcf3a 100644 --- a/programs/rewards/src/entrypoint.rs +++ b/programs/rewards/src/entrypoint.rs @@ -1,17 +1,25 @@ //! Program entrypoint use crate::{error::MplxRewardsError, instructions::process_instruction}; +use solana_program::instruction::get_stack_height; use solana_program::{ account_info::AccountInfo, entrypoint, entrypoint::ProgramResult, program_error::PrintProgramError, pubkey::Pubkey, }; entrypoint!(program_entrypoint); + +pub const TRANSACTION_LEVEL_STACK_HEIGHT: usize = 1; + fn program_entrypoint<'a>( program_id: &Pubkey, accounts: &'a [AccountInfo<'a>], instruction_data: &[u8], ) -> ProgramResult { if let Err(error) = process_instruction(program_id, accounts, instruction_data) { + #[cfg(not(feature = "testing"))] + if get_stack_height() == TRANSACTION_LEVEL_STACK_HEIGHT { + return Err(MplxRewardsError::ForbiddenInvocation.into()); + } // Catch the error so we can print it error.print::(); return Err(error); diff --git a/programs/rewards/src/error.rs b/programs/rewards/src/error.rs index 07f67d0..598ceea 100644 --- a/programs/rewards/src/error.rs +++ b/programs/rewards/src/error.rs @@ -84,6 +84,10 @@ pub enum MplxRewardsError { /// 19 #[error("Account addres derivation has failed")] AccountDerivationAddresFailed, + + /// 20 + #[error("This contract is supposed to be called only from the staking contract")] + ForbiddenInvocation, } impl PrintProgramError for MplxRewardsError {