| | |
| | """ |
| | Prototype LM for geometric simplex structures. |
| | |
| | Requires the geometricvocab's SimplexFactory for valid simplex representations, or the simplex behavior will not learn. |
| | |
| | try: |
| | !pip uninstall -qy geometricvocab |
| | except: |
| | pass |
| | |
| | !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git |
| | |
| | License: MIT |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| | from torch.utils.tensorboard import SummaryWriter |
| | import math |
| | from itertools import combinations |
| | import time |
| | import os |
| | import json |
| | from tqdm.auto import tqdm |
| | from pathlib import Path |
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | print(f"Device: {device}") |
| |
|
| | from geovocab2.shapes.factory.simplex_factory import SimplexFactory |
| | from huggingface_hub import HfApi, create_repo, upload_folder |
| | import tiktoken |
| |
|
| | |
| | |
| | |
| |
|
| | HF_REPO = "AbstractPhil/ksimplex-llm-prototype" |
| | RUN_NAME = f"run_{int(time.time())}" |
| | CHECKPOINT_DIR = Path(f"./checkpoints/{RUN_NAME}") |
| | TENSORBOARD_DIR = Path(f"./runs/{RUN_NAME}") |
| |
|
| | CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
| | TENSORBOARD_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | |
| | |
| |
|
| | class CMValidator(nn.Module): |
| | def __init__(self, k): |
| | super().__init__() |
| | self._k = k |
| | self._nv = k + 1 |
| | |
| | pairs = list(combinations(range(self._nv), 2)) |
| | self._npairs = len(pairs) |
| | self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) |
| | self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) |
| | |
| | sign = (-1.0) ** (k + 1) |
| | fact = math.factorial(k) |
| | self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
| | |
| | def forward(self, verts): |
| | gram = torch.einsum('...ve,...we->...vw', verts, verts) |
| | norms = torch.diagonal(gram, dim1=-2, dim2=-1) |
| | d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram |
| | d2_mat = F.relu(d2_mat) |
| | |
| | d2_pairs = d2_mat[..., self._pi, self._pj] |
| | |
| | shape = d2_mat.shape[:-2] |
| | V = d2_mat.shape[-1] |
| | cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype) |
| | cm[..., 0, 1:] = 1.0 |
| | cm[..., 1:, 0] = 1.0 |
| | cm[..., 1:, 1:] = d2_mat |
| | |
| | vol2 = self._prefactor * torch.linalg.det(cm) |
| | |
| | return d2_pairs, vol2 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class KSimplexChannel(nn.Module): |
| | BASE_DEFORM = 0.05 |
| | |
| | def __init__(self, k, in_dim, edim, feat_dim): |
| | super().__init__() |
| | self._k = k |
| | self._nv = k + 1 |
| | self._edim = edim |
| | self._feat_dim = feat_dim |
| | |
| | self._cm = CMValidator(k) |
| | self._geo_dim = self._cm._npairs + 1 |
| | |
| | factory = SimplexFactory(k=k, embed_dim=edim, method="regular", scale=1.0) |
| | self.register_buffer('_template', factory.build_torch(dtype=torch.float32)) |
| | |
| | self._to_coords = nn.Linear(in_dim, self._nv * edim) |
| | self._to_feats = nn.Linear(in_dim, self._nv * feat_dim) |
| | |
| | self._geo_gate = nn.Sequential( |
| | nn.Linear(self._geo_dim, feat_dim), |
| | nn.Sigmoid(), |
| | ) |
| | |
| | self._out_dim = feat_dim + self._geo_dim |
| | |
| | @property |
| | def out_dim(self): |
| | return self._out_dim |
| | |
| | def forward(self, x): |
| | coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim)) |
| | verts = self._template + self.BASE_DEFORM * coords |
| | |
| | vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim)) |
| | |
| | d2, vol2 = self._cm(verts) |
| | geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) |
| | |
| | gate = self._geo_gate(geo) |
| | validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1) |
| | |
| | feat_agg = vert_feats.mean(dim=-2) * gate * validity |
| | |
| | out = torch.cat([feat_agg, geo], dim=-1) |
| | |
| | return out, vol2, d2.mean(dim=-1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TokenToKChannels(nn.Module): |
| | def __init__(self, embed_dim, depth, edim, feat_dim, hidden=256): |
| | super().__init__() |
| | self._depth = depth |
| | |
| | self._proj = nn.Sequential( |
| | nn.Linear(embed_dim, hidden), |
| | nn.LayerNorm(hidden), |
| | nn.GELU(), |
| | nn.Linear(hidden, hidden), |
| | nn.LayerNorm(hidden), |
| | nn.GELU(), |
| | ) |
| | |
| | self._k_encoders = nn.ModuleList([ |
| | KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim) |
| | for k in range(depth) |
| | ]) |
| | |
| | self._k_out_dims = [enc.out_dim for enc in self._k_encoders] |
| | self._max_out_dim = max(self._k_out_dims) |
| | |
| | def forward(self, x): |
| | h = self._proj(x) |
| | |
| | out_list, vol2_list, d2_list = [], [], [] |
| | |
| | for enc in self._k_encoders: |
| | out, vol2, d2_mean = enc(h) |
| | |
| | pad_size = self._max_out_dim - out.shape[-1] |
| | if pad_size > 0: |
| | out = F.pad(out, (0, pad_size)) |
| | |
| | out_list.append(out) |
| | vol2_list.append(vol2) |
| | d2_list.append(d2_mean) |
| | |
| | k_channels = torch.stack(out_list, dim=-2) |
| | vol2 = torch.stack(vol2_list, dim=-1) |
| | d2_mean = torch.stack(d2_list, dim=-1) |
| | |
| | return k_channels, vol2, d2_mean |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class KChannelCrossAttention(nn.Module): |
| | def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1): |
| | super().__init__() |
| | self._depth = depth |
| | self._feat_dim = feat_dim |
| | self._num_heads = num_heads |
| | self._head_dim = feat_dim // num_heads |
| | |
| | self._norm_q = nn.LayerNorm(feat_dim) |
| | self._norm_kv = nn.LayerNorm(feat_dim) |
| | |
| | self._to_q = nn.Linear(feat_dim, feat_dim) |
| | self._to_k = nn.Linear(feat_dim, feat_dim) |
| | self._to_v = nn.Linear(feat_dim, feat_dim) |
| | self._out = nn.Linear(feat_dim, feat_dim) |
| | self._drop = nn.Dropout(dropout) |
| | |
| | self._scale = self._head_dim ** -0.5 |
| | |
| | def forward(self, x): |
| | B, T, K, F = x.shape |
| | |
| | x_flat = x.view(B * T, K, F) |
| | |
| | q = self._to_q(self._norm_q(x_flat)) |
| | k = self._to_k(self._norm_kv(x_flat)) |
| | v = self._to_v(self._norm_kv(x_flat)) |
| | |
| | q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
| | k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
| | v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
| | |
| | attn = (q @ k.transpose(-2, -1)) * self._scale |
| | attn = attn.softmax(dim=-1) |
| | attn = self._drop(attn) |
| | |
| | out = (attn @ v).transpose(1, 2).reshape(B * T, K, F) |
| | out = self._out(out) |
| | out = self._drop(out) |
| | |
| | return x + out.view(B, T, K, F) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CausalSequenceAttention(nn.Module): |
| | def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1, max_seq_len=2048): |
| | super().__init__() |
| | self._num_heads = num_heads |
| | |
| | total_dim = depth * feat_dim |
| | self._head_dim = total_dim // num_heads |
| | |
| | self._norm = nn.LayerNorm(total_dim) |
| | self._to_qkv = nn.Linear(total_dim, 3 * total_dim) |
| | self._out = nn.Linear(total_dim, total_dim) |
| | self._drop = nn.Dropout(dropout) |
| | |
| | self._scale = self._head_dim ** -0.5 |
| | |
| | self.register_buffer( |
| | '_causal_mask', |
| | torch.tril(torch.ones(max_seq_len, max_seq_len)).bool() |
| | ) |
| | |
| | def forward(self, x): |
| | B, T, K, F = x.shape |
| | |
| | x_flat = x.view(B, T, K * F) |
| | x_norm = self._norm(x_flat) |
| | |
| | qkv = self._to_qkv(x_norm).chunk(3, dim=-1) |
| | q, k, v = [t.view(B, T, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv] |
| | |
| | attn = (q @ k.transpose(-2, -1)) * self._scale |
| | |
| | mask = self._causal_mask[:T, :T] |
| | attn = attn.masked_fill(~mask, float('-inf')) |
| | attn = attn.softmax(dim=-1) |
| | attn = self._drop(attn) |
| | |
| | out = (attn @ v).transpose(1, 2).reshape(B, T, K * F) |
| | out = self._out(out) |
| | out = self._drop(out) |
| | |
| | return x + out.view(B, T, K, F) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GeoBlock(nn.Module): |
| | def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1, max_seq_len=2048): |
| | super().__init__() |
| | |
| | self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout) |
| | self._seq_attn = CausalSequenceAttention(depth, feat_dim, num_heads, dropout, max_seq_len) |
| | |
| | total_dim = depth * feat_dim |
| | self._norm = nn.LayerNorm(total_dim) |
| | self._mlp = nn.Sequential( |
| | nn.Linear(total_dim, int(total_dim * mlp_ratio)), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(int(total_dim * mlp_ratio), total_dim), |
| | nn.Dropout(dropout), |
| | ) |
| | |
| | def forward(self, x): |
| | B, T, K, F = x.shape |
| | |
| | x = self._k_attn(x) |
| | x = self._seq_attn(x) |
| | |
| | x_flat = x.view(B, T, K * F) |
| | x_flat = x_flat + self._mlp(self._norm(x_flat)) |
| | x = x_flat.view(B, T, K, F) |
| | |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GeometricLM(nn.Module): |
| | def __init__( |
| | self, |
| | vocab_size, |
| | max_seq_len=512, |
| | embed_dim=256, |
| | depth=4, |
| | edim=16, |
| | feat_dim=64, |
| | hidden=256, |
| | num_heads=8, |
| | num_blocks=8, |
| | dropout=0.1, |
| | ): |
| | super().__init__() |
| | |
| | self._vocab_size = vocab_size |
| | self._max_seq_len = max_seq_len |
| | self._depth = depth |
| | self._feat_dim = feat_dim |
| | |
| | self._tok_embed = nn.Embedding(vocab_size, embed_dim) |
| | self._pos_embed = nn.Embedding(max_seq_len, embed_dim) |
| | |
| | self._tok_to_k = TokenToKChannels(embed_dim, depth, edim, feat_dim, hidden) |
| | self._max_out_dim = self._tok_to_k._max_out_dim |
| | |
| | self._proj = nn.Linear(self._max_out_dim, feat_dim) |
| | |
| | self._blocks = nn.ModuleList([ |
| | GeoBlock(depth, feat_dim, num_heads, dropout=dropout, max_seq_len=max_seq_len) |
| | for _ in range(num_blocks) |
| | ]) |
| | |
| | total_dim = depth * feat_dim |
| | self._norm = nn.LayerNorm(total_dim) |
| | self._lm_head = nn.Linear(total_dim, vocab_size, bias=False) |
| | |
| | self._config = { |
| | 'vocab_size': vocab_size, |
| | 'max_seq_len': max_seq_len, |
| | 'embed_dim': embed_dim, |
| | 'depth': depth, |
| | 'edim': edim, |
| | 'feat_dim': feat_dim, |
| | 'hidden': hidden, |
| | 'num_heads': num_heads, |
| | 'num_blocks': num_blocks, |
| | 'dropout': dropout, |
| | 'total_dim': total_dim, |
| | } |
| | |
| | def forward(self, tokens): |
| | B, T = tokens.shape |
| | |
| | pos = torch.arange(T, device=tokens.device) |
| | x = self._tok_embed(tokens) + self._pos_embed(pos) |
| | |
| | k_channels, vol2, d2_mean = self._tok_to_k(x) |
| | k_channels = self._proj(k_channels) |
| | |
| | for blk in self._blocks: |
| | k_channels = blk(k_channels) |
| | |
| | out = k_channels.flatten(-2) |
| | logits = self._lm_head(self._norm(out)) |
| | |
| | return logits, {'vol2': vol2, 'd2_mean': d2_mean} |
| | |
| | @torch.no_grad() |
| | def generate(self, prompt_tokens, max_new_tokens=100, temperature=1.0, top_k=50): |
| | self.eval() |
| | tokens = prompt_tokens.clone() |
| | |
| | for _ in range(max_new_tokens): |
| | ctx = tokens[:, -self._max_seq_len:] |
| | logits, _ = self(ctx) |
| | logits = logits[:, -1, :] / temperature |
| | |
| | if top_k > 0: |
| | v, _ = torch.topk(logits, top_k) |
| | logits[logits < v[:, [-1]]] = float('-inf') |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | next_tok = torch.multinomial(probs, num_samples=1) |
| | tokens = torch.cat([tokens, next_tok], dim=1) |
| | |
| | return tokens |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TokenizedDataset(Dataset): |
| | def __init__(self, tokens, seq_len, stride=None): |
| | self._tokens = tokens |
| | self._seq_len = seq_len |
| | self._stride = stride if stride else seq_len // 2 |
| | |
| | def __len__(self): |
| | return max(0, (len(self._tokens) - self._seq_len - 1) // self._stride) |
| | |
| | def __getitem__(self, idx): |
| | start = idx * self._stride |
| | chunk = self._tokens[start:start + self._seq_len + 1] |
| | x = torch.tensor(chunk[:-1], dtype=torch.long) |
| | y = torch.tensor(chunk[1:], dtype=torch.long) |
| | return x, y |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def lm_loss(logits, targets, info, ce_weight=1.0, validity_weight=0.1): |
| | B, T, V = logits.shape |
| | ce = F.cross_entropy(logits.view(B * T, V), targets.view(B * T)) |
| | validity = F.relu(-info['vol2']).mean() |
| | total = ce_weight * ce + validity_weight * validity |
| | return total, ce, validity |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_metrics(info, depth): |
| | vol2 = info['vol2'] |
| | d2_mean = info['d2_mean'] |
| | |
| | m = {'valid_rate': (vol2 > 0).float().mean().item()} |
| | for k in range(depth): |
| | m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item() |
| | m[f'k{k+1}_vol2'] = vol2[..., k].mean().item() |
| | m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item() |
| | return m |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def sanity_check(model, enc, device): |
| | """Verify no information leak.""" |
| | print("\n" + "=" * 60) |
| | print("SANITY CHECK") |
| | print("=" * 60) |
| | |
| | model.eval() |
| | |
| | |
| | random_tokens = torch.randint(0, 1000, (4, 256), device=device) |
| | logits, _ = model(random_tokens) |
| | random_targets = torch.randint(0, enc.n_vocab, (4, 256), device=device) |
| | ce = F.cross_entropy(logits.view(-1, enc.n_vocab), random_targets.view(-1)) |
| | |
| | expected_ce = math.log(enc.n_vocab) |
| | print(f"Test 1 - Random input:") |
| | print(f" CE: {ce.item():.2f} (expected ~{expected_ce:.2f})") |
| | print(f" PPL: {math.exp(min(ce.item(), 20)):.0f} (expected ~{enc.n_vocab})") |
| | |
| | test1_pass = ce.item() > 8.0 |
| | print(f" Status: {'✓ PASS' if test1_pass else '✗ FAIL'}") |
| | |
| | |
| | tokens1 = torch.zeros(1, 256, dtype=torch.long, device=device) |
| | tokens2 = torch.zeros(1, 256, dtype=torch.long, device=device) |
| | tokens2[0, 128:] = 999 |
| | |
| | logits1, _ = model(tokens1) |
| | logits2, _ = model(tokens2) |
| | |
| | diff_early = (logits1[0, :128] - logits2[0, :128]).abs().max().item() |
| | diff_late = (logits1[0, 128:] - logits2[0, 128:]).abs().max().item() |
| | |
| | print(f"\nTest 2 - Causal mask:") |
| | print(f" Early positions diff: {diff_early:.6f} (should be ~0)") |
| | print(f" Late positions diff: {diff_late:.6f} (should be >0)") |
| | |
| | test2_pass = diff_early < 1e-5 and diff_late > 1e-3 |
| | print(f" Status: {'✓ PASS' if test2_pass else '✗ FAIL'}") |
| | |
| | |
| | print(f"\nTest 3 - Dataset offset:") |
| | test_tokens = list(range(100)) |
| | ds = TokenizedDataset(test_tokens, seq_len=10) |
| | x, y = ds[0] |
| | offset_correct = all(x[i] + 1 == y[i] for i in range(len(x))) |
| | print(f" x: {x[:5].tolist()}...") |
| | print(f" y: {y[:5].tolist()}...") |
| | print(f" Offset correct: {'✓ PASS' if offset_correct else '✗ FAIL'}") |
| | |
| | print("=" * 60) |
| | |
| | all_pass = test1_pass and test2_pass and offset_correct |
| | if not all_pass: |
| | print("⚠️ WARNING: Some sanity checks failed!") |
| | else: |
| | print("✓ All sanity checks passed!") |
| | |
| | print("=" * 60 + "\n") |
| | |
| | model.train() |
| | return all_pass |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | PROMPTS = [ |
| | "ROMEO: ", |
| | "JULIET: ", |
| | "To be or not to be", |
| | "The king ", |
| | "Once upon a time", |
| | "First Citizen:\n", |
| | "What light through yonder", |
| | "Friends, Romans, countrymen", |
| | "Now is the winter of", |
| | "All the world's a stage", |
| | ] |
| |
|
| | @torch.no_grad() |
| | def generate_samples(model, enc, device, epoch, writer=None): |
| | """Generate samples from all prompts.""" |
| | model.eval() |
| | |
| | samples = [] |
| | print(f"\n{'='*60}") |
| | print(f"GENERATION SAMPLES - Epoch {epoch}") |
| | print(f"{'='*60}") |
| | |
| | for i, prompt in enumerate(PROMPTS): |
| | prompt_tokens = torch.tensor([enc.encode(prompt)], device=device) |
| | |
| | out_tokens = model.generate( |
| | prompt_tokens, |
| | max_new_tokens=100, |
| | temperature=0.8, |
| | top_k=50 |
| | ) |
| | |
| | generated = enc.decode(out_tokens[0].tolist()) |
| | samples.append({'prompt': prompt, 'generated': generated}) |
| | |
| | print(f"\n--- Prompt {i+1}: '{prompt.strip()}' ---") |
| | print(generated[:300]) |
| | if len(generated) > 300: |
| | print("...") |
| | |
| | print(f"{'='*60}\n") |
| | |
| | |
| | if writer: |
| | sample_text = "\n\n".join([ |
| | f"**Prompt:** {s['prompt']}\n**Generated:**\n{s['generated'][:500]}" |
| | for s in samples |
| | ]) |
| | writer.add_text("samples/generated", sample_text, epoch) |
| | |
| | model.train() |
| | return samples |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def save_checkpoint(model, optimizer, scheduler, epoch, config, metrics, checkpoint_dir): |
| | """Save checkpoint locally.""" |
| | checkpoint = { |
| | 'epoch': epoch, |
| | 'model_state_dict': model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'scheduler_state_dict': scheduler.state_dict(), |
| | 'config': config, |
| | 'metrics': metrics, |
| | } |
| | |
| | path = checkpoint_dir / f"checkpoint_epoch_{epoch:03d}.pt" |
| | torch.save(checkpoint, path) |
| | |
| | |
| | torch.save(checkpoint, checkpoint_dir / "checkpoint_latest.pt") |
| | |
| | |
| | with open(checkpoint_dir / "config.json", 'w') as f: |
| | json.dump(config, f, indent=2) |
| | |
| | print(f"Saved checkpoint: {path}") |
| | return path |
| |
|
| |
|
| | def upload_to_hf(checkpoint_dir, repo_id, epoch): |
| | """Upload checkpoint directory to HuggingFace.""" |
| | try: |
| | api = HfApi() |
| | |
| | |
| | try: |
| | create_repo(repo_id, exist_ok=True, repo_type="model") |
| | except Exception as e: |
| | print(f"Repo creation note: {e}") |
| | |
| | |
| | api.upload_folder( |
| | folder_path=str(checkpoint_dir), |
| | repo_id=repo_id, |
| | commit_message=f"Epoch {epoch} checkpoint", |
| | ) |
| | |
| | print(f"Uploaded to HuggingFace: {repo_id}") |
| | return True |
| | except Exception as e: |
| | print(f"HuggingFace upload failed: {e}") |
| | return False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def train(): |
| | import urllib.request |
| | |
| | |
| | writer = SummaryWriter(log_dir=str(TENSORBOARD_DIR)) |
| | print(f"TensorBoard logs: {TENSORBOARD_DIR}") |
| | print(f"Checkpoints: {CHECKPOINT_DIR}") |
| | print(f"HuggingFace repo: {HF_REPO}") |
| | |
| | |
| | data_path = './data/shakespeare.txt' |
| | if not os.path.exists(data_path): |
| | os.makedirs('./data', exist_ok=True) |
| | url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' |
| | print("Downloading Shakespeare...") |
| | urllib.request.urlretrieve(url, data_path) |
| | |
| | with open(data_path, 'r') as f: |
| | text = f.read() |
| | |
| | print(f"Text length: {len(text):,} chars") |
| | |
| | |
| | print("Loading tokenizer...") |
| | enc = tiktoken.get_encoding("gpt2") |
| | |
| | print("Tokenizing...") |
| | tokens = enc.encode(text) |
| | print(f"Token count: {len(tokens):,}") |
| | print(f"Vocab size: {enc.n_vocab:,}") |
| | print(f"Compression ratio: {len(text) / len(tokens):.2f}x") |
| | |
| | |
| | seq_len = 256 |
| | split_idx = int(len(tokens) * 0.9) |
| | train_tokens = tokens[:split_idx] |
| | val_tokens = tokens[split_idx:] |
| | |
| | train_ds = TokenizedDataset(train_tokens, seq_len) |
| | val_ds = TokenizedDataset(val_tokens, seq_len) |
| | |
| | batch_size = 12 |
| | train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) |
| | val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True) |
| | |
| | print(f"Train sequences: {len(train_ds):,} ({len(train_dl)} batches)") |
| | print(f"Val sequences: {len(val_ds):,} ({len(val_dl)} batches)") |
| | |
| | |
| | model_config = { |
| | 'vocab_size': enc.n_vocab, |
| | 'max_seq_len': seq_len, |
| | 'embed_dim': 384, |
| | 'depth': 4, |
| | 'edim': 16, |
| | 'feat_dim': 96, |
| | 'hidden': 384, |
| | 'num_heads': 8, |
| | 'num_blocks': 8, |
| | 'dropout': 0.1, |
| | } |
| | |
| | |
| | train_config = { |
| | 'batch_size': batch_size, |
| | 'seq_len': seq_len, |
| | 'lr': 3e-4, |
| | 'weight_decay': 0.1, |
| | 'num_epochs': 14, |
| | 'grad_clip': 1.0, |
| | 'ce_weight': 1.0, |
| | 'validity_weight': 0.1, |
| | } |
| | |
| | full_config = { |
| | 'model': model_config, |
| | 'training': train_config, |
| | 'data': { |
| | 'train_tokens': len(train_tokens), |
| | 'val_tokens': len(val_tokens), |
| | 'vocab_size': enc.n_vocab, |
| | }, |
| | 'run_name': RUN_NAME, |
| | } |
| | |
| | |
| | with open(CHECKPOINT_DIR / "config.json", 'w') as f: |
| | json.dump(full_config, f, indent=2) |
| | |
| | |
| | print("\nBuilding model...") |
| | model = GeometricLM(**model_config).to(device) |
| | |
| | print(f"\nConfig:") |
| | for k, v in model._config.items(): |
| | print(f" {k}: {v}") |
| | |
| | params = sum(p.numel() for p in model.parameters()) |
| | print(f" params: {params:,}") |
| | full_config['model']['params'] = params |
| | |
| | |
| | sanity_check(model, enc, device) |
| | |
| | print("\nCompiling...") |
| | |
| | |
| | |
| | opt = torch.optim.AdamW( |
| | model.parameters(), |
| | lr=train_config['lr'], |
| | weight_decay=train_config['weight_decay'] |
| | ) |
| | sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=train_config['num_epochs']) |
| | |
| | |
| | |
| | |
| | best_val = float('inf') |
| | best_ppl = float('inf') |
| | global_step = 0 |
| | |
| | print("\nTraining...") |
| | print("=" * 120) |
| | |
| | epoch_pbar = tqdm(range(train_config['num_epochs']), desc="Epochs", position=0) |
| | |
| | for ep in epoch_pbar: |
| | epoch_start = time.time() |
| | |
| | |
| | model.train() |
| | ce_sum, val_sum, n = 0, 0, 0 |
| | |
| | train_pbar = tqdm(train_dl, desc=f"Train {ep+1}", leave=False, position=1) |
| | for batch_idx, (x, y) in enumerate(train_pbar): |
| | x, y = x.to(device), y.to(device) |
| | |
| | opt.zero_grad() |
| | logits, info = model(x) |
| | loss, ce, val = lm_loss( |
| | logits, y, info, |
| | ce_weight=train_config['ce_weight'], |
| | validity_weight=train_config['validity_weight'] |
| | ) |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip']) |
| | opt.step() |
| | |
| | ce_sum += ce.item() * x.size(0) |
| | val_sum += val.item() * x.size(0) |
| | n += x.size(0) |
| | |
| | |
| | if global_step % 100 == 0: |
| | writer.add_scalar("train/ce_batch", ce.item(), global_step) |
| | writer.add_scalar("train/ppl_batch", math.exp(min(ce.item(), 10)), global_step) |
| | writer.add_scalar("train/validity_batch", val.item(), global_step) |
| | writer.add_scalar("train/lr", sched.get_last_lr()[0], global_step) |
| | |
| | global_step += 1 |
| | |
| | train_pbar.set_postfix({ |
| | 'CE': f'{ce.item():.3f}', |
| | 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' |
| | }) |
| | |
| | tr_ce = ce_sum / n |
| | tr_ppl = math.exp(min(tr_ce, 10)) |
| | tr_val = val_sum / n |
| | |
| | |
| | model.eval() |
| | ce_sum, n = 0, 0 |
| | metrics_agg = [] |
| | |
| | val_pbar = tqdm(val_dl, desc=f"Val {ep+1}", leave=False, position=1) |
| | with torch.no_grad(): |
| | for x, y in val_pbar: |
| | x, y = x.to(device), y.to(device) |
| | logits, info = model(x) |
| | _, ce, _ = lm_loss(logits, y, info) |
| | ce_sum += ce.item() * x.size(0) |
| | n += x.size(0) |
| | metrics_agg.append(compute_metrics(info, model._config['depth'])) |
| | |
| | val_pbar.set_postfix({ |
| | 'CE': f'{ce.item():.3f}', |
| | 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' |
| | }) |
| | |
| | va_ce = ce_sum / n |
| | va_ppl = math.exp(min(va_ce, 10)) |
| | |
| | sched.step() |
| | |
| | if va_ce < best_val: |
| | best_val = va_ce |
| | best_ppl = va_ppl |
| | |
| | |
| | m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]} |
| | |
| | epoch_time = time.time() - epoch_start |
| | |
| | |
| | writer.add_scalar("epoch/train_ce", tr_ce, ep) |
| | writer.add_scalar("epoch/train_ppl", tr_ppl, ep) |
| | writer.add_scalar("epoch/val_ce", va_ce, ep) |
| | writer.add_scalar("epoch/val_ppl", va_ppl, ep) |
| | writer.add_scalar("epoch/best_ppl", best_ppl, ep) |
| | writer.add_scalar("epoch/validity_loss", tr_val, ep) |
| | writer.add_scalar("epoch/time", epoch_time, ep) |
| | |
| | for k in range(model._config['depth']): |
| | writer.add_scalar(f"geometry/k{k+1}_valid", m[f'k{k+1}_valid'], ep) |
| | writer.add_scalar(f"geometry/k{k+1}_vol2", m[f'k{k+1}_vol2'], ep) |
| | writer.add_scalar(f"geometry/k{k+1}_d2", m[f'k{k+1}_d2'], ep) |
| | |
| | writer.add_scalar("geometry/valid_rate", m['valid_rate'], ep) |
| | |
| | |
| | epoch_pbar.set_postfix({ |
| | 'TrPPL': f'{tr_ppl:.1f}', |
| | 'VaPPL': f'{va_ppl:.1f}', |
| | 'Best': f'{best_ppl:.1f}', |
| | 'Valid': f"{m['valid_rate']:.0%}" |
| | }) |
| | |
| | tqdm.write( |
| | f"\nEp {ep+1:3d} | TrCE {tr_ce:.4f} | VaCE {va_ce:.4f} | " |
| | f"TrPPL {tr_ppl:7.2f} | VaPPL {va_ppl:7.2f} | BestPPL {best_ppl:.2f} | " |
| | f"Time {epoch_time:.1f}s" |
| | ) |
| | tqdm.write( |
| | f" | k1 {m['k1_valid']:5.1%} vol²={m['k1_vol2']:.2e} | " |
| | f"k2 {m['k2_valid']:5.1%} vol²={m['k2_vol2']:.2e} | " |
| | f"k3 {m['k3_valid']:5.1%} vol²={m['k3_vol2']:.2e} | " |
| | f"k4 {m['k4_valid']:5.1%} vol²={m['k4_vol2']:.2e}" |
| | ) |
| | |
| | |
| | if ep % 25 == 0 or ep == train_config['num_epochs'] - 1: |
| | samples = generate_samples(model, enc, device, ep + 1, writer) |
| | |
| | |
| | with open(CHECKPOINT_DIR / f"samples_epoch_{ep+1:03d}.json", 'w') as f: |
| | json.dump(samples, f, indent=2) |
| | |
| | |
| | metrics = { |
| | 'epoch': ep + 1, |
| | 'train_ce': tr_ce, |
| | 'train_ppl': tr_ppl, |
| | 'val_ce': va_ce, |
| | 'val_ppl': va_ppl, |
| | 'best_ppl': best_ppl, |
| | 'geometry': m, |
| | } |
| | |
| | if ep % 2 == 0 or ep == train_config['num_epochs'] - 1: |
| | save_checkpoint(model, opt, sched, ep + 1, full_config, metrics, CHECKPOINT_DIR) |
| | |
| | |
| | |
| | if train_config['num_epochs'] - 1 == ep: |
| | upload_to_hf(CHECKPOINT_DIR, HF_REPO, ep + 1) |
| | |
| | |
| | writer.close() |
| | |
| | print("\n" + "=" * 120) |
| | print(f"Training complete!") |
| | print(f"Best val CE: {best_val:.4f}, PPL: {best_ppl:.2f}") |
| | print(f"Checkpoints: {CHECKPOINT_DIR}") |
| | print(f"TensorBoard: {TENSORBOARD_DIR}") |
| | print(f"HuggingFace: https://huggingface.co/{HF_REPO}") |
| | print("=" * 120) |
| | |
| | return model, enc |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model, tokenizer = train() |