Skip to content

Commit 4d0f2ba

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactor elasticity retry logic into a reusable private method.
PiperOrigin-RevId: 864734021
1 parent 4f2cead commit 4d0f2ba

2 files changed

Lines changed: 309 additions & 287 deletions

File tree

pathwaysutils/elastic/elastic.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Elasticity manager.
15+
16+
This class provides a utility for elastic training. It provides a decorator that
17+
retries a function in case of `jax.errors.JaxRuntimeError` caused by slice down
18+
events. It also provides a utility for waiting for slices to become active.
19+
"""
20+
21+
import collections
22+
from collections.abc import Mapping, Sequence
23+
import logging
24+
import time
25+
import traceback
26+
27+
import jax
28+
import numpy as np
29+
from pathwaysutils.debug import timing
30+
31+
32+
_logger = logging.getLogger(__name__)
33+
34+
_SIMPLE_EXECUTION_TEST_VALUE = 100
35+
_ELASTIC_DOWN_ERROR_TYPES = [
36+
"DATA_LOSS",
37+
]
38+
_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [
39+
"DEADLINE_EXCEEDED",
40+
"NOT_FOUND",
41+
"INTERNAL",
42+
]
43+
44+
45+
def _plus_one(x: jax.Array) -> jax.Array:
46+
"""Adds one to each element in the array.
47+
48+
Used to test if a slice is active.
49+
50+
Args:
51+
x: The array to add one to.
52+
53+
Returns:
54+
The array with one added to each element.
55+
"""
56+
return x + 1
57+
58+
59+
def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array:
60+
"""Simple execution to test if a slice is active.
61+
62+
This function is used to test if a slice is active. It executes a simple
63+
computation on the devices and returns the result. If any of the devices are
64+
not active, the returned array will fail with a JaxRuntimeError used.
65+
66+
Simply executing this function is not enough to determine if the slice is
67+
active. We also need to check the value of the returned array.
68+
69+
Args:
70+
devices: The devices to execute on.
71+
72+
Returns:
73+
The result of the execution.
74+
"""
75+
if not devices:
76+
raise ValueError("No devices")
77+
78+
test_input = np.zeros(len(devices), dtype=float) + (
79+
_SIMPLE_EXECUTION_TEST_VALUE - 1
80+
)
81+
82+
return jax.pmap(_plus_one, devices=devices)(test_input)
83+
84+
85+
def get_slice_to_devices(
86+
devices: Sequence[jax.Device],
87+
) -> dict[int, Sequence[jax.Device]]:
88+
"""Returns the mapping from slice index to devices."""
89+
slice_to_devices = collections.defaultdict(list)
90+
for d in devices:
91+
slice_to_devices[d.slice_index].append(d)
92+
return dict(slice_to_devices)
93+
94+
95+
@timing.timeit
96+
def get_active_slice_indices(
97+
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
98+
) -> set[int]:
99+
"""Returns the set of active slices indices."""
100+
if slice_to_devices is None:
101+
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))
102+
103+
active_slice_indices = set()
104+
105+
results = {
106+
slice_index: _simple_execution(devices)
107+
for slice_index, devices in slice_to_devices.items()
108+
}
109+
110+
for slice_index, x in results.items():
111+
_logger.info("Checking slice_index=%s", slice_index)
112+
expected = (
113+
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
114+
+ _SIMPLE_EXECUTION_TEST_VALUE
115+
)
116+
try:
117+
with timing.Timer(f"Checking {slice_index=}"):
118+
jax.block_until_ready(x)
119+
if np.allclose(x, expected):
120+
active_slice_indices.add(slice_index)
121+
_logger.info("slice_index=%s active", slice_index)
122+
else:
123+
_logger.error(
124+
"Error with _simple_execution for slice_index=%s. "
125+
"This should never happen. Expected: %s, Actual: %s",
126+
slice_index,
127+
expected,
128+
x,
129+
)
130+
raise ValueError(
131+
f"Error with _simple_execution for slice_index={slice_index}."
132+
)
133+
except jax.errors.JaxRuntimeError as error:
134+
if not is_error_due_to_slice_down(error):
135+
raise
136+
_logger.info("slice_index=%s bad", slice_index)
137+
138+
_logger.info("active_slice_indices=%s", active_slice_indices)
139+
140+
return active_slice_indices
141+
142+
143+
def wait_for_slices(
144+
slice_count: int,
145+
poll_interval: float | int = 10,
146+
timeout: float | int | None = None,
147+
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
148+
) -> set[int]:
149+
"""Waits until after at least `slice_count` slices become active.
150+
151+
Args:
152+
slice_count: The number of slices to wait for.
153+
poll_interval: The minimum number of seconds to wait between availability
154+
checks. If the check takes longer than this, the next check will start
155+
immediately after the current check completes. Defaults to 10 seconds.
156+
timeout: The maximum number of seconds to wait. If None, there is no
157+
timeout.
158+
slice_to_devices: A mapping from slice index to devices. If None,
159+
`get_slice_to_devices(jax.devices())` is used.
160+
161+
Returns:
162+
The active slice indices
163+
164+
Raises:
165+
TimeoutError: If the timeout is reached before the slices become
166+
active.
167+
"""
168+
if slice_to_devices is None:
169+
slice_to_devices = get_slice_to_devices(jax.devices())
170+
171+
start_time = time.time()
172+
173+
while True:
174+
check_start_time = time.time()
175+
176+
active_slice_indices = get_active_slice_indices(slice_to_devices)
177+
if len(active_slice_indices) >= slice_count:
178+
_logger.info("%s slices active.", len(active_slice_indices))
179+
return active_slice_indices
180+
181+
_logger.info(
182+
"%s slices active. Wanting at least %s.",
183+
len(active_slice_indices),
184+
slice_count,
185+
)
186+
187+
time_to_sleep = max(0, poll_interval - (time.time() - check_start_time))
188+
189+
if (
190+
timeout is not None
191+
and (elapsed_time := time.time() - start_time) + time_to_sleep
192+
>= timeout
193+
):
194+
raise TimeoutError(
195+
f"Timed out waiting for {slice_count} slices. Only"
196+
f" {len(active_slice_indices)} active after"
197+
f" {elapsed_time:.2f} seconds."
198+
f" Next check would occur after the timeout of {timeout}"
199+
" seconds."
200+
)
201+
202+
if time_to_sleep > 0:
203+
_logger.info("Sleeping for %.2f seconds.", time_to_sleep)
204+
205+
time.sleep(time_to_sleep)
206+
207+
208+
def is_error_due_to_slice_down(error: Exception) -> bool:
209+
"""Returns True if the error is due to slice down.
210+
211+
The error types that are considered due to slice down are
212+
jax.errors.JaxRuntimeError with the following error kind in the message:
213+
- DATA_LOSS
214+
- DEADLINE_EXCEEDED
215+
- NOT_FOUND
216+
- INTERNAL
217+
218+
Args:
219+
error: The error to check.
220+
"""
221+
error_due_to_slice_down = False
222+
traceback_logging_level = logging.DEBUG
223+
224+
if isinstance(error, jax.errors.JaxRuntimeError):
225+
if any(
226+
error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES
227+
):
228+
_logger.info("Caught an error due to slice down")
229+
230+
error_due_to_slice_down = True
231+
232+
elif any(
233+
error_type in str(error)
234+
for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
235+
):
236+
_logger.warning(
237+
"Caught an error due that may or may not be due to slice down. This"
238+
" error will be treated as due to slice down."
239+
)
240+
traceback_logging_level = logging.WARNING
241+
242+
error_due_to_slice_down = True
243+
244+
if not error_due_to_slice_down:
245+
_logger.info("Caught an error not due to slice down")
246+
247+
_logger.log(
248+
traceback_logging_level, "\n".join(traceback.format_exception(error))
249+
)
250+
251+
return error_due_to_slice_down

0 commit comments

Comments
 (0)