From d1ff3a236720e50c42a20ae15ac53f0ea7d2d840 Mon Sep 17 00:00:00 2001 From: Kyrylo Stepanov Date: Thu, 10 Oct 2024 14:33:01 +0300 Subject: [PATCH] Forbid direct invocations --- programs/rewards/Cargo.toml | 1 + programs/rewards/src/entrypoint.rs | 8 ++++++++ programs/rewards/src/error.rs | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/programs/rewards/Cargo.toml b/programs/rewards/Cargo.toml index 4d6ce24e..2b07fb02 100644 --- a/programs/rewards/Cargo.toml +++ b/programs/rewards/Cargo.toml @@ -27,3 +27,4 @@ path = "tests/rewards/tests.rs" [features] no-entrypoint = [] +testing = [] diff --git a/programs/rewards/src/entrypoint.rs b/programs/rewards/src/entrypoint.rs index 78a11584..06bcf3a9 100644 --- a/programs/rewards/src/entrypoint.rs +++ b/programs/rewards/src/entrypoint.rs @@ -1,17 +1,25 @@ //! Program entrypoint use crate::{error::MplxRewardsError, instructions::process_instruction}; +use solana_program::instruction::get_stack_height; use solana_program::{ account_info::AccountInfo, entrypoint, entrypoint::ProgramResult, program_error::PrintProgramError, pubkey::Pubkey, }; entrypoint!(program_entrypoint); + +pub const TRANSACTION_LEVEL_STACK_HEIGHT: usize = 1; + fn program_entrypoint<'a>( program_id: &Pubkey, accounts: &'a [AccountInfo<'a>], instruction_data: &[u8], ) -> ProgramResult { if let Err(error) = process_instruction(program_id, accounts, instruction_data) { + #[cfg(not(feature = "testing"))] + if get_stack_height() == TRANSACTION_LEVEL_STACK_HEIGHT { + return Err(MplxRewardsError::ForbiddenInvocation.into()); + } // Catch the error so we can print it error.print::(); return Err(error); diff --git a/programs/rewards/src/error.rs b/programs/rewards/src/error.rs index 07f67d03..598ceea8 100644 --- a/programs/rewards/src/error.rs +++ b/programs/rewards/src/error.rs @@ -84,6 +84,10 @@ pub enum MplxRewardsError { /// 19 #[error("Account addres derivation has failed")] AccountDerivationAddresFailed, + + /// 20 + #[error("This contract is supposed to be called only from the staking contract")] + ForbiddenInvocation, } impl PrintProgramError for MplxRewardsError {