diff --git a/q2_annotate/kraken2/database.py b/q2_annotate/kraken2/database.py
index 30073d30..27596357 100644
--- a/q2_annotate/kraken2/database.py
+++ b/q2_annotate/kraken2/database.py
@@ -175,7 +175,7 @@ def _build_bracken_database(
)
-def _find_latest_db(collection: str, response: requests.Response) -> str:
+def _find_latest_db(s3_objects: list, collection: str) -> str:
collection_id = COLLECTIONS[collection]
if collection in S16_DBS:
@@ -183,16 +183,7 @@ def _find_latest_db(collection: str, response: requests.Response) -> str:
else:
pattern = rf"kraken\/k2_{collection_id}_\d{{8}}.tar.gz"
- s3_objects = xmltodict.parse(response.content)
- s3_objects = s3_objects.get("ListBucketResult")
- if not s3_objects:
- raise ValueError(
- "No databases were found in the response returned by S3. "
- "Please try again."
- )
- s3_objects = [
- obj for obj in s3_objects["Contents"] if re.match(pattern, obj["Key"])
- ]
+ s3_objects = [obj for obj in s3_objects if re.match(pattern, obj["Key"])]
s3_objects = sorted(s3_objects, key=lambda x: x["LastModified"], reverse=True)
latest_db = s3_objects[0]["Key"]
@@ -211,21 +202,42 @@ def _fetch_db_collection(collection: str, tmp_dir: str):
"Could not connect to the server. Please check your internet "
"connection and try again. The error was: {}."
)
- try:
- response = requests.get(S3_COLLECTIONS_URL)
- except requests.exceptions.ConnectionError as e:
- raise ValueError(err_msg.format(e))
+ continuation_token = None
+ s3_objects = []
- if response.status_code == 200:
- latest_db = _find_latest_db(collection, response)
- print(f'Found the latest "{collection}" database: {latest_db}.')
- else:
+ while True:
+ params = {"list-type": "2"}
+ if continuation_token:
+ params["continuation-token"] = continuation_token
+
+ try:
+ response = requests.get(S3_COLLECTIONS_URL, params=params)
+ response.raise_for_status()
+ except requests.RequestException as e:
+ raise ValueError(f"Could not fetch the list of available databases: {e}")
+
+ data = xmltodict.parse(response.content).get("ListBucketResult", {})
+
+ # Safely extract contents (handles single-item dict vs list edge cases)
+ contents = data.get("Contents", [])
+ if isinstance(contents, dict):
+ contents = [contents]
+ s3_objects.extend(contents)
+
+ # Clean pagination check: if a token exists and it's truncated, keep going
+ if data.get("IsTruncated") == "true" and "NextContinuationToken" in data:
+ continuation_token = data["NextContinuationToken"]
+ else:
+ break
+
+ if not s3_objects:
raise ValueError(
- "Could not fetch the list of available databases. "
- f"Status code was: {response.status_code}. "
- "Please try again later."
+ "No databases were found in the response returned by S3. Please try again."
)
+ latest_db = _find_latest_db(s3_objects, collection)
+ print(f'Found the latest "{collection}" database: {latest_db}.')
+
db_uri = f"{S3_COLLECTIONS_URL}/{latest_db}"
try:
response = requests.get(db_uri, stream=True)
diff --git a/q2_annotate/kraken2/tests/test_database.py b/q2_annotate/kraken2/tests/test_database.py
index 7dc38a40..4a42c30f 100644
--- a/q2_annotate/kraken2/tests/test_database.py
+++ b/q2_annotate/kraken2/tests/test_database.py
@@ -12,10 +12,11 @@
import tempfile
import unittest
from copy import deepcopy
+import requests
from requests.exceptions import ConnectionError
from subprocess import CalledProcessError
from tempfile import TemporaryDirectory
-from unittest.mock import patch, ANY, call, Mock, MagicMock
+from unittest.mock import patch, ANY, call, MagicMock
from parameterized import parameterized
import pandas as pd
@@ -345,24 +346,38 @@ def test_build_bracken_database_exception(self, p1):
)
def test_find_latest_db(self):
- response = Mock(content=self.s3_response)
-
- obs_db = _find_latest_db("viral", response)
+ s3_objects = [
+ {
+ "Key": "kraken/k2_viral_20201202.tar.gz",
+ "LastModified": "2020-12-09T01:38:22.000Z",
+ },
+ {
+ "Key": "kraken/k2_viral_20230314.tar.gz",
+ "LastModified": "2023-03-22T01:29:11.000Z",
+ },
+ ]
+ obs_db = _find_latest_db(s3_objects, "viral")
exp_db = "kraken/k2_viral_20230314.tar.gz"
self.assertEqual(obs_db, exp_db)
def test_find_latest_16S_db(self):
- response = Mock(content=self.s3_response_16S)
-
- obs_db = _find_latest_db("greengenes", response)
+ s3_objects = [
+ {
+ "Key": "kraken/16S_Greengenes13.5_20200326.tgz",
+ "LastModified": "2020-03-26T01:38:22.000Z",
+ },
+ {
+ "Key": "kraken/16S_Greengenes13.5_20200326.tgz",
+ "LastModified": "2024-02-12T01:29:11.000Z",
+ },
+ ]
+ obs_db = _find_latest_db(s3_objects, "greengenes")
exp_db = "kraken/16S_Greengenes13.5_20200326.tgz"
self.assertEqual(obs_db, exp_db)
def test_find_latest_db_empty(self):
- response = Mock(content=b"""""")
-
- with self.assertRaisesRegex(ValueError, r"No databases were found.+"):
- _find_latest_db("viral", response)
+ with self.assertRaises(IndexError):
+ _find_latest_db([], "viral")
@patch("requests.get")
@patch("tarfile.open")
@@ -375,7 +390,7 @@ def test_fetch_db_collection_success(
self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
):
mock_requests_get.side_effect = [
- MagicMock(status_code=200),
+ MagicMock(status_code=200, content=self.s3_response),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
@@ -388,12 +403,12 @@ def test_fetch_db_collection_success(
mock_requests_get.assert_has_calls(
[
- call(S3_COLLECTIONS_URL),
+ call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/kraken/k2_viral.tar.gz", stream=True),
]
)
mock_tarfile_open.assert_called_once_with("/tmp/k2_viral.tar.gz", "r:gz")
- mock_find.assert_called_once_with("viral", ANY)
+ mock_find.assert_called_once_with(ANY, "viral")
mock_tqdm.assert_not_called()
@parameterized.expand(
@@ -450,7 +465,7 @@ def test_fetch_db_collection_16S_success(
"q2_annotate.kraken2.database._find_latest_db", return_value=latest_db
):
mock_requests_get.side_effect = [
- MagicMock(status_code=200),
+ MagicMock(status_code=200, content=self.s3_response_16S),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
@@ -465,7 +480,7 @@ def test_fetch_db_collection_16S_success(
mock_move.assert_called_once_with("/tmp")
mock_requests_get.assert_has_calls(
[
- call(S3_COLLECTIONS_URL),
+ call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/{latest_db}", stream=True),
]
)
@@ -482,7 +497,7 @@ def test_fetch_db_collection_success_with_tqdm(
self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
):
mock_requests_get.side_effect = [
- MagicMock(status_code=200),
+ MagicMock(status_code=200, content=self.s3_response),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
@@ -495,12 +510,12 @@ def test_fetch_db_collection_success_with_tqdm(
mock_requests_get.assert_has_calls(
[
- call(S3_COLLECTIONS_URL),
+ call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/kraken/k2_viral.tar.gz", stream=True),
]
)
mock_tarfile_open.assert_called_once_with("/tmp/k2_viral.tar.gz", "r:gz")
- mock_find.assert_called_once_with("viral", ANY)
+ mock_find.assert_called_once_with(ANY, "viral")
mock_tqdm.assert_called_once_with(
desc='Downloading the "kraken/k2_viral.tar.gz" database',
total=1000,
@@ -563,7 +578,7 @@ def test_fetch_db_collection_16S_tqdm_success(
"q2_annotate.kraken2.database._find_latest_db", return_value=latest_db
):
mock_requests_get.side_effect = [
- MagicMock(status_code=200),
+ MagicMock(status_code=200, content=self.s3_response_16S),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
@@ -578,7 +593,7 @@ def test_fetch_db_collection_16S_tqdm_success(
mock_move.assert_called_once_with("/tmp")
mock_requests_get.assert_has_calls(
[
- call(S3_COLLECTIONS_URL),
+ call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/{latest_db}", stream=True),
]
)
@@ -593,15 +608,138 @@ def test_fetch_db_collection_16S_tqdm_success(
@patch("requests.get")
def test_fetch_db_collection_connection_error(self, mock_get):
mock_get.side_effect = ConnectionError("Some error.")
- with self.assertRaisesRegex(ValueError, r".+The error was\: Some error\."):
+ with self.assertRaisesRegex(
+ ValueError, r"Could not fetch the list of available databases:.*"
+ ):
_fetch_db_collection("my_collection", "/tmp")
@patch("requests.get")
def test_fetch_db_collection_status_non200(self, mock_get):
- mock_get.return_value = Mock(status_code=404)
- with self.assertRaisesRegex(ValueError, r".+Status code was\: 404"):
+ response_mock = MagicMock()
+ response_mock.raise_for_status.side_effect = requests.HTTPError(
+ "404 Client Error"
+ )
+ mock_get.return_value = response_mock
+ with self.assertRaisesRegex(
+ ValueError, r"Could not fetch the list of available databases:.*"
+ ):
_fetch_db_collection("my_collection", "/tmp")
+ @patch("requests.get")
+ @patch("tarfile.open")
+ @patch(
+ "q2_annotate.kraken2.database._find_latest_db",
+ return_value="kraken/k2_viral_20230314.tar.gz",
+ )
+ @patch("q2_annotate.kraken2.database.tqdm")
+ def test_fetch_db_collection_paginated(
+ self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
+ ):
+ page1 = b"""
+
+ true
+ token-123
+
+ kraken/k2_viral_20201202.tar.gz
+ 2020-12-09T01:38:22.000Z
+
+
+ """
+ page2 = b"""
+
+ false
+
+ kraken/k2_viral_20230314.tar.gz
+ 2023-03-22T01:29:11.000Z
+
+
+ """
+ mock_requests_get.side_effect = [
+ MagicMock(status_code=200, content=page1),
+ MagicMock(status_code=200, content=page2),
+ MagicMock(
+ status_code=200,
+ iter_content=lambda chunk_size: self.tar_chunks,
+ headers={},
+ ),
+ ]
+ mock_tarfile_open.return_value.__enter__.return_value = MagicMock()
+
+ _fetch_db_collection("viral", "/tmp")
+
+ mock_requests_get.assert_has_calls(
+ [
+ call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
+ call(
+ S3_COLLECTIONS_URL,
+ params={"list-type": "2", "continuation-token": "token-123"},
+ ),
+ call(
+ f"{S3_COLLECTIONS_URL}/kraken/k2_viral_20230314.tar.gz", stream=True
+ ),
+ ]
+ )
+ mock_tarfile_open.assert_called_once_with(
+ "/tmp/k2_viral_20230314.tar.gz", "r:gz"
+ )
+
+ @patch("requests.get")
+ @patch("tarfile.open")
+ @patch(
+ "q2_annotate.kraken2.database._find_latest_db",
+ return_value="kraken/k2_viral_20230314.tar.gz",
+ )
+ @patch("q2_annotate.kraken2.database.tqdm")
+ def test_fetch_db_collection_single_content_dict(
+ self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
+ ):
+ single_item_response = b"""
+
+
+ kraken/k2_viral_20230314.tar.gz
+ 2023-03-22T01:29:11.000Z
+
+
+ """
+ mock_requests_get.side_effect = [
+ MagicMock(status_code=200, content=single_item_response),
+ MagicMock(
+ status_code=200,
+ iter_content=lambda chunk_size: self.tar_chunks,
+ headers={},
+ ),
+ ]
+ mock_tarfile_open.return_value.__enter__.return_value = MagicMock()
+
+ _fetch_db_collection("viral", "/tmp")
+
+ mock_find.assert_called_once_with(
+ [
+ {
+ "Key": "kraken/k2_viral_20230314.tar.gz",
+ "LastModified": "2023-03-22T01:29:11.000Z",
+ }
+ ],
+ "viral",
+ )
+
+ @patch("requests.get")
+ def test_fetch_db_collection_empty_s3_objects(self, mock_requests_get):
+ empty_response = b"""
+
+ test-bucket
+
+ """
+ mock_requests_get.return_value = MagicMock(
+ status_code=200, content=empty_response
+ )
+
+ with self.assertRaisesRegex(
+ ValueError,
+ "No databases were found in the response returned by S3. Please try again.",
+ ):
+ _fetch_db_collection("viral", "/tmp")
+
def test_file_move_one_level_up(self):
with TemporaryDirectory() as tmp_dir:
fake_folder = os.path.join(tmp_dir, "test_move")