diff --git a/q2_annotate/kraken2/database.py b/q2_annotate/kraken2/database.py index 30073d30..27596357 100644 --- a/q2_annotate/kraken2/database.py +++ b/q2_annotate/kraken2/database.py @@ -175,7 +175,7 @@ def _build_bracken_database( ) -def _find_latest_db(collection: str, response: requests.Response) -> str: +def _find_latest_db(s3_objects: list, collection: str) -> str: collection_id = COLLECTIONS[collection] if collection in S16_DBS: @@ -183,16 +183,7 @@ def _find_latest_db(collection: str, response: requests.Response) -> str: else: pattern = rf"kraken\/k2_{collection_id}_\d{{8}}.tar.gz" - s3_objects = xmltodict.parse(response.content) - s3_objects = s3_objects.get("ListBucketResult") - if not s3_objects: - raise ValueError( - "No databases were found in the response returned by S3. " - "Please try again." - ) - s3_objects = [ - obj for obj in s3_objects["Contents"] if re.match(pattern, obj["Key"]) - ] + s3_objects = [obj for obj in s3_objects if re.match(pattern, obj["Key"])] s3_objects = sorted(s3_objects, key=lambda x: x["LastModified"], reverse=True) latest_db = s3_objects[0]["Key"] @@ -211,21 +202,42 @@ def _fetch_db_collection(collection: str, tmp_dir: str): "Could not connect to the server. Please check your internet " "connection and try again. The error was: {}." ) - try: - response = requests.get(S3_COLLECTIONS_URL) - except requests.exceptions.ConnectionError as e: - raise ValueError(err_msg.format(e)) + continuation_token = None + s3_objects = [] - if response.status_code == 200: - latest_db = _find_latest_db(collection, response) - print(f'Found the latest "{collection}" database: {latest_db}.') - else: + while True: + params = {"list-type": "2"} + if continuation_token: + params["continuation-token"] = continuation_token + + try: + response = requests.get(S3_COLLECTIONS_URL, params=params) + response.raise_for_status() + except requests.RequestException as e: + raise ValueError(f"Could not fetch the list of available databases: {e}") + + data = xmltodict.parse(response.content).get("ListBucketResult", {}) + + # Safely extract contents (handles single-item dict vs list edge cases) + contents = data.get("Contents", []) + if isinstance(contents, dict): + contents = [contents] + s3_objects.extend(contents) + + # Clean pagination check: if a token exists and it's truncated, keep going + if data.get("IsTruncated") == "true" and "NextContinuationToken" in data: + continuation_token = data["NextContinuationToken"] + else: + break + + if not s3_objects: raise ValueError( - "Could not fetch the list of available databases. " - f"Status code was: {response.status_code}. " - "Please try again later." + "No databases were found in the response returned by S3. Please try again." ) + latest_db = _find_latest_db(s3_objects, collection) + print(f'Found the latest "{collection}" database: {latest_db}.') + db_uri = f"{S3_COLLECTIONS_URL}/{latest_db}" try: response = requests.get(db_uri, stream=True) diff --git a/q2_annotate/kraken2/tests/test_database.py b/q2_annotate/kraken2/tests/test_database.py index 7dc38a40..4a42c30f 100644 --- a/q2_annotate/kraken2/tests/test_database.py +++ b/q2_annotate/kraken2/tests/test_database.py @@ -12,10 +12,11 @@ import tempfile import unittest from copy import deepcopy +import requests from requests.exceptions import ConnectionError from subprocess import CalledProcessError from tempfile import TemporaryDirectory -from unittest.mock import patch, ANY, call, Mock, MagicMock +from unittest.mock import patch, ANY, call, MagicMock from parameterized import parameterized import pandas as pd @@ -345,24 +346,38 @@ def test_build_bracken_database_exception(self, p1): ) def test_find_latest_db(self): - response = Mock(content=self.s3_response) - - obs_db = _find_latest_db("viral", response) + s3_objects = [ + { + "Key": "kraken/k2_viral_20201202.tar.gz", + "LastModified": "2020-12-09T01:38:22.000Z", + }, + { + "Key": "kraken/k2_viral_20230314.tar.gz", + "LastModified": "2023-03-22T01:29:11.000Z", + }, + ] + obs_db = _find_latest_db(s3_objects, "viral") exp_db = "kraken/k2_viral_20230314.tar.gz" self.assertEqual(obs_db, exp_db) def test_find_latest_16S_db(self): - response = Mock(content=self.s3_response_16S) - - obs_db = _find_latest_db("greengenes", response) + s3_objects = [ + { + "Key": "kraken/16S_Greengenes13.5_20200326.tgz", + "LastModified": "2020-03-26T01:38:22.000Z", + }, + { + "Key": "kraken/16S_Greengenes13.5_20200326.tgz", + "LastModified": "2024-02-12T01:29:11.000Z", + }, + ] + obs_db = _find_latest_db(s3_objects, "greengenes") exp_db = "kraken/16S_Greengenes13.5_20200326.tgz" self.assertEqual(obs_db, exp_db) def test_find_latest_db_empty(self): - response = Mock(content=b"""""") - - with self.assertRaisesRegex(ValueError, r"No databases were found.+"): - _find_latest_db("viral", response) + with self.assertRaises(IndexError): + _find_latest_db([], "viral") @patch("requests.get") @patch("tarfile.open") @@ -375,7 +390,7 @@ def test_fetch_db_collection_success( self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get ): mock_requests_get.side_effect = [ - MagicMock(status_code=200), + MagicMock(status_code=200, content=self.s3_response), MagicMock( status_code=200, iter_content=lambda chunk_size: self.tar_chunks, @@ -388,12 +403,12 @@ def test_fetch_db_collection_success( mock_requests_get.assert_has_calls( [ - call(S3_COLLECTIONS_URL), + call(S3_COLLECTIONS_URL, params={"list-type": "2"}), call(f"{S3_COLLECTIONS_URL}/kraken/k2_viral.tar.gz", stream=True), ] ) mock_tarfile_open.assert_called_once_with("/tmp/k2_viral.tar.gz", "r:gz") - mock_find.assert_called_once_with("viral", ANY) + mock_find.assert_called_once_with(ANY, "viral") mock_tqdm.assert_not_called() @parameterized.expand( @@ -450,7 +465,7 @@ def test_fetch_db_collection_16S_success( "q2_annotate.kraken2.database._find_latest_db", return_value=latest_db ): mock_requests_get.side_effect = [ - MagicMock(status_code=200), + MagicMock(status_code=200, content=self.s3_response_16S), MagicMock( status_code=200, iter_content=lambda chunk_size: self.tar_chunks, @@ -465,7 +480,7 @@ def test_fetch_db_collection_16S_success( mock_move.assert_called_once_with("/tmp") mock_requests_get.assert_has_calls( [ - call(S3_COLLECTIONS_URL), + call(S3_COLLECTIONS_URL, params={"list-type": "2"}), call(f"{S3_COLLECTIONS_URL}/{latest_db}", stream=True), ] ) @@ -482,7 +497,7 @@ def test_fetch_db_collection_success_with_tqdm( self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get ): mock_requests_get.side_effect = [ - MagicMock(status_code=200), + MagicMock(status_code=200, content=self.s3_response), MagicMock( status_code=200, iter_content=lambda chunk_size: self.tar_chunks, @@ -495,12 +510,12 @@ def test_fetch_db_collection_success_with_tqdm( mock_requests_get.assert_has_calls( [ - call(S3_COLLECTIONS_URL), + call(S3_COLLECTIONS_URL, params={"list-type": "2"}), call(f"{S3_COLLECTIONS_URL}/kraken/k2_viral.tar.gz", stream=True), ] ) mock_tarfile_open.assert_called_once_with("/tmp/k2_viral.tar.gz", "r:gz") - mock_find.assert_called_once_with("viral", ANY) + mock_find.assert_called_once_with(ANY, "viral") mock_tqdm.assert_called_once_with( desc='Downloading the "kraken/k2_viral.tar.gz" database', total=1000, @@ -563,7 +578,7 @@ def test_fetch_db_collection_16S_tqdm_success( "q2_annotate.kraken2.database._find_latest_db", return_value=latest_db ): mock_requests_get.side_effect = [ - MagicMock(status_code=200), + MagicMock(status_code=200, content=self.s3_response_16S), MagicMock( status_code=200, iter_content=lambda chunk_size: self.tar_chunks, @@ -578,7 +593,7 @@ def test_fetch_db_collection_16S_tqdm_success( mock_move.assert_called_once_with("/tmp") mock_requests_get.assert_has_calls( [ - call(S3_COLLECTIONS_URL), + call(S3_COLLECTIONS_URL, params={"list-type": "2"}), call(f"{S3_COLLECTIONS_URL}/{latest_db}", stream=True), ] ) @@ -593,15 +608,138 @@ def test_fetch_db_collection_16S_tqdm_success( @patch("requests.get") def test_fetch_db_collection_connection_error(self, mock_get): mock_get.side_effect = ConnectionError("Some error.") - with self.assertRaisesRegex(ValueError, r".+The error was\: Some error\."): + with self.assertRaisesRegex( + ValueError, r"Could not fetch the list of available databases:.*" + ): _fetch_db_collection("my_collection", "/tmp") @patch("requests.get") def test_fetch_db_collection_status_non200(self, mock_get): - mock_get.return_value = Mock(status_code=404) - with self.assertRaisesRegex(ValueError, r".+Status code was\: 404"): + response_mock = MagicMock() + response_mock.raise_for_status.side_effect = requests.HTTPError( + "404 Client Error" + ) + mock_get.return_value = response_mock + with self.assertRaisesRegex( + ValueError, r"Could not fetch the list of available databases:.*" + ): _fetch_db_collection("my_collection", "/tmp") + @patch("requests.get") + @patch("tarfile.open") + @patch( + "q2_annotate.kraken2.database._find_latest_db", + return_value="kraken/k2_viral_20230314.tar.gz", + ) + @patch("q2_annotate.kraken2.database.tqdm") + def test_fetch_db_collection_paginated( + self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get + ): + page1 = b""" + + true + token-123 + + kraken/k2_viral_20201202.tar.gz + 2020-12-09T01:38:22.000Z + + + """ + page2 = b""" + + false + + kraken/k2_viral_20230314.tar.gz + 2023-03-22T01:29:11.000Z + + + """ + mock_requests_get.side_effect = [ + MagicMock(status_code=200, content=page1), + MagicMock(status_code=200, content=page2), + MagicMock( + status_code=200, + iter_content=lambda chunk_size: self.tar_chunks, + headers={}, + ), + ] + mock_tarfile_open.return_value.__enter__.return_value = MagicMock() + + _fetch_db_collection("viral", "/tmp") + + mock_requests_get.assert_has_calls( + [ + call(S3_COLLECTIONS_URL, params={"list-type": "2"}), + call( + S3_COLLECTIONS_URL, + params={"list-type": "2", "continuation-token": "token-123"}, + ), + call( + f"{S3_COLLECTIONS_URL}/kraken/k2_viral_20230314.tar.gz", stream=True + ), + ] + ) + mock_tarfile_open.assert_called_once_with( + "/tmp/k2_viral_20230314.tar.gz", "r:gz" + ) + + @patch("requests.get") + @patch("tarfile.open") + @patch( + "q2_annotate.kraken2.database._find_latest_db", + return_value="kraken/k2_viral_20230314.tar.gz", + ) + @patch("q2_annotate.kraken2.database.tqdm") + def test_fetch_db_collection_single_content_dict( + self, mock_tqdm, mock_find, mock_tarfile_open, mock_requests_get + ): + single_item_response = b""" + + + kraken/k2_viral_20230314.tar.gz + 2023-03-22T01:29:11.000Z + + + """ + mock_requests_get.side_effect = [ + MagicMock(status_code=200, content=single_item_response), + MagicMock( + status_code=200, + iter_content=lambda chunk_size: self.tar_chunks, + headers={}, + ), + ] + mock_tarfile_open.return_value.__enter__.return_value = MagicMock() + + _fetch_db_collection("viral", "/tmp") + + mock_find.assert_called_once_with( + [ + { + "Key": "kraken/k2_viral_20230314.tar.gz", + "LastModified": "2023-03-22T01:29:11.000Z", + } + ], + "viral", + ) + + @patch("requests.get") + def test_fetch_db_collection_empty_s3_objects(self, mock_requests_get): + empty_response = b""" + + test-bucket + + """ + mock_requests_get.return_value = MagicMock( + status_code=200, content=empty_response + ) + + with self.assertRaisesRegex( + ValueError, + "No databases were found in the response returned by S3. Please try again.", + ): + _fetch_db_collection("viral", "/tmp") + def test_file_move_one_level_up(self): with TemporaryDirectory() as tmp_dir: fake_folder = os.path.join(tmp_dir, "test_move")