|
1 | 1 | from collections import deque |
| 2 | +from collections.abc import Callable |
2 | 3 | from dataclasses import dataclass, field |
3 | 4 | from datetime import datetime, timedelta |
| 5 | +from functools import partial |
4 | 6 | from typing import cast |
5 | 7 |
|
6 | 8 | from apscheduler.triggers.base import BaseTrigger |
@@ -33,6 +35,11 @@ def _limit_dep_wrapper(limit: int | _DependentCallable[int]) -> _DependentCallab |
33 | 35 | limit_dep = limit |
34 | 36 | return limit_dep |
35 | 37 |
|
| 38 | +def inject_increaser(state: T_State, func: Callable): |
| 39 | + executors = state.setdefault("plugin_limiter:increaser", []) |
| 40 | + assert isinstance(executors, list) |
| 41 | + executors.append(func) |
| 42 | + |
36 | 43 | # region: FixWindow |
37 | 44 | @dataclass |
38 | 45 | class FixWindowUsage: |
@@ -139,21 +146,27 @@ async def _limiter_dependency( |
139 | 146 | bucket[entity_id] = FixWindowUsage(now, limit) |
140 | 147 | usage = bucket[entity_id] |
141 | 148 |
|
142 | | - if usage.available > 0: |
| 149 | + def _increase_action(reset: bool = True): |
| 150 | + if reset: |
| 151 | + usage.start_time = now |
| 152 | + usage.available = limit |
143 | 153 | usage.available -= 1 |
| 154 | + |
| 155 | + if usage.available > 0: |
| 156 | + if set_increaser: |
| 157 | + inject_increaser(state, partial(_increase_action, False)) |
| 158 | + else: |
| 159 | + _increase_action(False) |
144 | 160 | return |
145 | 161 |
|
146 | 162 | # Calculate reset time based on when the limitation was set |
147 | 163 | reset_time = trigger.get_next_fire_time(usage.start_time, now) |
148 | 164 | assert reset_time is not None, "reset_time should not be None" |
149 | 165 |
|
150 | | - def _increase_action(): |
151 | | - usage.start_time = now |
152 | | - usage.available = limit - 1 |
153 | | - |
| 166 | + # Reset |
154 | 167 | if now >= reset_time: |
155 | 168 | if set_increaser: |
156 | | - state["plugin_limiter:increaser"] = _increase_action |
| 169 | + inject_increaser(state, _increase_action) |
157 | 170 | else: |
158 | 171 | _increase_action() |
159 | 172 | return # Didn't exceed |
@@ -273,7 +286,7 @@ def _increase_action(): |
273 | 286 |
|
274 | 287 | if len(usage.timestamps) < limit: |
275 | 288 | if set_increaser: |
276 | | - state["plugin_limiter:increaser"] = _increase_action |
| 289 | + inject_increaser(state, _increase_action) |
277 | 290 | else: |
278 | 291 | _increase_action() |
279 | 292 | return # Didn't exceed |
|
0 commit comments