update position manager

This commit is contained in:
zhiyong 2025-05-09 18:31:41 +08:00
parent 978834772b
commit a9f654d359
11 changed files with 11473 additions and 253 deletions

View File

@ -1,6 +0,0 @@
def main():
print("Hello from real-trader!")
if __name__ == "__main__":
main()

11160
resources/grouped_etf.json Normal file

File diff suppressed because it is too large Load Diff

16
src/local_order.py Normal file
View File

@ -0,0 +1,16 @@
from trade_constants import ORDER_STATUS_PENDING
from datetime import datetime
class LocalOrder:
def __init__(self, order_id, code, price, amount, direction, order_type='limit'):
self.order_id = order_id
self.code = code
self.price = price
self.amount = amount
self.filled = 0
self.direction = direction
self.order_type = order_type
self.status = ORDER_STATUS_PENDING
self.created_time = datetime.now()

5
src/local_position.py Normal file
View File

@ -0,0 +1,5 @@
class LocalPosition:
def __init__(self, code, total_amount, closeable_amount):
self.code = code
self.total_amount = total_amount
self.closeable_amount = closeable_amount

View File

@ -1,7 +1,10 @@
import os
import json
from logger_config import get_logger
from trade_constants import ORDER_DIRECTION_BUY
from trade_constants import ORDER_DIRECTION_BUY, ORDER_TYPE_LIMIT, ORDER_TYPE_MARKET
from local_position import LocalPosition
from local_order import LocalOrder
from t0_stocks import is_t0
# 获取日志记录器
logger = get_logger('position_manager')
@ -9,129 +12,76 @@ logger = get_logger('position_manager')
class PositionManager():
"""实盘策略持仓管理器,负责管理不同策略在实盘环境下的持仓情况"""
def __init__(self, trade_type):
def __init__(self, strategy_name="default_strategy"):
"""初始化实盘持仓管理器"""
super().__init__()
self.strategy_name = strategy_name
# 策略持仓信息
self.positions = {} # 策略名 -> {股票代码 -> {total_amount, closeable_amount}}
self.positions = {} # {股票代码 -> LocalPosition}
# 待处理订单信息
self.pending_orders = {} # order_id -> 订单信息
self.data_path = trade_type + '_strategy_positions.json'
self.pending_orders = {} # {order_id -> LocalOrder}
self.data_path = self.strategy_name + '_positions.json'
self.load_data()
def update_position(self, strategy_name, code, direction, amount):
"""更新策略持仓
Args:
strategy_name: 策略名称
code: 股票代码
direction: 'buy''sell'
amount: 交易数量
"""
if not strategy_name:
return
# 确保策略在字典中
if strategy_name not in self.positions:
self.positions[strategy_name] = {}
def update_position(self, code, direction, amount):
# 如果股票代码在持仓字典中不存在,初始化它
if code not in self.positions[strategy_name]:
self.positions[strategy_name][code] = {
'total_amount': 0,
'closeable_amount': 0
}
if code not in self.positions:
self.positions[code] = LocalPosition(code, 0, 0)
# 根据方向更新持仓
position = self.positions[code]
is_t0_stock = is_t0(code)
if direction == ORDER_DIRECTION_BUY:
self.positions[strategy_name][code]['total_amount'] += amount
self.positions[strategy_name][code]['closeable_amount'] += amount
position.total_amount += amount
if is_t0_stock:
position.closeable_amount += amount
else: # sell
self.positions[strategy_name][code]['total_amount'] -= amount
self.positions[strategy_name][code]['closeable_amount'] -= amount
position.total_amount -= amount
position.closeable_amount -= amount
logger.info(f"更新策略持仓 - 策略: {strategy_name}, 代码: {code}, 方向: {direction}, 数量: {amount}, "
f"更新后总量: {self.positions[strategy_name][code]['total_amount']}, "
f"可用: {self.positions[strategy_name][code]['closeable_amount']}")
logger.info(f"更新策略持仓 - 策略: {self.strategy_name}, 代码: {code}, 方向: {direction}, 数量: {amount}, "
f"更新后总量: {position.total_amount}, "
f"可用: {position.closeable_amount}")
# 移除total_amount为0的持仓
if code in self.positions[strategy_name] and self.positions[strategy_name][code]['total_amount'] <= 0:
del self.positions[strategy_name][code]
logger.info(f"移除空持仓 - 策略: {strategy_name}, 代码: {code}")
if code in self.positions and self.positions[code].total_amount <= 0:
del self.positions[code]
logger.info(f"移除空持仓 - 策略: {self.strategy_name}, 代码: {code}")
def add_pending_order(self, order_id, strategy_name, code, price, amount, direction, order_type='limit'):
"""添加未完成委托
def add_pending_order(self, order_id, code, price, amount, direction, order_type=ORDER_TYPE_LIMIT):
if not self.strategy_name:
return
Args:
order_id: 订单ID
strategy_name: 策略名称
code: 股票代码
price: 委托价格
amount: 委托数量
direction: 交易方向
order_type: 订单类型
"""
# 添加未处理订单
self.pending_orders[order_id] = {
'strategy_name': strategy_name,
'code': code,
'price': price,
'target_amount': amount,
'direction': direction,
'order_type': order_type,
'status': 'pending',
'created_time': self._get_current_time(),
'retry_count': 0
}
order = LocalOrder(order_id, code, price, amount, direction, order_type)
self.pending_orders[order_id] = order
logger.info(f"添加未完成委托 - ID: {order_id}, 策略: {strategy_name}, 代码: {code}, 方向: {direction}, "
logger.info(f"添加订单 - ID: {order_id}, 策略: {self.strategy_name}, 代码: {code}, 方向: {direction}, "
f"数量: {amount}, 价格: {price}, 类型: {order_type}")
def update_order_status(self, order_id, new_status):
"""更新订单状态
Args:
order_id: 订单ID
new_status: 新状态
Returns:
bool: 是否成功更新
"""
def update_order_status(self, order_id, filled,new_status):
if order_id in self.pending_orders:
_order = self.pending_orders[order_id]
# 记录之前的状态用于日志
previous_status = self.pending_orders[order_id].get('status')
previous_status = _order.status
# 更新状态和最后检查时间
self.pending_orders[order_id]['status'] = new_status
# 更新状态
_order.status = new_status
_order.filled = filled
# 记录状态变化日志
if previous_status != new_status:
code = self.pending_orders[order_id].get('code')
code = self.pending_orders[order_id].code
logger.info(f"订单状态变化: ID={order_id}, 代码={code}, 旧状态={previous_status}, 新状态={new_status}")
# 如果订单已完成,移除它
if new_status in ['completed', 'cancelled', 'failed']:
# 保留订单信息以供参考,但标记为已完成
self.remove_pending_order(order_id)
del self.pending_orders[order_id]
logger.info(f"订单已删除 - ID: {order_id}, 状态: {new_status}")
return True
return False
def remove_pending_order(self, order_id):
"""移除未完成委托
Args:
order_id: 订单ID
Returns:
bool: 是否成功移除
"""
if order_id in self.pending_orders:
del self.pending_orders[order_id]
return True
return False
def get_pending_order(self, order_id):
"""获取未完成委托信息
@ -151,7 +101,7 @@ class PositionManager():
"""
return self.pending_orders
def get_positions(self, strategy_name=None):
def get_positions(self):
"""获取策略持仓
Args:
@ -160,19 +110,38 @@ class PositionManager():
Returns:
dict: 策略持仓信息
"""
if strategy_name:
if strategy_name not in self.positions:
return {}
return self.positions[strategy_name]
return self.positions
def save_data(self):
"""保存策略数据"""
try:
# 将对象转换为可序列化的字典
positions_dict = {}
for code, pos in self.positions.items():
positions_dict[code] = {
'code': pos.code,
'total_amount': pos.total_amount,
'closeable_amount': pos.closeable_amount
}
pending_orders_dict = {}
for order_id, order in self.pending_orders.items():
pending_orders_dict[order_id] = {
'order_id': order.order_id,
'code': order.code,
'price': order.price,
'amount': order.amount,
'filled': order.filled,
'direction': order.direction,
'order_type': order.order_type,
'status': order.status,
'created_time': order.created_time.isoformat() if hasattr(order, 'created_time') else None
}
with open(self.data_path, 'w') as f:
json.dump({
'positions': self.positions,
'pending_orders': self.pending_orders
'positions': positions_dict,
'pending_orders': pending_orders_dict
}, f)
logger.info("成功保存实盘策略数据")
except Exception as e:
@ -182,10 +151,42 @@ class PositionManager():
"""加载策略数据"""
try:
if os.path.exists(self.data_path):
from datetime import datetime
with open(self.data_path, 'r') as f:
data = json.load(f)
self.positions = data.get('positions', {})
self.pending_orders = data.get('pending_orders', {})
# 还原positions对象
self.positions = {}
positions_dict = data.get('positions', {})
for code, pos_data in positions_dict.items():
self.positions[code] = LocalPosition(
pos_data['code'],
pos_data['total_amount'],
pos_data['closeable_amount']
)
# 还原pending_orders对象
self.pending_orders = {}
pending_orders_dict = data.get('pending_orders', {})
for order_id, order_data in pending_orders_dict.items():
order = LocalOrder(
order_data['order_id'],
order_data['code'],
order_data['price'],
order_data['amount'],
order_data['direction'],
order_data['order_type']
)
order.filled = order_data['filled']
order.status = order_data['status']
if order_data.get('created_time'):
try:
order.created_time = datetime.fromisoformat(order_data['created_time'])
except (ValueError, TypeError):
order.created_time = datetime.now()
self.pending_orders[order_id] = order
logger.info("已加载实盘策略数据")
logger.info(f"策略数: {len(self.positions)}")
@ -199,60 +200,8 @@ class PositionManager():
self.positions = {}
self.pending_orders = {}
def _get_current_time(self):
"""获取当前时间戳"""
import time
return time.time()
def clean_timeout_orders(self):
"""清理超时未完成订单"""
timeout_limit = 24 * 60 * 60 # 24小时
current_time = self._get_current_time()
timeout_orders = []
for order_id, order_info in list(self.pending_orders.items()):
# 检查是否超时
if current_time - order_info['created_time'] > timeout_limit:
timeout_orders.append(order_id)
# 更新状态
self.update_order_status(order_id, 'failed')
if timeout_orders:
logger.warn(f"清理超时订单完成,共 {len(timeout_orders)} 个: {', '.join(timeout_orders)}")
def clear_strategy(self, strategy_name):
"""清除指定策略的持仓管理数据
Args:
strategy_name: 策略名称
Returns:
tuple: (success, message)
success: 是否成功清除
message: 提示信息
"""
if not strategy_name:
return False, "缺少策略名称参数"
# 检查策略是否存在
if strategy_name in self.positions:
# 从策略持仓字典中删除该策略
del self.positions[strategy_name]
# 清除该策略的交易记录
if strategy_name in self.trades:
del self.trades[strategy_name]
# 清除与该策略相关的未完成委托
for order_id, order_info in list(self.pending_orders.items()):
if order_info.get('strategy_name') == strategy_name:
del self.pending_orders[order_id]
# 保存更新后的策略数据
self.save_data()
logger.info(f"成功清除策略持仓数据: {strategy_name}")
return True, f"成功清除策略 '{strategy_name}' 的持仓数据"
else:
logger.info(f"策略不存在或没有持仓数据: {strategy_name}")
return True, f"策略 '{strategy_name}' 不存在或没有持仓数据"
def clear(self):
"""清除所有持仓管理数据"""
self.positions = {}
self.pending_orders = {}
self.save_data()

4
src/settlement_type.py Normal file
View File

@ -0,0 +1,4 @@
from enum import Enum
class SettlementType(Enum):
T0 = 0
T1 = 1

View File

@ -1,13 +1,13 @@
from logger_config import get_logger
from trade_constants import TRADE_TYPE_SIMULATION, ORDER_DIRECTION_BUY, ORDER_DIRECTION_SELL
from position_manager import PositionManager
class SimulationTrader:
def __init__(self, logger=None):
self.logger = logger or get_logger('simulation_trader')
# 添加模拟持仓字典,用于追踪模拟交易的持仓
self.sim_positions = {}
# 模拟资金账户信息
self.sim_balance = {"cash": 1000000.00, "frozen": 0.00, "total": 1000000.00}
self.position_manager = PositionManager(TRADE_TYPE_SIMULATION)
def is_logged_in(self):
"""检查交易系统是否已经登录
@ -24,66 +24,59 @@ class SimulationTrader:
self.logger.info("模拟交易:登出成功")
return True
def buy(self, code, price, amount):
message = f"模拟买入 - 代码: {code}, 价格: {price}, 数量: {amount}"
def buy(self, code, price, amount, strategy_name = "default_strategy"):
message = f"模拟买入 - 代码: {code}, 价格: {price}, 数量: {amount}, 策略: {strategy_name}"
self.logger.info(message)
# 更新模拟持仓
if code not in self.sim_positions:
self.sim_positions[code] = {
"stock_code": code,
"volume": 0,
"can_use_volume": 0,
"frozen_volume": 0,
"avg_price": 0.0,
"market_value": 0.0
}
# 计算交易成本
cost = price * amount
# 计算新的平均成本
current_cost = self.sim_positions[code]["avg_price"] * self.sim_positions[code]["volume"]
new_cost = price * amount
total_volume = self.sim_positions[code]["volume"] + amount
# 更新持仓信息
self.sim_positions[code]["volume"] += amount
self.sim_positions[code]["can_use_volume"] += amount
self.sim_positions[code]["avg_price"] = (current_cost + new_cost) / total_volume if total_volume > 0 else 0
self.sim_positions[code]["market_value"] = self.sim_positions[code]["volume"] * price
# 检查余额是否足够
if self.sim_balance["cash"] < cost:
message = f"模拟买入失败 - 代码: {code}, 资金不足"
self.logger.warning(message)
return {"order_id": None, "message": message, "success": False}
# 更新资金
self.sim_balance["cash"] -= price * amount
self.sim_balance["total"] = self.sim_balance["cash"] + sum(pos["market_value"] for pos in self.sim_positions.values())
self.sim_balance["cash"] -= cost
return {"order_id": "simulation", "message": message}
# 更新持仓管理器
self.position_manager.update_position(strategy_name, code, ORDER_DIRECTION_BUY, amount)
def sell(self, code, price, amount):
message = f"模拟卖出 - 代码: {code}, 价格: {price}, 数量: {amount}"
# 更新总资产
self._update_total_assets()
return {"order_id": "simulation", "message": message, "success": True}
def sell(self, code, price, amount, strategy_name = "default_strategy"):
message = f"模拟卖出 - 代码: {code}, 价格: {price}, 数量: {amount}, 策略: {strategy_name}"
self.logger.info(message)
# 更新模拟持仓
if code in self.sim_positions:
# 确保可用数量足够
if self.sim_positions[code]["can_use_volume"] >= amount:
# 更新持仓信息
self.sim_positions[code]["volume"] -= amount
self.sim_positions[code]["can_use_volume"] -= amount
self.sim_positions[code]["market_value"] = self.sim_positions[code]["volume"] * price
# 获取策略持仓
strategy_positions = self.position_manager.get_positions(strategy_name)
# 如果持仓为0删除该股票
if self.sim_positions[code]["volume"] <= 0:
del self.sim_positions[code]
# 更新资金
self.sim_balance["cash"] += price * amount
self.sim_balance["total"] = self.sim_balance["cash"] + sum(pos["market_value"] for pos in self.sim_positions.values())
else:
message = f"模拟卖出失败 - 代码: {code}, 可用数量不足"
self.logger.warning(message)
else:
message = f"模拟卖出失败 - 代码: {code}, 无持仓"
# 检查持仓是否足够
if code not in strategy_positions or strategy_positions[code]['closeable_amount'] < amount:
message = f"模拟卖出失败 - 代码: {code}, 可用数量不足"
self.logger.warning(message)
return {"order_id": None, "message": message, "success": False}
return {"order_id": "simulation", "message": message}
# 更新资金
proceeds = price * amount
self.sim_balance["cash"] += proceeds
# 更新持仓管理器
self.position_manager.update_position(strategy_name, code, ORDER_DIRECTION_SELL, amount)
# 更新总资产
self._update_total_assets()
return {"order_id": "simulation", "message": message, "success": True}
def _update_total_assets(self):
"""更新总资产"""
# 此处简化处理,在实际情况中应该计算所有持仓的市值
self.sim_balance["total"] = self.sim_balance["cash"]
def cancel(self, entrust_no):
message = f"模拟撤单 - 委托号: {entrust_no}"
@ -98,20 +91,25 @@ class SimulationTrader:
def get_positions(self):
message = "模拟交易:查询持仓"
self.logger.info(message)
# 返回与XtTrader格式一致的持仓数据
return [
{
"account_id": "simulation",
"stock_code": code,
"volume": pos["volume"],
"can_use_volume": pos["can_use_volume"],
"open_price": pos["avg_price"],
"avg_price": pos["avg_price"],
"market_value": pos["market_value"],
"frozen_volume": pos["frozen_volume"],
"on_road_volume": 0
} for code, pos in self.sim_positions.items()
]
# 从持仓管理器获取所有策略的持仓
all_positions = []
for strategy_name, positions in self.position_manager.get_positions().items():
for code, position_info in positions.items():
all_positions.append({
"account_id": "simulation",
"stock_code": code,
"strategy_name": strategy_name,
"volume": position_info["total_amount"],
"can_use_volume": position_info["closeable_amount"],
"open_price": 0.0, # 持仓管理器中没有记录价格信息
"avg_price": 0.0, # 持仓管理器中没有记录价格信息
"market_value": 0.0, # 持仓管理器中没有记录价格信息
"frozen_volume": 0,
"on_road_volume": 0
})
return all_positions
def get_today_trades(self):
message = "模拟交易:查询今日成交"

117
src/t0_stocks.py Normal file
View File

@ -0,0 +1,117 @@
# 读取所有ETF文件 /resources/grouped_etf.json
import json
import os
from typing import List
def get_all_t0() -> List[str]:
"""
读取/resources/grouped_etf.json文件获取所有T+0交易的ETF代码
除了"其他ETF"分类外其余都是T+0
Returns:
List[str]: 所有T+0交易的ETF代码列表
"""
# 获取当前文件所在目录的路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 构建json文件的绝对路径
json_path = os.path.join(os.path.dirname(current_dir), 'resources', 'grouped_etf.json')
# 读取json文件
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 初始化结果列表
t0_stocks = []
# 遍历所有分类
for category, stocks in data.items():
# 跳过"其他ETF"分类
if category == "其他ETF":
continue
# 将当前分类下的所有股票代码添加到结果列表中
for stock in stocks:
t0_stocks.append(stock["code"])
return t0_stocks
def normalize_stock_code(stock: str) -> str:
"""
标准化股票代码格式
Args:
stock (str): 可能是"123456.XSHE""123456.SH""123456"格式的股票代码
Returns:
str: 标准化后的股票代码格式为"123456.XSHE""123456.XSHG"
"""
if '.' not in stock:
# 如果没有后缀,则根据第一位数字判断交易所
code = stock.strip()
if code[0] in ['0', '3']: # 深交所
return f"{code}.XSHE"
else: # 上交所
return f"{code}.XSHG"
else:
# 已经有后缀,判断是否需要转换
code, exchange = stock.split('.')
if exchange.upper() in ['SH', 'XSHG']:
return f"{code}.XSHG"
elif exchange.upper() in ['SZ', 'XSHE']:
return f"{code}.XSHE"
else:
# 已经是标准格式或其他格式,直接返回
return stock
def is_t0(stock: str) -> bool:
"""
判断给定的股票代码是否属于T+0交易的ETF
Args:
stock (str): 股票代码可能是"123456.XSHE""123456.SH""123456"格式
Returns:
bool: 如果是T+0交易的ETF则返回True否则返回False
"""
# 获取所有T+0股票列表
t0_list = get_all_t0()
# 标准化输入的股票代码
normalized_stock = normalize_stock_code(stock)
# 判断标准化后的代码是否在T+0列表中
return normalized_stock in t0_list
def get_hk_t0() -> List[str]:
"""
获取所有T+0交易的香港ETF代码
Returns:
List[str]: 所有T+0交易的香港ETF代码列表
"""
# 获取当前文件所在目录的路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 构建json文件的绝对路径
json_path = os.path.join(os.path.dirname(current_dir), 'resources', 'grouped_etf.json')
# 读取json文件
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 初始化结果列表
hk_t0_stocks = []
# 检查"港股ETF"分类是否存在
if "港股ETF" in data:
# 获取所有港股ETF
for stock in data["港股ETF"]:
hk_t0_stocks.append(stock["code"])
return hk_t0_stocks
if __name__ == "__main__":
print(get_hk_t0())

1
src/trade_tools.py Normal file
View File

@ -0,0 +1 @@

View File

@ -1,24 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""测试导入新模块结构"""
import sys
import os
# 添加src目录到Python导入路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
try:
from simulation.simulation_trader import SimulationTrader
print("导入 SimulationTrader 成功!")
except Exception as e:
print(f"导入 SimulationTrader 失败: {e}")
try:
from real.xt_trader import XtTrader
print("导入 XtTrader 成功!")
except Exception as e:
print(f"导入 XtTrader 失败: {e}")
print("测试完成")