| import os
|
| import argparse
|
| import torch
|
| from torch.utils.data import DataLoader, Dataset
|
| from torch.optim import AdamW
|
| from torch.optim.lr_scheduler import CosineAnnealingLR
|
| from torch.nn.utils.rnn import pad_sequence
|
| from tqdm import tqdm
|
| from src.modeling_openpeer import OpenPeerLLM
|
| from src.configuration_openpeer import OpenPeerConfig
|
| from src.tokenization_openpeer import OpenPeerTokenizer
|
|
|
| class TextDataset(Dataset):
|
| def __init__(self, texts, tokenizer, max_length=1024):
|
| self.tokenizer = tokenizer
|
| self.texts = texts
|
| self.max_length = max_length
|
|
|
| def __len__(self):
|
| return len(self.texts)
|
|
|
| def __getitem__(self, idx):
|
| text = self.texts[idx]
|
| encoded = self.tokenizer(text,
|
| truncation=True,
|
| max_length=self.max_length)
|
|
|
| input_ids = encoded["input_ids"]
|
| attention_mask = encoded["attention_mask"]
|
|
|
|
|
| labels = input_ids[1:] + [self.tokenizer.eos_token_id]
|
|
|
| return {
|
| "input_ids": torch.tensor(input_ids),
|
| "attention_mask": torch.tensor(attention_mask),
|
| "labels": torch.tensor(labels)
|
| }
|
|
|
| def collate_fn(batch):
|
| input_ids = [item["input_ids"] for item in batch]
|
| attention_mask = [item["attention_mask"] for item in batch]
|
| labels = [item["labels"] for item in batch]
|
|
|
|
|
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
| attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
|
| labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "labels": labels
|
| }
|
|
|
| def train(
|
| model,
|
| train_dataloader,
|
| optimizer,
|
| scheduler,
|
| num_epochs,
|
| device,
|
| save_path,
|
| log_interval=100
|
| ):
|
| model.train()
|
| total_steps = 0
|
| best_loss = float('inf')
|
|
|
| for epoch in range(num_epochs):
|
| print(f"\nEpoch {epoch+1}/{num_epochs}")
|
| progress_bar = tqdm(train_dataloader, desc="Training")
|
| epoch_loss = 0
|
|
|
| for batch_idx, batch in enumerate(progress_bar):
|
|
|
| input_ids = batch["input_ids"].to(device)
|
| attention_mask = batch["attention_mask"].to(device)
|
| labels = batch["labels"].to(device)
|
|
|
|
|
| outputs = model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| labels=labels
|
| )
|
|
|
| loss = outputs["loss"]
|
| epoch_loss += loss.item()
|
|
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| optimizer.step()
|
| scheduler.step()
|
|
|
| total_steps += 1
|
|
|
|
|
| progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
|
|
|
| if loss.item() < best_loss:
|
| best_loss = loss.item()
|
| torch.save({
|
| "epoch": epoch,
|
| "model_state_dict": model.state_dict(),
|
| "optimizer_state_dict": optimizer.state_dict(),
|
| "loss": best_loss,
|
| }, f"{save_path}/best_model.pt")
|
|
|
|
|
| avg_epoch_loss = epoch_loss / len(train_dataloader)
|
| print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")
|
|
|
| checkpoint = {
|
| "epoch": epoch,
|
| "model_state_dict": model.state_dict(),
|
| "optimizer_state_dict": optimizer.state_dict(),
|
| "loss": avg_epoch_loss,
|
| }
|
| torch.save(checkpoint, f"{save_path}/checkpoint_epoch_{epoch+1}.pt")
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--train_data", type=str, required=True, help="Path to training data file")
|
| parser.add_argument("--save_path", type=str, required=True, help="Directory to save model checkpoints")
|
| parser.add_argument("--load_checkpoint", type=str, help="Path to model checkpoint to continue training")
|
| parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
|
| parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
|
| parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
|
| parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length")
|
| args = parser.parse_args()
|
|
|
|
|
| os.makedirs(args.save_path, exist_ok=True)
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Using device: {device}")
|
|
|
|
|
| config = OpenPeerConfig()
|
| model = OpenPeerLLM(config).to(device)
|
| tokenizer = OpenPeerTokenizer()
|
|
|
|
|
| start_epoch = 0
|
| if args.load_checkpoint and os.path.exists(args.load_checkpoint):
|
| print(f"Loading checkpoint: {args.load_checkpoint}")
|
| checkpoint = torch.load(args.load_checkpoint, map_location=device)
|
| model.load_state_dict(checkpoint["model_state_dict"])
|
| start_epoch = checkpoint["epoch"] + 1
|
| print(f"Resuming from epoch {start_epoch}")
|
|
|
|
|
| print("Loading training data...")
|
| with open(args.train_data, 'r', encoding='utf-8') as f:
|
| texts = [line.strip() for line in f.readlines() if line.strip()]
|
|
|
|
|
| print("Creating dataset...")
|
| dataset = TextDataset(texts, tokenizer, max_length=args.max_length)
|
| train_dataloader = DataLoader(
|
| dataset,
|
| batch_size=args.batch_size,
|
| shuffle=True,
|
| collate_fn=collate_fn,
|
| num_workers=4
|
| )
|
|
|
|
|
| optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|
| scheduler = CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * args.num_epochs)
|
|
|
|
|
| if args.load_checkpoint and os.path.exists(args.load_checkpoint):
|
| checkpoint = torch.load(args.load_checkpoint, map_location=device)
|
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
|
|
|
| print("Starting training...")
|
| train(
|
| model=model,
|
| train_dataloader=train_dataloader,
|
| optimizer=optimizer,
|
| scheduler=scheduler,
|
| num_epochs=args.num_epochs,
|
| device=device,
|
| save_path=args.save_path,
|
| )
|
|
|
| if __name__ == "__main__":
|
| main() |