You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1858 lines
76 KiB

1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. import json
  5. import plotly.graph_objects as go
  6. import plotly.utils
  7. from flask import Flask, render_template, request, jsonify
  8. from flask_cors import CORS
  9. import sys
  10. import warnings
  11. from datetime import datetime
  12. import baostock as bs
  13. import re
  14. warnings.filterwarnings('ignore')
  15. # Add project root directory to path
  16. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  17. try:
  18. from model import Kronos, KronosTokenizer, KronosPredictor
  19. MODEL_AVAILABLE = True
  20. except ImportError:
  21. MODEL_AVAILABLE = False
  22. print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
  23. app = Flask(__name__)
  24. CORS(app)
  25. # Global variables to store models
  26. tokenizer = None
  27. model = None
  28. predictor = None
  29. # 获取webui目录的路径
  30. WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))
  31. # 获取项目根目录(webui的父目录)
  32. BASE_DIR = os.path.dirname(WEBUI_DIR)
  33. AVAILABLE_MODELS = {
  34. 'kronos-mini': {
  35. 'name': 'Kronos-mini',
  36. 'model_id': os.path.join(BASE_DIR, 'models', 'Kronos-mini'),
  37. 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
  38. 'context_length': 2048,
  39. 'params': '4.1M',
  40. 'description': '轻量级模型,适合快速预测'
  41. },
  42. 'kronos-small': {
  43. 'name': 'Kronos-small',
  44. 'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-small'),
  45. 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
  46. 'context_length': 512,
  47. 'params': '24.7M',
  48. 'description': '小型模型,平衡性能和速度'
  49. },
  50. 'kronos-base': {
  51. 'name': 'Kronos-base',
  52. 'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-base'),
  53. 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
  54. 'context_length': 512,
  55. 'params': '102.3M',
  56. 'description': '基础模型,提供更好的预测质量'
  57. }
  58. }
  59. # Available model configurations
  60. # AVAILABLE_MODELS = {
  61. # 'kronos-mini': {
  62. # 'name': 'Kronos-mini',
  63. # 'model_id': 'models/Kronos-mini', # 本地路径
  64. # 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
  65. # 'context_length': 2048,
  66. # 'params': '4.1M',
  67. # 'description': '轻量级模型,适合快速预测'
  68. # },
  69. # 'kronos-small': {
  70. # 'name': 'Kronos-small',
  71. # 'model_id': 'models/NeoQuasarKronos-small', # 本地路径
  72. # 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
  73. # 'context_length': 512,
  74. # 'params': '24.7M',
  75. # 'description': '小型模型,平衡性能和速度'
  76. # },
  77. # 'kronos-base': {
  78. # 'name': 'Kronos-base',
  79. # 'model_id': 'models/NeoQuasarKronos-base', # 本地路径
  80. # 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
  81. # 'context_length': 512,
  82. # 'params': '102.3M',
  83. # 'description': '基础模型,提供更好的预测质量'
  84. # }
  85. # }
  86. def load_data_files():
  87. """Scan data directory and return available data files"""
  88. data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
  89. data_files = []
  90. if os.path.exists(data_dir):
  91. for file in os.listdir(data_dir):
  92. if file.endswith(('.csv', '.feather')):
  93. file_path = os.path.join(data_dir, file)
  94. file_size = os.path.getsize(file_path)
  95. data_files.append({
  96. 'name': file,
  97. 'path': file_path,
  98. 'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB"
  99. })
  100. return data_files
  101. def load_data_file(file_path):
  102. """Load data file"""
  103. try:
  104. if file_path.endswith('.csv'):
  105. df = pd.read_csv(file_path)
  106. elif file_path.endswith('.feather'):
  107. df = pd.read_feather(file_path)
  108. else:
  109. return None, "Unsupported file format"
  110. # Check required columns
  111. required_cols = ['open', 'high', 'low', 'close']
  112. if not all(col in df.columns for col in required_cols):
  113. return None, f"Missing required columns: {required_cols}"
  114. # Process timestamp column
  115. if 'timestamps' in df.columns:
  116. df['timestamps'] = pd.to_datetime(df['timestamps'])
  117. elif 'timestamp' in df.columns:
  118. df['timestamps'] = pd.to_datetime(df['timestamp'])
  119. elif 'date' in df.columns:
  120. # If column name is 'date', rename it to 'timestamps'
  121. df['timestamps'] = pd.to_datetime(df['date'])
  122. else:
  123. # If no timestamp column exists, create one
  124. df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H')
  125. # Ensure numeric columns are numeric type
  126. for col in ['open', 'high', 'low', 'close']:
  127. df[col] = pd.to_numeric(df[col], errors='coerce')
  128. # Process volume column (optional)
  129. if 'volume' in df.columns:
  130. df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
  131. # Process amount column (optional, but not used for prediction)
  132. if 'amount' in df.columns:
  133. df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
  134. # Remove rows containing NaN values
  135. df = df.dropna()
  136. return df, None
  137. except Exception as e:
  138. return None, f"Failed to load file: {str(e)}"
  139. def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
  140. """Save prediction results to file"""
  141. try:
  142. # Create prediction results directory
  143. results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
  144. os.makedirs(results_dir, exist_ok=True)
  145. # Generate filename
  146. timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
  147. filename = f'prediction_{timestamp}.json'
  148. filepath = os.path.join(results_dir, filename)
  149. # Prepare data for saving
  150. save_data = {
  151. 'timestamp': datetime.datetime.now().isoformat(),
  152. 'file_path': file_path,
  153. 'prediction_type': prediction_type,
  154. 'prediction_params': prediction_params,
  155. 'input_data_summary': {
  156. 'rows': len(input_data),
  157. 'columns': list(input_data.columns),
  158. 'price_range': {
  159. 'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
  160. 'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
  161. 'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
  162. 'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
  163. },
  164. 'last_values': {
  165. 'open': float(input_data['open'].iloc[-1]),
  166. 'high': float(input_data['high'].iloc[-1]),
  167. 'low': float(input_data['low'].iloc[-1]),
  168. 'close': float(input_data['close'].iloc[-1])
  169. }
  170. },
  171. 'prediction_results': prediction_results,
  172. 'actual_data': actual_data,
  173. 'analysis': {}
  174. }
  175. # If actual data exists, perform comparison analysis
  176. if actual_data and len(actual_data) > 0:
  177. # Calculate continuity analysis
  178. if len(prediction_results) > 0 and len(actual_data) > 0:
  179. last_pred = prediction_results[0] # First prediction point
  180. first_actual = actual_data[0] # First actual point
  181. save_data['analysis']['continuity'] = {
  182. 'last_prediction': {
  183. 'open': last_pred['open'],
  184. 'high': last_pred['high'],
  185. 'low': last_pred['low'],
  186. 'close': last_pred['close']
  187. },
  188. 'first_actual': {
  189. 'open': first_actual['open'],
  190. 'high': first_actual['high'],
  191. 'low': first_actual['low'],
  192. 'close': first_actual['close']
  193. },
  194. 'gaps': {
  195. 'open_gap': abs(last_pred['open'] - first_actual['open']),
  196. 'high_gap': abs(last_pred['high'] - first_actual['high']),
  197. 'low_gap': abs(last_pred['low'] - first_actual['low']),
  198. 'close_gap': abs(last_pred['close'] - first_actual['close'])
  199. },
  200. 'gap_percentages': {
  201. 'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
  202. 'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
  203. 'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
  204. 'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
  205. }
  206. }
  207. # Save to file
  208. with open(filepath, 'w', encoding='utf-8') as f:
  209. json.dump(save_data, f, indent=2, ensure_ascii=False)
  210. print(f"Prediction results saved to: {filepath}")
  211. return filepath
  212. except Exception as e:
  213. print(f"Failed to save prediction results: {e}")
  214. return None
  215. # def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
  216. # """Create prediction chart"""
  217. #
  218. # print(f"🔍 创建图表调试:")
  219. # print(f" 历史数据: {len(df) if df is not None else 0} 行")
  220. # print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行")
  221. # print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")
  222. #
  223. # # 确保数据不为空
  224. # if pred_df is None or len(pred_df) == 0:
  225. # print("⚠️ 警告: 预测数据为空!")
  226. # # 创建空图表
  227. # fig = go.Figure()
  228. # fig.update_layout(title='No prediction data available')
  229. # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  230. #
  231. # # 其余代码保持不变...
  232. #
  233. # # Use specified historical data start position, not always from the beginning of df
  234. # if historical_start_idx + lookback + pred_len <= len(df):
  235. # # Display lookback historical points + pred_len prediction points starting from specified position
  236. # historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]
  237. # prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)
  238. # else:
  239. # # If data is insufficient, adjust to maximum available range
  240. # available_lookback = min(lookback, len(df) - historical_start_idx)
  241. # available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
  242. # historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]
  243. # prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)
  244. #
  245. # # Create chart
  246. # fig = go.Figure()
  247. #
  248. # # Add historical data (candlestick chart)
  249. # fig.add_trace(go.Candlestick(
  250. # x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
  251. # open=historical_df['open'],
  252. # high=historical_df['high'],
  253. # low=historical_df['low'],
  254. # close=historical_df['close'],
  255. # name='Historical Data (400 data points)',
  256. # increasing_line_color='#26A69A',
  257. # decreasing_line_color='#EF5350'
  258. # ))
  259. #
  260. # # Add prediction data (candlestick chart)
  261. # if pred_df is not None and len(pred_df) > 0:
  262. # # Calculate prediction data timestamps - ensure continuity with historical data
  263. # if 'timestamps' in df.columns and len(historical_df) > 0:
  264. # # Start from the last timestamp of historical data, create prediction timestamps with the same time interval
  265. # last_timestamp = historical_df['timestamps'].iloc[-1]
  266. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
  267. #
  268. # pred_timestamps = pd.date_range(
  269. # start=last_timestamp + time_diff,
  270. # periods=len(pred_df),
  271. # freq=time_diff
  272. # )
  273. # else:
  274. # # If no timestamps, use index
  275. # pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
  276. #
  277. # fig.add_trace(go.Candlestick(
  278. # x=pred_timestamps,
  279. # open=pred_df['open'],
  280. # high=pred_df['high'],
  281. # low=pred_df['low'],
  282. # close=pred_df['close'],
  283. # name='Prediction Data (120 data points)',
  284. # increasing_line_color='#66BB6A',
  285. # decreasing_line_color='#FF7043'
  286. # ))
  287. #
  288. # # Add actual data for comparison (if exists)
  289. # if actual_df is not None and len(actual_df) > 0:
  290. # # Actual data should be in the same time period as prediction data
  291. # if 'timestamps' in df.columns:
  292. # # Actual data should use the same timestamps as prediction data to ensure time alignment
  293. # if 'pred_timestamps' in locals():
  294. # actual_timestamps = pred_timestamps
  295. # else:
  296. # # If no prediction timestamps, calculate from the last timestamp of historical data
  297. # if len(historical_df) > 0:
  298. # last_timestamp = historical_df['timestamps'].iloc[-1]
  299. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
  300. # actual_timestamps = pd.date_range(
  301. # start=last_timestamp + time_diff,
  302. # periods=len(actual_df),
  303. # freq=time_diff
  304. # )
  305. # else:
  306. # actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  307. # else:
  308. # actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  309. #
  310. # fig.add_trace(go.Candlestick(
  311. # x=actual_timestamps,
  312. # open=actual_df['open'],
  313. # high=actual_df['high'],
  314. # low=actual_df['low'],
  315. # close=actual_df['close'],
  316. # name='Actual Data (120 data points)',
  317. # increasing_line_color='#FF9800',
  318. # decreasing_line_color='#F44336'
  319. # ))
  320. #
  321. # # Update layout
  322. # fig.update_layout(
  323. # title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
  324. # xaxis_title='Time',
  325. # yaxis_title='Price',
  326. # template='plotly_white',
  327. # height=600,
  328. # showlegend=True
  329. # )
  330. #
  331. # # Ensure x-axis time continuity
  332. # if 'timestamps' in historical_df.columns:
  333. # # Get all timestamps and sort them
  334. # all_timestamps = []
  335. # if len(historical_df) > 0:
  336. # all_timestamps.extend(historical_df['timestamps'])
  337. # if 'pred_timestamps' in locals():
  338. # all_timestamps.extend(pred_timestamps)
  339. # if 'actual_timestamps' in locals():
  340. # all_timestamps.extend(actual_timestamps)
  341. #
  342. # if all_timestamps:
  343. # all_timestamps = sorted(all_timestamps)
  344. # fig.update_xaxes(
  345. # range=[all_timestamps[0], all_timestamps[-1]],
  346. # rangeslider_visible=False,
  347. # type='date'
  348. # )
  349. #
  350. # # 修改这一行:
  351. # # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  352. #
  353. # # 改为:
  354. # try:
  355. # chart_json = fig.to_json()
  356. # print(f"✅ 图表JSON序列化成功,长度: {len(chart_json)}")
  357. # return chart_json
  358. # except Exception as e:
  359. # print(f"❌ 图表序列化失败: {e}")
  360. # # 返回一个简单的错误图表
  361. # error_fig = go.Figure()
  362. # error_fig.update_layout(title='Chart Rendering Error')
  363. # return error_fig.to_json()
  364. def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
  365. """Create prediction chart"""
  366. print(f"🔍 创建图表调试:")
  367. print(f" 历史数据: {len(df) if df is not None else 0} 行")
  368. print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行")
  369. print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")
  370. # 确保数据不为空
  371. if pred_df is None or len(pred_df) == 0:
  372. print("⚠️ 警告: 预测数据为空!")
  373. # 创建空图表
  374. fig = go.Figure()
  375. fig.update_layout(title='No prediction data available')
  376. return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  377. # Use specified historical data start position, not always from the beginning of df
  378. if historical_start_idx + lookback + pred_len <= len(df):
  379. # Display lookback historical points + pred_len prediction points starting from specified position
  380. historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
  381. prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len)
  382. else:
  383. # If data is insufficient, adjust to maximum available range
  384. available_lookback = min(lookback, len(df) - historical_start_idx)
  385. available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
  386. historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
  387. prediction_range = range(historical_start_idx + available_lookback,
  388. historical_start_idx + available_lookback + available_pred_len)
  389. # Create chart
  390. fig = go.Figure()
  391. # Add historical data (candlestick chart)
  392. fig.add_trace(go.Candlestick(
  393. x=historical_df[
  394. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  395. open=historical_df['open'].tolist(),
  396. high=historical_df['high'].tolist(),
  397. low=historical_df['low'].tolist(),
  398. close=historical_df['close'].tolist(),
  399. name='Historical Data (400 data points)',
  400. increasing_line_color='#26A69A',
  401. decreasing_line_color='#EF5350'
  402. ))
  403. # Add prediction data (candlestick chart)
  404. if pred_df is not None and len(pred_df) > 0:
  405. # Calculate prediction data timestamps - ensure continuity with historical data
  406. if 'timestamps' in df.columns and len(historical_df) > 0:
  407. # Start from the last timestamp of historical data, create prediction timestamps with the same time interval
  408. last_timestamp = historical_df['timestamps'].iloc[-1]
  409. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
  410. pred_timestamps = pd.date_range(
  411. start=last_timestamp + time_diff,
  412. periods=len(pred_df),
  413. freq=time_diff
  414. )
  415. else:
  416. # If no timestamps, use index
  417. pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
  418. fig.add_trace(go.Candlestick(
  419. x=pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps),
  420. open=pred_df['open'].tolist(),
  421. high=pred_df['high'].tolist(),
  422. low=pred_df['low'].tolist(),
  423. close=pred_df['close'].tolist(),
  424. name='Prediction Data (120 data points)',
  425. increasing_line_color='#66BB6A',
  426. decreasing_line_color='#FF7043'
  427. ))
  428. # Add actual data for comparison (if exists)
  429. if actual_df is not None and len(actual_df) > 0:
  430. # Actual data should be in the same time period as prediction data
  431. if 'timestamps' in df.columns:
  432. # Actual data should use the same timestamps as prediction data to ensure time alignment
  433. if 'pred_timestamps' in locals():
  434. actual_timestamps = pred_timestamps
  435. else:
  436. # If no prediction timestamps, calculate from the last timestamp of historical data
  437. if len(historical_df) > 0:
  438. last_timestamp = historical_df['timestamps'].iloc[-1]
  439. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
  440. hours=1)
  441. actual_timestamps = pd.date_range(
  442. start=last_timestamp + time_diff,
  443. periods=len(actual_df),
  444. freq=time_diff
  445. )
  446. else:
  447. actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  448. else:
  449. actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  450. fig.add_trace(go.Candlestick(
  451. x=actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps),
  452. open=actual_df['open'].tolist(),
  453. high=actual_df['high'].tolist(),
  454. low=actual_df['low'].tolist(),
  455. close=actual_df['close'].tolist(),
  456. name='Actual Data (120 data points)',
  457. increasing_line_color='#FF9800',
  458. decreasing_line_color='#F44336'
  459. ))
  460. # Update layout
  461. fig.update_layout(
  462. title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
  463. xaxis_title='Time',
  464. yaxis_title='Price',
  465. template='plotly_white',
  466. height=600,
  467. showlegend=True
  468. )
  469. # Ensure x-axis time continuity
  470. if 'timestamps' in historical_df.columns:
  471. # Get all timestamps and sort them
  472. all_timestamps = []
  473. if len(historical_df) > 0:
  474. all_timestamps.extend(historical_df['timestamps'].tolist())
  475. if 'pred_timestamps' in locals():
  476. all_timestamps.extend(
  477. pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps))
  478. if 'actual_timestamps' in locals():
  479. all_timestamps.extend(
  480. actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps))
  481. if all_timestamps:
  482. all_timestamps = sorted(all_timestamps)
  483. fig.update_xaxes(
  484. range=[all_timestamps[0], all_timestamps[-1]],
  485. rangeslider_visible=False,
  486. type='date'
  487. )
  488. # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  489. try:
  490. chart_json = fig.to_json()
  491. print(f"✅ 图表数据序列化完成,长度: {len(chart_json)}")
  492. return chart_json
  493. except Exception as e:
  494. print(f"❌ 图表序列化失败: {e}")
  495. error_fig = go.Figure()
  496. error_fig.update_layout(title='Chart Rendering Error')
  497. return error_fig.to_json()
  498. # 计算指标
  499. def calculate_indicators(df):
  500. indicators = {}
  501. # 计算移动平均线 (MA)
  502. indicators['ma5'] = df['close'].rolling(window=5).mean()
  503. indicators['ma10'] = df['close'].rolling(window=10).mean()
  504. indicators['ma20'] = df['close'].rolling(window=20).mean()
  505. # 计算MACD
  506. exp12 = df['close'].ewm(span=12, adjust=False).mean()
  507. exp26 = df['close'].ewm(span=26, adjust=False).mean()
  508. indicators['macd'] = exp12 - exp26
  509. indicators['signal'] = indicators['macd'].ewm(span=9, adjust=False).mean()
  510. indicators['macd_hist'] = indicators['macd'] - indicators['signal']
  511. # 计算RSI
  512. delta = df['close'].diff()
  513. gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
  514. loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
  515. rs = gain / loss
  516. indicators['rsi'] = 100 - (100 / (1 + rs))
  517. # 计算布林带
  518. indicators['bb_mid'] = df['close'].rolling(window=20).mean()
  519. indicators['bb_std'] = df['close'].rolling(window=20).std()
  520. indicators['bb_upper'] = indicators['bb_mid'] + 2 * indicators['bb_std']
  521. indicators['bb_lower'] = indicators['bb_mid'] - 2 * indicators['bb_std']
  522. # 计算随机震荡指标
  523. low_min = df['low'].rolling(window=14).min()
  524. high_max = df['high'].rolling(window=14).max()
  525. indicators['stoch_k'] = 100 * ((df['close'] - low_min) / (high_max - low_min))
  526. indicators['stoch_d'] = indicators['stoch_k'].rolling(window=3).mean()
  527. # 滚动窗口均值策略
  528. indicators['rwms_window'] = 90
  529. indicators['rwms_mean'] = df['close'].rolling(window=90).mean()
  530. indicators['rwms_signal'] = (df['close'] > indicators['rwms_mean']).astype(int)
  531. # 三重指数平均(TRIX)策略
  532. # 计算收盘价的EMA
  533. ema1 = df['close'].ewm(span=12, adjust=False).mean()
  534. # 计算EMA的EMA
  535. ema2 = ema1.ewm(span=12, adjust=False).mean()
  536. # 计算EMA的EMA的EMA
  537. ema3 = ema2.ewm(span=12, adjust=False).mean()
  538. # 计算TRIX
  539. indicators['trix'] = (ema3 - ema3.shift(1)) / ema3.shift(1) * 100
  540. # 计算信号线
  541. indicators['trix_signal'] = indicators['trix'].ewm(span=9, adjust=False).mean()
  542. return indicators
  543. # 创建图表
  544. def create_technical_chart(df, pred_df, lookback, pred_len, diagram_type, actual_df=None, historical_start_idx=0):
  545. print(f" 🔍 数据内容: {len(df) if df is not None else 0} 行")
  546. print(f" 🔍 图表类型: {diagram_type}")
  547. # 数据范围
  548. if historical_start_idx + lookback <= len(df):
  549. historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
  550. else:
  551. available_lookback = min(lookback, len(df) - historical_start_idx)
  552. historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
  553. # 计算指标
  554. historical_indicators = calculate_indicators(historical_df)
  555. fig = go.Figure()
  556. # 成交量图表
  557. if diagram_type == 'Volume Chart (VOL)':
  558. fig.add_trace(go.Bar(
  559. x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  560. y = historical_df['volume'].tolist() if 'volume' in historical_df.columns else [],
  561. name = 'Historical Volume',
  562. marker_color='#42A5F5'
  563. ))
  564. if actual_df is not None and len(actual_df) > 0 and 'volume' in actual_df.columns:
  565. if 'timestamps' in df.columns and len(historical_df) > 0:
  566. last_timestamp = historical_df['timestamps'].iloc[-1]
  567. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
  568. hours=1)
  569. actual_timestamps = pd.date_range(start=last_timestamp + time_diff, periods=len(actual_df),freq=time_diff)
  570. else:
  571. actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  572. fig.add_trace(go.Bar(
  573. x = actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps),
  574. y = actual_df['volume'].tolist(),
  575. name = 'Actual Volume',
  576. marker_color='#FF9800'
  577. ))
  578. fig.update_layout(yaxis_title='Volume')
  579. # 移动平均线
  580. elif diagram_type == 'Moving Average (MA)':
  581. fig.add_trace(go.Scatter(
  582. x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  583. y = historical_indicators['ma5'],
  584. name='MA5',
  585. line=dict(color='#26A69A', width=1)
  586. ))
  587. fig.add_trace(go.Scatter(
  588. x = historical_df[
  589. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  590. y = historical_indicators['ma10'],
  591. name = 'MA10',
  592. line = dict(color = '#42A5F5', width = 1)
  593. ))
  594. fig.add_trace(go.Scatter(
  595. x = historical_df[
  596. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  597. y = historical_indicators['ma20'],
  598. name = 'MA20',
  599. line = dict(color = '#7E57C2', width = 1)
  600. ))
  601. fig.add_trace(go.Scatter(
  602. x = historical_df[
  603. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  604. y = historical_df['close'],
  605. name = 'Close Price',
  606. line = dict(color = '#212121', width = 1, dash = 'dash')
  607. ))
  608. fig.update_layout(yaxis_title = 'Price')
  609. # MACD指标
  610. elif diagram_type == 'MACD Indicator (MACD)':
  611. fig.add_trace(go.Scatter(
  612. x = historical_df[
  613. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  614. y = historical_indicators['macd'],
  615. name = 'MACD',
  616. line = dict(color = '#26A69A', width = 1)
  617. ))
  618. fig.add_trace(go.Scatter(
  619. x = historical_df[
  620. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  621. y = historical_indicators['signal'],
  622. name = 'Signal',
  623. line = dict(color = '#EF5350', width = 1)
  624. ))
  625. fig.add_trace(go.Bar(
  626. x = historical_df[
  627. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  628. y = historical_indicators['macd_hist'],
  629. name = 'MACD Histogram',
  630. marker_color = '#42A5F5'
  631. ))
  632. # 零轴线
  633. fig.add_hline(y = 0, line_dash = "dash", line_color = "gray")
  634. fig.update_layout(yaxis_title = 'MACD')
  635. # RSI指标
  636. elif diagram_type == 'RSI Indicator (RSI)':
  637. fig.add_trace(go.Scatter(
  638. x = historical_df[
  639. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  640. y = historical_indicators['rsi'],
  641. name = 'RSI',
  642. line = dict(color = '#26A69A', width = 1)
  643. ))
  644. # 超买超卖线
  645. fig.add_hline(y = 70, line_dash = "dash", line_color = "red", name = 'Overbought')
  646. fig.add_hline(y = 30, line_dash = "dash", line_color = "green", name = 'Oversold')
  647. fig.update_layout(yaxis_title = 'RSI', yaxis_range = [0, 100])
  648. # 布林带
  649. elif diagram_type == 'Bollinger Bands (BB)':
  650. fig.add_trace(go.Scatter(
  651. x = historical_df[
  652. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  653. y = historical_indicators['bb_upper'],
  654. name = 'Upper Band',
  655. line = dict(color = '#EF5350', width = 1)
  656. ))
  657. fig.add_trace(go.Scatter(
  658. x = historical_df[
  659. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  660. y = historical_indicators['bb_mid'],
  661. name = 'Middle Band (MA20)',
  662. line = dict(color = '#42A5F5', width = 1)
  663. ))
  664. fig.add_trace(go.Scatter(
  665. x = historical_df[
  666. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  667. y = historical_indicators['bb_lower'],
  668. name = 'Lower Band',
  669. line = dict(color = '#26A69A', width = 1)
  670. ))
  671. fig.add_trace(go.Scatter(
  672. x = historical_df[
  673. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  674. y = historical_df['close'],
  675. name = 'Close Price',
  676. line = dict(color = '#212121', width = 1)
  677. ))
  678. fig.update_layout(yaxis_title = 'Price')
  679. # 随机震荡指标
  680. elif diagram_type == 'Stochastic Oscillator (STOCH)':
  681. fig.add_trace(go.Scatter(
  682. x = historical_df[
  683. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  684. y = historical_indicators['stoch_k'],
  685. name = '%K',
  686. line = dict(color = '#26A69A', width = 1)
  687. ))
  688. fig.add_trace(go.Scatter(
  689. x = historical_df[
  690. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  691. y = historical_indicators['stoch_d'],
  692. name = '%D',
  693. line = dict(color = '#EF5350', width = 1)
  694. ))
  695. fig.add_hline(y = 80, line_dash = "dash", line_color = "red", name = 'Overbought')
  696. fig.add_hline(y = 20, line_dash = "dash", line_color = "green", name = 'Oversold')
  697. fig.update_layout(yaxis_title = 'Stochastic', yaxis_range = [0, 100])
  698. # 滚动窗口均值策略
  699. elif diagram_type == 'Rolling Window Mean Strategy':
  700. fig.add_trace(go.Scatter(
  701. x = historical_df[
  702. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  703. y = historical_df['close'],
  704. name = 'Close Price',
  705. line = dict(color = '#212121', width = 1.5)
  706. ))
  707. fig.add_trace(go.Scatter(
  708. x = historical_df[
  709. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  710. y = historical_indicators['rwms_mean'],
  711. name = f'Rolling Mean ({historical_indicators["rwms_window"]} periods)',
  712. line = dict(color = '#42A5F5', width = 1.5, dash = 'dash')
  713. ))
  714. buy_signals = historical_df[historical_indicators['rwms_signal'] == 1]
  715. fig.add_trace(go.Scatter(
  716. x = buy_signals['timestamps'].tolist() if 'timestamps' in buy_signals.columns else buy_signals.index.tolist(),
  717. y = buy_signals['close'],
  718. mode = 'markers',
  719. name = 'Buy Signal',
  720. marker = dict(color = '#26A69A', size = 8, symbol = 'triangle-up')
  721. ))
  722. sell_signals = historical_df[historical_indicators['rwms_signal'] == 0]
  723. fig.add_trace(go.Scatter(
  724. x = sell_signals[
  725. 'timestamps'].tolist() if 'timestamps' in sell_signals.columns else sell_signals.index.tolist(),
  726. y = sell_signals['close'],
  727. mode = 'markers',
  728. name = 'Sell Signal',
  729. marker = dict(color = '#EF5350', size = 8, symbol = 'triangle-down')
  730. ))
  731. fig.update_layout(
  732. yaxis_title = 'Price',
  733. title = f'Rolling Window Mean Strategy (Window Size: {historical_indicators["rwms_window"]})'
  734. )
  735. # TRIX指标图表
  736. elif diagram_type == 'TRIX Indicator (TRIX)':
  737. fig.add_trace(go.Scatter(
  738. x=historical_df[
  739. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  740. y=historical_indicators['trix'],
  741. name='TRIX',
  742. line=dict(color='#26A69A', width=1)
  743. ))
  744. fig.add_trace(go.Scatter(
  745. x=historical_df[
  746. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  747. y=historical_indicators['trix_signal'],
  748. name='TRIX Signal',
  749. line=dict(color='#EF5350', width=1)
  750. ))
  751. fig.add_hline(y=0, line_dash="dash", line_color="gray")
  752. fig.update_layout(
  753. yaxis_title='TRIX (%)',
  754. title='Triple Exponential Average (TRIX) Strategy'
  755. )
  756. # 布局设置
  757. fig.update_layout(
  758. title = f'{diagram_type} - Technical Indicator (Real Data Only)',
  759. xaxis_title = 'Time',
  760. template = 'plotly_white',
  761. height = 400,
  762. showlegend = True,
  763. margin = dict(t = 50, b = 30)
  764. )
  765. if 'timestamps' in historical_df.columns:
  766. all_timestamps = historical_df['timestamps'].tolist()
  767. if actual_df is not None and len(actual_df) > 0 and 'timestamps' in df.columns:
  768. if 'actual_timestamps' in locals():
  769. all_timestamps.extend(actual_timestamps.tolist())
  770. if all_timestamps:
  771. all_timestamps = sorted(all_timestamps)
  772. fig.update_xaxes(
  773. range=[all_timestamps[0], all_timestamps[-1]],
  774. rangeslider_visible=False,
  775. type='date'
  776. )
  777. try:
  778. chart_json = fig.to_json()
  779. print(f"✅ 技术指标图表序列化完成,长度: {len(chart_json)}")
  780. return chart_json
  781. except Exception as e:
  782. print(f"❌ 技术指标图表序列化失败: {e}")
  783. error_fig = go.Figure()
  784. error_fig.update_layout(title='Chart Rendering Error')
  785. return error_fig.to_json()
  786. @app.route('/')
  787. def index():
  788. """Home page"""
  789. return render_template('index.html')
  790. @app.route('/api/data-files')
  791. def get_data_files():
  792. """Get available data file list"""
  793. data_files = load_data_files()
  794. return jsonify(data_files)
  795. @app.route('/api/load-data', methods=['POST'])
  796. def load_data():
  797. """Load data file"""
  798. try:
  799. data = request.get_json()
  800. file_path = data.get('file_path')
  801. if not file_path:
  802. return jsonify({'error': 'File path cannot be empty'}), 400
  803. df, error = load_data_file(file_path)
  804. if error:
  805. return jsonify({'error': error}), 400
  806. # Detect data time frequency
  807. def detect_timeframe(df):
  808. if len(df) < 2:
  809. return "Unknown"
  810. time_diffs = []
  811. for i in range(1, min(10, len(df))): # Check first 10 time differences
  812. diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1]
  813. time_diffs.append(diff)
  814. if not time_diffs:
  815. return "Unknown"
  816. # Calculate average time difference
  817. avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
  818. # Convert to readable format
  819. if avg_diff < pd.Timedelta(minutes=1):
  820. return f"{avg_diff.total_seconds():.0f} seconds"
  821. elif avg_diff < pd.Timedelta(hours=1):
  822. return f"{avg_diff.total_seconds() / 60:.0f} minutes"
  823. elif avg_diff < pd.Timedelta(days=1):
  824. return f"{avg_diff.total_seconds() / 3600:.0f} hours"
  825. else:
  826. return f"{avg_diff.days} days"
  827. # Return data information
  828. data_info = {
  829. 'rows': len(df),
  830. 'columns': list(df.columns),
  831. 'start_date': df['timestamps'].min().isoformat() if 'timestamps' in df.columns else 'N/A',
  832. 'end_date': df['timestamps'].max().isoformat() if 'timestamps' in df.columns else 'N/A',
  833. 'price_range': {
  834. 'min': float(df[['open', 'high', 'low', 'close']].min().min()),
  835. 'max': float(df[['open', 'high', 'low', 'close']].max().max())
  836. },
  837. 'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
  838. 'timeframe': detect_timeframe(df)
  839. }
  840. return jsonify({
  841. 'success': True,
  842. 'data_info': data_info,
  843. 'message': f'Successfully loaded data, total {len(df)} rows'
  844. })
  845. except Exception as e:
  846. return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
  847. # @app.route('/api/predict', methods=['POST'])
  848. # def predict():
  849. # """Perform prediction"""
  850. # try:
  851. # data = request.get_json()
  852. # file_path = data.get('file_path')
  853. # lookback = int(data.get('lookback', 400))
  854. # pred_len = int(data.get('pred_len', 120))
  855. #
  856. # # Get prediction quality parameters
  857. # temperature = float(data.get('temperature', 1.0))
  858. # top_p = float(data.get('top_p', 0.9))
  859. # sample_count = int(data.get('sample_count', 1))
  860. #
  861. # if not file_path:
  862. # return jsonify({'error': 'File path cannot be empty'}), 400
  863. #
  864. # # Load data
  865. # df, error = load_data_file(file_path)
  866. # if error:
  867. # return jsonify({'error': error}), 400
  868. #
  869. # if len(df) < lookback:
  870. # return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
  871. #
  872. # # Perform prediction
  873. # if MODEL_AVAILABLE and predictor is not None:
  874. # try:
  875. # # Use real Kronos model
  876. # # Only use necessary columns: OHLCV, excluding amount
  877. # required_cols = ['open', 'high', 'low', 'close']
  878. # if 'volume' in df.columns:
  879. # required_cols.append('volume')
  880. #
  881. # # Process time period selection
  882. # start_date = data.get('start_date')
  883. #
  884. # if start_date:
  885. # # Custom time period - fix logic: use data within selected window
  886. # start_dt = pd.to_datetime(start_date)
  887. #
  888. # # Find data after start time
  889. # mask = df['timestamps'] >= start_dt
  890. # time_range_df = df[mask]
  891. #
  892. # # Ensure sufficient data: lookback + pred_len
  893. # if len(time_range_df) < lookback + pred_len:
  894. # return jsonify({'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
  895. #
  896. # # Use first lookback data points within selected window for prediction
  897. # x_df = time_range_df.iloc[:lookback][required_cols]
  898. # x_timestamp = time_range_df.iloc[:lookback]['timestamps']
  899. #
  900. # # Use last pred_len data points within selected window as actual values
  901. # y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']
  902. #
  903. # # Calculate actual time period length
  904. # start_timestamp = time_range_df['timestamps'].iloc[0]
  905. # end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]
  906. # time_span = end_timestamp - start_timestamp
  907. #
  908. # prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
  909. # else:
  910. # # Use latest data
  911. # x_df = df.iloc[:lookback][required_cols]
  912. # x_timestamp = df.iloc[:lookback]['timestamps']
  913. # y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
  914. # prediction_type = "Kronos model prediction (latest data)"
  915. #
  916. # # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
  917. # if isinstance(x_timestamp, pd.DatetimeIndex):
  918. # x_timestamp = pd.Series(x_timestamp, name='timestamps')
  919. # if isinstance(y_timestamp, pd.DatetimeIndex):
  920. # y_timestamp = pd.Series(y_timestamp, name='timestamps')
  921. #
  922. # # # 在 pred_df = predictor.predict(...) 之前添加:
  923. # # print("🔍 调试预测输入:")
  924. # # print(f"x_df 类型: {type(x_df)}")
  925. # # print(f"x_df 形状: {x_df.shape}")
  926. # # print(f"x_df 列名: {x_df.columns.tolist()}")
  927. # # print(f"x_df 数据类型: {x_df.dtypes}")
  928. # #
  929. # # print(f"x_timestamp 类型: {type(x_timestamp)}")
  930. # # print(f"x_timestamp 长度: {len(x_timestamp)}")
  931. # #
  932. # # print(f"y_timestamp 类型: {type(y_timestamp)}")
  933. # # print(f"y_timestamp 长度: {len(y_timestamp)}")
  934. # #
  935. # # # 检查数据内容
  936. # # print("x_df 前5行:")
  937. # # print(x_df.head())
  938. # #
  939. # # # 在调用 predict 前确保数据格式正确
  940. # # print(f"x_df 实际形状: {x_df.shape}") # 确认是 (400, 5)
  941. # # print(f"x_df 数值类型: {x_df.values.dtype}")
  942. # #
  943. # # # 确保没有隐藏的索引列
  944. # # x_df_clean = x_df.reset_index(drop=True)
  945. # # print(f"重置索引后形状: {x_df_clean.shape}")
  946. # #
  947. # # # 在调用 predict 之前添加更详细的调试
  948. # # print("🔍 深入调试 KronosPredictor:")
  949. # #
  950. # # # 检查 predictor 的属性
  951. # # print(f"predictor 类型: {type(predictor)}")
  952. # # print(f"predictor 设备: {getattr(predictor, 'device', 'unknown')}")
  953. # # print(f"predictor max_context: {getattr(predictor, 'max_context', 'unknown')}")
  954. # #
  955. # # # 检查模型输入维度
  956. # # if hasattr(predictor, 'model'):
  957. # # model = predictor.model
  958. # # print(f"模型参数示例:")
  959. # # for name, param in model.named_parameters():
  960. # # if 'weight' in name and param.dim() == 2:
  961. # # print(f" {name}: {param.shape}")
  962. # # break
  963. # #
  964. # # # 尝试手动准备数据
  965. # # try:
  966. # # # 将数据转换为 tensor 看看维度
  967. # # import torch
  968. # # x_tensor = torch.tensor(x_df.values, dtype=torch.float32)
  969. # # print(f"Tensor 形状: {x_tensor.shape}")
  970. # #
  971. # # # 检查 tokenizer 的输入维度
  972. # # if hasattr(predictor, 'tokenizer'):
  973. # # tokenizer = predictor.tokenizer
  974. # # print(f"tokenizer 输入维度: {getattr(tokenizer, 'd_in', 'unknown')}")
  975. # #
  976. # # except Exception as e:
  977. # # print(f"Tensor 转换错误: {e}")
  978. # #
  979. # # # 在 predict 调用前测试 tokenizer
  980. # # try:
  981. # # # 测试 tokenizer 是否能正确处理数据
  982. # # test_data = x_df.values # (400, 5)
  983. # # print(f"测试数据形状: {test_data.shape}")
  984. # #
  985. # # # 尝试手动调用 tokenizer
  986. # # if hasattr(predictor.tokenizer, 'encode'):
  987. # # encoded = predictor.tokenizer.encode(test_data)
  988. # # print(f"Tokenized 数据形状: {encoded.shape}")
  989. # # else:
  990. # # print("Tokenizer 没有 encode 方法")
  991. # #
  992. # # except Exception as e:
  993. # # print(f"Tokenizer 测试错误: {e}")
  994. #
  995. # pred_df = predictor.predict(
  996. # df=x_df,
  997. # x_timestamp=x_timestamp,
  998. # y_timestamp=y_timestamp,
  999. # pred_len=pred_len,
  1000. # T=temperature,
  1001. # top_p=top_p,
  1002. # sample_count=sample_count
  1003. # )
  1004. #
  1005. # except Exception as e:
  1006. # return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
  1007. # else:
  1008. # return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
  1009. #
  1010. # # Prepare actual data for comparison (if exists)
  1011. # actual_data = []
  1012. # actual_df = None
  1013. #
  1014. # if start_date: # Custom time period
  1015. # # Fix logic: use data within selected window
  1016. # # Prediction uses first 400 data points within selected window
  1017. # # Actual data should be last 120 data points within selected window
  1018. # start_dt = pd.to_datetime(start_date)
  1019. #
  1020. # # Find data starting from start_date
  1021. # mask = df['timestamps'] >= start_dt
  1022. # time_range_df = df[mask]
  1023. #
  1024. # if len(time_range_df) >= lookback + pred_len:
  1025. # # Get last 120 data points within selected window as actual values
  1026. # actual_df = time_range_df.iloc[lookback:lookback+pred_len]
  1027. #
  1028. # for i, (_, row) in enumerate(actual_df.iterrows()):
  1029. # actual_data.append({
  1030. # 'timestamp': row['timestamps'].isoformat(),
  1031. # 'open': float(row['open']),
  1032. # 'high': float(row['high']),
  1033. # 'low': float(row['low']),
  1034. # 'close': float(row['close']),
  1035. # 'volume': float(row['volume']) if 'volume' in row else 0,
  1036. # 'amount': float(row['amount']) if 'amount' in row else 0
  1037. # })
  1038. # else: # Latest data
  1039. # # Prediction uses first 400 data points
  1040. # # Actual data should be 120 data points after first 400 data points
  1041. # if len(df) >= lookback + pred_len:
  1042. # actual_df = df.iloc[lookback:lookback+pred_len]
  1043. # for i, (_, row) in enumerate(actual_df.iterrows()):
  1044. # actual_data.append({
  1045. # 'timestamp': row['timestamps'].isoformat(),
  1046. # 'open': float(row['open']),
  1047. # 'high': float(row['high']),
  1048. # 'low': float(row['low']),
  1049. # 'close': float(row['close']),
  1050. # 'volume': float(row['volume']) if 'volume' in row else 0,
  1051. # 'amount': float(row['amount']) if 'amount' in row else 0
  1052. # })
  1053. #
  1054. # # Create chart - pass historical data start position
  1055. # if start_date:
  1056. # # Custom time period: find starting position of historical data in original df
  1057. # start_dt = pd.to_datetime(start_date)
  1058. # mask = df['timestamps'] >= start_dt
  1059. # historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
  1060. # else:
  1061. # # Latest data: start from beginning
  1062. # historical_start_idx = 0
  1063. #
  1064. # chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
  1065. #
  1066. # # Prepare prediction result data - fix timestamp calculation logic
  1067. # if 'timestamps' in df.columns:
  1068. # if start_date:
  1069. # # Custom time period: use selected window data to calculate timestamps
  1070. # start_dt = pd.to_datetime(start_date)
  1071. # mask = df['timestamps'] >= start_dt
  1072. # time_range_df = df[mask]
  1073. #
  1074. # if len(time_range_df) >= lookback:
  1075. # # Calculate prediction timestamps starting from last time point of selected window
  1076. # last_timestamp = time_range_df['timestamps'].iloc[lookback-1]
  1077. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1078. # future_timestamps = pd.date_range(
  1079. # start=last_timestamp + time_diff,
  1080. # periods=pred_len,
  1081. # freq=time_diff
  1082. # )
  1083. # else:
  1084. # future_timestamps = []
  1085. # else:
  1086. # # Latest data: calculate from last time point of entire data file
  1087. # last_timestamp = df['timestamps'].iloc[-1]
  1088. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1089. # future_timestamps = pd.date_range(
  1090. # start=last_timestamp + time_diff,
  1091. # periods=pred_len,
  1092. # freq=time_diff
  1093. # )
  1094. # else:
  1095. # future_timestamps = range(len(df), len(df) + pred_len)
  1096. #
  1097. # prediction_results = []
  1098. # for i, (_, row) in enumerate(pred_df.iterrows()):
  1099. # prediction_results.append({
  1100. # 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
  1101. # 'open': float(row['open']),
  1102. # 'high': float(row['high']),
  1103. # 'low': float(row['low']),
  1104. # 'close': float(row['close']),
  1105. # 'volume': float(row['volume']) if 'volume' in row else 0,
  1106. # 'amount': float(row['amount']) if 'amount' in row else 0
  1107. # })
  1108. #
  1109. # # Save prediction results to file
  1110. # try:
  1111. # save_prediction_results(
  1112. # file_path=file_path,
  1113. # prediction_type=prediction_type,
  1114. # prediction_results=prediction_results,
  1115. # actual_data=actual_data,
  1116. # input_data=x_df,
  1117. # prediction_params={
  1118. # 'lookback': lookback,
  1119. # 'pred_len': pred_len,
  1120. # 'temperature': temperature,
  1121. # 'top_p': top_p,
  1122. # 'sample_count': sample_count,
  1123. # 'start_date': start_date if start_date else 'latest'
  1124. # }
  1125. # )
  1126. # except Exception as e:
  1127. # print(f"Failed to save prediction results: {e}")
  1128. #
  1129. # return jsonify({
  1130. # 'success': True,
  1131. # 'prediction_type': prediction_type,
  1132. # 'chart': chart_json,
  1133. # 'prediction_results': prediction_results,
  1134. # 'actual_data': actual_data,
  1135. # 'has_comparison': len(actual_data) > 0,
  1136. # 'message': f'Prediction completed, generated {pred_len} prediction points' + (f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
  1137. # })
  1138. #
  1139. # except Exception as e:
  1140. # return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
  1141. @app.route('/api/predict', methods=['POST'])
  1142. def predict():
  1143. """Perform prediction"""
  1144. try:
  1145. data = request.get_json()
  1146. file_path = data.get('file_path')
  1147. lookback = int(data.get('lookback', 400))
  1148. pred_len = int(data.get('pred_len', 120))
  1149. # Get prediction quality parameters
  1150. temperature = float(data.get('temperature', 1.0))
  1151. top_p = float(data.get('top_p', 0.9))
  1152. sample_count = int(data.get('sample_count', 1))
  1153. if not file_path:
  1154. return jsonify({'error': 'File path cannot be empty'}), 400
  1155. # Load data
  1156. df, error = load_data_file(file_path)
  1157. if error:
  1158. return jsonify({'error': error}), 400
  1159. if len(df) < lookback:
  1160. return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
  1161. # Perform prediction
  1162. if MODEL_AVAILABLE and predictor is not None:
  1163. try:
  1164. # Use real Kronos model
  1165. # Only use necessary columns: OHLCV + amount
  1166. required_cols = ['open', 'high', 'low', 'close', 'volume', 'amount']
  1167. # Process time period selection
  1168. start_date = data.get('start_date')
  1169. if start_date:
  1170. # Custom time period - fix logic: use data within selected window
  1171. start_dt = pd.to_datetime(start_date)
  1172. # Find data after start time
  1173. mask = df['timestamps'] >= start_dt
  1174. time_range_df = df[mask]
  1175. # Ensure sufficient data: lookback + pred_len
  1176. if len(time_range_df) < lookback + pred_len:
  1177. return jsonify({
  1178. 'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
  1179. # Use first lookback data points within selected window for prediction
  1180. x_df = time_range_df.iloc[:lookback][required_cols]
  1181. x_timestamp = time_range_df.iloc[:lookback]['timestamps']
  1182. # Use last pred_len data points within selected window as actual values
  1183. y_timestamp = time_range_df.iloc[lookback:lookback + pred_len]['timestamps']
  1184. # Calculate actual time period length
  1185. start_timestamp = time_range_df['timestamps'].iloc[0]
  1186. end_timestamp = time_range_df['timestamps'].iloc[lookback + pred_len - 1]
  1187. time_span = end_timestamp - start_timestamp
  1188. prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
  1189. else:
  1190. # Use latest data
  1191. x_df = df.iloc[:lookback][required_cols]
  1192. x_timestamp = df.iloc[:lookback]['timestamps']
  1193. y_timestamp = df.iloc[lookback:lookback + pred_len]['timestamps']
  1194. prediction_type = "Kronos model prediction (latest data)"
  1195. # Debug information
  1196. print(f"🔍 传递给predictor的数据列: {x_df.columns.tolist()}")
  1197. print(f"🔍 数据形状: {x_df.shape}")
  1198. print(f"🔍 数据样例:")
  1199. print(x_df.head(2))
  1200. # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
  1201. if isinstance(x_timestamp, pd.DatetimeIndex):
  1202. x_timestamp = pd.Series(x_timestamp, name='timestamps')
  1203. if isinstance(y_timestamp, pd.DatetimeIndex):
  1204. y_timestamp = pd.Series(y_timestamp, name='timestamps')
  1205. pred_df = predictor.predict(
  1206. df=x_df,
  1207. x_timestamp=x_timestamp,
  1208. y_timestamp=y_timestamp,
  1209. pred_len=pred_len,
  1210. T=temperature,
  1211. top_p=top_p,
  1212. sample_count=sample_count
  1213. )
  1214. except Exception as e:
  1215. return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
  1216. else:
  1217. return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
  1218. # Prepare actual data for comparison (if exists)
  1219. actual_data = []
  1220. actual_df = None
  1221. if start_date: # Custom time period
  1222. # Fix logic: use data within selected window
  1223. # Prediction uses first 400 data points within selected window
  1224. # Actual data should be last 120 data points within selected window
  1225. start_dt = pd.to_datetime(start_date)
  1226. # Find data starting from start_date
  1227. mask = df['timestamps'] >= start_dt
  1228. time_range_df = df[mask]
  1229. if len(time_range_df) >= lookback + pred_len:
  1230. # Get last 120 data points within selected window as actual values
  1231. actual_df = time_range_df.iloc[lookback:lookback + pred_len]
  1232. for i, (_, row) in enumerate(actual_df.iterrows()):
  1233. actual_data.append({
  1234. 'timestamp': row['timestamps'].isoformat(),
  1235. 'open': float(row['open']),
  1236. 'high': float(row['high']),
  1237. 'low': float(row['low']),
  1238. 'close': float(row['close']),
  1239. 'volume': float(row['volume']) if 'volume' in row else 0,
  1240. 'amount': float(row['amount']) if 'amount' in row else 0
  1241. })
  1242. else: # Latest data
  1243. # Prediction uses first 400 data points
  1244. # Actual data should be 120 data points after first 400 data points
  1245. if len(df) >= lookback + pred_len:
  1246. actual_df = df.iloc[lookback:lookback + pred_len]
  1247. for i, (_, row) in enumerate(actual_df.iterrows()):
  1248. actual_data.append({
  1249. 'timestamp': row['timestamps'].isoformat(),
  1250. 'open': float(row['open']),
  1251. 'high': float(row['high']),
  1252. 'low': float(row['low']),
  1253. 'close': float(row['close']),
  1254. 'volume': float(row['volume']) if 'volume' in row else 0,
  1255. 'amount': float(row['amount']) if 'amount' in row else 0
  1256. })
  1257. # Create chart - pass historical data start position
  1258. if start_date:
  1259. # Custom time period: find starting position of historical data in original df
  1260. start_dt = pd.to_datetime(start_date)
  1261. mask = df['timestamps'] >= start_dt
  1262. historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
  1263. else:
  1264. # Latest data: start from beginning
  1265. historical_start_idx = 0
  1266. chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
  1267. # Prepare prediction result data - fix timestamp calculation logic
  1268. if 'timestamps' in df.columns:
  1269. if start_date:
  1270. # Custom time period: use selected window data to calculate timestamps
  1271. start_dt = pd.to_datetime(start_date)
  1272. mask = df['timestamps'] >= start_dt
  1273. time_range_df = df[mask]
  1274. if len(time_range_df) >= lookback:
  1275. # Calculate prediction timestamps starting from last time point of selected window
  1276. last_timestamp = time_range_df['timestamps'].iloc[lookback - 1]
  1277. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1278. future_timestamps = pd.date_range(
  1279. start=last_timestamp + time_diff,
  1280. periods=pred_len,
  1281. freq=time_diff
  1282. )
  1283. else:
  1284. future_timestamps = []
  1285. else:
  1286. # Latest data: calculate from last time point of entire data file
  1287. last_timestamp = df['timestamps'].iloc[-1]
  1288. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1289. future_timestamps = pd.date_range(
  1290. start=last_timestamp + time_diff,
  1291. periods=pred_len,
  1292. freq=time_diff
  1293. )
  1294. else:
  1295. future_timestamps = range(len(df), len(df) + pred_len)
  1296. prediction_results = []
  1297. for i, (_, row) in enumerate(pred_df.iterrows()):
  1298. prediction_results.append({
  1299. 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
  1300. 'open': float(row['open']),
  1301. 'high': float(row['high']),
  1302. 'low': float(row['low']),
  1303. 'close': float(row['close']),
  1304. 'volume': float(row['volume']) if 'volume' in row else 0,
  1305. 'amount': float(row['amount']) if 'amount' in row else 0
  1306. })
  1307. # Save prediction results to file
  1308. try:
  1309. save_prediction_results(
  1310. file_path=file_path,
  1311. prediction_type=prediction_type,
  1312. prediction_results=prediction_results,
  1313. actual_data=actual_data,
  1314. input_data=x_df,
  1315. prediction_params={
  1316. 'lookback': lookback,
  1317. 'pred_len': pred_len,
  1318. 'temperature': temperature,
  1319. 'top_p': top_p,
  1320. 'sample_count': sample_count,
  1321. 'start_date': start_date if start_date else 'latest'
  1322. }
  1323. )
  1324. except Exception as e:
  1325. print(f"Failed to save prediction results: {e}")
  1326. # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  1327. # 在返回前添加
  1328. print(f"✅ 预测完成,返回数据:")
  1329. print(f" 成功: {True}")
  1330. print(f" 预测类型: {prediction_type}")
  1331. print(f" 图表数据长度: {len(chart_json)}")
  1332. print(f" 预测结果数量: {len(prediction_results)}")
  1333. print(f" 实际数据数量: {len(actual_data)}")
  1334. print(f" 有比较数据: {len(actual_data) > 0}")
  1335. return jsonify({
  1336. 'success': True,
  1337. 'prediction_type': prediction_type,
  1338. 'chart': chart_json,
  1339. 'prediction_results': prediction_results,
  1340. 'actual_data': actual_data,
  1341. 'has_comparison': len(actual_data) > 0,
  1342. 'message': f'Prediction completed, generated {pred_len} prediction points' + (
  1343. f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
  1344. })
  1345. # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  1346. # return jsonify({
  1347. # 'success': True,
  1348. # 'prediction_type': prediction_type,
  1349. # 'chart': chart_json,
  1350. # 'prediction_results': prediction_results,
  1351. # 'actual_data': actual_data,
  1352. # 'has_comparison': len(actual_data) > 0,
  1353. # 'message': f'Prediction completed, generated {pred_len} prediction points' + (
  1354. # f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
  1355. # })
  1356. except Exception as e:
  1357. return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
  1358. # @app.route('/api/load-model', methods=['POST'])
  1359. # def load_model():
  1360. # """Load Kronos model"""
  1361. # global tokenizer, model, predictor
  1362. #
  1363. # try:
  1364. # if not MODEL_AVAILABLE:
  1365. # return jsonify({'error': 'Kronos model library not available'}), 400
  1366. #
  1367. # data = request.get_json()
  1368. # model_key = data.get('model_key', 'kronos-small')
  1369. # device = data.get('device', 'cpu')
  1370. #
  1371. # if model_key not in AVAILABLE_MODELS:
  1372. # return jsonify({'error': f'Unsupported model: {model_key}'}), 400
  1373. #
  1374. # model_config = AVAILABLE_MODELS[model_key]
  1375. #
  1376. # # Load tokenizer and model
  1377. # tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
  1378. # model = Kronos.from_pretrained(model_config['model_id'])
  1379. #
  1380. # # Create predictor
  1381. # predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
  1382. #
  1383. # return jsonify({
  1384. # 'success': True,
  1385. # 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
  1386. # 'model_info': {
  1387. # 'name': model_config['name'],
  1388. # 'params': model_config['params'],
  1389. # 'context_length': model_config['context_length'],
  1390. # 'description': model_config['description']
  1391. # }
  1392. # })
  1393. #
  1394. # except Exception as e:
  1395. # return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
  1396. @app.route('/api/load-model', methods=['POST'])
  1397. def load_model():
  1398. global tokenizer, model, predictor
  1399. try:
  1400. if not MODEL_AVAILABLE:
  1401. return jsonify({'error': 'Kronos model library not available'}), 400
  1402. data = request.get_json()
  1403. model_key = data.get('model_key', 'kronos-small')
  1404. device = data.get('device', 'cpu')
  1405. if model_key not in AVAILABLE_MODELS:
  1406. return jsonify({'error': f'Unsupported model: {model_key}'}), 400
  1407. model_config = AVAILABLE_MODELS[model_key]
  1408. print(f"🚀 Loading model from: {model_config['model_id']}")
  1409. model_path = model_config['model_id']
  1410. tokenizer_path = model_config['tokenizer_id']
  1411. if os.path.exists(model_path):
  1412. model_files = os.listdir(model_path)
  1413. print(f"📄 模型目录中的文件: {model_files}")
  1414. # 检查模型路径是否存在
  1415. if not os.path.exists(model_path):
  1416. return jsonify({'error': f'模型路径不存在: {model_path}'}), 400
  1417. try:
  1418. # 直接从本地加载模型
  1419. model = Kronos.from_pretrained(
  1420. model_config['model_id'],
  1421. local_files_only=True
  1422. )
  1423. # 读取模型配置文件获取正确参数
  1424. config_path = os.path.join(model_config['model_id'], 'config.json')
  1425. if os.path.exists(config_path):
  1426. print(f"读取配置文件: {config_path}")
  1427. with open(config_path, 'r') as f:
  1428. config = json.load(f)
  1429. for key, value in config.items():
  1430. print(f" {key}: {value}")
  1431. # 使用配置中的参数创建tokenizer
  1432. tokenizer = KronosTokenizer(
  1433. d_in=6, # OHLC + volume
  1434. d_model=config['d_model'], # 832
  1435. n_heads=config['n_heads'], # 16
  1436. ff_dim=config['ff_dim'], # 2048
  1437. n_enc_layers=config['n_layers'], # 12
  1438. n_dec_layers=config['n_layers'], # 12
  1439. ffn_dropout_p=config['ffn_dropout_p'], # 0.2
  1440. attn_dropout_p=config['attn_dropout_p'], # 0.0
  1441. resid_dropout_p=config['resid_dropout_p'], # 0.2
  1442. s1_bits=config['s1_bits'], # 10
  1443. s2_bits=config['s2_bits'], # 10
  1444. beta=1.0,
  1445. gamma0=1.0,
  1446. gamma=1.0,
  1447. zeta=1.0,
  1448. group_size=1
  1449. )
  1450. else:
  1451. return jsonify({'error': f'Config file not found: {config_path}'}), 400
  1452. except Exception as e:
  1453. return jsonify({'error': f'Failed to load model: {str(e)}'}), 500
  1454. # 创建predictor
  1455. predictor = KronosPredictor(
  1456. model,
  1457. tokenizer,
  1458. device=device,
  1459. max_context=model_config['context_length']
  1460. )
  1461. return jsonify({
  1462. 'success': True,
  1463. 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
  1464. 'model_info': {
  1465. **model_config,
  1466. 'model_path': model_config['model_id'],
  1467. 'abs_model_path': os.path.abspath(model_config['model_id']),
  1468. 'device': device
  1469. }
  1470. })
  1471. except Exception as e:
  1472. import traceback
  1473. print("【API接口错误】")
  1474. print(f"错误类型: {type(e).__name__}")
  1475. print(f"错误信息: {str(e)}")
  1476. traceback.print_exc()
  1477. print("=" * 60)
  1478. return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
  1479. @app.route('/api/available-models')
  1480. def get_available_models():
  1481. """Get available model list"""
  1482. return jsonify({
  1483. 'models': AVAILABLE_MODELS,
  1484. 'model_available': MODEL_AVAILABLE
  1485. })
  1486. @app.route('/api/model-status')
  1487. def get_model_status():
  1488. """Get model status"""
  1489. if MODEL_AVAILABLE:
  1490. if predictor is not None:
  1491. return jsonify({
  1492. 'available': True,
  1493. 'loaded': True,
  1494. 'message': 'Kronos model loaded and available',
  1495. 'current_model': {
  1496. 'name': predictor.model.__class__.__name__,
  1497. 'device': str(next(predictor.model.parameters()).device)
  1498. }
  1499. })
  1500. else:
  1501. return jsonify({
  1502. 'available': True,
  1503. 'loaded': False,
  1504. 'message': 'Kronos model available but not loaded'
  1505. })
  1506. else:
  1507. return jsonify({
  1508. 'available': False,
  1509. 'loaded': False,
  1510. 'message': 'Kronos model library not available, please install related dependencies'
  1511. })
  1512. @app.route('/api/stock-data', methods=['POST'])
  1513. def Stock_Data():
  1514. try:
  1515. data = request.get_json()
  1516. stock_code = data.get('stock_code', '').strip()
  1517. # 股票代码不能为空
  1518. if not stock_code:
  1519. return jsonify({
  1520. 'success': False,
  1521. 'error': f'Stock code cannot be empty'
  1522. }), 400
  1523. # 股票代码格式验证
  1524. if not re.match(r'^[a-z]+\.\d+$', stock_code):
  1525. return jsonify({
  1526. 'success': False,
  1527. 'error': f'The stock code you entered is invalid'
  1528. }), 400
  1529. # 登录 baostock
  1530. lg = bs.login()
  1531. if lg.error_code != '0':
  1532. return jsonify({
  1533. 'success': False,
  1534. 'error': f'Login failed: {lg.error_msg}'
  1535. }), 400
  1536. end_date = datetime.now().strftime('%Y-%m-%d')
  1537. rs = bs.query_history_k_data_plus(
  1538. stock_code,
  1539. "time,open,high,low,close,volume,amount",
  1540. start_date = '2024-06-01',
  1541. end_date = end_date,
  1542. frequency = "5",
  1543. adjustflag = "3"
  1544. )
  1545. # 检查获取结果
  1546. if rs.error_code != '0':
  1547. bs.logout()
  1548. return jsonify({
  1549. 'success': False,
  1550. 'error': f'Failed to retrieve data, please enter a valid stock code'
  1551. }), 400
  1552. # 提取数据
  1553. data_list = []
  1554. while rs.next():
  1555. data_list.append(rs.get_row_data())
  1556. # 登出系统
  1557. bs.logout()
  1558. columns = rs.fields
  1559. df = pd.DataFrame(data_list, columns=columns)
  1560. # 数值列转换
  1561. df = df.rename(columns={'time': 'timestamps'})
  1562. numeric_columns = ['timestamps','open', 'high', 'low', 'close', 'volume', 'amount']
  1563. for col in numeric_columns:
  1564. df[col] = pd.to_numeric(df[col], errors='coerce')
  1565. df['timestamps'] = pd.to_datetime(df['timestamps'].astype(str), format='%Y%m%d%H%M%S%f')
  1566. # 去除无效数据
  1567. df = df.dropna()
  1568. # 保存
  1569. data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
  1570. os.makedirs(data_dir, exist_ok=True)
  1571. filename = f"Stock_5min_A股.csv"
  1572. file_path = os.path.join(data_dir, filename)
  1573. df.to_csv(
  1574. file_path,
  1575. index = False,
  1576. encoding = 'utf-8',
  1577. mode = 'w'
  1578. )
  1579. data_files = load_data_files()
  1580. return jsonify({
  1581. 'success': True,
  1582. 'message': f'Stock data saved successfully: {filename}',
  1583. 'file_name': filename
  1584. })
  1585. except Exception as e:
  1586. return jsonify({
  1587. 'success': False,
  1588. 'error': f'Error processing stock data: {str(e)}'
  1589. }), 500
  1590. @app.route('/api/generate-chart', methods=['POST'])
  1591. def generate_chart():
  1592. try:
  1593. data = request.get_json()
  1594. # 验证参数
  1595. required_fields = ['file_path', 'lookback', 'diagram_type', 'historical_start_idx']
  1596. for field in required_fields:
  1597. if field not in data:
  1598. return jsonify({'success': False, 'error': f'Missing required field: {field}'}), 400
  1599. # 解析参数
  1600. file_path = data['file_path']
  1601. lookback = int(data['lookback'])
  1602. diagram_type = data['diagram_type']
  1603. historical_start_idx = int(data['historical_start_idx'])
  1604. # 加载数据
  1605. df, error = load_data_file(file_path)
  1606. if error:
  1607. return jsonify({'success': False, 'error': error}), 400
  1608. if len(df) < lookback + historical_start_idx:
  1609. return jsonify({
  1610. 'success': False,
  1611. 'error': f'Insufficient data length, need at least {lookback + historical_start_idx} rows'
  1612. }), 400
  1613. pred_df = None
  1614. actual_df = None
  1615. # 生成图表
  1616. chart_json = create_technical_chart(
  1617. df=df,
  1618. pred_df=pred_df,
  1619. lookback=lookback,
  1620. pred_len=0,
  1621. diagram_type=diagram_type,
  1622. actual_df=actual_df,
  1623. historical_start_idx=historical_start_idx
  1624. )
  1625. # 表格数据
  1626. table_data_start = historical_start_idx
  1627. table_data_end = historical_start_idx + lookback
  1628. table_df = df.iloc[table_data_start:table_data_end]
  1629. table_data = table_df.to_dict('records')
  1630. return jsonify({
  1631. 'success': True,
  1632. 'chart': json.loads(chart_json),
  1633. 'table_data': table_data,
  1634. 'message': 'Technical chart generated successfully'
  1635. })
  1636. except Exception as e:
  1637. return jsonify({
  1638. 'success': False,
  1639. 'error': f'Failed to generate technical chart: {str(e)}'
  1640. }), 500
  1641. if __name__ == '__main__':
  1642. print("Starting Kronos Web UI...")
  1643. print(f"Model availability: {MODEL_AVAILABLE}")
  1644. if MODEL_AVAILABLE:
  1645. print("Tip: You can load Kronos model through /api/load-model endpoint")
  1646. else:
  1647. print("Tip: Will use simulated data for demonstration")
  1648. app.run(debug=True, host='0.0.0.0', port=7070)