Metadata-Version: 2.4
Name: transformer-toolkit
Version: 0.0.26
Summary: Minimal, modular transformer library for training your own LLM
Author-email: Govind Barbade <govindbarbade5@gmail.com>
License:                                  Apache License
                                   Version 2.0, January 2004
                                http://www.apache.org/licenses/
        
           TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
        
           1. Definitions.
        
              "License" shall mean the terms and conditions for use, reproduction,
              and distribution as defined by Sections 1 through 9 of this document.
        
              "Licensor" shall mean the copyright owner or entity authorized by
              the copyright owner that is granting the License.
        
              "Legal Entity" shall mean the union of the acting entity and all
              other entities that control, are controlled by, or are under common
              control with that entity. For the purposes of this definition,
              "control" means (i) the power, direct or indirect, to cause the
              direction or management of such entity, whether by contract or
              otherwise, or (ii) ownership of fifty percent (50%) or more of the
              outstanding shares, or (iii) beneficial ownership of such entity.
        
              "You" (or "Your") shall mean an individual or Legal Entity
              exercising permissions granted by this License.
        
              "Source" form shall mean the preferred form for making modifications,
              including but not limited to software source code, documentation
              source, and configuration files.
        
              "Object" form shall mean any form resulting from mechanical
              transformation or translation of a Source form, including but
              not limited to compiled object code, generated documentation,
              and conversions to other media types.
        
              "Work" shall mean the work of authorship, whether in Source or
              Object form, made available under the License, as indicated by a
              copyright notice that is included in or attached to the work
              (an example is provided in the Appendix below).
        
              "Derivative Works" shall mean any work, whether in Source or Object
              form, that is based on (or derived from) the Work and for which the
              editorial revisions, annotations, elaborations, or other modifications
              represent, as a whole, an original work of authorship. For the purposes
              of this License, Derivative Works shall not include works that remain
              separable from, or merely link (or bind by name) to the interfaces of,
              the Work and Derivative Works thereof.
        
              "Contribution" shall mean any work of authorship, including
              the original version of the Work and any modifications or additions
              to that Work or Derivative Works thereof, that is intentionally
              submitted to Licensor for inclusion in the Work by the copyright owner
              or by an individual or Legal Entity authorized to submit on behalf of
              the copyright owner. For the purposes of this definition, "submitted"
              means any form of electronic, verbal, or written communication sent
              to the Licensor or its representatives, including but not limited to
              communication on electronic mailing lists, source code control systems,
              and issue tracking systems that are managed by, or on behalf of, the
              Licensor for the purpose of discussing and improving the Work, but
              excluding communication that is conspicuously marked or otherwise
              designated in writing by the copyright owner as "Not a Contribution."
        
              "Contributor" shall mean Licensor and any individual or Legal Entity
              on behalf of whom a Contribution has been received by Licensor and
              subsequently incorporated within the Work.
        
           2. Grant of Copyright License. Subject to the terms and conditions of
              this License, each Contributor hereby grants to You a perpetual,
              worldwide, non-exclusive, no-charge, royalty-free, irrevocable
              copyright license to reproduce, prepare Derivative Works of,
              publicly display, publicly perform, sublicense, and distribute the
              Work and such Derivative Works in Source or Object form.
        
           3. Grant of Patent License. Subject to the terms and conditions of
              this License, each Contributor hereby grants to You a perpetual,
              worldwide, non-exclusive, no-charge, royalty-free, irrevocable
              (except as stated in this section) patent license to make, have made,
              use, offer to sell, sell, import, and otherwise transfer the Work,
              where such license applies only to those patent claims licensable
              by such Contributor that are necessarily infringed by their
              Contribution(s) alone or by combination of their Contribution(s)
              with the Work to which such Contribution(s) was submitted. If You
              institute patent litigation against any entity (including a
              cross-claim or counterclaim in a lawsuit) alleging that the Work
              or a Contribution incorporated within the Work constitutes direct
              or contributory patent infringement, then any patent licenses
              granted to You under this License for that Work shall terminate
              as of the date such litigation is filed.
        
           4. Redistribution. You may reproduce and distribute copies of the
              Work or Derivative Works thereof in any medium, with or without
              modifications, and in Source or Object form, provided that You
              meet the following conditions:
        
              (a) You must give any other recipients of the Work or
                  Derivative Works a copy of this License; and
        
              (b) You must cause any modified files to carry prominent notices
                  stating that You changed the files; and
        
              (c) You must retain, in the Source form of any Derivative Works
                  that You distribute, all copyright, patent, trademark, and
                  attribution notices from the Source form of the Work,
                  excluding those notices that do not pertain to any part of
                  the Derivative Works; and
        
              (d) If the Work includes a "NOTICE" text file as part of its
                  distribution, then any Derivative Works that You distribute must
                  include a readable copy of the attribution notices contained
                  within such NOTICE file, excluding those notices that do not
                  pertain to any part of the Derivative Works, in at least one
                  of the following places: within a NOTICE text file distributed
                  as part of the Derivative Works; within the Source form or
                  documentation, if provided along with the Derivative Works; or,
                  within a display generated by the Derivative Works, if and
                  wherever such third-party notices normally appear. The contents
                  of the NOTICE file are for informational purposes only and
                  do not modify the License. You may add Your own attribution
                  notices within Derivative Works that You distribute, alongside
                  or as an addendum to the NOTICE text from the Work, provided
                  that such additional attribution notices cannot be construed
                  as modifying the License.
        
              You may add Your own copyright statement to Your modifications and
              may provide additional or different license terms and conditions
              for use, reproduction, or distribution of Your modifications, or
              for any such Derivative Works as a whole, provided Your use,
              reproduction, and distribution of the Work otherwise complies with
              the conditions stated in this License.
        
           5. Submission of Contributions. Unless You explicitly state otherwise,
              any Contribution intentionally submitted for inclusion in the Work
              by You to the Licensor shall be under the terms and conditions of
              this License, without any additional terms or conditions.
              Notwithstanding the above, nothing herein shall supersede or modify
              the terms of any separate license agreement you may have executed
              with Licensor regarding such Contributions.
        
           6. Trademarks. This License does not grant permission to use the trade
              names, trademarks, service marks, or product names of the Licensor,
              except as required for reasonable and customary use in describing the
              origin of the Work and reproducing the content of the NOTICE file.
        
           7. Disclaimer of Warranty. Unless required by applicable law or
              agreed to in writing, Licensor provides the Work (and each
              Contributor provides its Contributions) on an "AS IS" BASIS,
              WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
              implied, including, without limitation, any warranties or conditions
              of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
              PARTICULAR PURPOSE. You are solely responsible for determining the
              appropriateness of using or redistributing the Work and assume any
              risks associated with Your exercise of permissions under this License.
        
           8. Limitation of Liability. In no event and under no legal theory,
              whether in tort (including negligence), contract, or otherwise,
              unless required by applicable law (such as deliberate and grossly
              negligent acts) or agreed to in writing, shall any Contributor be
              liable to You for damages, including any direct, indirect, special,
              incidental, or consequential damages of any character arising as a
              result of this License or out of the use or inability to use the
              Work (including but not limited to damages for loss of goodwill,
              work stoppage, computer failure or malfunction, or any and all
              other commercial damages or losses), even if such Contributor
              has been advised of the possibility of such damages.
        
           9. Accepting Warranty or Additional Liability. While redistributing
              the Work or Derivative Works thereof, You may choose to offer,
              and charge a fee for, acceptance of support, warranty, indemnity,
              or other liability obligations and/or rights consistent with this
              License. However, in accepting such obligations, You may act only
              on Your own behalf and on Your sole responsibility, not on behalf
              of any other Contributor, and only if You agree to indemnify,
              defend, and hold each Contributor harmless for any liability
              incurred by, or claims asserted against, such Contributor by reason
              of your accepting any such warranty or additional liability.
        
           END OF TERMS AND CONDITIONS
        
           APPENDIX: How to apply the Apache License to your work.
        
              To apply the Apache License to your work, attach the following
              boilerplate notice, with the fields enclosed by brackets "[]"
              replaced with your own identifying information. (Don't include
              the brackets!)  The text should be enclosed in the appropriate
              comment syntax for the file format. We also recommend that a
              file or class name and description of purpose be included on the
              same "printed page" as the copyright notice for easier
              identification within third-party archives.
        
           Copyright [yyyy] [name of copyright owner]
        
           Licensed under the Apache License, Version 2.0 (the "License");
           you may not use this file except in compliance with the License.
           You may obtain a copy of the License at
        
               http://www.apache.org/licenses/LICENSE-2.0
        
           Unless required by applicable law or agreed to in writing, software
           distributed under the License is distributed on an "AS IS" BASIS,
           WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
           See the License for the specific language governing permissions and
           limitations under the License.
        
Project-URL: Homepage, https://github.com/Barbade22/transformer-toolkit
Project-URL: Repository, https://github.com/Barbade22/transformer-toolkit
Keywords: transformer,llm,deep learning,nlp,pytorch
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: pydantic>=2.0.0
Provides-Extra: tokenizers
Requires-Dist: tokenizers>=0.15.0; extra == "tokenizers"
Provides-Extra: hf
Requires-Dist: transformers>=4.35.0; extra == "hf"
Requires-Dist: huggingface_hub>=0.20.0; extra == "hf"
Requires-Dist: datasets>=2.14.0; extra == "hf"
Provides-Extra: all
Requires-Dist: transformer-toolkit[hf,tokenizers]; extra == "all"
Dynamic: license-file

# Transformer-Toolkit

<p align="center">
  <img src="https://raw.githubusercontent.com/Barbade22/transformer-toolkit/main/images/image.png" alt="Transformer Toolkit Logo" width="600"/>
</p>

<p align="center">
  <a href="https://pypi.org/project/transformer-toolkit/"><img src="https://img.shields.io/pypi/v/transformer-toolkit?color=cyan&style=flat-square" alt="PyPI version"/></a>
  <a href="https://pypi.org/project/transformer-toolkit/"><img src="https://static.pepy.tech/badge/transformer-toolkit" alt="Downloads"/></a>
  <a href="https://github.com/Barbade22/transformer-toolkit/blob/main/LICENSE"><img src="https://img.shields.io/github/license/Barbade22/transformer-toolkit?style=flat-square" alt="License"/></a>
</p>


A modular, from-scratch transformer library for training and experimenting with modern LLM architectures. Swap attention types, positional encodings, FFN variants, and normalization — all from a single config object.

```bash
pip install transformer-toolkit
```

---

## Table of Contents

- [Quick Start](#quick-start)
- [Model](#model)
  - [TransformerConfig](#transformerconfig)
  - [Transformer](#transformer)
  - [Weight Tying](#weight-tying)
  - [Debug Mode](#debug-mode)
- [Attention](#attention)
- [Feed-Forward Networks](#feed-forward-networks)
- [Positional Encodings](#positional-encodings)
- [Normalization](#normalization)
- [Dataloader](#dataloader)
  - [DataConfig](#dataconfig)
  - [Binary Files](#loading-from-a-binary-file)
  - [Memmap NPY](#memmap--loading-pre-split-npy-files)
  - [Text Files](#loading-from-text-files)
  - [HuggingFace](#loading-from-huggingface)
  - [Debug Samples](#dataloader-debug-mode)
- [Tokenizers](#tokenizers)
  - [ByteLevelTokenizer](#byteleveltokenizer)
  - [RustBPETokenizer](#rustbpetokenizer)
  - [HFTokenizer](#hftokenizer)
- [Trainer](#trainer)
  - [TrainConfig](#trainconfig)
  - [Training Loop](#training-loop)
- [Supervised Fine-Tuning (SFT)](#supervised-fine-tuning-sft)
  - [How it works](#how-it-works)
  - [Chat templates](#chat-templates)
  - [Data formats](#data-formats)
  - [SFT data loading](#sft-data-loading)
  - [SFT training](#sft-training)
  - [Loading and inference](#loading-and-inference)
  - [Debug output](#debug-output)
  - [Common issues](#common-issues)
- [HuggingFace Hub](#huggingface-hub)
- [Generation](#generation)
- [Full Examples](#full-examples)
  - [Small Model — Shakespeare](#small-model--shakespeare)
  - [Large Dataset — HuggingFace Streaming](#large-dataset--huggingface-streaming)
  - [MoE Model](#moe-model)
- [Architecture Reference](#architecture-reference)
- [Requirements](#requirements)

---

## Quick Start

```python
import torch
from transformer_toolkit.model import Transformer, TransformerConfig
from transformer_toolkit.c_tokenizers import RustBPETokenizer
from transformer_toolkit.dataloader import DataConfig, from_binary, save_binary
from transformer_toolkit.trainer import Trainer, TrainConfig

# tokenizer
tok = RustBPETokenizer()
tok.train(open("data.txt", encoding="utf-8").readlines(), vocab_size=8000)
tok.save("tokenizer.json")

# data
save_binary(tok.encode(open("data.txt", encoding="utf-8").read()), "data.bin")
train_dl, val_dl = from_binary("data.bin", DataConfig(seq_len=128, batch_size=32))

# model
model = Transformer(TransformerConfig(
    vocab_size  = tok.vocab_size,
    dim         = 512,
    n_layers    = 8,
    n_heads     = 8,
    pos_enc     = "rope",
    tie_weights = False,   # recommended for training from scratch
)).to("cuda")

# train
trainer = Trainer(model, train_dl, val_dl, tok.vocab_size, TrainConfig(max_steps=3000))
trainer.train()
```

---

## Model

### TransformerConfig

All architecture decisions live in one dataclass. Pass it to `Transformer()`.
```python
from transformer_toolkit.model import TransformerConfig

cfg = TransformerConfig(
    # ── core ──────────────────────────────────────────────────────────
    vocab_size = 32000,      # tokenizer vocabulary size
    dim        = 512,        # model embedding dimension
    n_layers   = 8,          # number of transformer blocks
    n_heads    = 8,          # number of attention heads
    max_seq    = 2048,       # maximum sequence length

    # ── attention ─────────────────────────────────────────────────────
    attn       = "gqa",      # "mha" | "gqa" | "mqa" | "flash" | "mla"
    n_kv_heads = 4,          # gqa only — n_heads must be divisible by n_kv_heads
    latent_dim = 64,         # mla only — latent compression dimension

    # ── feed-forward ──────────────────────────────────────────────────
    ffn        = "swiglu",   # "ffn" | "relu_ffn" | "glu" | "reglu" | "geglu"
                             # | "swiglu" | "moe" | "moe_ec" | "moe_shared"
    hidden_dim = 2048,       # FFN inner dimension (default: dim × 4)
    n_experts  = 8,          # moe / moe_ec / moe_shared — total experts
    top_k      = 2,          # moe / moe_shared — experts activated per token
    moe_aux_weight = 0.01,   # moe / moe_shared — load-balancing loss coefficient
    moe_capacity   = 1.0,    # moe_ec — capacity factor
    moe_n_shared   = 2,      # moe_shared — always-active experts
    moe_n_routed   = 6,      # moe_shared — sparse routed experts

    # ── normalization ─────────────────────────────────────────────────
    norm       = "rmsnorm",  # "rmsnorm" | "layernorm"
    eps        = 1e-6,

    # ── positional encoding ───────────────────────────────────────────
    pos_enc    = "rope",     # "rope" | "sinusoidal" | "learned" | "alibi" | "none"

    # ── regularisation ────────────────────────────────────────────────
    dropout    = 0.0,        # 0.0 recommended for SFT and inference

    # ── output ────────────────────────────────────────────────────────
    tie_weights = True,      # share embedding and output projection weights
                             # recommended — halves vocab params, better for small models
                             # see Weight Tying section before disabling
)
```

### Transformer

```python
from transformer_toolkit.model import Transformer

model = Transformer(cfg).to("cuda")

print(model.n_params())   # "30.21M"

# forward pass — returns (logits, aux_loss)
# aux_loss is non-zero only for MoE; always add it to your training loss
logits, aux_loss = model(tokens)   # tokens: [B, T]  →  logits: [B, T, vocab_size]

# generation
output = model.generate(
    tokens      = prompt_tokens,   # [B, T]
    max_new     = 200,
    temperature = 0.8,
    top_k       = 40,
)
```

### Weight Tying

Weight tying makes the embedding matrix and the output projection share the same tensor in memory. This reduces parameter count and can improve perplexity, but requires careful initialization.

> **Important:** `nn.Embedding` initializes weights with `N(0, 1)` — values around ±5. When the head shares these large weights, it produces logits of ±400 at initialization instead of the expected ±3, causing loss to start at ~346 instead of the correct ~`log(vocab_size)`. The model cannot recover from this initialization.

**Recommended approach — disable tying for training from scratch:**

```python
cfg = TransformerConfig(
    ...
    tie_weights = False,   # safe default for training from scratch
)
```

**If you want to enable tying**, scale down the embedding at initialization:

```python
model = Transformer(cfg).to("cuda")

if cfg.tie_weights:
    with torch.no_grad():
        model.embed.weight.mul_(0.02)   # bring logits into ±3 range
```

**Checkpoint save/load with tying enabled** — use the dedicated helpers to prevent the tie from breaking across save/load cycles:

```python
# saving
torch.save({"model": model.state_dict_for_save(), ...}, "checkpoint.pt")

# loading
model.load_state_dict_with_tie(ckpt["model"])
```

### Debug Mode

Pass `debug=True` to `Transformer()` to get a model summary at construction and a full forward pass trace.

```python
model = Transformer(cfg, debug=True).to("cuda")
model.debug = False   # turn off after inspecting — runs on every forward pass
```

**What it prints at construction:**

```
  🏗️  Model summary
  params             16.35M
  dim                384
  n_layers           6
  entropy check → should be > 90% of log(vocab_size) at init

  parameter breakdown:
  embed     ███░░░░░░░░░░░░░░░░░  3.07M  18.8%
  blocks    ████████████████░░░░  13.28M  81.2%
```

**What it prints per forward pass:**

```
  🔬 Forward pass debug
  tokens   [32, 128]  int64
  embed    [32, 128, 384]  float32  min=-4.84  mean=+0.00  max=+4.95
  block 0  residual update norm ratio: 0.133   ← healthy (0.01–2.0)
  logits   [32, 128, 8000]  float32  min=-2.97  mean=0.00  max=+2.98
  entropy  8.821 / max 8.99  (98.1% of uniform)   ← healthy at init
```

**Entropy at init should be above 90% of `log(vocab_size)`**. If it shows `-0.0%`, the logit scale is wrong — check the weight tying section above.

**Additional debug utilities:**

```python
# after loss.backward() — inspect gradient health per parameter
model.debug_gradients()

# any time — inspect weight statistics per parameter
model.debug_weights()
```

---

## Attention

Five attention variants, all swappable via `TransformerConfig.attn`.

| Value | Class | Used in |
|-------|-------|---------|
| `"mha"` | `MultiHeadAttention` | Original Transformer, BERT, GPT-2 |
| `"gqa"` | `GroupedQueryAttention` | LLaMA 3, Mistral |
| `"mqa"` | `MultiQueryAttention` | Falcon, early Gemini |
| `"flash"` | `FlashAttention` | Any model on PyTorch ≥ 2.0 |
| `"mla"` | `MLAttention` | DeepSeek-V2/V3 |

**RoPE** is applied inside attention to `q` and `k` after head-splitting — not to the residual stream. It is instantiated once and shared across all layers. The cos/sin cache is kept in `float32` regardless of model dtype to preserve precision.

**ALiBi** bias is computed once per forward pass and passed as an additive mask to every block.

**Causal masking** is applied automatically inside each attention module. You do not need to pass a mask for standard language model training.

### Example — Flash Attention

```python
cfg = TransformerConfig(
    dim     = 512,
    n_heads = 8,
    attn    = "flash",   # uses torch.nn.functional.scaled_dot_product_attention
)
```

### Example — Grouped Query Attention (LLaMA-style)

```python
cfg = TransformerConfig(
    dim        = 512,
    n_heads    = 8,
    attn       = "gqa",
    n_kv_heads = 2,   # 4 query heads share each kv head → 4x KV cache reduction
)
```

---

## Feed-Forward Networks

| Value | Class | Used in |
|-------|-------|---------|
| `"ffn"` | `FFN` | Original Transformer, BERT |
| `"swiglu"` | `SwiGLU` | LLaMA, Mistral, PaLM |
| `"moe"` | `MoE` | Mixtral, GPT-4 (rumoured) |

### MoE — Mixture of Experts

When using `ffn="moe"`, the model forward pass returns an auxiliary load-balancing loss that **must** be added to the main loss. Without it, all tokens collapse onto 1–2 experts within a few hundred steps and the remaining experts never get trained.

```python
cfg = TransformerConfig(
    ffn            = "moe",
    n_experts      = 8,
    top_k          = 2,
    moe_aux_weight = 0.01,   # weight of the load-balancing term (Mixtral uses 0.02)
)

logits, aux_loss = model(tokens)
ce_loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
loss    = ce_loss + aux_loss   # aux_loss is 0.0 for non-MoE models — safe to always add
```

The `Trainer` handles `aux_loss` automatically — no changes to training code needed.

---

## Positional Encodings

| Value | Applied where | Notes |
|-------|---------------|-------|
| `"rope"` | Inside attention, on q and k | LLaMA, Mistral, Qwen — best for most use cases |
| `"sinusoidal"` | Residual stream before blocks | Original Transformer — no parameters |
| `"learned"` | Residual stream before blocks | BERT, GPT-2 — trainable |
| `"alibi"` | Additive bias on attention scores | Good for length generalization |
| `"none"` | Not applied | Bare model with no position information |

Each encoding applies exactly once in exactly one place — there is no double-application between the residual stream and attention.

---

## Normalization

| Value | Class | Notes |
|-------|-------|-------|
| `"rmsnorm"` | `RMSNorm` | LLaMA, Mistral, Qwen — no mean subtraction, no bias, faster |
| `"layernorm"` | `LayerNorm` | BERT, GPT-2 — classic formulation with bias |

---

## Dataloader

### DataConfig

```python
from transformer_toolkit.dataloader import DataConfig

cfg = DataConfig(
    seq_len     = 128,    # sequence length fed to the model
    batch_size  = 32,     # samples per batch
    split       = 0.9,    # fraction of data used for training
    stride      = None,   # None = non-overlapping windows (strongly recommended)
                          # stride < seq_len = overlapping windows (more samples,
                          # but causes rapid overfitting on small datasets)
    shuffle     = True,
    num_workers = 4,
    pin_memory  = True,
    debug       = False,  # print decoded sample preview before training starts
    debug_n     = 3,      # number of samples to show when debug=True
)
```

> **stride** — the default `stride=None` (equivalent to `stride=seq_len`) produces non-overlapping windows. For a 1.86M token dataset with `seq_len=128` this gives ~14,600 clean distinct samples. Setting `stride=1` gives 1.86M heavily-overlapping samples and causes rapid overfitting on small datasets.

### Loading from a Binary File

```python
from transformer_toolkit.dataloader import save_binary, from_binary

# tokenize once and save to disk
save_binary(tok.encode(text), "data.bin")

# load — supports both raw uint16 binary and .npy
train_dl, val_dl = from_binary("data.bin", cfg, tokenizer=tok)

# pass train_path and val_path to save splits as memmap .npy for future runs
train_dl, val_dl = from_binary(
    "data.bin", cfg,
    train_path = "train.npy",
    val_path   = "val.npy",
    tokenizer  = tok,
)
```

### Memmap — Loading Pre-split NPY Files

On second and subsequent runs, load the pre-split `.npy` files directly. The token file stays on disk — only the pages actually accessed are loaded into RAM. Scales to datasets of 100GB+.

```python
from transformer_toolkit.dataloader import from_npy_split

train_dl, val_dl = from_npy_split("train.npy", "val.npy", cfg, tokenizer=tok)
```

### Loading from Text Files

```python
from transformer_toolkit.dataloader import from_files

train_dl, val_dl = from_files(
    paths      = ["data1.txt", "data2.txt"],
    tokenizer  = tok,
    cfg        = cfg,
    train_path = "train.npy",   # optional — saves splits for future memmap reuse
    val_path   = "val.npy",
    bos_id     = tok.bos_id,    # optional — wrap each document with BOS/EOS tokens
    eos_id     = tok.eos_id,
)
```

### Loading from HuggingFace

```python
from transformer_toolkit.dataloader import from_hf

# streaming — no full download required, works with infinite datasets
cfg_stream = DataConfig(seq_len=512, batch_size=16, streaming=True)
train_dl, val_dl = from_hf("roneneldan/TinyStories", tok, cfg_stream)

# in-memory — downloads fully, then splits and optionally saves as .npy
train_dl, val_dl = from_hf(
    dataset_name = "roneneldan/TinyStories",
    tokenizer    = tok,
    cfg          = cfg,
    text_col     = "text",
    bos_id       = 1,
    eos_id       = 2,
    train_path   = "train.npy",
    val_path     = "val.npy",
)
```

### Dataloader Debug Mode

```python
cfg = DataConfig(seq_len=128, batch_size=32, debug=True, debug_n=3)
train_dl, val_dl = from_binary("data.bin", cfg, tokenizer=tok)
```

Prints before training starts, showing decoded text and verifying x/y alignment:

```
  🔍 Debug samples (train)
  seq_len=128  stride=128  batch_size=32

  sample 1
  x ids : [23, 451, 12, 8, 1203 ...] ... +121
  y ids : [451, 12, 8, 1203, 44 ...] ... +121
  x text: 'ROMEO:\nBut soft, what light through yonder window...'
  y text: '\nBut soft, what light through yonder window breaks'
  ✓  x/y alignment correct (y = x shifted by 1)
```

---

## Tokenizers

Three tokenizers with a unified interface.

```python
from transformer_toolkit.c_tokenizers import (
    ByteLevelTokenizer,
    RustBPETokenizer,
    HFTokenizer,
)
```

### ByteLevelTokenizer

Zero dependencies. Every byte is a token (vocab size fixed at 256). Works on any language or encoding out of the box.

```python
tok = ByteLevelTokenizer()
ids = tok.encode("Hello world")   # [72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]
txt = tok.decode(ids)             # "Hello world"
print(tok.vocab_size)             # 256
```

### RustBPETokenizer

BPE tokenizer backed by HuggingFace's Rust `tokenizers` library. Trains approximately 100x faster than a pure Python BPE implementation.

```bash
pip install tokenizers
```

```python
tok = RustBPETokenizer()
tok.train(open("data.txt").readlines(), vocab_size=8000)
tok.save("tokenizer.json")

# on subsequent runs — load instead of retraining
tok.load("tokenizer.json")

ids = tok.encode("Hello world")
txt = tok.decode(ids)
print(tok.vocab_size)   # 8000
```

### HFTokenizer

Thin wrapper around any HuggingFace pretrained tokenizer.

```bash
pip install transformers
```

```python
tok = HFTokenizer("gpt2")
ids = tok.encode("Hello world")
txt = tok.decode(ids)
print(tok.vocab_size)   # 50257
```

---

## Trainer

### TrainConfig

```python
from transformer_toolkit.trainer import TrainConfig

cfg = TrainConfig(
    # ── steps ─────────────────────────────────────────────────────────
    max_steps        = 10000,   # total number of optimizer steps
    eval_every       = 500,     # run validation every N steps
    save_every       = 1000,    # save step_N.pt every N steps
    log_every        = 50,      # print loss and lr every N steps
    interruptible    = True,    # Ctrl+C saves a clean checkpoint instead of crashing

    # ── optimiser ─────────────────────────────────────────────────────
    lr               = 3e-4,   # peak learning rate after warmup
    min_lr           = 3e-5,   # floor lr at end of cosine decay (typically lr / 10)
    weight_decay     = 0.1,    # L2 penalty on 2D weights — biases and norms excluded
    beta1            = 0.9,    # AdamW beta1
    beta2            = 0.95,   # AdamW beta2
    grad_clip        = 1.0,    # max gradient norm

    # ── lr schedule ───────────────────────────────────────────────────
    warmup_steps     = 200,    # linear ramp from 0 to peak lr over this many steps

    # ── efficiency ────────────────────────────────────────────────────
    grad_accum_steps = 4,      # effective batch = batch_size × grad_accum_steps
    mixed_precision  = True,   # bf16/fp16 on CUDA, float32 on CPU automatically
    grad_checkpoint  = False,  # recompute activations during backward (~20% slower,
                               # but reduces VRAM by ~60% for large models)

    # ── checkpoints ───────────────────────────────────────────────────
    ckpt_dir         = "checkpoints",
    save_best        = True,        # save best.pt whenever val loss improves
    save_step_ckpts  = True,        # save step_N.pt every save_every steps

    # ── huggingface hub ───────────────────────────────────────────────
    hf_repo          = "username/my-model",   # None to disable
    hf_private       = True,
    hf_push_best     = True,    # push to hub whenever best val loss improves
    hf_push_every_n  = False,   # push to hub every save_every steps
    hf_push_end      = True,    # push to hub at end of training
    hf_push_on_pause = True,    # push to hub on Ctrl+C pause
)
```

### Training Loop

```python
from transformer_toolkit.trainer import Trainer

trainer = Trainer(
    model      = model,
    train_dl   = train_dl,
    val_dl     = val_dl,
    vocab_size = tok.vocab_size,
    cfg        = cfg_train,
    tokenizer  = tok,        # optional — used for HuggingFace hub uploads
)

# start training
trainer.train()

# resume from a checkpoint
trainer.train(resume_from="checkpoints/step_2000.pt")
```

Training output:

```
  ⚡ Transformer Toolkit Trainer
  steps=3000  lr=0.0003  warmup=200  accum=4
  mixed_precision=True  grad_clip=1.0

  step    100/3000  ████████░░░░░░░░░░░░░░░░  loss 3.1423  lr 1.5e-04  eta 4m
  step    200/3000  ████████████░░░░░░░░░░░░  loss 2.8901  lr 3.0e-04  eta 3m

  ● eval  step 300  val_loss 2.7130  ppl 15.07  ▼0.1823  ★ best
```

**Expected loss curve for a healthy run:**

| Step | Expected val loss | Notes |
|------|-------------------|-------|
| 0 | ~`log(vocab_size)` | Random init — ~8.99 for vocab=8000 |
| 100 | 5–7 | Model learning basic patterns |
| 300 | 3–5 | First eval — confirm learning is happening |
| 1000 | 2–3.5 | Good progress |
| 3000 | 1.5–2.5 | Healthy final loss for a small model |

If val loss is still above 8.0 at step 300, something is wrong with initialization. If it drops below 1.0 before step 1000 on a small dataset, you are overfitting.

---

# Supervised Fine-Tuning (SFT)

Transformer Toolkit supports full SFT training on top of a pretrained model.
The pipeline handles data formatting, loss masking, multi-turn conversations,
and inference — all with the same tokenizer used during pretraining.

---

## How it works

During pretraining the model learns language from raw text with no special structure.
SFT teaches it to follow a specific conversation format — roles, turns, and how to stop.

The key idea is the **loss mask**. Not all tokens contribute to the loss:

```
<|start_header_id|>user<|end_header_id|>        → loss=0  (model sees this as context)
What is Python?<|eot_id|>                        → loss=0
<|start_header_id|>assistant<|end_header_id|>   → loss=0  (header primes generation)

Python is a programming language.<|eot_id|>     → loss=1  (model learns this)
[EOS]                                            → loss=1  (model learns to stop)
```

The model only trains on what it needs to *generate* — assistant content, the turn-closing
token, and EOS. Everything else is context.

---

## Tokenizer

All special tokens must be registered **before pretraining**. The vocabulary is frozen
after pretraining — tokens cannot be added at SFT time.

`RustBPETokenizer` registers all required tokens automatically at train time:

```python
from transformer_toolkit import RustBPETokenizer

tok = RustBPETokenizer()

with open("corpus.txt", encoding="utf-8") as f:
    lines = [l.strip() for l in f if l.strip()]

tok.train(texts=lines, vocab_size=32_000)
tok.save("tokenizer.json")
```

Fixed special token IDs (always at these positions regardless of vocab size):

| ID | Token | Used for |
|----|-------|----------|
| 0  | `[UNK]` | unknown token |
| 1  | `[PAD]` | padding |
| 2  | `[BOS]` | beginning of sequence |
| 3  | `[EOS]` | end of sequence |
| 4  | `[SEP]` | separator |
| 5  | `[MASK]` | masked token |
| 6  | `[CLS]` | classification |
| 7  | `<\|im_start\|>` | ChatML turn start |
| 8  | `<\|im_end\|>` | ChatML turn end |
| 9  | `<\|start_header_id\|>` | LLaMA3 header start |
| 10 | `<\|end_header_id\|>` | LLaMA3 header end |
| 11 | `<\|eot_id\|>` | LLaMA3 end of turn |
| 12 | `<start_of_turn>` | Gemma turn start |
| 13 | `<end_of_turn>` | Gemma turn end |
| 14 | `<\|tool_call\|>` | tool use |
| 15 | `<\|tool_result\|>` | tool result |
| 16 | `<\|doc_start\|>` | document boundary |
| 17 | `<\|doc_end\|>` | document boundary |
| 18 | `<\|code_start\|>` | code block |
| 19 | `<\|code_end\|>` | code block |
| 20 | `<\|system\|>` | system prompt |

---

## Chat templates

A `ChatTemplate` defines how conversations are formatted into a string.
Pick one template and use it consistently across SFT and inference.

```python
from transformer_toolkit import ChatTemplate

template = ChatTemplate("llama3")   # or "chatml", "gemma", "alpaca", "raw"
```

### Available presets

| Preset | Format | Special tokens |
|--------|--------|----------------|
| `llama3` | `<\|start_header_id\|>role<\|end_header_id\|>\n\ncontent<\|eot_id\|>` | IDs 9, 10, 11 |
| `chatml` | `<\|im_start\|>role<\|im_end\|>\ncontent<\|im_end\|>\n` | IDs 7, 8 |
| `gemma` | `<start_of_turn>role<end_of_turn>\ncontent<end_of_turn>\n` | IDs 12, 13 |
| `alpaca` | `### Instruction:\ncontent\n\n### Response:\ncontent` | none |
| `raw` | `User: content\nAssistant: content` | none |

### Custom template

```python
template = ChatTemplate(
    preset           = "chatml",          # base preset to inherit from
    assistant_header = "<\|im_start\|>assistant\n",   # loss=0
    assistant_closer = "<\|im_end\|>\n",              # loss=1
)
```

---

## Data formats

Three schemas are supported and auto-detected:

### messages (recommended for multi-turn)
```json
{
  "messages": [
    {"role": "system",    "content": "You are a helpful assistant."},
    {"role": "user",      "content": "What is Python?"},
    {"role": "assistant", "content": "Python is a programming language."},
    {"role": "user",      "content": "How do I reverse a list?"},
    {"role": "assistant", "content": "Use my_list[::-1]."}
  ]
}
```

### prompt_response (single-turn)
```json
{"prompt": "What is Python?", "response": "Python is a programming language."}
```

### instruction (Alpaca style)
```json
{"instruction": "Explain Python.", "input": "", "output": "Python is a programming language."}
```

All three can be mixed in the same dataset — schema is detected per sample.

---

## SFT data loading

```python
from transformer_toolkit import RustBPETokenizer, ChatTemplate
from transformer_toolkit import SFTDataConfig, from_sft_strings

tok = RustBPETokenizer()
tok.load("tokenizer.json")

cfg = SFTDataConfig(
    tokenizer            = tok,       # auto-pulls bos_id, eos_id, pad_id
    seq_len              = 512,       # must match model max_seq
    batch_size           = 8,
    split                = 0.9,       # 90% train, 10% val
    template             = "llama3",  # must match what was used at pretrain time
    schema               = "auto",    # auto-detect per sample
    truncation_strategy  = "turn",    # drop whole turns instead of cutting mid-response
    debug                = True,      # print sample debug info on first batch
    debug_n              = 2,         # number of debug samples to show
)

# from a list of dicts in memory
train_dl, val_dl = from_sft_strings(samples, tok, cfg)

# from a local file
from transformer_toolkit import from_sft_json
train_dl, val_dl = from_sft_json("data.jsonl", tok, cfg)

# from multiple files
from transformer_toolkit import from_sft_files
train_dl, val_dl = from_sft_files(["data1.jsonl", "data2.jsonl"], tok, cfg)

# from HuggingFace
from transformer_toolkit import from_sft_hf
train_dl, val_dl = from_sft_hf("tatsu-lab/alpaca", tok, cfg)
```

### truncation_strategy

| Value | Behaviour | Use when |
|-------|-----------|----------|
| `"token"` | Hard-truncate at `seq_len` | Single-turn SFT |
| `"turn"` | Drop whole user+assistant pairs from the end | Multi-turn conversations |

`"turn"` is always safer for conversations — it never leaves a partial assistant
response with `loss=1` on incomplete text.

---

## SFT training

```python
from transformer_toolkit import Transformer, TransformerConfig
from transformer_toolkit import SFTTrainer, TrainConfig
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load pretrained model
cfg_model = TransformerConfig(
    vocab_size  = tok.vocab_size,
    dim         = 512,
    n_layers    = 8,
    n_heads     = 8,
    n_kv_heads  = 2,          # GQA — n_heads must be divisible by n_kv_heads
    attn        = "gqa",
    ffn         = "swiglu",
    hidden_dim  = 2048,
    norm        = "rmsnorm",
    pos_enc     = "rope",
    dropout     = 0.0,        # typically 0 for SFT
    tie_weights = True,       # recommended — halves embedding params
    max_seq     = 512,        # must match SFTDataConfig seq_len
)

model = Transformer(cfg_model).to(DEVICE)

# optionally load pretrained weights before SFT
ckpt = torch.load("pretrain_checkpoints/best.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model"])

cfg_train = TrainConfig(
    max_steps        = 1000,
    warmup_steps     = 50,
    eval_every       = 100,
    save_every       = 200,
    log_every        = 25,
    lr               = 1e-4,      # lower than pretraining — typically 1e-4 to 5e-5
    min_lr           = 1e-5,
    grad_accum_steps = 4,
    mixed_precision  = True,
    save_best        = True,
    save_step_ckpts  = True,
    ckpt_dir         = "sft_checkpoints",
    hf_repo          = None,      # "username/model-name" to push to HF Hub
)

trainer = SFTTrainer(
    model      = model,
    train_dl   = train_dl,
    val_dl     = val_dl,
    vocab_size = tok.vocab_size,
    cfg        = cfg_train,
    tokenizer  = tok,
)
trainer.train()
```

### SFT vs pretraining hyperparameters

| Parameter | Pretraining | SFT |
|-----------|-------------|-----|
| `lr` | `3e-4` | `1e-4` — `5e-5` |
| `dropout` | `0.1` | `0.0` |
| `warmup_steps` | `1000+` | `50` — `100` |
| `grad_accum_steps` | `8+` | `4` |

Lower learning rate for SFT — you are fine-tuning an existing model, not training from scratch.
Too high an LR causes catastrophic forgetting of pretraining knowledge.

---

## Loading and inference

```python
from transformer_toolkit import RustBPETokenizer, ChatTemplate, Transformer, TransformerConfig
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load tokenizer
tok = RustBPETokenizer()
tok.load("tokenizer.json")

# load model — same config as training
cfg_model = TransformerConfig(...)
model     = Transformer(cfg_model).to(DEVICE)

# load SFT checkpoint — inference only, strip optimizer state
ckpt = torch.load("sft_checkpoints/best.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model"])
model.eval()

# same template as training
template = ChatTemplate("llama3")


def chat(
    prompt:      str,
    system:      str   = None,
    history:     list  = None,
    max_new:     int   = 200,
    temperature: float = 0.8,
    top_k:       int   = 50,
) -> str:
    msgs = []
    if system:
        msgs.append({"role": "system", "content": system})
    if history:
        msgs.extend(history)
    msgs.append({"role": "user", "content": prompt})

    # format and append assistant header to prime generation
    full_text, _ = template.format_messages(msgs)
    primed       = full_text + template.assistant_header

    ids = tok.encode(primed)
    # truncate from left if too long — keep most recent context
    if len(ids) > cfg_model.max_seq:
        ids = ids[-cfg_model.max_seq:]

    x = torch.tensor([ids], dtype=torch.long).to(DEVICE)
    with torch.no_grad():
        out = model.generate(x, max_new=max_new, temperature=temperature, top_k=top_k)

    new_ids  = out[0][len(ids):].tolist()
    response = tok.decode(new_ids, skip_special_tokens=False)

    # strip end-of-turn marker
    closer = template.assistant_closer.strip()
    if closer and closer in response:
        response = response[:response.index(closer)]

    return response.strip()
```

### Single turn

```python
print(chat("What is the capital of France?"))
```

### With system prompt

```python
print(chat(
    prompt = "How do I reverse a string in Python?",
    system = "You are a concise coding assistant. Answer in 1-2 sentences.",
))
```

### Multi-turn conversation

```python
history = []
system  = "You are a helpful Python tutor."

while True:
    user_input = input("You: ").strip()
    if not user_input:
        break

    reply = chat(
        prompt      = user_input,
        system      = system,
        history     = history,
        temperature = 0.8,
        top_k       = 50,
    )

    # append to history for next turn
    history.append({"role": "user",      "content": user_input})
    history.append({"role": "assistant", "content": reply})

    print(f"Assistant: {reply}\n")
```

### Saving an inference-only checkpoint

The full training checkpoint includes optimizer state (~3× the model size).
For deployment, strip it:

```python
# after training
torch.save({"model": model.state_dict()}, "model_inference.pt")
# full checkpoint:  ~500 MB  (model + Adam optimizer m/v buffers)
# inference only:   ~170 MB  (model weights only)
```

---

## Debug output

Set `debug=True` in `SFTDataConfig` to inspect samples before training.
The debug view shows the exact formatted text with color coding:

```
── formatted view ──
<|start_header_id|>user<|end_header_id|>          ← cyan  (loss=0)
What is Python?<|eot_id|>                          ← cyan  (loss=0)
<|start_header_id|>assistant<|end_header_id|>     ← cyan  (loss=0)
Python is a programming language.<|eot_id|>[EOS]  ← green (loss=1)
```

Sanity checks run automatically on each sample:

- **zero response tokens** — schema or template mismatch
- **alignment** — `y = x` shifted by 1 (catches dataset bugs)
- **heavy padding** — suggests a smaller `seq_len`

---

## Common issues

**`Template tokens [...] are fragmented`**
The tokenizer was saved before the special tokens were registered.
Retrain the tokenizer — vocabulary cannot be changed after pretraining.

**`n_heads must be divisible by n_kv_heads`**
GQA requires `n_heads % n_kv_heads == 0`.
Example: `n_heads=6, n_kv_heads=3` ✓ — `n_heads=2, n_kv_heads=3` ✗

**`seq_len` mismatch**
`SFTDataConfig(seq_len=512)` and `TransformerConfig(max_seq=512)` must match exactly.

**High padding warning**
Your `seq_len` is much larger than your average sample length.
Use `truncation_strategy="turn"` and lower `seq_len` to match your data.

**Model not learning / loss not decreasing**
Check `mask sum` in debug output — if response tokens are very few relative to
total tokens, the model gets very little gradient signal per batch.
Increase `batch_size` or `grad_accum_steps` to compensate.


## HuggingFace Hub

### Login

```python
from transformer_toolkit.hf_hub import login

login(token="hf_your_token_here")
```

### Push to Hub

```python
from transformer_toolkit.hf_hub import push_to_hub

push_to_hub(
    repo_id   = "username/my-model",
    model     = model,
    cfg       = cfg_model,
    tokenizer = tok,
    metrics   = {"val_loss": 1.83, "perplexity": 6.23},
    step      = 3000,
    private   = True,
)
```

### Pull from Hub

```python
from transformer_toolkit.hf_hub import pull_from_hub

pull_from_hub("username/my-model", save_dir="checkpoints")
# downloads: model.pt, tokenizer.json, config.json, metrics.json
```

---

## Generation

```python
from transformer_toolkit.model import Transformer, TransformerConfig
from transformer_toolkit.c_tokenizers import RustBPETokenizer
from transformer_toolkit.trainer import load_ckpt
import torch

DEVICE = torch.device("cuda")

# load tokenizer — always load the saved file, never retrain
tok = RustBPETokenizer()
tok.load("tokenizer.json")

# model config must match the training config exactly
cfg = TransformerConfig(
    vocab_size  = tok.vocab_size,
    dim         = 384,
    n_layers    = 6,
    n_heads     = 6,
    attn        = "gqa",
    n_kv_heads  = 3,
    ffn         = "swiglu",
    hidden_dim  = 1536,
    norm        = "rmsnorm",
    pos_enc     = "rope",
    dropout     = 0.0,        # always 0.0 at inference
    tie_weights = False,
)
model = Transformer(cfg).to(DEVICE)
load_ckpt("checkpoints/best.pt", model)
model.eval()

def generate(prompt, max_new=200, temperature=0.8, top_k=40):
    ids    = tok.encode(prompt)
    tokens = torch.tensor([ids], dtype=torch.long, device=DEVICE)
    out    = model.generate(tokens, max_new=max_new,
                             temperature=temperature, top_k=top_k)
    return tok.decode(out[0].tolist())

print(generate("ROMEO:"))
print(generate("To be or not to be,"))
```

**Generation parameters:**

| Parameter | Effect | Recommended range |
|-----------|--------|-------------------|
| `temperature` | Higher = more random, lower = more repetitive | 0.7 – 1.0 |
| `top_k` | Only sample from the top-k most likely tokens | 20 – 50 |
| `max_new` | Number of new tokens to generate | 100 – 500 |

---

## Full Examples

### Small Model — Shakespeare

Suitable for any GPU. Trains in under 10 minutes on a 4GB card.

```python
import torch, os
from transformer_toolkit.model import Transformer, TransformerConfig
from transformer_toolkit.c_tokenizers import RustBPETokenizer
from transformer_toolkit.dataloader import DataConfig, from_binary, from_npy_split, save_binary
from transformer_toolkit.trainer import Trainer, TrainConfig

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# tokenizer — load if saved, train once otherwise
tok = RustBPETokenizer()
if os.path.exists("tokenizer.json"):
    tok.load("tokenizer.json")
else:
    tok.train(open("shakespeare.txt", encoding="utf-8").readlines(), vocab_size=8000)
    tok.save("tokenizer.json")

# data — tokenize once, reuse memmap splits on subsequent runs
cfg_data = DataConfig(seq_len=128, batch_size=32, split=0.9, stride=None)
if os.path.exists("train.npy") and os.path.exists("val.npy"):
    train_dl, val_dl = from_npy_split("train.npy", "val.npy", cfg_data, tokenizer=tok)
else:
    if not os.path.exists("data.bin"):
        save_binary(tok.encode(open("shakespeare.txt", encoding="utf-8").read()), "data.bin")
    train_dl, val_dl = from_binary("data.bin", cfg_data,
                                    train_path="train.npy", val_path="val.npy",
                                    tokenizer=tok)

# model
model = Transformer(TransformerConfig(
    vocab_size  = tok.vocab_size,
    dim         = 384,
    n_layers    = 6,
    n_heads     = 6,
    n_kv_heads  = 3,
    attn        = "gqa",
    ffn         = "swiglu",
    hidden_dim  = 1536,
    norm        = "rmsnorm",
    pos_enc     = "rope",
    dropout     = 0.1,
    tie_weights = False,
)).to(DEVICE)
print(f"params: {model.n_params()}")   # ~15M

trainer = Trainer(model, train_dl, val_dl, tok.vocab_size, TrainConfig(
    max_steps        = 3000,
    warmup_steps     = 200,
    eval_every       = 300,
    lr               = 3e-4,
    grad_accum_steps = 4,
    mixed_precision  = True,
    save_best        = True,
    save_step_ckpts  = True,
))
trainer.train()
```

### Large Dataset — HuggingFace Streaming

```python
from transformer_toolkit.dataloader import DataConfig, from_hf, from_npy_split
from transformer_toolkit.c_tokenizers import HFTokenizer

tok = HFTokenizer("HuggingFaceTB/SmolLM-135M")
cfg = DataConfig(seq_len=512, batch_size=16, stride=None, num_workers=4)

# first run — downloads, tokenizes, and saves as memmap .npy splits
train_dl, val_dl = from_hf(
    dataset_name = "roneneldan/TinyStories",
    tokenizer    = tok,
    cfg          = cfg,
    bos_id       = tok._tok.bos_token_id,
    eos_id       = tok._tok.eos_token_id,
    train_path   = "train.npy",
    val_path     = "val.npy",
)

# second+ runs — zero RAM overhead, loads directly from disk
train_dl, val_dl = from_npy_split("train.npy", "val.npy", cfg, tokenizer=tok)
```

### MoE Model

```python
model = Transformer(TransformerConfig(
    vocab_size     = tok.vocab_size,
    dim            = 512,
    n_layers       = 8,
    n_heads        = 8,
    attn           = "flash",
    ffn            = "moe",
    n_experts      = 8,
    top_k          = 2,
    moe_aux_weight = 0.01,
    pos_enc        = "rope",
    dropout        = 0.1,
    tie_weights    = False,
)).to("cuda")

# The Trainer adds aux_loss to ce_loss automatically — no changes needed
trainer = Trainer(model, train_dl, val_dl, tok.vocab_size, TrainConfig(
    max_steps = 5000,
    lr        = 3e-4,
))
trainer.train()
```

---

## Architecture Reference

```
Input tokens [B, T]
      │
      ▼
Embedding [B, T, dim]
      │
      ▼  SinusoidalPE or LearnedPE added here (if selected)
      │
      ▼  × n_layers
┌─────────────────────────────────────────────┐
│  RMSNorm / LayerNorm                        │
│  Attention  ← RoPE applied to q,k here     │
│             ← ALiBi bias added to scores   │
│  Residual connection                        │
│                                             │
│  RMSNorm / LayerNorm                        │
│  FFN / SwiGLU / MoE                        │
│  Residual connection                        │
└─────────────────────────────────────────────┘
      │
      ▼
Final RMSNorm / LayerNorm
      │
      ▼
Linear head [B, T, vocab_size]  →  logits
```

---

## Requirements

| Package | Version | Purpose |
|---------|---------|---------|
| `torch` | ≥ 2.0 | Core — required |
| `numpy` | any | Memmap dataloader — required |
| `pydantic` | any | TrainConfig validation — required |
| `tokenizers` | any | `RustBPETokenizer` |
| `transformers` | any | `HFTokenizer` |
| `datasets` | any | `from_hf()` |
| `huggingface_hub` | any | Hub push/pull |
| `hf_transfer` | any | Faster hub uploads (optional) |

Install all optional dependencies at once:

```bash
pip install transformer-toolkit tokenizers transformers datasets huggingface_hub hf-transfer
```

---

## License

This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
