nanochat 是一个用于学习目的的 LLM 的迷你实现。
一个大模型的训练和推理过程如下:
分词器训练 ——> 预训练 ——> 中期训练 ——> 监督微调(SFT)——> 强化学习 (RL)
——> 推理引擎推理
工程结构
整体项目结构如下
├── dev
├── nanochat # 迷你LLM的实现,包含训练和推理
│ ├── __init__.py
│ ├── __pycache__
│ ├── adamw.py ## 实现了 AdamW 优化器,与muon.py类似,是模型训练时用于更新权重的另一种选择
│ ├── checkpoint_manager.py # checkpoint 管理
│ ├── common.py # 工具类
│ ├── configurator.py ## 配置
│ ├── core_eval.py ## 核心能力评估模块,用于在标准基准测试
│ ├── dataloader.py ## 数据加载
│ ├── dataset.py ## 数据集
│ ├── engine.py ## 推理引擎
│ ├── execution.py ## 执行LLM输出的python代码的沙盒
│ ├── gpt.py ## 简易的GPT模型
│ ├── loss_eval.py ## 损失评估模块,专门用于计算模型在训练集和验证集上的损失(loss),
│ ├── muon.py ## 名为 Muon 的自定义 PyTorch 优化器
│ ├── report.py
│ ├── tokenizer.py ## 分词器
├── pyproject.toml
├── rustbpe # 用rust实现的bpe分词器
├── scripts
│ ├── __pycache__
│ ├── base_eval.py
│ ├── base_loss.py
│ ├── base_train.py
│ ├── chat_cli.py
│ ├── chat_eval.py
│ ├── chat_rl.py
│ ├── chat_sft.py
│ ├── chat_web.py
│ ├── mid_train.py
│ ├── tok_eval.py
│ └── tok_train.py
├── speedrun.sh # 开始训练的脚本
├── tasks # 训练任务
│ ├── __pycache__
│ ├── arc.py
│ ├── common.py
│ ├── gsm8k.py
│ ├── humaneval.py
│ ├── mmlu.py
│ └── smoltalk.py
├── tests
│ └── test_rustbpe.py
└── uv.lock
训练过程
speedrun.sh 中是整个训练过程调用的脚本。训练顺序为 pretrain -> midtraining -> SFT
# pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively
python -m scripts.chat_cli -p "Why is the sky blue?"