refactor trader

This commit is contained in:
zhiyong 2025-05-01 05:33:28 +08:00
parent 5823967305
commit a407ce1f2f
7 changed files with 267 additions and 89 deletions

190
src/base_trader.py Normal file
View File

@ -0,0 +1,190 @@
import datetime as dt
from chinese_calendar import is_workday
from abc import ABC, abstractmethod
from logger_config import get_logger
# 获取日志记录器
logger = get_logger('base_trader')
class BaseTrader(ABC):
"""交易基类,定义交易接口的通用方法"""
def __init__(self):
"""初始化交易基类"""
pass
@abstractmethod
def is_logged_in(self):
"""检查交易系统是否已经登录
Returns:
bool: True表示已登录False表示未登录
"""
pass
@abstractmethod
def login(self):
"""登录交易系统
Returns:
bool: 登录是否成功
"""
pass
@abstractmethod
def logout(self):
"""登出交易系统"""
pass
@abstractmethod
def get_balance(self):
"""获取账户资金情况
Returns:
dict: 账户资金信息若失败返回None
"""
pass
@abstractmethod
def get_positions(self):
"""获取所有持仓
Returns:
list: 持仓列表若无持仓返回空列表
"""
pass
@abstractmethod
def get_position(self, stock_code):
"""查询指定股票代码的持仓信息
Args:
stock_code: 股票代码例如 "600000.SH"
Returns:
dict: 持仓详情如果未持有则返回None
"""
pass
@abstractmethod
def get_today_trades(self):
"""获取当日成交
Returns:
list: 成交列表若无成交返回空列表
"""
pass
@abstractmethod
def get_today_orders(self):
"""获取当日委托
Returns:
list: 委托列表若无委托返回空列表
"""
pass
@abstractmethod
def get_order(self, order_id):
"""查询指定订单ID的详细信息
Args:
order_id: 订单ID
Returns:
dict: 订单详情如果未找到则返回None
"""
pass
@abstractmethod
def buy(self, code, price, amount, order_type='limit'):
"""买入股票
Args:
code: 股票代码
price: 买入价格市价单时可为0
amount: 买入数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
pass
@abstractmethod
def sell(self, code, price, amount, order_type='limit'):
"""卖出股票
Args:
code: 股票代码
price: 卖出价格市价单时可为0
amount: 卖出数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
pass
@abstractmethod
def cancel(self, order_id):
"""撤销订单
Args:
order_id: 订单ID
Returns:
dict: 撤单结果
"""
pass
def is_trading_time(self):
"""判断当前是否为交易时间
Returns:
bool: True 表示当前为交易时间False 表示当前休市
"""
try:
now = dt.datetime.now()
# 先判断是否为交易日
if not self.is_trading_date():
return False
# 判断是否在交易时间段内
current_time = now.time()
morning_start = dt.time(9, 30) # 上午开市时间 9:30
morning_end = dt.time(11, 30) # 上午休市时间 11:30
afternoon_start = dt.time(13, 0) # 下午开市时间 13:00
afternoon_end = dt.time(15, 0) # 下午休市时间 15:00
# 判断是否在上午或下午的交易时段
is_morning_session = morning_start <= current_time <= morning_end
is_afternoon_session = afternoon_start <= current_time <= afternoon_end
return is_morning_session or is_afternoon_session
except Exception as e:
logger.error(f"判断交易时间发生错误: {str(e)}")
return False
def is_trading_date(self, date=None):
"""判断指定日期是否为交易日
Args:
date: 日期默认为当前日期
Returns:
bool: True 表示是交易日False 表示非交易日
"""
try:
# 如果未指定日期,使用当前日期
if date is None:
date = dt.datetime.now()
# 使用 chinese_calendar 判断是否为工作日(考虑节假日和调休)
return is_workday(date)
except Exception as e:
logger.error(f"判断交易日期发生错误: {str(e)}")
return False

View File

@ -92,7 +92,7 @@ def get_today_trades() -> dict:
response = requests.get(f"{URL}/todaytrades") response = requests.get(f"{URL}/todaytrades")
return response.json() return response.json()
def get_today_entrust() -> dict: def get_today_orders() -> dict:
"""获取今日委托记录(仅实盘模式) """获取今日委托记录(仅实盘模式)
Returns: Returns:
@ -183,7 +183,7 @@ if __name__ == "__main__":
print("今日成交:", trades) print("今日成交:", trades)
# 示例:获取今日委托记录 # 示例:获取今日委托记录
entrusts = get_today_entrust() entrusts = get_today_orders()
print("今日委托:", entrusts) print("今日委托:", entrusts)
# 示例:再次查询持仓变化 # 示例:再次查询持仓变化

View File

@ -153,7 +153,7 @@ class RealTraderManager:
# 获取最新的委托列表 # 获取最新的委托列表
try: try:
entrusts = self.trader.get_today_entrust() entrusts = self.trader.get_today_orders()
if entrusts is None: if entrusts is None:
logger.error("获取今日委托失败,跳过本次检查") logger.error("获取今日委托失败,跳过本次检查")
return return
@ -224,7 +224,7 @@ class RealTraderManager:
try: try:
# 如果没有提供委托字典,则获取当前委托 # 如果没有提供委托字典,则获取当前委托
if entrust_map is None: if entrust_map is None:
entrusts = self.trader.get_today_entrust() entrusts = self.trader.get_today_orders()
entrust_map = {str(e['order_id']): e for e in entrusts} entrust_map = {str(e['order_id']): e for e in entrusts}
# 查找对应的委托记录 # 查找对应的委托记录

View File

@ -119,7 +119,7 @@ class SimulationTrader:
self.logger.info(message) self.logger.info(message)
return [] return []
def get_today_entrust(self): def get_today_orders(self):
message = "模拟交易:查询今日委托" message = "模拟交易:查询今日委托"
self.logger.info(message) self.logger.info(message)
return [] return []

View File

@ -102,7 +102,7 @@ class StrategyPositionManager:
trader_type = StrategyPositionManager.get_trader_type(trader) trader_type = StrategyPositionManager.get_trader_type(trader)
# 获取今日委托 # 获取今日委托
today_entrusts = trader.get_today_entrust() today_entrusts = trader.get_today_orders()
# 更新委托状态 # 更新委托状态
for order_id, order_info in list(pending_orders[trader_type].items()): for order_id, order_info in list(pending_orders[trader_type].items()):

View File

@ -498,13 +498,13 @@ def get_today_trades():
@app.route("/yu/todayentrust", methods=["GET"]) @app.route("/yu/todayentrust", methods=["GET"])
def get_today_entrust(): def get_today_orders():
"""Get the today's entrust of the account.""" """Get the today's entrust of the account."""
logger.info("Received today entrust request") logger.info("Received today entrust request")
try: try:
# 直接使用实盘交易实例,不考虑模拟盘 # 直接使用实盘交易实例,不考虑模拟盘
trader = get_real_trader() trader = get_real_trader()
entrust = trader.get_today_entrust() entrust = trader.get_today_orders()
logger.info(f"今日委托: {entrust}") logger.info(f"今日委托: {entrust}")
return jsonify({"success": True, "data": entrust, "simulation": False}), 200 return jsonify({"success": True, "data": entrust, "simulation": False}), 200
@ -620,7 +620,7 @@ def get_order_status():
trader = get_real_trader() trader = get_real_trader()
try: try:
entrusts = execute_with_timeout(trader.get_today_entrust, Config.TRADE_TIMEOUT) entrusts = execute_with_timeout(trader.get_today_orders, Config.TRADE_TIMEOUT)
if entrusts is None: if entrusts is None:
logger.error("获取今日委托超时") logger.error("获取今日委托超时")
return jsonify({"success": False, "error": "获取今日委托超时", "simulation": False}), 500 return jsonify({"success": False, "error": "获取今日委托超时", "simulation": False}), 500
@ -632,7 +632,7 @@ def get_order_status():
# 模拟交易模式 # 模拟交易模式
trader = get_sim_trader() trader = get_sim_trader()
try: try:
entrusts = trader.get_today_entrust() entrusts = trader.get_today_orders()
return jsonify({"success": True, "data": entrusts, "simulation": True}), 200 return jsonify({"success": True, "data": entrusts, "simulation": True}), 200
except Exception as e: except Exception as e:
logger.error(f"获取今日委托时出错: {str(e)}") logger.error(f"获取今日委托时出错: {str(e)}")

View File

@ -1,12 +1,11 @@
import os import os
import random import random
from config import Config from config import Config
from base_trader import BaseTrader
from xtquant.xttrader import XtQuantTrader from xtquant.xttrader import XtQuantTrader
from xtquant.xttype import StockAccount from xtquant.xttype import StockAccount
from xtquant import xtconstant from xtquant import xtconstant
from xtquant.xtdata import get_instrument_detail, get_trading_time from xtquant.xtdata import get_instrument_detail
import datetime as dt
from chinese_calendar import is_workday
from logger_config import get_logger from logger_config import get_logger
# 获取日志记录器 # 获取日志记录器
@ -38,8 +37,12 @@ class MyXtQuantTraderCallback:
def on_smt_appointment_async_response(self, response): def on_smt_appointment_async_response(self, response):
logger.info(f"约券异步反馈: {response.seq}") logger.info(f"约券异步反馈: {response.seq}")
class XtTrader: class XtTrader(BaseTrader):
def __init__(self): def __init__(self):
super().__init__()
self.started = False
self.connected = False
self.subscribed = False
self._ACCOUNT = Config.XT_ACCOUNT self._ACCOUNT = Config.XT_ACCOUNT
self._PATH = Config.XT_PATH self._PATH = Config.XT_PATH
self._SESSION_ID = random.randint(100000, 99999999) self._SESSION_ID = random.randint(100000, 99999999)
@ -50,9 +53,6 @@ class XtTrader:
self.xt_trader = XtQuantTrader(self._PATH, self._SESSION_ID) self.xt_trader = XtQuantTrader(self._PATH, self._SESSION_ID)
self.account = StockAccount(self._ACCOUNT, self._account_type) self.account = StockAccount(self._ACCOUNT, self._account_type)
self.xt_trader.register_callback(self._callback) self.xt_trader.register_callback(self._callback)
self.started = False
self.connected = False
self.subscribed = False
def is_logged_in(self): def is_logged_in(self):
"""检查交易系统是否已经登录 """检查交易系统是否已经登录
@ -112,6 +112,23 @@ class XtTrader:
] ]
return [] return []
def get_position(self, stock_code):
position = self.xt_trader.query_stock_position(self.account, stock_code)
if position:
return {
"account_id": position.account_id,
"stock_code": position.stock_code,
"volume": position.volume,
"can_use_volume": position.can_use_volume,
"open_price": position.open_price,
"market_value": position.market_value,
"frozen_volume": position.frozen_volume,
"on_road_volume": position.on_road_volume,
"yesterday_volume": position.yesterday_volume,
"avg_price": position.avg_price
}
return None
def get_today_trades(self): def get_today_trades(self):
trades = self.xt_trader.query_stock_trades(self.account) trades = self.xt_trader.query_stock_trades(self.account)
if trades: if trades:
@ -119,7 +136,7 @@ class XtTrader:
{ {
"account_id": t.account_id, "account_id": t.account_id,
"stock_code": t.stock_code, "stock_code": t.stock_code,
"stock_name": get_instrument_detail(t.stock_code)["InstrumentName"] if get_instrument_detail(t.stock_code) else "", "stock_name": self.get_stock_name(t.stock_code),
"order_id": t.order_id, "order_id": t.order_id,
"traded_id": t.traded_id, "traded_id": t.traded_id,
"traded_time": t.traded_time, "traded_time": t.traded_time,
@ -131,7 +148,7 @@ class XtTrader:
] ]
return [] return []
def get_today_entrust(self): def get_today_orders(self):
orders = self.xt_trader.query_stock_orders(self.account) orders = self.xt_trader.query_stock_orders(self.account)
if orders: if orders:
return [ return [
@ -152,6 +169,25 @@ class XtTrader:
] ]
return [] return []
def get_order(self, order_id):
order = self.xt_trader.query_stock_order(self.account, int(order_id))
if order:
return {
"account_id": order.account_id,
"stock_code": order.stock_code,
"order_id": order.order_id,
"order_time": order.order_time,
"order_type": "buy" if order.order_type == xtconstant.STOCK_BUY else "sell",
"order_volume": order.order_volume,
"price_type": self._convert_price_type(order.price_type),
"price": order.price,
"traded_volume": order.traded_volume,
"traded_price": order.traded_price,
"order_status": order.order_status,
"status_msg": order.status_msg
}
return None
def _convert_price_type(self, price_type): def _convert_price_type(self, price_type):
"""Convert numeric price type to readable string""" """Convert numeric price type to readable string"""
price_type_map = { price_type_map = {
@ -165,28 +201,27 @@ class XtTrader:
} }
return price_type_map.get(price_type, f"unknown_{price_type}") return price_type_map.get(price_type, f"unknown_{price_type}")
def buy(self, code, price, amount, order_type='limit'): def get_stock_name(self, stock_code):
"""买入股票 """获取股票名称
Args: Args:
code: 股票代码 stock_code: 股票代码例如 "600000.SH"
price: 买入价格市价单时可为0
amount: 买入数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
Returns: Returns:
dict: 包含订单ID的字典 str: 股票名称如果获取失败则返回空字符串
""" """
# 确定价格类型 try:
price_type = xtconstant.FIX_PRICE # 默认为限价单 instrument_info = get_instrument_detail(stock_code)
if instrument_info and "InstrumentName" in instrument_info:
return instrument_info["InstrumentName"]
return ""
except Exception as e:
logger.error(f"获取股票名称失败: {stock_code}, {str(e)}")
return ""
if order_type == 'market': def buy(self, code, price, amount, order_type='limit'):
# 市价单,根据不同市场选择合适的市价单类型 # 确定价格类型
if code.startswith('1') or code.startswith('5'): price_type = xtconstant.MARKET_BEST if order_type == 'market' else xtconstant.FIX_PRICE
# 基金等可能需要不同的市价单类型
price_type = xtconstant.MARKET_BEST
else:
price_type = xtconstant.MARKET_BEST # 市价最优价
# 如果是市价单价格可以设为0 # 如果是市价单价格可以设为0
if price_type != xtconstant.FIX_PRICE: if price_type != xtconstant.FIX_PRICE:
@ -198,27 +233,8 @@ class XtTrader:
return {"order_id": order_id} return {"order_id": order_id}
def sell(self, code, price, amount, order_type='limit'): def sell(self, code, price, amount, order_type='limit'):
"""卖出股票
Args:
code: 股票代码
price: 卖出价格市价单时可为0
amount: 卖出数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
# 确定价格类型 # 确定价格类型
price_type = xtconstant.FIX_PRICE # 默认为限价单 price_type = xtconstant.MARKET_BEST if order_type == 'market' else xtconstant.FIX_PRICE
if order_type == 'market':
# 市价单,根据不同市场选择合适的市价单类型
if code.startswith('1') or code.startswith('5'):
# 基金等可能需要不同的市价单类型
price_type = xtconstant.MARKET_BEST
else:
price_type = xtconstant.MARKET_BEST # 市价最优价
# 如果是市价单价格可以设为0 # 如果是市价单价格可以设为0
if price_type != xtconstant.FIX_PRICE: if price_type != xtconstant.FIX_PRICE:
@ -229,40 +245,12 @@ class XtTrader:
) )
return {"order_id": order_id} return {"order_id": order_id}
def cancel(self, entrust_no): def cancel(self, order_id):
# 撤单接口需要订单编号 # 撤单接口需要订单编号
result = self.xt_trader.cancel_order_stock(self.account, int(entrust_no)) result = self.xt_trader.cancel_order_stock(self.account, int(order_id))
return {"cancel_result": result} return {"cancel_result": result}
def is_trading_time(self):
"""判断当前是否为交易时间
Returns:
bool: True 表示当前为交易时间False 表示当前休市
"""
try:
now = dt.datetime.now()
# 判断是否为工作日(使用 chinese_calendar 判断,会考虑节假日和调休)
if not is_workday(now):
return False
# 判断是否在交易时间段内
current_time = now.time()
morning_start = dt.time(9, 30) # 上午开市时间 9:30
morning_end = dt.time(11, 30) # 上午休市时间 11:30
afternoon_start = dt.time(13, 0) # 下午开市时间 13:00
afternoon_end = dt.time(15, 0) # 下午休市时间 15:00
# 判断是否在上午或下午的交易时段
is_morning_session = morning_start <= current_time <= morning_end
is_afternoon_session = afternoon_start <= current_time <= afternoon_end
return is_morning_session or is_afternoon_session
except Exception as e:
logger.error(f"判断交易时间发生错误: {str(e)}")
return False
if __name__ == "__main__": if __name__ == "__main__":
trader = XtTrader() trader = XtTrader()
@ -270,4 +258,4 @@ if __name__ == "__main__":
logger.info(f"账户余额: {trader.get_balance()}") logger.info(f"账户余额: {trader.get_balance()}")
logger.info(f"持仓: {trader.get_positions()}") logger.info(f"持仓: {trader.get_positions()}")
logger.info(f"当日成交: {trader.get_today_trades()}") logger.info(f"当日成交: {trader.get_today_trades()}")
logger.info(f"当日委托: {trader.get_today_entrust()}") logger.info(f"当日委托: {trader.get_today_orders()}")