You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

468 lines
19 KiB

1 month ago
  1. import os
  2. import sys
  3. import json
  4. import time
  5. import pickle
  6. import random
  7. import pandas as pd
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import Dataset, DataLoader
  12. from torch.utils.data.distributed import DistributedSampler
  13. from time import gmtime, strftime
  14. import logging
  15. from logging.handlers import RotatingFileHandler
  16. import datetime
  17. import torch.distributed as dist
  18. from torch.nn.parallel import DistributedDataParallel as DDP
  19. sys.path.append('../')
  20. from model import Kronos, KronosTokenizer, KronosPredictor
  21. from config_loader import CustomFinetuneConfig
  22. class CustomKlineDataset(Dataset):
  23. def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
  24. clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
  25. self.data_path = data_path
  26. self.data_type = data_type
  27. self.lookback_window = lookback_window
  28. self.predict_window = predict_window
  29. self.window = lookback_window + predict_window + 1
  30. self.clip = clip
  31. self.seed = seed
  32. self.train_ratio = train_ratio
  33. self.val_ratio = val_ratio
  34. self.test_ratio = test_ratio
  35. self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
  36. self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
  37. self.py_rng = random.Random(seed)
  38. self._load_and_preprocess_data()
  39. self._split_data_by_time()
  40. self.n_samples = len(self.data) - self.window + 1
  41. print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
  42. def _load_and_preprocess_data(self):
  43. df = pd.read_csv(self.data_path)
  44. df['timestamps'] = pd.to_datetime(df['timestamps'])
  45. df = df.sort_values('timestamps').reset_index(drop=True)
  46. self.timestamps = df['timestamps'].copy()
  47. df['minute'] = df['timestamps'].dt.minute
  48. df['hour'] = df['timestamps'].dt.hour
  49. df['weekday'] = df['timestamps'].dt.weekday
  50. df['day'] = df['timestamps'].dt.day
  51. df['month'] = df['timestamps'].dt.month
  52. self.data = df[self.feature_list + self.time_feature_list].copy()
  53. if self.data.isnull().any().any():
  54. print("Warning: Missing values found in data, performing forward fill")
  55. self.data = self.data.fillna(method='ffill')
  56. print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
  57. print(f"Original data total length: {len(df)} records")
  58. def _split_data_by_time(self):
  59. total_length = len(self.data)
  60. train_end = int(total_length * self.train_ratio)
  61. val_end = int(total_length * (self.train_ratio + self.val_ratio))
  62. if self.data_type == 'train':
  63. self.data = self.data.iloc[:train_end].copy()
  64. self.timestamps = self.timestamps.iloc[:train_end].copy()
  65. print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
  66. print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
  67. elif self.data_type == 'val':
  68. self.data = self.data.iloc[train_end:val_end].copy()
  69. self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
  70. print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
  71. print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
  72. elif self.data_type == 'test':
  73. self.data = self.data.iloc[val_end:].copy()
  74. self.timestamps = self.timestamps.iloc[val_end:].copy()
  75. print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
  76. print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
  77. print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
  78. def set_epoch_seed(self, epoch):
  79. epoch_seed = self.seed + epoch
  80. self.py_rng.seed(epoch_seed)
  81. self.current_epoch = epoch
  82. def __len__(self):
  83. return self.n_samples
  84. def __getitem__(self, idx):
  85. max_start = len(self.data) - self.window
  86. if max_start <= 0:
  87. raise ValueError("Data length insufficient to create samples")
  88. if self.data_type == 'train':
  89. epoch = getattr(self, 'current_epoch', 0)
  90. start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
  91. else:
  92. start_idx = idx % (max_start + 1)
  93. end_idx = start_idx + self.window
  94. window_data = self.data.iloc[start_idx:end_idx]
  95. x = window_data[self.feature_list].values.astype(np.float32)
  96. x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
  97. x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
  98. x = (x - x_mean) / (x_std + 1e-5)
  99. x = np.clip(x, -self.clip, self.clip)
  100. x_tensor = torch.from_numpy(x)
  101. x_stamp_tensor = torch.from_numpy(x_stamp)
  102. return x_tensor, x_stamp_tensor
  103. def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
  104. os.makedirs(log_dir, exist_ok=True)
  105. logger = logging.getLogger(f"basemodel_training_rank_{rank}")
  106. logger.setLevel(logging.INFO)
  107. if logger.handlers:
  108. return logger
  109. log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
  110. file_handler = RotatingFileHandler(
  111. log_file,
  112. maxBytes=10*1024*1024,
  113. backupCount=5,
  114. encoding='utf-8'
  115. )
  116. file_handler.setLevel(logging.INFO)
  117. console_handler = None
  118. if rank == 0:
  119. console_handler = logging.StreamHandler()
  120. console_handler.setLevel(logging.INFO)
  121. formatter = logging.Formatter(
  122. '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  123. datefmt='%Y-%m-%d %H:%M:%S'
  124. )
  125. file_handler.setFormatter(formatter)
  126. if console_handler is not None:
  127. console_handler.setFormatter(formatter)
  128. logger.addHandler(file_handler)
  129. if console_handler is not None:
  130. logger.addHandler(console_handler)
  131. logger.info(f"=== Basemodel Training Started ===")
  132. logger.info(f"Experiment Name: {exp_name}")
  133. logger.info(f"Log Directory: {log_dir}")
  134. logger.info(f"Rank: {rank}")
  135. logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
  136. return logger
  137. def create_dataloaders(config):
  138. if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
  139. print("Creating data loaders...")
  140. train_dataset = CustomKlineDataset(
  141. data_path=config.data_path,
  142. data_type='train',
  143. lookback_window=config.lookback_window,
  144. predict_window=config.predict_window,
  145. clip=config.clip,
  146. seed=config.seed,
  147. train_ratio=config.train_ratio,
  148. val_ratio=config.val_ratio,
  149. test_ratio=config.test_ratio
  150. )
  151. val_dataset = CustomKlineDataset(
  152. data_path=config.data_path,
  153. data_type='val',
  154. lookback_window=config.lookback_window,
  155. predict_window=config.predict_window,
  156. clip=config.clip,
  157. seed=config.seed + 1,
  158. train_ratio=config.train_ratio,
  159. val_ratio=config.val_ratio,
  160. test_ratio=config.test_ratio
  161. )
  162. use_ddp = dist.is_available() and dist.is_initialized()
  163. train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
  164. val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
  165. train_loader = DataLoader(
  166. train_dataset,
  167. batch_size=config.batch_size,
  168. shuffle=(train_sampler is None),
  169. num_workers=config.num_workers,
  170. pin_memory=True,
  171. drop_last=True,
  172. sampler=train_sampler
  173. )
  174. val_loader = DataLoader(
  175. val_dataset,
  176. batch_size=config.batch_size,
  177. shuffle=False,
  178. num_workers=config.num_workers,
  179. pin_memory=True,
  180. drop_last=False,
  181. sampler=val_sampler
  182. )
  183. if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
  184. print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
  185. return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
  186. def train_model(model, tokenizer, device, config, save_dir, logger):
  187. logger.info("Starting training...")
  188. use_ddp = dist.is_available() and dist.is_initialized()
  189. rank = dist.get_rank() if use_ddp else 0
  190. world_size = dist.get_world_size() if use_ddp else 1
  191. train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
  192. optimizer = torch.optim.AdamW(
  193. model.parameters(),
  194. lr=config.predictor_learning_rate,
  195. betas=(config.adam_beta1, config.adam_beta2),
  196. weight_decay=config.adam_weight_decay
  197. )
  198. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  199. optimizer,
  200. max_lr=config.predictor_learning_rate,
  201. steps_per_epoch=len(train_loader),
  202. epochs=config.basemodel_epochs,
  203. pct_start=0.03,
  204. div_factor=10
  205. )
  206. if use_ddp:
  207. local_rank = int(os.environ.get("LOCAL_RANK", "0"))
  208. model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
  209. best_val_loss = float('inf')
  210. batch_idx_global = 0
  211. for epoch in range(config.basemodel_epochs):
  212. epoch_start_time = time.time()
  213. model.train()
  214. train_dataset.set_epoch_seed(epoch * 10000)
  215. val_dataset.set_epoch_seed(0)
  216. if train_sampler is not None:
  217. train_sampler.set_epoch(epoch)
  218. epoch_train_loss = 0.0
  219. train_batches = 0
  220. for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
  221. batch_x = batch_x.to(device, non_blocking=True)
  222. batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
  223. with torch.no_grad():
  224. token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
  225. token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
  226. token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
  227. logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
  228. loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
  229. optimizer.zero_grad()
  230. loss.backward()
  231. torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
  232. optimizer.step()
  233. scheduler.step()
  234. epoch_train_loss += loss.item()
  235. train_batches += 1
  236. if (batch_idx_global + 1) % config.log_interval == 0:
  237. lr = optimizer.param_groups[0]['lr']
  238. log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
  239. f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
  240. logger.info(log_msg)
  241. if rank == 0:
  242. print(log_msg)
  243. batch_idx_global += 1
  244. model.eval()
  245. val_loss = 0.0
  246. val_batches = 0
  247. with torch.no_grad():
  248. for batch_x, batch_x_stamp in val_loader:
  249. batch_x = batch_x.to(device, non_blocking=True)
  250. batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
  251. token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
  252. token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
  253. token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
  254. logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
  255. loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
  256. val_loss += loss.item()
  257. val_batches += 1
  258. if use_ddp:
  259. tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
  260. dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
  261. epoch_train_loss_all = tensor_sum[0].item()
  262. train_batches_all = int(tensor_sum[1].item())
  263. val_loss_all = tensor_sum[2].item()
  264. val_batches_all = int(tensor_sum[3].item())
  265. avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
  266. avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
  267. else:
  268. avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
  269. avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
  270. epoch_time = time.time() - epoch_start_time
  271. epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
  272. f"Training Loss: {avg_train_loss:.4f}\n"
  273. f"Validation Loss: {avg_val_loss:.4f}\n"
  274. f"Epoch Time: {epoch_time:.2f} seconds\n")
  275. logger.info(epoch_summary)
  276. if rank == 0:
  277. print(epoch_summary)
  278. if avg_val_loss < best_val_loss:
  279. best_val_loss = avg_val_loss
  280. if rank == 0:
  281. model_save_path = os.path.join(save_dir, "best_model")
  282. os.makedirs(model_save_path, exist_ok=True)
  283. (model.module if use_ddp else model).save_pretrained(model_save_path)
  284. save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
  285. logger.info(save_msg)
  286. print(save_msg)
  287. return best_val_loss
  288. def main():
  289. import argparse
  290. parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
  291. parser.add_argument('--config', type=str, default='config.yaml',
  292. help='Configuration file path (default: config.yaml)')
  293. args = parser.parse_args()
  294. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  295. print(f"Using device: {device}")
  296. config = CustomFinetuneConfig(args.config)
  297. os.makedirs(config.basemodel_save_path, exist_ok=True)
  298. log_dir = os.path.join(config.base_save_path, "logs")
  299. logger = setup_logging(config.exp_name, log_dir, 0)
  300. torch.manual_seed(config.seed)
  301. np.random.seed(config.seed)
  302. random.seed(config.seed)
  303. logger.info("Loading pretrained model or random initialization...")
  304. print("Loading pretrained model or random initialization...")
  305. if getattr(config, 'pre_trained_tokenizer', True):
  306. tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
  307. else:
  308. import json, os
  309. print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
  310. cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
  311. with open(cfg_path_tok, 'r') as f:
  312. arch_t = json.load(f)
  313. tokenizer = KronosTokenizer(
  314. d_in=arch_t.get('d_in', 6),
  315. d_model=arch_t.get('d_model', 256),
  316. n_heads=arch_t.get('n_heads', 4),
  317. ff_dim=arch_t.get('ff_dim', 512),
  318. n_enc_layers=arch_t.get('n_enc_layers', 4),
  319. n_dec_layers=arch_t.get('n_dec_layers', 4),
  320. ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
  321. attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
  322. resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
  323. s1_bits=arch_t.get('s1_bits', 10),
  324. s2_bits=arch_t.get('s2_bits', 10),
  325. beta=arch_t.get('beta', 0.05),
  326. gamma0=arch_t.get('gamma0', 1.0),
  327. gamma=arch_t.get('gamma', 1.1),
  328. zeta=arch_t.get('zeta', 0.05),
  329. group_size=arch_t.get('group_size', 4)
  330. )
  331. if getattr(config, 'pre_trained_predictor', True):
  332. model = Kronos.from_pretrained(config.pretrained_predictor_path)
  333. else:
  334. import json, os
  335. print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
  336. cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
  337. with open(cfg_path, 'r') as f:
  338. arch = json.load(f)
  339. model = Kronos(
  340. s1_bits=arch.get('s1_bits', 10),
  341. s2_bits=arch.get('s2_bits', 10),
  342. n_layers=arch.get('n_layers', 12),
  343. d_model=arch.get('d_model', 832),
  344. n_heads=arch.get('n_heads', 16),
  345. ff_dim=arch.get('ff_dim', 2048),
  346. ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
  347. attn_dropout_p=arch.get('attn_dropout_p', 0.0),
  348. resid_dropout_p=arch.get('resid_dropout_p', 0.2),
  349. token_dropout_p=arch.get('token_dropout_p', 0.0),
  350. learn_te=arch.get('learn_te', True)
  351. )
  352. tokenizer = tokenizer.to(device)
  353. model = model.to(device)
  354. model_size = sum(p.numel() for p in model.parameters())
  355. logger.info(f"Model parameters: {model_size:,}")
  356. print(f"Model parameters: {model_size:,}")
  357. logger.info("=== Training Configuration ===")
  358. logger.info(f"Data path: {config.data_path}")
  359. logger.info(f"Lookback window: {config.lookback_window}")
  360. logger.info(f"Predict window: {config.predict_window}")
  361. logger.info(f"Batch size: {config.batch_size}")
  362. logger.info(f"Learning rate: {config.predictor_learning_rate}")
  363. logger.info(f"Training epochs: {config.basemodel_epochs}")
  364. logger.info(f"Device: {device}")
  365. logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
  366. logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
  367. logger.info("Starting fine-tuning training...")
  368. print("Starting fine-tuning training...")
  369. best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
  370. final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
  371. logger.info(final_msg)
  372. print(final_msg)
  373. if __name__ == "__main__":
  374. main()