You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
267 lines
11 KiB
267 lines
11 KiB
import os
|
|
import yaml
|
|
from typing import Dict, Any
|
|
|
|
|
|
class ConfigLoader:
|
|
|
|
def __init__(self, config_path: str):
|
|
|
|
self.config_path = config_path
|
|
self.config = self._load_config()
|
|
|
|
def _load_config(self) -> Dict[str, Any]:
|
|
|
|
if not os.path.exists(self.config_path):
|
|
raise FileNotFoundError(f"config file not found: {self.config_path}")
|
|
|
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
config = self._resolve_dynamic_paths(config)
|
|
|
|
return config
|
|
|
|
def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
exp_name = config.get('model_paths', {}).get('exp_name', '')
|
|
if not exp_name:
|
|
return config
|
|
|
|
base_path = config.get('model_paths', {}).get('base_path', '')
|
|
path_templates = {
|
|
'base_save_path': f"{base_path}/{exp_name}",
|
|
'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model"
|
|
}
|
|
|
|
if 'model_paths' in config:
|
|
for key, template in path_templates.items():
|
|
if key in config['model_paths']:
|
|
# only use template when the original value is empty string
|
|
current_value = config['model_paths'][key]
|
|
if current_value == "" or current_value is None:
|
|
config['model_paths'][key] = template
|
|
else:
|
|
# if the original value is not empty, use template to replace the {exp_name} placeholder
|
|
if isinstance(current_value, str) and '{exp_name}' in current_value:
|
|
config['model_paths'][key] = current_value.format(exp_name=exp_name)
|
|
|
|
return config
|
|
|
|
def get(self, key: str, default=None):
|
|
|
|
keys = key.split('.')
|
|
value = self.config
|
|
|
|
try:
|
|
for k in keys:
|
|
value = value[k]
|
|
return value
|
|
except (KeyError, TypeError):
|
|
return default
|
|
|
|
def get_data_config(self) -> Dict[str, Any]:
|
|
return self.config.get('data', {})
|
|
|
|
def get_training_config(self) -> Dict[str, Any]:
|
|
return self.config.get('training', {})
|
|
|
|
def get_model_paths(self) -> Dict[str, str]:
|
|
return self.config.get('model_paths', {})
|
|
|
|
def get_experiment_config(self) -> Dict[str, Any]:
|
|
return self.config.get('experiment', {})
|
|
|
|
def get_device_config(self) -> Dict[str, Any]:
|
|
return self.config.get('device', {})
|
|
|
|
def get_distributed_config(self) -> Dict[str, Any]:
|
|
return self.config.get('distributed', {})
|
|
|
|
def update_config(self, updates: Dict[str, Any]):
|
|
|
|
def update_nested_dict(d, u):
|
|
for k, v in u.items():
|
|
if isinstance(v, dict):
|
|
d[k] = update_nested_dict(d.get(k, {}), v)
|
|
else:
|
|
d[k] = v
|
|
return d
|
|
|
|
self.config = update_nested_dict(self.config, updates)
|
|
|
|
def save_config(self, save_path: str = None):
|
|
|
|
if save_path is None:
|
|
save_path = self.config_path
|
|
|
|
with open(save_path, 'w', encoding='utf-8') as f:
|
|
yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2)
|
|
|
|
def print_config(self):
|
|
print("=" * 50)
|
|
print("Current configuration:")
|
|
print("=" * 50)
|
|
yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2)
|
|
print("=" * 50)
|
|
|
|
|
|
class CustomFinetuneConfig:
|
|
|
|
def __init__(self, config_path: str = None):
|
|
|
|
if config_path is None:
|
|
config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
|
|
|
|
self.loader = ConfigLoader(config_path)
|
|
self._load_all_configs()
|
|
|
|
def _load_all_configs(self):
|
|
|
|
data_config = self.loader.get_data_config()
|
|
self.data_path = data_config.get('data_path')
|
|
self.lookback_window = data_config.get('lookback_window', 512)
|
|
self.predict_window = data_config.get('predict_window', 48)
|
|
self.max_context = data_config.get('max_context', 512)
|
|
self.clip = data_config.get('clip', 5.0)
|
|
self.train_ratio = data_config.get('train_ratio', 0.9)
|
|
self.val_ratio = data_config.get('val_ratio', 0.1)
|
|
self.test_ratio = data_config.get('test_ratio', 0.0)
|
|
|
|
# training configuration
|
|
training_config = self.loader.get_training_config()
|
|
# support training epochs of tokenizer and basemodel separately
|
|
self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30)
|
|
self.basemodel_epochs = training_config.get('basemodel_epochs', 30)
|
|
|
|
if 'epochs' in training_config and 'tokenizer_epochs' not in training_config:
|
|
self.tokenizer_epochs = training_config.get('epochs', 30)
|
|
if 'epochs' in training_config and 'basemodel_epochs' not in training_config:
|
|
self.basemodel_epochs = training_config.get('epochs', 30)
|
|
|
|
self.batch_size = training_config.get('batch_size', 160)
|
|
self.log_interval = training_config.get('log_interval', 50)
|
|
self.num_workers = training_config.get('num_workers', 6)
|
|
self.seed = training_config.get('seed', 100)
|
|
self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4)
|
|
self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5)
|
|
self.adam_beta1 = training_config.get('adam_beta1', 0.9)
|
|
self.adam_beta2 = training_config.get('adam_beta2', 0.95)
|
|
self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1)
|
|
self.accumulation_steps = training_config.get('accumulation_steps', 1)
|
|
|
|
model_paths = self.loader.get_model_paths()
|
|
self.exp_name = model_paths.get('exp_name', 'default_experiment')
|
|
self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer')
|
|
self.pretrained_predictor_path = model_paths.get('pretrained_predictor')
|
|
self.base_save_path = model_paths.get('base_save_path')
|
|
self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer')
|
|
self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel')
|
|
self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer')
|
|
|
|
experiment_config = self.loader.get_experiment_config()
|
|
self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune')
|
|
self.experiment_description = experiment_config.get('description', '')
|
|
self.use_comet = experiment_config.get('use_comet', False)
|
|
self.train_tokenizer = experiment_config.get('train_tokenizer', True)
|
|
self.train_basemodel = experiment_config.get('train_basemodel', True)
|
|
self.skip_existing = experiment_config.get('skip_existing', False)
|
|
|
|
unified_pretrained = experiment_config.get('pre_trained', None)
|
|
self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True)
|
|
self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True)
|
|
|
|
device_config = self.loader.get_device_config()
|
|
self.use_cuda = device_config.get('use_cuda', True)
|
|
self.device_id = device_config.get('device_id', 0)
|
|
|
|
distributed_config = self.loader.get_distributed_config()
|
|
self.use_ddp = distributed_config.get('use_ddp', False)
|
|
self.ddp_backend = distributed_config.get('backend', 'nccl')
|
|
|
|
self._compute_full_paths()
|
|
|
|
def _compute_full_paths(self):
|
|
|
|
self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name)
|
|
self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model')
|
|
|
|
self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name)
|
|
self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model')
|
|
|
|
def get_tokenizer_config(self):
|
|
|
|
return {
|
|
'data_path': self.data_path,
|
|
'lookback_window': self.lookback_window,
|
|
'predict_window': self.predict_window,
|
|
'max_context': self.max_context,
|
|
'clip': self.clip,
|
|
'train_ratio': self.train_ratio,
|
|
'val_ratio': self.val_ratio,
|
|
'test_ratio': self.test_ratio,
|
|
'epochs': self.tokenizer_epochs,
|
|
'batch_size': self.batch_size,
|
|
'log_interval': self.log_interval,
|
|
'num_workers': self.num_workers,
|
|
'seed': self.seed,
|
|
'learning_rate': self.tokenizer_learning_rate,
|
|
'adam_beta1': self.adam_beta1,
|
|
'adam_beta2': self.adam_beta2,
|
|
'adam_weight_decay': self.adam_weight_decay,
|
|
'accumulation_steps': self.accumulation_steps,
|
|
'pretrained_model_path': self.pretrained_tokenizer_path,
|
|
'save_path': self.tokenizer_save_path,
|
|
'use_comet': self.use_comet
|
|
}
|
|
|
|
def get_basemodel_config(self):
|
|
|
|
return {
|
|
'data_path': self.data_path,
|
|
'lookback_window': self.lookback_window,
|
|
'predict_window': self.predict_window,
|
|
'max_context': self.max_context,
|
|
'clip': self.clip,
|
|
'train_ratio': self.train_ratio,
|
|
'val_ratio': self.val_ratio,
|
|
'test_ratio': self.test_ratio,
|
|
'epochs': self.basemodel_epochs,
|
|
'batch_size': self.batch_size,
|
|
'log_interval': self.log_interval,
|
|
'num_workers': self.num_workers,
|
|
'seed': self.seed,
|
|
'predictor_learning_rate': self.predictor_learning_rate,
|
|
'tokenizer_learning_rate': self.tokenizer_learning_rate,
|
|
'adam_beta1': self.adam_beta1,
|
|
'adam_beta2': self.adam_beta2,
|
|
'adam_weight_decay': self.adam_weight_decay,
|
|
'pretrained_tokenizer_path': self.finetuned_tokenizer_path,
|
|
'pretrained_predictor_path': self.pretrained_predictor_path,
|
|
'save_path': self.basemodel_save_path,
|
|
'use_comet': self.use_comet
|
|
}
|
|
|
|
def print_config_summary(self):
|
|
|
|
print("=" * 60)
|
|
print("Kronos finetuning configuration summary")
|
|
print("=" * 60)
|
|
print(f"Experiment name: {self.exp_name}")
|
|
print(f"Data path: {self.data_path}")
|
|
print(f"Lookback window: {self.lookback_window}")
|
|
print(f"Predict window: {self.predict_window}")
|
|
print(f"Tokenizer training epochs: {self.tokenizer_epochs}")
|
|
print(f"Basemodel training epochs: {self.basemodel_epochs}")
|
|
print(f"Batch size: {self.batch_size}")
|
|
print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}")
|
|
print(f"Predictor learning rate: {self.predictor_learning_rate}")
|
|
print(f"Train tokenizer: {self.train_tokenizer}")
|
|
print(f"Train basemodel: {self.train_basemodel}")
|
|
print(f"Skip existing: {self.skip_existing}")
|
|
print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}")
|
|
print(f"Use pre-trained predictor: {self.pre_trained_predictor}")
|
|
print(f"Base save path: {self.base_save_path}")
|
|
print(f"Tokenizer save path: {self.tokenizer_save_path}")
|
|
print(f"Basemodel save path: {self.basemodel_save_path}")
|
|
print("=" * 60)
|