From 5aa8c6a91e79a272fe58bebc7b4ea2600d0661f8 Mon Sep 17 00:00:00 2001 From: jordy25519 Date: Sat, 7 Dec 2024 01:32:43 +0800 Subject: [PATCH] Feat/wallet refactor (#95) * refactor wallet with mode enum --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/controller.rs | 72 +++++++++++++++++++++-------------------------- src/main.rs | 2 +- src/types.rs | 31 ++++++++++++++++++++ 5 files changed, 66 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3ec26b0..c8d8603 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1762,7 +1762,7 @@ dependencies = [ [[package]] name = "drift-gateway" -version = "1.2.1" +version = "1.2.2" dependencies = [ "actix-web", "argh", diff --git a/Cargo.toml b/Cargo.toml index 6d56745..040b06b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "drift-gateway" -version = "1.2.1" +version = "1.2.2" edition = "2021" [dependencies] diff --git a/src/controller.rs b/src/controller.rs index 08b1472..6106bb9 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -37,10 +37,11 @@ use thiserror::Error; use crate::{ types::{ get_market_decimals, AllMarketsResponse, CancelAndPlaceRequest, CancelOrdersRequest, - GetOrdersRequest, GetOrdersResponse, GetPositionsRequest, GetPositionsResponse, Market, - MarketInfoResponse, ModifyOrdersRequest, Order, PerpPosition, PerpPositionExtended, - PlaceOrdersRequest, SolBalanceResponse, SpotPosition, TxEventsResponse, TxResponse, - UserCollateralResponse, UserLeverageResponse, UserMarginResponse, PRICE_DECIMALS, + GatewayWallet, GetOrdersRequest, GetOrdersResponse, GetPositionsRequest, + GetPositionsResponse, Market, MarketInfoResponse, ModifyOrdersRequest, Order, PerpPosition, + PerpPositionExtended, PlaceOrdersRequest, SolBalanceResponse, SpotPosition, + TxEventsResponse, TxResponse, UserCollateralResponse, UserLeverageResponse, + UserMarginResponse, WalletMode, PRICE_DECIMALS, }, websocket::map_drift_event_for_account, Context, LOG_TARGET, @@ -67,9 +68,7 @@ pub enum ControllerError { #[derive(Clone)] pub struct AppState { - pub wallet: Wallet, - /// true if gateway is using delegated signing - delegated: bool, + pub wallet: Arc, pub client: Arc, /// Solana tx commitment level for preflight confirmation tx_commitment: CommitmentConfig, @@ -85,17 +84,18 @@ pub struct AppState { impl AppState { /// Configured drift authority address pub fn authority(&self) -> &Pubkey { - self.wallet.authority() + self.wallet.inner().authority() } /// Configured drift signing address pub fn signer(&self) -> Pubkey { - self.wallet.signer() + self.wallet.inner().signer() } pub fn default_sub_account(&self) -> Pubkey { - self.wallet.sub_account(self.default_subaccount_id) + self.wallet.inner().sub_account(self.default_subaccount_id) } pub fn resolve_sub_account(&self, sub_account_id: Option) -> Pubkey { self.wallet + .inner() .sub_account(sub_account_id.unwrap_or(self.default_subaccount_id)) } @@ -111,7 +111,7 @@ impl AppState { pub async fn new( endpoint: &str, devnet: bool, - wallet: Wallet, + wallet: GatewayWallet, commitment: Option<(CommitmentConfig, CommitmentConfig)>, default_subaccount_id: Option, skip_tx_preflight: bool, @@ -126,11 +126,13 @@ impl AppState { }; let rpc_client = RpcClient::new_with_commitment(endpoint.into(), state_commitment); - let client = DriftClient::new(context, rpc_client, wallet.clone()) + let client = DriftClient::new(context, rpc_client, wallet.inner().clone()) .await .expect("ok"); - let default_subaccount = wallet.sub_account(default_subaccount_id.unwrap_or(0)); + let default_subaccount = wallet + .inner() + .sub_account(default_subaccount_id.unwrap_or(0)); if let Err(err) = client.subscribe_account(&default_subaccount).await { log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}"); } else { @@ -150,24 +152,23 @@ impl AppState { }, ); - let priority_fee_subscriber = if !wallet.is_emulating() { + let priority_fee_subscriber = if wallet.is_emulating() { + Arc::new(priority_fee_subscriber) + } else { client .subscribe_blockhashes() .await .expect("blockhashes subscribed"); priority_fee_subscriber.subscribe() - } else { - Arc::new(priority_fee_subscriber) }; Self { client: Arc::new(client), - delegated: wallet.is_delegated(), tx_commitment, default_subaccount_id: default_subaccount_id.unwrap_or(0), skip_tx_preflight, priority_fee_subscriber, - wallet, + wallet: Arc::new(wallet), extra_rpcs: extra_rpcs .into_iter() .map(|u| Arc::new(RpcClient::new(get_http_url(u).expect("valid RPC url")))) @@ -218,7 +219,7 @@ impl AppState { let balance = self .client .inner() - .get_balance(&self.wallet.signer()) + .get_balance(&self.wallet.inner().signer()) .await .map_err(|err| ControllerError::Sdk(err.into()))?; Ok(SolBalanceResponse { @@ -248,7 +249,7 @@ impl AppState { self.client.program_data(), sub_account, Cow::Owned(account_data), - self.delegated, + self.wallet.is_delegated(), ) .with_priority_fee(priority_fee, ctx.cu_limit); let tx = build_cancel_ix(builder, req)?.build(); @@ -472,7 +473,7 @@ impl AppState { self.client.program_data(), sub_account, Cow::Owned(account_data), - self.delegated, + self.wallet.is_delegated(), ) .with_priority_fee(ctx.cu_price.unwrap_or(pf), ctx.cu_limit); @@ -507,7 +508,7 @@ impl AppState { self.client.program_data(), sub_account, Cow::Owned(account_data), - self.delegated, + self.wallet.is_delegated(), ) .with_priority_fee(priority_fee, ctx.cu_limit) .place_orders(orders) @@ -528,7 +529,7 @@ impl AppState { self.client.program_data(), sub_account, Cow::Owned(account_data), - self.delegated, + self.wallet.is_delegated(), ) .with_priority_fee(ctx.cu_price.unwrap_or(pf), ctx.cu_limit); let tx = build_modify_ix(builder, req, self.client.program_data())?.build(); @@ -612,7 +613,7 @@ impl AppState { ttl: Option, ) -> GatewayResult { let recent_block_hash = self.client.get_latest_blockhash().await?; - let tx = self.wallet.sign_tx(tx, recent_block_hash)?; + let tx = self.wallet.inner().sign_tx(tx, recent_block_hash)?; let tx_config = RpcSendTransactionConfig { max_retries: Some(0), preflight_commitment: Some(self.tx_commitment.commitment), @@ -780,31 +781,22 @@ pub fn create_wallet( secret_key: Option, emulate: Option, delegate: Option, -) -> Wallet { +) -> GatewayWallet { match (&secret_key, emulate, delegate) { (Some(secret_key), _, delegate) => { let mut wallet = Wallet::try_from_str(secret_key).expect("valid key"); if let Some(authority) = delegate { wallet.to_delegated(authority); + GatewayWallet::new(wallet, WalletMode::Delegated) + } else { + GatewayWallet::new(wallet, WalletMode::Normal) } - wallet } - (None, Some(emulate), None) => Wallet::read_only(emulate), + (None, Some(emulate), None) => { + GatewayWallet::new(Wallet::read_only(emulate), WalletMode::Normal) + } _ => { panic!("expected 'DRIFT_GATEWAY_KEY' or --emulate "); } } } - -/// Wallet extension traits -trait WalletExt { - /// True if the wallet is running in emulation mode (unable to sign txs) - fn is_emulating(&self) -> bool; -} - -impl WalletExt for Wallet { - /// True if the wallet is running in emulation mode (unable to sign txs) - fn is_emulating(&self) -> bool { - self.authority() != &self.signer() && !self.is_delegated() - } -} diff --git a/src/main.rs b/src/main.rs index 67ed313..b4477de 100644 --- a/src/main.rs +++ b/src/main.rs @@ -286,7 +286,7 @@ async fn main() -> std::io::Result<()> { websocket::start_ws_server( format!("{}:{}", &config.host, config.ws_port).as_str(), config.rpc_host.replace("http", "ws"), - state.wallet.clone(), + state.wallet.inner().clone(), client.program_data(), ) .await; diff --git a/src/types.rs b/src/types.rs index 7887926..3c71661 100644 --- a/src/types.rs +++ b/src/types.rs @@ -14,6 +14,7 @@ use drift_rs::{ MarketPrecision, MarketType, ModifyOrderParams, OrderParams, PositionDirection, PostOnlyParam, }, + Wallet, }; use rust_decimal::Decimal; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -571,6 +572,36 @@ impl From for UserCollateralResponse { } } +#[derive(PartialEq)] +pub enum WalletMode { + Normal, + Delegated, + Emulated, +} + +/// Wallet extension +pub struct GatewayWallet { + wallet: Wallet, + mode: WalletMode, +} + +impl GatewayWallet { + pub fn new(wallet: Wallet, mode: WalletMode) -> Self { + Self { wallet, mode } + } + pub fn inner(&self) -> &Wallet { + &self.wallet + } + /// True if the wallet is using delegated signing + pub fn is_delegated(&self) -> bool { + self.mode == WalletMode::Delegated + } + /// True if the wallet is running in emulation mode (unable to sign txs) + pub fn is_emulating(&self) -> bool { + self.mode == WalletMode::Emulated + } +} + #[cfg(test)] mod tests { use std::str::FromStr;