refactor trader

This commit is contained in:
zhiyong 2025-05-01 05:33:28 +08:00
parent 5823967305
commit a407ce1f2f
7 changed files with 267 additions and 89 deletions

190
src/base_trader.py Normal file
View File

@ -0,0 +1,190 @@
import datetime as dt
from chinese_calendar import is_workday
from abc import ABC, abstractmethod
from logger_config import get_logger
# 获取日志记录器
logger = get_logger('base_trader')
class BaseTrader(ABC):
"""交易基类,定义交易接口的通用方法"""
def __init__(self):
"""初始化交易基类"""
pass
@abstractmethod
def is_logged_in(self):
"""检查交易系统是否已经登录
Returns:
bool: True表示已登录False表示未登录
"""
pass
@abstractmethod
def login(self):
"""登录交易系统
Returns:
bool: 登录是否成功
"""
pass
@abstractmethod
def logout(self):
"""登出交易系统"""
pass
@abstractmethod
def get_balance(self):
"""获取账户资金情况
Returns:
dict: 账户资金信息若失败返回None
"""
pass
@abstractmethod
def get_positions(self):
"""获取所有持仓
Returns:
list: 持仓列表若无持仓返回空列表
"""
pass
@abstractmethod
def get_position(self, stock_code):
"""查询指定股票代码的持仓信息
Args:
stock_code: 股票代码例如 "600000.SH"
Returns:
dict: 持仓详情如果未持有则返回None
"""
pass
@abstractmethod
def get_today_trades(self):
"""获取当日成交
Returns:
list: 成交列表若无成交返回空列表
"""
pass
@abstractmethod
def get_today_orders(self):
"""获取当日委托
Returns:
list: 委托列表若无委托返回空列表
"""
pass
@abstractmethod
def get_order(self, order_id):
"""查询指定订单ID的详细信息
Args:
order_id: 订单ID
Returns:
dict: 订单详情如果未找到则返回None
"""
pass
@abstractmethod
def buy(self, code, price, amount, order_type='limit'):
"""买入股票
Args:
code: 股票代码
price: 买入价格市价单时可为0
amount: 买入数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
pass
@abstractmethod
def sell(self, code, price, amount, order_type='limit'):
"""卖出股票
Args:
code: 股票代码
price: 卖出价格市价单时可为0
amount: 卖出数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
pass
@abstractmethod
def cancel(self, order_id):
"""撤销订单
Args:
order_id: 订单ID
Returns:
dict: 撤单结果
"""
pass
def is_trading_time(self):
"""判断当前是否为交易时间
Returns:
bool: True 表示当前为交易时间False 表示当前休市
"""
try:
now = dt.datetime.now()
# 先判断是否为交易日
if not self.is_trading_date():
return False
# 判断是否在交易时间段内
current_time = now.time()
morning_start = dt.time(9, 30) # 上午开市时间 9:30
morning_end = dt.time(11, 30) # 上午休市时间 11:30
afternoon_start = dt.time(13, 0) # 下午开市时间 13:00
afternoon_end = dt.time(15, 0) # 下午休市时间 15:00
# 判断是否在上午或下午的交易时段
is_morning_session = morning_start <= current_time <= morning_end
is_afternoon_session = afternoon_start <= current_time <= afternoon_end
return is_morning_session or is_afternoon_session
except Exception as e:
logger.error(f"判断交易时间发生错误: {str(e)}")
return False
def is_trading_date(self, date=None):
"""判断指定日期是否为交易日
Args:
date: 日期默认为当前日期
Returns:
bool: True 表示是交易日False 表示非交易日
"""
try:
# 如果未指定日期,使用当前日期
if date is None:
date = dt.datetime.now()
# 使用 chinese_calendar 判断是否为工作日(考虑节假日和调休)
return is_workday(date)
except Exception as e:
logger.error(f"判断交易日期发生错误: {str(e)}")
return False

View File

@ -92,7 +92,7 @@ def get_today_trades() -> dict:
response = requests.get(f"{URL}/todaytrades")
return response.json()
def get_today_entrust() -> dict:
def get_today_orders() -> dict:
"""获取今日委托记录(仅实盘模式)
Returns:
@ -183,7 +183,7 @@ if __name__ == "__main__":
print("今日成交:", trades)
# 示例:获取今日委托记录
entrusts = get_today_entrust()
entrusts = get_today_orders()
print("今日委托:", entrusts)
# 示例:再次查询持仓变化

View File

@ -153,7 +153,7 @@ class RealTraderManager:
# 获取最新的委托列表
try:
entrusts = self.trader.get_today_entrust()
entrusts = self.trader.get_today_orders()
if entrusts is None:
logger.error("获取今日委托失败,跳过本次检查")
return
@ -224,7 +224,7 @@ class RealTraderManager:
try:
# 如果没有提供委托字典,则获取当前委托
if entrust_map is None:
entrusts = self.trader.get_today_entrust()
entrusts = self.trader.get_today_orders()
entrust_map = {str(e['order_id']): e for e in entrusts}
# 查找对应的委托记录

View File

@ -119,7 +119,7 @@ class SimulationTrader:
self.logger.info(message)
return []
def get_today_entrust(self):
def get_today_orders(self):
message = "模拟交易:查询今日委托"
self.logger.info(message)
return []

View File

@ -102,7 +102,7 @@ class StrategyPositionManager:
trader_type = StrategyPositionManager.get_trader_type(trader)
# 获取今日委托
today_entrusts = trader.get_today_entrust()
today_entrusts = trader.get_today_orders()
# 更新委托状态
for order_id, order_info in list(pending_orders[trader_type].items()):

View File

@ -498,13 +498,13 @@ def get_today_trades():
@app.route("/yu/todayentrust", methods=["GET"])
def get_today_entrust():
def get_today_orders():
"""Get the today's entrust of the account."""
logger.info("Received today entrust request")
try:
# 直接使用实盘交易实例,不考虑模拟盘
trader = get_real_trader()
entrust = trader.get_today_entrust()
entrust = trader.get_today_orders()
logger.info(f"今日委托: {entrust}")
return jsonify({"success": True, "data": entrust, "simulation": False}), 200
@ -620,7 +620,7 @@ def get_order_status():
trader = get_real_trader()
try:
entrusts = execute_with_timeout(trader.get_today_entrust, Config.TRADE_TIMEOUT)
entrusts = execute_with_timeout(trader.get_today_orders, Config.TRADE_TIMEOUT)
if entrusts is None:
logger.error("获取今日委托超时")
return jsonify({"success": False, "error": "获取今日委托超时", "simulation": False}), 500
@ -632,7 +632,7 @@ def get_order_status():
# 模拟交易模式
trader = get_sim_trader()
try:
entrusts = trader.get_today_entrust()
entrusts = trader.get_today_orders()
return jsonify({"success": True, "data": entrusts, "simulation": True}), 200
except Exception as e:
logger.error(f"获取今日委托时出错: {str(e)}")

View File

@ -1,12 +1,11 @@
import os
import random
from config import Config
from base_trader import BaseTrader
from xtquant.xttrader import XtQuantTrader
from xtquant.xttype import StockAccount
from xtquant import xtconstant
from xtquant.xtdata import get_instrument_detail, get_trading_time
import datetime as dt
from chinese_calendar import is_workday
from xtquant.xtdata import get_instrument_detail
from logger_config import get_logger
# 获取日志记录器
@ -38,8 +37,12 @@ class MyXtQuantTraderCallback:
def on_smt_appointment_async_response(self, response):
logger.info(f"约券异步反馈: {response.seq}")
class XtTrader:
class XtTrader(BaseTrader):
def __init__(self):
super().__init__()
self.started = False
self.connected = False
self.subscribed = False
self._ACCOUNT = Config.XT_ACCOUNT
self._PATH = Config.XT_PATH
self._SESSION_ID = random.randint(100000, 99999999)
@ -50,9 +53,6 @@ class XtTrader:
self.xt_trader = XtQuantTrader(self._PATH, self._SESSION_ID)
self.account = StockAccount(self._ACCOUNT, self._account_type)
self.xt_trader.register_callback(self._callback)
self.started = False
self.connected = False
self.subscribed = False
def is_logged_in(self):
"""检查交易系统是否已经登录
@ -61,7 +61,7 @@ class XtTrader:
bool: True表示已登录False表示未登录
"""
return self.started and self.connected and self.subscribed
def login(self):
if not self.started:
self.xt_trader.start()
@ -112,6 +112,23 @@ class XtTrader:
]
return []
def get_position(self, stock_code):
position = self.xt_trader.query_stock_position(self.account, stock_code)
if position:
return {
"account_id": position.account_id,
"stock_code": position.stock_code,
"volume": position.volume,
"can_use_volume": position.can_use_volume,
"open_price": position.open_price,
"market_value": position.market_value,
"frozen_volume": position.frozen_volume,
"on_road_volume": position.on_road_volume,
"yesterday_volume": position.yesterday_volume,
"avg_price": position.avg_price
}
return None
def get_today_trades(self):
trades = self.xt_trader.query_stock_trades(self.account)
if trades:
@ -119,7 +136,7 @@ class XtTrader:
{
"account_id": t.account_id,
"stock_code": t.stock_code,
"stock_name": get_instrument_detail(t.stock_code)["InstrumentName"] if get_instrument_detail(t.stock_code) else "",
"stock_name": self.get_stock_name(t.stock_code),
"order_id": t.order_id,
"traded_id": t.traded_id,
"traded_time": t.traded_time,
@ -131,7 +148,7 @@ class XtTrader:
]
return []
def get_today_entrust(self):
def get_today_orders(self):
orders = self.xt_trader.query_stock_orders(self.account)
if orders:
return [
@ -152,6 +169,25 @@ class XtTrader:
]
return []
def get_order(self, order_id):
order = self.xt_trader.query_stock_order(self.account, int(order_id))
if order:
return {
"account_id": order.account_id,
"stock_code": order.stock_code,
"order_id": order.order_id,
"order_time": order.order_time,
"order_type": "buy" if order.order_type == xtconstant.STOCK_BUY else "sell",
"order_volume": order.order_volume,
"price_type": self._convert_price_type(order.price_type),
"price": order.price,
"traded_volume": order.traded_volume,
"traded_price": order.traded_price,
"order_status": order.order_status,
"status_msg": order.status_msg
}
return None
def _convert_price_type(self, price_type):
"""Convert numeric price type to readable string"""
price_type_map = {
@ -165,29 +201,28 @@ class XtTrader:
}
return price_type_map.get(price_type, f"unknown_{price_type}")
def buy(self, code, price, amount, order_type='limit'):
"""买入股票
def get_stock_name(self, stock_code):
"""获取股票名称
Args:
code: 股票代码
price: 买入价格市价单时可为0
amount: 买入数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
stock_code: 股票代码例如 "600000.SH"
Returns:
dict: 包含订单ID的字典
str: 股票名称如果获取失败则返回空字符串
"""
try:
instrument_info = get_instrument_detail(stock_code)
if instrument_info and "InstrumentName" in instrument_info:
return instrument_info["InstrumentName"]
return ""
except Exception as e:
logger.error(f"获取股票名称失败: {stock_code}, {str(e)}")
return ""
def buy(self, code, price, amount, order_type='limit'):
# 确定价格类型
price_type = xtconstant.FIX_PRICE # 默认为限价单
if order_type == 'market':
# 市价单,根据不同市场选择合适的市价单类型
if code.startswith('1') or code.startswith('5'):
# 基金等可能需要不同的市价单类型
price_type = xtconstant.MARKET_BEST
else:
price_type = xtconstant.MARKET_BEST # 市价最优价
price_type = xtconstant.MARKET_BEST if order_type == 'market' else xtconstant.FIX_PRICE
# 如果是市价单价格可以设为0
if price_type != xtconstant.FIX_PRICE:
price = 0
@ -198,27 +233,8 @@ class XtTrader:
return {"order_id": order_id}
def sell(self, code, price, amount, order_type='limit'):
"""卖出股票
Args:
code: 股票代码
price: 卖出价格市价单时可为0
amount: 卖出数量
order_type: 订单类型'limit'=限价单'market'=市价单默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
# 确定价格类型
price_type = xtconstant.FIX_PRICE # 默认为限价单
if order_type == 'market':
# 市价单,根据不同市场选择合适的市价单类型
if code.startswith('1') or code.startswith('5'):
# 基金等可能需要不同的市价单类型
price_type = xtconstant.MARKET_BEST
else:
price_type = xtconstant.MARKET_BEST # 市价最优价
price_type = xtconstant.MARKET_BEST if order_type == 'market' else xtconstant.FIX_PRICE
# 如果是市价单价格可以设为0
if price_type != xtconstant.FIX_PRICE:
@ -229,40 +245,12 @@ class XtTrader:
)
return {"order_id": order_id}
def cancel(self, entrust_no):
def cancel(self, order_id):
# 撤单接口需要订单编号
result = self.xt_trader.cancel_order_stock(self.account, int(entrust_no))
result = self.xt_trader.cancel_order_stock(self.account, int(order_id))
return {"cancel_result": result}
def is_trading_time(self):
"""判断当前是否为交易时间
Returns:
bool: True 表示当前为交易时间False 表示当前休市
"""
try:
now = dt.datetime.now()
# 判断是否为工作日(使用 chinese_calendar 判断,会考虑节假日和调休)
if not is_workday(now):
return False
# 判断是否在交易时间段内
current_time = now.time()
morning_start = dt.time(9, 30) # 上午开市时间 9:30
morning_end = dt.time(11, 30) # 上午休市时间 11:30
afternoon_start = dt.time(13, 0) # 下午开市时间 13:00
afternoon_end = dt.time(15, 0) # 下午休市时间 15:00
# 判断是否在上午或下午的交易时段
is_morning_session = morning_start <= current_time <= morning_end
is_afternoon_session = afternoon_start <= current_time <= afternoon_end
return is_morning_session or is_afternoon_session
except Exception as e:
logger.error(f"判断交易时间发生错误: {str(e)}")
return False
if __name__ == "__main__":
trader = XtTrader()
@ -270,4 +258,4 @@ if __name__ == "__main__":
logger.info(f"账户余额: {trader.get_balance()}")
logger.info(f"持仓: {trader.get_positions()}")
logger.info(f"当日成交: {trader.get_today_trades()}")
logger.info(f"当日委托: {trader.get_today_entrust()}")
logger.info(f"当日委托: {trader.get_today_orders()}")