|
|
import osimport pandas as pdimport numpy as npimport jsonimport plotly.graph_objects as goimport plotly.utilsfrom flask import Flask, render_template, request, jsonifyfrom flask_cors import CORSimport sysimport warningsfrom datetime import datetimeimport baostock as bsimport re
warnings.filterwarnings('ignore')
# Add project root directory to pathsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try: from model import Kronos, KronosTokenizer, KronosPredictor MODEL_AVAILABLE = Trueexcept ImportError: MODEL_AVAILABLE = False print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
app = Flask(__name__)CORS(app)
# Global variables to store modelstokenizer = Nonemodel = Nonepredictor = None
# 获取webui目录的路径WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))# 获取项目根目录(webui的父目录)BASE_DIR = os.path.dirname(WEBUI_DIR)
AVAILABLE_MODELS = { 'kronos-mini': { 'name': 'Kronos-mini', 'model_id': os.path.join(BASE_DIR, 'models', 'Kronos-mini'), 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'), 'context_length': 2048, 'params': '4.1M', 'description': '轻量级模型,适合快速预测' }, 'kronos-small': { 'name': 'Kronos-small', 'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-small'), 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'), 'context_length': 512, 'params': '24.7M', 'description': '小型模型,平衡性能和速度' }, 'kronos-base': { 'name': 'Kronos-base', 'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-base'), 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'), 'context_length': 512, 'params': '102.3M', 'description': '基础模型,提供更好的预测质量' }}
# Available model configurations# AVAILABLE_MODELS = {# 'kronos-mini': {# 'name': 'Kronos-mini',# 'model_id': 'models/Kronos-mini', # 本地路径# 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径# 'context_length': 2048,# 'params': '4.1M',# 'description': '轻量级模型,适合快速预测'# },# 'kronos-small': {# 'name': 'Kronos-small',# 'model_id': 'models/NeoQuasarKronos-small', # 本地路径# 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径# 'context_length': 512,# 'params': '24.7M',# 'description': '小型模型,平衡性能和速度'# },# 'kronos-base': {# 'name': 'Kronos-base',# 'model_id': 'models/NeoQuasarKronos-base', # 本地路径# 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径# 'context_length': 512,# 'params': '102.3M',# 'description': '基础模型,提供更好的预测质量'# }# }
def load_data_files(): """Scan data directory and return available data files""" data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') data_files = [] if os.path.exists(data_dir): for file in os.listdir(data_dir): if file.endswith(('.csv', '.feather')): file_path = os.path.join(data_dir, file) file_size = os.path.getsize(file_path) data_files.append({ 'name': file, 'path': file_path, 'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB" }) return data_files
def load_data_file(file_path): """Load data file""" try: if file_path.endswith('.csv'): df = pd.read_csv(file_path) elif file_path.endswith('.feather'): df = pd.read_feather(file_path) else: return None, "Unsupported file format" # Check required columns required_cols = ['open', 'high', 'low', 'close'] if not all(col in df.columns for col in required_cols): return None, f"Missing required columns: {required_cols}" # Process timestamp column if 'timestamps' in df.columns: df['timestamps'] = pd.to_datetime(df['timestamps']) elif 'timestamp' in df.columns: df['timestamps'] = pd.to_datetime(df['timestamp']) elif 'date' in df.columns: # If column name is 'date', rename it to 'timestamps' df['timestamps'] = pd.to_datetime(df['date']) else: # If no timestamp column exists, create one df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H') # Ensure numeric columns are numeric type for col in ['open', 'high', 'low', 'close']: df[col] = pd.to_numeric(df[col], errors='coerce') # Process volume column (optional) if 'volume' in df.columns: df['volume'] = pd.to_numeric(df['volume'], errors='coerce') # Process amount column (optional, but not used for prediction) if 'amount' in df.columns: df['amount'] = pd.to_numeric(df['amount'], errors='coerce') # Remove rows containing NaN values df = df.dropna() return df, None except Exception as e: return None, f"Failed to load file: {str(e)}"
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params): """Save prediction results to file""" try: # Create prediction results directory results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results') os.makedirs(results_dir, exist_ok=True) # Generate filename timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') filename = f'prediction_{timestamp}.json' filepath = os.path.join(results_dir, filename) # Prepare data for saving save_data = { 'timestamp': datetime.datetime.now().isoformat(), 'file_path': file_path, 'prediction_type': prediction_type, 'prediction_params': prediction_params, 'input_data_summary': { 'rows': len(input_data), 'columns': list(input_data.columns), 'price_range': { 'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())}, 'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())}, 'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())}, 'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())} }, 'last_values': { 'open': float(input_data['open'].iloc[-1]), 'high': float(input_data['high'].iloc[-1]), 'low': float(input_data['low'].iloc[-1]), 'close': float(input_data['close'].iloc[-1]) } }, 'prediction_results': prediction_results, 'actual_data': actual_data, 'analysis': {} } # If actual data exists, perform comparison analysis if actual_data and len(actual_data) > 0: # Calculate continuity analysis if len(prediction_results) > 0 and len(actual_data) > 0: last_pred = prediction_results[0] # First prediction point first_actual = actual_data[0] # First actual point save_data['analysis']['continuity'] = { 'last_prediction': { 'open': last_pred['open'], 'high': last_pred['high'], 'low': last_pred['low'], 'close': last_pred['close'] }, 'first_actual': { 'open': first_actual['open'], 'high': first_actual['high'], 'low': first_actual['low'], 'close': first_actual['close'] }, 'gaps': { 'open_gap': abs(last_pred['open'] - first_actual['open']), 'high_gap': abs(last_pred['high'] - first_actual['high']), 'low_gap': abs(last_pred['low'] - first_actual['low']), 'close_gap': abs(last_pred['close'] - first_actual['close']) }, 'gap_percentages': { 'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100, 'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100, 'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100, 'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100 } } # Save to file with open(filepath, 'w', encoding='utf-8') as f: json.dump(save_data, f, indent=2, ensure_ascii=False) print(f"Prediction results saved to: {filepath}") return filepath except Exception as e: print(f"Failed to save prediction results: {e}") return None
# def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):# """Create prediction chart"""## print(f"🔍 创建图表调试:")# print(f" 历史数据: {len(df) if df is not None else 0} 行")# print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行")# print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")## # 确保数据不为空# if pred_df is None or len(pred_df) == 0:# print("⚠️ 警告: 预测数据为空!")# # 创建空图表# fig = go.Figure()# fig.update_layout(title='No prediction data available')# return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)## # 其余代码保持不变...## # Use specified historical data start position, not always from the beginning of df# if historical_start_idx + lookback + pred_len <= len(df):# # Display lookback historical points + pred_len prediction points starting from specified position# historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]# prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)# else:# # If data is insufficient, adjust to maximum available range# available_lookback = min(lookback, len(df) - historical_start_idx)# available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))# historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]# prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)## # Create chart# fig = go.Figure()## # Add historical data (candlestick chart)# fig.add_trace(go.Candlestick(# x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,# open=historical_df['open'],# high=historical_df['high'],# low=historical_df['low'],# close=historical_df['close'],# name='Historical Data (400 data points)',# increasing_line_color='#26A69A',# decreasing_line_color='#EF5350'# ))## # Add prediction data (candlestick chart)# if pred_df is not None and len(pred_df) > 0:# # Calculate prediction data timestamps - ensure continuity with historical data# if 'timestamps' in df.columns and len(historical_df) > 0:# # Start from the last timestamp of historical data, create prediction timestamps with the same time interval# last_timestamp = historical_df['timestamps'].iloc[-1]# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)## pred_timestamps = pd.date_range(# start=last_timestamp + time_diff,# periods=len(pred_df),# freq=time_diff# )# else:# # If no timestamps, use index# pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))## fig.add_trace(go.Candlestick(# x=pred_timestamps,# open=pred_df['open'],# high=pred_df['high'],# low=pred_df['low'],# close=pred_df['close'],# name='Prediction Data (120 data points)',# increasing_line_color='#66BB6A',# decreasing_line_color='#FF7043'# ))## # Add actual data for comparison (if exists)# if actual_df is not None and len(actual_df) > 0:# # Actual data should be in the same time period as prediction data# if 'timestamps' in df.columns:# # Actual data should use the same timestamps as prediction data to ensure time alignment# if 'pred_timestamps' in locals():# actual_timestamps = pred_timestamps# else:# # If no prediction timestamps, calculate from the last timestamp of historical data# if len(historical_df) > 0:# last_timestamp = historical_df['timestamps'].iloc[-1]# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)# actual_timestamps = pd.date_range(# start=last_timestamp + time_diff,# periods=len(actual_df),# freq=time_diff# )# else:# actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))# else:# actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))## fig.add_trace(go.Candlestick(# x=actual_timestamps,# open=actual_df['open'],# high=actual_df['high'],# low=actual_df['low'],# close=actual_df['close'],# name='Actual Data (120 data points)',# increasing_line_color='#FF9800',# decreasing_line_color='#F44336'# ))## # Update layout# fig.update_layout(# title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',# xaxis_title='Time',# yaxis_title='Price',# template='plotly_white',# height=600,# showlegend=True# )## # Ensure x-axis time continuity# if 'timestamps' in historical_df.columns:# # Get all timestamps and sort them# all_timestamps = []# if len(historical_df) > 0:# all_timestamps.extend(historical_df['timestamps'])# if 'pred_timestamps' in locals():# all_timestamps.extend(pred_timestamps)# if 'actual_timestamps' in locals():# all_timestamps.extend(actual_timestamps)## if all_timestamps:# all_timestamps = sorted(all_timestamps)# fig.update_xaxes(# range=[all_timestamps[0], all_timestamps[-1]],# rangeslider_visible=False,# type='date'# )## # 修改这一行:# # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)## # 改为:# try:# chart_json = fig.to_json()# print(f"✅ 图表JSON序列化成功,长度: {len(chart_json)}")# return chart_json# except Exception as e:# print(f"❌ 图表序列化失败: {e}")# # 返回一个简单的错误图表# error_fig = go.Figure()# error_fig.update_layout(title='Chart Rendering Error')# return error_fig.to_json()
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0): """Create prediction chart""" print(f"🔍 创建图表调试:") print(f" 历史数据: {len(df) if df is not None else 0} 行") print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行") print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")
# 确保数据不为空 if pred_df is None or len(pred_df) == 0: print("⚠️ 警告: 预测数据为空!") # 创建空图表 fig = go.Figure() fig.update_layout(title='No prediction data available') return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
# Use specified historical data start position, not always from the beginning of df if historical_start_idx + lookback + pred_len <= len(df): # Display lookback historical points + pred_len prediction points starting from specified position historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback] prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len) else: # If data is insufficient, adjust to maximum available range available_lookback = min(lookback, len(df) - historical_start_idx) available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback)) historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback] prediction_range = range(historical_start_idx + available_lookback, historical_start_idx + available_lookback + available_pred_len)
# Create chart fig = go.Figure()
# Add historical data (candlestick chart) fig.add_trace(go.Candlestick( x=historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), open=historical_df['open'].tolist(), high=historical_df['high'].tolist(), low=historical_df['low'].tolist(), close=historical_df['close'].tolist(), name='Historical Data (400 data points)', increasing_line_color='#26A69A', decreasing_line_color='#EF5350' ))
# Add prediction data (candlestick chart) if pred_df is not None and len(pred_df) > 0: # Calculate prediction data timestamps - ensure continuity with historical data if 'timestamps' in df.columns and len(historical_df) > 0: # Start from the last timestamp of historical data, create prediction timestamps with the same time interval last_timestamp = historical_df['timestamps'].iloc[-1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
pred_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=len(pred_df), freq=time_diff ) else: # If no timestamps, use index pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
fig.add_trace(go.Candlestick( x=pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps), open=pred_df['open'].tolist(), high=pred_df['high'].tolist(), low=pred_df['low'].tolist(), close=pred_df['close'].tolist(), name='Prediction Data (120 data points)', increasing_line_color='#66BB6A', decreasing_line_color='#FF7043' ))
# Add actual data for comparison (if exists) if actual_df is not None and len(actual_df) > 0: # Actual data should be in the same time period as prediction data if 'timestamps' in df.columns: # Actual data should use the same timestamps as prediction data to ensure time alignment if 'pred_timestamps' in locals(): actual_timestamps = pred_timestamps else: # If no prediction timestamps, calculate from the last timestamp of historical data if len(historical_df) > 0: last_timestamp = historical_df['timestamps'].iloc[-1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta( hours=1) actual_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=len(actual_df), freq=time_diff ) else: actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df)) else: actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
fig.add_trace(go.Candlestick( x=actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps), open=actual_df['open'].tolist(), high=actual_df['high'].tolist(), low=actual_df['low'].tolist(), close=actual_df['close'].tolist(), name='Actual Data (120 data points)', increasing_line_color='#FF9800', decreasing_line_color='#F44336' ))
# Update layout fig.update_layout( title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points', xaxis_title='Time', yaxis_title='Price', template='plotly_white', height=600, showlegend=True )
# Ensure x-axis time continuity if 'timestamps' in historical_df.columns: # Get all timestamps and sort them all_timestamps = [] if len(historical_df) > 0: all_timestamps.extend(historical_df['timestamps'].tolist()) if 'pred_timestamps' in locals(): all_timestamps.extend( pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps)) if 'actual_timestamps' in locals(): all_timestamps.extend( actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps))
if all_timestamps: all_timestamps = sorted(all_timestamps) fig.update_xaxes( range=[all_timestamps[0], all_timestamps[-1]], rangeslider_visible=False, type='date' )
# return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) try: chart_json = fig.to_json() print(f"✅ 图表数据序列化完成,长度: {len(chart_json)}") return chart_json except Exception as e: print(f"❌ 图表序列化失败: {e}") error_fig = go.Figure() error_fig.update_layout(title='Chart Rendering Error') return error_fig.to_json()
# 计算技术指标def calculate_indicators(df): indicators = {}
# 计算移动平均线 (MA) indicators['ma5'] = df['close'].rolling(window=5).mean() indicators['ma10'] = df['close'].rolling(window=10).mean() indicators['ma20'] = df['close'].rolling(window=20).mean()
# 计算MACD exp12 = df['close'].ewm(span=12, adjust=False).mean() exp26 = df['close'].ewm(span=26, adjust=False).mean() indicators['macd'] = exp12 - exp26 indicators['signal'] = indicators['macd'].ewm(span=9, adjust=False).mean() indicators['macd_hist'] = indicators['macd'] - indicators['signal']
# 计算RSI delta = df['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() rs = gain / loss indicators['rsi'] = 100 - (100 / (1 + rs))
# 计算布林带 indicators['bb_mid'] = df['close'].rolling(window=20).mean() indicators['bb_std'] = df['close'].rolling(window=20).std() indicators['bb_upper'] = indicators['bb_mid'] + 2 * indicators['bb_std'] indicators['bb_lower'] = indicators['bb_mid'] - 2 * indicators['bb_std']
# 计算随机震荡指标 low_min = df['low'].rolling(window=14).min() high_max = df['high'].rolling(window=14).max() indicators['stoch_k'] = 100 * ((df['close'] - low_min) / (high_max - low_min)) indicators['stoch_d'] = indicators['stoch_k'].rolling(window=3).mean()
# 滚动窗口均值策略 indicators['rwms_window'] = 90 indicators['rwms_mean'] = df['close'].rolling(window=90).mean() indicators['rwms_signal'] = (df['close'] > indicators['rwms_mean']).astype(int)
# 三重指数平均(TRIX)策略 ema1 = df['close'].ewm(span=12, adjust=False).mean() ema2 = ema1.ewm(span=12, adjust=False).mean() ema3 = ema2.ewm(span=12, adjust=False).mean() indicators['trix'] = (ema3 - ema3.shift(1)) / ema3.shift(1) * 100 indicators['trix_signal'] = indicators['trix'].ewm(span=9, adjust=False).mean()
return indicators
# 技术指标图表绘制def create_technical_chart(df, pred_df, lookback, pred_len, diagram_type, actual_df=None, historical_start_idx=0): print(f" 🔍 数据内容: {len(df) if df is not None else 0} 行") print(f" 🔍 图表类型: {diagram_type}")
# 数据范围 if historical_start_idx + lookback <= len(df): historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback] else: available_lookback = min(lookback, len(df) - historical_start_idx) historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
# 计算指标 historical_indicators = calculate_indicators(historical_df)
fig = go.Figure()
# 成交量图表 if diagram_type == 'Volume Chart (VOL)': fig.add_trace(go.Bar( x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_df['volume'].tolist() if 'volume' in historical_df.columns else [], name = 'Historical Volume', marker_color='#42A5F5' ))
if actual_df is not None and len(actual_df) > 0 and 'volume' in actual_df.columns: if 'timestamps' in df.columns and len(historical_df) > 0: last_timestamp = historical_df['timestamps'].iloc[-1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta( hours=1) actual_timestamps = pd.date_range(start=last_timestamp + time_diff, periods=len(actual_df),freq=time_diff) else: actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
fig.add_trace(go.Bar( x = actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps), y = actual_df['volume'].tolist(), name = 'Actual Volume', marker_color='#FF9800' ))
fig.update_layout(yaxis_title='Volume')
# 移动平均线 elif diagram_type == 'Moving Average (MA)': fig.add_trace(go.Scatter( x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['ma5'], name='MA5', line=dict(color='#26A69A', width=1) )) fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['ma10'], name = 'MA10', line = dict(color = '#42A5F5', width = 1) )) fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['ma20'], name = 'MA20', line = dict(color = '#7E57C2', width = 1) ))
fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_df['close'], name = 'Close Price', line = dict(color = '#212121', width = 1, dash = 'dash') ))
fig.update_layout(yaxis_title = 'Price')
# MACD指标 elif diagram_type == 'MACD Indicator (MACD)': fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['macd'], name = 'MACD', line = dict(color = '#26A69A', width = 1) ))
fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['signal'], name = 'Signal', line = dict(color = '#EF5350', width = 1) ))
fig.add_trace(go.Bar( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['macd_hist'], name = 'MACD Histogram', marker_color = '#42A5F5' ))
fig.add_hline(y = 0, line_dash = "dash", line_color = "gray") fig.update_layout(yaxis_title = 'MACD')
# RSI指标 elif diagram_type == 'RSI Indicator (RSI)': fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['rsi'], name = 'RSI', line = dict(color = '#26A69A', width = 1) ))
fig.add_hline(y = 70, line_dash = "dash", line_color = "red", name = 'Overbought') fig.add_hline(y = 30, line_dash = "dash", line_color = "green", name = 'Oversold') fig.update_layout(yaxis_title = 'RSI', yaxis_range = [0, 100])
# 布林带 elif diagram_type == 'Bollinger Bands (BB)': fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['bb_upper'], name = 'Upper Band', line = dict(color = '#EF5350', width = 1) )) fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['bb_mid'], name = 'Middle Band (MA20)', line = dict(color = '#42A5F5', width = 1) )) fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['bb_lower'], name = 'Lower Band', line = dict(color = '#26A69A', width = 1) ))
fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_df['close'], name = 'Close Price', line = dict(color = '#212121', width = 1) ))
fig.update_layout(yaxis_title = 'Price')
# 随机震荡指标 elif diagram_type == 'Stochastic Oscillator (STOCH)': fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['stoch_k'], name = '%K', line = dict(color = '#26A69A', width = 1) )) fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['stoch_d'], name = '%D', line = dict(color = '#EF5350', width = 1) ))
fig.add_hline(y = 80, line_dash = "dash", line_color = "red", name = 'Overbought') fig.add_hline(y = 20, line_dash = "dash", line_color = "green", name = 'Oversold') fig.update_layout(yaxis_title = 'Stochastic', yaxis_range = [0, 100])
# 滚动窗口均值策略 elif diagram_type == 'Rolling Window Mean Strategy': fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_df['close'], name = 'Close Price', line = dict(color = '#212121', width = 1.5) ))
fig.add_trace(go.Scatter( x = historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y = historical_indicators['rwms_mean'], name = f'Rolling Mean ({historical_indicators["rwms_window"]} periods)', line = dict(color = '#42A5F5', width = 1.5, dash = 'dash') ))
buy_signals = historical_df[historical_indicators['rwms_signal'] == 1] fig.add_trace(go.Scatter( x = buy_signals['timestamps'].tolist() if 'timestamps' in buy_signals.columns else buy_signals.index.tolist(), y = buy_signals['close'], mode = 'markers', name = 'Buy Signal', marker = dict(color = '#26A69A', size = 8, symbol = 'triangle-up') ))
sell_signals = historical_df[historical_indicators['rwms_signal'] == 0] fig.add_trace(go.Scatter( x = sell_signals[ 'timestamps'].tolist() if 'timestamps' in sell_signals.columns else sell_signals.index.tolist(), y = sell_signals['close'], mode = 'markers', name = 'Sell Signal', marker = dict(color = '#EF5350', size = 8, symbol = 'triangle-down') ))
fig.update_layout( yaxis_title = 'Price', title = f'Rolling Window Mean Strategy (Window Size: {historical_indicators["rwms_window"]})' )
# TRIX指标图表 elif diagram_type == 'TRIX Indicator (TRIX)':
fig.add_trace(go.Scatter( x=historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y=historical_indicators['trix'], name='TRIX', line=dict(color='#26A69A', width=1) ))
fig.add_trace(go.Scatter( x=historical_df[ 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(), y=historical_indicators['trix_signal'], name='TRIX Signal', line=dict(color='#EF5350', width=1) ))
fig.add_hline(y=0, line_dash="dash", line_color="gray")
fig.update_layout( yaxis_title='TRIX (%)', title='Triple Exponential Average (TRIX) Strategy' )
# 布局设置 fig.update_layout( title = f'{diagram_type} - Technical Indicator', xaxis_title = 'Time', template = 'plotly_white', height = 400, showlegend = True, margin = dict(t = 50, b = 30) )
if 'timestamps' in historical_df.columns: all_timestamps = historical_df['timestamps'].tolist()
if actual_df is not None and len(actual_df) > 0 and 'timestamps' in df.columns: if 'actual_timestamps' in locals(): all_timestamps.extend(actual_timestamps.tolist())
if all_timestamps: all_timestamps = sorted(all_timestamps) fig.update_xaxes( range=[all_timestamps[0], all_timestamps[-1]], rangeslider_visible=False, type='date' )
try: chart_json = fig.to_json() print(f"✅ 技术指标图表序列化完成,长度: {len(chart_json)}") return chart_json except Exception as e: print(f"❌ 技术指标图表序列化失败: {e}") error_fig = go.Figure() error_fig.update_layout(title='Chart Rendering Error') return error_fig.to_json()
@app.route('/')def index(): """Home page""" return render_template('index.html')
@app.route('/api/data-files')def get_data_files(): """Get available data file list""" data_files = load_data_files() return jsonify(data_files)
@app.route('/api/load-data', methods=['POST'])def load_data(): """Load data file""" try: data = request.get_json() file_path = data.get('file_path') if not file_path: return jsonify({'error': 'File path cannot be empty'}), 400 df, error = load_data_file(file_path) if error: return jsonify({'error': error}), 400 # Detect data time frequency def detect_timeframe(df): if len(df) < 2: return "Unknown" time_diffs = [] for i in range(1, min(10, len(df))): # Check first 10 time differences diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1] time_diffs.append(diff) if not time_diffs: return "Unknown" # Calculate average time difference avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs) # Convert to readable format if avg_diff < pd.Timedelta(minutes=1): return f"{avg_diff.total_seconds():.0f} seconds" elif avg_diff < pd.Timedelta(hours=1): return f"{avg_diff.total_seconds() / 60:.0f} minutes" elif avg_diff < pd.Timedelta(days=1): return f"{avg_diff.total_seconds() / 3600:.0f} hours" else: return f"{avg_diff.days} days" # Return data information data_info = { 'rows': len(df), 'columns': list(df.columns), 'start_date': df['timestamps'].min().isoformat() if 'timestamps' in df.columns else 'N/A', 'end_date': df['timestamps'].max().isoformat() if 'timestamps' in df.columns else 'N/A', 'price_range': { 'min': float(df[['open', 'high', 'low', 'close']].min().min()), 'max': float(df[['open', 'high', 'low', 'close']].max().max()) }, 'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []), 'timeframe': detect_timeframe(df) } return jsonify({ 'success': True, 'data_info': data_info, 'message': f'Successfully loaded data, total {len(df)} rows' }) except Exception as e: return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
# @app.route('/api/predict', methods=['POST'])# def predict():# """Perform prediction"""# try:# data = request.get_json()# file_path = data.get('file_path')# lookback = int(data.get('lookback', 400))# pred_len = int(data.get('pred_len', 120))## # Get prediction quality parameters# temperature = float(data.get('temperature', 1.0))# top_p = float(data.get('top_p', 0.9))# sample_count = int(data.get('sample_count', 1))## if not file_path:# return jsonify({'error': 'File path cannot be empty'}), 400## # Load data# df, error = load_data_file(file_path)# if error:# return jsonify({'error': error}), 400## if len(df) < lookback:# return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400## # Perform prediction# if MODEL_AVAILABLE and predictor is not None:# try:# # Use real Kronos model# # Only use necessary columns: OHLCV, excluding amount# required_cols = ['open', 'high', 'low', 'close']# if 'volume' in df.columns:# required_cols.append('volume')## # Process time period selection# start_date = data.get('start_date')## if start_date:# # Custom time period - fix logic: use data within selected window# start_dt = pd.to_datetime(start_date)## # Find data after start time# mask = df['timestamps'] >= start_dt# time_range_df = df[mask]## # Ensure sufficient data: lookback + pred_len# if len(time_range_df) < lookback + pred_len:# return jsonify({'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400## # Use first lookback data points within selected window for prediction# x_df = time_range_df.iloc[:lookback][required_cols]# x_timestamp = time_range_df.iloc[:lookback]['timestamps']## # Use last pred_len data points within selected window as actual values# y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']## # Calculate actual time period length# start_timestamp = time_range_df['timestamps'].iloc[0]# end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]# time_span = end_timestamp - start_timestamp## prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"# else:# # Use latest data# x_df = df.iloc[:lookback][required_cols]# x_timestamp = df.iloc[:lookback]['timestamps']# y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']# prediction_type = "Kronos model prediction (latest data)"## # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model# if isinstance(x_timestamp, pd.DatetimeIndex):# x_timestamp = pd.Series(x_timestamp, name='timestamps')# if isinstance(y_timestamp, pd.DatetimeIndex):# y_timestamp = pd.Series(y_timestamp, name='timestamps')## # # 在 pred_df = predictor.predict(...) 之前添加:# # print("🔍 调试预测输入:")# # print(f"x_df 类型: {type(x_df)}")# # print(f"x_df 形状: {x_df.shape}")# # print(f"x_df 列名: {x_df.columns.tolist()}")# # print(f"x_df 数据类型: {x_df.dtypes}")# ## # print(f"x_timestamp 类型: {type(x_timestamp)}")# # print(f"x_timestamp 长度: {len(x_timestamp)}")# ## # print(f"y_timestamp 类型: {type(y_timestamp)}")# # print(f"y_timestamp 长度: {len(y_timestamp)}")# ## # # 检查数据内容# # print("x_df 前5行:")# # print(x_df.head())# ## # # 在调用 predict 前确保数据格式正确# # print(f"x_df 实际形状: {x_df.shape}") # 确认是 (400, 5)# # print(f"x_df 数值类型: {x_df.values.dtype}")# ## # # 确保没有隐藏的索引列# # x_df_clean = x_df.reset_index(drop=True)# # print(f"重置索引后形状: {x_df_clean.shape}")# ## # # 在调用 predict 之前添加更详细的调试# # print("🔍 深入调试 KronosPredictor:")# ## # # 检查 predictor 的属性# # print(f"predictor 类型: {type(predictor)}")# # print(f"predictor 设备: {getattr(predictor, 'device', 'unknown')}")# # print(f"predictor max_context: {getattr(predictor, 'max_context', 'unknown')}")# ## # # 检查模型输入维度# # if hasattr(predictor, 'model'):# # model = predictor.model# # print(f"模型参数示例:")# # for name, param in model.named_parameters():# # if 'weight' in name and param.dim() == 2:# # print(f" {name}: {param.shape}")# # break# ## # # 尝试手动准备数据# # try:# # # 将数据转换为 tensor 看看维度# # import torch# # x_tensor = torch.tensor(x_df.values, dtype=torch.float32)# # print(f"Tensor 形状: {x_tensor.shape}")# ## # # 检查 tokenizer 的输入维度# # if hasattr(predictor, 'tokenizer'):# # tokenizer = predictor.tokenizer# # print(f"tokenizer 输入维度: {getattr(tokenizer, 'd_in', 'unknown')}")# ## # except Exception as e:# # print(f"Tensor 转换错误: {e}")# ## # # 在 predict 调用前测试 tokenizer# # try:# # # 测试 tokenizer 是否能正确处理数据# # test_data = x_df.values # (400, 5)# # print(f"测试数据形状: {test_data.shape}")# ## # # 尝试手动调用 tokenizer# # if hasattr(predictor.tokenizer, 'encode'):# # encoded = predictor.tokenizer.encode(test_data)# # print(f"Tokenized 数据形状: {encoded.shape}")# # else:# # print("Tokenizer 没有 encode 方法")# ## # except Exception as e:# # print(f"Tokenizer 测试错误: {e}")## pred_df = predictor.predict(# df=x_df,# x_timestamp=x_timestamp,# y_timestamp=y_timestamp,# pred_len=pred_len,# T=temperature,# top_p=top_p,# sample_count=sample_count# )## except Exception as e:# return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500# else:# return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400## # Prepare actual data for comparison (if exists)# actual_data = []# actual_df = None## if start_date: # Custom time period# # Fix logic: use data within selected window# # Prediction uses first 400 data points within selected window# # Actual data should be last 120 data points within selected window# start_dt = pd.to_datetime(start_date)## # Find data starting from start_date# mask = df['timestamps'] >= start_dt# time_range_df = df[mask]## if len(time_range_df) >= lookback + pred_len:# # Get last 120 data points within selected window as actual values# actual_df = time_range_df.iloc[lookback:lookback+pred_len]## for i, (_, row) in enumerate(actual_df.iterrows()):# actual_data.append({# 'timestamp': row['timestamps'].isoformat(),# 'open': float(row['open']),# 'high': float(row['high']),# 'low': float(row['low']),# 'close': float(row['close']),# 'volume': float(row['volume']) if 'volume' in row else 0,# 'amount': float(row['amount']) if 'amount' in row else 0# })# else: # Latest data# # Prediction uses first 400 data points# # Actual data should be 120 data points after first 400 data points# if len(df) >= lookback + pred_len:# actual_df = df.iloc[lookback:lookback+pred_len]# for i, (_, row) in enumerate(actual_df.iterrows()):# actual_data.append({# 'timestamp': row['timestamps'].isoformat(),# 'open': float(row['open']),# 'high': float(row['high']),# 'low': float(row['low']),# 'close': float(row['close']),# 'volume': float(row['volume']) if 'volume' in row else 0,# 'amount': float(row['amount']) if 'amount' in row else 0# })## # Create chart - pass historical data start position# if start_date:# # Custom time period: find starting position of historical data in original df# start_dt = pd.to_datetime(start_date)# mask = df['timestamps'] >= start_dt# historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0# else:# # Latest data: start from beginning# historical_start_idx = 0## chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)## # Prepare prediction result data - fix timestamp calculation logic# if 'timestamps' in df.columns:# if start_date:# # Custom time period: use selected window data to calculate timestamps# start_dt = pd.to_datetime(start_date)# mask = df['timestamps'] >= start_dt# time_range_df = df[mask]## if len(time_range_df) >= lookback:# # Calculate prediction timestamps starting from last time point of selected window# last_timestamp = time_range_df['timestamps'].iloc[lookback-1]# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]# future_timestamps = pd.date_range(# start=last_timestamp + time_diff,# periods=pred_len,# freq=time_diff# )# else:# future_timestamps = []# else:# # Latest data: calculate from last time point of entire data file# last_timestamp = df['timestamps'].iloc[-1]# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]# future_timestamps = pd.date_range(# start=last_timestamp + time_diff,# periods=pred_len,# freq=time_diff# )# else:# future_timestamps = range(len(df), len(df) + pred_len)## prediction_results = []# for i, (_, row) in enumerate(pred_df.iterrows()):# prediction_results.append({# 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",# 'open': float(row['open']),# 'high': float(row['high']),# 'low': float(row['low']),# 'close': float(row['close']),# 'volume': float(row['volume']) if 'volume' in row else 0,# 'amount': float(row['amount']) if 'amount' in row else 0# })## # Save prediction results to file# try:# save_prediction_results(# file_path=file_path,# prediction_type=prediction_type,# prediction_results=prediction_results,# actual_data=actual_data,# input_data=x_df,# prediction_params={# 'lookback': lookback,# 'pred_len': pred_len,# 'temperature': temperature,# 'top_p': top_p,# 'sample_count': sample_count,# 'start_date': start_date if start_date else 'latest'# }# )# except Exception as e:# print(f"Failed to save prediction results: {e}")## return jsonify({# 'success': True,# 'prediction_type': prediction_type,# 'chart': chart_json,# 'prediction_results': prediction_results,# 'actual_data': actual_data,# 'has_comparison': len(actual_data) > 0,# 'message': f'Prediction completed, generated {pred_len} prediction points' + (f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')# })## except Exception as e:# return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
@app.route('/api/predict', methods=['POST'])def predict(): """Perform prediction""" try: data = request.get_json() file_path = data.get('file_path') lookback = int(data.get('lookback', 400)) pred_len = int(data.get('pred_len', 120))
# Get prediction quality parameters temperature = float(data.get('temperature', 1.0)) top_p = float(data.get('top_p', 0.9)) sample_count = int(data.get('sample_count', 1))
if not file_path: return jsonify({'error': 'File path cannot be empty'}), 400
# Load data df, error = load_data_file(file_path) if error: return jsonify({'error': error}), 400
if len(df) < lookback: return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
# Perform prediction if MODEL_AVAILABLE and predictor is not None: try: # Use real Kronos model # Only use necessary columns: OHLCV + amount required_cols = ['open', 'high', 'low', 'close', 'volume', 'amount']
# Process time period selection start_date = data.get('start_date')
if start_date: # Custom time period - fix logic: use data within selected window start_dt = pd.to_datetime(start_date)
# Find data after start time mask = df['timestamps'] >= start_dt time_range_df = df[mask]
# Ensure sufficient data: lookback + pred_len if len(time_range_df) < lookback + pred_len: return jsonify({ 'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
# Use first lookback data points within selected window for prediction x_df = time_range_df.iloc[:lookback][required_cols] x_timestamp = time_range_df.iloc[:lookback]['timestamps']
# Use last pred_len data points within selected window as actual values y_timestamp = time_range_df.iloc[lookback:lookback + pred_len]['timestamps']
# Calculate actual time period length start_timestamp = time_range_df['timestamps'].iloc[0] end_timestamp = time_range_df['timestamps'].iloc[lookback + pred_len - 1] time_span = end_timestamp - start_timestamp
prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})" else: # Use latest data x_df = df.iloc[:lookback][required_cols] x_timestamp = df.iloc[:lookback]['timestamps'] y_timestamp = df.iloc[lookback:lookback + pred_len]['timestamps'] prediction_type = "Kronos model prediction (latest data)"
# Debug information print(f"🔍 传递给predictor的数据列: {x_df.columns.tolist()}") print(f"🔍 数据形状: {x_df.shape}") print(f"🔍 数据样例:") print(x_df.head(2))
# Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model if isinstance(x_timestamp, pd.DatetimeIndex): x_timestamp = pd.Series(x_timestamp, name='timestamps') if isinstance(y_timestamp, pd.DatetimeIndex): y_timestamp = pd.Series(y_timestamp, name='timestamps')
pred_df = predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=pred_len, T=temperature, top_p=top_p, sample_count=sample_count )
except Exception as e: return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500 else: return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
# Prepare actual data for comparison (if exists) actual_data = [] actual_df = None
if start_date: # Custom time period # Fix logic: use data within selected window # Prediction uses first 400 data points within selected window # Actual data should be last 120 data points within selected window start_dt = pd.to_datetime(start_date)
# Find data starting from start_date mask = df['timestamps'] >= start_dt time_range_df = df[mask]
if len(time_range_df) >= lookback + pred_len: # Get last 120 data points within selected window as actual values actual_df = time_range_df.iloc[lookback:lookback + pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()): actual_data.append({ 'timestamp': row['timestamps'].isoformat(), 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) if 'volume' in row else 0, 'amount': float(row['amount']) if 'amount' in row else 0 }) else: # Latest data # Prediction uses first 400 data points # Actual data should be 120 data points after first 400 data points if len(df) >= lookback + pred_len: actual_df = df.iloc[lookback:lookback + pred_len] for i, (_, row) in enumerate(actual_df.iterrows()): actual_data.append({ 'timestamp': row['timestamps'].isoformat(), 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) if 'volume' in row else 0, 'amount': float(row['amount']) if 'amount' in row else 0 })
# Create chart - pass historical data start position if start_date: # Custom time period: find starting position of historical data in original df start_dt = pd.to_datetime(start_date) mask = df['timestamps'] >= start_dt historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0 else: # Latest data: start from beginning historical_start_idx = 0
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
# Prepare prediction result data - fix timestamp calculation logic if 'timestamps' in df.columns: if start_date: # Custom time period: use selected window data to calculate timestamps start_dt = pd.to_datetime(start_date) mask = df['timestamps'] >= start_dt time_range_df = df[mask]
if len(time_range_df) >= lookback: # Calculate prediction timestamps starting from last time point of selected window last_timestamp = time_range_df['timestamps'].iloc[lookback - 1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] future_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=pred_len, freq=time_diff ) else: future_timestamps = [] else: # Latest data: calculate from last time point of entire data file last_timestamp = df['timestamps'].iloc[-1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] future_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=pred_len, freq=time_diff ) else: future_timestamps = range(len(df), len(df) + pred_len)
prediction_results = [] for i, (_, row) in enumerate(pred_df.iterrows()): prediction_results.append({ 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}", 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) if 'volume' in row else 0, 'amount': float(row['amount']) if 'amount' in row else 0 })
# Save prediction results to file try: save_prediction_results( file_path=file_path, prediction_type=prediction_type, prediction_results=prediction_results, actual_data=actual_data, input_data=x_df, prediction_params={ 'lookback': lookback, 'pred_len': pred_len, 'temperature': temperature, 'top_p': top_p, 'sample_count': sample_count, 'start_date': start_date if start_date else 'latest' } ) except Exception as e: print(f"Failed to save prediction results: {e}")
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# 在返回前添加 print(f"✅ 预测完成,返回数据:") print(f" 成功: {True}") print(f" 预测类型: {prediction_type}") print(f" 图表数据长度: {len(chart_json)}") print(f" 预测结果数量: {len(prediction_results)}") print(f" 实际数据数量: {len(actual_data)}") print(f" 有比较数据: {len(actual_data) > 0}")
return jsonify({ 'success': True, 'prediction_type': prediction_type, 'chart': chart_json, 'prediction_results': prediction_results, 'actual_data': actual_data, 'has_comparison': len(actual_data) > 0, 'message': f'Prediction completed, generated {pred_len} prediction points' + ( f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '') })
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# return jsonify({ # 'success': True, # 'prediction_type': prediction_type, # 'chart': chart_json, # 'prediction_results': prediction_results, # 'actual_data': actual_data, # 'has_comparison': len(actual_data) > 0, # 'message': f'Prediction completed, generated {pred_len} prediction points' + ( # f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '') # })
except Exception as e: return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
# @app.route('/api/load-model', methods=['POST'])# def load_model():# """Load Kronos model"""# global tokenizer, model, predictor## try:# if not MODEL_AVAILABLE:# return jsonify({'error': 'Kronos model library not available'}), 400## data = request.get_json()# model_key = data.get('model_key', 'kronos-small')# device = data.get('device', 'cpu')## if model_key not in AVAILABLE_MODELS:# return jsonify({'error': f'Unsupported model: {model_key}'}), 400## model_config = AVAILABLE_MODELS[model_key]## # Load tokenizer and model# tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])# model = Kronos.from_pretrained(model_config['model_id'])## # Create predictor# predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])## return jsonify({# 'success': True,# 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',# 'model_info': {# 'name': model_config['name'],# 'params': model_config['params'],# 'context_length': model_config['context_length'],# 'description': model_config['description']# }# })## except Exception as e:# return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route('/api/load-model', methods=['POST'])def load_model(): global tokenizer, model, predictor
try: if not MODEL_AVAILABLE: return jsonify({'error': 'Kronos model library not available'}), 400
data = request.get_json() model_key = data.get('model_key', 'kronos-small') device = data.get('device', 'cpu')
if model_key not in AVAILABLE_MODELS: return jsonify({'error': f'Unsupported model: {model_key}'}), 400
model_config = AVAILABLE_MODELS[model_key]
print(f"🚀 Loading model from: {model_config['model_id']}")
model_path = model_config['model_id'] tokenizer_path = model_config['tokenizer_id']
if os.path.exists(model_path): model_files = os.listdir(model_path) print(f"📄 模型目录中的文件: {model_files}")
# 检查模型路径是否存在 if not os.path.exists(model_path): return jsonify({'error': f'模型路径不存在: {model_path}'}), 400
try: # 直接从本地加载模型 model = Kronos.from_pretrained( model_config['model_id'], local_files_only=True )
# 读取模型配置文件获取正确参数 config_path = os.path.join(model_config['model_id'], 'config.json') if os.path.exists(config_path):
print(f"读取配置文件: {config_path}")
with open(config_path, 'r') as f: config = json.load(f)
for key, value in config.items(): print(f" {key}: {value}")
# 使用配置中的参数创建tokenizer tokenizer = KronosTokenizer( d_in=6, # OHLC + volume d_model=config['d_model'], # 832 n_heads=config['n_heads'], # 16 ff_dim=config['ff_dim'], # 2048 n_enc_layers=config['n_layers'], # 12 n_dec_layers=config['n_layers'], # 12 ffn_dropout_p=config['ffn_dropout_p'], # 0.2 attn_dropout_p=config['attn_dropout_p'], # 0.0 resid_dropout_p=config['resid_dropout_p'], # 0.2 s1_bits=config['s1_bits'], # 10 s2_bits=config['s2_bits'], # 10 beta=1.0, gamma0=1.0, gamma=1.0, zeta=1.0, group_size=1 ) else: return jsonify({'error': f'Config file not found: {config_path}'}), 400
except Exception as e: return jsonify({'error': f'Failed to load model: {str(e)}'}), 500
# 创建predictor predictor = KronosPredictor( model, tokenizer, device=device, max_context=model_config['context_length'] )
return jsonify({ 'success': True, 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}', 'model_info': { **model_config, 'model_path': model_config['model_id'], 'abs_model_path': os.path.abspath(model_config['model_id']), 'device': device } })
except Exception as e: import traceback print("【API接口错误】") print(f"错误类型: {type(e).__name__}") print(f"错误信息: {str(e)}") traceback.print_exc() print("=" * 60) return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route('/api/available-models')def get_available_models(): """Get available model list""" return jsonify({ 'models': AVAILABLE_MODELS, 'model_available': MODEL_AVAILABLE })
@app.route('/api/model-status')def get_model_status(): """Get model status""" if MODEL_AVAILABLE: if predictor is not None: return jsonify({ 'available': True, 'loaded': True, 'message': 'Kronos model loaded and available', 'current_model': { 'name': predictor.model.__class__.__name__, 'device': str(next(predictor.model.parameters()).device) } }) else: return jsonify({ 'available': True, 'loaded': False, 'message': 'Kronos model available but not loaded' }) else: return jsonify({ 'available': False, 'loaded': False, 'message': 'Kronos model library not available, please install related dependencies' })
# 股票数据获取接口@app.route('/api/stock-data', methods=['POST'])def Stock_Data(): try: data = request.get_json() stock_code = data.get('stock_code', '').strip()
if not stock_code: return jsonify({ 'success': False, 'error': f'Stock code cannot be empty' }), 400
if not re.match(r'^[a-z]+\.\d+$', stock_code): return jsonify({ 'success': False, 'error': f'The stock code you entered is invalid' }), 400
# 登录 baostock lg = bs.login()
if lg.error_code != '0': return jsonify({ 'success': False, 'error': f'Login failed: {lg.error_msg}' }), 400
end_date = datetime.now().strftime('%Y-%m-%d') rs = bs.query_history_k_data_plus( stock_code, "time,open,high,low,close,volume,amount", start_date = '2024-06-01', end_date = end_date, frequency = "5", adjustflag = "3" )
if rs.error_code != '0': bs.logout() return jsonify({ 'success': False, 'error': f'Failed to retrieve data, please enter a valid stock code' }), 400
data_list = [] while rs.next(): data_list.append(rs.get_row_data())
# 登出系统 bs.logout()
columns = rs.fields df = pd.DataFrame(data_list, columns=columns)
df = df.rename(columns={'time': 'timestamps'})
numeric_columns = ['timestamps','open', 'high', 'low', 'close', 'volume', 'amount'] for col in numeric_columns: df[col] = pd.to_numeric(df[col], errors='coerce')
df['timestamps'] = pd.to_datetime(df['timestamps'].astype(str), format='%Y%m%d%H%M%S%f')
df = df.dropna()
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') os.makedirs(data_dir, exist_ok=True)
filename = f"Stock_5min_A股.csv" file_path = os.path.join(data_dir, filename)
df.to_csv( file_path, index = False, encoding = 'utf-8', mode = 'w' )
data_files = load_data_files()
return jsonify({ 'success': True, 'message': f'Stock data saved successfully: {filename}', 'file_name': filename })
except Exception as e: return jsonify({ 'success': False, 'error': f'Error processing stock data: {str(e)}' }), 500
# 技术指标图表绘制接口@app.route('/api/generate-chart', methods=['POST'])def generate_chart(): try: data = request.get_json()
# 验证参数 required_fields = ['file_path', 'lookback', 'diagram_type', 'historical_start_idx'] for field in required_fields: if field not in data: return jsonify({'success': False, 'error': f'Missing required field: {field}'}), 400
file_path = data['file_path'] lookback = int(data['lookback']) diagram_type = data['diagram_type'] historical_start_idx = int(data['historical_start_idx'])
# 加载数据 df, error = load_data_file(file_path) if error: return jsonify({'success': False, 'error': error}), 400
if len(df) < lookback + historical_start_idx: return jsonify({ 'success': False, 'error': f'Insufficient data length, need at least {lookback + historical_start_idx} rows' }), 400
pred_df = None actual_df = None
# 生成图表 chart_json = create_technical_chart( df=df, pred_df=pred_df, lookback=lookback, pred_len=0, diagram_type=diagram_type, actual_df=actual_df, historical_start_idx=historical_start_idx )
# 表格数据 table_data_start = historical_start_idx table_data_end = historical_start_idx + lookback table_df = df.iloc[table_data_start:table_data_end] table_data = table_df.to_dict('records')
return jsonify({ 'success': True, 'chart': json.loads(chart_json), 'table_data': table_data, 'message': 'Technical chart generated successfully' })
except Exception as e: return jsonify({ 'success': False, 'error': f'Failed to generate technical chart: {str(e)}' }), 500
if __name__ == '__main__': print("Starting Kronos Web UI...") print(f"Model availability: {MODEL_AVAILABLE}") if MODEL_AVAILABLE: print("Tip: You can load Kronos model through /api/load-model endpoint") else: print("Tip: Will use simulated data for demonstration") app.run(debug=True, host='0.0.0.0', port=7070)
|