You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -12,31 +12,310 @@ Let's jump into the role of an algorithm developer. Things are pretty clear: I d
12
12
13
13
Of course we can't achieve that, but we can try minimizing the intrusion we have to intrduce to algorithm code. This reminds me the monkey-patching of `gevent` library, it patches (primarily) the `socket` library, replaces it with `gevent.socket` which can switch to other greenlets when IO would block, much like a goroutine (actually `gevent` is older than Golang!).
14
14
15
-
...
16
-
17
-
18
-
- limitation: single gpu
19
-
- we defined the interface around huggingface
20
-
- we tried pickle and it works
21
-
- but naive pickle is slow
22
-
- we then put it into shared memory
23
-
- daemon
24
-
- how to share memory
25
-
- we encounter details
26
-
- vae
27
-
- stablefast
28
-
- diffusers dyn module
29
-
- accelerate
30
-
- bitsandbytes
31
-
- ??
32
-
- we tried torch's shared tensor
33
-
- this is the initial impl
34
-
- posix shm limit
35
-
- can't reuse pickle result
36
-
- many fd
37
-
- what we accidentally got
38
-
- save developer time
39
-
- saving cpu memory
40
-
future optimizations
41
-
- can we merge Storages of CUDA tensors?
15
+
Since we were only using HuggingFace libs (`transformers`, `diffusers`) to load models at the time, the target become clear: We only introduce a monkey-patch call, and the rest of code should remain unchanged, `XXXPipeline.from_pretrained(...)` should be much faster.
16
+
17
+
## Some Facts, Obvious Decisions and Assumptions
18
+
19
+
**Overmind is a caching library, it caches model loading call results into system memory and later reconstruct it fast.**
20
+
21
+
We skip discussing about how monkey-patching is implemented, that's a not-so-interesting detail. All we need to know is, it redirects all the `XXXPipeline.from_pretrained(...)` calls to `overmind.api.load(XXXPipeline.from_pretrained, ...)`.
22
+
23
+
We use `pickle` to serialize our cache result since... we have no choice, and `torch.save` itself uses `pickle`, it's weird not to use it.
24
+
25
+
We use a client/server architecture since we don't want to invalidate our cache when process terminates. There are many subprocess calls could benefit from it.
26
+
27
+
We assume `XXXPipeline.from_pretrained` parameters to be simple hashable things (`str` and things alike) and other models loaded by overmind (explain later).
28
+
29
+
The name `overmind` is borrowed from Starcraft, as you may have guessed.
30
+
31
+
## Reconstruct it fast!
32
+
33
+
We can't naively save `pickle.loads` result in memory and call it a day. After all, on a warmed up scenario, Linux page cache did its job caching on-disk models and we can still see a loading time measured at 10s of seconds.
34
+
35
+
The inefficiency comes from memory copying. In Python, even creating millions of objects would cost no more than several hundred ms. However, for a memory copy of 10GiB, it would cost half a second. We must avoid memory copy as much as possible.
36
+
37
+
Fortunately, most of the big memory chunks are Torch tensors, we can safely address only them and ignore the rest.
38
+
39
+
Actually, I got the knowledge of the internal structure of a Torch tensor in the reduction code while researching the tensor sharing mechanism:
40
+
41
+
```python
42
+
# Copied from torch.multiprocessing.reductions, most of the code is removed
Quite simple: a tensor is its type, its metadata and its underlying storage. Here `storage` is of type `TypedStorage`, but actually `TypedStorage` is just a simple wrapper to `UntypedStorage`. `UntypedStorage` is the class that actually holding all the tensor data.
57
+
58
+
Our task become more specific now: How do we avoid copying `UntypedStorage`? Can we manage these tensor memory by ourselves and construct `UntypedStorage`s by pointing to the memory we manage?
59
+
60
+
The answer is yes!
61
+
62
+
Skimming through the C++ code of where `UntypedStorage` is constructed, we can easily find a code snippet like this:
Not only can we can use a pointer, but the `at::DataPtr` class can also handle destruction, making the lifetime management much simpler.
91
+
92
+
On the Python side, a pointer to a memory region is represented by a `memoryview` object, these objects support the buffer protocol. We can get a `memoryview` object from many other things, `bytes` and `mmap` are the 2 major things supporting it and also what we care about.
93
+
94
+
Finally, we know what we should do: create a function that accepts a `memoryview` object and turns it into an `UntypedStorage` without copying. With ability to reconstruct `UntypedStorage` from `memoryview`, the actual tensor data don't have to be in the pickle stream, greatly reduced the data size we have to copy around.
There's already a tensor sharing mechanism in PyTorch, but it doesn't fit our needs. More on this later.
136
+
137
+
138
+
When we see 'share' and 'memory' comes together, we all have an urge to use `shmget` and its friends. It is "designed" to be used as a memory sharing mechanism right, why not? But it has 2 major flaws:
139
+
140
+
- POSIX shm is a scarce resource, what you can use is determined by how sysadmin configure the system. An most extreme but ubiquitous
141
+
142
+
143
+
144
+
145
+
3.**Shared Memory (`shmem.py`)**: Manages memory arenas using `memfd_create` (Linux) or named shared memory (Windows). Fragments are content-addressed by hash, enabling deduplication.
146
+
147
+
## The Engineering Tricks
148
+
149
+
### Trick #1: Shared Memory Without POSIX shm Limits
150
+
151
+
PyTorch's built-in tensor sharing uses POSIX shared memory, but this hits practical limits quickly:
152
+
153
+
- Docker defaults to 64MB of `/dev/shm`
154
+
- Each `UntypedStorage` gets its own shm segment, even for tiny buffers
155
+
- Each segment consumes a file descriptor
156
+
- Reference counting prevents pickle reuse
157
+
158
+
Our solution: use `memfd_create` to create anonymous memory-backed file descriptors, then pass them across processes via `/proc/{pid}/fd/{fd}`:
This bypasses the `SCM_RIGHTS` API entirely—no wrestling with ancillary messages on Unix sockets. The fd remains valid as long as the daemon is alive.
179
+
180
+
We allocate exponentially-growing arenas (starting at 8GB, doubling each time) and pack tensor data into them with content-based deduplication. A 64-bit MetroHash identifies fragments:
Same tensor content? Same fragment. No redundant storage.
187
+
188
+
### Trick #2: Custom Pickle Reducers for Tensors
189
+
190
+
Standard pickle serializes tensors as bytes—12GB for a SDXL pipeline, taking 14s to dump and 6s to load. We register custom reducers that store tensor data in shared memory and only pickle a small `Fragment` reference:
191
+
192
+
```python
193
+
def _reduce_storage(storage):
194
+
if storage.size() == 0:
195
+
return (rebuild_storage_empty, (type(storage),))
196
+
else:
197
+
device = storage.device
198
+
storage = storage.cpu()
199
+
frag = hoarder.put(storage) # Store in shared memory
imemstream in((char*)info.ptr, info.size); // No copy!
227
+
return import_ir_module(std::move(cu), in, ...);
228
+
});
229
+
```
230
+
231
+
Similarly, `_make_untyped_storage` creates a `torch.UntypedStorage` that directly wraps a `memoryview`, with the buffer's lifetime tied to the Python object via a custom destructor.
232
+
233
+
### Trick #4: Surviving the HuggingFace Ecosystem
234
+
235
+
Real-world usage exposed edge cases in diffusers, transformers, accelerate, and bitsandbytes:
236
+
237
+
**diffusers dynamic modules**: Model repos can include Python files that get imported at runtime into a `diffusers_modules` namespace. The client doesn't have these in `sys.path`, breaking unpickling. Fix: pre-import the module on the client:
238
+
239
+
```python
240
+
def diffusers_dyn_module_workaround():
241
+
from diffusers.utils.constants import HF_MODULES_CACHE
**The `vae=vae` pattern**: Users often load a VAE separately and pass it to a pipeline. If we naively pickle this, we lose the caching benefit. Solution: attach an `_overmind_ref` marker to loaded models and resolve it server-side:
248
+
249
+
```python
250
+
defreplace_ref(obj):
251
+
if (ref :=getattr(obj, '_overmind_ref', None)):
252
+
returnFalse, ref
253
+
returnTrue, obj
254
+
```
255
+
256
+
**accelerate hooks**: Quantized models via bitsandbytes come with `AlignDevices` hooks that don't pickle. We strip them:
257
+
258
+
```python
259
+
from accelerate.hooks import remove_hook_from_module
**CUDA tensors**: Quantization happens on GPU, but we can't keep CUDA tensors in the daemon (it would block the GPU). We move to CPU, pickle, then restore to the original device on the client.
266
+
267
+
### The Trade-offs
268
+
269
+
-**Single GPU only**: We normalize all `device_map` configurations to `cuda:0`. Multi-GPU would require tracking device placement.
270
+
-**Cold load overhead**: Using `dill` for closures adds ~14s to cold loads (pure Python serialization). This is a one-time cost.
271
+
-**No training**: We force `requires_grad=False`. Overmind is for inference.
272
+
273
+
## Performance & Results
274
+
275
+
| Scenario | Time |
276
+
|----------|------|
277
+
| Vanilla `from_pretrained`|~15s |
278
+
| Overmind cold load |~14s (dill overhead) |
279
+
| Overmind warm load |**0.15s**|
280
+
281
+
Real workload (Image3D pipeline):
282
+
283
+
| Configuration | Total Runtime |
284
+
|---------------|---------------|
285
+
| Without Overmind | 123.5s |
286
+
| Overmind cold | 137.8s |
287
+
| Overmind warm |**109.0s**|
288
+
289
+
The warm case saves 14.5s per run—multiply that by hundreds of daily iterations during research, and you get hours back.
290
+
291
+
### Bonus: CPU Memory Savings
292
+
293
+
Since all processes share the same tensor backing store, you're not duplicating multi-GB models across workers. A single SDXL pipeline in memory serves all clients.
294
+
295
+
## Summary & Getting Started
296
+
297
+
Overmind makes model loading boring—in the best way. One import, one function call, and your 15-second loads become 0.2-second cache hits.
298
+
299
+
```python
300
+
import overmind.api
301
+
overmind.api.monkey_patch_all()
302
+
303
+
# That's it. Your from_pretrained calls are now cached.
0 commit comments