Skip to content

Commit 2922851

Browse files
committed
More typing
1 parent 02eb089 commit 2922851

16 files changed

Lines changed: 94 additions & 83 deletions

File tree

.github/workflows/ci.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,22 @@ jobs:
1313
strategy:
1414
max-parallel: 4
1515
matrix:
16-
python-version: ["3.9", "3.10", "3.11", "3.12"]
16+
python-version: ["3.10", "3.11", "3.12", "3.13"]
1717
steps:
18-
- uses: actions/checkout@v4
18+
- uses: actions/checkout@v5
1919

2020
- name: Set up Python ${{ matrix.python-version }}
21-
uses: actions/setup-python@v5
21+
uses: actions/setup-python@v6
2222
with:
2323
python-version: ${{ matrix.python-version }}
2424
- name: Install dependencies
2525
run: |
2626
python -m pip install --upgrade pip
27-
pip install pytest pycodestyle
27+
pip install pytest pycodestyle pyright
2828
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
2929
- name: Lint with PyCodeStyle
3030
run: |
3131
find . -name \*.py -exec pycodestyle {} +
32+
- name: PyRight
33+
run: |
34+
pyright

aiogcd/connector/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
Created on: May 19, 2017
44
Author: Jeroen van der Heijden <jeroen@cesbit.com>
55
"""
6-
from .connector import GcdConnector, GcdServiceAccountConnector
7-
from .client_token import Token
8-
from .service_account_token import ServiceAccountToken
6+
from .connector import GcdConnector, GcdServiceAccountConnector # noqa: F401
7+
from .client_token import Token # noqa: F401
8+
from .service_account_token import ServiceAccountToken # noqa: F401

aiogcd/connector/client_token.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ async def get(self) -> str:
9595
9696
:return: Access token (string)
9797
"""
98+
assert self._refresh_ts is not None and self._token is not None
9899
async with self._lock:
99100
if self._refresh_ts < time.time():
100101
await self._refresh_token()
@@ -127,6 +128,7 @@ async def connect(self):
127128
logging.info('Token is valid.')
128129

129130
async def _refresh_token(self):
131+
assert self._token is not None
130132
logging.info(
131133
'Token has exceeded half of the expiration time, '
132134
'starting a token refresh.')

aiogcd/connector/connector.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import json
99
import aiohttp
10-
from typing import Iterable, Optional, Any, Union
10+
from typing import Iterable, Any
1111
from .client_token import Token
1212
from .service_account_token import ServiceAccountToken
1313
from .entity import Entity
@@ -43,7 +43,7 @@ def __init__(
4343
client_secret: str,
4444
token_file: str,
4545
scopes: Iterable[str] = DEFAULT_SCOPES,
46-
namespace_id: Optional[str] = None):
46+
namespace_id: str | None = None):
4747

4848
self.project_id = project_id
4949
self.namespace_id = namespace_id
@@ -220,7 +220,7 @@ async def run_query(self, data) -> list[dict]:
220220
results, _ = await self._run_query(data)
221221
return results
222222

223-
async def _run_query(self, data) -> tuple[list[dict], Optional[str]]:
223+
async def _run_query(self, data) -> tuple[list[dict], str| None]:
224224
results = []
225225
cursor = None
226226

@@ -281,7 +281,7 @@ async def _run_query(self, data) -> tuple[list[dict], Optional[str]]:
281281
return results, cursor
282282

283283
async def _get_entities_cursor(self, data) -> \
284-
tuple[list[Entity], Optional[str]]:
284+
tuple[list[Entity], str | None]:
285285
results, cursor = await self._run_query(data)
286286
return [Entity(result['entity']) for result in results], cursor
287287

@@ -301,7 +301,7 @@ async def get_keys(self, data) -> list[Key]:
301301
results, _ = await self._run_query(data)
302302
return [Key(result['entity']['key']) for result in results]
303303

304-
async def get_entity(self, data) -> Optional[Entity]:
304+
async def get_entity(self, data) -> Entity | None:
305305
"""Return an entity object by given query data.
306306
307307
:param data: see the following link for the data format:
@@ -313,18 +313,18 @@ async def get_entity(self, data) -> Optional[Entity]:
313313
result = await self.get_entities(data)
314314
return result[0] if result else None
315315

316-
async def get_key(self, data) -> Optional[Key]:
316+
async def get_key(self, data) -> Key | None:
317317
data['query']['limit'] = 1
318318
result = await self.get_keys(data)
319319
return result[0] if result else None
320320

321-
async def get_entities_by_kind(self, kind: str,
322-
offset: Optional[int] = None,
323-
limit: Optional[int] = None,
324-
cursor: Optional[str] = None) -> Union[
325-
list[Entity],
326-
tuple[list[Entity], Optional[str]]
327-
]:
321+
async def get_entities_by_kind(
322+
self,
323+
kind: str,
324+
offset: int | None = None,
325+
limit: int | None = None,
326+
cursor: str | None = None
327+
) -> list[Entity] | tuple[list[Entity], str | None]:
328328
"""Returns entities by kind.
329329
330330
When a limit is set, this function returns a list and a cursor.
@@ -344,8 +344,8 @@ async def get_entities_by_kind(self, kind: str,
344344
return await self._get_entities_cursor(data)
345345

346346
async def get_entities_by_keys(self, keys: Iterable[Key],
347-
missing: Optional[list[Any]] = None,
348-
deferred: Optional[list[Key]] = None,
347+
missing: list[Any] | None = None,
348+
deferred: list[Key] | None = None,
349349
eventual: bool = False) -> list[Entity]:
350350
"""Returns entity objects for the given keys or an empty list in case
351351
no entity is found. The order of entities might not be equal to the
@@ -398,9 +398,9 @@ def data():
398398
return entities
399399

400400
async def get_entity_by_key(self, key: Key,
401-
missing: Optional[list[Any]] = None,
402-
deferred: Optional[list[Key]] = None,
403-
eventual: bool = False) -> Optional[Entity]:
401+
missing: list[Any] | None = None,
402+
deferred: list[Key] | None = None,
403+
eventual: bool = False) -> Entity | None:
404404
"""Returns an entity object for the given key or None in case no
405405
entity is found.
406406
@@ -447,9 +447,9 @@ def __init__(
447447
self,
448448
project_id: str,
449449
service_file: str,
450-
session: Optional[aiohttp.ClientSession] = None,
451-
scopes: Optional[Iterable[str]] = None,
452-
namespace_id: Optional[str] = None):
450+
session: aiohttp.ClientSession | None = None,
451+
scopes: Iterable[str] | None = None,
452+
namespace_id: str | None = None):
453453

454454
scopes = scopes or list(DEFAULT_SCOPES)
455455
self.project_id = project_id

aiogcd/connector/decoder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def __new__(cls, *args, ks):
2525
ks += b'=' * (4 - len(ks) % 4)
2626
ks = base64.b64decode(ks.replace(b'-', b'+').replace(b'_', b'/'))
2727
decoder.frombytes(ks)
28-
decoder._idx = 0
29-
decoder.set_end()
28+
decoder._idx = 0 # type: ignore
29+
decoder.set_end() # type: ignore
3030
return decoder
3131

3232
def set_end(self, end=None):
@@ -35,13 +35,13 @@ def set_end(self, end=None):
3535
self._end = len(self) if end is None else self._idx + end
3636

3737
def __bool__(self):
38-
return self._idx < self._end
38+
return self._idx < self._end # type: ignore
3939

4040
def _get8(self):
4141
if not self:
4242
raise BufferDecodeError('truncated')
43-
c = self[self._idx]
44-
self._idx += 1
43+
c = self[self._idx] # type: ignore
44+
self._idx += 1 # type: ignore
4545
return c
4646

4747
def get_var_int32(self):

aiogcd/connector/entity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, entity_res: dict):
3737
https://cloud.google.com/datastore/docs/reference/rest/v1/Entity
3838
"""
3939
self.key = Key(entity_res['key'])
40-
self._properties = set()
40+
self._properties: set[str] = set()
4141

4242
for prop, val in entity_res['properties'].items():
4343
self._properties.add(prop)

aiogcd/connector/key.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Author: Jeroen van der Heijden <jeroen@cesbit.com>
55
"""
66
import base64
7-
from typing import Optional
87
from .buffer import Buffer
98
from .buffer import BufferDecodeError
109
from .path import Path
@@ -40,11 +39,12 @@ class Key:
4039
Key(path=Path(...), project_id="my-project-id")
4140
"""
4241
_ks = None
42+
path: Path
4343

44-
def __init__(self, *args, ks: Optional[str] = None,
45-
path: Optional[Path] = None,
46-
project_id: Optional[str] = None,
47-
namespace_id: Optional[str] = None):
44+
def __init__(self, *args, ks: str | None = None,
45+
path: Path | None = None,
46+
project_id: str | None = None,
47+
namespace_id: str | None = None):
4848
if len(args) == 1 and isinstance(args[0], dict):
4949
assert ks is None and path is None and project_id is None, \
5050
self.KEY_INIT_MSG
@@ -96,16 +96,17 @@ def encode(self):
9696
this method for generating an urlsafe key string.
9797
"""
9898
buffer = Buffer()
99-
buffer.add_var_int32(106)
99+
buffer.add_var_int32(106) # type: ignore
100100

101101
# The project id in a key string is prefixed with s~
102-
buffer.add_prefixed_string('s~{}'.format(self.project_id))
102+
buffer.add_prefixed_string( # type: ignore
103+
's~{}'.format(self.project_id))
103104

104105
self.path.encode(buffer)
105106

106107
if self.namespace_id:
107-
buffer.add_var_int32(162)
108-
buffer.add_prefixed_string(self.namespace_id)
108+
buffer.add_var_int32(162) # type: ignore
109+
buffer.add_prefixed_string(self.namespace_id) # type: ignore
109110

110111
return buffer
111112

@@ -134,54 +135,55 @@ def kind(self):
134135
return self.path[-1].kind
135136

136137
@property
137-
def id(self):
138+
def id(self) -> int | str:
138139
"""Shortcut for .path[-1].id"""
139140
return self.path[-1].id
140141

141142
@staticmethod
142-
def _extract_id_or_name(pair):
143+
def _extract_id_or_name(pair) -> int | str:
143144
"""Used on __init__."""
144145
if 'id' in pair:
145146
return int(pair['id'])
146147

147148
if 'name' in pair:
148149
return pair['name']
149150

150-
return None
151+
return ''
151152

152153
@staticmethod
153-
def _deserialize_ks(ks: str):
154+
def _deserialize_ks(ks: str) -> tuple[str | None, str | None, Path]:
154155
"""Returns a tuple with the project_id, namespace_id and Path
155156
from a key string."""
156157

157158
decoder = Decoder(ks=ks)
158-
project_id = None
159-
namespace_id = None
160-
path = None
159+
project_id: str | None = None
160+
namespace_id: str | None = None
161+
path: Path | None = None
161162

162163
while decoder:
163-
tt = decoder.get_var_int32()
164+
tt = decoder.get_var_int32() # type: ignore
164165

165166
if tt == 106:
166167
# The project id in a key string is prefixed with s~ which is
167168
# not part of the real project id.
168-
project_id = decoder.get_prefixed_string()[2:]
169+
project_id = decoder.get_prefixed_string()[2:] # type: ignore
169170
continue
170171

171172
if tt == 114:
172-
sz = decoder.get_var_int32()
173-
decoder.set_end(sz)
173+
sz = decoder.get_var_int32() # type: ignore
174+
decoder.set_end(sz) # type: ignore
174175
path = path_from_decoder(decoder)
175-
decoder.set_end()
176+
decoder.set_end() # type: ignore
176177
continue
177178

178179
if tt == 162:
179-
namespace_id = decoder.get_prefixed_string()
180+
namespace_id = decoder.get_prefixed_string() # type: ignore
180181
continue
181182

182183
if tt == 0:
183184
raise BufferDecodeError('corrupt')
184185

186+
assert path
185187
return project_id, namespace_id, path
186188

187189
def get_parent(self):

aiogcd/connector/path.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@
33
Created on: May 19, 2017
44
Author: Jeroen van der Heijden <jeroen@cesbit.com>
55
"""
6-
from typing import Iterable, Union
6+
from typing import Iterable
77
from .pathelement import PathElement
88
from .pathelement import path_element_from_decoder
99
from .buffer import BufferDecodeError
1010

1111

1212
class Path:
1313

14-
def __init__(self, pairs: Union[
15-
Iterable[PathElement],
16-
Iterable[tuple[str, Union[int, str]]]]):
17-
self._path: tuple[PathElement] = tuple(
14+
def __init__(self, pairs: Iterable[PathElement] | Iterable[
15+
tuple[str, int | str]]):
16+
self._path: tuple[PathElement, ...] = tuple(
1817
pe if isinstance(pe, PathElement) else PathElement(*pe)
1918
for pe in pairs)
2019

@@ -43,7 +42,7 @@ def byte_size(self) -> int:
4342
n += path_element.byte_size
4443
return n
4544

46-
def get_as_tuple(self) -> tuple[tuple[str, Union[str, int]], ...]:
45+
def get_as_tuple(self) -> tuple[tuple[str, str | int], ...]:
4746
"""Returns a tuple of pairs (tuples) representing the key path of an
4847
entity. Useful for composing entities with a specific ancestor."""
4948
return tuple((pe.kind, pe.id) for pe in self._path)

aiogcd/connector/service_account_token.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ async def ensure_token(self):
7272

7373
else:
7474
now = datetime.datetime.now()
75+
assert self.access_token_acquired_at is not None
76+
assert self.access_token_duration is not None
7577
delta = (now - self.access_token_acquired_at).total_seconds()
7678
if delta > self.access_token_duration / 2:
7779
self.acquiring = asyncio.ensure_future(
@@ -93,6 +95,7 @@ async def _acquire_access_token(self):
9395
return True
9496

9597
async def _acquire_token(self):
98+
assert self.service_data is not None
9699
assertion = self._generate_assertion()
97100
url = self.service_data['token_uri']
98101

@@ -115,6 +118,7 @@ async def _acquire_token(self):
115118
return json
116119

117120
def _generate_assertion(self):
121+
assert self.service_data is not None
118122
payload = self._make_gcloud_oauth_body(
119123
)
120124

@@ -127,6 +131,7 @@ def _generate_assertion(self):
127131
return jwt_token
128132

129133
def _make_gcloud_oauth_body(self):
134+
assert self.service_data is not None
130135
uri = self.service_data['token_uri']
131136
client_email = self.service_data['client_email']
132137

aiogcd/connector/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def value_to_dict(val):
3434
.decode('utf-8')}
3535
return {'stringValue': val}
3636
if isinstance(val, Key):
37-
return {'keyValue': val.get_dict()}
37+
return {'keyValue': val.get_dict()} # type: ignore
3838
if isinstance(val, list):
3939
return {'arrayValue': {'values': [value_to_dict(v) for v in val]}}
4040
if isinstance(val, TimestampValue):

0 commit comments

Comments
 (0)