diff --git a/src/simulation/simulation_trader.py b/src/simulation/simulation_trader.py index 17d8998..62ac137 100644 --- a/src/simulation/simulation_trader.py +++ b/src/simulation/simulation_trader.py @@ -108,30 +108,32 @@ class SimulationTrader(BaseTrader): # 此处简化处理,在实际情况中应该计算所有持仓的市值 self.sim_balance["total"] = self.sim_balance["cash"] - def cancel(self, order_id, strategy_name="default_strategy"): + def cancel(self, order_id): message = f"模拟撤单 - 委托号: {order_id}" self.logger.info(message) - position_manager = self.get_position_manager(strategy_name) - - if order_id in position_manager.pending_orders: - position_manager.update_order_status(order_id, 0, ORDER_STATUS_CANCELLED) - return {"order_id": "order_id", "message": message, "success": True} - else: - return {"order_id": None, "message": "订单不存在", "success": False} + position_managers = self.get_all_position_managers() + for position_manager in position_managers.values(): + if order_id in position_manager.pending_orders: + position_manager.update_order_status(order_id, 0, ORDER_STATUS_CANCELLED) + return {"order_id": "order_id", "message": message, "success": True} + else: + return {"order_id": None, "message": "订单不存在", "success": False} def get_balance(self): message = "模拟交易:查询余额" self.logger.info(message) return self.sim_balance - def get_positions(self, strategy_name="default_strategy"): + def get_positions(self): message = "模拟交易:查询持仓" self.logger.info(message) - position_manager = self.get_position_manager(strategy_name) - postions = position_manager.get_positions() + position_managers = self.get_all_position_managers() + positions = {} + for position_manager in position_managers.values(): + positions.update(position_manager.get_positions()) # convert to json list - return [position.to_dict() for position in postions.values()] + return [position.to_dict() for position in positions.values()] def get_today_trades(self): message = "模拟交易:查询今日成交" diff --git a/src/trade_server.py b/src/trade_server.py index 5435351..5f294d4 100644 --- a/src/trade_server.py +++ b/src/trade_server.py @@ -154,7 +154,7 @@ def buy(): strategy_name, code, ORDER_DIRECTION_BUY, amount, price ) else: - result = get_trader().buy(code, price, amount) + result = get_trader().buy(code, price, amount, strategy_name) if result.get("success"): logger.info(f"买入成功: {result}") @@ -211,7 +211,7 @@ def sell(): strategy_name, code, ORDER_DIRECTION_SELL, amount, price ) else: - result = get_trader().sell(code, price, amount) + result = get_trader().sell(code, price, amount, strategy_name) if result.get("success"): logger.info(f"卖出成功: {result}") @@ -268,7 +268,6 @@ def get_positions(): logger.error(f"Error processing positions request: {str(e)}") abort(500, description="Internal server error") - @app.route("/yu/todaytrades", methods=["GET"]) def get_today_trades(): """Get the today's trades of the account."""