fix buy sell for simulation trader

This commit is contained in:
zhiyong 2025-05-10 22:11:22 +08:00
parent 5e732bf97a
commit 6ad14c689e
2 changed files with 16 additions and 15 deletions

View File

@ -108,30 +108,32 @@ class SimulationTrader(BaseTrader):
# 此处简化处理,在实际情况中应该计算所有持仓的市值 # 此处简化处理,在实际情况中应该计算所有持仓的市值
self.sim_balance["total"] = self.sim_balance["cash"] self.sim_balance["total"] = self.sim_balance["cash"]
def cancel(self, order_id, strategy_name="default_strategy"): def cancel(self, order_id):
message = f"模拟撤单 - 委托号: {order_id}" message = f"模拟撤单 - 委托号: {order_id}"
self.logger.info(message) self.logger.info(message)
position_manager = self.get_position_manager(strategy_name) position_managers = self.get_all_position_managers()
for position_manager in position_managers.values():
if order_id in position_manager.pending_orders: if order_id in position_manager.pending_orders:
position_manager.update_order_status(order_id, 0, ORDER_STATUS_CANCELLED) position_manager.update_order_status(order_id, 0, ORDER_STATUS_CANCELLED)
return {"order_id": "order_id", "message": message, "success": True} return {"order_id": "order_id", "message": message, "success": True}
else: else:
return {"order_id": None, "message": "订单不存在", "success": False} return {"order_id": None, "message": "订单不存在", "success": False}
def get_balance(self): def get_balance(self):
message = "模拟交易:查询余额" message = "模拟交易:查询余额"
self.logger.info(message) self.logger.info(message)
return self.sim_balance return self.sim_balance
def get_positions(self, strategy_name="default_strategy"): def get_positions(self):
message = "模拟交易:查询持仓" message = "模拟交易:查询持仓"
self.logger.info(message) self.logger.info(message)
position_manager = self.get_position_manager(strategy_name) position_managers = self.get_all_position_managers()
postions = position_manager.get_positions() positions = {}
for position_manager in position_managers.values():
positions.update(position_manager.get_positions())
# convert to json list # convert to json list
return [position.to_dict() for position in postions.values()] return [position.to_dict() for position in positions.values()]
def get_today_trades(self): def get_today_trades(self):
message = "模拟交易:查询今日成交" message = "模拟交易:查询今日成交"

View File

@ -154,7 +154,7 @@ def buy():
strategy_name, code, ORDER_DIRECTION_BUY, amount, price strategy_name, code, ORDER_DIRECTION_BUY, amount, price
) )
else: else:
result = get_trader().buy(code, price, amount) result = get_trader().buy(code, price, amount, strategy_name)
if result.get("success"): if result.get("success"):
logger.info(f"买入成功: {result}") logger.info(f"买入成功: {result}")
@ -211,7 +211,7 @@ def sell():
strategy_name, code, ORDER_DIRECTION_SELL, amount, price strategy_name, code, ORDER_DIRECTION_SELL, amount, price
) )
else: else:
result = get_trader().sell(code, price, amount) result = get_trader().sell(code, price, amount, strategy_name)
if result.get("success"): if result.get("success"):
logger.info(f"卖出成功: {result}") logger.info(f"卖出成功: {result}")
@ -268,7 +268,6 @@ def get_positions():
logger.error(f"Error processing positions request: {str(e)}") logger.error(f"Error processing positions request: {str(e)}")
abort(500, description="Internal server error") abort(500, description="Internal server error")
@app.route("/yu/todaytrades", methods=["GET"]) @app.route("/yu/todaytrades", methods=["GET"])
def get_today_trades(): def get_today_trades():
"""Get the today's trades of the account.""" """Get the today's trades of the account."""