Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions q2_annotate/kraken2/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,24 +175,15 @@ def _build_bracken_database(
)


def _find_latest_db(collection: str, response: requests.Response) -> str:
def _find_latest_db(s3_objects: list, collection: str) -> str:
collection_id = COLLECTIONS[collection]

if collection in S16_DBS:
pattern = rf"kraken\/16S_{collection_id}_\d{{8}}.tgz"
else:
pattern = rf"kraken\/k2_{collection_id}_\d{{8}}.tar.gz"

s3_objects = xmltodict.parse(response.content)
s3_objects = s3_objects.get("ListBucketResult")
if not s3_objects:
raise ValueError(
"No databases were found in the response returned by S3. "
"Please try again."
)
s3_objects = [
obj for obj in s3_objects["Contents"] if re.match(pattern, obj["Key"])
]
s3_objects = [obj for obj in s3_objects if re.match(pattern, obj["Key"])]
s3_objects = sorted(s3_objects, key=lambda x: x["LastModified"], reverse=True)

latest_db = s3_objects[0]["Key"]
Expand All @@ -211,21 +202,42 @@ def _fetch_db_collection(collection: str, tmp_dir: str):
"Could not connect to the server. Please check your internet "
"connection and try again. The error was: {}."
)
try:
response = requests.get(S3_COLLECTIONS_URL)
except requests.exceptions.ConnectionError as e:
raise ValueError(err_msg.format(e))
continuation_token = None
s3_objects = []

if response.status_code == 200:
latest_db = _find_latest_db(collection, response)
print(f'Found the latest "{collection}" database: {latest_db}.')
else:
while True:
params = {"list-type": "2"}
if continuation_token:
params["continuation-token"] = continuation_token

try:
response = requests.get(S3_COLLECTIONS_URL, params=params)
response.raise_for_status()
except requests.RequestException as e:
raise ValueError(f"Could not fetch the list of available databases: {e}")

data = xmltodict.parse(response.content).get("ListBucketResult", {})

# Safely extract contents (handles single-item dict vs list edge cases)
contents = data.get("Contents", [])
if isinstance(contents, dict):
contents = [contents]
s3_objects.extend(contents)

# Clean pagination check: if a token exists and it's truncated, keep going
if data.get("IsTruncated") == "true" and "NextContinuationToken" in data:
continuation_token = data["NextContinuationToken"]
else:
break

if not s3_objects:
raise ValueError(
"Could not fetch the list of available databases. "
f"Status code was: {response.status_code}. "
"Please try again later."
"No databases were found in the response returned by S3. Please try again."
)

latest_db = _find_latest_db(s3_objects, collection)
print(f'Found the latest "{collection}" database: {latest_db}.')

db_uri = f"{S3_COLLECTIONS_URL}/{latest_db}"
try:
response = requests.get(db_uri, stream=True)
Expand Down
186 changes: 162 additions & 24 deletions q2_annotate/kraken2/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import tempfile
import unittest
from copy import deepcopy
import requests
from requests.exceptions import ConnectionError
from subprocess import CalledProcessError
from tempfile import TemporaryDirectory
from unittest.mock import patch, ANY, call, Mock, MagicMock
from unittest.mock import patch, ANY, call, MagicMock
from parameterized import parameterized

import pandas as pd
Expand Down Expand Up @@ -345,24 +346,38 @@ def test_build_bracken_database_exception(self, p1):
)

def test_find_latest_db(self):
response = Mock(content=self.s3_response)

obs_db = _find_latest_db("viral", response)
s3_objects = [
{
"Key": "kraken/k2_viral_20201202.tar.gz",
"LastModified": "2020-12-09T01:38:22.000Z",
},
{
"Key": "kraken/k2_viral_20230314.tar.gz",
"LastModified": "2023-03-22T01:29:11.000Z",
},
]
obs_db = _find_latest_db(s3_objects, "viral")
exp_db = "kraken/k2_viral_20230314.tar.gz"
self.assertEqual(obs_db, exp_db)

def test_find_latest_16S_db(self):
response = Mock(content=self.s3_response_16S)

obs_db = _find_latest_db("greengenes", response)
s3_objects = [
{
"Key": "kraken/16S_Greengenes13.5_20200326.tgz",
"LastModified": "2020-03-26T01:38:22.000Z",
},
{
"Key": "kraken/16S_Greengenes13.5_20200326.tgz",
"LastModified": "2024-02-12T01:29:11.000Z",
},
]
obs_db = _find_latest_db(s3_objects, "greengenes")
exp_db = "kraken/16S_Greengenes13.5_20200326.tgz"
self.assertEqual(obs_db, exp_db)

def test_find_latest_db_empty(self):
response = Mock(content=b"""<ListBucketResult></ListBucketResult>""")

with self.assertRaisesRegex(ValueError, r"No databases were found.+"):
_find_latest_db("viral", response)
with self.assertRaises(IndexError):
_find_latest_db([], "viral")
Comment thread
misialq marked this conversation as resolved.

@patch("requests.get")
@patch("tarfile.open")
Expand All @@ -375,7 +390,7 @@ def test_fetch_db_collection_success(
self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
):
mock_requests_get.side_effect = [
MagicMock(status_code=200),
MagicMock(status_code=200, content=self.s3_response),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
Expand All @@ -388,12 +403,12 @@ def test_fetch_db_collection_success(

mock_requests_get.assert_has_calls(
[
call(S3_COLLECTIONS_URL),
call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/kraken/k2_viral.tar.gz", stream=True),
]
)
mock_tarfile_open.assert_called_once_with("/tmp/k2_viral.tar.gz", "r:gz")
mock_find.assert_called_once_with("viral", ANY)
mock_find.assert_called_once_with(ANY, "viral")
mock_tqdm.assert_not_called()

@parameterized.expand(
Expand Down Expand Up @@ -450,7 +465,7 @@ def test_fetch_db_collection_16S_success(
"q2_annotate.kraken2.database._find_latest_db", return_value=latest_db
):
mock_requests_get.side_effect = [
MagicMock(status_code=200),
MagicMock(status_code=200, content=self.s3_response_16S),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
Expand All @@ -465,7 +480,7 @@ def test_fetch_db_collection_16S_success(
mock_move.assert_called_once_with("/tmp")
mock_requests_get.assert_has_calls(
[
call(S3_COLLECTIONS_URL),
call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/{latest_db}", stream=True),
]
)
Expand All @@ -482,7 +497,7 @@ def test_fetch_db_collection_success_with_tqdm(
self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
):
mock_requests_get.side_effect = [
MagicMock(status_code=200),
MagicMock(status_code=200, content=self.s3_response),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
Expand All @@ -495,12 +510,12 @@ def test_fetch_db_collection_success_with_tqdm(

mock_requests_get.assert_has_calls(
[
call(S3_COLLECTIONS_URL),
call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/kraken/k2_viral.tar.gz", stream=True),
]
)
mock_tarfile_open.assert_called_once_with("/tmp/k2_viral.tar.gz", "r:gz")
mock_find.assert_called_once_with("viral", ANY)
mock_find.assert_called_once_with(ANY, "viral")
mock_tqdm.assert_called_once_with(
desc='Downloading the "kraken/k2_viral.tar.gz" database',
total=1000,
Expand Down Expand Up @@ -563,7 +578,7 @@ def test_fetch_db_collection_16S_tqdm_success(
"q2_annotate.kraken2.database._find_latest_db", return_value=latest_db
):
mock_requests_get.side_effect = [
MagicMock(status_code=200),
MagicMock(status_code=200, content=self.s3_response_16S),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
Expand All @@ -578,7 +593,7 @@ def test_fetch_db_collection_16S_tqdm_success(
mock_move.assert_called_once_with("/tmp")
mock_requests_get.assert_has_calls(
[
call(S3_COLLECTIONS_URL),
call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(f"{S3_COLLECTIONS_URL}/{latest_db}", stream=True),
]
)
Expand All @@ -593,15 +608,138 @@ def test_fetch_db_collection_16S_tqdm_success(
@patch("requests.get")
def test_fetch_db_collection_connection_error(self, mock_get):
mock_get.side_effect = ConnectionError("Some error.")
with self.assertRaisesRegex(ValueError, r".+The error was\: Some error\."):
with self.assertRaisesRegex(
ValueError, r"Could not fetch the list of available databases:.*"
):
_fetch_db_collection("my_collection", "/tmp")

@patch("requests.get")
def test_fetch_db_collection_status_non200(self, mock_get):
mock_get.return_value = Mock(status_code=404)
with self.assertRaisesRegex(ValueError, r".+Status code was\: 404"):
response_mock = MagicMock()
response_mock.raise_for_status.side_effect = requests.HTTPError(
"404 Client Error"
)
mock_get.return_value = response_mock
with self.assertRaisesRegex(
ValueError, r"Could not fetch the list of available databases:.*"
):
_fetch_db_collection("my_collection", "/tmp")

@patch("requests.get")
@patch("tarfile.open")
@patch(
"q2_annotate.kraken2.database._find_latest_db",
return_value="kraken/k2_viral_20230314.tar.gz",
)
@patch("q2_annotate.kraken2.database.tqdm")
def test_fetch_db_collection_paginated(
self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
):
page1 = b"""
<ListBucketResult>
<IsTruncated>true</IsTruncated>
<NextContinuationToken>token-123</NextContinuationToken>
<Contents>
<Key>kraken/k2_viral_20201202.tar.gz</Key>
<LastModified>2020-12-09T01:38:22.000Z</LastModified>
</Contents>
</ListBucketResult>
"""
page2 = b"""
<ListBucketResult>
<IsTruncated>false</IsTruncated>
<Contents>
<Key>kraken/k2_viral_20230314.tar.gz</Key>
<LastModified>2023-03-22T01:29:11.000Z</LastModified>
</Contents>
</ListBucketResult>
"""
mock_requests_get.side_effect = [
MagicMock(status_code=200, content=page1),
MagicMock(status_code=200, content=page2),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
headers={},
),
]
mock_tarfile_open.return_value.__enter__.return_value = MagicMock()

_fetch_db_collection("viral", "/tmp")

mock_requests_get.assert_has_calls(
[
call(S3_COLLECTIONS_URL, params={"list-type": "2"}),
call(
S3_COLLECTIONS_URL,
params={"list-type": "2", "continuation-token": "token-123"},
),
call(
f"{S3_COLLECTIONS_URL}/kraken/k2_viral_20230314.tar.gz", stream=True
),
]
)
mock_tarfile_open.assert_called_once_with(
"/tmp/k2_viral_20230314.tar.gz", "r:gz"
)

@patch("requests.get")
@patch("tarfile.open")
@patch(
"q2_annotate.kraken2.database._find_latest_db",
return_value="kraken/k2_viral_20230314.tar.gz",
)
@patch("q2_annotate.kraken2.database.tqdm")
def test_fetch_db_collection_single_content_dict(
self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get
):
single_item_response = b"""
<ListBucketResult>
<Contents>
<Key>kraken/k2_viral_20230314.tar.gz</Key>
<LastModified>2023-03-22T01:29:11.000Z</LastModified>
</Contents>
</ListBucketResult>
"""
mock_requests_get.side_effect = [
MagicMock(status_code=200, content=single_item_response),
MagicMock(
status_code=200,
iter_content=lambda chunk_size: self.tar_chunks,
headers={},
),
]
mock_tarfile_open.return_value.__enter__.return_value = MagicMock()

_fetch_db_collection("viral", "/tmp")

mock_find.assert_called_once_with(
[
{
"Key": "kraken/k2_viral_20230314.tar.gz",
"LastModified": "2023-03-22T01:29:11.000Z",
}
],
"viral",
)

@patch("requests.get")
def test_fetch_db_collection_empty_s3_objects(self, mock_requests_get):
empty_response = b"""
<ListBucketResult>
<Name>test-bucket</Name>
</ListBucketResult>
"""
mock_requests_get.return_value = MagicMock(
status_code=200, content=empty_response
)

with self.assertRaisesRegex(
ValueError,
"No databases were found in the response returned by S3. Please try again.",
):
_fetch_db_collection("viral", "/tmp")

def test_file_move_one_level_up(self):
with TemporaryDirectory() as tmp_dir:
fake_folder = os.path.join(tmp_dir, "test_move")
Expand Down
Loading