diff --git a/programs/drift/src/error.rs b/programs/drift/src/error.rs index ea7cc04d8..f62258f27 100644 --- a/programs/drift/src/error.rs +++ b/programs/drift/src/error.rs @@ -696,6 +696,10 @@ pub enum ErrorCode { MarketIndexNotFoundAmmCache, #[msg("Invalid Isolated Perp Market")] InvalidIsolatedPerpMarket, + #[msg("Invalid scale order count - must be between 2 and 10")] + InvalidOrderScaleOrderCount, + #[msg("Invalid scale order price range")] + InvalidOrderScalePriceRange, } #[macro_export] diff --git a/programs/drift/src/instructions/user.rs b/programs/drift/src/instructions/user.rs index 893fe5fad..d0ae756f6 100644 --- a/programs/drift/src/instructions/user.rs +++ b/programs/drift/src/instructions/user.rs @@ -81,6 +81,7 @@ use crate::state::order_params::{ PlaceOrderOptions, PostOnlyParam, }; use crate::state::paused_operations::{PerpOperation, SpotOperation}; +use crate::state::scale_order_params::ScaleOrderParams; use crate::state::perp_market::MarketStatus; use crate::state::perp_market_map::{get_writable_perp_market_set, MarketSet}; use crate::state::protected_maker_mode_config::ProtectedMakerModeConfig; @@ -2606,6 +2607,31 @@ pub fn handle_modify_order_by_user_order_id<'c: 'info, 'info>( pub fn handle_place_orders<'c: 'info, 'info>( ctx: Context<'_, '_, 'c, 'info, PlaceOrder>, params: Vec, +) -> Result<()> { + place_orders(&ctx, PlaceOrdersInput::Orders(params)) +} + +#[access_control( + exchange_not_paused(&ctx.accounts.state) +)] +pub fn handle_place_scale_orders<'c: 'info, 'info>( + ctx: Context<'_, '_, 'c, 'info, PlaceOrder>, + params: ScaleOrderParams, +) -> Result<()> { + place_orders(&ctx, PlaceOrdersInput::ScaleOrders(params)) +} + +/// Input for place_orders - either direct OrderParams or ScaleOrderParams to expand +enum PlaceOrdersInput { + Orders(Vec), + ScaleOrders(ScaleOrderParams), +} + +/// Internal implementation for placing multiple orders. +/// Used by both handle_place_orders and handle_place_scale_orders. +fn place_orders<'c: 'info, 'info>( + ctx: &Context<'_, '_, 'c, 'info, PlaceOrder>, + input: PlaceOrdersInput, ) -> Result<()> { let clock = &Clock::get()?; let state = &ctx.accounts.state; @@ -2625,8 +2651,31 @@ pub fn handle_place_orders<'c: 'info, 'info>( let high_leverage_mode_config = get_high_leverage_mode_config(&mut remaining_accounts)?; + // Convert input to order params, expanding scale orders if needed + let order_params = match input { + PlaceOrdersInput::Orders(params) => params, + PlaceOrdersInput::ScaleOrders(scale_params) => { + let order_step_size = match scale_params.market_type { + MarketType::Perp => { + let market = perp_market_map.get_ref(&scale_params.market_index)?; + market.amm.order_step_size + } + MarketType::Spot => { + let market = spot_market_map.get_ref(&scale_params.market_index)?; + market.order_step_size + } + }; + + scale_params.expand_to_order_params(order_step_size) + .map_err(|e| { + msg!("Failed to expand scale order params: {:?}", e); + ErrorCode::InvalidOrder + })? + } + }; + validate!( - params.len() <= 32, + order_params.len() <= 32, ErrorCode::DefaultError, "max 32 order params" )?; @@ -2634,8 +2683,8 @@ pub fn handle_place_orders<'c: 'info, 'info>( let user_key = ctx.accounts.user.key(); let mut user = load_mut!(ctx.accounts.user)?; - let num_orders = params.len(); - for (i, params) in params.iter().enumerate() { + let num_orders = order_params.len(); + for (i, params) in order_params.iter().enumerate() { validate!( !params.is_immediate_or_cancel(), ErrorCode::InvalidOrderIOC, @@ -2654,7 +2703,7 @@ pub fn handle_place_orders<'c: 'info, 'info>( if params.market_type == MarketType::Perp { controller::orders::place_perp_order( - &ctx.accounts.state, + state, &mut user, user_key, &perp_market_map, @@ -2668,7 +2717,7 @@ pub fn handle_place_orders<'c: 'info, 'info>( )?; } else { controller::orders::place_spot_order( - &ctx.accounts.state, + state, &mut user, user_key, &perp_market_map, diff --git a/programs/drift/src/lib.rs b/programs/drift/src/lib.rs index fe0b403e4..3ecd0ef29 100644 --- a/programs/drift/src/lib.rs +++ b/programs/drift/src/lib.rs @@ -13,6 +13,7 @@ use crate::controller::position::PositionDirection; use crate::state::if_rebalance_config::IfRebalanceConfigParams; use crate::state::oracle::PrelaunchOracleParams; use crate::state::order_params::{ModifyOrderParams, OrderParams}; +use crate::state::scale_order_params::ScaleOrderParams; use crate::state::perp_market::{ContractTier, MarketStatus}; use crate::state::settle_pnl_mode::SettlePnlMode; use crate::state::spot_market::AssetTier; @@ -367,6 +368,13 @@ pub mod drift { handle_place_orders(ctx, params) } + pub fn place_scale_orders<'c: 'info, 'info>( + ctx: Context<'_, '_, 'c, 'info, PlaceOrder>, + params: ScaleOrderParams, + ) -> Result<()> { + handle_place_scale_orders(ctx, params) + } + pub fn begin_swap<'c: 'info, 'info>( ctx: Context<'_, '_, 'c, 'info, Swap<'info>>, in_market_index: u16, diff --git a/programs/drift/src/state/mod.rs b/programs/drift/src/state/mod.rs index 73b57392f..aecc0d168 100644 --- a/programs/drift/src/state/mod.rs +++ b/programs/drift/src/state/mod.rs @@ -15,6 +15,7 @@ pub mod oracle; pub mod oracle_map; pub mod order_params; pub mod paused_operations; +pub mod scale_order_params; pub mod perp_market; pub mod perp_market_map; pub mod protected_maker_mode_config; diff --git a/programs/drift/src/state/order_params.rs b/programs/drift/src/state/order_params.rs index a5b81b8bb..97e261923 100644 --- a/programs/drift/src/state/order_params.rs +++ b/programs/drift/src/state/order_params.rs @@ -1027,3 +1027,4 @@ pub fn parse_optional_params(optional_params: Option) -> (u8, u8) { None => (0, 100), } } + diff --git a/programs/drift/src/state/scale_order_params.rs b/programs/drift/src/state/scale_order_params.rs new file mode 100644 index 000000000..5b06df86c --- /dev/null +++ b/programs/drift/src/state/scale_order_params.rs @@ -0,0 +1,270 @@ +use crate::controller::position::PositionDirection; +use crate::error::{DriftResult, ErrorCode}; +use crate::math::constants::MAX_OPEN_ORDERS; +use crate::math::safe_math::SafeMath; +use crate::state::order_params::{OrderParams, PostOnlyParam}; +use crate::state::user::{MarketType, OrderTriggerCondition, OrderType}; +use crate::validate; +use anchor_lang::prelude::*; + +#[cfg(test)] +mod tests; + +/// Minimum number of orders required for a scale order +pub const MIN_SCALE_ORDER_COUNT: u8 = 2; +/// Maximum number of orders allowed in a single scale order instruction +pub const MAX_SCALE_ORDER_COUNT: u8 = MAX_OPEN_ORDERS; + +/// How to distribute order sizes across scale orders +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Copy, Default, Eq, PartialEq, Debug)] +pub enum SizeDistribution { + /// Equal size for all orders + #[default] + Flat, + /// Smallest orders at start price, largest at end price + Ascending, + /// Largest orders at start price, smallest at end price + Descending, +} + +/// Parameters for placing scale orders - multiple limit orders distributed across a price range +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Default, Eq, PartialEq, Debug)] +pub struct ScaleOrderParams { + pub market_type: MarketType, + pub direction: PositionDirection, + pub market_index: u16, + /// Total base asset amount to distribute across all orders + pub total_base_asset_amount: u64, + /// Starting price for the scale (in PRICE_PRECISION) + pub start_price: u64, + /// Ending price for the scale (in PRICE_PRECISION) + pub end_price: u64, + /// Number of orders to place (min 2, max 32) + pub order_count: u8, + /// How to distribute sizes across orders + pub size_distribution: SizeDistribution, + /// Whether orders should be reduce-only + pub reduce_only: bool, + /// Post-only setting for all orders + pub post_only: PostOnlyParam, + /// Bit flags (e.g., for high leverage mode) + pub bit_flags: u8, + /// Maximum timestamp for orders to be valid + pub max_ts: Option, +} + +impl ScaleOrderParams { + /// Validates the scale order parameters + pub fn validate(&self, order_step_size: u64) -> DriftResult<()> { + validate!( + self.order_count >= MIN_SCALE_ORDER_COUNT, + ErrorCode::InvalidOrderScaleOrderCount, + "order_count must be at least {}", + MIN_SCALE_ORDER_COUNT + )?; + + validate!( + self.order_count <= MAX_SCALE_ORDER_COUNT, + ErrorCode::InvalidOrderScaleOrderCount, + "order_count must be at most {}", + MAX_SCALE_ORDER_COUNT + )?; + + validate!( + self.start_price != self.end_price, + ErrorCode::InvalidOrderScalePriceRange, + "start_price and end_price cannot be equal" + )?; + + // For long orders, start price is higher (first buy) and end price is lower (DCA down) + // For short orders, start price is lower (first sell) and end price is higher (scale out up) + match self.direction { + PositionDirection::Long => { + validate!( + self.start_price > self.end_price, + ErrorCode::InvalidOrderScalePriceRange, + "for long scale orders, start_price must be greater than end_price (scaling down)" + )?; + } + PositionDirection::Short => { + validate!( + self.start_price < self.end_price, + ErrorCode::InvalidOrderScalePriceRange, + "for short scale orders, start_price must be less than end_price (scaling up)" + )?; + } + } + + // Validate that total size can be distributed among all orders meeting minimum step size + let min_total_size = order_step_size.safe_mul(self.order_count as u64)?; + validate!( + self.total_base_asset_amount >= min_total_size, + ErrorCode::OrderAmountTooSmall, + "total_base_asset_amount must be at least {} (order_step_size * order_count)", + min_total_size + )?; + + Ok(()) + } + + /// Calculate evenly distributed prices between start and end price + pub fn calculate_price_distribution(&self) -> DriftResult> { + let order_count = self.order_count as usize; + + if order_count == 1 { + return Ok(vec![self.start_price]); + } + + if order_count == 2 { + return Ok(vec![self.start_price, self.end_price]); + } + + let (min_price, max_price) = if self.start_price < self.end_price { + (self.start_price, self.end_price) + } else { + (self.end_price, self.start_price) + }; + + let price_range = max_price.safe_sub(min_price)?; + let num_steps = (order_count - 1) as u64; + let price_step = price_range.safe_div(num_steps)?; + + let mut prices = Vec::with_capacity(order_count); + for i in 0..order_count { + // Use exact end_price for the last order to avoid rounding errors + let price = if i == order_count - 1 { + self.end_price + } else if self.start_price < self.end_price { + self.start_price.safe_add(price_step.safe_mul(i as u64)?)? + } else { + self.start_price.safe_sub(price_step.safe_mul(i as u64)?)? + }; + prices.push(price); + } + + Ok(prices) + } + + /// Calculate order sizes based on size distribution strategy + pub fn calculate_size_distribution(&self, order_step_size: u64) -> DriftResult> { + match self.size_distribution { + SizeDistribution::Flat => self.calculate_flat_sizes(order_step_size), + SizeDistribution::Ascending => self.calculate_scaled_sizes(order_step_size, false), + SizeDistribution::Descending => self.calculate_scaled_sizes(order_step_size, true), + } + } + + /// Calculate flat (equal) distribution of sizes + fn calculate_flat_sizes(&self, order_step_size: u64) -> DriftResult> { + let order_count = self.order_count as u64; + let base_size = self.total_base_asset_amount.safe_div(order_count)?; + // Round down to step size + let rounded_size = base_size + .safe_div(order_step_size)? + .safe_mul(order_step_size)?; + + let mut sizes = vec![rounded_size; self.order_count as usize]; + + // Add remainder to the last order + let total_distributed: u64 = sizes.iter().sum(); + let remainder = self.total_base_asset_amount.safe_sub(total_distributed)?; + if remainder > 0 { + if let Some(last) = sizes.last_mut() { + *last = last.safe_add(remainder)?; + } + } + + Ok(sizes) + } + + /// Calculate scaled (ascending/descending) distribution of sizes + /// Uses multipliers: 1x, 1.5x, 2x, 2.5x, ... for ascending + fn calculate_scaled_sizes( + &self, + order_step_size: u64, + descending: bool, + ) -> DriftResult> { + let order_count = self.order_count as usize; + + // Calculate multipliers: 1.0, 1.5, 2.0, 2.5, ... (using 0.5 step) + // Sum of multipliers = n/2 * (first + last) = n/2 * (1 + (1 + 0.5*(n-1))) + // For precision, multiply everything by 2: multipliers become 2, 3, 4, 5, ... + // Sum = n/2 * (2 + (2 + (n-1))) = n/2 * (3 + n) = n*(n+3)/2 + let multiplier_sum = (order_count * (order_count + 3)) / 2; + + // Base unit size (multiplied by 2 for precision) + let base_unit = self + .total_base_asset_amount + .safe_mul(2)? + .safe_div(multiplier_sum as u64)?; + + let mut sizes = Vec::with_capacity(order_count); + let mut total = 0u64; + + for i in 0..order_count { + // Multiplier for position i is (2 + i) when using 0.5 step scaled by 2 + let multiplier = (2 + i) as u64; + let raw_size = base_unit.safe_mul(multiplier)?.safe_div(2)?; + // Round to step size + let rounded_size = raw_size + .safe_div(order_step_size)? + .safe_mul(order_step_size)? + .max(order_step_size); // Ensure at least step size + sizes.push(rounded_size); + total = total.safe_add(rounded_size)?; + } + + // Adjust last order to account for rounding + if total != self.total_base_asset_amount { + if let Some(last) = sizes.last_mut() { + if total > self.total_base_asset_amount { + let diff = total.safe_sub(self.total_base_asset_amount)?; + *last = last.saturating_sub(diff).max(order_step_size); + } else { + let diff = self.total_base_asset_amount.safe_sub(total)?; + *last = last.safe_add(diff)?; + } + } + } + + if descending { + sizes.reverse(); + } + + Ok(sizes) + } + + /// Expand scale order params into individual OrderParams + pub fn expand_to_order_params(&self, order_step_size: u64) -> DriftResult> { + self.validate(order_step_size)?; + + let prices = self.calculate_price_distribution()?; + let sizes = self.calculate_size_distribution(order_step_size)?; + + let mut order_params = Vec::with_capacity(self.order_count as usize); + + for (i, (price, size)) in prices.iter().zip(sizes.iter()).enumerate() { + order_params.push(OrderParams { + order_type: OrderType::Limit, + market_type: self.market_type, + direction: self.direction, + user_order_id: 0, + base_asset_amount: *size, + price: *price, + market_index: self.market_index, + reduce_only: self.reduce_only, + post_only: self.post_only, + bit_flags: if i == 0 { self.bit_flags } else { 0 }, + max_ts: self.max_ts, + trigger_price: None, + trigger_condition: OrderTriggerCondition::Above, + oracle_price_offset: None, + auction_duration: None, + auction_start_price: None, + auction_end_price: None, + }); + } + + Ok(order_params) + } +} diff --git a/programs/drift/src/state/scale_order_params/tests.rs b/programs/drift/src/state/scale_order_params/tests.rs new file mode 100644 index 000000000..4877d3d11 --- /dev/null +++ b/programs/drift/src/state/scale_order_params/tests.rs @@ -0,0 +1,568 @@ +use crate::state::order_params::PostOnlyParam; +use crate::state::scale_order_params::{ScaleOrderParams, SizeDistribution}; +use crate::state::user::MarketType; +use crate::{PositionDirection, BASE_PRECISION_U64, PRICE_PRECISION_U64}; + +#[test] +fn test_validate_order_count_bounds() { + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + // Test minimum order count + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 1, // Below minimum + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + assert!(params.validate(step_size).is_err()); + + // Test maximum order count + let params = ScaleOrderParams { + order_count: 33, // Above maximum (MAX_OPEN_ORDERS = 32) + ..params + }; + assert!(params.validate(step_size).is_err()); + + // Test valid order count + let params = ScaleOrderParams { + order_count: 5, + ..params + }; + assert!(params.validate(step_size).is_ok()); +} + +#[test] +fn test_validate_price_range() { + let step_size = BASE_PRECISION_U64 / 1000; + + // Long orders: start_price must be > end_price (scaling down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, + start_price: 100 * PRICE_PRECISION_U64, // Wrong: lower than end + end_price: 110 * PRICE_PRECISION_U64, + order_count: 5, + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + assert!(params.validate(step_size).is_err()); + + // Short orders: start_price must be < end_price (scaling up) + let params = ScaleOrderParams { + direction: PositionDirection::Short, + start_price: 110 * PRICE_PRECISION_U64, // Wrong: higher than end + end_price: 100 * PRICE_PRECISION_U64, + ..params + }; + assert!(params.validate(step_size).is_err()); + + // Valid long order (start high, end low - DCA down) + let params = ScaleOrderParams { + direction: PositionDirection::Long, + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + ..params + }; + assert!(params.validate(step_size).is_ok()); + + // Valid short order (start low, end high - scale out up) + let params = ScaleOrderParams { + direction: PositionDirection::Short, + start_price: 100 * PRICE_PRECISION_U64, + end_price: 110 * PRICE_PRECISION_U64, + ..params + }; + assert!(params.validate(step_size).is_ok()); +} + +#[test] +fn test_price_distribution_long() { + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 5, + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let prices = params.calculate_price_distribution().unwrap(); + assert_eq!(prices.len(), 5); + assert_eq!(prices[0], 110 * PRICE_PRECISION_U64); + assert_eq!(prices[1], 107500000); // 107.5 + assert_eq!(prices[2], 105 * PRICE_PRECISION_U64); + assert_eq!(prices[3], 102500000); // 102.5 + assert_eq!(prices[4], 100 * PRICE_PRECISION_U64); +} + +#[test] +fn test_price_distribution_short() { + // Short: start low, end high (scale out up) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Short, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, + start_price: 100 * PRICE_PRECISION_U64, + end_price: 110 * PRICE_PRECISION_U64, + order_count: 5, + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let prices = params.calculate_price_distribution().unwrap(); + assert_eq!(prices.len(), 5); + assert_eq!(prices[0], 100 * PRICE_PRECISION_U64); + assert_eq!(prices[1], 102500000); // 102.5 + assert_eq!(prices[2], 105 * PRICE_PRECISION_U64); + assert_eq!(prices[3], 107500000); // 107.5 + assert_eq!(prices[4], 110 * PRICE_PRECISION_U64); +} + +#[test] +fn test_flat_size_distribution() { + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 5, + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let sizes = params.calculate_size_distribution(step_size).unwrap(); + assert_eq!(sizes.len(), 5); + + // Total must equal the requested amount + let total: u64 = sizes.iter().sum(); + assert_eq!(total, BASE_PRECISION_U64); + + // Flat distribution: each order should be 1/5 = 20% of total + // Expected: 200_000_000 each (0.2 BASE) + // First 4 orders are exactly 0.2, last order gets any remainder + assert_eq!(sizes[0], 200_000_000); // 20% + assert_eq!(sizes[1], 200_000_000); // 20% + assert_eq!(sizes[2], 200_000_000); // 20% + assert_eq!(sizes[3], 200_000_000); // 20% + assert_eq!(sizes[4], 200_000_000); // 20% (remainder goes here if any) + + // Verify each order is exactly 20% of total + for size in &sizes { + let pct = (*size as f64) / (BASE_PRECISION_U64 as f64) * 100.0; + assert!((pct - 20.0).abs() < 0.1, "Expected ~20%, got {}%", pct); + } +} + +#[test] +fn test_ascending_size_distribution() { + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 5, + size_distribution: SizeDistribution::Ascending, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let sizes = params.calculate_size_distribution(step_size).unwrap(); + assert_eq!(sizes.len(), 5); + + // Total must equal the requested amount + let total: u64 = sizes.iter().sum(); + assert_eq!(total, BASE_PRECISION_U64); + + // Ascending distribution uses multipliers: 1x, 1.5x, 2x, 2.5x, 3x + // Scaled by 2 for precision: 2, 3, 4, 5, 6 (sum = 20) + // Expected proportions: 10%, 15%, 20%, 25%, 30% + // For 1_000_000_000 total: 100M, 150M, 200M, 250M, 300M + assert_eq!(sizes[0], 100_000_000); // 10% - smallest + assert_eq!(sizes[1], 150_000_000); // 15% + assert_eq!(sizes[2], 200_000_000); // 20% + assert_eq!(sizes[3], 250_000_000); // 25% + assert_eq!(sizes[4], 300_000_000); // 30% - largest + + // Verify ascending order: each subsequent order is larger + assert!(sizes[0] < sizes[1]); + assert!(sizes[1] < sizes[2]); + assert!(sizes[2] < sizes[3]); + assert!(sizes[3] < sizes[4]); + + // Verify the proportions are correct (within 1% tolerance for rounding) + let expected_pcts = [10.0, 15.0, 20.0, 25.0, 30.0]; + for (i, (size, expected_pct)) in sizes.iter().zip(expected_pcts.iter()).enumerate() { + let actual_pct = (*size as f64) / (BASE_PRECISION_U64 as f64) * 100.0; + assert!( + (actual_pct - expected_pct).abs() < 1.0, + "Order {}: expected ~{}%, got {}%", + i, + expected_pct, + actual_pct + ); + } +} + +#[test] +fn test_descending_size_distribution() { + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 5, + size_distribution: SizeDistribution::Descending, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let sizes = params.calculate_size_distribution(step_size).unwrap(); + assert_eq!(sizes.len(), 5); + + // Total must equal the requested amount + let total: u64 = sizes.iter().sum(); + assert_eq!(total, BASE_PRECISION_U64); + + // Descending distribution is reverse of ascending + // Multipliers (reversed): 3x, 2.5x, 2x, 1.5x, 1x + // Expected proportions: 30%, 25%, 20%, 15%, 10% + // For 1_000_000_000 total: 300M, 250M, 200M, 150M, 100M + assert_eq!(sizes[0], 300_000_000); // 30% - largest + assert_eq!(sizes[1], 250_000_000); // 25% + assert_eq!(sizes[2], 200_000_000); // 20% + assert_eq!(sizes[3], 150_000_000); // 15% + assert_eq!(sizes[4], 100_000_000); // 10% - smallest + + // Verify descending order: each subsequent order is smaller + assert!(sizes[0] > sizes[1]); + assert!(sizes[1] > sizes[2]); + assert!(sizes[2] > sizes[3]); + assert!(sizes[3] > sizes[4]); + + // Verify the proportions are correct (within 1% tolerance for rounding) + let expected_pcts = [30.0, 25.0, 20.0, 15.0, 10.0]; + for (i, (size, expected_pct)) in sizes.iter().zip(expected_pcts.iter()).enumerate() { + let actual_pct = (*size as f64) / (BASE_PRECISION_U64 as f64) * 100.0; + assert!( + (actual_pct - expected_pct).abs() < 1.0, + "Order {}: expected ~{}%, got {}%", + i, + expected_pct, + actual_pct + ); + } +} + +#[test] +fn test_ascending_size_distribution_3_orders() { + // Test with different order count to verify formula works correctly + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 3, + size_distribution: SizeDistribution::Ascending, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let sizes = params.calculate_size_distribution(step_size).unwrap(); + assert_eq!(sizes.len(), 3); + + // Total must equal the requested amount + let total: u64 = sizes.iter().sum(); + assert_eq!(total, BASE_PRECISION_U64); + + // For 3 orders: multiplier_sum = n*(n+3)/2 = 3*6/2 = 9 + // Multipliers (scaled by 2): 2, 3, 4 + // Expected proportions: 2/9 ≈ 22.2%, 3/9 ≈ 33.3%, 4/9 ≈ 44.4% + let expected_pcts = [22.22, 33.33, 44.44]; + for (i, (size, expected_pct)) in sizes.iter().zip(expected_pcts.iter()).enumerate() { + let actual_pct = (*size as f64) / (BASE_PRECISION_U64 as f64) * 100.0; + assert!( + (actual_pct - expected_pct).abs() < 1.0, + "Order {}: expected ~{}%, got {}%", + i, + expected_pct, + actual_pct + ); + } + + // Verify ascending order + assert!(sizes[0] < sizes[1]); + assert!(sizes[1] < sizes[2]); +} + +#[test] +fn test_flat_distribution_with_remainder() { + // Test flat distribution where total doesn't divide evenly + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 3, // 1.0 / 3 doesn't divide evenly + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let sizes = params.calculate_size_distribution(step_size).unwrap(); + assert_eq!(sizes.len(), 3); + + // Total must still equal exactly the requested amount + let total: u64 = sizes.iter().sum(); + assert_eq!(total, BASE_PRECISION_U64); + + // Each order should be ~33.3%, with remainder going to last order + // step_size = 1_000_000 (0.001) + // base_size = 1_000_000_000 / 3 = 333_333_333 + // rounded_size = (333_333_333 / 1_000_000) * 1_000_000 = 333_000_000 + // First two orders: 333_000_000 each + // Last order: 1_000_000_000 - 2*333_000_000 = 334_000_000 + assert_eq!(sizes[0], 333_000_000); + assert_eq!(sizes[1], 333_000_000); + assert_eq!(sizes[2], 334_000_000); // Gets the remainder +} + +#[test] +fn test_expand_to_order_params_perp() { + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 1, + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 3, + size_distribution: SizeDistribution::Flat, + reduce_only: true, + post_only: PostOnlyParam::MustPostOnly, + bit_flags: 2, // High leverage mode + max_ts: Some(12345), + }; + + let order_params = params.expand_to_order_params(step_size).unwrap(); + assert_eq!(order_params.len(), 3); + + // Check first order has bit flags + assert_eq!(order_params[0].bit_flags, 2); + // Other orders should have 0 bit flags + assert_eq!(order_params[1].bit_flags, 0); + assert_eq!(order_params[2].bit_flags, 0); + + // Check common properties + for op in &order_params { + assert_eq!(op.market_type, MarketType::Perp); + assert_eq!(op.market_index, 1); + assert_eq!(op.reduce_only, true); + assert_eq!(op.post_only, PostOnlyParam::MustPostOnly); + assert_eq!(op.max_ts, Some(12345)); + assert!(matches!(op.direction, PositionDirection::Long)); + } + + // Check prices are distributed (high to low for long) + assert_eq!(order_params[0].price, 110 * PRICE_PRECISION_U64); + assert_eq!(order_params[1].price, 105 * PRICE_PRECISION_U64); + assert_eq!(order_params[2].price, 100 * PRICE_PRECISION_U64); + + // Check total size + let total: u64 = order_params.iter().map(|op| op.base_asset_amount).sum(); + assert_eq!(total, BASE_PRECISION_U64); +} + +#[test] +fn test_expand_to_order_params_spot() { + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + // Spot Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Spot, + direction: PositionDirection::Long, + market_index: 1, // SOL spot market + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 3, + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let order_params = params.expand_to_order_params(step_size).unwrap(); + assert_eq!(order_params.len(), 3); + + // Check all orders are Spot market type + for op in &order_params { + assert_eq!(op.market_type, MarketType::Spot); + assert_eq!(op.market_index, 1); + assert!(matches!(op.direction, PositionDirection::Long)); + } + + // Check prices are distributed (high to low for long) + assert_eq!(order_params[0].price, 110 * PRICE_PRECISION_U64); + assert_eq!(order_params[1].price, 105 * PRICE_PRECISION_U64); + assert_eq!(order_params[2].price, 100 * PRICE_PRECISION_U64); + + // Check total size + let total: u64 = order_params.iter().map(|op| op.base_asset_amount).sum(); + assert_eq!(total, BASE_PRECISION_U64); +} + +#[test] +fn test_spot_short_scale_orders() { + let step_size = BASE_PRECISION_U64 / 1000; // 0.001 + + // Spot Short: start low, end high (scale out up) + let params = ScaleOrderParams { + market_type: MarketType::Spot, + direction: PositionDirection::Short, + market_index: 1, // SOL spot market + total_base_asset_amount: BASE_PRECISION_U64, // 1.0 + start_price: 100 * PRICE_PRECISION_U64, + end_price: 110 * PRICE_PRECISION_U64, + order_count: 4, + size_distribution: SizeDistribution::Ascending, + reduce_only: false, + post_only: PostOnlyParam::MustPostOnly, + bit_flags: 0, + max_ts: Some(99999), + }; + + let order_params = params.expand_to_order_params(step_size).unwrap(); + assert_eq!(order_params.len(), 4); + + // Check all orders are Spot market type and Short direction + for op in &order_params { + assert_eq!(op.market_type, MarketType::Spot); + assert_eq!(op.market_index, 1); + assert!(matches!(op.direction, PositionDirection::Short)); + assert_eq!(op.post_only, PostOnlyParam::MustPostOnly); + assert_eq!(op.max_ts, Some(99999)); + } + + // Check prices are distributed (low to high for short) + assert_eq!(order_params[0].price, 100 * PRICE_PRECISION_U64); + // Middle prices + assert_eq!(order_params[3].price, 110 * PRICE_PRECISION_U64); + + // Ascending: sizes should increase + assert!(order_params[0].base_asset_amount < order_params[3].base_asset_amount); + + // Check total size + let total: u64 = order_params.iter().map(|op| op.base_asset_amount).sum(); + assert_eq!(total, BASE_PRECISION_U64); +} + +#[test] +fn test_two_orders_price_distribution() { + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64, + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 2, + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + let prices = params.calculate_price_distribution().unwrap(); + assert_eq!(prices.len(), 2); + assert_eq!(prices[0], 110 * PRICE_PRECISION_U64); + assert_eq!(prices[1], 100 * PRICE_PRECISION_U64); +} + +#[test] +fn test_validate_min_total_size() { + let step_size = BASE_PRECISION_U64 / 10; // 0.1 + + // Total size is too small for 5 orders with this step size + // Long: start high, end low (DCA down) + let params = ScaleOrderParams { + market_type: MarketType::Perp, + direction: PositionDirection::Long, + market_index: 0, + total_base_asset_amount: BASE_PRECISION_U64 / 20, // 0.05 - not enough + start_price: 110 * PRICE_PRECISION_U64, + end_price: 100 * PRICE_PRECISION_U64, + order_count: 5, + size_distribution: SizeDistribution::Flat, + reduce_only: false, + post_only: PostOnlyParam::None, + bit_flags: 0, + max_ts: None, + }; + + assert!(params.validate(step_size).is_err()); +} diff --git a/sdk/src/driftClient.ts b/sdk/src/driftClient.ts index 132a50727..08c372664 100644 --- a/sdk/src/driftClient.ts +++ b/sdk/src/driftClient.ts @@ -48,6 +48,7 @@ import { PositionDirection, ReferrerInfo, ReferrerNameAccount, + ScaleOrderParams, SerumV3FulfillmentConfigAccount, SettlePnlMode, SignedTxData, @@ -5602,6 +5603,96 @@ export class DriftClient { return [placeOrdersIxs, setPositionMaxLevIxs]; } + /** + * Place scale orders - multiple limit orders distributed across a price range + * @param params Scale order parameters + * @param txParams Optional transaction parameters + * @param subAccountId Optional sub account ID + * @returns Transaction signature + */ + public async placeScaleOrders( + params: ScaleOrderParams, + txParams?: TxParams, + subAccountId?: number + ): Promise { + const { txSig } = await this.sendTransaction( + (await this.preparePlaceScaleOrdersTx(params, txParams, subAccountId)) + .placeScaleOrdersTx, + [], + this.opts, + false + ); + return txSig; + } + + public async preparePlaceScaleOrdersTx( + params: ScaleOrderParams, + txParams?: TxParams, + subAccountId?: number + ) { + const lookupTableAccounts = await this.fetchAllLookupTableAccounts(); + + const tx = await this.buildTransaction( + await this.getPlaceScaleOrdersIx(params, subAccountId), + txParams, + undefined, + lookupTableAccounts + ); + + return { + placeScaleOrdersTx: tx, + }; + } + + public async getPlaceScaleOrdersIx( + params: ScaleOrderParams, + subAccountId?: number + ): Promise { + const user = await this.getUserAccountPublicKey(subAccountId); + + const isPerp = isVariant(params.marketType, 'perp'); + + const remainingAccounts = this.getRemainingAccounts({ + userAccounts: [this.getUserAccount(subAccountId)], + readablePerpMarketIndex: isPerp ? [params.marketIndex] : [], + readableSpotMarketIndexes: isPerp ? [] : [params.marketIndex], + useMarketLastSlotCache: true, + }); + + if (isUpdateHighLeverageMode(params.bitFlags)) { + remainingAccounts.push({ + pubkey: getHighLeverageModeConfigPublicKey(this.program.programId), + isWritable: true, + isSigner: false, + }); + } + + const formattedParams = { + marketType: params.marketType, + direction: params.direction, + marketIndex: params.marketIndex, + totalBaseAssetAmount: params.totalBaseAssetAmount, + startPrice: params.startPrice, + endPrice: params.endPrice, + orderCount: params.orderCount, + sizeDistribution: params.sizeDistribution, + reduceOnly: params.reduceOnly, + postOnly: params.postOnly, + bitFlags: params.bitFlags, + maxTs: params.maxTs, + }; + + return await this.program.instruction.placeScaleOrders(formattedParams, { + accounts: { + state: await this.getStatePublicKey(), + user, + userStats: this.getUserStatsAccountPublicKey(), + authority: this.wallet.publicKey, + }, + remainingAccounts, + }); + } + public async fillPerpOrder( userAccountPublicKey: PublicKey, user: UserAccount, diff --git a/sdk/src/idl/drift.json b/sdk/src/idl/drift.json index 92afad4e6..43a3d9169 100644 --- a/sdk/src/idl/drift.json +++ b/sdk/src/idl/drift.json @@ -1382,6 +1382,34 @@ } ] }, + { + "name": "placeScaleOrders", + "accounts": [ + { + "name": "state", + "isMut": false, + "isSigner": false + }, + { + "name": "user", + "isMut": true, + "isSigner": false + }, + { + "name": "authority", + "isMut": false, + "isSigner": true + } + ], + "args": [ + { + "name": "params", + "type": { + "defined": "ScaleOrderParams" + } + } + ] + }, { "name": "beginSwap", "accounts": [ @@ -13260,6 +13288,102 @@ ] } }, + { + "name": "ScaleOrderParams", + "docs": [ + "Parameters for placing scale orders - multiple limit orders distributed across a price range" + ], + "type": { + "kind": "struct", + "fields": [ + { + "name": "marketType", + "type": { + "defined": "MarketType" + } + }, + { + "name": "direction", + "type": { + "defined": "PositionDirection" + } + }, + { + "name": "marketIndex", + "type": "u16" + }, + { + "name": "totalBaseAssetAmount", + "docs": [ + "Total base asset amount to distribute across all orders" + ], + "type": "u64" + }, + { + "name": "startPrice", + "docs": [ + "Starting price for the scale (in PRICE_PRECISION)" + ], + "type": "u64" + }, + { + "name": "endPrice", + "docs": [ + "Ending price for the scale (in PRICE_PRECISION)" + ], + "type": "u64" + }, + { + "name": "orderCount", + "docs": [ + "Number of orders to place (min 2, max 32)" + ], + "type": "u8" + }, + { + "name": "sizeDistribution", + "docs": [ + "How to distribute sizes across orders" + ], + "type": { + "defined": "SizeDistribution" + } + }, + { + "name": "reduceOnly", + "docs": [ + "Whether orders should be reduce-only" + ], + "type": "bool" + }, + { + "name": "postOnly", + "docs": [ + "Post-only setting for all orders" + ], + "type": { + "defined": "PostOnlyParam" + } + }, + { + "name": "bitFlags", + "docs": [ + "Bit flags (e.g., for high leverage mode)" + ], + "type": "u8" + }, + { + "name": "maxTs", + "docs": [ + "Maximum timestamp for orders to be valid" + ], + "type": { + "option": "i64" + } + } + ] + } + }, { "name": "InsuranceClaim", "type": { @@ -15607,6 +15731,26 @@ ] } }, + { + "name": "SizeDistribution", + "docs": [ + "How to distribute order sizes across scale orders" + ], + "type": { + "kind": "enum", + "variants": [ + { + "name": "Flat" + }, + { + "name": "Ascending" + }, + { + "name": "Descending" + } + ] + } + }, { "name": "PerpOperation", "type": { @@ -19824,9 +19968,16 @@ "code": 6345, "name": "InvalidIsolatedPerpMarket", "msg": "Invalid Isolated Perp Market" + }, + { + "code": 6346, + "name": "InvalidOrderScaleOrderCount", + "msg": "Invalid scale order count - must be between 2 and 10" + }, + { + "code": 6347, + "name": "InvalidOrderScalePriceRange", + "msg": "Invalid scale order price range" } - ], - "metadata": { - "address": "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" - } + ] } \ No newline at end of file diff --git a/sdk/src/types.ts b/sdk/src/types.ts index f98221090..feda6c790 100644 --- a/sdk/src/types.ts +++ b/sdk/src/types.ts @@ -1295,6 +1295,42 @@ export class PostOnlyParams { static readonly SLIDE = { slide: {} }; // Modify price to be post only if can't be post only } +/** + * How to distribute order sizes across scale orders + */ +export class SizeDistribution { + static readonly FLAT = { flat: {} }; // Equal size for all orders + static readonly ASCENDING = { ascending: {} }; // Smallest at start price, largest at end price + static readonly DESCENDING = { descending: {} }; // Largest at start price, smallest at end price +} + +/** + * Parameters for placing scale orders - multiple limit orders distributed across a price range + */ +export type ScaleOrderParams = { + marketType: MarketType; + direction: PositionDirection; + marketIndex: number; + /** Total base asset amount to distribute across all orders */ + totalBaseAssetAmount: BN; + /** Starting price for the scale (in PRICE_PRECISION) */ + startPrice: BN; + /** Ending price for the scale (in PRICE_PRECISION) */ + endPrice: BN; + /** Number of orders to place (min 2, max 32). User cannot exceed 32 total open orders. */ + orderCount: number; + /** How to distribute sizes across orders */ + sizeDistribution: SizeDistribution; + /** Whether orders should be reduce-only */ + reduceOnly: boolean; + /** Post-only setting for all orders */ + postOnly: PostOnlyParams; + /** Bit flags (e.g., for high leverage mode) */ + bitFlags: number; + /** Maximum timestamp for orders to be valid */ + maxTs: BN | null; +}; + export class OrderParamsBitFlag { static readonly ImmediateOrCancel = 1; static readonly UpdateHighLeverageMode = 2; diff --git a/test-scripts/run-anchor-tests.sh b/test-scripts/run-anchor-tests.sh index 7b2fceedf..8068b1d9c 100644 --- a/test-scripts/run-anchor-tests.sh +++ b/test-scripts/run-anchor-tests.sh @@ -24,6 +24,7 @@ test_files=( # TODO BROKEN ^^ builderCodes.ts decodeUser.ts + scaleOrders.ts # fuel.ts # fuelSweep.ts admin.ts diff --git a/test-scripts/single-anchor-test.sh b/test-scripts/single-anchor-test.sh index f3fb15708..ccc088712 100755 --- a/test-scripts/single-anchor-test.sh +++ b/test-scripts/single-anchor-test.sh @@ -1,3 +1,5 @@ +#!/bin/bash + if [ "$1" != "--skip-build" ] then anchor build -- --features anchor-test && anchor test --skip-build && @@ -7,8 +9,8 @@ fi export ANCHOR_WALLET=~/.config/solana/id.json test_files=( - lpPool.ts - lpPoolSwap.ts + scaleOrders.ts + order.ts ) for test_file in ${test_files[@]}; do diff --git a/tests/scaleOrders.ts b/tests/scaleOrders.ts new file mode 100644 index 000000000..1f3f4bd08 --- /dev/null +++ b/tests/scaleOrders.ts @@ -0,0 +1,856 @@ +import * as anchor from '@coral-xyz/anchor'; +import { assert } from 'chai'; + +import { Program } from '@coral-xyz/anchor'; + +import { PublicKey, Transaction } from '@solana/web3.js'; + +import { + TestClient, + BN, + PRICE_PRECISION, + PositionDirection, + User, + EventSubscriber, + PostOnlyParams, + SizeDistribution, + BASE_PRECISION, + isVariant, + MarketType, + MARGIN_PRECISION, + getUserAccountPublicKeySync, +} from '../sdk/src'; + +import { + mockOracleNoProgram, + mockUserUSDCAccount, + mockUSDCMint, + initializeQuoteSpotMarket, + initializeSolSpotMarket, + sleep, +} from './testHelpers'; +import { OracleSource, ZERO } from '../sdk'; +import { startAnchor } from 'solana-bankrun'; +import { TestBulkAccountLoader } from '../sdk/src/accounts/testBulkAccountLoader'; +import { BankrunContextWrapper } from '../sdk/src/bankrun/bankrunConnection'; + +describe('scale orders', () => { + const chProgram = anchor.workspace.Drift as Program; + + let driftClient: TestClient; + let driftClientUser: User; + let eventSubscriber: EventSubscriber; + + let bulkAccountLoader: TestBulkAccountLoader; + + let bankrunContextWrapper: BankrunContextWrapper; + + let _userAccountPublicKey: PublicKey; + + let usdcMint; + let userUSDCAccount; + + // ammInvariant == k == x * y + const mantissaSqrtScale = new BN(Math.sqrt(PRICE_PRECISION.toNumber())); + const ammInitialQuoteAssetReserve = new anchor.BN(5 * 10 ** 11).mul( + mantissaSqrtScale + ); + const ammInitialBaseAssetReserve = new anchor.BN(5 * 10 ** 11).mul( + mantissaSqrtScale + ); + + const usdcAmount = new BN(100000 * 10 ** 6); // $100k + + const perpMarketIndex = 0; + const spotMarketIndex = 1; // SOL spot market (USDC is 0) + + let solUsd; + + before(async () => { + const context = await startAnchor('', [], []); + + bankrunContextWrapper = new BankrunContextWrapper(context); + + bulkAccountLoader = new TestBulkAccountLoader( + bankrunContextWrapper.connection, + 'processed', + 1 + ); + + eventSubscriber = new EventSubscriber( + bankrunContextWrapper.connection.toConnection(), + chProgram + ); + + await eventSubscriber.subscribe(); + + usdcMint = await mockUSDCMint(bankrunContextWrapper); + userUSDCAccount = await mockUserUSDCAccount( + usdcMint, + usdcAmount, + bankrunContextWrapper + ); + + solUsd = await mockOracleNoProgram(bankrunContextWrapper, 100); + + const marketIndexes = [perpMarketIndex]; + const bankIndexes = [0, 1]; // USDC and SOL spot markets + const oracleInfos = [ + { publicKey: PublicKey.default, source: OracleSource.QUOTE_ASSET }, + { publicKey: solUsd, source: OracleSource.PYTH }, + ]; + + driftClient = new TestClient({ + connection: bankrunContextWrapper.connection.toConnection(), + wallet: bankrunContextWrapper.provider.wallet, + programID: chProgram.programId, + opts: { + commitment: 'confirmed', + }, + activeSubAccountId: 0, + perpMarketIndexes: marketIndexes, + spotMarketIndexes: bankIndexes, + subAccountIds: [], + oracleInfos, + accountSubscription: { + type: 'polling', + accountLoader: bulkAccountLoader, + }, + }); + await driftClient.initialize(usdcMint.publicKey, true); + await driftClient.subscribe(); + await initializeQuoteSpotMarket(driftClient, usdcMint.publicKey); + await driftClient.updatePerpAuctionDuration(new BN(0)); + + let oraclesLoaded = false; + while (!oraclesLoaded) { + await driftClient.accountSubscriber.setSpotOracleMap(); + const found = + !!driftClient.accountSubscriber.getOraclePriceDataAndSlotForSpotMarket( + 0 + ); + if (found) { + oraclesLoaded = true; + } + await sleep(1000); + } + + const periodicity = new BN(60 * 60); // 1 HOUR + + await driftClient.initializePerpMarket( + 0, + solUsd, + ammInitialBaseAssetReserve, + ammInitialQuoteAssetReserve, + periodicity + ); + + // Set step size to 0.001 (1e6 in base precision) + await driftClient.updatePerpMarketStepSizeAndTickSize( + 0, + new BN(1000000), // 0.001 in BASE_PRECISION + new BN(1) + ); + + // Initialize SOL spot market + await initializeSolSpotMarket(driftClient, solUsd); + + // Set step size for spot market + await driftClient.updateSpotMarketStepSizeAndTickSize( + spotMarketIndex, + new BN(1000000), // 0.001 in token precision + new BN(1) + ); + + // Enable margin trading on spot market (required for short orders) + await driftClient.updateSpotMarketMarginWeights( + spotMarketIndex, + MARGIN_PRECISION.toNumber() * 0.75, // initial asset weight + MARGIN_PRECISION.toNumber() * 0.8, // maintenance asset weight + MARGIN_PRECISION.toNumber() * 1.25, // initial liability weight + MARGIN_PRECISION.toNumber() * 1.2 // maintenance liability weight + ); + + // Get initialization instructions + const { ixs: initIxs, userAccountPublicKey } = + await driftClient.createInitializeUserAccountAndDepositCollateralIxs( + usdcAmount, + userUSDCAccount.publicKey + ); + _userAccountPublicKey = userAccountPublicKey; + + // Get margin trading enabled instruction (manually construct since user doesn't exist yet) + const marginTradingIx = + await driftClient.program.instruction.updateUserMarginTradingEnabled( + 0, // subAccountId + true, // marginTradingEnabled + { + accounts: { + user: getUserAccountPublicKeySync( + driftClient.program.programId, + bankrunContextWrapper.provider.wallet.publicKey, + 0 + ), + authority: bankrunContextWrapper.provider.wallet.publicKey, + }, + remainingAccounts: [], + } + ); + + // Bundle and send all instructions together + const allIxs = [...initIxs, marginTradingIx]; + const tx = await driftClient.buildTransaction(allIxs); + await driftClient.sendTransaction(tx as Transaction, [], driftClient.opts); + + // Add user to client + await driftClient.addUser(0); + + driftClientUser = new User({ + driftClient, + userAccountPublicKey: await driftClient.getUserAccountPublicKey(), + accountSubscription: { + type: 'polling', + accountLoader: bulkAccountLoader, + }, + }); + await driftClientUser.subscribe(); + }); + + after(async () => { + await driftClient.unsubscribe(); + await driftClientUser.unsubscribe(); + await eventSubscriber.unsubscribe(); + }); + + beforeEach(async () => { + // Clean up any orders from previous tests + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + const userAccount = driftClientUser.getUserAccount(); + const hasOpenOrders = userAccount.orders.some((order) => + isVariant(order.status, 'open') + ); + if (hasOpenOrders) { + await driftClient.cancelOrders(); + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + } + }); + + // ==================== PERP MARKET TESTS ==================== + + it('place perp scale orders - flat distribution', async () => { + const totalBaseAmount = BASE_PRECISION; // 1 SOL + const orderCount = 5; + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.PERP, + direction: PositionDirection.LONG, + marketIndex: perpMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(95).mul(PRICE_PRECISION), // $95 (end low) + orderCount: orderCount, + sizeDistribution: SizeDistribution.FLAT, + reduceOnly: false, + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders.filter((order) => + isVariant(order.status, 'open') + ); + + assert.equal(orders.length, orderCount, 'Should have 5 open orders'); + + // All orders should be perp market type + for (const order of orders) { + assert.ok(isVariant(order.marketType, 'perp'), 'Order should be perp'); + } + + // Check orders are distributed across prices (sorted low to high) + const prices = orders.map((o) => o.price.toNumber()).sort((a, b) => a - b); + assert.equal( + prices[0], + 95 * PRICE_PRECISION.toNumber(), + 'Lowest price should be $95' + ); + assert.equal( + prices[4], + 100 * PRICE_PRECISION.toNumber(), + 'Highest price should be $100' + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders for next test + await driftClient.cancelOrders(); + }); + + it('place perp scale orders - ascending distribution (long)', async () => { + const totalBaseAmount = BASE_PRECISION; // 1 SOL + const orderCount = 3; + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.PERP, + direction: PositionDirection.LONG, + marketIndex: perpMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(90).mul(PRICE_PRECISION), // $90 (end low) + orderCount: orderCount, + sizeDistribution: SizeDistribution.ASCENDING, + reduceOnly: false, + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders + .filter((order) => isVariant(order.status, 'open')) + .sort((a, b) => a.price.toNumber() - b.price.toNumber()); + + assert.equal(orders.length, orderCount, 'Should have 3 open orders'); + + // For ascending distribution, sizes increase from first to last order + // First order (at start price $100) is smallest, last order (at end price $90) is largest + console.log( + 'Order sizes (ascending):', + orders.map((o) => ({ + price: o.price.toString(), + size: o.baseAssetAmount.toString(), + })) + ); + + // Verify sizes - lowest price should have largest size (ascending from start to end) + assert.ok( + orders[0].baseAssetAmount.gt(orders[2].baseAssetAmount), + 'Order at lowest price ($90) should have largest size (ascending distribution ends there)' + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders for next test + await driftClient.cancelOrders(); + }); + + it('place perp scale orders - short direction', async () => { + const totalBaseAmount = BASE_PRECISION.div(new BN(2)); // 0.5 SOL + const orderCount = 4; + + // Short: start low, end high (scale out up) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.PERP, + direction: PositionDirection.SHORT, + marketIndex: perpMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(105).mul(PRICE_PRECISION), // $105 (start low) + endPrice: new BN(110).mul(PRICE_PRECISION), // $110 (end high) + orderCount: orderCount, + sizeDistribution: SizeDistribution.FLAT, + reduceOnly: false, + postOnly: PostOnlyParams.MUST_POST_ONLY, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders.filter((order) => + isVariant(order.status, 'open') + ); + + assert.equal(orders.length, orderCount, 'Should have 4 open orders'); + + // All orders should be short direction + for (const order of orders) { + assert.deepEqual( + order.direction, + PositionDirection.SHORT, + 'All orders should be SHORT' + ); + } + + // Check prices are distributed from 105 to 110 + const prices = orders.map((o) => o.price.toNumber()).sort((a, b) => a - b); + // Allow small rounding tolerance + const expectedStartPrice = 105 * PRICE_PRECISION.toNumber(); + assert.ok( + Math.abs(prices[0] - expectedStartPrice) <= 10, + `Lowest price should be ~$105 (got ${prices[0]}, expected ${expectedStartPrice})` + ); + const expectedEndPrice = 110 * PRICE_PRECISION.toNumber(); + assert.ok( + Math.abs(prices[3] - expectedEndPrice) <= 10, + `Highest price should be ~$110 (got ${prices[3]}, expected ${expectedEndPrice})` + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders for next test + await driftClient.cancelOrders(); + }); + + it('place perp scale orders - descending distribution', async () => { + const totalBaseAmount = BASE_PRECISION; // 1 SOL + const orderCount = 3; + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.PERP, + direction: PositionDirection.LONG, + marketIndex: perpMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(90).mul(PRICE_PRECISION), // $90 (end low) + orderCount: orderCount, + sizeDistribution: SizeDistribution.DESCENDING, + reduceOnly: false, + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders + .filter((order) => isVariant(order.status, 'open')) + .sort((a, b) => a.price.toNumber() - b.price.toNumber()); + + assert.equal(orders.length, orderCount, 'Should have 3 open orders'); + + // For descending distribution, sizes decrease from first order to last + // First order (at start price $100) gets largest size + // Last order (at end price $90) gets smallest size + console.log( + 'Order sizes (descending):', + orders.map((o) => ({ + price: o.price.toString(), + size: o.baseAssetAmount.toString(), + })) + ); + + // Verify sizes - highest price (start) has largest size, lowest price (end) has smallest + assert.ok( + orders[2].baseAssetAmount.gt(orders[0].baseAssetAmount), + 'Order at highest price ($100) should have largest size, lowest price ($90) smallest' + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders + await driftClient.cancelOrders(); + }); + + it('place perp scale orders - with reduce only flag', async () => { + // Test that reduce-only flag is properly set on scale orders + // Note: We don't need an actual position to test the flag is set correctly + const totalBaseAmount = BASE_PRECISION.div(new BN(2)); // 0.5 SOL + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.PERP, + direction: PositionDirection.LONG, + marketIndex: perpMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(95).mul(PRICE_PRECISION), // $95 (end low) + orderCount: 2, + sizeDistribution: SizeDistribution.FLAT, + reduceOnly: true, // Test reduce only flag + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders.filter((order) => + isVariant(order.status, 'open') + ); + + assert.equal(orders.length, 2, 'Should have 2 open orders'); + + // All orders should have reduce only flag set + for (const order of orders) { + assert.equal(order.reduceOnly, true, 'Order should be reduce only'); + } + + // Cancel all orders + await driftClient.cancelOrders(); + }); + + it('place perp scale orders - minimum 2 orders', async () => { + const totalBaseAmount = BASE_PRECISION; + const orderCount = 2; // Minimum allowed + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.PERP, + direction: PositionDirection.LONG, + marketIndex: perpMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(95).mul(PRICE_PRECISION), // $95 (end low) + orderCount: orderCount, + sizeDistribution: SizeDistribution.FLAT, + reduceOnly: false, + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders.filter((order) => + isVariant(order.status, 'open') + ); + + assert.equal(orders.length, 2, 'Should have exactly 2 orders'); + + const prices = orders.map((o) => o.price.toNumber()).sort((a, b) => a - b); + assert.equal( + prices[0], + 95 * PRICE_PRECISION.toNumber(), + 'Lowest price should be $95' + ); + assert.equal( + prices[1], + 100 * PRICE_PRECISION.toNumber(), + 'Highest price should be $100' + ); + + // Cancel all orders + await driftClient.cancelOrders(); + }); + + // ==================== SPOT MARKET TESTS ==================== + + it('place spot scale orders - flat distribution (long)', async () => { + const totalBaseAmount = BASE_PRECISION; // 1 SOL + const orderCount = 3; + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.SPOT, + direction: PositionDirection.LONG, + marketIndex: spotMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(95).mul(PRICE_PRECISION), // $95 (end low) + orderCount: orderCount, + sizeDistribution: SizeDistribution.FLAT, + reduceOnly: false, + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders.filter((order) => + isVariant(order.status, 'open') + ); + + assert.equal(orders.length, orderCount, 'Should have 3 open orders'); + + // All orders should be spot market type + for (const order of orders) { + assert.ok(isVariant(order.marketType, 'spot'), 'Order should be spot'); + assert.equal( + order.marketIndex, + spotMarketIndex, + 'Market index should match' + ); + } + + // Check orders are distributed across prices + const prices = orders.map((o) => o.price.toNumber()).sort((a, b) => a - b); + assert.equal( + prices[0], + 95 * PRICE_PRECISION.toNumber(), + 'Lowest price should be $95' + ); + assert.equal( + prices[2], + 100 * PRICE_PRECISION.toNumber(), + 'Highest price should be $100' + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders + await driftClient.cancelOrders(); + }); + + it('place spot scale orders - short direction', async () => { + const totalBaseAmount = BASE_PRECISION.div(new BN(2)); // 0.5 SOL + const orderCount = 4; + + // Short: start low, end high (scale out up) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.SPOT, + direction: PositionDirection.SHORT, + marketIndex: spotMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(105).mul(PRICE_PRECISION), // $105 (start low) + endPrice: new BN(110).mul(PRICE_PRECISION), // $110 (end high) + orderCount: orderCount, + sizeDistribution: SizeDistribution.FLAT, + reduceOnly: false, + postOnly: PostOnlyParams.MUST_POST_ONLY, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders.filter((order) => + isVariant(order.status, 'open') + ); + + assert.equal(orders.length, orderCount, 'Should have 4 open orders'); + + // All orders should be spot market type and short direction + for (const order of orders) { + assert.ok(isVariant(order.marketType, 'spot'), 'Order should be spot'); + assert.deepEqual( + order.direction, + PositionDirection.SHORT, + 'All orders should be SHORT' + ); + } + + // Check prices are distributed from 105 to 110 + const prices = orders.map((o) => o.price.toNumber()).sort((a, b) => a - b); + const expectedStartPrice = 105 * PRICE_PRECISION.toNumber(); + assert.ok( + Math.abs(prices[0] - expectedStartPrice) <= 10, + `Lowest price should be ~$105 (got ${prices[0]})` + ); + const expectedEndPrice = 110 * PRICE_PRECISION.toNumber(); + assert.ok( + Math.abs(prices[3] - expectedEndPrice) <= 10, + `Highest price should be ~$110 (got ${prices[3]})` + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders + await driftClient.cancelOrders(); + }); + + it('place spot scale orders - ascending distribution', async () => { + const totalBaseAmount = BASE_PRECISION; // 1 SOL + const orderCount = 3; + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.SPOT, + direction: PositionDirection.LONG, + marketIndex: spotMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(90).mul(PRICE_PRECISION), // $90 (end low) + orderCount: orderCount, + sizeDistribution: SizeDistribution.ASCENDING, + reduceOnly: false, + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders + .filter((order) => isVariant(order.status, 'open')) + .sort((a, b) => a.price.toNumber() - b.price.toNumber()); + + assert.equal(orders.length, orderCount, 'Should have 3 open orders'); + + // All orders should be spot market type + for (const order of orders) { + assert.ok(isVariant(order.marketType, 'spot'), 'Order should be spot'); + } + + // For ascending distribution, sizes increase from start to end + // Order at lowest price ($90 - end) should have largest size + console.log( + 'Spot order sizes (ascending):', + orders.map((o) => ({ + price: o.price.toString(), + size: o.baseAssetAmount.toString(), + })) + ); + + assert.ok( + orders[0].baseAssetAmount.gt(orders[2].baseAssetAmount), + 'Order at lowest price ($90) should have largest size' + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders + await driftClient.cancelOrders(); + }); + + it('place spot scale orders - descending distribution', async () => { + const totalBaseAmount = BASE_PRECISION; // 1 SOL + const orderCount = 3; + + // Long: start high, end low (DCA down) + const txSig = await driftClient.placeScaleOrders({ + marketType: MarketType.SPOT, + direction: PositionDirection.LONG, + marketIndex: spotMarketIndex, + totalBaseAssetAmount: totalBaseAmount, + startPrice: new BN(100).mul(PRICE_PRECISION), // $100 (start high) + endPrice: new BN(90).mul(PRICE_PRECISION), // $90 (end low) + orderCount: orderCount, + sizeDistribution: SizeDistribution.DESCENDING, + reduceOnly: false, + postOnly: PostOnlyParams.NONE, + bitFlags: 0, + maxTs: null, + }); + + bankrunContextWrapper.printTxLogs(txSig); + + await driftClient.fetchAccounts(); + await driftClientUser.fetchAccounts(); + + const userAccount = driftClientUser.getUserAccount(); + const orders = userAccount.orders + .filter((order) => isVariant(order.status, 'open')) + .sort((a, b) => a.price.toNumber() - b.price.toNumber()); + + assert.equal(orders.length, orderCount, 'Should have 3 open orders'); + + // For descending distribution, sizes decrease from start to end + // Order at highest price ($100 - start) should have largest size + console.log( + 'Spot order sizes (descending):', + orders.map((o) => ({ + price: o.price.toString(), + size: o.baseAssetAmount.toString(), + })) + ); + + assert.ok( + orders[2].baseAssetAmount.gt(orders[0].baseAssetAmount), + 'Order at highest price ($100) should have largest size' + ); + + // Check total base amount sums correctly + const totalBase = orders.reduce( + (sum, o) => sum.add(o.baseAssetAmount), + ZERO + ); + assert.ok( + totalBase.eq(totalBaseAmount), + 'Total base amount should equal input' + ); + + // Cancel all orders + await driftClient.cancelOrders(); + }); +});