You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
468 lines
19 KiB
468 lines
19 KiB
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import pickle
|
|
import random
|
|
import pandas as pd
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from time import gmtime, strftime
|
|
import logging
|
|
from logging.handlers import RotatingFileHandler
|
|
import datetime
|
|
import torch.distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
sys.path.append('../')
|
|
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
from config_loader import CustomFinetuneConfig
|
|
|
|
|
|
class CustomKlineDataset(Dataset):
|
|
|
|
def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10,
|
|
clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
|
self.data_path = data_path
|
|
self.data_type = data_type
|
|
self.lookback_window = lookback_window
|
|
self.predict_window = predict_window
|
|
self.window = lookback_window + predict_window + 1
|
|
self.clip = clip
|
|
self.seed = seed
|
|
self.train_ratio = train_ratio
|
|
self.val_ratio = val_ratio
|
|
self.test_ratio = test_ratio
|
|
|
|
self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount']
|
|
self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
|
|
|
|
self.py_rng = random.Random(seed)
|
|
|
|
self._load_and_preprocess_data()
|
|
self._split_data_by_time()
|
|
|
|
self.n_samples = len(self.data) - self.window + 1
|
|
|
|
print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}")
|
|
|
|
def _load_and_preprocess_data(self):
|
|
df = pd.read_csv(self.data_path)
|
|
|
|
df['timestamps'] = pd.to_datetime(df['timestamps'])
|
|
df = df.sort_values('timestamps').reset_index(drop=True)
|
|
|
|
self.timestamps = df['timestamps'].copy()
|
|
|
|
df['minute'] = df['timestamps'].dt.minute
|
|
df['hour'] = df['timestamps'].dt.hour
|
|
df['weekday'] = df['timestamps'].dt.weekday
|
|
df['day'] = df['timestamps'].dt.day
|
|
df['month'] = df['timestamps'].dt.month
|
|
|
|
self.data = df[self.feature_list + self.time_feature_list].copy()
|
|
|
|
if self.data.isnull().any().any():
|
|
print("Warning: Missing values found in data, performing forward fill")
|
|
self.data = self.data.fillna(method='ffill')
|
|
|
|
print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
print(f"Original data total length: {len(df)} records")
|
|
|
|
def _split_data_by_time(self):
|
|
total_length = len(self.data)
|
|
|
|
train_end = int(total_length * self.train_ratio)
|
|
val_end = int(total_length * (self.train_ratio + self.val_ratio))
|
|
|
|
if self.data_type == 'train':
|
|
self.data = self.data.iloc[:train_end].copy()
|
|
self.timestamps = self.timestamps.iloc[:train_end].copy()
|
|
print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})")
|
|
print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
elif self.data_type == 'val':
|
|
self.data = self.data.iloc[train_end:val_end].copy()
|
|
self.timestamps = self.timestamps.iloc[train_end:val_end].copy()
|
|
print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})")
|
|
print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
elif self.data_type == 'test':
|
|
self.data = self.data.iloc[val_end:].copy()
|
|
self.timestamps = self.timestamps.iloc[val_end:].copy()
|
|
print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}")
|
|
print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}")
|
|
|
|
print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records")
|
|
|
|
def set_epoch_seed(self, epoch):
|
|
epoch_seed = self.seed + epoch
|
|
self.py_rng.seed(epoch_seed)
|
|
self.current_epoch = epoch
|
|
|
|
def __len__(self):
|
|
return self.n_samples
|
|
|
|
def __getitem__(self, idx):
|
|
max_start = len(self.data) - self.window
|
|
if max_start <= 0:
|
|
raise ValueError("Data length insufficient to create samples")
|
|
|
|
if self.data_type == 'train':
|
|
epoch = getattr(self, 'current_epoch', 0)
|
|
start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1)
|
|
else:
|
|
start_idx = idx % (max_start + 1)
|
|
|
|
end_idx = start_idx + self.window
|
|
|
|
window_data = self.data.iloc[start_idx:end_idx]
|
|
|
|
x = window_data[self.feature_list].values.astype(np.float32)
|
|
x_stamp = window_data[self.time_feature_list].values.astype(np.float32)
|
|
|
|
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
|
|
x = (x - x_mean) / (x_std + 1e-5)
|
|
x = np.clip(x, -self.clip, self.clip)
|
|
|
|
x_tensor = torch.from_numpy(x)
|
|
x_stamp_tensor = torch.from_numpy(x_stamp)
|
|
|
|
return x_tensor, x_stamp_tensor
|
|
|
|
|
|
|
|
|
|
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
|
logger = logging.getLogger(f"basemodel_training_rank_{rank}")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
if logger.handlers:
|
|
return logger
|
|
|
|
log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log")
|
|
file_handler = RotatingFileHandler(
|
|
log_file,
|
|
maxBytes=10*1024*1024,
|
|
backupCount=5,
|
|
encoding='utf-8'
|
|
)
|
|
file_handler.setLevel(logging.INFO)
|
|
|
|
console_handler = None
|
|
if rank == 0:
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setLevel(logging.INFO)
|
|
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
file_handler.setFormatter(formatter)
|
|
if console_handler is not None:
|
|
console_handler.setFormatter(formatter)
|
|
|
|
logger.addHandler(file_handler)
|
|
if console_handler is not None:
|
|
logger.addHandler(console_handler)
|
|
|
|
logger.info(f"=== Basemodel Training Started ===")
|
|
logger.info(f"Experiment Name: {exp_name}")
|
|
logger.info(f"Log Directory: {log_dir}")
|
|
logger.info(f"Rank: {rank}")
|
|
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
return logger
|
|
|
|
|
|
def create_dataloaders(config):
|
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
print("Creating data loaders...")
|
|
|
|
train_dataset = CustomKlineDataset(
|
|
data_path=config.data_path,
|
|
data_type='train',
|
|
lookback_window=config.lookback_window,
|
|
predict_window=config.predict_window,
|
|
clip=config.clip,
|
|
seed=config.seed,
|
|
train_ratio=config.train_ratio,
|
|
val_ratio=config.val_ratio,
|
|
test_ratio=config.test_ratio
|
|
)
|
|
|
|
val_dataset = CustomKlineDataset(
|
|
data_path=config.data_path,
|
|
data_type='val',
|
|
lookback_window=config.lookback_window,
|
|
predict_window=config.predict_window,
|
|
clip=config.clip,
|
|
seed=config.seed + 1,
|
|
train_ratio=config.train_ratio,
|
|
val_ratio=config.val_ratio,
|
|
test_ratio=config.test_ratio
|
|
)
|
|
|
|
use_ddp = dist.is_available() and dist.is_initialized()
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
|
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=(train_sampler is None),
|
|
num_workers=config.num_workers,
|
|
pin_memory=True,
|
|
drop_last=True,
|
|
sampler=train_sampler
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=False,
|
|
num_workers=config.num_workers,
|
|
pin_memory=True,
|
|
drop_last=False,
|
|
sampler=val_sampler
|
|
)
|
|
|
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
|
|
|
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
|
|
|
|
|
def train_model(model, tokenizer, device, config, save_dir, logger):
|
|
logger.info("Starting training...")
|
|
use_ddp = dist.is_available() and dist.is_initialized()
|
|
rank = dist.get_rank() if use_ddp else 0
|
|
world_size = dist.get_world_size() if use_ddp else 1
|
|
|
|
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=config.predictor_learning_rate,
|
|
betas=(config.adam_beta1, config.adam_beta2),
|
|
weight_decay=config.adam_weight_decay
|
|
)
|
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
optimizer,
|
|
max_lr=config.predictor_learning_rate,
|
|
steps_per_epoch=len(train_loader),
|
|
epochs=config.basemodel_epochs,
|
|
pct_start=0.03,
|
|
div_factor=10
|
|
)
|
|
|
|
if use_ddp:
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
|
|
|
best_val_loss = float('inf')
|
|
batch_idx_global = 0
|
|
|
|
for epoch in range(config.basemodel_epochs):
|
|
epoch_start_time = time.time()
|
|
model.train()
|
|
|
|
train_dataset.set_epoch_seed(epoch * 10000)
|
|
val_dataset.set_epoch_seed(0)
|
|
if train_sampler is not None:
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
epoch_train_loss = 0.0
|
|
train_batches = 0
|
|
|
|
for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader):
|
|
batch_x = batch_x.to(device, non_blocking=True)
|
|
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
|
|
|
with torch.no_grad():
|
|
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
|
|
|
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
|
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
|
|
|
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
|
loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
epoch_train_loss += loss.item()
|
|
train_batches += 1
|
|
|
|
if (batch_idx_global + 1) % config.log_interval == 0:
|
|
lr = optimizer.param_groups[0]['lr']
|
|
log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
|
f"LR: {lr:.6f}, Loss: {loss.item():.4f}")
|
|
logger.info(log_msg)
|
|
if rank == 0:
|
|
print(log_msg)
|
|
|
|
batch_idx_global += 1
|
|
|
|
model.eval()
|
|
val_loss = 0.0
|
|
val_batches = 0
|
|
|
|
with torch.no_grad():
|
|
for batch_x, batch_x_stamp in val_loader:
|
|
batch_x = batch_x.to(device, non_blocking=True)
|
|
batch_x_stamp = batch_x_stamp.to(device, non_blocking=True)
|
|
|
|
token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True)
|
|
token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]]
|
|
token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]]
|
|
|
|
logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :])
|
|
loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1])
|
|
|
|
val_loss += loss.item()
|
|
val_batches += 1
|
|
|
|
if use_ddp:
|
|
tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device)
|
|
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
|
epoch_train_loss_all = tensor_sum[0].item()
|
|
train_batches_all = int(tensor_sum[1].item())
|
|
val_loss_all = tensor_sum[2].item()
|
|
val_batches_all = int(tensor_sum[3].item())
|
|
avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0
|
|
avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0
|
|
else:
|
|
avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0
|
|
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
|
|
|
|
epoch_time = time.time() - epoch_start_time
|
|
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n"
|
|
f"Training Loss: {avg_train_loss:.4f}\n"
|
|
f"Validation Loss: {avg_val_loss:.4f}\n"
|
|
f"Epoch Time: {epoch_time:.2f} seconds\n")
|
|
logger.info(epoch_summary)
|
|
if rank == 0:
|
|
print(epoch_summary)
|
|
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
if rank == 0:
|
|
model_save_path = os.path.join(save_dir, "best_model")
|
|
os.makedirs(model_save_path, exist_ok=True)
|
|
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
|
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
|
logger.info(save_msg)
|
|
print(save_msg)
|
|
|
|
return best_val_loss
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training')
|
|
parser.add_argument('--config', type=str, default='config.yaml',
|
|
help='Configuration file path (default: config.yaml)')
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
config = CustomFinetuneConfig(args.config)
|
|
|
|
os.makedirs(config.basemodel_save_path, exist_ok=True)
|
|
|
|
log_dir = os.path.join(config.base_save_path, "logs")
|
|
logger = setup_logging(config.exp_name, log_dir, 0)
|
|
|
|
torch.manual_seed(config.seed)
|
|
np.random.seed(config.seed)
|
|
random.seed(config.seed)
|
|
|
|
logger.info("Loading pretrained model or random initialization...")
|
|
print("Loading pretrained model or random initialization...")
|
|
if getattr(config, 'pre_trained_tokenizer', True):
|
|
tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path)
|
|
else:
|
|
import json, os
|
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training")
|
|
cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json')
|
|
with open(cfg_path_tok, 'r') as f:
|
|
arch_t = json.load(f)
|
|
tokenizer = KronosTokenizer(
|
|
d_in=arch_t.get('d_in', 6),
|
|
d_model=arch_t.get('d_model', 256),
|
|
n_heads=arch_t.get('n_heads', 4),
|
|
ff_dim=arch_t.get('ff_dim', 512),
|
|
n_enc_layers=arch_t.get('n_enc_layers', 4),
|
|
n_dec_layers=arch_t.get('n_dec_layers', 4),
|
|
ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0),
|
|
attn_dropout_p=arch_t.get('attn_dropout_p', 0.0),
|
|
resid_dropout_p=arch_t.get('resid_dropout_p', 0.0),
|
|
s1_bits=arch_t.get('s1_bits', 10),
|
|
s2_bits=arch_t.get('s2_bits', 10),
|
|
beta=arch_t.get('beta', 0.05),
|
|
gamma0=arch_t.get('gamma0', 1.0),
|
|
gamma=arch_t.get('gamma', 1.1),
|
|
zeta=arch_t.get('zeta', 0.05),
|
|
group_size=arch_t.get('group_size', 4)
|
|
)
|
|
|
|
if getattr(config, 'pre_trained_predictor', True):
|
|
model = Kronos.from_pretrained(config.pretrained_predictor_path)
|
|
else:
|
|
import json, os
|
|
print("pre_trained_predictor=False, randomly initializing Predictor architecture for training")
|
|
cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json')
|
|
with open(cfg_path, 'r') as f:
|
|
arch = json.load(f)
|
|
model = Kronos(
|
|
s1_bits=arch.get('s1_bits', 10),
|
|
s2_bits=arch.get('s2_bits', 10),
|
|
n_layers=arch.get('n_layers', 12),
|
|
d_model=arch.get('d_model', 832),
|
|
n_heads=arch.get('n_heads', 16),
|
|
ff_dim=arch.get('ff_dim', 2048),
|
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
|
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
|
learn_te=arch.get('learn_te', True)
|
|
)
|
|
|
|
tokenizer = tokenizer.to(device)
|
|
model = model.to(device)
|
|
|
|
model_size = sum(p.numel() for p in model.parameters())
|
|
logger.info(f"Model parameters: {model_size:,}")
|
|
print(f"Model parameters: {model_size:,}")
|
|
|
|
logger.info("=== Training Configuration ===")
|
|
logger.info(f"Data path: {config.data_path}")
|
|
logger.info(f"Lookback window: {config.lookback_window}")
|
|
logger.info(f"Predict window: {config.predict_window}")
|
|
logger.info(f"Batch size: {config.batch_size}")
|
|
logger.info(f"Learning rate: {config.predictor_learning_rate}")
|
|
logger.info(f"Training epochs: {config.basemodel_epochs}")
|
|
logger.info(f"Device: {device}")
|
|
logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}")
|
|
logger.info(f"Pretrained model path: {config.pretrained_predictor_path}")
|
|
|
|
logger.info("Starting fine-tuning training...")
|
|
print("Starting fine-tuning training...")
|
|
best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger)
|
|
|
|
final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}"
|
|
logger.info(final_msg)
|
|
print(final_msg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|