You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
361 lines
15 KiB
361 lines
15 KiB
import os
|
|
import sys
|
|
import time
|
|
import argparse
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
import torch.distributed as dist
|
|
|
|
sys.path.append('../')
|
|
from model import Kronos, KronosTokenizer, KronosPredictor
|
|
|
|
from config_loader import CustomFinetuneConfig
|
|
from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging
|
|
from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
|
|
|
|
|
|
class SequentialTrainer:
|
|
|
|
def __init__(self, config_path: str = None):
|
|
self.config = CustomFinetuneConfig(config_path)
|
|
self.rank = int(os.environ.get("RANK", "0"))
|
|
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0)))
|
|
self.device = self._setup_device()
|
|
|
|
self.config.print_config_summary()
|
|
|
|
def _setup_device(self):
|
|
if self.config.use_cuda and torch.cuda.is_available():
|
|
torch.cuda.set_device(self.local_rank)
|
|
device = torch.device(f"cuda:{self.local_rank}")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
if self.rank == 0:
|
|
print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})")
|
|
return device
|
|
|
|
def _setup_distributed(self):
|
|
if self.world_size > 1 and torch.cuda.is_available():
|
|
backend = os.environ.get("DIST_BACKEND", "nccl").lower()
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(backend=backend)
|
|
if self.rank == 0:
|
|
print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}")
|
|
else:
|
|
if self.rank == 0:
|
|
print("Distributed training not enabled, using single GPU/CPU training")
|
|
|
|
def _check_existing_models(self):
|
|
tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path)
|
|
basemodel_exists = os.path.exists(self.config.basemodel_best_model_path)
|
|
|
|
print(f"Tokenizer model exists: {tokenizer_exists}")
|
|
print(f"Basemodel model exists: {basemodel_exists}")
|
|
|
|
return tokenizer_exists, basemodel_exists
|
|
|
|
def _create_directories(self):
|
|
os.makedirs(self.config.tokenizer_save_path, exist_ok=True)
|
|
os.makedirs(self.config.basemodel_save_path, exist_ok=True)
|
|
print(f"Created directory: {self.config.tokenizer_save_path}")
|
|
print(f"Created directory: {self.config.basemodel_save_path}")
|
|
|
|
def train_tokenizer_phase(self):
|
|
print("\n" + "="*60)
|
|
print("Starting Tokenizer Fine-tuning Phase")
|
|
print("="*60)
|
|
|
|
tokenizer_exists, _ = self._check_existing_models()
|
|
if tokenizer_exists and self.config.skip_existing:
|
|
print("Tokenizer model already exists, skipping training")
|
|
return True
|
|
|
|
log_dir = os.path.join(self.config.base_save_path, "logs")
|
|
logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank)
|
|
|
|
set_seed(self.config.seed)
|
|
|
|
if getattr(self.config, 'pre_trained_tokenizer', True):
|
|
logger.info("Loading pretrained tokenizer...")
|
|
if self.rank == 0:
|
|
print("Loading pretrained tokenizer...")
|
|
tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path)
|
|
else:
|
|
if self.rank == 0:
|
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
|
import json
|
|
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
|
with open(cfg_path, 'r') as f:
|
|
arch = json.load(f)
|
|
tokenizer = KronosTokenizer(
|
|
d_in=arch.get('d_in', 6),
|
|
d_model=arch.get('d_model', 256),
|
|
n_heads=arch.get('n_heads', 4),
|
|
ff_dim=arch.get('ff_dim', 512),
|
|
n_enc_layers=arch.get('n_enc_layers', 4),
|
|
n_dec_layers=arch.get('n_dec_layers', 4),
|
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
|
s1_bits=arch.get('s1_bits', 10),
|
|
s2_bits=arch.get('s2_bits', 10),
|
|
beta=arch.get('beta', 0.05),
|
|
gamma0=arch.get('gamma0', 1.0),
|
|
gamma=arch.get('gamma', 1.1),
|
|
zeta=arch.get('zeta', 0.05),
|
|
group_size=arch.get('group_size', 4)
|
|
)
|
|
tokenizer = tokenizer.to(self.device)
|
|
|
|
model_size = sum(p.numel() for p in tokenizer.parameters())
|
|
logger.info(f"Tokenizer parameters: {model_size:,}")
|
|
if self.rank == 0:
|
|
print(f"Tokenizer parameters: {model_size:,}")
|
|
|
|
logger.info("=== Training Configuration ===")
|
|
logger.info(f"Data path: {self.config.data_path}")
|
|
logger.info(f"Lookback window: {self.config.lookback_window}")
|
|
logger.info(f"Predict window: {self.config.predict_window}")
|
|
logger.info(f"Batch size: {self.config.batch_size}")
|
|
logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}")
|
|
logger.info(f"Training epochs: {self.config.tokenizer_epochs}")
|
|
logger.info(f"Device: {self.device}")
|
|
logger.info(f"Distributed training: False")
|
|
|
|
logger.info("Starting tokenizer fine-tuning training...")
|
|
if self.rank == 0:
|
|
print("Starting tokenizer fine-tuning training...")
|
|
start_time = time.time()
|
|
best_val_loss = train_tokenizer(
|
|
tokenizer,
|
|
self.device,
|
|
self.config,
|
|
self.config.tokenizer_save_path,
|
|
logger,
|
|
)
|
|
training_time = time.time() - start_time
|
|
|
|
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}"
|
|
logger.info(final_msg)
|
|
if self.rank == 0:
|
|
print(f"\n{final_msg}")
|
|
|
|
return True
|
|
|
|
def train_basemodel_phase(self):
|
|
print("\n" + "="*60)
|
|
print("Starting Basemodel Fine-tuning Phase")
|
|
print("="*60)
|
|
|
|
if getattr(self.config, 'pre_trained_tokenizer', True):
|
|
if not os.path.exists(self.config.finetuned_tokenizer_path):
|
|
raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}")
|
|
|
|
_, basemodel_exists = self._check_existing_models()
|
|
if basemodel_exists and self.config.skip_existing:
|
|
print("Basemodel model already exists, skipping training")
|
|
return True
|
|
|
|
log_dir = os.path.join(self.config.base_save_path, "logs")
|
|
logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank)
|
|
|
|
set_seed(self.config.seed)
|
|
|
|
if getattr(self.config, 'pre_trained_tokenizer', True):
|
|
logger.info("Loading fine-tuned tokenizer...")
|
|
if self.rank == 0:
|
|
print("Loading fine-tuned tokenizer...")
|
|
tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path)
|
|
else:
|
|
if self.rank == 0:
|
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training")
|
|
import json
|
|
cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json')
|
|
with open(cfg_path, 'r') as f:
|
|
arch = json.load(f)
|
|
tokenizer = KronosTokenizer(
|
|
d_in=arch.get('d_in', 6),
|
|
d_model=arch.get('d_model', 256),
|
|
n_heads=arch.get('n_heads', 4),
|
|
ff_dim=arch.get('ff_dim', 512),
|
|
n_enc_layers=arch.get('n_enc_layers', 4),
|
|
n_dec_layers=arch.get('n_dec_layers', 4),
|
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
|
s1_bits=arch.get('s1_bits', 10),
|
|
s2_bits=arch.get('s2_bits', 10),
|
|
beta=arch.get('beta', 0.05),
|
|
gamma0=arch.get('gamma0', 1.0),
|
|
gamma=arch.get('gamma', 1.1),
|
|
zeta=arch.get('zeta', 0.05),
|
|
group_size=arch.get('group_size', 4)
|
|
)
|
|
tokenizer = tokenizer.to(self.device)
|
|
|
|
if getattr(self.config, 'pre_trained_predictor', True):
|
|
logger.info("Loading pretrained predictor...")
|
|
if self.rank == 0:
|
|
print("Loading pretrained predictor...")
|
|
model = Kronos.from_pretrained(self.config.pretrained_predictor_path)
|
|
else:
|
|
if self.rank == 0:
|
|
print("pre_trained_predictor=False, randomly initializing Predictor architecture")
|
|
import json
|
|
cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json')
|
|
with open(cfg_path, 'r') as f:
|
|
arch = json.load(f)
|
|
print("model_config: ", arch)
|
|
model = Kronos(
|
|
s1_bits=arch.get('s1_bits', 10),
|
|
s2_bits=arch.get('s2_bits', 10),
|
|
n_layers=arch.get('n_layers', 12),
|
|
d_model=arch.get('d_model', 832),
|
|
n_heads=arch.get('n_heads', 16),
|
|
ff_dim=arch.get('ff_dim', 2048),
|
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.2),
|
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
resid_dropout_p=arch.get('resid_dropout_p', 0.2),
|
|
token_dropout_p=arch.get('token_dropout_p', 0.0),
|
|
learn_te=arch.get('learn_te', True)
|
|
)
|
|
model = model.to(self.device)
|
|
|
|
model_size = sum(p.numel() for p in model.parameters())
|
|
logger.info(f"Model parameters: {model_size:,}")
|
|
if self.rank == 0:
|
|
print(f"Model parameters: {model_size:,}")
|
|
|
|
logger.info("=== Training Configuration ===")
|
|
logger.info(f"Data path: {self.config.data_path}")
|
|
logger.info(f"Lookback window: {self.config.lookback_window}")
|
|
logger.info(f"Predict window: {self.config.predict_window}")
|
|
logger.info(f"Batch size: {self.config.batch_size}")
|
|
logger.info(f"Learning rate: {self.config.predictor_learning_rate}")
|
|
logger.info(f"Training epochs: {self.config.basemodel_epochs}")
|
|
logger.info(f"Device: {self.device}")
|
|
logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}")
|
|
logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}")
|
|
|
|
logger.info("Starting fine-tuning training...")
|
|
if self.rank == 0:
|
|
print("Starting fine-tuning training...")
|
|
start_time = time.time()
|
|
best_val_loss = train_model(
|
|
model,
|
|
tokenizer,
|
|
self.device,
|
|
self.config,
|
|
self.config.basemodel_save_path,
|
|
logger,
|
|
)
|
|
training_time = time.time() - start_time
|
|
|
|
final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}"
|
|
logger.info(final_msg)
|
|
if self.rank == 0:
|
|
print(f"\n{final_msg}")
|
|
|
|
return True
|
|
|
|
def run_training(self):
|
|
if self.rank == 0:
|
|
print("Starting Kronos model sequential fine-tuning training")
|
|
print(f"Experiment name: {self.config.experiment_name}")
|
|
print(f"Experiment description: {self.config.experiment_description}")
|
|
|
|
self._setup_distributed()
|
|
|
|
self._create_directories()
|
|
|
|
tokenizer_exists, basemodel_exists = self._check_existing_models()
|
|
|
|
total_start_time = time.time()
|
|
|
|
try:
|
|
if self.config.train_tokenizer:
|
|
success = self.train_tokenizer_phase()
|
|
if not success:
|
|
print("Tokenizer training failed, terminating training")
|
|
return False
|
|
else:
|
|
print("Skipping Tokenizer training phase")
|
|
|
|
if self.config.train_basemodel:
|
|
success = self.train_basemodel_phase()
|
|
if not success:
|
|
print("Basemodel training failed, terminating training")
|
|
return False
|
|
else:
|
|
print("Skipping Basemodel training phase")
|
|
|
|
total_time = time.time() - total_start_time
|
|
|
|
if self.rank == 0:
|
|
print("\n" + "="*60)
|
|
print("Training completed!")
|
|
print("="*60)
|
|
print(f"Total training time: {total_time/60:.2f} minutes")
|
|
print(f"Tokenizer model: {self.config.tokenizer_best_model_path}")
|
|
print(f"Basemodel model: {self.config.basemodel_best_model_path}")
|
|
print("="*60)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
if self.rank == 0:
|
|
print(f"Error occurred during training: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
finally:
|
|
pass
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training')
|
|
parser.add_argument('--config', type=str, default='config.yaml',
|
|
help='Configuration file path (default: config.yaml)')
|
|
parser.add_argument('--skip-tokenizer', action='store_true',
|
|
help='Skip tokenizer training phase')
|
|
parser.add_argument('--skip-basemodel', action='store_true',
|
|
help='Skip basemodel training phase')
|
|
parser.add_argument('--skip-existing', action='store_true',
|
|
help='Skip training for existing models')
|
|
|
|
args = parser.parse_args()
|
|
|
|
trainer = SequentialTrainer(args.config)
|
|
|
|
if args.skip_tokenizer:
|
|
trainer.config.train_tokenizer = False
|
|
if args.skip_basemodel:
|
|
trainer.config.train_basemodel = False
|
|
if args.skip_existing:
|
|
trainer.config.skip_existing = True
|
|
|
|
success = trainer.run_training()
|
|
|
|
if success:
|
|
print("Training completed successfully!")
|
|
if dist.is_available() and dist.is_initialized():
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
sys.exit(0)
|
|
else:
|
|
print("Training failed!")
|
|
if dist.is_available() and dist.is_initialized():
|
|
try:
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
except Exception:
|
|
pass
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|