Skip to content

Commit

Permalink
Finalize CPI, switch to find_program_address
Browse files Browse the repository at this point in the history
  • Loading branch information
SecretSaturn committed Jul 26, 2024
1 parent 4342402 commit a9e56a6
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,10 @@ pub enum GatewayError {
PDAAlreadyInitialized,
#[msg("Only owner can call this function!")]
NotOwner
}

#[error_code]
pub enum ProgramError {
#[msg("The signer is not the Secretpath Gateway program")]
InvalidSecretPathGateway
}
181 changes: 108 additions & 73 deletions TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use anchor_lang::{
use base64::{engine::general_purpose::STANDARD, Engine};
use hex::decode;
use sha3::{Digest, Keccak256};
use std::str::FromStr;

pub mod errors;
use crate::errors::{GatewayError, TaskError};
use crate::errors::{GatewayError, ProgramError, TaskError};

declare_id!("5LWZAN7ZFE3Rmg4MdjqNTRkSbMxthyG8ouSa3cfn3R6V");

Expand All @@ -30,9 +31,18 @@ const SEED: &[u8] = b"gateway_state";
mod solana_gateway {
use super::*;

pub fn initialize(ctx: Context<Initialize>, bump: u8) -> Result<()> {
pub fn initialize(ctx: Context<Initialize>) -> Result<()> {
let gateway_state = &mut ctx.accounts.gateway_state;

// Verify that the gateway_state is a PDA with the correct seeds and bump
let (expected_gateway_state, _bump_seed) =
Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id);

require!(
gateway_state.key() == expected_gateway_state,
GatewayError::InvalidGatewayState
);

// Check if the gateway_state has already been initialized

require!(
Expand All @@ -44,21 +54,16 @@ mod solana_gateway {
gateway_state.task_id = 0;
gateway_state.tasks = Vec::new();
gateway_state.max_tasks = 10;
gateway_state.bump = bump;

Ok(())
}

pub fn increase_task_id(
ctx: Context<IncreaseTaskId>,
new_task_id: u64,
bump: u8,
) -> Result<()> {
pub fn increase_task_id(ctx: Context<IncreaseTaskId>, new_task_id: u64) -> Result<()> {
let gateway_state = &mut ctx.accounts.gateway_state;

// Verify that the gateway_state is a PDA with the correct seeds and bump
let expected_gateway_state =
Pubkey::create_program_address(&[SEED.as_ref(), &[bump]], ctx.program_id).unwrap();
let (expected_gateway_state, _bump_seed) =
Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id);

require!(
gateway_state.key() == expected_gateway_state,
Expand All @@ -79,13 +84,12 @@ mod solana_gateway {
user_address: Pubkey,
routing_info: String,
execution_info: ExecutionInfo,
bump: u8,
) -> Result<SendResponse> {
let gateway_state = &mut ctx.accounts.gateway_state;

// Verify that the gateway_state is a PDA with the correct seeds and bump
let expected_gateway_state =
Pubkey::create_program_address(&[SEED.as_ref(), &[bump]], ctx.program_id).unwrap();
let (expected_gateway_state, _bump_seed) =
Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id);

require!(
gateway_state.key() == expected_gateway_state,
Expand Down Expand Up @@ -182,18 +186,17 @@ mod solana_gateway {
})
}

pub fn post_execution(
ctx: Context<PostExecution>,
pub fn post_execution<'info>(
ctx: Context<'_, '_, '_, 'info, PostExecution<'info>>,
task_id: u64,
source_network: String,
post_execution_info: PostExecutionInfo,
bump: u8,
) -> Result<()> {
let gateway_state = &mut ctx.accounts.gateway_state;

// Verify that the gateway_state is a PDA with the correct seeds and bump
let expected_gateway_state =
Pubkey::create_program_address(&[SEED.as_ref(), &[bump]], ctx.program_id).unwrap();
let (expected_gateway_state, bump_seed) =
Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id);

require!(
gateway_state.key() == expected_gateway_state,
Expand Down Expand Up @@ -261,7 +264,7 @@ mod solana_gateway {

// // Verify that the recovered public key matches the expected public key
require!(
recovered_pubkey.to_bytes() == expected_pubkey.to_bytes(),
recovered_pubkey == expected_pubkey,
TaskError::InvalidPacketSignature
);

Expand All @@ -284,12 +287,12 @@ mod solana_gateway {
);

// Extract and concatenate the program ID and function identifier
let program_id_bytes = &post_execution_info.callback_selector[0..32];
let function_identifier = &post_execution_info.callback_selector[32..40];
let (program_id_bytes, function_identifier) =
post_execution_info.callback_selector.split_at(32);

// Concatenate the function identifier with the rest of the data
let mut callback_data = Vec::with_capacity(function_identifier.len() + borsh_data.len());
callback_data.extend_from_slice(function_identifier);
let mut callback_data = Vec::with_capacity(8 + borsh_data.len());
callback_data.extend_from_slice(&function_identifier[0..8]);
callback_data.extend_from_slice(&borsh_data);

// Concatenate all addresses that will be accessed
Expand All @@ -300,36 +303,59 @@ mod solana_gateway {
TaskError::InvalidCallbackAddresses
);

let mut callback_account_metas = Vec::new();
let mut callback_addresses = Vec::new();

// Add the PDA as the signer
// Modify the AccountInfo to set is_writable to false for gateway_state
let mut gateway_state_account_info = ctx.accounts.gateway_state.to_account_info();
gateway_state_account_info.is_writable = false;
gateway_state_account_info.is_signer = true;
callback_account_metas.push(AccountMeta::new_readonly(expected_gateway_state, true));
callback_addresses.push(gateway_state_account_info);

// Add the system_program account
callback_account_metas.push(AccountMeta::new_readonly(
*ctx.accounts.system_program.key,
false,
));
callback_addresses.push(ctx.accounts.system_program.to_account_info());

let mut found_addresses = std::collections::HashSet::new();

for chunk in callback_address_bytes.chunks(32) {
let pubkey = Pubkey::try_from(chunk).expect("Invalid callback pubkey");
if ctx.remaining_accounts.iter().find(|account| account.key == &pubkey).is_none() {
return Err(TaskError::MissingRequiredSignature.into());
match Pubkey::try_from(chunk) {
Ok(pubkey) => {
if pubkey == expected_gateway_state {
continue;
}

if ctx
.remaining_accounts
.iter()
.any(|account| account.key == &pubkey)
{
if found_addresses.insert(pubkey) {
if let Some(account) = ctx
.remaining_accounts
.iter()
.find(|account| account.key == &pubkey)
{
callback_account_metas
.push(AccountMeta::new(*account.key, account.is_signer));
callback_addresses.push(account.clone());
} else {
return Err(TaskError::InvalidCallbackAddresses.into());
}
}
} else {
return Err(TaskError::MissingRequiredSignature.into());
}
}
Err(_) => return Err(TaskError::InvalidCallbackAddresses.into()),
}
}

// Map callback_address_bytes to AccountInfo
let callback_addresses: Vec<AccountInfo> = callback_address_bytes
.chunks(32)
.map(|address| {
let pubkey = Pubkey::try_from(address).expect("Invalid callback pubkey");
ctx.remaining_accounts
.iter()
.find(|account| account.key == &pubkey)
.expect("Callback account not found")
.clone()
})
.collect();

let system_program = ctx.accounts.system_program.to_account_info();

// Collect the callback addresses into a vector
let mut callback_account_metas: Vec<AccountMeta> = callback_addresses.iter()
.map(|account| AccountMeta::new(*account.key, account.is_signer))
.collect();

// Add the system_program account to the vector
callback_account_metas.push(AccountMeta::new_readonly(*system_program.key, false));

// Execute the callback with signed seeds
let callback_result = invoke_signed(
&Instruction {
Expand All @@ -338,33 +364,48 @@ mod solana_gateway {
data: callback_data,
},
&callback_addresses,
&[&[SEED.as_ref(), &[bump]]],
&[&[SEED.as_ref(), &[bump_seed]]],
);

let task_completed = TaskCompleted {
// Emit Message that the task was completed and if it returned Ok
msg!(
"TaskCompleted: task_id: {} and callback_result: {}",
task_id,
callback_successful: callback_result.is_ok(),
};

msg!(&format!(
"TaskCompleted:{}",
STANDARD.encode(&task_completed.try_to_vec().unwrap())
));
callback_result.is_ok()
);

Ok(())
}

pub fn callback_test(ctx: Context<CallbackTest>, task_id: u64, result: String) -> Result<()> {
msg!("Callback invoked with task_id: {} and result: {}", task_id, result);
pub fn callback_test(ctx: Context<CallbackTest>, task_id: u64, result: Vec<u8>) -> Result<()> {
// Check if the callback is really coming from the secretpath_gateway and that it was signed by it
const SECRET_PATH_ADDRESS: &str = "5mf563g8JSeTE1mMY4GSqbynjayToKSh7x5WLoQ9RDEQ";
let secretpath_address_pubkey =
Pubkey::from_str(SECRET_PATH_ADDRESS).expect("Invalid public key format");

// Inline check for signature and address
if !ctx.accounts.secretpath_gateway.is_signer
|| ctx.accounts.secretpath_gateway.key() != secretpath_address_pubkey
{
msg!("Callback failed: Invalid signer or public key mismatch");
return Err(ProgramError::InvalidSecretPathGateway.into());
}

// Convert result to base64 string for test purposes
msg!(
"Callback invoked with task_id: {} and result: {}",
task_id,
STANDARD.encode(&result)
);

Ok(())
}
}

#[derive(Accounts)]
pub struct CallbackTest<'info> {
#[account(mut)]
#[account()]
pub secretpath_gateway: Signer<'info>,
pub system_program: Program<'info, System>,
}

#[derive(Accounts)]
Expand All @@ -377,7 +418,7 @@ pub struct Initialize<'info> {
bump
)]
pub gateway_state: Account<'info, GatewayState>,
#[account(mut)]
#[account(mut, signer)]
pub owner: Signer<'info>,
pub system_program: Program<'info, System>,
}
Expand All @@ -386,14 +427,15 @@ pub struct Initialize<'info> {
pub struct IncreaseTaskId<'info> {
#[account(mut, seeds = [SEED], bump)]
pub gateway_state: Account<'info, GatewayState>,
#[account(mut, signer)]
pub owner: Signer<'info>,
}

#[derive(Accounts)]
pub struct Send<'info> {
#[account(mut, seeds = [SEED], bump)]
pub gateway_state: Account<'info, GatewayState>,
#[account(mut)]
#[account(mut, signer)]
pub user: Signer<'info>,
pub system_program: Program<'info, System>,
}
Expand All @@ -402,8 +444,8 @@ pub struct Send<'info> {
pub struct PostExecution<'info> {
#[account(mut, seeds = [SEED], bump)]
pub gateway_state: Account<'info, GatewayState>,
#[account(mut)]
pub user: Signer<'info>,
#[account(mut, signer)]
pub signer: Signer<'info>,
pub system_program: Program<'info, System>,
}

Expand All @@ -413,7 +455,6 @@ pub struct GatewayState {
pub task_id: u64,
pub tasks: Vec<Task>,
pub max_tasks: u64,
pub bump: u8,
}

#[derive(Debug, Clone, AnchorSerialize, AnchorDeserialize)]
Expand Down Expand Up @@ -457,12 +498,6 @@ pub struct SendResponse {
pub request_id: u64,
}

#[event]
pub struct TaskCompleted {
pub task_id: u64,
pub callback_successful: bool,
}

#[derive(Debug, Clone, AnchorSerialize, AnchorDeserialize)]
pub struct LogNewTask {
pub task_id: u64,
Expand Down
Loading

0 comments on commit a9e56a6

Please sign in to comment.