From a9e56a60b6cb7894cdce392f3bcc2b1fa2e0fd6a Mon Sep 17 00:00:00 2001 From: h3rt <94856309+SecretSaturn@users.noreply.github.com> Date: Sat, 27 Jul 2024 00:09:10 +0200 Subject: [PATCH] Finalize CPI, switch to find_program_address --- .../programs/solana-gateway/src/errors.rs | 6 + .../programs/solana-gateway/src/lib.rs | 181 +++++++++++------- TNLS-Relayers/sol_interface.py | 50 +++-- 3 files changed, 147 insertions(+), 90 deletions(-) diff --git a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs index 0017073..67db7cd 100644 --- a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs +++ b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs @@ -44,4 +44,10 @@ pub enum GatewayError { PDAAlreadyInitialized, #[msg("Only owner can call this function!")] NotOwner +} + +#[error_code] +pub enum ProgramError { + #[msg("The signer is not the Secretpath Gateway program")] + InvalidSecretPathGateway } \ No newline at end of file diff --git a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs index 776214a..f9bc54e 100644 --- a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs +++ b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs @@ -11,9 +11,10 @@ use anchor_lang::{ use base64::{engine::general_purpose::STANDARD, Engine}; use hex::decode; use sha3::{Digest, Keccak256}; +use std::str::FromStr; pub mod errors; -use crate::errors::{GatewayError, TaskError}; +use crate::errors::{GatewayError, ProgramError, TaskError}; declare_id!("5LWZAN7ZFE3Rmg4MdjqNTRkSbMxthyG8ouSa3cfn3R6V"); @@ -30,9 +31,18 @@ const SEED: &[u8] = b"gateway_state"; mod solana_gateway { use super::*; - pub fn initialize(ctx: Context, bump: u8) -> Result<()> { + pub fn initialize(ctx: Context) -> Result<()> { let gateway_state = &mut ctx.accounts.gateway_state; + // Verify that the gateway_state is a PDA with the correct seeds and bump + let (expected_gateway_state, _bump_seed) = + Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id); + + require!( + gateway_state.key() == expected_gateway_state, + GatewayError::InvalidGatewayState + ); + // Check if the gateway_state has already been initialized require!( @@ -44,21 +54,16 @@ mod solana_gateway { gateway_state.task_id = 0; gateway_state.tasks = Vec::new(); gateway_state.max_tasks = 10; - gateway_state.bump = bump; Ok(()) } - pub fn increase_task_id( - ctx: Context, - new_task_id: u64, - bump: u8, - ) -> Result<()> { + pub fn increase_task_id(ctx: Context, new_task_id: u64) -> Result<()> { let gateway_state = &mut ctx.accounts.gateway_state; // Verify that the gateway_state is a PDA with the correct seeds and bump - let expected_gateway_state = - Pubkey::create_program_address(&[SEED.as_ref(), &[bump]], ctx.program_id).unwrap(); + let (expected_gateway_state, _bump_seed) = + Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id); require!( gateway_state.key() == expected_gateway_state, @@ -79,13 +84,12 @@ mod solana_gateway { user_address: Pubkey, routing_info: String, execution_info: ExecutionInfo, - bump: u8, ) -> Result { let gateway_state = &mut ctx.accounts.gateway_state; // Verify that the gateway_state is a PDA with the correct seeds and bump - let expected_gateway_state = - Pubkey::create_program_address(&[SEED.as_ref(), &[bump]], ctx.program_id).unwrap(); + let (expected_gateway_state, _bump_seed) = + Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id); require!( gateway_state.key() == expected_gateway_state, @@ -182,18 +186,17 @@ mod solana_gateway { }) } - pub fn post_execution( - ctx: Context, + pub fn post_execution<'info>( + ctx: Context<'_, '_, '_, 'info, PostExecution<'info>>, task_id: u64, source_network: String, post_execution_info: PostExecutionInfo, - bump: u8, ) -> Result<()> { let gateway_state = &mut ctx.accounts.gateway_state; // Verify that the gateway_state is a PDA with the correct seeds and bump - let expected_gateway_state = - Pubkey::create_program_address(&[SEED.as_ref(), &[bump]], ctx.program_id).unwrap(); + let (expected_gateway_state, bump_seed) = + Pubkey::find_program_address(&[SEED.as_ref()], ctx.program_id); require!( gateway_state.key() == expected_gateway_state, @@ -261,7 +264,7 @@ mod solana_gateway { // // Verify that the recovered public key matches the expected public key require!( - recovered_pubkey.to_bytes() == expected_pubkey.to_bytes(), + recovered_pubkey == expected_pubkey, TaskError::InvalidPacketSignature ); @@ -284,12 +287,12 @@ mod solana_gateway { ); // Extract and concatenate the program ID and function identifier - let program_id_bytes = &post_execution_info.callback_selector[0..32]; - let function_identifier = &post_execution_info.callback_selector[32..40]; + let (program_id_bytes, function_identifier) = + post_execution_info.callback_selector.split_at(32); // Concatenate the function identifier with the rest of the data - let mut callback_data = Vec::with_capacity(function_identifier.len() + borsh_data.len()); - callback_data.extend_from_slice(function_identifier); + let mut callback_data = Vec::with_capacity(8 + borsh_data.len()); + callback_data.extend_from_slice(&function_identifier[0..8]); callback_data.extend_from_slice(&borsh_data); // Concatenate all addresses that will be accessed @@ -300,36 +303,59 @@ mod solana_gateway { TaskError::InvalidCallbackAddresses ); + let mut callback_account_metas = Vec::new(); + let mut callback_addresses = Vec::new(); + + // Add the PDA as the signer + // Modify the AccountInfo to set is_writable to false for gateway_state + let mut gateway_state_account_info = ctx.accounts.gateway_state.to_account_info(); + gateway_state_account_info.is_writable = false; + gateway_state_account_info.is_signer = true; + callback_account_metas.push(AccountMeta::new_readonly(expected_gateway_state, true)); + callback_addresses.push(gateway_state_account_info); + + // Add the system_program account + callback_account_metas.push(AccountMeta::new_readonly( + *ctx.accounts.system_program.key, + false, + )); + callback_addresses.push(ctx.accounts.system_program.to_account_info()); + + let mut found_addresses = std::collections::HashSet::new(); + for chunk in callback_address_bytes.chunks(32) { - let pubkey = Pubkey::try_from(chunk).expect("Invalid callback pubkey"); - if ctx.remaining_accounts.iter().find(|account| account.key == &pubkey).is_none() { - return Err(TaskError::MissingRequiredSignature.into()); + match Pubkey::try_from(chunk) { + Ok(pubkey) => { + if pubkey == expected_gateway_state { + continue; + } + + if ctx + .remaining_accounts + .iter() + .any(|account| account.key == &pubkey) + { + if found_addresses.insert(pubkey) { + if let Some(account) = ctx + .remaining_accounts + .iter() + .find(|account| account.key == &pubkey) + { + callback_account_metas + .push(AccountMeta::new(*account.key, account.is_signer)); + callback_addresses.push(account.clone()); + } else { + return Err(TaskError::InvalidCallbackAddresses.into()); + } + } + } else { + return Err(TaskError::MissingRequiredSignature.into()); + } + } + Err(_) => return Err(TaskError::InvalidCallbackAddresses.into()), } } - // Map callback_address_bytes to AccountInfo - let callback_addresses: Vec = callback_address_bytes - .chunks(32) - .map(|address| { - let pubkey = Pubkey::try_from(address).expect("Invalid callback pubkey"); - ctx.remaining_accounts - .iter() - .find(|account| account.key == &pubkey) - .expect("Callback account not found") - .clone() - }) - .collect(); - - let system_program = ctx.accounts.system_program.to_account_info(); - - // Collect the callback addresses into a vector - let mut callback_account_metas: Vec = callback_addresses.iter() - .map(|account| AccountMeta::new(*account.key, account.is_signer)) - .collect(); - - // Add the system_program account to the vector - callback_account_metas.push(AccountMeta::new_readonly(*system_program.key, false)); - // Execute the callback with signed seeds let callback_result = invoke_signed( &Instruction { @@ -338,33 +364,48 @@ mod solana_gateway { data: callback_data, }, &callback_addresses, - &[&[SEED.as_ref(), &[bump]]], + &[&[SEED.as_ref(), &[bump_seed]]], ); - let task_completed = TaskCompleted { + // Emit Message that the task was completed and if it returned Ok + msg!( + "TaskCompleted: task_id: {} and callback_result: {}", task_id, - callback_successful: callback_result.is_ok(), - }; - - msg!(&format!( - "TaskCompleted:{}", - STANDARD.encode(&task_completed.try_to_vec().unwrap()) - )); + callback_result.is_ok() + ); Ok(()) } - pub fn callback_test(ctx: Context, task_id: u64, result: String) -> Result<()> { - msg!("Callback invoked with task_id: {} and result: {}", task_id, result); + pub fn callback_test(ctx: Context, task_id: u64, result: Vec) -> Result<()> { + // Check if the callback is really coming from the secretpath_gateway and that it was signed by it + const SECRET_PATH_ADDRESS: &str = "5mf563g8JSeTE1mMY4GSqbynjayToKSh7x5WLoQ9RDEQ"; + let secretpath_address_pubkey = + Pubkey::from_str(SECRET_PATH_ADDRESS).expect("Invalid public key format"); + + // Inline check for signature and address + if !ctx.accounts.secretpath_gateway.is_signer + || ctx.accounts.secretpath_gateway.key() != secretpath_address_pubkey + { + msg!("Callback failed: Invalid signer or public key mismatch"); + return Err(ProgramError::InvalidSecretPathGateway.into()); + } + + // Convert result to base64 string for test purposes + msg!( + "Callback invoked with task_id: {} and result: {}", + task_id, + STANDARD.encode(&result) + ); + Ok(()) } } #[derive(Accounts)] pub struct CallbackTest<'info> { - #[account(mut)] + #[account()] pub secretpath_gateway: Signer<'info>, - pub system_program: Program<'info, System>, } #[derive(Accounts)] @@ -377,7 +418,7 @@ pub struct Initialize<'info> { bump )] pub gateway_state: Account<'info, GatewayState>, - #[account(mut)] + #[account(mut, signer)] pub owner: Signer<'info>, pub system_program: Program<'info, System>, } @@ -386,6 +427,7 @@ pub struct Initialize<'info> { pub struct IncreaseTaskId<'info> { #[account(mut, seeds = [SEED], bump)] pub gateway_state: Account<'info, GatewayState>, + #[account(mut, signer)] pub owner: Signer<'info>, } @@ -393,7 +435,7 @@ pub struct IncreaseTaskId<'info> { pub struct Send<'info> { #[account(mut, seeds = [SEED], bump)] pub gateway_state: Account<'info, GatewayState>, - #[account(mut)] + #[account(mut, signer)] pub user: Signer<'info>, pub system_program: Program<'info, System>, } @@ -402,8 +444,8 @@ pub struct Send<'info> { pub struct PostExecution<'info> { #[account(mut, seeds = [SEED], bump)] pub gateway_state: Account<'info, GatewayState>, - #[account(mut)] - pub user: Signer<'info>, + #[account(mut, signer)] + pub signer: Signer<'info>, pub system_program: Program<'info, System>, } @@ -413,7 +455,6 @@ pub struct GatewayState { pub task_id: u64, pub tasks: Vec, pub max_tasks: u64, - pub bump: u8, } #[derive(Debug, Clone, AnchorSerialize, AnchorDeserialize)] @@ -457,12 +498,6 @@ pub struct SendResponse { pub request_id: u64, } -#[event] -pub struct TaskCompleted { - pub task_id: u64, - pub callback_successful: bool, -} - #[derive(Debug, Clone, AnchorSerialize, AnchorDeserialize)] pub struct LogNewTask { pub task_id: u64, diff --git a/TNLS-Relayers/sol_interface.py b/TNLS-Relayers/sol_interface.py index dd194a6..a882ed8 100644 --- a/TNLS-Relayers/sol_interface.py +++ b/TNLS-Relayers/sol_interface.py @@ -1,15 +1,16 @@ import json from solana.rpc.api import Client +from solders.compute_budget import set_compute_unit_limit from solders.keypair import Keypair from solders.pubkey import Pubkey from threading import Lock from solana.transaction import Transaction from concurrent.futures import ThreadPoolExecutor, as_completed from logging import getLogger, basicConfig, INFO, StreamHandler -from borsh_construct import CStruct, U64, String, Vec, U8, U32, Bytes +from borsh_construct import CStruct, U64, String, U8, U32, Bytes from solders.system_program import ID as SYS_PROGRAM_ID from solders.instruction import Instruction, AccountMeta -from solana.rpc.commitment import Confirmed, Finalized +from solana.rpc.commitment import Confirmed from typing import List from solana.rpc.types import TxOpts import base64 @@ -48,8 +49,7 @@ class PostExecution: "callback_gas_limit" / Bytes, "packet_signature" / U8[65], "result" / Bytes, - ), - "bump" / U8, + ) ) # Base class for interaction with Solana @@ -79,15 +79,12 @@ def sign_and_send_transaction(self, txn): """ Sign and send a transaction to the Solana network synchronously. """ - # Create the transaction - transaction = Transaction() - transaction.add(txn) # Sign the transaction - transaction.sign(self.account) + txn.sign(self.account) # Send the transaction - response = self.provider.send_transaction(transaction, self.account, + response = self.provider.send_transaction(txn, self.account, opts=TxOpts(skip_confirmation=False, preflight_commitment=Confirmed)) # Confirm the transaction @@ -108,9 +105,11 @@ def get_transactions(self, contract_interface, height): """ Get transactions for a given address. """ + #jump = 0 jump = 10 if height % jump != 0: return [] + filtered_transactions = [] try: response = self.provider.get_signatures_for_address(account=contract_interface.address, limit=10, @@ -184,13 +183,15 @@ def call_function(self, function_name, *args): """ Create a transaction with the given instructions and signers. """ - # Create context + + # Create AccountMetas accounts: list[AccountMeta] = [ AccountMeta(pubkey=self.address, is_signer=False, is_writable=True), AccountMeta(pubkey=self.interface.address, is_signer=True, is_writable=True), AccountMeta(pubkey=SYS_PROGRAM_ID, is_signer=False, is_writable=False), ] + # Parse the JSON if len(args) == 1: args = json.loads(args[0]) @@ -199,14 +200,23 @@ def call_function(self, function_name, *args): if len(callback_address_bytes) % 32 != 0: raise ValueError("callback_address_bytes length is not a multiple of 32") - callback_addresses: List[AccountMeta] = [ + callback_accounts: List[AccountMeta] = [ AccountMeta(pubkey=Pubkey(callback_address_bytes[i:i + 32]), is_signer=False, is_writable=True) for i in range(0, len(callback_address_bytes), 32) ] + # Add the callback_accounts to the accounts + if callback_accounts is not None: + accounts += callback_accounts + + # Extract the program_id from the callback_selector + callback_selector_bytes = bytes.fromhex(args[2][3][2:]) + if len(callback_selector_bytes) < 32: + raise ValueError("callback_selector does not contain enough bytes for a program_id") + program_id_bytes = callback_selector_bytes[:32] + program_id_pubkey = Pubkey(program_id_bytes) - print(callback_addresses) - if callback_addresses is not None: - accounts += callback_addresses + # Add the extracted program_id as an AccountMeta + accounts.append(AccountMeta(pubkey=program_id_pubkey, is_signer=False, is_writable=False)) # The Identifier of the post execution function identifier = bytes([52, 46, 67, 194, 153, 197, 69, 168]) @@ -222,14 +232,20 @@ def call_function(self, function_name, *args): "callback_gas_limit": bytes.fromhex(args[2][4][2:]), "packet_signature": bytes.fromhex(args[2][5][2:]), "result": bytes.fromhex(args[2][6][2:]), - }, - "bump": self.bump + } } ) + data = identifier + encoded_args tx = Instruction(program_id=self.program_id, data=data, accounts=accounts) + callback_gas_limit = int.from_bytes(bytes.fromhex(args[2][4][2:]), byteorder='big') + compute_budget_ix = set_compute_unit_limit(callback_gas_limit) + + # Create the transaction + transaction = Transaction(fee_payer=self.interface.address) + transaction.add(compute_budget_ix, tx) - submitted_txn = self.interface.sign_and_send_transaction(tx) + submitted_txn = self.interface.sign_and_send_transaction(transaction) return submitted_txn def parse_event_from_txn(self, event_name, txn) -> List[Task]: