Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jordy25519 committed Dec 15, 2023
1 parent b5f6661 commit 350a9e9
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use drift_sdk::{
use rust_decimal::Decimal;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Order {
#[serde(serialize_with = "order_type_ser", deserialize_with = "order_type_de")]
order_type: sdk_types::OrderType,
Expand Down Expand Up @@ -155,7 +155,7 @@ pub struct ModifyOrder {
price: Option<Decimal>,
pub user_order_id: Option<u8>,
pub order_id: Option<u32>,
pub reduce_only: Option<bool>,
reduce_only: Option<bool>,
}

impl ModifyOrder {
Expand Down Expand Up @@ -410,10 +410,10 @@ pub struct GetOrderbookRequest {

#[cfg(test)]
mod tests {
use drift_sdk::types::{Context, MarketType};
use drift_sdk::types::{Context, MarketType, PositionDirection};
use std::str::FromStr;

use crate::types::{Market, Order};
use crate::types::{Market, ModifyOrder, Order};

use super::{Decimal, PlaceOrder};

Expand Down Expand Up @@ -449,7 +449,8 @@ mod tests {
("0.1234", 123_400u64, 0_u16),
("123", 123_000_000_000, 1),
("1.23", 1_230_000_000, 1),
("5.123456789", 512_345_678, 4),
("-1.23", 1_230_000_000, 1),
("5.123456789", 512_345_678, 4), // truncates extra decimals
];
for (input, expected, market_index) in cases {
let p = PlaceOrder {
Expand All @@ -458,17 +459,27 @@ mod tests {
market: Market::spot(market_index),
..Default::default()
};
let is_short = p.amount.is_sign_negative();
let order_params = p.to_order_params(Context::MainNet);
assert_eq!(order_params.base_asset_amount, expected);
if is_short {
assert_eq!(order_params.direction, PositionDirection::Short);
} else {
assert_eq!(order_params.direction, PositionDirection::Long);
}
}
}

#[test]
fn order_from_sdk_order() {
let cases = [
(123_4000u64, Decimal::from_str("1.23400").unwrap(), 0_u16),
(
1_230_400_000_u64,
Decimal::from_str("1.2304").unwrap(),
0_u16,
),
(123_000_000_000, Decimal::from_str("123.0").unwrap(), 1),
(512_345_678, Decimal::from_str("5.12345678").unwrap(), 4),
(5_123_456_789, Decimal::from_str("5.123456789").unwrap(), 4),
];
for (input, expected, market_index) in cases {
let o = drift_sdk::types::Order {
Expand All @@ -482,4 +493,29 @@ mod tests {
assert_eq!(gateway_order.amount, expected);
}
}

#[test]
fn modify_order_to_order_params() {
let m = ModifyOrder {
amount: Decimal::from_str("-0.5").ok(),
price: Decimal::from_str("11.1").ok(),
..Default::default()
};
let order_params = m.to_order_params(1, MarketType::Spot, Context::MainNet);

assert_eq!(order_params.direction, Some(PositionDirection::Short));
assert_eq!(order_params.base_asset_amount, Some(500_000_000));
assert_eq!(order_params.price, Some(11_100_000));

let m = ModifyOrder {
amount: Decimal::from_str("12").ok(),
price: Decimal::from_str("1.02").ok(),
..Default::default()
};
let order_params = m.to_order_params(1, MarketType::Spot, Context::MainNet);

assert_eq!(order_params.direction, Some(PositionDirection::Long));
assert_eq!(order_params.base_asset_amount, Some(12_000_000_000));
assert_eq!(order_params.price, Some(1_020_000));
}
}

0 comments on commit 350a9e9

Please sign in to comment.