|
|
import osimport sysimport timeimport argparseimport torchimport torch.nn as nnfrom torch.utils.data import DataLoaderimport torch.distributed as dist
sys.path.append('../')from model import Kronos, KronosTokenizer, KronosPredictor
from config_loader import CustomFinetuneConfigfrom finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_loggingfrom finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging
class SequentialTrainer: def __init__(self, config_path: str = None): self.config = CustomFinetuneConfig(config_path) self.rank = int(os.environ.get("RANK", "0")) self.world_size = int(os.environ.get("WORLD_SIZE", "1")) self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0))) self.device = self._setup_device() self.config.print_config_summary() def _setup_device(self): if self.config.use_cuda and torch.cuda.is_available(): torch.cuda.set_device(self.local_rank) device = torch.device(f"cuda:{self.local_rank}") else: device = torch.device("cpu") if self.rank == 0: print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})") return device def _setup_distributed(self): if self.world_size > 1 and torch.cuda.is_available(): backend = os.environ.get("DIST_BACKEND", "nccl").lower() if not dist.is_initialized(): dist.init_process_group(backend=backend) if self.rank == 0: print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}") else: if self.rank == 0: print("Distributed training not enabled, using single GPU/CPU training") def _check_existing_models(self): tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path) basemodel_exists = os.path.exists(self.config.basemodel_best_model_path) print(f"Tokenizer model exists: {tokenizer_exists}") print(f"Basemodel model exists: {basemodel_exists}") return tokenizer_exists, basemodel_exists def _create_directories(self): os.makedirs(self.config.tokenizer_save_path, exist_ok=True) os.makedirs(self.config.basemodel_save_path, exist_ok=True) print(f"Created directory: {self.config.tokenizer_save_path}") print(f"Created directory: {self.config.basemodel_save_path}") def train_tokenizer_phase(self): print("\n" + "="*60) print("Starting Tokenizer Fine-tuning Phase") print("="*60) tokenizer_exists, _ = self._check_existing_models() if tokenizer_exists and self.config.skip_existing: print("Tokenizer model already exists, skipping training") return True log_dir = os.path.join(self.config.base_save_path, "logs") logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank) set_seed(self.config.seed) if getattr(self.config, 'pre_trained_tokenizer', True): logger.info("Loading pretrained tokenizer...") if self.rank == 0: print("Loading pretrained tokenizer...") tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path) else: if self.rank == 0: print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") import json cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') with open(cfg_path, 'r') as f: arch = json.load(f) tokenizer = KronosTokenizer( d_in=arch.get('d_in', 6), d_model=arch.get('d_model', 256), n_heads=arch.get('n_heads', 4), ff_dim=arch.get('ff_dim', 512), n_enc_layers=arch.get('n_enc_layers', 4), n_dec_layers=arch.get('n_dec_layers', 4), ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), attn_dropout_p=arch.get('attn_dropout_p', 0.0), resid_dropout_p=arch.get('resid_dropout_p', 0.0), s1_bits=arch.get('s1_bits', 10), s2_bits=arch.get('s2_bits', 10), beta=arch.get('beta', 0.05), gamma0=arch.get('gamma0', 1.0), gamma=arch.get('gamma', 1.1), zeta=arch.get('zeta', 0.05), group_size=arch.get('group_size', 4) ) tokenizer = tokenizer.to(self.device) model_size = sum(p.numel() for p in tokenizer.parameters()) logger.info(f"Tokenizer parameters: {model_size:,}") if self.rank == 0: print(f"Tokenizer parameters: {model_size:,}") logger.info("=== Training Configuration ===") logger.info(f"Data path: {self.config.data_path}") logger.info(f"Lookback window: {self.config.lookback_window}") logger.info(f"Predict window: {self.config.predict_window}") logger.info(f"Batch size: {self.config.batch_size}") logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}") logger.info(f"Training epochs: {self.config.tokenizer_epochs}") logger.info(f"Device: {self.device}") logger.info(f"Distributed training: False") logger.info("Starting tokenizer fine-tuning training...") if self.rank == 0: print("Starting tokenizer fine-tuning training...") start_time = time.time() best_val_loss = train_tokenizer( tokenizer, self.device, self.config, self.config.tokenizer_save_path, logger, ) training_time = time.time() - start_time final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}" logger.info(final_msg) if self.rank == 0: print(f"\n{final_msg}") return True def train_basemodel_phase(self): print("\n" + "="*60) print("Starting Basemodel Fine-tuning Phase") print("="*60) if getattr(self.config, 'pre_trained_tokenizer', True): if not os.path.exists(self.config.finetuned_tokenizer_path): raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}") _, basemodel_exists = self._check_existing_models() if basemodel_exists and self.config.skip_existing: print("Basemodel model already exists, skipping training") return True log_dir = os.path.join(self.config.base_save_path, "logs") logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank) set_seed(self.config.seed) if getattr(self.config, 'pre_trained_tokenizer', True): logger.info("Loading fine-tuned tokenizer...") if self.rank == 0: print("Loading fine-tuned tokenizer...") tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path) else: if self.rank == 0: print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training") import json cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') with open(cfg_path, 'r') as f: arch = json.load(f) tokenizer = KronosTokenizer( d_in=arch.get('d_in', 6), d_model=arch.get('d_model', 256), n_heads=arch.get('n_heads', 4), ff_dim=arch.get('ff_dim', 512), n_enc_layers=arch.get('n_enc_layers', 4), n_dec_layers=arch.get('n_dec_layers', 4), ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), attn_dropout_p=arch.get('attn_dropout_p', 0.0), resid_dropout_p=arch.get('resid_dropout_p', 0.0), s1_bits=arch.get('s1_bits', 10), s2_bits=arch.get('s2_bits', 10), beta=arch.get('beta', 0.05), gamma0=arch.get('gamma0', 1.0), gamma=arch.get('gamma', 1.1), zeta=arch.get('zeta', 0.05), group_size=arch.get('group_size', 4) ) tokenizer = tokenizer.to(self.device) if getattr(self.config, 'pre_trained_predictor', True): logger.info("Loading pretrained predictor...") if self.rank == 0: print("Loading pretrained predictor...") model = Kronos.from_pretrained(self.config.pretrained_predictor_path) else: if self.rank == 0: print("pre_trained_predictor=False, randomly initializing Predictor architecture") import json cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json') with open(cfg_path, 'r') as f: arch = json.load(f) print("model_config: ", arch) model = Kronos( s1_bits=arch.get('s1_bits', 10), s2_bits=arch.get('s2_bits', 10), n_layers=arch.get('n_layers', 12), d_model=arch.get('d_model', 832), n_heads=arch.get('n_heads', 16), ff_dim=arch.get('ff_dim', 2048), ffn_dropout_p=arch.get('ffn_dropout_p', 0.2), attn_dropout_p=arch.get('attn_dropout_p', 0.0), resid_dropout_p=arch.get('resid_dropout_p', 0.2), token_dropout_p=arch.get('token_dropout_p', 0.0), learn_te=arch.get('learn_te', True) ) model = model.to(self.device) model_size = sum(p.numel() for p in model.parameters()) logger.info(f"Model parameters: {model_size:,}") if self.rank == 0: print(f"Model parameters: {model_size:,}") logger.info("=== Training Configuration ===") logger.info(f"Data path: {self.config.data_path}") logger.info(f"Lookback window: {self.config.lookback_window}") logger.info(f"Predict window: {self.config.predict_window}") logger.info(f"Batch size: {self.config.batch_size}") logger.info(f"Learning rate: {self.config.predictor_learning_rate}") logger.info(f"Training epochs: {self.config.basemodel_epochs}") logger.info(f"Device: {self.device}") logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}") logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}") logger.info("Starting fine-tuning training...") if self.rank == 0: print("Starting fine-tuning training...") start_time = time.time() best_val_loss = train_model( model, tokenizer, self.device, self.config, self.config.basemodel_save_path, logger, ) training_time = time.time() - start_time final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}" logger.info(final_msg) if self.rank == 0: print(f"\n{final_msg}") return True def run_training(self): if self.rank == 0: print("Starting Kronos model sequential fine-tuning training") print(f"Experiment name: {self.config.experiment_name}") print(f"Experiment description: {self.config.experiment_description}") self._setup_distributed() self._create_directories() tokenizer_exists, basemodel_exists = self._check_existing_models() total_start_time = time.time() try: if self.config.train_tokenizer: success = self.train_tokenizer_phase() if not success: print("Tokenizer training failed, terminating training") return False else: print("Skipping Tokenizer training phase") if self.config.train_basemodel: success = self.train_basemodel_phase() if not success: print("Basemodel training failed, terminating training") return False else: print("Skipping Basemodel training phase") total_time = time.time() - total_start_time if self.rank == 0: print("\n" + "="*60) print("Training completed!") print("="*60) print(f"Total training time: {total_time/60:.2f} minutes") print(f"Tokenizer model: {self.config.tokenizer_best_model_path}") print(f"Basemodel model: {self.config.basemodel_best_model_path}") print("="*60) return True except Exception as e: if self.rank == 0: print(f"Error occurred during training: {str(e)}") import traceback traceback.print_exc() return False finally: pass
def main(): parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training') parser.add_argument('--config', type=str, default='config.yaml', help='Configuration file path (default: config.yaml)') parser.add_argument('--skip-tokenizer', action='store_true', help='Skip tokenizer training phase') parser.add_argument('--skip-basemodel', action='store_true', help='Skip basemodel training phase') parser.add_argument('--skip-existing', action='store_true', help='Skip training for existing models') args = parser.parse_args() trainer = SequentialTrainer(args.config) if args.skip_tokenizer: trainer.config.train_tokenizer = False if args.skip_basemodel: trainer.config.train_basemodel = False if args.skip_existing: trainer.config.skip_existing = True success = trainer.run_training() if success: print("Training completed successfully!") if dist.is_available() and dist.is_initialized(): dist.barrier() dist.destroy_process_group() sys.exit(0) else: print("Training failed!") if dist.is_available() and dist.is_initialized(): try: dist.barrier() dist.destroy_process_group() except Exception: pass sys.exit(1)
if __name__ == "__main__": main()
|