Learning (to reproduce Pythia 2.8b) pretraining
Some researchers found through tracing the data provenance of open source models that Pythia 2.8b deduped may have not been trained with the same data as the other deduplicated Pythias. All the Pythias were trained on the Pile, which exists in roughly 4 formats: deduped or standard, and for each of those, preshuffled or not. If the model was trained on the wrong dataset, it is most likely to be one of the other Pile formats.
To the best of my abilities, it seems most likely that pythia-2.8b-deduped was trained on the standard Pile rather than the deduped Pile. I also found that number of GPUs and gradient accumulation steps impacted reproduction quality the most - getting these closer to the original training config resulted in increasingly close reproductions. Other factors such as per-device batch size, GPU type, dependency versions, and data shuffling methods make a smaller difference relatively.
Goal
To validate the hypothesis, the plan was to reproduce Pythia training, using the original training environment and configs, then determine if any Pile other than the deduped preshuffled Pile (from which it was supposed to be trained on) can be used to reproduce the published checkpoints of pythia-2.8b-deduped.
As control and test, I’d also take models from the same family that we think were trained correctly, for example pythia-70m-deduped, train them in the same environment on all candidate datasets, and ideally see that training on the deduped Pile gets the closest to the published checkpoints.
Beyond validating the hypothesis, I also wanted to gain an intuition for how pretraining reproduction is affected by different factors. I wanted to understand how much different GPU types, PyTorch/CUDA versions, attention implementations, per-device batch size/GPU count/gradient accumulation steps at the same global batch size, data shuffling, and data distributions (i.e. standard vs deduped) changed the model’s training trajectory.
Setup
Starting point
To eliminate model initialization as a source of divergence, I wanted to resume training from the uploaded checkpoint 0 rather than reinitialize the model.
There are five generations of the 2.8b Pythia on HuggingFace. The the standard (non-deduped) versions of the latest three all had some form of corruption, while the first two had no noticeable issues. However, the first two and the last three seem to come from two distinct training setups, where gen 1/2 are similar to each other and 3/4/5 are similar to each other.
| Gen | Repos | Notes |
|---|---|---|
| 1 | pythia-2.7b / pythia-2.7b-deduped | No noticeable issues. |
| 2 | pythia-2.8b-v0 / pythia-2.8b-deduped-v0 | All models are identical to gen 1. |
| 3 | neox-ckpt-pythia-2.8b / neox-ckpt-pythia-2.8b-deduped | Both variants are corrupt: every checkpoint is a clone of gen 1/2’s final checkpoint. |
| 4 | neox-ckpt-pythia-2.8b-v1 / neox-ckpt-pythia-2.8b-deduped-v1 | Deduped is unproblematic. Standard is corrupt - every checkpoint is the same, and extremely similar but not identical to step 143k of the deduped model (>0.99 param cosine sim, vs ~0.45 param cosine sim between pythia-2.8b-v0 standard and deduped). |
| 5 | pythia-2.8b / pythia-2.8b-deduped | Deduped is unproblematic. Standard is corrupt in a slightly more complex way where model.safetensors and pytorch_model.bin have gen 4’s issue, while each checkpoint of the sharded safetensor is a clone of the deduplicated file at that same step. |
Since we’re investigating the deduped model and the deduped variants of gen 4/5 are not corrupt, I resumed training from checkpoint 0 of neox-ckpt-pythia-2.8b-deduped-v1.
Metrics
To measure how similar a trained model checkpoint is to the HuggingFace checkpoint, the most obvious is to check if the weights are bit-wise identical, but since all of my reproductions had some deviation, I needed other metrics that gave more directional signals.
Parameter Cosine Similarity (P.Cos): flatten all parameter tensors from our reproduction and from the published checkpoint into single vectors, then compute cosine similarity. P.Cos = +1.0 means the parameter vectors point in exactly the same direction; +0.0 means uncorrelated.
L2 Distance: calculated between flattened parameter vectors.
Hardware
I had access to 3 different hardware configs: 8xA40-48GB, 16xA40-48GB, and 8xA100-80GB, while none of these are the same as the original Pythia training (2.8B used 64x A100-40GB, 70m used 32x A100-40GB), they allowed me to test the effect of different gradient accumulation steps, per device batch size, attention implementations, and number of all reduce operations.
Environment
The GPT-NeoX v1.0 tag for Pythia led to two Dockerfiles in the repo, one based on PyTorch 1.13 + CUDA 11.7, the other on PyTorch 1.10 + CUDA 11.1. Stella from Eleuther confirmed that the latter was used for the original Pythia training, but I also ran the former to compare.
The ‘correct’ environment for repro:
- PyTorch 1.10.0+cu111
- CUDA 11.1
- Apex with commit
a651e2c - DeepSpeed 0.3.15 (EleutherAI’s DeeperSpeed fork)
- flash-attn 0.2.2
Experiments
Gradient accumulation steps (gas)
Of all the knobs I had, gas seemed to make the biggest and most consistent difference in reproduction quality. The original training used 64 GPUs with gas=1 and global batch size of 1024 samples/step. Since I had few GPUs, I couldn’t run any repros with gas=1, but by varying the number of GPUs and per device batch size (PDBS) and keeping global batch size steady, I was able to run gas from 32 -> 8. I saw lowering gas gets us closer to the HuggingFace checkpoints on both Piles, while the gap between standard and deduped remains roughly steady.
| GPUs (A40s) | PDBS | gas | Std P.Cos | Ded P.Cos | Std L2 | Ded L2 |
|---|---|---|---|---|---|---|
| 16 | 8 | 8 | 0.999994 | 0.999992 | 2.36 | 2.63 |
| 8 | 8 | 16 | 0.999992 | 0.999990 | 2.72 | 3.00 |
| 8 | 4 | 32 | 0.999984 | 0.999982 | 3.75 | 3.93 |
Why does gas matter so much? I didn’t know, but from Claude:
In the original training, with 64 GPUs and gas=1, each microbatch’s gradients are computed independently and then averaged across all devices in a single all-reduce before the optimizer step. With fewer GPUs and higher gas, gradients are instead accumulated locally over multiple microbatches before the all-reduce. Since floating point addition is non-associative, the order and grouping of these reductions changes the result — accumulating 16 microbatches locally before averaging across 8 GPUs produces different rounding patterns than averaging 64 microbatches across 64 GPUs in one shot. These small numerical differences compound over training steps.
GPU type
I tested using 8xA100 vs 8xA40 and interestingly found A100 slightly closer, although the gap is smaller than A40s with different gas, I had limited time with the A100s so didn’t run both Piles, only the deduped.
| GPU | P.Cos | L2 |
|---|---|---|
| 8xA100 | 0.999990 | 3.00 |
| 8xA40 | 0.999990 | 3.00 |
PDBS vs gas
Since the A100s had 80GB VRAM I had the option to double PDBS and half gas compared to the A40s, while keeping global batch size and number of GPUs constant, here are the results.
| GPU | PDBS | gas | Std P.Cos | Ded P.Cos | Std L2 | Ded L2 |
|---|---|---|---|---|---|---|
| 8xA100 | 16 | 8 | 0.999993 | 0.999990 | 2.687 | 3.01 |
| 8xA40 | 8 | 16 | 0.999992 | 0.999990 | 2.72 | 3.00 |
Global vs flash attention
Pythia v1.0 was trained with flash attention 0.2.2, for the 2.8b Pythia which has 80 head_dim, the A40 did not have enough shared memory. I compared using flash attention to global attention on the A100 cluster.
| attention | P.Cos | L2 |
|---|---|---|
| flash | 0.999990 | 2.877 |
| global | 0.999990 | 2.873 |
Dependency versions
Comparing two different environments, we resumed training of neox-ckpt-pythia-2.8b-deduped-v1 checkpoint 0 -> 128 on both ‘standard’ and deduped Pile, and compared them with the step 128 checkpoint on HuggingFace. The gap between cuda/torch versions is relatively small. Since Stella confirmed the PyTorch 1.10 + CUDA 11.1 setup was used by the original Pythia training, all other experiments were run with that.
| PyTorch | CUDA | Std P.Cos | Ded P.Cos | Std L2 | Ded L2 |
|---|---|---|---|---|---|
| 1.10+cu111 | 11.1 | 0.999992 | 0.999990 | 2.72 | 3.00 |
| 1.13+cu117 | 11.7 | 0.999991 | 0.999990 | 2.73 | 2.98 |
Pythia 70M control
I also resumed training of the 70m Pythia deduped from checkpoint 0 in the same environment on both Piles. Since we have no reason to suspect that model was trained on the wrong data, I wanted to see if we can reproduce a consistent gap with our chosen metric. I trained the 70m up to 1k steps and saw deduped is closer for all steps, with the gap growing over training:
| Step | Deduped P.Cos | Standard P.Cos | Deduped L2 | Standard L2 |
|---|---|---|---|---|
| 2 | 1.000000 | 1.000000 | 0.00 | 0.00 |
| 8 | 1.000000 | 1.000000 | 0.06 | 0.06 |
| 32 | 0.999999 | 0.999998 | 0.41 | 0.44 |
| 128 | 0.999869 | 0.999778 | 3.89 | 5.06 |
| 512 | 0.989679 | 0.987252 | 38.11 | 42.33 |
| 1000 | 0.932638 | 0.922152 | 113.34 | 121.67 |
Data shuffling
GPT-NeoX has two levels of shuffling, document-level and sample-level, both of which on by default with seed 1234. Eleuther also has both a regular Pile and a “preshuffled” Pile, the latter tokenized into a single continuous stream with no document boundaries (doc_count=1), so document-level shuffle is a no-op on that Pile.
I tested different shuffle configurations at 70M up to step 128, with the standard “non-preshuffled” Pile, the “preshuffled” Pile with sample-shuffling, and “preshuffled Pile without sample-shuffling:
| Variant | P.Cos | L2 |
|---|---|---|
| Preshuffled deduped, sample-shuffling | 0.999869 | 3.89 |
| Preshuffled deduped, no sample-shuffling | 0.999866 | 3.93 |
| Non-preshuffled deduped, document-shuffling + sample-shuffling | 0.999771 | 5.14 |
The preshuffled Pile with/without sample-shuffling had almost the same P.Cos so I manually checked the first few tokens to confirm they are actually different. Using “preshuffled” Pile and default shuffling on got us slightly closer results to HuggingFace checkpoint 128, but the margin is insignificant.
Versus gen 1/2
I noticed that the first 2 generations of the 2.8b Pythias were seemingly unproblematic and a distinct training run to the latter 3, so I also compared a late checkpoint of our suspected problematic 2.8b deduped, and compared them to both the standard and deduped versions of the earlier generations, while this is an extremely noisy signal, it also showed that the 2.8b deduped was closer to the first 2 generations’ standard model.
| Step | 2.8b-v0 vs 2.8b-deduped (final) | 2.8b-deduped-v0 vs 2.8b-deduped (final) |
|---|---|---|
| 1000 | 0.994286 | 0.992782 |
| 16000 | 0.499541 | 0.476072 |
| 64000 | 0.329775 | 0.306820 |
| 143000 | 0.456450 | 0.417573 |
Conclusions
While I was not able to bit-exact reproduce training of Pythia models given the resources I had, in every set of experiments, the 2.8b-deduped trained with standard Pile was consistently closer to the deduped HuggingFace checkpoints than those trained with deduped Pile. By contrast the deduped 70m Pythia showed the opposite behaviour (i.e. deduped checkpoints on HuggingFace are closer to those we trained with deduped Pile). Additionally the 2.8b-deduped is also closer at all checkpoints to the standard v0 rather than the deduped v0.
I’d say this is convincing but not conclusive evidence that the 2.8b-deduped pythia was trained on the standard rather than deduped Pile.
P.S. the issues with corrupt files are now being fixed by Eleuther.
P.P.S. thanks to Lucia and Stella from Eleuther for helping with this work.
