real_trader/src/base_trader.py

251 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime as dt
from chinese_calendar import is_workday
from abc import ABC, abstractmethod
from logger_config import get_logger
from position_manager import PositionManager
from typing import Dict
from config import Config
import os
# 获取日志记录器
class BaseTrader(ABC):
"""交易基类,定义交易接口的通用方法"""
def __init__(self, logger):
"""初始化交易基类"""
self.position_managers: Dict[str, PositionManager] = {}
self.logger = logger
self._load_all_position_managers_from_data_dir() # 新增:自动加载所有持仓管理器
pass
@abstractmethod
def is_logged_in(self):
"""检查交易系统是否已经登录
Returns:
bool: True表示已登录False表示未登录
"""
pass
@abstractmethod
def login(self):
"""登录交易系统
Returns:
bool: 登录是否成功
"""
pass
@abstractmethod
def logout(self):
"""登出交易系统"""
pass
@abstractmethod
def get_balance(self):
"""获取账户资金情况
Returns:
dict: 账户资金信息若失败返回None
"""
pass
@abstractmethod
def get_positions(self):
"""获取所有持仓
Returns:
list: 持仓列表,若无持仓返回空列表
"""
pass
@abstractmethod
def get_position(self, stock_code):
"""查询指定股票代码的持仓信息
Args:
stock_code: 股票代码,例如 "600000.SH"
Returns:
dict: 持仓详情如果未持有则返回None
"""
pass
@abstractmethod
def get_today_trades(self):
"""获取当日成交
Returns:
list: 成交列表,若无成交返回空列表
"""
pass
@abstractmethod
def get_today_orders(self):
"""获取当日委托
Returns:
list: 委托列表,若无委托返回空列表
"""
pass
@abstractmethod
def get_order(self, order_id):
"""查询指定订单ID的详细信息
Args:
order_id: 订单ID
Returns:
dict: 订单详情如果未找到则返回None
"""
pass
@abstractmethod
def buy(self, code, price, amount, order_type='limit'):
"""买入股票
Args:
code: 股票代码
price: 买入价格市价单时可为0
amount: 买入数量
order_type: 订单类型,'limit'=限价单,'market'=市价单,默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
pass
@abstractmethod
def sell(self, code, price, amount, order_type='limit'):
"""卖出股票
Args:
code: 股票代码
price: 卖出价格市价单时可为0
amount: 卖出数量
order_type: 订单类型,'limit'=限价单,'market'=市价单,默认为'limit'
Returns:
dict: 包含订单ID的字典
"""
pass
@abstractmethod
def cancel(self, order_id):
"""撤销订单
Args:
order_id: 订单ID
Returns:
dict: 撤单结果
"""
pass
@staticmethod
def is_trading_time():
"""判断当前是否为交易时间
Returns:
bool: True 表示当前为交易时间False 表示当前休市
"""
try:
now = dt.datetime.now()
# 先判断是否为交易日
if not BaseTrader.is_trading_date():
return False
# 判断是否在交易时间段内
current_time = now.time()
morning_start = dt.time(9, 30) # 上午开市时间 9:30
morning_end = dt.time(11, 30) # 上午休市时间 11:30
afternoon_start = dt.time(13, 0) # 下午开市时间 13:00
afternoon_end = dt.time(15, 0) # 下午休市时间 15:00
# 判断是否在上午或下午的交易时段
is_morning_session = morning_start <= current_time <= morning_end
is_afternoon_session = afternoon_start <= current_time <= afternoon_end
return is_morning_session or is_afternoon_session
except Exception as e:
logger = get_logger("BaseTrader")
logger.error(f"判断交易时间发生错误: {str(e)}")
return False
@staticmethod
def is_trading_date(date=None):
"""判断指定日期是否为交易日
Args:
date: 日期,默认为当前日期
Returns:
bool: True 表示是交易日False 表示非交易日
"""
try:
# 如果未指定日期,使用当前日期
if date is None:
date = dt.datetime.now()
# 使用 chinese_calendar 判断是否为工作日(考虑节假日和调休)
return is_workday(date)
except Exception as e:
logger = get_logger("BaseTrader")
logger.error(f"判断交易日期发生错误: {str(e)}")
return False
def get_position_manager(self, strategy_name) -> PositionManager:
"""获取指定策略的持仓管理器
Args:
strategy_name: 策略名称
Returns:
PositionManager: 指定策略的持仓管理器
"""
if strategy_name not in self.position_managers:
self.position_managers[strategy_name] = PositionManager(strategy_name)
return self.position_managers[strategy_name]
def get_all_position_managers(self) -> Dict[str, PositionManager]:
"""获取所有持仓管理器"""
return self.position_managers
def is_today(self, datetime: dt.datetime) -> bool:
"""判断指定日期是否为当前日期
Args:
datetime: 日期时间
Returns:
bool: True 表示是当前日期False 表示不是当前日期
"""
return datetime.date() == dt.datetime.now().date()
def clear_position_manager(self, strategy_name):
"""清除指定策略的持仓管理器"""
if strategy_name in self.position_managers:
self.position_managers[strategy_name].clear()
return True
return False
def _load_all_position_managers_from_data_dir(self):
"""从Config.DATA_DIR目录下的持仓文件自动加载所有PositionManager"""
data_dir = Config.DATA_DIR
if not os.path.exists(data_dir):
self.logger.info(f"持仓数据目录不存在: {data_dir}")
return
for fname in os.listdir(data_dir):
if fname.endswith('_positions.json'):
strategy_name = fname[:-len('_positions.json')]
try:
self.position_managers[strategy_name] = PositionManager(strategy_name)
self.logger.info(f"已自动加载策略持仓: {strategy_name}")
except Exception as e:
self.logger.error(f"加载策略持仓失败: {strategy_name}, 错误: {e}")