Writing an LLM from scratch, part 32b -- Interventions: gradient clipping
I'm still working on training the best GPT-2 small sized base model that I can with a number of FLOPs roughly equal to two days on my own machine -- my "extra credit" exercise after having worked through Sebastian Raschka 's book " Build a Large Language Model (from Scratch) ". In the last post I trained a baseline model -- one with the same architecture and almost the same training code as in the minimal training run in the book, just modified to run using DDP on an 8x A100 40 GiB/GPU machine in the cloud. There are a bunch of "interventions" I want to try to see if they'll make it better, as measured by the loss they get on a test set. I'll do a post for each intervention, and this is the first: gradient clipping. In the training chart for the baseline model, you can see that there are three places where the loss suddenly spiked up, at around global steps 4,200, 13,000, and 23,000: There are a number of things that could cause loss spikes like that: Exploding gradients are common in RNNs, and also happen in LLMs like this one. I spent a bit of time reading around to find out how they happen, and the ah-ha moment came when I came across this post from Wanshun Wong . Not only is the post itself a good intro in terms of how it affects RNNs, but in the "further reading" at the end, there's some gold: Chapter 10.11 of [1] has a good overview of how gradient clipping works. Now, I bought a copy of " Deep Learning " at the same time as I bought Raschka's book, but I'd only glanced through it. Now was the time to get it down from the shelf -- and, indeed, section 10.11.1 is all about clipping to handle exploding gradients. I'll put the explanation of how they happen into my own words, to see if I can clarify things (at least in my mind). Normally, when we learn about gradient descent, it's illustrated with nice smooth loss charts like this imaginary one for a single-parameter model: We're told that we might start at point A. The gradient is quite high and negative, so we multiply it by our learning rate and subtract it from our parameter. That gets us to point B. This time around, the gradient is smaller as the curve is flatter there, so when we do the same -- multiply by LR and subtract -- we take a smaller step, and wind up at C. Rinse and repeat and we'll wind up near the minimum. The problem is, what if the loss curve actually looks like this: We start at A, with a small gradient, move a little to the right, and now we're at B halfway down a cliff! The gradient is massive, and when we subtract it, even scaled by the learning rate, we can zoom off somewhere to the right -- maybe not even on the chart. Indeed, you can imagine a cliff that is so steep that it would have vertical portions -- negative infinite gradients in this case -- and no matter what your learning rate is, you'll wind up with an infinite parameter update and everything will break. It's hard to see how a model can continue training in a case like that. Now, what can cause steep cliffs like that? The book says "strongly nonlinear functions, such as those computed by a recurrent neural net over many time steps". If you know about RNNs (I wrote about them if you'd like a summary), you'll remember that a single RNN might be quite shallow -- maybe three or four layers -- but when you're doing backpropagation, you run a number of inputs through, one after the other, work out the overall loss, and then "unroll" it to something similar to a "vanilla" neural net to do the backward pass. To put that in concrete terms, a 3-layer neural network trained with a 100-element sequence would unroll to a 300-layer deep network. Every one of those layers has several operations, including (in the implementation I was looking at in my post above), a t a n h . It's not surprising that there are cliffs in the loss landscape -- it's more surprising that there are any smooth bits! Now in LLMs, we don't have that unrolling through time -- but our network is deep enough as it is. For the GPT-2 small model, disregarding the embeddings and the final output head, we have 12 Transformer layers, each of which is multiple matrix multiplications for attention, then a softmax, then another layer, and then a feed-forward... mapping precisely to the equivalent vanilla NN is hard, but I think you can treat each one as at least four layers, so we've got 48. And there are GELUs and logs and exps 1 dotted around, so again -- we should expect cliffs. So if sometimes we'll get crazy gradients, what can we do about them? We clip them. Clipping gradients simply means that if they get larger than a particular number -- v , which we define -- we reduce them to that number. In other words, we have a cap on how big they can get. "Deep Learning" ("DL" from now on) suggests two ways to do it. Remember that while in the example above, we only had one parameter -- on the X axis -- for the GPT-2 small LLM we're training, we have 163 million of them. So the gradients, instead of being one number, will be a 163M-long vector, one per parameter. The two ways to clip are: The second feels more elegant -- we're scaling all of the elements of the gradient vector by the same amount, so it still points in the same direction. Interestingly, though, DL says that the two methods "work similarly", which I'll read as "are pretty much the same in practice". DL then goes on to say how infinite or not-a-number gradients should be handled. With the first way, clearly doing it naively would set every element in the gradient vector to v , which would make the total size (norm) of the update very large. With the second, it be even worse -- we'd still wind up with completely junk gradients, because the norm would be infinite, and in Python is , so we'd be applying gradients with NaNs in them at best. That would be likely to knock our model into unrecoverable territory, as any parameter that had that applied to it would be NaN forever. Their suggested solution is that if you get garbage gradients like that, you can take a random step -- that is, create a new gradient to apply that has the norm v but just points in a random direction. The idea is that this will move you away from the cliff-ridden part of the loss landscape where you've found yourself (more about that later), and things will continue nicely. So, anyway, how to do this in practice? PyTorch has a function, , and that's what's referenced in almost every bit of writing I've found about how to clip gradients. So I decided to use that, assuming it would do what was described in DL's second option and that it would do the random updates they suggest for non-finite gradients. (I was half-correct -- see later.) As to how to use it -- if we had a normal training loop, where we were just using a normal optimiser, we would go from: ...to something like ...where is the max value v from above. However, for our training code using Automatic Mixed Precision (AMP), it's a little more complicated -- but luckily, the AMP explainer we've been using has a section explaining what to do . Right now we have this: Per that explainer, we need to move to this: That looks a bit weird; we're "unscaling" the gradients, then clipping them, then using the scaler to step the optimiser. You'd think that you'd need to "re-scale" the scaler after clipping the gradients -- to get back to where you started from before the optimiser step. From the help page I gather it keeps track of whether or not the gradients it has right now are currently scaled and handles them appropriately based on that state in . Anyway, given that we know what the code looks like now, we need to implement it in a way that can be easily switched on for this experiment (and potentially in the future), but which also allows us to not use it if we don't want to. The best way with our setup is to make it a training option, so we can do it this way: ...with extracted from the file where we call it in : ...and we can just pass in for it in our function that we use to find the maximum micro-batch size for our current hardware, as all we're testing for there is memory usage -- we don't care if we're doing good updates. Here's the code delta for that , plus a bugfix to allow for files without a in them. But it would also be useful to be able to track when it "fired" -- that is, when we had to clip our gradients. Then we can see two things: Now, the docs for say that it returns the "[t]otal norm of the parameter gradients (viewed as a single vector)". It doesn't say whether that's before or after the clipping, but given that the return value would always be if it was after, I'm going to guess that it returns the pre-clipping norm (ChatGPT agrees). So we can chart that; changes in these diffs: 1 , 2 , 3 , 4 . So we now have code to clip gradients to a given norm size and to chart the gradient norms so that we know what they were before clipping. The question is, what should that clipping norm be? Some googling around suggested that there was no standard way of saying "for such-and-such a kind of model, gradients should be clipped at around x ". For example, on this Reddit thread , says "Common values are 1, 3, 5, 8, 10", and likewise sample code in this tutorial . has 1, as does this one . So my initial thought was, let's just use 1. But then I wondered, what actually are the gradient norms that we're getting in normal training? I decided to run a local short train on 3m tokens (a thousandth of the full training set, taking just less than four minutes) with very frequent checkpointing, and gradient clipping set to 1, and see what happened. You can see that the "grad max" line is almost always above the "grad clip" -- we're almost always clipping. This doesn't sound right. It looked like the range of the grad max was generally beween 1.1 and a little above 3, so I set the to 3.5 and did another train: Our loss is about the same, but we're no longer clipping -- and that's what we want; there was no evidence of exploding gradients for that short run -- just big updates near the start, as you'd expect. I then ran the same with no gradient clipping at all, and got exactly the same shape for the loss chart as I did with gradient clipping at 3.5, and the same final loss -- that's a good signal that clipping is not affecting the train when we stay inside the limit, which is exactly what we want. So, it was time to train our model! I kicked off the train, and after a little while, I looked at the training chart, which is updated dynamically as the model trains: You can see the dotted green lines, both the light one and the dark one -- that is, the "grad max" and the "grad avg" -- disappear starting just before global step 4,000, only coming back at about 5,500 -- that is, these were not plotted for global steps 4,319 and 4,936, even though the loss was. What was going on? I took a look at the checkpoint meta file for the first of those to see what the actual numbers were, and saw this: Aha! The PyPlot code I was using could not handle infinite values, which is entirely reasonable. That was easy enough to fix , though -- I just replaced positive infinity by 1,000,000 and negative infinity by -1,000,000, and then (in the interest of getting a proper from-scratch run) kicked everything off from the beginning. That training run completed with this chart: That's a little hard to read, but if you look closely at the green lines, you can see that there are seven periods where gradients were either very large or infinite. Weirdly, though, out of the seven, two of them were two checkpoint periods long (that is, two periods of 617 global steps). That felt weird, though of course we're looking at the maximum gradient norm and the average gradient norm -- so two single infinite/high-gradient steps in successive 617-step periods would lead to that effect. What was even stranger, though, was that if you look at the training chart for the run with no gradient clipping, we have only three loss spikes rather than seven: ...though it's also very noticeable that the gradient-clipped run had only two small loss spikes, unlike the three larger ones in the unclipped run. The training loss the gradient-clipped run reported at the end was better, too: ...versus 3.743 at the end of the baseline train. So it was time to download it, and run the sequence-completion smoke test: Coherent enough! Next, we evaluate it against our held-back test set: So, the loss had gone down -- but only from 3.743 to 3.678, a reduction of 0.065, or about 1.7%. That's not actually all that bad! After all, in my initial experiments on my local machine, training for a Chinchilla-optimal number of tokens from FineWeb-Edu (rather than the regular FineWeb I'm using now) got a loss of 4.167 on the same dataset (weirdly worse with the more-curated training set), and training for a further Chinchilla-optimal number of tokens only brought that down to 4.135, for a difference of 0.032, or 0.7%. It's not strictly comparable due to the different training sets, but speaking very loosely, we could say that gradient clipping for this train had more effect than doubling the training time for the other one. That's pretty nifty. But the question remained: why those long periods of high gradients, even with gradient clipping? And why were there still loss spikes -- in particular the one just before global step 12,000, which lasted for two checkpoint periods? Remember that when I started the first run of this train, and got the chart with the missing bits, it was because the logged and were infinite. What happens when gets an infinite gradient -- either one that has an infinity as one of its components, or one that (due to numerical overflow) winds up with a norm of infinity anyway? I'd been kind of assuming that it did what the authors described in "Deep Learning" -- a random update of norm v -- given that the book stated pretty confidently that you "can" do it but then appeared to consider the topic closed. But it doesn't! If you check that link to the docs, you'll see that it has a parameter , which is by default. If it's set to , that will raise an exception if the norm is positive or negative infinity, or if it's not a number -- which catches both the infinite component and the norm overflow cases above. But if it's not set -- and we weren't setting it -- and the norm or the gradients are non-finite, then will essentially return garbage gradients. Depending on the exact cause, elements will either be infinities of one sign or another, or NaNs. And if these are added to parameters, then those parameters will become garbage too. Now that leads to the question, given that we know that somewhere in the period between the checkpoint at global step 4,319 and the previous one at 3,702 there was an infinite norm at some point, how on earth did the model manage to continue training after that? Loss went up at around the same time, but it wasn't completely broken as it would have been with NaNs or infinities in its parameters. Obscurely enough, the answer turned out to be in the AMP explainer , in a comment in one of the bits of example code. Regarding the class we're using: So what was happening was that the scaler -- something we introduced into our code to get a speedup by using 16-bit floats instead of 32-bit whenever PyTorch thought it would make sense -- was protecting us against infinite and NaN gradients as a side-effect. It was skipping updates that would have polluted our weights with bad values from non-finite gradients. If the above comes across as a little frustrated, then it's because I am a bit! From a software engineering viewpoint, this situation really does feel a bit like a rather messy part of the API. There are three things that it's reasonable for a library to do with infinite/NaN gradients: Now, if we look at that , we can see that the first two of those cases are handled there; and the developer can choose which option to follow. It's not where I'd personally put it (the function on the optimiser seems more natural) and I think I'd probably set the default to too, but I can also imagine good reasons for it being the way it is -- backward compatibility for one. But the "skip non-finite gradients" being a (not even optional!) behaviour that is on a class designed for handling mixed-precision training just seems outright bonkers. I would be surprised if there weren't people out there who've spent days trying to work out why their training runs failed catastrophically when they decided to switch from mixed-precision to "full fat" 32-bit floats, not realising that a hardly-even-documented feature of the scaler 3 had been saving them from gradient issues previously. Anyway, rant over. What does this all mean? There are three ways a gradient can explode: With both the baseline code and our new code, the was saving us from the last two of those, by skipping the optimiser steps with non-finite gradients. However, the baseline run was not protected against the first kind -- large but finite gradients with a finite norm -- while this run was protected. What I'm almost certain is happening here is that in all of my training runs so far, there have been all three kinds of issues with exploding gradients. The , which again, we introduced for faster training, happened to be saving us from the infinite gradients/norms. But we were still being bitten by the finite but excessively large ones. And that, I think, is why this training run had a positive -- not huge, but certainly worthwhile -- effect on the test set loss. If I had more time, I think I'd do another run, logging all three of those categories of error to see how frequent they are, and charting the result. That might go some way to explaining the final question I had here: why is it that the renowned "Deep Learning" suggests a random update to get away from the cliff where you've found yourself, while we seem to be getting away with just skipping the update, which is much simpler? Well, the book was written in 2016, and I guess rather a lot has changed in the last 10 years :-) My guess is that their solution might have been a solid default in the age of RNNs, but might not make so much sense with the kind of models we're training these days. I think I can see a way in which that makes sense. Think of the illustration of a loss "cliff" in a one-parameter world that we had at the start of this post: If you happen to wind up on that cliff, you're in trouble. But imagine a two-parameter model -- the line of the loss function becomes a surface. Just as in the real world you might be able to walk along the edge at the top of a cliff and find a nice easy slope down next to it, you can imagine that the cliff in the two-parameter case might be less of a problem because you don't need to be lucky enough to jump down it -- you can walk around it. Extrapolating examples like this to higher dimensions is risky, but I think it should hold that the more dimensions you're working with, the less likely it is that a cliff is an issue -- you're more likely to be able to find a way around it. I've heard a very similar argument made for why local minima are less of an issue with lots of parameters. It's certainly worth saying that this is far from a mathematical proof, but I think it's a decent grounding for intuition. Now think about an RNN. Although you're doing back-propagation through time over what amounts to a very deep network, there aren't actually all that many parameters, certainly compared to an LLM like this. Each parameter is involved in the back-propagation multiple times. So, thinking of it that way, the gradient vector for the RNNs they were dealing with was of much lower dimensionality than the ones we're dealing with, even for this tiny model. They say that the random step "will typically move away from the numerically unstable configuration". I'm probably playing fast and loose here, but I'll take that as something like: if you wound up on a cliff, you were likely in a very "cliffy" area of the loss landscape. "Teleporting" randomly to somewhere some distance away was a sensible way to handle that. In our situation, even if the area is "cliffy" in the direction that one particular batch might push us, we have so many extra dimensions that it may well be that it won't be so bad with the next one. So just skipping the problematic update -- under all of those assumptions -- seems a perfectly reasonable way to handle it. All of this, BTW, made me think back to validation loss. In our previous training runs, where we were measuring it just before each checkpoint, its spikes were in general correlated with but not identical to spikes in training loss: Now, of course, exploding gradients don't have to be related to high training loss -- there's enough non-linearity in there that we can treat them as being completely uncorrelated, I think. But you definitely would expect them to have an effect on validation loss if applied. Disregarding the infinite ones (which were being filtered out anyway), the very high ones that we are now clipping would, in the unclipped baseline train, seem very likely to have caused validation loss spikes. So: if I hadn't stripped that out, we would likely have been able to see a clear difference in the validation loss line between clipped and unclipped. That would have been useful! I'm not going to re-introduce it, though. Best to keep the number of code changes to a minimum if I'm trying to compare like with like over the course of these intervention tests. I think that's enough for gradient clipping. I may come back and do the experiment another time to see what the relative ratios of the different kinds of problematic gradients are. Are there parts of the train where we get lots of them as a percentage (ie. we're somewhere "cliffy" in the loss landscape)? How many infinite gradient vs infinite norm vs big-but-not-infinite instances do we have relative to each other, and to normal gradient updates? What do we see if we have validation loss? And so on. But for now: gradient clipping definitely helps, and goes on the positive interventions list! I'm thinking I'll see what happens with switching off dropout next. That should at least be a bit easier... Stay tuned! Oh my . ↩ Technically the L2 norm -- if you used cubes/cube root it would be L3, and likewise for the power of four and L4 and so on. But the L2 is the one used for gradient clipping. ↩ Shades of Douglas Adams , really: "But the plans were on display..." "On display? I eventually had to go down to the cellar to find them." “That’s the display department." “With a flashlight." “Ah, well, the lights had probably gone." “So had the stairs." “But look, you found the notice, didn’t you?" “Yes," said Arthur, “yes I did. It was on display in the bottom of a locked filing cabinet stuck in a disused lavatory with a sign on the door saying ‘Beware of the Leopard." ↩ A "bad batch" -- that is, one batch, or even one sequence in a batch, was massively different in structure to the others that the model had seen, so it just had much worse loss. That doesn't seem likely in this case, though: the numbers on the chart are averages over 617 global steps each, and it would take a truly pathological sequence to move the needle that much. Something weird in the optimiser. That's not something I understand well, but according to the various LLMs I'm working with, it's a possibility. Exploding gradients. This is my working hypothesis, and so in this post I'll try out gradient clipping, the normal solution to that problem. I. Goodfellow, Y. Bengio, and A. Courville. Deep Learning (2016), MIT Press. We clip element-wise. If any one of the gradients in the vector is larger than v , we reduce it to v . We clip based on the norm: the length of the gradient vector in -- in our case -- 163M-dimensional space. That sounds harder than it is -- it's really just an extension of the Pythagorean equation that a 2 + b 2 = c 2 to multiple dimensions. If you want to work out the length of a vector ( a , b ) then you can use Pythagoras to work out c = a 2 + b 2 , and that generalises to any number of dimensions. So for our model we'd just square all 163M elements of the vector, sum those, and take the square root of the result, and that's the norm. 2 If the norm is greater than v , we just divide every element of the gradient vector by the norm and multiply the result by v , to produce a new gradient vector whose norm is v . Whether we actually did wind up clipping them and fixing those loss spikes Whether we were clipping at other times -- we don't want to be doing it unnecessarily. Blindly apply them and expect the developer to sanitise their inputs. Raise an error. Take some kind of default sane action, like skipping the update. It can get very large, still be finite, and have a finite norm. It can get very large, still be finite, but have an infinite norm (eg. due to numerical overflow) It can become infinite -- that is, at least one of the parameters' gradients is infinite (which of course means an infinite norm regardless of any numerical stuff). Oh my . ↩ Technically the L2 norm -- if you used cubes/cube root it would be L3, and likewise for the power of four and L4 and so on. But the L2 is the one used for gradient clipping. ↩ Shades of Douglas Adams , really: "But the plans were on display..." "On display? I eventually had to go down to the cellar to find them." “That’s the display department." “With a flashlight." “Ah, well, the lights had probably gone." “So had the stairs." “But look, you found the notice, didn’t you?" “Yes," said Arthur, “yes I did. It was on display in the bottom of a locked filing cabinet stuck in a disused lavatory with a sign on the door saying ‘Beware of the Leopard." ↩