55import logging
66import threading
77import time
8- from collections import defaultdict
8+ from collections import OrderedDict , defaultdict
99from typing import Any , Sequence
1010
11- import celery .events # type: ignore[import]
12- import celery .events .state # type: ignore[import]
11+ import celery # type: ignore[import]
1312
1413from .timer import RepeatTimer
1514
1615logger = logging .getLogger (__name__ )
1716
1817
1918class EventWatcher :
19+ _TASK_NAMES_CACHE_LIMIT = 100_000
20+
2021 last_received_timestamp : datetime .datetime | None
2122 last_received_timestamp_per_task_event : dict [tuple [str , str ], datetime .datetime ]
2223 num_events_per_task_count : dict [tuple [str , str ], int ]
@@ -29,10 +30,9 @@ class EventWatcher:
2930 def create_started (
3031 cls ,
3132 app : celery .Celery ,
32- state : celery .events .state .State ,
3333 buckets : Sequence [float | str ],
3434 ):
35- store = cls (state , buckets )
35+ store = cls (buckets )
3636
3737 def run () -> None :
3838 backoff = 1.0
@@ -64,10 +64,8 @@ def update_enable_event() -> None:
6464
6565 return store
6666
67- def __init__ (
68- self , state : celery .events .state .State , buckets : Sequence [float | str ]
69- ):
70- self ._state = state
67+ def __init__ (self , buckets : Sequence [float | str ]):
68+ self ._task_names_by_uuid : OrderedDict [str , str ] = OrderedDict ()
7169
7270 self .upper_bounds = [float (b ) for b in buckets ]
7371 if self .upper_bounds and self .upper_bounds [- 1 ] != float ("inf" ):
@@ -87,13 +85,18 @@ def on_event(self, event: dict[str, Any]):
8785 now = datetime .datetime .now (tz = datetime .UTC )
8886 self .last_received_timestamp = now
8987
90- self ._state .event (event )
9188 event_name : str = event ["type" ]
9289 if not event_name .startswith ("task-" ):
9390 return
9491
95- task : celery .events .Task = self ._state .get_or_create_task (event ["uuid" ])[0 ]
96- task_name = task .name or "(UNKNOWN)"
92+ uuid : str = event ["uuid" ]
93+ if "name" in event :
94+ self ._task_names_by_uuid .pop (uuid , None )
95+ self ._task_names_by_uuid [uuid ] = event ["name" ]
96+ while len (self ._task_names_by_uuid ) > self ._TASK_NAMES_CACHE_LIMIT :
97+ self ._task_names_by_uuid .popitem (last = False )
98+
99+ task_name = self ._task_names_by_uuid .get (uuid , "(UNKNOWN)" )
97100 self .task_names .add (task_name )
98101
99102 self .last_received_timestamp_per_task_event [(task_name , event_name )] = now
0 commit comments