Skip to content

Commit 9bdab58

Browse files
Chen, VivienChen, Vivien
authored andcommitted
Fix Redis MI auth: call super().on_connect() BEFORE AUTH command, add debug logging
1 parent ba339eb commit 9bdab58

1 file changed

Lines changed: 48 additions & 28 deletions

File tree

application/single_app/app.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -136,22 +136,32 @@ def __init__(self, *args, **kwargs):
136136
super().__init__(*args, **kwargs)
137137

138138
def on_connect(self):
139-
super().on_connect()
140139
if hasattr(ManagedIdentityConnection, '_credential') and ManagedIdentityConnection._credential:
141140
# Get fresh token and extract username
142-
token = ManagedIdentityConnection._credential.get_token(ManagedIdentityConnection._token_scope)
143-
import base64, json
144-
token_parts = token.token.split('.')
145-
if len(token_parts) > 1:
146-
payload = token_parts[1] + '=' * (4 - len(token_parts[1]) % 4)
147-
token_data = json.loads(base64.urlsafe_b64decode(payload))
148-
username = token_data.get('oid', '')
149-
else:
150-
username = ''
151-
# Authenticate with fresh token
152-
if username and token.token:
153-
self.send_command('AUTH', username, token.token)
154-
self.read_response()
141+
try:
142+
token = ManagedIdentityConnection._credential.get_token(ManagedIdentityConnection._token_scope)
143+
import base64, json
144+
token_parts = token.token.split('.')
145+
if len(token_parts) > 1:
146+
payload = token_parts[1] + '=' * (4 - len(token_parts[1]) % 4)
147+
token_data = json.loads(base64.urlsafe_b64decode(payload))
148+
username = token_data.get('oid', '')
149+
else:
150+
username = ''
151+
152+
# Call parent to establish connection
153+
super().on_connect()
154+
155+
# Authenticate with fresh token after connection established
156+
if username and token.token:
157+
self.send_command('AUTH', username, token.token)
158+
auth_response = self.read_response()
159+
print(f"Redis AUTH response (startup): {auth_response}")
160+
except Exception as e:
161+
print(f"Redis MI auth error in on_connect (startup): {e}")
162+
raise
163+
else:
164+
super().on_connect()
155165

156166
# Create connection pool with custom connection class
157167
pool = ConnectionPool(
@@ -253,22 +263,32 @@ def __init__(self, *args, **kwargs):
253263
super().__init__(*args, **kwargs)
254264

255265
def on_connect(self):
256-
super().on_connect()
257266
if hasattr(ManagedIdentityConnection, '_credential') and ManagedIdentityConnection._credential:
258267
# Get fresh token and extract username
259-
token = ManagedIdentityConnection._credential.get_token(ManagedIdentityConnection._token_scope)
260-
import base64, json
261-
token_parts = token.token.split('.')
262-
if len(token_parts) > 1:
263-
payload = token_parts[1] + '=' * (4 - len(token_parts[1]) % 4)
264-
token_data = json.loads(base64.urlsafe_b64decode(payload))
265-
username = token_data.get('oid', '')
266-
else:
267-
username = ''
268-
# Authenticate with fresh token
269-
if username and token.token:
270-
self.send_command('AUTH', username, token.token)
271-
self.read_response()
268+
try:
269+
token = ManagedIdentityConnection._credential.get_token(ManagedIdentityConnection._token_scope)
270+
import base64, json
271+
token_parts = token.token.split('.')
272+
if len(token_parts) > 1:
273+
payload = token_parts[1] + '=' * (4 - len(token_parts[1]) % 4)
274+
token_data = json.loads(base64.urlsafe_b64decode(payload))
275+
username = token_data.get('oid', '')
276+
else:
277+
username = ''
278+
279+
# Call parent to establish connection
280+
super().on_connect()
281+
282+
# Authenticate with fresh token after connection established
283+
if username and token.token:
284+
self.send_command('AUTH', username, token.token)
285+
auth_response = self.read_response()
286+
print(f"Redis AUTH response (before-first-request): {auth_response}")
287+
except Exception as e:
288+
print(f"Redis MI auth error in on_connect (before-first-request): {e}")
289+
raise
290+
else:
291+
super().on_connect()
272292

273293
# Create connection pool with custom connection class
274294
pool = ConnectionPool(

0 commit comments

Comments
 (0)