A few weeks ago, Yongsheng invited me to give a presentation at Granica’s Tech Talk series. As all good (read: lazy) speakers, the first thing I did was to open my go-to slides to introduce and advocate for Ray while at Anyscale.
But when glancing through the slides I began realizing that a new angle ought to be added. After joining NVIDIA as an IC 3 months ago, I have spent most of my time getting my hands (very) dirty running and fixing real ML workloads. Being in this ML practitioner position (who consumes ML infra software) has allowed and forced me to think about infra problems in a more critical and impartial way.
Plus, I promised Rahul a “no-BS” talk; so here we go 😉 This post summarizes the key points that I want the audience to take home. None of them are new or original, but I think they deserve to be repeated again and again with different metaphors and different examples.
-
A key problem in ML infra is how to command and utilize resources of remote servers
Why? Because your code will most likely run more slowly (if at all) on a single computer that you have full control of. Of course there’s the race between increasingly powerful single computers, and increasingly large models and data. But at this moment, you pretty much cannot go around having the ability and optionality to harness remote servers.
-
RPC (remote procedure call) is the foundational concept and contract for one client to utilize one remote server
A useful starting point to think about the above problem is “well, how do I utilize one server”. The de facto solution there is RPC, where a client
c
invokes an arbitrary functionf(x)
on the servers
and gets a returny
. Pretty much every distributed application uses one variant of RPC or another, when some unit of computation needs to happen on a remote server. -
There is a spectrum of frameworks that allow a client to utilize multiple (”a cluster”) of servers
So, vanilla RPC surely is a simple-and-elegant starting point; but what if you need to command and utilize not one, but many servers? The problem becomes more complicated and nuanced, and there is a spectrum of solutions. E.g. for training a large model, you can do it by using
torchrun
(either on Slurm or K8s), or by usingActors
on Ray. For analyzing a large dataset, you can use Apache Spark or Presto. So on and so forth.. -
Which framework should you use?
Thanks for bearing with me for the build-up (I hope not too long-winded)! Now lets look at this million $ question: I have a (ML) workload, which compute framework should I use? Thanks to the complexities of these frameworks, I myself have struggled and tripped over quite a bit in understanding and explaining the tradeoffs. What I found helpful is to try to come back to the RPC concept, and I hope this helps you too.
So at the end of the day, a cluster is a bunch of servers each running a RPC server. You can certainly implement whatever workload you have by directly writing the RPC code between your laptop and each server (and between servers). In fact, quoting what Andrej Karpathy recently said about PyTorch, this zero-abstraction approach might become plausible in some future day. But before that day comes, you most likely need the help of a cluster computing framework to keep your code “sane” (simple, comprehensible, and doesn’t change too much when you scale up the problem). And here’s my take on some of these frameworks (in the context of ML workloads):
K8s’s superpower is allowing you to declare the desired scale (as in the declarative-imperative tradeoff). If you want to run
f()
on 10 servers, each with a different piece of the inputx
, you just tell K8s--replicas=10
and it will find you 10 servers and even replace failed ones.And Ray has the superpower of making RPCs between the servers easy and high performance. As for why Ray can do it, I highly recommend this paper by Stephanie (this post’s title pays tribute to it). The TL;DR is that if your code is like
y = f(x)
and thenz = g(y)
, Ray allows you to write this logic like normal code, and handles the RPC between the servers runningf()
andg()
for you.As for frameworks like Apache Spark, I won’t go into too much. They “do all the RPCs for you” as long as your workload can be expressed in their language.
To make all these easy to remember:
🏈 If your workload looks like American football, where a quarterback passes to multiple runners, K8s does this beautifully
⚽ If your workload looks like “proper football”, where everyone passes to everyone, consider Ray
(Editor's note: Huge thanks to Zhe for his fantastic talk and guest blog!)
November 18, 2024