@ -9,6 +9,9 @@ from flask_cors import CORS
import sys
import warnings
import datetime
import baostock as bs
import re
warnings . filterwarnings ( ' ignore ' )
# Add project root directory to path
@ -107,6 +110,7 @@ def load_data_files():
return data_files
def load_data_file ( file_path ) :
""" Load data file """
try :
@ -154,6 +158,7 @@ def load_data_file(file_path):
except Exception as e :
return None , f " Failed to load file: {str(e)} "
def save_prediction_results ( file_path , prediction_type , prediction_results , actual_data , input_data , prediction_params ) :
""" Save prediction results to file """
try :
@ -391,7 +396,6 @@ def save_prediction_results(file_path, prediction_type, prediction_results, actu
def create_prediction_chart ( df , pred_df , lookback , pred_len , actual_df = None , historical_start_idx = 0 ) :
""" Create prediction chart """
print ( f " 🔍 创建图表调试: " )
print ( f " 历史数据: {len(df) if df is not None else 0} 行 " )
print ( f " 预测数据: {len(pred_df) if pred_df is not None else 0} 行 " )
@ -538,17 +542,359 @@ def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, his
error_fig . update_layout ( title = ' Chart Rendering Error ' )
return error_fig . to_json ( )
# 计算指标
def calculate_indicators ( df ) :
indicators = { }
# 计算移动平均线 (MA)
indicators [ ' ma5 ' ] = df [ ' close ' ] . rolling ( window = 5 ) . mean ( )
indicators [ ' ma10 ' ] = df [ ' close ' ] . rolling ( window = 10 ) . mean ( )
indicators [ ' ma20 ' ] = df [ ' close ' ] . rolling ( window = 20 ) . mean ( )
# 计算MACD
exp12 = df [ ' close ' ] . ewm ( span = 12 , adjust = False ) . mean ( )
exp26 = df [ ' close ' ] . ewm ( span = 26 , adjust = False ) . mean ( )
indicators [ ' macd ' ] = exp12 - exp26
indicators [ ' signal ' ] = indicators [ ' macd ' ] . ewm ( span = 9 , adjust = False ) . mean ( )
indicators [ ' macd_hist ' ] = indicators [ ' macd ' ] - indicators [ ' signal ' ]
# 计算RSI
delta = df [ ' close ' ] . diff ( )
gain = ( delta . where ( delta > 0 , 0 ) ) . rolling ( window = 14 ) . mean ( )
loss = ( - delta . where ( delta < 0 , 0 ) ) . rolling ( window = 14 ) . mean ( )
rs = gain / loss
indicators [ ' rsi ' ] = 100 - ( 100 / ( 1 + rs ) )
# 计算布林带
indicators [ ' bb_mid ' ] = df [ ' close ' ] . rolling ( window = 20 ) . mean ( )
indicators [ ' bb_std ' ] = df [ ' close ' ] . rolling ( window = 20 ) . std ( )
indicators [ ' bb_upper ' ] = indicators [ ' bb_mid ' ] + 2 * indicators [ ' bb_std ' ]
indicators [ ' bb_lower ' ] = indicators [ ' bb_mid ' ] - 2 * indicators [ ' bb_std ' ]
# 计算随机震荡指标
low_min = df [ ' low ' ] . rolling ( window = 14 ) . min ( )
high_max = df [ ' high ' ] . rolling ( window = 14 ) . max ( )
indicators [ ' stoch_k ' ] = 100 * ( ( df [ ' close ' ] - low_min ) / ( high_max - low_min ) )
indicators [ ' stoch_d ' ] = indicators [ ' stoch_k ' ] . rolling ( window = 3 ) . mean ( )
# 滚动窗口均值策略
indicators [ ' rwms_window ' ] = 90
indicators [ ' rwms_mean ' ] = df [ ' close ' ] . rolling ( window = 90 ) . mean ( )
indicators [ ' rwms_signal ' ] = ( df [ ' close ' ] > indicators [ ' rwms_mean ' ] ) . astype ( int )
# 三重指数平均(TRIX)策略
# 计算收盘价的EMA
ema1 = df [ ' close ' ] . ewm ( span = 12 , adjust = False ) . mean ( )
# 计算EMA的EMA
ema2 = ema1 . ewm ( span = 12 , adjust = False ) . mean ( )
# 计算EMA的EMA的EMA
ema3 = ema2 . ewm ( span = 12 , adjust = False ) . mean ( )
# 计算TRIX
indicators [ ' trix ' ] = ( ema3 - ema3 . shift ( 1 ) ) / ema3 . shift ( 1 ) * 100
# 计算信号线
indicators [ ' trix_signal ' ] = indicators [ ' trix ' ] . ewm ( span = 9 , adjust = False ) . mean ( )
return indicators
# 创建图表
def create_technical_chart ( df , pred_df , lookback , pred_len , diagram_type , actual_df = None , historical_start_idx = 0 ) :
print ( f " 🔍 数据内容: {len(df) if df is not None else 0} 行 " )
print ( f " 🔍 图表类型: {diagram_type} " )
# 数据范围
if historical_start_idx + lookback < = len ( df ) :
historical_df = df . iloc [ historical_start_idx : historical_start_idx + lookback ]
else :
available_lookback = min ( lookback , len ( df ) - historical_start_idx )
historical_df = df . iloc [ historical_start_idx : historical_start_idx + available_lookback ]
# 计算指标
historical_indicators = calculate_indicators ( historical_df )
fig = go . Figure ( )
# 成交量图表
if diagram_type == ' Volume Chart (VOL) ' :
fig . add_trace ( go . Bar (
x = historical_df [ ' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_df [ ' volume ' ] . tolist ( ) if ' volume ' in historical_df . columns else [ ] ,
name = ' Historical Volume ' ,
marker_color = ' #42A5F5 '
) )
if actual_df is not None and len ( actual_df ) > 0 and ' volume ' in actual_df . columns :
if ' timestamps ' in df . columns and len ( historical_df ) > 0 :
last_timestamp = historical_df [ ' timestamps ' ] . iloc [ - 1 ]
time_diff = df [ ' timestamps ' ] . iloc [ 1 ] - df [ ' timestamps ' ] . iloc [ 0 ] if len ( df ) > 1 else pd . Timedelta (
hours = 1 )
actual_timestamps = pd . date_range ( start = last_timestamp + time_diff , periods = len ( actual_df ) , freq = time_diff )
else :
actual_timestamps = range ( len ( historical_df ) , len ( historical_df ) + len ( actual_df ) )
fig . add_trace ( go . Bar (
x = actual_timestamps . tolist ( ) if hasattr ( actual_timestamps , ' tolist ' ) else list ( actual_timestamps ) ,
y = actual_df [ ' volume ' ] . tolist ( ) ,
name = ' Actual Volume ' ,
marker_color = ' #FF9800 '
) )
fig . update_layout ( yaxis_title = ' Volume ' )
# 移动平均线
elif diagram_type == ' Moving Average (MA) ' :
fig . add_trace ( go . Scatter (
x = historical_df [ ' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' ma5 ' ] ,
name = ' MA5 ' ,
line = dict ( color = ' #26A69A ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' ma10 ' ] ,
name = ' MA10 ' ,
line = dict ( color = ' #42A5F5 ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' ma20 ' ] ,
name = ' MA20 ' ,
line = dict ( color = ' #7E57C2 ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_df [ ' close ' ] ,
name = ' Close Price ' ,
line = dict ( color = ' #212121 ' , width = 1 , dash = ' dash ' )
) )
fig . update_layout ( yaxis_title = ' Price ' )
# MACD指标
elif diagram_type == ' MACD Indicator (MACD) ' :
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' macd ' ] ,
name = ' MACD ' ,
line = dict ( color = ' #26A69A ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' signal ' ] ,
name = ' Signal ' ,
line = dict ( color = ' #EF5350 ' , width = 1 )
) )
fig . add_trace ( go . Bar (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' macd_hist ' ] ,
name = ' MACD Histogram ' ,
marker_color = ' #42A5F5 '
) )
# 零轴线
fig . add_hline ( y = 0 , line_dash = " dash " , line_color = " gray " )
fig . update_layout ( yaxis_title = ' MACD ' )
# RSI指标
elif diagram_type == ' RSI Indicator (RSI) ' :
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' rsi ' ] ,
name = ' RSI ' ,
line = dict ( color = ' #26A69A ' , width = 1 )
) )
# 超买超卖线
fig . add_hline ( y = 70 , line_dash = " dash " , line_color = " red " , name = ' Overbought ' )
fig . add_hline ( y = 30 , line_dash = " dash " , line_color = " green " , name = ' Oversold ' )
fig . update_layout ( yaxis_title = ' RSI ' , yaxis_range = [ 0 , 100 ] )
# 布林带
elif diagram_type == ' Bollinger Bands (BB) ' :
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' bb_upper ' ] ,
name = ' Upper Band ' ,
line = dict ( color = ' #EF5350 ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' bb_mid ' ] ,
name = ' Middle Band (MA20) ' ,
line = dict ( color = ' #42A5F5 ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' bb_lower ' ] ,
name = ' Lower Band ' ,
line = dict ( color = ' #26A69A ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_df [ ' close ' ] ,
name = ' Close Price ' ,
line = dict ( color = ' #212121 ' , width = 1 )
) )
fig . update_layout ( yaxis_title = ' Price ' )
# 随机震荡指标
elif diagram_type == ' Stochastic Oscillator (STOCH) ' :
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' stoch_k ' ] ,
name = ' % K ' ,
line = dict ( color = ' #26A69A ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' stoch_d ' ] ,
name = ' % D ' ,
line = dict ( color = ' #EF5350 ' , width = 1 )
) )
fig . add_hline ( y = 80 , line_dash = " dash " , line_color = " red " , name = ' Overbought ' )
fig . add_hline ( y = 20 , line_dash = " dash " , line_color = " green " , name = ' Oversold ' )
fig . update_layout ( yaxis_title = ' Stochastic ' , yaxis_range = [ 0 , 100 ] )
# 滚动窗口均值策略
elif diagram_type == ' Rolling Window Mean Strategy ' :
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_df [ ' close ' ] ,
name = ' Close Price ' ,
line = dict ( color = ' #212121 ' , width = 1.5 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' rwms_mean ' ] ,
name = f ' Rolling Mean ({historical_indicators[ " rwms_window " ]} periods) ' ,
line = dict ( color = ' #42A5F5 ' , width = 1.5 , dash = ' dash ' )
) )
buy_signals = historical_df [ historical_indicators [ ' rwms_signal ' ] == 1 ]
fig . add_trace ( go . Scatter (
x = buy_signals [ ' timestamps ' ] . tolist ( ) if ' timestamps ' in buy_signals . columns else buy_signals . index . tolist ( ) ,
y = buy_signals [ ' close ' ] ,
mode = ' markers ' ,
name = ' Buy Signal ' ,
marker = dict ( color = ' #26A69A ' , size = 8 , symbol = ' triangle-up ' )
) )
sell_signals = historical_df [ historical_indicators [ ' rwms_signal ' ] == 0 ]
fig . add_trace ( go . Scatter (
x = sell_signals [
' timestamps ' ] . tolist ( ) if ' timestamps ' in sell_signals . columns else sell_signals . index . tolist ( ) ,
y = sell_signals [ ' close ' ] ,
mode = ' markers ' ,
name = ' Sell Signal ' ,
marker = dict ( color = ' #EF5350 ' , size = 8 , symbol = ' triangle-down ' )
) )
fig . update_layout (
yaxis_title = ' Price ' ,
title = f ' Rolling Window Mean Strategy (Window Size: {historical_indicators[ " rwms_window " ]}) '
)
# TRIX指标图表
elif diagram_type == ' TRIX Indicator (TRIX) ' :
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' trix ' ] ,
name = ' TRIX ' ,
line = dict ( color = ' #26A69A ' , width = 1 )
) )
fig . add_trace ( go . Scatter (
x = historical_df [
' timestamps ' ] . tolist ( ) if ' timestamps ' in historical_df . columns else historical_df . index . tolist ( ) ,
y = historical_indicators [ ' trix_signal ' ] ,
name = ' TRIX Signal ' ,
line = dict ( color = ' #EF5350 ' , width = 1 )
) )
fig . add_hline ( y = 0 , line_dash = " dash " , line_color = " gray " )
fig . update_layout (
yaxis_title = ' TRIX ( % ) ' ,
title = ' Triple Exponential Average (TRIX) Strategy '
)
# 布局设置
fig . update_layout (
title = f ' {diagram_type} - Technical Indicator (Real Data Only) ' ,
xaxis_title = ' Time ' ,
template = ' plotly_white ' ,
height = 400 ,
showlegend = True ,
margin = dict ( t = 50 , b = 30 )
)
if ' timestamps ' in historical_df . columns :
all_timestamps = historical_df [ ' timestamps ' ] . tolist ( )
if actual_df is not None and len ( actual_df ) > 0 and ' timestamps ' in df . columns :
if ' actual_timestamps ' in locals ( ) :
all_timestamps . extend ( actual_timestamps . tolist ( ) )
if all_timestamps :
all_timestamps = sorted ( all_timestamps )
fig . update_xaxes (
range = [ all_timestamps [ 0 ] , all_timestamps [ - 1 ] ] ,
rangeslider_visible = False ,
type = ' date '
)
try :
chart_json = fig . to_json ( )
print ( f " ✅ 技术指标图表序列化完成,长度: {len(chart_json)} " )
return chart_json
except Exception as e :
print ( f " ❌ 技术指标图表序列化失败: {e} " )
error_fig = go . Figure ( )
error_fig . update_layout ( title = ' Chart Rendering Error ' )
return error_fig . to_json ( )
@app.route ( ' / ' )
def index ( ) :
""" Home page """
return render_template ( ' index.html ' )
@app.route ( ' /api/data-files ' )
def get_data_files ( ) :
""" Get available data file list """
data_files = load_data_files ( )
return jsonify ( data_files )
@app.route ( ' /api/load-data ' , methods = [ ' POST ' ] )
def load_data ( ) :
""" Load data file """
@ -907,6 +1253,7 @@ def load_data():
# except Exception as e:
# return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
@app.route ( ' /api/predict ' , methods = [ ' POST ' ] )
def predict ( ) :
""" Perform prediction """
@ -1198,6 +1545,7 @@ def predict():
# except Exception as e:
# return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route ( ' /api/load-model ' , methods = [ ' POST ' ] )
def load_model ( ) :
global tokenizer , model , predictor
@ -1278,6 +1626,7 @@ def load_model():
except Exception as e :
return jsonify ( { ' error ' : f ' Model loading failed: {str(e)} ' } ) , 500
@app.route ( ' /api/available-models ' )
def get_available_models ( ) :
""" Get available model list """
@ -1286,6 +1635,7 @@ def get_available_models():
' model_available ' : MODEL_AVAILABLE
} )
@app.route ( ' /api/model-status ' )
def get_model_status ( ) :
""" Get model status """
@ -1313,6 +1663,167 @@ def get_model_status():
' message ' : ' Kronos model library not available, please install related dependencies '
} )
@app.route ( ' /api/stock-data ' , methods = [ ' POST ' ] )
def Stock_Data ( ) :
try :
data = request . get_json ( )
stock_code = data . get ( ' stock_code ' , ' ' ) . strip ( )
# 股票代码不能为空
if not stock_code :
return jsonify ( {
' success ' : False ,
' error ' : f ' Stock code cannot be empty '
} ) , 400
# 股票代码格式验证
if not re . match ( r ' ^[a-z]+ \ . \ d+$ ' , stock_code ) :
return jsonify ( {
' success ' : False ,
' error ' : f ' The stock code you entered is invalid '
} ) , 400
# 登录 baostock
lg = bs . login ( )
if lg . error_code != ' 0 ' :
return jsonify ( {
' success ' : False ,
' error ' : f ' Login failed: {lg.error_msg} '
} ) , 400
rs = bs . query_history_k_data_plus (
stock_code ,
" time,open,high,low,close,volume,amount " ,
start_date = ' 2024-06-01 ' ,
end_date = ' 2024-10-31 ' ,
frequency = " 5 " ,
adjustflag = " 3 "
)
# 检查获取结果
if rs . error_code != ' 0 ' :
bs . logout ( )
return jsonify ( {
' success ' : False ,
' error ' : f ' Failed to retrieve data, please enter a valid stock code '
} ) , 400
# 提取数据
data_list = [ ]
while rs . next ( ) :
data_list . append ( rs . get_row_data ( ) )
# 登出系统
bs . logout ( )
columns = rs . fields
df = pd . DataFrame ( data_list , columns = columns )
# 数值列转换
df = df . rename ( columns = { ' time ' : ' timestamps ' } )
numeric_columns = [ ' timestamps ' , ' open ' , ' high ' , ' low ' , ' close ' , ' volume ' , ' amount ' ]
for col in numeric_columns :
df [ col ] = pd . to_numeric ( df [ col ] , errors = ' coerce ' )
df [ ' timestamps ' ] = pd . to_datetime ( df [ ' timestamps ' ] . astype ( str ) , format = ' % Y % m %d % H % M % S %f ' )
# 去除无效数据
df = df . dropna ( )
# 保存
data_dir = os . path . join ( os . path . dirname ( os . path . dirname ( os . path . abspath ( __file__ ) ) ) , ' data ' )
os . makedirs ( data_dir , exist_ok = True )
filename = f " Stock_5min_A股.csv "
file_path = os . path . join ( data_dir , filename )
df . to_csv (
file_path ,
index = False ,
encoding = ' utf-8 ' ,
mode = ' w '
)
data_files = load_data_files ( )
return jsonify ( {
' success ' : True ,
' message ' : f ' Stock data saved successfully: {filename} ' ,
' file_name ' : filename
} )
except Exception as e :
return jsonify ( {
' success ' : False ,
' error ' : f ' Error processing stock data: {str(e)} '
} ) , 500
@app.route ( ' /api/generate-chart ' , methods = [ ' POST ' ] )
def generate_chart ( ) :
try :
data = request . get_json ( )
# 验证参数
required_fields = [ ' file_path ' , ' lookback ' , ' diagram_type ' , ' historical_start_idx ' ]
for field in required_fields :
if field not in data :
return jsonify ( { ' success ' : False , ' error ' : f ' Missing required field: {field} ' } ) , 400
# 解析参数
file_path = data [ ' file_path ' ]
lookback = int ( data [ ' lookback ' ] )
diagram_type = data [ ' diagram_type ' ]
historical_start_idx = int ( data [ ' historical_start_idx ' ] )
# 加载数据
df , error = load_data_file ( file_path )
if error :
return jsonify ( { ' success ' : False , ' error ' : error } ) , 400
if len ( df ) < lookback + historical_start_idx :
return jsonify ( {
' success ' : False ,
' error ' : f ' Insufficient data length, need at least {lookback + historical_start_idx} rows '
} ) , 400
pred_df = None
actual_df = None
# 生成图表
chart_json = create_technical_chart (
df = df ,
pred_df = pred_df ,
lookback = lookback ,
pred_len = 0 ,
diagram_type = diagram_type ,
actual_df = actual_df ,
historical_start_idx = historical_start_idx
)
# 表格数据
table_data_start = historical_start_idx
table_data_end = historical_start_idx + lookback
table_df = df . iloc [ table_data_start : table_data_end ]
table_data = table_df . to_dict ( ' records ' )
return jsonify ( {
' success ' : True ,
' chart ' : json . loads ( chart_json ) ,
' table_data ' : table_data ,
' message ' : ' Technical chart generated successfully '
} )
except Exception as e :
return jsonify ( {
' success ' : False ,
' error ' : f ' Failed to generate technical chart: {str(e)} '
} ) , 500
if __name__ == ' __main__ ' :
print ( " Starting Kronos Web UI... " )
print ( f " Model availability: {MODEL_AVAILABLE} " )