You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
359 lines
13 KiB
359 lines
13 KiB
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from time import gmtime, strftime
|
|
import datetime
|
|
import logging
|
|
from logging.handlers import RotatingFileHandler
|
|
import torch.distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
sys.path.append("../")
|
|
from model import KronosTokenizer
|
|
from finetune_base_model import CustomKlineDataset
|
|
from config_loader import CustomFinetuneConfig
|
|
|
|
|
|
def set_seed(seed: int, rank: int = 0):
|
|
actual_seed = seed
|
|
random.seed(actual_seed)
|
|
np.random.seed(actual_seed)
|
|
torch.manual_seed(actual_seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(actual_seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def get_model_size(model: torch.nn.Module) -> str:
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
if total_params >= 1e9:
|
|
return f"{total_params / 1e9:.1f}B"
|
|
elif total_params >= 1e6:
|
|
return f"{total_params / 1e6:.1f}M"
|
|
else:
|
|
return f"{total_params / 1e3:.1f}K"
|
|
|
|
|
|
def format_time(seconds: float) -> str:
|
|
return str(datetime.timedelta(seconds=int(seconds)))
|
|
|
|
|
|
def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger:
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
|
logger = logging.getLogger(f"tokenizer_training_rank_{rank}")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
if logger.handlers:
|
|
return logger
|
|
|
|
log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log")
|
|
file_handler = RotatingFileHandler(
|
|
log_file,
|
|
maxBytes=10*1024*1024,
|
|
backupCount=5,
|
|
encoding='utf-8'
|
|
)
|
|
file_handler.setLevel(logging.INFO)
|
|
|
|
console_handler = None
|
|
if rank == 0:
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setLevel(logging.INFO)
|
|
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
file_handler.setFormatter(formatter)
|
|
if console_handler is not None:
|
|
console_handler.setFormatter(formatter)
|
|
|
|
logger.addHandler(file_handler)
|
|
if console_handler is not None:
|
|
logger.addHandler(console_handler)
|
|
|
|
logger.info(f"=== Tokenizer Training Started ===")
|
|
logger.info(f"Experiment Name: {exp_name}")
|
|
logger.info(f"Log Directory: {log_dir}")
|
|
logger.info(f"Rank: {rank}")
|
|
logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
return logger
|
|
|
|
|
|
def create_dataloaders(config):
|
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
print("Creating tokenizer training data loaders...")
|
|
|
|
train_dataset = CustomKlineDataset(
|
|
data_path=config.data_path,
|
|
data_type="train",
|
|
lookback_window=config.lookback_window,
|
|
predict_window=config.predict_window,
|
|
clip=config.clip,
|
|
seed=config.seed,
|
|
train_ratio=config.train_ratio,
|
|
val_ratio=config.val_ratio,
|
|
test_ratio=config.test_ratio
|
|
)
|
|
|
|
val_dataset = CustomKlineDataset(
|
|
data_path=config.data_path,
|
|
data_type="val",
|
|
lookback_window=config.lookback_window,
|
|
predict_window=config.predict_window,
|
|
clip=config.clip,
|
|
seed=config.seed + 1,
|
|
train_ratio=config.train_ratio,
|
|
val_ratio=config.val_ratio,
|
|
test_ratio=config.test_ratio
|
|
)
|
|
|
|
use_ddp = dist.is_available() and dist.is_initialized()
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None
|
|
val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=(train_sampler is None),
|
|
num_workers=config.num_workers,
|
|
pin_memory=True,
|
|
drop_last=True,
|
|
sampler=train_sampler
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=False,
|
|
num_workers=config.num_workers,
|
|
pin_memory=True,
|
|
drop_last=False,
|
|
sampler=val_sampler
|
|
)
|
|
|
|
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
|
|
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
|
|
|
return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler
|
|
|
|
|
|
def train_tokenizer(model, device, config, save_dir, logger):
|
|
logger.info("Starting tokenizer training...")
|
|
use_ddp = dist.is_available() and dist.is_initialized()
|
|
rank = dist.get_rank() if use_ddp else 0
|
|
world_size = dist.get_world_size() if use_ddp else 1
|
|
|
|
train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config)
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=config.tokenizer_learning_rate,
|
|
weight_decay=config.adam_weight_decay
|
|
)
|
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
optimizer,
|
|
max_lr=config.tokenizer_learning_rate,
|
|
steps_per_epoch=len(train_loader),
|
|
epochs=config.tokenizer_epochs,
|
|
pct_start=0.03,
|
|
div_factor=10
|
|
)
|
|
|
|
if use_ddp:
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
|
|
|
best_val_loss = float("inf")
|
|
batch_idx_global = 0
|
|
|
|
accumulation_steps = getattr(config, 'accumulation_steps', 1)
|
|
|
|
for epoch in range(config.tokenizer_epochs):
|
|
epoch_start_time = time.time()
|
|
model.train()
|
|
|
|
train_dataset.set_epoch_seed(epoch * 10000)
|
|
val_dataset.set_epoch_seed(0)
|
|
if train_sampler is not None:
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
for batch_idx, (ori_batch_x, _) in enumerate(train_loader):
|
|
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
|
|
|
current_batch_total_loss = 0.0
|
|
for j in range(accumulation_steps):
|
|
start_idx = j * (ori_batch_x.shape[0] // accumulation_steps)
|
|
end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps)
|
|
batch_x = ori_batch_x[start_idx:end_idx]
|
|
|
|
zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x)
|
|
z_pre, z = zs
|
|
|
|
recon_loss_pre = F.mse_loss(z_pre, batch_x)
|
|
recon_loss_all = F.mse_loss(z, batch_x)
|
|
recon_loss = recon_loss_pre + recon_loss_all
|
|
loss = (recon_loss + bsq_loss) / 2
|
|
|
|
loss_scaled = loss / accumulation_steps
|
|
current_batch_total_loss += loss.item()
|
|
loss_scaled.backward()
|
|
|
|
torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
if (batch_idx_global + 1) % config.log_interval == 0:
|
|
avg_loss = current_batch_total_loss / accumulation_steps
|
|
lr = optimizer.param_groups[0]["lr"]
|
|
log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] "
|
|
f"LR: {lr:.6f}, Loss: {avg_loss:.4f}")
|
|
logger.info(log_msg)
|
|
if rank == 0:
|
|
print(log_msg)
|
|
|
|
detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n"
|
|
f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n"
|
|
f" - Recon Loss All: {recon_loss_all.item():.4f}")
|
|
logger.info(detail_msg)
|
|
if rank == 0:
|
|
print(detail_msg)
|
|
|
|
batch_idx_global += 1
|
|
|
|
model.eval()
|
|
tot_val_loss_sum_rank = 0.0
|
|
val_sample_count_rank = 0
|
|
|
|
with torch.no_grad():
|
|
for ori_batch_x, _ in val_loader:
|
|
ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True)
|
|
zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x)
|
|
_, z = zs
|
|
val_loss_item = F.mse_loss(z, ori_batch_x)
|
|
|
|
tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0)
|
|
val_sample_count_rank += ori_batch_x.size(0)
|
|
|
|
if use_ddp:
|
|
tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device)
|
|
dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM)
|
|
tot_val_loss_all = tensor_sum[0].item()
|
|
val_count_all = int(tensor_sum[1].item())
|
|
avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0
|
|
else:
|
|
avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0
|
|
|
|
epoch_time = time.time() - epoch_start_time
|
|
epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n"
|
|
f"Validation Loss: {avg_val_loss:.4f}\n"
|
|
f"Epoch Time: {format_time(epoch_time)}\n"
|
|
f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n")
|
|
logger.info(epoch_summary)
|
|
if rank == 0:
|
|
print(epoch_summary)
|
|
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
if rank == 0:
|
|
model_save_path = os.path.join(save_dir, "best_model")
|
|
os.makedirs(model_save_path, exist_ok=True)
|
|
(model.module if use_ddp else model).save_pretrained(model_save_path)
|
|
save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})"
|
|
logger.info(save_msg)
|
|
print(save_msg)
|
|
|
|
return best_val_loss
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training')
|
|
parser.add_argument('--config', type=str, default='config.yaml',
|
|
help='Configuration file path (default: config.yaml)')
|
|
args = parser.parse_args()
|
|
|
|
config = CustomFinetuneConfig(args.config)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
config = CustomFinetuneConfig(args.config)
|
|
|
|
os.makedirs(config.tokenizer_save_path, exist_ok=True)
|
|
|
|
log_dir = os.path.join(config.base_save_path, "logs")
|
|
logger = setup_logging(config.exp_name, log_dir, 0)
|
|
|
|
set_seed(config.seed)
|
|
|
|
# 加载预训练tokenizer
|
|
if getattr(config, 'pre_trained_tokenizer', True):
|
|
logger.info("Loading pretrained tokenizer...")
|
|
print("Loading pretrained tokenizer...")
|
|
tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path)
|
|
else:
|
|
print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture")
|
|
import json, os
|
|
cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json')
|
|
with open(cfg_path, 'r') as f:
|
|
arch = json.load(f)
|
|
tokenizer = KronosTokenizer(
|
|
d_in=arch.get('d_in', 6),
|
|
d_model=arch.get('d_model', 256),
|
|
n_heads=arch.get('n_heads', 4),
|
|
ff_dim=arch.get('ff_dim', 512),
|
|
n_enc_layers=arch.get('n_enc_layers', 4),
|
|
n_dec_layers=arch.get('n_dec_layers', 4),
|
|
ffn_dropout_p=arch.get('ffn_dropout_p', 0.0),
|
|
attn_dropout_p=arch.get('attn_dropout_p', 0.0),
|
|
resid_dropout_p=arch.get('resid_dropout_p', 0.0),
|
|
s1_bits=arch.get('s1_bits', 10),
|
|
s2_bits=arch.get('s2_bits', 10),
|
|
beta=arch.get('beta', 0.05),
|
|
gamma0=arch.get('gamma0', 1.0),
|
|
gamma=arch.get('gamma', 1.1),
|
|
zeta=arch.get('zeta', 0.05),
|
|
group_size=arch.get('group_size', 4)
|
|
)
|
|
tokenizer = tokenizer.to(device)
|
|
|
|
model_size = get_model_size(tokenizer)
|
|
logger.info(f"Tokenizer parameters: {model_size}")
|
|
print(f"Tokenizer parameters: {model_size}")
|
|
|
|
logger.info("=== Training Configuration ===")
|
|
logger.info(f"Data path: {config.data_path}")
|
|
logger.info(f"Lookback window: {config.lookback_window}")
|
|
logger.info(f"Predict window: {config.predict_window}")
|
|
logger.info(f"Batch size: {config.batch_size}")
|
|
logger.info(f"Learning rate: {config.tokenizer_learning_rate}")
|
|
logger.info(f"Training epochs: {config.tokenizer_epochs}")
|
|
logger.info(f"Device: {device}")
|
|
logger.info(f"Distributed training: False")
|
|
|
|
logger.info("Starting tokenizer fine-tuning training...")
|
|
print("Starting tokenizer fine-tuning training...")
|
|
best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger)
|
|
|
|
final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}"
|
|
logger.info(final_msg)
|
|
print(final_msg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|