DumPy: NumPy except it’s OK if you’re dum
What I want from an array language is: I say NumPy misses on three of these. So I’d like to propose a “fix” that—I claim—eliminates 90% of unnecessary thinking, with no loss of power. It would also fix all the things based on NumPy, for example every machine learning library. I know that sounds grandiose. Quite possibly you’re thinking that good-old dynomight has finally lost it. So I warn you now: My solution is utterly non-clever. If anything is clever here, it’s my single-minded rejection of cleverness. To motivate the fix, let me give my story for how NumPy went wrong. It started as a nice little library for array operations and linear algebra. When everything has two or fewer dimensions, it’s great. But at some point, someone showed up with some higher-dimensional arrays. If loops were fast in Python, NumPy would have said, “Hello person with ≥3 dimensions, please call my ≤2 dimensional functions in a loop so I can stay nice and simple, xox, NumPy.” But since loops are slow, NumPy instead took all the complexity that would usually be addressed with loops and pushed it down into individual functions. I think this was a disaster, because every time you see some function call like , you have to think: Different functions have different rules. Sometimes they’re bewildering. This means constantly thinking and constantly moving dimensions around to appease the whims of particular functions. It’s the functions that should be appeasing your whims! Even simple-looking things like or do quite different things depending on the starting shapes. And those starting shapes are often themselves the output of previous functions, so the complexity spirals. Worst of all, if you write a new ≤2 dimensional function, then high-dimensional arrays are your problem. You need to decide what rules to obey, and then you need to re-write your function in a much more complex way to— Voice from the back : Python sucks! If you used a real language, loops would be fast! This problem is stupid! That was a strong argument, ten years ago. But now everything is GPU, and GPUs hate loops. Today, array packages are cheerful interfaces that look like Python (or whatever) but are actually embedded languages that secretly compile everything into special GPU instructions that run on whole arrays in parallel. With big arrays, you need GPUs. So I think the speed of the host language doesn’t matter so much anymore. Python’s slowness may have paradoxically turned out to be an advantage , since it forced everything to be designed to work without loops even before GPUs took over. Still, thinking is bad, and NumPy makes me think, so I don’t like NumPy . Here’s my extremely non-clever idea: Let’s just admit that loops were better. In high dimensions, no one has yet come up with a notation that beats loops and indices. So, let’s do this: That’s basically the whole idea. If you take those three bullet-points, you could probably re-derive everything I do below. I told you this wasn’t clever. Suppose that and are 2D arrays, and is a 4D array. And suppose you want to find a 2D array such that . If you could write loops, this would be easy: That’s not pretty. It’s not short or fast. But it is easy! Meanwhile, how do you do this efficiently in NumPy? Like this: If you’re not a NumPy otaku, that may look like outsider art. Rest assured, it looks like that to me too, and I just wrote it. Why is it so confusing? At a high level, it’s because and and multiplication ( ) have complicated rules and weren’t designed to work together to solve this particular problem nicely. That would be impossible, because there are an infinite number of problems. So you need to mash the arrays around a lot to make those functions happy. Without further ado, here’s how you solve this problem with DumPy (ostensibly D ynomight N umPy ): Yes! If you prefer, you can also use this equivalent syntax: Those are both fully vectorized. No loops are executed behind the scenes. They’ll run on a GPU if you have one. While it looks magical, the way this actually works is fairly simple: If you index a DumPy array with a string (or a object), it creates a special “mapped” array that pretends to have fewer dimensions. When a DumPy function is called (e.g. or (called with )), it checks if any of the arguments have mapped dimensions. If so, it automatically vectorizes the computation, matching up mapped dimensions that share labels. When you assign an array with mapped dimensions to a , it “unmaps” them into the positions you specify. No evil meta-programming abstract syntax tree macro bytecode interception is needed. When you run this code: This is what happens behind the scenes: It might seem like I’ve skipped the hard part. How does know how to vectorize over any combination of input dimensions? Don’t I need to do that for every single function that DumPy includes? Isn’t that hard? It is hard, but did it already. This takes a function defined using ( JAX ’s version of) NumPy and vectorizes it over any set of input dimensions. DumPy relies on this to do all the actual vectorization. (If you prefer your janky and broken, I heartily recommend PyTorch’s .) But hold on. If already exists, then why do we need DumPy? Here’s why: That’s how you solve the same problem with . (And basically what DumPy does behind the scenes.) I think is one of the best parts of the NumPy ecosystem. The above code seems genuinely better than the base NumPy version. But it still involves a lot of thinking! Why put in the inner and in the outer one? Why are all the axes even though you need to vectorize over the second dimension of ? There are answers, but they require thinking. Loops and indices are better. OK, I did do one thing that’s a little clever. Say you want to create a Hilbert matrix with . In base NumPy you’d have to do this: In DumPy, you can just write: Yes! That works! It works because a acts both like a string and like an array mapped along that string. So the above code is roughly equivalent to: In reality, the choose random strings. (The class maintains a stack of active ranges to prevent collisions.) So in more detail, the above code becomes something like this: To test if DumPy is actually better in practice, I took six problems of increasing complexity and implemented each of them using loops, NumPy, JAX (with ), and DumPy. Note that in these examples, I always assume the input arrays are in the class of the system being used. If you try running them, you’ll need to add some conversions with / / . (Pretending doesn’t exist.) The goal is to create with The goal of this problem is, given a list of vectors and a list of Gaussians parameters, and arrays mapping each vector to a list of parameters, evaluate each corresponding vector/parameter combination. Formally, given 2D , , , and and 3D , the goal is to create with See also the discussion in the previous post . I gave each implementation a subjective “goodness” score on a 1-10 scale. I always gave the best implementation for each problem 10 points, and then took off points from the others based on how much thinking they required. According to this dubious methodology and these made-up numbers, DumPy is 96.93877% as good as loops! Knowledge is power! But seriously, while subjective, I don’t think my scores should be too controversial. The most debatable one is probably JAX’s attention score. The only thing DumPy adds to NumPy is some nice notation for indices. That’s it. What I think makes DumPy good is it also removes a lot of stuff. Roughly speaking, I’ve tried to remove anything that is confusing and exists because NumPy doesn’t have loops. I’m not sure that I’ve drawn the line in exactly the right place, but I do feel confident that I’m on the right track with removing stuff. In NumPy, works if and are both scalar. Or if is and is . But not if is and is . Huh? In truth, the broadcasting rules aren’t that complicated for scalar operations like multiplication. But still, I don’t like it, because every time you see , you have to worry about what shapes those have and what the computation might be doing. So, I removed it. In DumPy you can only do if one of or is scalar or and have exactly the same shape. That’s it, anything else raises an error. Instead, use indices, so it’s clear what you’re doing. Instead of this: write this: Indexing in NumPy is absurdly complicated . When you write that could do many different things depending on what all the shapes are. I considered going cold-turkey and only allowing scalar indices in DumPy. That wouldn’t have been so bad, since you can still do advanced stuff using loops. But it’s quite annoying to not be able to write when and are just simple 1D arrays. So I’ve tentatively decided to be more pragmatic. In DumPy, you can index with integers, or slices, or (possibly mapped) s. But only one index can be non-scalar . I settled on this because it’s the most general syntax that doesn’t require thinking. Let me show you what I mean. If you see this: It’s “obvious” what the output shape will be. (First the shape of , then the shape of , then the shape of ). Simple enough. But as soon as you have two multidimensional array inputs like this: Suddenly all hell breaks loose. You need to think about broadcasting between and , orthogonal vs. pointwise indices, slices behaving differently than arrays, and quirks for where the output dimensions go. So DumPy forbids this. Instead, you need to write one of these: They all do exactly what they look like they do. Oh, and one more thing! In DumPy, you must index all dimensions . In NumPy, if has three dimensions, then is equivalent to . This is sometimes nice, but it means that every time you see , you have to worry about how many dimensions has. In DumPy, every time you index an array or assign to a , it checks that all indices have been included. So when you see option (4) above, you know that: Always, always, always . No cases, no thinking. Again, many NumPy functions have complex conventions for vectorization. sort of says, “If the inputs have ≤2 dimensions, do the obvious thing. Otherwise, do some extremely confusing broadcasting stuff.” DumPy removes the confusing broadcasting stuff. When you see , you know that and have no more than two dimensions, so nothing tricky is happening. Similarly, in NumPy, is equivalent to . When both inputs have ≤2 or fewer dimensions, this does the “obvious thing”. (Either an inner-product or some kind of matrix/vector multiplication.) Otherwise, it broadcasts or vectorizes or something? I can never remember. In DumPy you don’t have that problem, because it restricts to arrays with one or two dimensions only. If you need more dimensions, no problem: Use indices. It might seem annoying to remove features, but I’m telling you: Just try it . If you program this way, a wonderful feeling of calmness comes over you, as class after class of possible errors disappear. Put another way, why remove all the fancy stuff, instead of leaving it optional? Because optional implies thinking! I want to program in a simple way. I don’t want to worry that I’m accidentally triggering some confusing broadcasting insanity, because that would be a mistake. I want the computer to help me catch mistakes, not silently do something weird that I didn’t intend. In principle, it would be OK if there was a method that preserves all the confusing batching stuff. If you really want that, you can make it yourself: You can use that same wrapper to convert any JAX NumPy function to work with DumPy. Think about math: In two or fewer dimensions, coordinate-free linear algebra notation is wonderful. But for higher dimensional tensors , there are just too many cases, so most physicists just use coordinates. So this solution seems pretty obvious to me. Honestly, I’m a little confused why it isn’t already standard. Am I missing something? When I complain about NumPy, many people often suggest looking into APL -type languages, like A, J, K, or Q. (All single-letter languages are APL-like, except C, D, F, R, T, X, and many others. Convenient, right?) The obvious disadvantages of these are that: None of those bother me. If the languages are better, we should learn to use them and make them do autodiff on GPUs. But I’m not convinced they are better. When you actually learn these languages, what you figure out is that the symbol gibberish basically amounts to doing the same kind of dimension mashing that we saw earlier in NumPy: The reason is that, just like NumPy and , these languages choose align dimensions by position , rather than by name. If I have to mash dimensions, I want to use the best tool. But I’d prefer not to mash dimensions at all. People also often suggest “NumPy with named dimensions” as in xarray . (PyTorch also has a half-hearted implementation .) Of course, DumPy also uses named dimensions, but there’s a critical difference. In xarray, they’re part of the arrays themselves, while in DumPy, they live outside the arrays. In some cases, permanent named dimensions are very nice. But for linear algebra, they’re confusing. For example, suppose is 2-D with named dimensions and . Now, what dimensions should have? ( twice?) Or say you take a singular value decomposition like . What name should the inner dimensions have? Does the user have to specify it? I haven’t seen a nice solution. xarray doesn’t focus on linear algebra, so it’s not much of an issue there. A theoretical “DumPy with permanent names” might be very nice, but I’m not sure how it should work. This is worth thinking about more. I like Julia ! Loops are fast in Julia! But again, I don’t think fast loops matter that much, because I want to move all the loops to the GPU. So even if I was using Julia, I think I’d want to use a DumPy-type solution. I think Julia might well be a better host language than Python, but it wouldn’t be because of fast loops, but because it offers much more powerful meta-programming capabilities. I built DumPy on top of JAX just because JAX is very mature and good at calling the GPU, but I’d love to see the same idea used in Julia (“Dulia”?) or other languages. OK, I promised a link to my prototype, so here it is: It’s just a single file with around 700 lines. I’m leaving it as a single file because I want to stress that this is just something I hacked together in the service of this rant . I wanted to show that I’m not totally out of my mind, and that doing all this is actually pretty easy. I stress that I don’t really intend to update or improve this. (Unless someone gives me a lot of money?) So please do not attempt to use it for “real work”, and do not make fun of my code. PS. DumPy works out of the box with both and . For gradients, you need to either cast the output to a JAX scalar or use the wrapper. PPS. If you like this, you may also like einx or torchdim . Update : Due to many requests, I have turned this into a “real” package, available on PyPi as . You can install it by typing: Or, if you use uv (you should) you can play around with DumPy by just typing this one-liner in your terminal: For example: Don’t make me think. Run fast on GPUs. Really, do not make me think. OK, what shapes do all those arrays have? And what does do when it sees those shapes? Bring back the syntax of loops and indices. But don’t actually execute the loops. Just take the syntax and secretly compile it into vectorized operations. Also, let’s get rid of all the insanity that’s been added to NumPy because loops were slow. If you index a DumPy array with a string (or a object), it creates a special “mapped” array that pretends to have fewer dimensions. When a DumPy function is called (e.g. or (called with )), it checks if any of the arguments have mapped dimensions. If so, it automatically vectorizes the computation, matching up mapped dimensions that share labels. When you assign an array with mapped dimensions to a , it “unmaps” them into the positions you specify. has 4 dimensions has 2 dimensions has 1 dimension has 4 dimensions They’re unfamiliar. The code looks like gibberish. They don’t usually provide autodiff or GPU execution.