From 350a9e94abdb0ab1b9e22173320fabbae9cebf41 Mon Sep 17 00:00:00 2001 From: holygits Date: Fri, 15 Dec 2023 16:01:16 +0800 Subject: [PATCH] fix unit tests --- src/types.rs | 50 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/src/types.rs b/src/types.rs index 6a69fcb..8bcd818 100644 --- a/src/types.rs +++ b/src/types.rs @@ -15,7 +15,7 @@ use drift_sdk::{ use rust_decimal::Decimal; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct Order { #[serde(serialize_with = "order_type_ser", deserialize_with = "order_type_de")] order_type: sdk_types::OrderType, @@ -155,7 +155,7 @@ pub struct ModifyOrder { price: Option, pub user_order_id: Option, pub order_id: Option, - pub reduce_only: Option, + reduce_only: Option, } impl ModifyOrder { @@ -410,10 +410,10 @@ pub struct GetOrderbookRequest { #[cfg(test)] mod tests { - use drift_sdk::types::{Context, MarketType}; + use drift_sdk::types::{Context, MarketType, PositionDirection}; use std::str::FromStr; - use crate::types::{Market, Order}; + use crate::types::{Market, ModifyOrder, Order}; use super::{Decimal, PlaceOrder}; @@ -449,7 +449,8 @@ mod tests { ("0.1234", 123_400u64, 0_u16), ("123", 123_000_000_000, 1), ("1.23", 1_230_000_000, 1), - ("5.123456789", 512_345_678, 4), + ("-1.23", 1_230_000_000, 1), + ("5.123456789", 512_345_678, 4), // truncates extra decimals ]; for (input, expected, market_index) in cases { let p = PlaceOrder { @@ -458,17 +459,27 @@ mod tests { market: Market::spot(market_index), ..Default::default() }; + let is_short = p.amount.is_sign_negative(); let order_params = p.to_order_params(Context::MainNet); assert_eq!(order_params.base_asset_amount, expected); + if is_short { + assert_eq!(order_params.direction, PositionDirection::Short); + } else { + assert_eq!(order_params.direction, PositionDirection::Long); + } } } #[test] fn order_from_sdk_order() { let cases = [ - (123_4000u64, Decimal::from_str("1.23400").unwrap(), 0_u16), + ( + 1_230_400_000_u64, + Decimal::from_str("1.2304").unwrap(), + 0_u16, + ), (123_000_000_000, Decimal::from_str("123.0").unwrap(), 1), - (512_345_678, Decimal::from_str("5.12345678").unwrap(), 4), + (5_123_456_789, Decimal::from_str("5.123456789").unwrap(), 4), ]; for (input, expected, market_index) in cases { let o = drift_sdk::types::Order { @@ -482,4 +493,29 @@ mod tests { assert_eq!(gateway_order.amount, expected); } } + + #[test] + fn modify_order_to_order_params() { + let m = ModifyOrder { + amount: Decimal::from_str("-0.5").ok(), + price: Decimal::from_str("11.1").ok(), + ..Default::default() + }; + let order_params = m.to_order_params(1, MarketType::Spot, Context::MainNet); + + assert_eq!(order_params.direction, Some(PositionDirection::Short)); + assert_eq!(order_params.base_asset_amount, Some(500_000_000)); + assert_eq!(order_params.price, Some(11_100_000)); + + let m = ModifyOrder { + amount: Decimal::from_str("12").ok(), + price: Decimal::from_str("1.02").ok(), + ..Default::default() + }; + let order_params = m.to_order_params(1, MarketType::Spot, Context::MainNet); + + assert_eq!(order_params.direction, Some(PositionDirection::Long)); + assert_eq!(order_params.base_asset_amount, Some(12_000_000_000)); + assert_eq!(order_params.price, Some(1_020_000)); + } }