Skip to content

Commit

Permalink
Feat/wallet refactor (#95)
Browse files Browse the repository at this point in the history
* refactor wallet with mode enum
  • Loading branch information
jordy25519 authored Dec 6, 2024
1 parent 4941da7 commit 5aa8c6a
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "drift-gateway"
version = "1.2.1"
version = "1.2.2"
edition = "2021"

[dependencies]
Expand Down
72 changes: 32 additions & 40 deletions src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ use thiserror::Error;
use crate::{
types::{
get_market_decimals, AllMarketsResponse, CancelAndPlaceRequest, CancelOrdersRequest,
GetOrdersRequest, GetOrdersResponse, GetPositionsRequest, GetPositionsResponse, Market,
MarketInfoResponse, ModifyOrdersRequest, Order, PerpPosition, PerpPositionExtended,
PlaceOrdersRequest, SolBalanceResponse, SpotPosition, TxEventsResponse, TxResponse,
UserCollateralResponse, UserLeverageResponse, UserMarginResponse, PRICE_DECIMALS,
GatewayWallet, GetOrdersRequest, GetOrdersResponse, GetPositionsRequest,
GetPositionsResponse, Market, MarketInfoResponse, ModifyOrdersRequest, Order, PerpPosition,
PerpPositionExtended, PlaceOrdersRequest, SolBalanceResponse, SpotPosition,
TxEventsResponse, TxResponse, UserCollateralResponse, UserLeverageResponse,
UserMarginResponse, WalletMode, PRICE_DECIMALS,
},
websocket::map_drift_event_for_account,
Context, LOG_TARGET,
Expand All @@ -67,9 +68,7 @@ pub enum ControllerError {

#[derive(Clone)]
pub struct AppState {
pub wallet: Wallet,
/// true if gateway is using delegated signing
delegated: bool,
pub wallet: Arc<GatewayWallet>,
pub client: Arc<DriftClient>,
/// Solana tx commitment level for preflight confirmation
tx_commitment: CommitmentConfig,
Expand All @@ -85,17 +84,18 @@ pub struct AppState {
impl AppState {
/// Configured drift authority address
pub fn authority(&self) -> &Pubkey {
self.wallet.authority()
self.wallet.inner().authority()
}
/// Configured drift signing address
pub fn signer(&self) -> Pubkey {
self.wallet.signer()
self.wallet.inner().signer()
}
pub fn default_sub_account(&self) -> Pubkey {
self.wallet.sub_account(self.default_subaccount_id)
self.wallet.inner().sub_account(self.default_subaccount_id)
}
pub fn resolve_sub_account(&self, sub_account_id: Option<u16>) -> Pubkey {
self.wallet
.inner()
.sub_account(sub_account_id.unwrap_or(self.default_subaccount_id))
}

Expand All @@ -111,7 +111,7 @@ impl AppState {
pub async fn new(
endpoint: &str,
devnet: bool,
wallet: Wallet,
wallet: GatewayWallet,
commitment: Option<(CommitmentConfig, CommitmentConfig)>,
default_subaccount_id: Option<u16>,
skip_tx_preflight: bool,
Expand All @@ -126,11 +126,13 @@ impl AppState {
};

let rpc_client = RpcClient::new_with_commitment(endpoint.into(), state_commitment);
let client = DriftClient::new(context, rpc_client, wallet.clone())
let client = DriftClient::new(context, rpc_client, wallet.inner().clone())
.await
.expect("ok");

let default_subaccount = wallet.sub_account(default_subaccount_id.unwrap_or(0));
let default_subaccount = wallet
.inner()
.sub_account(default_subaccount_id.unwrap_or(0));
if let Err(err) = client.subscribe_account(&default_subaccount).await {
log::error!(target: LOG_TARGET, "couldn't subscribe to user updates: {err:?}");
} else {
Expand All @@ -150,24 +152,23 @@ impl AppState {
},
);

let priority_fee_subscriber = if !wallet.is_emulating() {
let priority_fee_subscriber = if wallet.is_emulating() {
Arc::new(priority_fee_subscriber)
} else {
client
.subscribe_blockhashes()
.await
.expect("blockhashes subscribed");
priority_fee_subscriber.subscribe()
} else {
Arc::new(priority_fee_subscriber)
};

Self {
client: Arc::new(client),
delegated: wallet.is_delegated(),
tx_commitment,
default_subaccount_id: default_subaccount_id.unwrap_or(0),
skip_tx_preflight,
priority_fee_subscriber,
wallet,
wallet: Arc::new(wallet),
extra_rpcs: extra_rpcs
.into_iter()
.map(|u| Arc::new(RpcClient::new(get_http_url(u).expect("valid RPC url"))))
Expand Down Expand Up @@ -218,7 +219,7 @@ impl AppState {
let balance = self
.client
.inner()
.get_balance(&self.wallet.signer())
.get_balance(&self.wallet.inner().signer())
.await
.map_err(|err| ControllerError::Sdk(err.into()))?;
Ok(SolBalanceResponse {
Expand Down Expand Up @@ -248,7 +249,7 @@ impl AppState {
self.client.program_data(),
sub_account,
Cow::Owned(account_data),
self.delegated,
self.wallet.is_delegated(),
)
.with_priority_fee(priority_fee, ctx.cu_limit);
let tx = build_cancel_ix(builder, req)?.build();
Expand Down Expand Up @@ -472,7 +473,7 @@ impl AppState {
self.client.program_data(),
sub_account,
Cow::Owned(account_data),
self.delegated,
self.wallet.is_delegated(),
)
.with_priority_fee(ctx.cu_price.unwrap_or(pf), ctx.cu_limit);

Expand Down Expand Up @@ -507,7 +508,7 @@ impl AppState {
self.client.program_data(),
sub_account,
Cow::Owned(account_data),
self.delegated,
self.wallet.is_delegated(),
)
.with_priority_fee(priority_fee, ctx.cu_limit)
.place_orders(orders)
Expand All @@ -528,7 +529,7 @@ impl AppState {
self.client.program_data(),
sub_account,
Cow::Owned(account_data),
self.delegated,
self.wallet.is_delegated(),
)
.with_priority_fee(ctx.cu_price.unwrap_or(pf), ctx.cu_limit);
let tx = build_modify_ix(builder, req, self.client.program_data())?.build();
Expand Down Expand Up @@ -612,7 +613,7 @@ impl AppState {
ttl: Option<u16>,
) -> GatewayResult<TxResponse> {
let recent_block_hash = self.client.get_latest_blockhash().await?;
let tx = self.wallet.sign_tx(tx, recent_block_hash)?;
let tx = self.wallet.inner().sign_tx(tx, recent_block_hash)?;
let tx_config = RpcSendTransactionConfig {
max_retries: Some(0),
preflight_commitment: Some(self.tx_commitment.commitment),
Expand Down Expand Up @@ -780,31 +781,22 @@ pub fn create_wallet(
secret_key: Option<String>,
emulate: Option<Pubkey>,
delegate: Option<Pubkey>,
) -> Wallet {
) -> GatewayWallet {
match (&secret_key, emulate, delegate) {
(Some(secret_key), _, delegate) => {
let mut wallet = Wallet::try_from_str(secret_key).expect("valid key");
if let Some(authority) = delegate {
wallet.to_delegated(authority);
GatewayWallet::new(wallet, WalletMode::Delegated)
} else {
GatewayWallet::new(wallet, WalletMode::Normal)
}
wallet
}
(None, Some(emulate), None) => Wallet::read_only(emulate),
(None, Some(emulate), None) => {
GatewayWallet::new(Wallet::read_only(emulate), WalletMode::Normal)
}
_ => {
panic!("expected 'DRIFT_GATEWAY_KEY' or --emulate <pubkey>");
}
}
}

/// Wallet extension traits
trait WalletExt {
/// True if the wallet is running in emulation mode (unable to sign txs)
fn is_emulating(&self) -> bool;
}

impl WalletExt for Wallet {
/// True if the wallet is running in emulation mode (unable to sign txs)
fn is_emulating(&self) -> bool {
self.authority() != &self.signer() && !self.is_delegated()
}
}
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ async fn main() -> std::io::Result<()> {
websocket::start_ws_server(
format!("{}:{}", &config.host, config.ws_port).as_str(),
config.rpc_host.replace("http", "ws"),
state.wallet.clone(),
state.wallet.inner().clone(),
client.program_data(),
)
.await;
Expand Down
31 changes: 31 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use drift_rs::{
MarketPrecision, MarketType, ModifyOrderParams, OrderParams, PositionDirection,
PostOnlyParam,
},
Wallet,
};
use rust_decimal::Decimal;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
Expand Down Expand Up @@ -571,6 +572,36 @@ impl From<CollateralInfo> for UserCollateralResponse {
}
}

#[derive(PartialEq)]
pub enum WalletMode {
Normal,
Delegated,
Emulated,
}

/// Wallet extension
pub struct GatewayWallet {
wallet: Wallet,
mode: WalletMode,
}

impl GatewayWallet {
pub fn new(wallet: Wallet, mode: WalletMode) -> Self {
Self { wallet, mode }
}
pub fn inner(&self) -> &Wallet {
&self.wallet
}
/// True if the wallet is using delegated signing
pub fn is_delegated(&self) -> bool {
self.mode == WalletMode::Delegated
}
/// True if the wallet is running in emulation mode (unable to sign txs)
pub fn is_emulating(&self) -> bool {
self.mode == WalletMode::Emulated
}
}

#[cfg(test)]
mod tests {
use std::str::FromStr;
Expand Down

0 comments on commit 5aa8c6a

Please sign in to comment.