You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

362 lines
14 KiB

1 month ago
  1. import os
  2. import sys
  3. import argparse
  4. import pickle
  5. from collections import defaultdict
  6. import numpy as np
  7. import pandas as pd
  8. import torch
  9. from torch.utils.data import Dataset, DataLoader
  10. from tqdm import trange, tqdm
  11. from matplotlib import pyplot as plt
  12. import qlib
  13. from qlib.config import REG_CN
  14. from qlib.backtest import backtest, executor, CommonInfrastructure
  15. from qlib.contrib.evaluate import risk_analysis
  16. from qlib.contrib.strategy import TopkDropoutStrategy
  17. from qlib.utils import flatten_dict
  18. from qlib.utils.time import Freq
  19. # Ensure project root is in the Python path
  20. sys.path.append("../")
  21. from config import Config
  22. from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference
  23. # =================================================================================
  24. # 1. Data Loading and Processing for Inference
  25. # =================================================================================
  26. class QlibTestDataset(Dataset):
  27. """
  28. PyTorch Dataset for handling Qlib test data, specifically for inference.
  29. This dataset iterates through all possible sliding windows sequentially. It also
  30. yields metadata like symbol and timestamp, which are crucial for mapping
  31. predictions back to the original time series.
  32. """
  33. def __init__(self, data: dict, config: Config):
  34. self.data = data
  35. self.config = config
  36. self.window_size = config.lookback_window + config.predict_window
  37. self.symbols = list(self.data.keys())
  38. self.feature_list = config.feature_list
  39. self.time_feature_list = config.time_feature_list
  40. self.indices = []
  41. print("Preprocessing and building indices for test dataset...")
  42. for symbol in self.symbols:
  43. df = self.data[symbol].reset_index()
  44. # Generate time features on-the-fly
  45. df['minute'] = df['datetime'].dt.minute
  46. df['hour'] = df['datetime'].dt.hour
  47. df['weekday'] = df['datetime'].dt.weekday
  48. df['day'] = df['datetime'].dt.day
  49. df['month'] = df['datetime'].dt.month
  50. self.data[symbol] = df # Store preprocessed dataframe
  51. num_samples = len(df) - self.window_size + 1
  52. if num_samples > 0:
  53. for i in range(num_samples):
  54. timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime']
  55. self.indices.append((symbol, i, timestamp))
  56. def __len__(self) -> int:
  57. return len(self.indices)
  58. def __getitem__(self, idx: int):
  59. symbol, start_idx, timestamp = self.indices[idx]
  60. df = self.data[symbol]
  61. context_end = start_idx + self.config.lookback_window
  62. predict_end = context_end + self.config.predict_window
  63. context_df = df.iloc[start_idx:context_end]
  64. predict_df = df.iloc[context_end:predict_end]
  65. x = context_df[self.feature_list].values.astype(np.float32)
  66. x_stamp = context_df[self.time_feature_list].values.astype(np.float32)
  67. y_stamp = predict_df[self.time_feature_list].values.astype(np.float32)
  68. # Instance-level normalization, consistent with training
  69. x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
  70. x = (x - x_mean) / (x_std + 1e-5)
  71. x = np.clip(x, -self.config.clip, self.config.clip)
  72. return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp
  73. # =================================================================================
  74. # 2. Backtesting Logic
  75. # =================================================================================
  76. class QlibBacktest:
  77. """
  78. A wrapper class for conducting backtesting experiments using Qlib.
  79. """
  80. def __init__(self, config: Config):
  81. self.config = config
  82. self.initialize_qlib()
  83. def initialize_qlib(self):
  84. """Initializes the Qlib environment."""
  85. print("Initializing Qlib for backtesting...")
  86. qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
  87. def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame:
  88. """
  89. Runs a single backtest for a given prediction signal.
  90. Args:
  91. signal_series (pd.Series): A pandas Series with a MultiIndex
  92. (instrument, datetime) and prediction scores.
  93. Returns:
  94. pd.DataFrame: A DataFrame containing the performance report.
  95. """
  96. strategy = TopkDropoutStrategy(
  97. topk=self.config.backtest_n_symbol_hold,
  98. n_drop=self.config.backtest_n_symbol_drop,
  99. hold_thresh=self.config.backtest_hold_thresh,
  100. signal=signal_series,
  101. )
  102. executor_config = {
  103. "time_per_step": "day",
  104. "generate_portfolio_metrics": True,
  105. "delay_execution": True,
  106. }
  107. backtest_config = {
  108. "start_time": self.config.backtest_time_range[0],
  109. "end_time": self.config.backtest_time_range[1],
  110. "account": 100_000_000,
  111. "benchmark": self.config.backtest_benchmark,
  112. "exchange_kwargs": {
  113. "freq": "day", "limit_threshold": 0.095, "deal_price": "open",
  114. "open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5,
  115. },
  116. "executor": executor.SimulatorExecutor(**executor_config),
  117. }
  118. portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config)
  119. analysis_freq = "{0}{1}".format(*Freq.parse("day"))
  120. report, _ = portfolio_metric_dict.get(analysis_freq)
  121. # --- Analysis and Reporting ---
  122. analysis = {
  123. "excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq),
  124. "excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq),
  125. }
  126. print("\n--- Backtest Analysis ---")
  127. print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n')
  128. print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n')
  129. print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n')
  130. report_df = pd.DataFrame({
  131. "cum_bench": report["bench"].cumsum(),
  132. "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(),
  133. "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(),
  134. })
  135. return report_df
  136. def run_and_plot_results(self, signals: dict[str, pd.DataFrame]):
  137. """
  138. Runs backtests for multiple signals and plots the cumulative return curves.
  139. Args:
  140. signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names
  141. and values are prediction DataFrames.
  142. """
  143. return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
  144. for signal_name, pred_df in signals.items():
  145. print(f"\nBacktesting signal: {signal_name}...")
  146. pred_series = pred_df.stack()
  147. pred_series.index.names = ['datetime', 'instrument']
  148. pred_series = pred_series.swaplevel().sort_index()
  149. report_df = self.run_single_backtest(pred_series)
  150. return_df[signal_name] = report_df['cum_return_w_cost']
  151. ex_return_df[signal_name] = report_df['cum_ex_return_w_cost']
  152. if 'return' not in bench_df:
  153. bench_df['return'] = report_df['cum_bench']
  154. # Plotting results
  155. fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
  156. return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True)
  157. axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--')
  158. axes[0].legend()
  159. axes[0].set_ylabel("Cumulative Return")
  160. ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True)
  161. axes[1].legend()
  162. axes[1].set_xlabel("Date")
  163. axes[1].set_ylabel("Cumulative Excess Return")
  164. plt.tight_layout()
  165. plt.savefig("../figures/backtest_result_example.png", dpi=200)
  166. plt.show()
  167. # =================================================================================
  168. # 3. Inference Logic
  169. # =================================================================================
  170. def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]:
  171. """Loads the fine-tuned tokenizer and predictor model."""
  172. device = torch.device(config['device'])
  173. print(f"Loading models onto device: {device}...")
  174. tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval()
  175. model = Kronos.from_pretrained(config['model_path']).to(device).eval()
  176. return tokenizer, model
  177. def collate_fn_for_inference(batch):
  178. """
  179. Custom collate function to handle batches containing Tensors, strings, and Timestamps.
  180. Args:
  181. batch (list): A list of samples, where each sample is the tuple returned by
  182. QlibTestDataset.__getitem__.
  183. Returns:
  184. A single tuple containing the batched data.
  185. """
  186. # Unzip the list of samples into separate lists for each data type
  187. x, x_stamp, y_stamp, symbols, timestamps = zip(*batch)
  188. # Stack the tensors to create a batch
  189. x_batch = torch.stack(x, dim=0)
  190. x_stamp_batch = torch.stack(x_stamp, dim=0)
  191. y_stamp_batch = torch.stack(y_stamp, dim=0)
  192. # Return the strings and timestamps as lists
  193. return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps)
  194. def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]:
  195. """
  196. Runs inference on the test dataset to generate prediction signals.
  197. Args:
  198. config (dict): A dictionary containing inference parameters.
  199. test_data (dict): The raw test data loaded from a pickle file.
  200. Returns:
  201. A dictionary where keys are signal types (e.g., 'mean', 'last') and
  202. values are DataFrames of predictions (datetime index, symbol columns).
  203. """
  204. tokenizer, model = load_models(config)
  205. device = torch.device(config['device'])
  206. # Use the Dataset and DataLoader for efficient batching and processing
  207. dataset = QlibTestDataset(data=test_data, config=Config())
  208. loader = DataLoader(
  209. dataset,
  210. batch_size=config['batch_size'] // config['sample_count'],
  211. shuffle=False,
  212. num_workers=os.cpu_count() // 2,
  213. collate_fn=collate_fn_for_inference
  214. )
  215. results = defaultdict(list)
  216. with torch.no_grad():
  217. for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
  218. preds = auto_regressive_inference(
  219. tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device),
  220. max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'],
  221. T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count']
  222. )
  223. # You can try commenting on this line to keep the history data
  224. preds = preds[:, -config['pred_len']:, :]
  225. # The 'close' price is at index 3 in `feature_list`
  226. last_day_close = x[:, -1, 3].numpy()
  227. signals = {
  228. 'last': preds[:, -1, 3] - last_day_close,
  229. 'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close,
  230. 'max': np.max(preds[:, :, 3], axis=1) - last_day_close,
  231. 'min': np.min(preds[:, :, 3], axis=1) - last_day_close,
  232. }
  233. for i in range(len(symbols)):
  234. for sig_type, sig_values in signals.items():
  235. results[sig_type].append((timestamps[i], symbols[i], sig_values[i]))
  236. print("Post-processing predictions into DataFrames...")
  237. prediction_dfs = {}
  238. for sig_type, records in results.items():
  239. df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score'])
  240. pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score')
  241. prediction_dfs[sig_type] = pivot_df.sort_index()
  242. return prediction_dfs
  243. # =================================================================================
  244. # 4. Main Execution
  245. # =================================================================================
  246. def main():
  247. """Main function to set up config, run inference, and execute backtesting."""
  248. parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting")
  249. parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')")
  250. args = parser.parse_args()
  251. # --- 1. Configuration Setup ---
  252. base_config = Config()
  253. # Create a dedicated dictionary for this run's configuration
  254. run_config = {
  255. 'device': args.device,
  256. 'data_path': base_config.dataset_path,
  257. 'result_save_path': base_config.backtest_result_path,
  258. 'result_name': base_config.backtest_save_folder_name,
  259. 'tokenizer_path': base_config.finetuned_tokenizer_path,
  260. 'model_path': base_config.finetuned_predictor_path,
  261. 'max_context': base_config.max_context,
  262. 'pred_len': base_config.predict_window,
  263. 'clip': base_config.clip,
  264. 'T': base_config.inference_T,
  265. 'top_k': base_config.inference_top_k,
  266. 'top_p': base_config.inference_top_p,
  267. 'sample_count': base_config.inference_sample_count,
  268. 'batch_size': base_config.backtest_batch_size,
  269. }
  270. print("--- Running with Configuration ---")
  271. for key, val in run_config.items():
  272. print(f"{key:>20}: {val}")
  273. print("-" * 35)
  274. # --- 2. Load Data ---
  275. test_data_path = os.path.join(run_config['data_path'], "test_data.pkl")
  276. print(f"Loading test data from {test_data_path}...")
  277. with open(test_data_path, 'rb') as f:
  278. test_data = pickle.load(f)
  279. print(test_data)
  280. # --- 3. Generate Predictions ---
  281. model_preds = generate_predictions(run_config, test_data)
  282. # --- 4. Save Predictions ---
  283. save_dir = os.path.join(run_config['result_save_path'], run_config['result_name'])
  284. os.makedirs(save_dir, exist_ok=True)
  285. predictions_file = os.path.join(save_dir, "predictions.pkl")
  286. print(f"Saving prediction signals to {predictions_file}...")
  287. with open(predictions_file, 'wb') as f:
  288. pickle.dump(model_preds, f)
  289. # --- 5. Run Backtesting ---
  290. with open(predictions_file, 'rb') as f:
  291. model_preds = pickle.load(f)
  292. backtester = QlibBacktest(base_config)
  293. backtester.run_and_plot_results(model_preds)
  294. if __name__ == '__main__':
  295. main()