Writing an LLM from scratch, part 32f -- Interventions: weight decay
I'm still working on improving the test loss for a from-scratch GPT-2 small base model, trained on code based on Sebastian Raschka 's book " Build a Large Language Model (from Scratch) ". In my training code, I have this code to create the optimiser: In my last post I looked into the learning rate, the parameter in that code, and found a value for that, plus some extra code to schedule it -- that is, to vary it over time -- which gave better training results. This time I want to go into the weight decay. What is it, what is it for, and is 0.1 really the best value? I was a little concerned going into this that in order to understand this hyperparameter, I'd need to have a good understanding of how the optimiser works; I've been building what I think is a solid mental model of optimisers, but I don't think I understand them well enough to explain them yet, and I've been hoping to delay posting about them to a separate blog post series after this one. The good news is that while weight decay is an important aspect of how optimisers work -- the "W" in AdamW, the thing that makes it different to the older Adam optimiser, is a nod to its different treatment of weight decay -- you don't need to know how the optimiser itself works to understand what weight decay is. Instead, you just need to consider an older and more fundamental aspect of building ML systems -- regularisation. In order to dig into that, let's start with overfitting. Let's imagine a simple classification task: we want to build a model that can -- for any point on this chart -- predict whether a cross or a circle should go there, training it using the sample data points that we already have: Let's say that we train a powerful model on this dataset, and it comes up with this: Now, ab initio we don't know whether that's a good result or not; we need to use our validation set to evaluate it. Let's say that the validation points are these blue ones: We can see that it looks like our powerful model has overfit. The training set is all nicely split by the boundary, but the validation points are not. A common solution to how to handle that kind of issue that you might see in introductory ML courses is to try using a less powerful model. A less powerful model in this case might come up with a less "wiggly" line to separate the two categories, perhaps because it didn't have enough parameters to make it wiggle so much, so you might find that it came up with a classifier that looked more like this: So: we use our validation set to detect overfitting, and we can adjust the complexity of our model to try to avoid it. Now, this is all very well, but it does require manual intervention. We had to do a training run, identify that we were overfitting, and then decide on parameters for the new simpler model (how many parameters should it have?). We could, perhaps have gone too far and wound up with something like this: ...and underfit. There's no way when we start out knowing what the right number of parameters is, so we need to try various values and then try to work out the optimum balance. Regularisation techniques are designed to try to automate this -- to prevent overfitting without all that tedious mucking about with the model. We've already looked at Dropout , which is one of the standard ways to do that. Although my own mental model of what it does goes some way beyond just helping to prevent overfitting, I may well be wrong -- and given that our LLM train is never seeing the same training data twice, being a single-epoch run, removing it turned out to improve our model . Another technique is just stopping the training run when you start seeing the validation loss rise, also known as "early stopping". That's such an obvious thing to do that I came up with it independently back when I was doing my early experiments with fine-tuning . Now, we don't have a separate validation set for these training runs, but because we're doing a single epoch, the training data it sees is just as "new to it" as a held-back validation set would be, so we could use a similar trick and treat "train loss starts rising" instead of validation loss rising as a reason to stop the train early. It's not exactly the same thing, but perhaps it would be close enough. But in all of the trains in this series, that's never happened -- while sometimes the train loss blips up for a bit, in the longer term it keeps going down. But there are other techniques that rely on a neat trick. Let's think back to the manual, boring way of trying to find how many parameters are appropriate for a modelling task. We tried one number, found that it overfit, then we might try a lower one, find that it underfit, then try something in the middle and find that it's better but still not perfect one way or the other, and rinse and repeat until we find something we're happy with. This kind of searching through a solution space to find an optimum is exactly what we're doing when training a model. It would be really nice to automate it in the same way. One trick is: if we want to minimise the complexity of our model so that it doesn't overfit, we can try adding a measure of the model's complexity to the loss function -- and then our normal process of gradient descent will try to minimise that, just like it will try to minimise the loss from the training results themselves. And that brings us on to weight decay. Regularisation by weight decay starts off with the hypothesis that the "size" of all of the model's weights, taken together, is a measure of the model's complexity. If the model's weights are small, then it's a simpler model than if they're large. 1 The "size" in this sense is the square of the L2 norm -- that's something we came across in gradient clipping . The L2 norm is basically all of the weights squared, added together and then the resulting sum square-rooted. You can think of it as the length of the vector that the weights represent -- that is, for our 163M-parameter model, it would be the length of the model's weights considered as a vector in 163-million dimensional space. 2 And by using its square, we get something that penalises larger values more (and we also save the time in calculating a square root). To me, it's not intuitively obvious that that measure really does express the complexity of the model in any clear sense. After all, you'd think that doubling all parameters would leave it no more complex than it was before, but it would double the L2 norm. 3 But I imagine there is solid maths behind it to say that it does work in a more general way, so in the interests of not disappearing down a mathematical rabbit hole at this stage, I'll take it as given. So: we're using the squared L2 norm as a measure of model complexity, and we're going to add that on to the training loss as a way to try to minimise both. The next question is, how do we balance between the two -- the training loss and the model complexity penalty? This is, in a somewhat hand-wavy way, similar to the decision of how much of the current loss function's gradient to use when adjusting the weights. For that, we use η , the learning rate to scale the gradients before applying them: And the balance between the "real" loss and the model complexity penalty is done in a similar way -- we have a number, the weight decay, normally represented by a lower-case lambda, λ , and we multiply the squared L2 norm by that, something like this: ...where I'm using ℒ for the normal loss on the training inputs vs the targets, N 2 for the squared L2 norm of the weights, and ℒ ′ for the combined loss. And ℒ ′ is what we -- in theory -- actually try to minimise using our optimiser. But there's actually a neat simplification that we can apply to make this even easier. Firstly, let's make one small change to the equation above: we'll halve the squared L2 norm before multiplying it by λ . That obviously doesn't change the underlying maths, it just means that we'd need to use larger values for λ to get the same effect. You'll see why that's useful in a bit. Now let's think about normal gradient descent. Again, we work out the gradient of the loss function for each weight, and subtract that times the learning rate η from the weight's value to update it: Let's reformulate that a bit. The gradient of the loss function for the weight is its partial derivative against that weight, so we can write the above like this for the version of the loss function including weight decay, ℒ ′ : Now, we defined ℒ ′ above as ℒ + λ · N 2 2 , so we can substitute that in there: Now, let's think about that L2 norm, N . It's the square root of the sum of all of the weights squared, or equivalently we can square it (like we do in the formula above) and say: Let's drop that in: Now, the derivative of a bunch of things added together is just each of them differentiated separately and then added together. Let's apply that to the two terms in the brackets: ...and now pull the constant λ and the 2 out of the second partial derivative: Then we apply the rule for the derivative of a bunch of things added together again: Now, we're doing a partial derivative versus one specific weight, w , which is one of the w 0 , w 1 , and so on in there. From that perspective, all of the other weights are constant -- which means that their derivative with respect to w is zero. So we can just get rid of all of them apart from the one that actually is w , and we wind up with this: The derivative of w 2 with respect to w is just 2 w . Thanks to that crafty halving of the N 2 earlier, that means that we can go to this: Multiplying that − η across the bracketed terms, we get: That's exactly the same as the normal gradient descent update, using the unmodified loss function without weight decay -- except that we're additionally subtracting the weight's original value scaled down by both the learning rate η and the weight decay value λ . Much simpler :-) (As an aside: the description above is correct for "traditional" simple gradient descent and -- loosely -- for Adam, but AdamW's trick is to do things somewhat differently. That's something I'll go into in more detail when I get round to writing my post on optimisers.) So: weight decay is a regularisation technique that tries to prevent our model from getting any more complex than it needs to be. We have one number, λ , which determines how much to weight complexity against the normal training loss. And, as we can see from the code: ...right now we're setting λ to 0.1. Is that the right value? As usual, the GPT-2 paper is light on the details of the hyperparameters they used, but nostalgebraist wrote a really nice post on Tumblr where they dug into what the number might have been. As they say: It does say it follows the first GPT paper in most respects, and that paper used weight decay of 0.01. Their link for the paper appears to be mistaken, as it's a different (albeit very interesting) paper from 2020, a year after the GPT-2 one, but I believe this is the paper normally called the GPT-1 one . They do indeed use 0.01 there: We also employed a modified version of L2 regularization proposed in [37], with w = 0.01 on all non bias or gain weights. The link to the GPT-3 paper looks right, though, and as they say, it uses a weight decay of 0.1: All models use weight decay of 0.1 to provide a small amount of regularization They then do a bit of maths to work out whether the GPT-2 weights are likely to have been regularised by something like weight decay, and come to the conclusion that they probably used 0.01, just like the GPT-1 paper. It seems plausible, but of course not certain. But: tentatively, GPT-2 used 0.01, while we're using 0.1, perhaps because the GPT-3 paper does. What other data points do we have? The Hugging Face "Smol training playbook" has some interesting stuff (including not using weight decay on embeddings, which they say they found helped), but the value that they use is 0.1, which they call "a very vanilla setting". And: Interestingly, over the last few years the AdamW hyperparameters have barely moved: The same triplet is reused in Llama 1, 2, and 3 and DeepSeek-V1, V2, and V3-671B, with no changes. Anyway, assuming they're right about weight decay value for the models they mention (and I assume they've done the research -- I had the link to the DeepSeek paper to hand, and that one certainly says 0.1), it looks like 0.1 is pretty much standard these days. And a quick double-check of what a typical value would be -- asking ChatGPT, Claude, Gemini and Grok -- they all recommend 0.1 as a solid sensible default with AdamW (though they all also say that values between 0.01 and 0.1 are reasonable). So on that basis, I think we can say that 0.1 is a reasonable default, and has pretty much become the standard, but it might be worth trying 0.01 just to see if it does help with tiny models like ours. Are there any dissenting voices to the 0.1 orthodoxy? I came across a paper from a team at Cerebras Systems , " Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training ". It's essentially a Chinchilla-like attempt to get scaling laws, but rather than looking just at optimal tokens per parameter in order to work out what you should scale up when adding on more compute, they're trying to find optimal batch sizes and values for weight decay. That's certainly relevant to our interests :-) However, it is very dense and in-depth, and fully understanding it at this stage would need quite a lot of work -- very much a side quest. Definitely something to come back to later, but for now, I'll just try to extract the stuff we need. Let's start off with the optimal batch size, as they have it right there on the first page. We're not going to use it, but it will be interesting to compare with what we're using, and what the DeepSeek paper that I looked at in the last post suggested. They fit this formula: ...where D is the total number of tokens that you're training on. That's quite different to the formula in the DeepSeek paper, which was: ...where C is the number of FLOPs 4 . C scales up linearly with the number of tokens D , but also with the number of parameters in the model N , so you can see the DeepSeek formula as a function of N and D -- as your model gets bigger, so does B opt -- whereas this Cerebras paper is saying that it's just a function of D , unaffected by model size. They did train over a number of different sizes (from 111M parameters up to 1.7B) and their formula seems to hold, so it's not just that they didn't treat model size as relevant. Well, let's see what their formula comes up with. We have 3,260,252,160 tokens in our train, so their formula for B opt comes out as: That's much closer to the 97-or-so sequences that appeared to be optimal when I did some rough-and-ready curve-fitting than the 373 that the DeepSeek formula gave for our setup :-) OK, so what about the weight decay? They don't give a direct formula for that, but they do give a formula for the optimal τ , the AdamW timescale. Without going into exactly what that means right now (that's one for my optimisers post later), they relate it to other numbers that we do know with this formula: ...where B is the batch size, D is the amount of data, and of course λ and η are weight decay and learning rate respectively. So if we know the optimal τ we can work out the optimal λ for our training run; solving for λ , we get: So let's work out the τ opt . Their fitted formula is this: ...where TPP is tokens-per-parameter. For us, with our Chinchilla-optimal TPP of 20, we get: Now, we're using a batch size B of 96, and (as before) D is 3,260,252,160. Our learning rate η is 0.0004 for this train -- remember, although in the last post we found that a scheduled learning rate with a peak at 0.0014 was better, in this post we're testing changing weight decay in isolation. 5 So, we just need to plug our τ opt into this: Before we do: having a batch size and a number of tokens in the same formula feels like a unit mismatch. In particular, as part of the explanation of that formula, they tie it back to a value S , the total number of optimisation steps, which they define as D / B . For that to work, either both need to be in terms of tokens, or both need to be in terms of sequences They clearly say that "B is reported in units of sequences". I'm not sure how to explain this, except by saying that perhaps the D is also meant to be in terms of sequences too, even though I'm pretty sure that it's meant to be in terms of tokens in the equation for the batch size. 6 Well, let's assume that is the case, and plug in numbers for sequences. We have 3,260,252,160 training tokens split into 1,024-token sequences, which is 3,183,840 sequences, so that comes out as: (Note that we'd get the same numbers if we plugged in numbers for tokens in both cases, as it would just multiply the top and the bottom by 1,024.) That comes out as 0.33724. Wow! That's even higher than the "traditional" 0.1, never mind the 0.01 that is the best guess we have for GPT-2. Even if I'm missing something here (I certainly can't say I've read the paper in as much detail as it deserves), that actually gives us a nice number to try out as an experiment. We already have a loss on our test set for a model trained with a weight decay of 0.1, as that was what we used in our baseline train. It looks like it might be worth doing two more, one with the GPT-2 estimate of 0.01, and one with this Cerebras-inspired 0.33724, neatly bracketing it. Let's give them a go! Firstly, the training run with λ = 0.01 : Looks like a nice smooth train -- one small loss spike near the start but it quickly recovered. The output was: That's not a bad final train loss (which does tend to indicate a good model). Let's look at the evals; firstly, the smoke test -- how would it complete "Every effort moves you"? Passably coherent. Let's take a look at the loss it gets on our test set: Not bad at all! Time to upload it to Hugging Face and to add it to the table so that we can compare it to the other interventions we've tried so far. So, it's better than gradient clipping and the QKV bias, but slightly worse than removing dropout and much worse than scheduling (and increasing) the learning rate. Now, that suggests to me that the much-higher Cerebras-inspired weight decay will be worse. My logic is this: if both decreasing it and increasing it improved loss, that would suggest that we have an inverted-U loss curve for weight decay like this: Now, it seems vanishingly unlikely that those downward trends on either side would continue so that you could get arbitrarily low loss by increasing or decreasing weight decay even more. So the curve would perhaps look a bit more like this W-shaped one: My intuition is that having multiple minima -- especially ones that just happen to be on either side of the "standard" value for weight decay -- seems less likely than the alternative -- that the higher number will be worse because we're actually on a U-shaped curve more like this: Of course, my intuition could be completely off on this, and it's definitely still worth doing the test! Here's the loss chart with that: You can see right away that it was a much choppier train, with quite a few loss spikes, some quite late on. The output at the end reflected this: ...a significantly worse loss at the end. Still, we should do the evals. Firstly the smoke test: Not too bad, but the loss test is the important one: That's terrible! Our first result for loss on the test set for an intervention that is actually worse than the baseline. Much worse: However, at this point I started wondering. When I was looking at the learning rate, the number I selected based on the DeepSeek paper worked well with learning rate scheduling, but failed to converge without. The weight decay number is multiplied by the current learning rate before it's used to reduce weights' values, so will be affected by both scheduling and η . It seemed likely that Cerebras used a learning rate schedule, and double-checking the paper: We present results with a single (standard) learning rate schedule ... For a given TPP, all models have the exact same warmup phase: a linear warmup of the learning rate from 0 to the maximum value. ... We use the µP-tuned and adjusted peak η , for 111M models. The learning rate increases linearly to the peak for the first 10% of steps, then decreases from the peak to 0 for the remainder of steps. Seems pretty certain. Now, I've been following a fairly strict rule of testing interventions in isolation; however, the learning rate and the weight decay parameters are so intertwined that perhaps that's just not reasonable here. I decided to do two more trains, both with learning rate scheduling. I'd use the same schedule as in the last blog post -- a warmup from pretty-much zero to the peak over 10% of the run, followed by a cosine decay to 10% of the peak. In the first, I'd use the same learning rate as our baseline model, 0.0004. In the second, I'd use the one we got from the DeepSeek paper, which did really well when scheduled: 0.0014. Well, that's less choppy, at least -- the scheduling calmed down the later parts of the run, as you'd expect given that the learning rate was dropping. The output: Still a kind of high training loss at the end, though. The smoke test: Not too bad, and the test set loss: Unfortunately still worse than the baseline of 3.692, albeit better than the one without learning rate scheduling. I'm not going to add it to the table, as this was more in the way of an exploratory training run. Let's see how we do with the larger DeepSeek-suggested learning rate. For this one, I kept the weight decay at 0.33724. (This was an error, as I realised later -- more on that shortly) Ouch, super-choppy loss -- and the loss at the end of the train isn't promising either Terrible loss at the end. The smoke test gives this: ...which is not too bad, but the test set loss: ...is still pretty terrible (though still a tad better than the one without the learning rate scheduling). Another one to throw away, I think. But then something occurred to me: the formula to go from the optimal AdamW time horizon τ opt to the optimal weight decay λ opt is this: It has the learning rate η in it -- I even made a footnote saying that I was going to have to remember to recalculate the weight decay value when that changed :-S Luckily, though, running the real numbers through that: ...which is almost exactly the same as the 0.1 that we've been using for all of our other experiments. So that actually suggests that the Cerebras equations come up with a reasonably usable number for weight decay if you use the DeepSeek-optimal level for the learning rate, and schedule it in a normal warmup-cosine decay manner. But it's still not as good -- for this model -- as using the GPT-2 number. 7 With that, I think it's time to wrap this intervention up! Let's look at our results table again: We've found that reducing the weight decay from the now-standard 0.1 to a GPT-2-inspired 0.01 improves the loss our model gets on the test set; it's the third-best intervention so far, after getting rid of dropout and updating our learning rate -- and the difference between it and the dropout intervention is pretty small. It did surprise me that the Cerebras-inspired number did so badly, though. To recap: I think that for now, I should not head any further down this rabbit hole and just take the win -- we have a weight decay parameter that works better than the one we had, and so that's something that can go into our set of working interventions. I can revisit the Cerebras paper later when I've spent more time studying optimisers. As to why this old-fashioned GPT-2 value might work better than the current default of 0.1: I think that could plausibly be due to scale. The 0.1 value appears to come from the GPT-3 paper, which essentially was an experiment in scaling up GPT-2. Perhaps larger models need larger weight decays? And the model we're working with here is really small, at 163M parameters. So, that's weight decay done! Of the list of planned interventions I wanted to try , only training in full-fat 32 bits (rather than AMP), and weight-tying remain. I think I'll look into the second of those next. Stay tuned! Here's a link to the next post in this series . More precisely, from Deep Learning : Minimizing J ( w ) results in a choice of weights that make a tradeoff between fitting the training data and being small. This gives us solutions that have a smaller slope, or that put weight on fewer of the features. ...where J ( w ) is the loss function we're trying to minimise in our training run, combining the "real" loss and a measure of the model's size. ↩ I can't decide whether that makes it easier or harder to understand ;-) ↩ Wild speculation: how about something using the Shannon entropy of the weights...? ↩ Specifically the non-embedding training FLOPs. ↩ Note to self: don't forget to adjust it if we do decide to combine this with the learning rate update. Also: I'm pretty sure from reading the paper that the η that they're using in these formulae is the peak -- they certainly are using learning rate scheduling, albeit with a decay-to-zero rather than the decay-to-10% we used. ↩ Plugging in the number of sequences into the batch size formula gives us an optimal value of 9.47, which definitely doesn't look right based on the trains I've done. ↩ Assuming that the GPT-2 value for weight decay "stacks up" well with the learning rate update and the scheduling from the last post. There may be some useful tests to do when we try to put this all together. ↩ β 1 = 0.9, β 2 = 0.95 Grad norm clipping = 1.0 Weight decay = 0.1 (Llama 3 405B drops this to 0.01) With our too-low learning rate of 0.0004, it performed terribly When we added scheduling, it was a bit better but still not great. When we used a DeepSeek-optimal learning rate (and actually did the right calculations to get the real value for weight decay based on that), we got a number which was very close to our baseline train, and seems very unlikely on the face of it to have a significantly different resulting test set loss. More precisely, from Deep Learning : Minimizing J ( w ) results in a choice of weights that make a tradeoff between fitting the training data and being small. This gives us solutions that have a smaller slope, or that put weight on fewer of the features. ...where J ( w ) is the loss function we're trying to minimise in our training run, combining the "real" loss and a measure of the model's size. ↩ I can't decide whether that makes it easier or harder to understand ;-) ↩ Wild speculation: how about something using the Shannon entropy of the weights...? ↩ Specifically the non-embedding training FLOPs. ↩ Note to self: don't forget to adjust it if we do decide to combine this with the learning rate update. Also: I'm pretty sure from reading the paper that the η that they're using in these formulae is the peak -- they certainly are using learning rate scheduling, albeit with a decay-to-zero rather than the decay-to-10% we used. ↩ Plugging in the number of sequences into the batch size formula gives us an optimal value of 9.47, which definitely doesn't look right based on the trains I've done. ↩ Assuming that the GPT-2 value for weight decay "stacks up" well with the learning rate update and the scheduling from the last post. There may be some useful tests to do when we try to put this all together. ↩