You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

596 lines
24 KiB

1 month ago
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Akshare数据下载和Kronos模型进行股票预测
  5. """
  6. import pandas as pd
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import akshare as ak
  10. import os
  11. import sys
  12. from datetime import datetime, timedelta
  13. import warnings
  14. import holidays
  15. warnings.filterwarnings('ignore')
  16. # 添加项目根目录到路径
  17. import os
  18. current_dir = os.path.dirname(os.path.abspath(__file__))
  19. parent_dir = os.path.dirname(current_dir)
  20. sys.path.insert(0, parent_dir)
  21. from model import Kronos, KronosTokenizer, KronosPredictor
  22. # 设置中文字体
  23. plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
  24. plt.rcParams['axes.unicode_minus'] = False
  25. class InteractiveStockPredictor:
  26. """交互式股票预测器"""
  27. def __init__(self):
  28. """初始化预测器"""
  29. self.predictor = None
  30. self.model = None
  31. self.tokenizer = None
  32. self.device = "cuda:0" if self._check_cuda() else "cpu"
  33. print(f"使用设备: {self.device}")
  34. # 初始化中国节假日
  35. self.cn_holidays = holidays.China()
  36. def _check_cuda(self):
  37. """检查CUDA是否可用"""
  38. try:
  39. import torch
  40. return torch.cuda.is_available()
  41. except ImportError:
  42. return False
  43. def is_trading_day(self, date):
  44. """判断是否为交易日(排除周末和节假日)"""
  45. # 排除周末
  46. if date.weekday() >= 5: # 5=周六, 6=周日
  47. return False
  48. # 排除节假日
  49. if date in self.cn_holidays:
  50. return False
  51. return True
  52. def generate_trading_days(self, start_date, num_days):
  53. """生成指定数量的交易日"""
  54. trading_days = []
  55. current_date = start_date
  56. while len(trading_days) < num_days:
  57. if self.is_trading_day(current_date):
  58. trading_days.append(current_date)
  59. current_date += timedelta(days=1)
  60. return trading_days
  61. def load_models(self):
  62. """加载Kronos模型和分词器"""
  63. try:
  64. print("正在加载Kronos模型...")
  65. self.tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
  66. self.model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
  67. self.predictor = KronosPredictor(
  68. model=self.model,
  69. tokenizer=self.tokenizer,
  70. device=self.device,
  71. max_context=512
  72. )
  73. print("✅ 模型加载成功!")
  74. return True
  75. except Exception as e:
  76. print(f"❌ 模型加载失败: {str(e)}")
  77. return False
  78. def get_stock_codes(self):
  79. """获取股票代码列表"""
  80. print("\n" + "="*60)
  81. print("股票代码输入方式")
  82. print("="*60)
  83. print("1. 手动输入股票代码")
  84. print("2. 从TXT文件读取股票代码列表")
  85. print()
  86. while True:
  87. choice = input("请选择输入方式 (1/2): ").strip()
  88. if choice == '1':
  89. return self._get_manual_codes()
  90. elif choice == '2':
  91. return self._get_codes_from_file()
  92. else:
  93. print("❌ 请输入 1 或 2")
  94. def _get_manual_codes(self):
  95. """手动输入股票代码"""
  96. print("\n手动输入股票代码")
  97. print("-" * 30)
  98. print("格式说明:")
  99. print("- 多个股票代码用逗号分隔")
  100. print("- 支持A股代码格式: 600030, 002261, 688326, 300364")
  101. print("- 示例: 600030,002261")
  102. print()
  103. while True:
  104. stock_input = input("请输入股票代码: ").strip()
  105. if not stock_input:
  106. print("❌ 请输入有效的股票代码")
  107. continue
  108. # 解析股票代码
  109. stock_codes = [code.strip() for code in stock_input.split(',')]
  110. stock_codes = [code for code in stock_codes if code]
  111. if not stock_codes:
  112. print("❌ 请输入有效的股票代码")
  113. continue
  114. # 验证股票代码格式
  115. valid_codes = []
  116. for code in stock_codes:
  117. if code.isdigit() and len(code) == 6:
  118. valid_codes.append(code)
  119. else:
  120. print(f"⚠️ 股票代码 {code} 格式不正确,已跳过")
  121. if not valid_codes:
  122. print("❌ 没有有效的股票代码")
  123. continue
  124. return valid_codes
  125. def _get_codes_from_file(self):
  126. """从TXT文件读取股票代码"""
  127. print("\n从TXT文件读取股票代码")
  128. print("-" * 30)
  129. print("文件格式说明:")
  130. print("- 每行一个股票代码")
  131. print("- 支持注释行(以#开头)")
  132. print("- 示例文件内容:")
  133. print(" # 这是注释行")
  134. print(" 600030")
  135. print(" 002261")
  136. print()
  137. while True:
  138. filename = input("请输入TXT文件名 (例如: stock_codes.txt): ").strip()
  139. if not filename:
  140. print("❌ 请输入文件名")
  141. continue
  142. # 如果用户没有输入扩展名,自动添加.txt
  143. if not filename.endswith('.txt'):
  144. filename += '.txt'
  145. try:
  146. with open(filename, 'r', encoding='utf-8') as f:
  147. lines = f.readlines()
  148. stock_codes = []
  149. for line in lines:
  150. line = line.strip()
  151. if line and not line.startswith('#'):
  152. stock_codes.append(line)
  153. if not stock_codes:
  154. print("❌ 文件中没有有效的股票代码")
  155. continue
  156. # 验证股票代码格式
  157. valid_codes = []
  158. for code in stock_codes:
  159. if code.isdigit() and len(code) == 6:
  160. valid_codes.append(code)
  161. else:
  162. print(f"⚠️ 股票代码 {code} 格式不正确,已跳过")
  163. if not valid_codes:
  164. print("❌ 文件中没有有效的股票代码")
  165. continue
  166. return valid_codes
  167. except FileNotFoundError:
  168. print(f"❌ 文件 {filename} 不存在")
  169. continue
  170. except Exception as e:
  171. print(f"❌ 读取文件失败: {e}")
  172. continue
  173. def download_stock_data(self, stock_code, days=100, max_retries=5):
  174. """下载股票数据"""
  175. import time
  176. import requests
  177. from requests.adapters import HTTPAdapter
  178. from urllib3.util.retry import Retry
  179. # 配置重试策略
  180. session = requests.Session()
  181. retry_strategy = Retry(
  182. total=2,
  183. backoff_factor=2,
  184. status_forcelist=[429, 500, 502, 503, 504],
  185. allowed_methods=["HEAD", "GET", "OPTIONS"]
  186. )
  187. adapter = HTTPAdapter(max_retries=retry_strategy)
  188. session.mount("http://", adapter)
  189. session.mount("https://", adapter)
  190. for attempt in range(max_retries):
  191. try:
  192. if attempt > 0:
  193. print(f"正在重试下载股票 {stock_code} 的数据... (第 {attempt + 1} 次)")
  194. # 递增等待时间,并添加随机抖动
  195. import random
  196. wait_time = 8 * attempt + random.uniform(1, 3)
  197. print(f"⏳ 等待 {wait_time:.1f} 秒后重试...")
  198. time.sleep(wait_time)
  199. else:
  200. print(f"正在下载股票 {stock_code} 的数据...")
  201. # 计算日期范围(最近100个交易日)
  202. end_date = datetime.now()
  203. start_date = end_date - timedelta(days=days*2) # 多取一些天数确保有足够的交易日
  204. print(f" 请求日期范围: {start_date.strftime('%Y-%m-%d')} 至 {end_date.strftime('%Y-%m-%d')}")
  205. # 添加请求前的短暂延迟,避免请求过于频繁
  206. if attempt > 0:
  207. time.sleep(2)
  208. # 使用akshare下载数据
  209. data = ak.stock_zh_a_hist(
  210. symbol=stock_code,
  211. period="daily",
  212. start_date=start_date.strftime('%Y%m%d'),
  213. end_date=end_date.strftime('%Y%m%d'),
  214. adjust="qfq" # 前复权
  215. )
  216. if data.empty:
  217. print(f"❌ 股票 {stock_code}: 未找到数据")
  218. return None
  219. # 重命名列以匹配Kronos格式
  220. data = data.rename(columns={
  221. '日期': 'timestamps',
  222. '开盘': 'open',
  223. '收盘': 'close',
  224. '最高': 'high',
  225. '最低': 'low',
  226. '成交量': 'volume',
  227. '成交额': 'amount'
  228. })
  229. # 设置日期为索引
  230. data['timestamps'] = pd.to_datetime(data['timestamps'])
  231. data = data.set_index('timestamps')
  232. # 只取最近100个交易日
  233. if len(data) > days:
  234. data = data.tail(days)
  235. print(f"✅ 股票 {stock_code}: 成功下载 {len(data)} 条记录")
  236. print(f" 数据范围: {data.index[0].strftime('%Y-%m-%d')} 至 {data.index[-1].strftime('%Y-%m-%d')}")
  237. return data
  238. except Exception as e:
  239. error_msg = str(e)
  240. print(f"❌ 股票 {stock_code}: 下载失败 (第 {attempt + 1} 次) - {error_msg}")
  241. # 分析错误类型
  242. if "Connection reset by peer" in error_msg:
  243. print(" 🔍 分析: 连接被服务器重置,可能是请求过于频繁")
  244. elif "timeout" in error_msg.lower():
  245. print(" 🔍 分析: 请求超时,网络可能较慢")
  246. elif "Connection aborted" in error_msg:
  247. print(" 🔍 分析: 连接被中断,可能是网络不稳定")
  248. if attempt == max_retries - 1:
  249. print(f"❌ 股票 {stock_code}: 经过 {max_retries} 次尝试后仍然失败")
  250. print("💡 建议:")
  251. print(" 1. 检查网络连接是否稳定")
  252. print(" 2. 稍后重试(服务器可能负载较高)")
  253. print(" 3. 确认股票代码是否正确")
  254. print(" 4. 尝试使用其他网络环境")
  255. return None
  256. else:
  257. # 更长的等待时间
  258. wait_time = 8 * (attempt + 1)
  259. print(f"⏳ 等待 {wait_time} 秒后重试...")
  260. time.sleep(wait_time)
  261. return None
  262. def prepare_prediction_data(self, data, lookback_days=100, pred_days=30):
  263. """准备预测数据"""
  264. try:
  265. # 确保数据长度足够
  266. if len(data) < lookback_days:
  267. print(f"⚠️ 数据长度不足,需要 {lookback_days} 天,实际只有 {len(data)} 天")
  268. lookback_days = len(data)
  269. # 准备历史数据
  270. x_df = data.tail(lookback_days)[['open', 'high', 'low', 'close', 'volume', 'amount']].copy()
  271. x_timestamp = data.tail(lookback_days).index
  272. # 生成未来预测时间戳(交易日,排除周末和节假日)
  273. last_date = x_timestamp[-1]
  274. future_trading_days = self.generate_trading_days(last_date + timedelta(days=1), pred_days)
  275. y_timestamp = pd.Series(future_trading_days)
  276. # 确保时间戳是Series格式
  277. x_timestamp = pd.Series(x_timestamp)
  278. print(f"📅 预测期间: {future_trading_days[0].strftime('%Y-%m-%d')} 至 {future_trading_days[-1].strftime('%Y-%m-%d')}")
  279. print(f"📅 预测天数: {len(future_trading_days)} 个交易日")
  280. return x_df, x_timestamp, y_timestamp
  281. except Exception as e:
  282. print(f"❌ 数据准备失败: {str(e)}")
  283. return None, None, None
  284. def make_prediction(self, x_df, x_timestamp, y_timestamp, pred_len=30):
  285. """进行预测"""
  286. try:
  287. print("正在进行预测...")
  288. pred_df = self.predictor.predict(
  289. df=x_df,
  290. x_timestamp=x_timestamp,
  291. y_timestamp=y_timestamp,
  292. pred_len=pred_len,
  293. T=1.0,
  294. top_p=0.9,
  295. sample_count=1,
  296. verbose=True
  297. )
  298. print("✅ 预测完成!")
  299. return pred_df
  300. except Exception as e:
  301. print(f"❌ 预测失败: {str(e)}")
  302. return None
  303. def plot_prediction(self, stock_code, historical_data, pred_data, x_timestamp, y_timestamp):
  304. """绘制预测结果"""
  305. try:
  306. fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10), sharex=True)
  307. # 创建连续的时间轴索引
  308. x_indices = range(len(x_timestamp))
  309. y_indices = range(len(x_timestamp), len(x_timestamp) + len(y_timestamp))
  310. # 绘制价格图
  311. ax1.plot(x_indices, historical_data['close'], label='历史价格', color='blue', linewidth=2)
  312. ax1.plot(y_indices, pred_data['close'], label='预测价格', color='red', linewidth=2, linestyle='--')
  313. ax1.set_ylabel('收盘价 (元)', fontsize=12)
  314. ax1.set_title(f'股票 {stock_code} 价格预测 (排除节假日)', fontsize=14, fontweight='bold')
  315. ax1.legend(fontsize=11)
  316. ax1.grid(True, alpha=0.3)
  317. # 绘制成交量图
  318. ax2.plot(x_indices, historical_data['volume'], label='历史成交量', color='blue', linewidth=2)
  319. ax2.plot(y_indices, pred_data['volume'], label='预测成交量', color='red', linewidth=2, linestyle='--')
  320. ax2.set_ylabel('成交量', fontsize=12)
  321. ax2.set_xlabel('交易日', fontsize=12)
  322. ax2.legend(fontsize=11)
  323. ax2.grid(True, alpha=0.3)
  324. # 设置x轴刻度
  325. total_days = len(x_timestamp) + len(y_timestamp)
  326. step = max(1, total_days // 12) # 显示约12个标签
  327. tick_positions = list(range(0, total_days, step))
  328. # 创建标签:历史数据用实际日期,预测数据用预测日期
  329. tick_labels = []
  330. for pos in tick_positions:
  331. if pos < len(x_timestamp):
  332. # 历史数据标签
  333. tick_labels.append(x_timestamp.iloc[pos].strftime('%m-%d'))
  334. else:
  335. # 预测数据标签
  336. pred_pos = pos - len(x_timestamp)
  337. if pred_pos < len(y_timestamp):
  338. tick_labels.append(y_timestamp.iloc[pred_pos].strftime('%m-%d'))
  339. else:
  340. tick_labels.append('')
  341. ax2.set_xticks(tick_positions)
  342. ax2.set_xticklabels(tick_labels, rotation=45, ha='right')
  343. # 添加分隔线区分历史和预测数据
  344. split_point = len(x_timestamp) - 0.5
  345. ax1.axvline(x=split_point, color='gray', linestyle=':', alpha=0.7, linewidth=2)
  346. ax2.axvline(x=split_point, color='gray', linestyle=':', alpha=0.7, linewidth=2)
  347. # 添加文本标注
  348. ax1.text(0.02, 0.98, f'历史数据: {len(x_timestamp)} 个交易日',
  349. transform=ax1.transAxes, verticalalignment='top',
  350. bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
  351. ax1.text(0.02, 0.88, f'预测数据: {len(y_timestamp)} 个交易日',
  352. transform=ax1.transAxes, verticalalignment='top',
  353. bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.8))
  354. # 添加日期范围标注
  355. ax1.text(0.98, 0.02, f'历史: {x_timestamp.iloc[0].strftime("%Y-%m-%d")} 至 {x_timestamp.iloc[-1].strftime("%Y-%m-%d")}',
  356. transform=ax1.transAxes, verticalalignment='bottom', horizontalalignment='right',
  357. bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8), fontsize=9)
  358. ax1.text(0.98, 0.12, f'预测: {y_timestamp.iloc[0].strftime("%Y-%m-%d")} 至 {y_timestamp.iloc[-1].strftime("%Y-%m-%d")}',
  359. transform=ax1.transAxes, verticalalignment='bottom', horizontalalignment='right',
  360. bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8), fontsize=9)
  361. plt.tight_layout()
  362. # 保存图片
  363. output_dir = "prediction_results"
  364. if not os.path.exists(output_dir):
  365. os.makedirs(output_dir)
  366. timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
  367. filename = f"{output_dir}/prediction_{stock_code}_{timestamp_str}.png"
  368. plt.savefig(filename, dpi=300, bbox_inches='tight')
  369. print(f"📊 预测图表已保存: {filename}")
  370. plt.show()
  371. except Exception as e:
  372. print(f"❌ 绘图失败: {str(e)}")
  373. import traceback
  374. traceback.print_exc()
  375. def save_prediction_results(self, stock_code, pred_data, y_timestamp):
  376. """保存预测结果"""
  377. try:
  378. output_dir = "prediction_results"
  379. if not os.path.exists(output_dir):
  380. os.makedirs(output_dir)
  381. # 保存为CSV
  382. timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
  383. csv_filename = f"{output_dir}/prediction_{stock_code}_{timestamp_str}.csv"
  384. pred_data.to_csv(csv_filename, encoding='utf-8-sig')
  385. print(f"💾 预测结果已保存: {csv_filename}")
  386. # 保存为JSON
  387. json_filename = f"{output_dir}/prediction_{stock_code}_{timestamp_str}.json"
  388. pred_data.to_json(json_filename, orient='index', date_format='iso')
  389. print(f"💾 预测结果(JSON)已保存: {json_filename}")
  390. except Exception as e:
  391. print(f"❌ 保存结果失败: {str(e)}")
  392. def print_prediction_summary(self, stock_code, pred_data):
  393. """打印预测摘要"""
  394. print(f"\n📈 股票 {stock_code} 预测摘要")
  395. print("="*50)
  396. print(f"预测期间: {pred_data.index[0].strftime('%Y-%m-%d')} 至 {pred_data.index[-1].strftime('%Y-%m-%d')}")
  397. print(f"预测天数: {len(pred_data)} 个交易日")
  398. print()
  399. # 价格统计
  400. print("价格预测:")
  401. print(f" 起始价格: {pred_data['close'].iloc[0]:.2f}")
  402. print(f" 结束价格: {pred_data['close'].iloc[-1]:.2f}")
  403. print(f" 最高价格: {pred_data['high'].max():.2f}")
  404. print(f" 最低价格: {pred_data['low'].min():.2f}")
  405. print(f" 价格变化: {((pred_data['close'].iloc[-1] / pred_data['close'].iloc[0]) - 1) * 100:.2f}%")
  406. print()
  407. # 成交量统计
  408. print("成交量预测:")
  409. print(f" 平均成交量: {pred_data['volume'].mean():.0f}")
  410. print(f" 最大成交量: {pred_data['volume'].max():.0f}")
  411. print(f" 最小成交量: {pred_data['volume'].min():.0f}")
  412. print()
  413. # 显示前5天和后5天的预测
  414. print("预测详情 (前5天):")
  415. print(pred_data.head().round(2))
  416. print()
  417. print("预测详情 (后5天):")
  418. print(pred_data.tail().round(2))
  419. def run(self, test_mode=False, test_stock_codes=None):
  420. """运行主程序"""
  421. print("🚀 交互式股票预测程序")
  422. print("="*60)
  423. print("本程序使用Kronos模型预测股票未来走势")
  424. print("支持A股市场,预测未来30个交易日的价格和成交量")
  425. print()
  426. # 加载模型
  427. if not self.load_models():
  428. return
  429. # 获取股票代码
  430. if test_mode and test_stock_codes:
  431. stock_codes = test_stock_codes
  432. print(f"🧪 测试模式: 使用预设股票代码 {stock_codes}")
  433. else:
  434. stock_codes = self.get_stock_codes()
  435. if not stock_codes:
  436. print("❌ 未获取到有效的股票代码")
  437. return
  438. print(f"\n📊 将预测以下股票: {', '.join(stock_codes)}")
  439. # 对每只股票进行预测
  440. for i, stock_code in enumerate(stock_codes, 1):
  441. print(f"\n{'='*60}")
  442. print(f"正在处理股票 {i}/{len(stock_codes)}: {stock_code}")
  443. print('='*60)
  444. # 下载数据
  445. data = self.download_stock_data(stock_code, days=100)
  446. if data is None:
  447. continue
  448. # 准备预测数据
  449. x_df, x_timestamp, y_timestamp = self.prepare_prediction_data(data, lookback_days=100, pred_days=30)
  450. if x_df is None:
  451. continue
  452. # 进行预测
  453. pred_data = self.make_prediction(x_df, x_timestamp, y_timestamp, pred_len=30)
  454. if pred_data is None:
  455. continue
  456. # 打印预测摘要
  457. self.print_prediction_summary(stock_code, pred_data)
  458. # 绘制预测图
  459. self.plot_prediction(stock_code, x_df, pred_data, x_timestamp, y_timestamp)
  460. # 保存预测结果
  461. self.save_prediction_results(stock_code, pred_data, y_timestamp)
  462. print(f"✅ 股票 {stock_code} 预测完成!")
  463. print(f"\n🎉 所有股票预测完成!")
  464. print("预测结果已保存到 prediction_results 目录")
  465. def main():
  466. """主函数"""
  467. predictor = InteractiveStockPredictor()
  468. # 检查是否为测试模式
  469. import sys
  470. if len(sys.argv) > 1 and sys.argv[1] == '--test':
  471. # 测试模式:使用示例股票代码
  472. test_codes = ['600036', '000001'] # 招商银行、平安银行
  473. predictor.run(test_mode=True, test_stock_codes=test_codes)
  474. elif len(sys.argv) > 1:
  475. # 命令行模式:直接指定股票代码
  476. stock_codes = sys.argv[1:]
  477. # 验证股票代码格式
  478. valid_codes = []
  479. for code in stock_codes:
  480. if code.isdigit() and len(code) == 6:
  481. valid_codes.append(code)
  482. else:
  483. print(f"⚠️ 股票代码 {code} 格式不正确,已跳过")
  484. if valid_codes:
  485. print(f"📊 将预测以下股票: {', '.join(valid_codes)}")
  486. predictor.run(test_mode=True, test_stock_codes=valid_codes)
  487. else:
  488. print("❌ 没有有效的股票代码")
  489. else:
  490. # 正常交互模式
  491. predictor.run()
  492. if __name__ == "__main__":
  493. main()