diff --git a/openml/__init__.py b/openml/__init__.py index ae5db261f..21dda24ad 100644 --- a/openml/__init__.py +++ b/openml/__init__.py @@ -33,6 +33,7 @@ utils, ) from .__version__ import __version__ +from ._api import _backend from .datasets import OpenMLDataFeature, OpenMLDataset from .evaluations import OpenMLEvaluation from .flows import OpenMLFlow @@ -109,6 +110,7 @@ def populate_cache( "OpenMLTask", "__version__", "_api_calls", + "_backend", "config", "datasets", "evaluations", diff --git a/openml/_api/__init__.py b/openml/_api/__init__.py new file mode 100644 index 000000000..926fee3d4 --- /dev/null +++ b/openml/_api/__init__.py @@ -0,0 +1,95 @@ +from .clients import ( + HTTPCache, + HTTPClient, + MinIOClient, +) +from .resources import ( + API_REGISTRY, + DatasetAPI, + DatasetV1API, + DatasetV2API, + EstimationProcedureAPI, + EstimationProcedureV1API, + EstimationProcedureV2API, + EvaluationAPI, + EvaluationMeasureAPI, + EvaluationMeasureV1API, + EvaluationMeasureV2API, + EvaluationV1API, + EvaluationV2API, + FallbackProxy, + FlowAPI, + FlowV1API, + FlowV2API, + ResourceAPI, + ResourceV1API, + ResourceV2API, + RunAPI, + RunV1API, + RunV2API, + SetupAPI, + SetupV1API, + SetupV2API, + StudyAPI, + StudyV1API, + StudyV2API, + TaskAPI, + TaskV1API, + TaskV2API, +) +from .setup import ( + APIBackend, + APIBackendBuilder, + APIConfig, + CacheConfig, + Config, + ConnectionConfig, + _backend, +) + +__all__ = [ + "API_REGISTRY", + "APIBackend", + "APIBackendBuilder", + "APIConfig", + "CacheConfig", + "Config", + "ConnectionConfig", + "DatasetAPI", + "DatasetV1API", + "DatasetV2API", + "EstimationProcedureAPI", + "EstimationProcedureV1API", + "EstimationProcedureV2API", + "EvaluationAPI", + "EvaluationMeasureAPI", + "EvaluationMeasureV1API", + "EvaluationMeasureV2API", + "EvaluationV1API", + "EvaluationV2API", + "FallbackProxy", + "FallbackProxy", + "FlowAPI", + "FlowV1API", + "FlowV2API", + "HTTPCache", + "HTTPClient", + "MinIOClient", + "ResourceAPI", + "ResourceAPI", + "ResourceV1API", + "ResourceV2API", + "RunAPI", + "RunV1API", + "RunV2API", + "SetupAPI", + "SetupV1API", + "SetupV2API", + "StudyAPI", + "StudyV1API", + "StudyV2API", + "TaskAPI", + "TaskV1API", + "TaskV2API", + "_backend", +] diff --git a/openml/_api/clients/__init__.py b/openml/_api/clients/__init__.py new file mode 100644 index 000000000..42f11fbcf --- /dev/null +++ b/openml/_api/clients/__init__.py @@ -0,0 +1,8 @@ +from .http import HTTPCache, HTTPClient +from .minio import MinIOClient + +__all__ = [ + "HTTPCache", + "HTTPClient", + "MinIOClient", +] diff --git a/openml/_api/clients/http.py b/openml/_api/clients/http.py new file mode 100644 index 000000000..2c15515f3 --- /dev/null +++ b/openml/_api/clients/http.py @@ -0,0 +1,481 @@ +from __future__ import annotations + +import hashlib +import json +import logging +import math +import random +import time +import xml +from collections.abc import Callable, Mapping +from pathlib import Path +from typing import Any +from urllib.parse import urlencode, urljoin, urlparse + +import requests +import xmltodict +from requests import Response + +from openml.__version__ import __version__ +from openml.enums import RetryPolicy +from openml.exceptions import ( + OpenMLCacheRequiredError, + OpenMLHashException, + OpenMLNotAuthorizedError, + OpenMLServerError, + OpenMLServerException, + OpenMLServerNoResult, +) + + +class HTTPCache: + def __init__(self, *, path: Path, ttl: int) -> None: + self.path = path + self.ttl = ttl + + def get_key(self, url: str, params: dict[str, Any]) -> str: + parsed_url = urlparse(url) + netloc_parts = parsed_url.netloc.split(".")[::-1] + path_parts = parsed_url.path.strip("/").split("/") + + filtered_params = {k: v for k, v in params.items() if k != "api_key"} + params_part = [urlencode(filtered_params)] if filtered_params else [] + + return str(Path(*netloc_parts, *path_parts, *params_part)) + + def _key_to_path(self, key: str) -> Path: + return self.path.joinpath(key) + + def load(self, key: str) -> Response: + path = self._key_to_path(key) + + if not path.exists(): + raise FileNotFoundError(f"Cache directory not found: {path}") + + meta_path = path / "meta.json" + headers_path = path / "headers.json" + body_path = path / "body.bin" + + if not (meta_path.exists() and headers_path.exists() and body_path.exists()): + raise FileNotFoundError(f"Incomplete cache at {path}") + + with meta_path.open("r", encoding="utf-8") as f: + meta = json.load(f) + + created_at = meta.get("created_at") + if created_at is None: + raise ValueError("Cache metadata missing 'created_at'") + + if time.time() - created_at > self.ttl: + raise TimeoutError(f"Cache expired for {path}") + + with headers_path.open("r", encoding="utf-8") as f: + headers = json.load(f) + + body = body_path.read_bytes() + + response = Response() + response.status_code = meta["status_code"] + response.url = meta["url"] + response.reason = meta["reason"] + response.headers = headers + response._content = body + response.encoding = meta["encoding"] + + return response + + def save(self, key: str, response: Response) -> None: + path = self._key_to_path(key) + path.mkdir(parents=True, exist_ok=True) + + (path / "body.bin").write_bytes(response.content) + + with (path / "headers.json").open("w", encoding="utf-8") as f: + json.dump(dict(response.headers), f) + + meta = { + "status_code": response.status_code, + "url": response.url, + "reason": response.reason, + "encoding": response.encoding, + "elapsed": response.elapsed.total_seconds(), + "created_at": time.time(), + "request": { + "method": response.request.method if response.request else None, + "url": response.request.url if response.request else None, + "headers": dict(response.request.headers) if response.request else None, + "body": response.request.body if response.request else None, + }, + } + + with (path / "meta.json").open("w", encoding="utf-8") as f: + json.dump(meta, f) + + +class HTTPClient: + def __init__( # noqa: PLR0913 + self, + *, + server: str, + base_url: str, + api_key: str, + retries: int, + retry_policy: RetryPolicy, + cache: HTTPCache | None = None, + ) -> None: + self.server = server + self.base_url = base_url + self.api_key = api_key + self.retries = retries + self.retry_policy = retry_policy + self.cache = cache + + self.retry_func = ( + self._human_delay if retry_policy == RetryPolicy.HUMAN else self._robot_delay + ) + self.headers: dict[str, str] = {"user-agent": f"openml-python/{__version__}"} + + def _robot_delay(self, n: int) -> float: + wait = (1 / (1 + math.exp(-(n * 0.5 - 4)))) * 60 + variation = random.gauss(0, wait / 10) + return max(1.0, wait + variation) + + def _human_delay(self, n: int) -> float: + return max(1.0, n) + + def _parse_exception_response( + self, + response: Response, + ) -> tuple[int | None, str]: + content_type = response.headers.get("Content-Type", "").lower() + + if "json" in content_type: + server_exception = response.json() + server_error = server_exception["detail"] + code = server_error.get("code") + message = server_error.get("message") + additional_information = server_error.get("additional_information") + else: + server_exception = xmltodict.parse(response.text) + server_error = server_exception["oml:error"] + code = server_error.get("oml:code") + message = server_error.get("oml:message") + additional_information = server_error.get("oml:additional_information") + + if code is not None: + code = int(code) + + if message and additional_information: + full_message = f"{message} - {additional_information}" + elif message: + full_message = message + elif additional_information: + full_message = additional_information + else: + full_message = "" + + return code, full_message + + def _raise_code_specific_error( + self, + code: int, + message: str, + url: str, + files: Mapping[str, Any] | None, + ) -> None: + if code in [111, 372, 512, 500, 482, 542, 674]: + # 512 for runs, 372 for datasets, 500 for flows + # 482 for tasks, 542 for evaluations, 674 for setups + # 111 for dataset descriptions + raise OpenMLServerNoResult(code=code, message=message, url=url) + + # 163: failure to validate flow XML (https://www.openml.org/api_docs#!/flow/post_flow) + if code in [163] and files is not None and "description" in files: + # file_elements['description'] is the XML file description of the flow + message = f"\n{files['description']}\n{message}" + + if code in [ + 102, # flow/exists post + 137, # dataset post + 350, # dataset/42 delete + 310, # flow/ post + 320, # flow/42 delete + 400, # run/42 delete + 460, # task/42 delete + ]: + raise OpenMLNotAuthorizedError( + message=( + f"The API call {url} requires authentication via an API key.\nPlease configure " + "OpenML-Python to use your API as described in this example:" + "\nhttps://openml.github.io/openml-python/latest/examples/Basics/introduction_tutorial/#authentication" + ) + ) + + # Propagate all server errors to the calling functions, except + # for 107 which represents a database connection error. + # These are typically caused by high server load, + # which means trying again might resolve the issue. + # DATABASE_CONNECTION_ERRCODE + if code != 107: + raise OpenMLServerException(code=code, message=message, url=url) + + def _validate_response( + self, + method: str, + url: str, + files: Mapping[str, Any] | None, + response: Response, + ) -> Exception | None: + if ( + "Content-Encoding" not in response.headers + or response.headers["Content-Encoding"] != "gzip" + ): + logging.warning(f"Received uncompressed content from OpenML for {url}.") + + if response.status_code == 200: + return None + + if response.status_code == requests.codes.URI_TOO_LONG: + raise OpenMLServerError(f"URI too long! ({url})") + + retry_raise_e: Exception | None = None + code: int | None = None + message: str = "" + + try: + code, message = self._parse_exception_response(response) + + except (requests.exceptions.JSONDecodeError, xml.parsers.expat.ExpatError) as e: + if method != "GET": + extra = f"Status code: {response.status_code}\n{response.text}" + raise OpenMLServerError( + f"Unexpected server error when calling {url}. Please contact the " + f"developers!\n{extra}" + ) from e + + retry_raise_e = e + + except Exception as e: + # If we failed to parse it out, + # then something has gone wrong in the body we have sent back + # from the server and there is little extra information we can capture. + raise OpenMLServerError( + f"Unexpected server error when calling {url}. Please contact the developers!\n" + f"Status code: {response.status_code}\n{response.text}", + ) from e + + if code is not None: + self._raise_code_specific_error( + code=code, + message=message, + url=url, + files=files, + ) + + if retry_raise_e is None: + retry_raise_e = OpenMLServerException(code=code, message=message, url=url) + + return retry_raise_e + + def _request( # noqa: PLR0913 + self, + session: requests.Session, + method: str, + url: str, + params: Mapping[str, Any], + data: Mapping[str, Any], + headers: Mapping[str, str], + files: Mapping[str, Any] | None, + **request_kwargs: Any, + ) -> tuple[Response | None, Exception | None]: + retry_raise_e: Exception | None = None + response: Response | None = None + + try: + response = session.request( + method=method, + url=url, + params=params, + data=data, + headers=headers, + files=files, + **request_kwargs, + ) + except ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.SSLError, + ) as e: + retry_raise_e = e + + if response is not None: + retry_raise_e = self._validate_response( + method=method, + url=url, + files=files, + response=response, + ) + + return response, retry_raise_e + + def request( # noqa: PLR0913, C901 + self, + method: str, + path: str, + *, + use_cache: bool = False, + reset_cache: bool = False, + use_api_key: bool = False, + md5_checksum: str | None = None, + **request_kwargs: Any, + ) -> Response: + url = urljoin(self.server, urljoin(self.base_url, path)) + retries = max(1, self.retries) + + params = request_kwargs.pop("params", {}).copy() + data = request_kwargs.pop("data", {}).copy() + + if use_api_key: + params["api_key"] = self.api_key + + if method.upper() in {"POST", "PUT", "PATCH"}: + data = {**params, **data} + params = {} + + # prepare headers + headers = request_kwargs.pop("headers", {}).copy() + headers.update(self.headers) + + files = request_kwargs.pop("files", None) + + if use_cache and not reset_cache and self.cache is not None: + cache_key = self.cache.get_key(url, params) + try: + return self.cache.load(cache_key) + except (FileNotFoundError, TimeoutError): + pass # cache miss or expired, continue + except Exception: + raise # propagate unexpected cache errors + + session = requests.Session() + for retry_counter in range(1, retries + 1): + response, retry_raise_e = self._request( + session=session, + method=method, + url=url, + params=params, + data=data, + headers=headers, + files=files, + **request_kwargs, + ) + + # executed successfully + if retry_raise_e is None: + break + # tries completed + if retry_counter >= retries: + raise retry_raise_e + + delay = self.retry_func(retry_counter) + time.sleep(delay) + + session.close() + + assert response is not None + + if use_cache and self.cache is not None: + cache_key = self.cache.get_key(url, params) + self.cache.save(cache_key, response) + + if md5_checksum is not None: + self._verify_checksum(response, md5_checksum) + + return response + + def _verify_checksum(self, response: Response, md5_checksum: str) -> None: + # ruff sees hashlib.md5 as insecure + actual = hashlib.md5(response.content).hexdigest() # noqa: S324 + if actual != md5_checksum: + raise OpenMLHashException( + f"Checksum of downloaded file is unequal to the expected checksum {md5_checksum} " + f"when downloading {response.url}.", + ) + + def get( + self, + path: str, + *, + use_cache: bool = False, + reset_cache: bool = False, + use_api_key: bool = False, + md5_checksum: str | None = None, + **request_kwargs: Any, + ) -> Response: + return self.request( + method="GET", + path=path, + use_cache=use_cache, + reset_cache=reset_cache, + use_api_key=use_api_key, + md5_checksum=md5_checksum, + **request_kwargs, + ) + + def post( + self, + path: str, + *, + use_api_key: bool = True, + **request_kwargs: Any, + ) -> Response: + return self.request( + method="POST", + path=path, + use_cache=False, + use_api_key=use_api_key, + **request_kwargs, + ) + + def delete( + self, + path: str, + **request_kwargs: Any, + ) -> Response: + return self.request( + method="DELETE", + path=path, + use_cache=False, + use_api_key=True, + **request_kwargs, + ) + + def download( + self, + url: str, + handler: Callable[[Response, Path, str], Path] | None = None, + encoding: str = "utf-8", + file_name: str = "response.txt", + md5_checksum: str | None = None, + ) -> Path: + if self.cache is None: + raise OpenMLCacheRequiredError( + "A cache object is required for download, but none was provided in the HTTPClient." + ) + base = self.cache.path + file_path = base / "downloads" / urlparse(url).path.lstrip("/") / file_name + file_path = file_path.expanduser() + file_path.parent.mkdir(parents=True, exist_ok=True) + if file_path.exists(): + return file_path + + response = self.get(url, md5_checksum=md5_checksum) + if handler is not None: + return handler(response, file_path, encoding) + + return self._text_handler(response, file_path, encoding) + + def _text_handler(self, response: Response, path: Path, encoding: str) -> Path: + with path.open("w", encoding=encoding) as f: + f.write(response.text) + return path diff --git a/openml/_api/clients/minio.py b/openml/_api/clients/minio.py new file mode 100644 index 000000000..2edc8269b --- /dev/null +++ b/openml/_api/clients/minio.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from pathlib import Path + +from openml.__version__ import __version__ + + +class MinIOClient: + def __init__(self, path: Path | None = None) -> None: + self.path = path + self.headers: dict[str, str] = {"user-agent": f"openml-python/{__version__}"} diff --git a/openml/_api/resources/__init__.py b/openml/_api/resources/__init__.py new file mode 100644 index 000000000..1f0b2caa1 --- /dev/null +++ b/openml/_api/resources/__init__.py @@ -0,0 +1,65 @@ +from ._registry import API_REGISTRY +from .base import ( + DatasetAPI, + EstimationProcedureAPI, + EvaluationAPI, + EvaluationMeasureAPI, + FallbackProxy, + FlowAPI, + ResourceAPI, + ResourceV1API, + ResourceV2API, + RunAPI, + SetupAPI, + StudyAPI, + TaskAPI, +) +from .dataset import DatasetV1API, DatasetV2API +from .estimation_procedure import ( + EstimationProcedureV1API, + EstimationProcedureV2API, +) +from .evaluation import EvaluationV1API, EvaluationV2API +from .evaluation_measure import EvaluationMeasureV1API, EvaluationMeasureV2API +from .flow import FlowV1API, FlowV2API +from .run import RunV1API, RunV2API +from .setup import SetupV1API, SetupV2API +from .study import StudyV1API, StudyV2API +from .task import TaskV1API, TaskV2API + +__all__ = [ + "API_REGISTRY", + "DatasetAPI", + "DatasetV1API", + "DatasetV2API", + "EstimationProcedureAPI", + "EstimationProcedureV1API", + "EstimationProcedureV2API", + "EvaluationAPI", + "EvaluationMeasureAPI", + "EvaluationMeasureV1API", + "EvaluationMeasureV2API", + "EvaluationV1API", + "EvaluationV2API", + "FallbackProxy", + "FallbackProxy", + "FlowAPI", + "FlowV1API", + "FlowV2API", + "ResourceAPI", + "ResourceAPI", + "ResourceV1API", + "ResourceV2API", + "RunAPI", + "RunV1API", + "RunV2API", + "SetupAPI", + "SetupV1API", + "SetupV2API", + "StudyAPI", + "StudyV1API", + "StudyV2API", + "TaskAPI", + "TaskV1API", + "TaskV2API", +] diff --git a/openml/_api/resources/_registry.py b/openml/_api/resources/_registry.py new file mode 100644 index 000000000..66d7ec428 --- /dev/null +++ b/openml/_api/resources/_registry.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from openml.enums import APIVersion, ResourceType + +from .dataset import DatasetV1API, DatasetV2API +from .estimation_procedure import ( + EstimationProcedureV1API, + EstimationProcedureV2API, +) +from .evaluation import EvaluationV1API, EvaluationV2API +from .evaluation_measure import EvaluationMeasureV1API, EvaluationMeasureV2API +from .flow import FlowV1API, FlowV2API +from .run import RunV1API, RunV2API +from .setup import SetupV1API, SetupV2API +from .study import StudyV1API, StudyV2API +from .task import TaskV1API, TaskV2API + +if TYPE_CHECKING: + from .base import ResourceAPI + +API_REGISTRY: dict[ + APIVersion, + dict[ResourceType, type[ResourceAPI]], +] = { + APIVersion.V1: { + ResourceType.DATASET: DatasetV1API, + ResourceType.TASK: TaskV1API, + ResourceType.EVALUATION_MEASURE: EvaluationMeasureV1API, + ResourceType.ESTIMATION_PROCEDURE: EstimationProcedureV1API, + ResourceType.EVALUATION: EvaluationV1API, + ResourceType.FLOW: FlowV1API, + ResourceType.STUDY: StudyV1API, + ResourceType.RUN: RunV1API, + ResourceType.SETUP: SetupV1API, + }, + APIVersion.V2: { + ResourceType.DATASET: DatasetV2API, + ResourceType.TASK: TaskV2API, + ResourceType.EVALUATION_MEASURE: EvaluationMeasureV2API, + ResourceType.ESTIMATION_PROCEDURE: EstimationProcedureV2API, + ResourceType.EVALUATION: EvaluationV2API, + ResourceType.FLOW: FlowV2API, + ResourceType.STUDY: StudyV2API, + ResourceType.RUN: RunV2API, + ResourceType.SETUP: SetupV2API, + }, +} diff --git a/openml/_api/resources/base/__init__.py b/openml/_api/resources/base/__init__.py new file mode 100644 index 000000000..ed6dc26f7 --- /dev/null +++ b/openml/_api/resources/base/__init__.py @@ -0,0 +1,30 @@ +from .base import ResourceAPI +from .fallback import FallbackProxy +from .resources import ( + DatasetAPI, + EstimationProcedureAPI, + EvaluationAPI, + EvaluationMeasureAPI, + FlowAPI, + RunAPI, + SetupAPI, + StudyAPI, + TaskAPI, +) +from .versions import ResourceV1API, ResourceV2API + +__all__ = [ + "DatasetAPI", + "EstimationProcedureAPI", + "EvaluationAPI", + "EvaluationMeasureAPI", + "FallbackProxy", + "FlowAPI", + "ResourceAPI", + "ResourceV1API", + "ResourceV2API", + "RunAPI", + "SetupAPI", + "StudyAPI", + "TaskAPI", +] diff --git a/openml/_api/resources/base/base.py b/openml/_api/resources/base/base.py new file mode 100644 index 000000000..5eadc4932 --- /dev/null +++ b/openml/_api/resources/base/base.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, NoReturn + +from openml.exceptions import OpenMLNotSupportedError + +if TYPE_CHECKING: + from collections.abc import Mapping + from typing import Any + + from openml._api.clients import HTTPClient, MinIOClient + from openml.enums import APIVersion, ResourceType + + +class ResourceAPI(ABC): + api_version: APIVersion + resource_type: ResourceType + + def __init__(self, http: HTTPClient, minio: MinIOClient | None = None): + self._http = http + self._minio = minio + + @abstractmethod + def delete(self, resource_id: int) -> bool: ... + + @abstractmethod + def publish(self, path: str, files: Mapping[str, Any] | None) -> int: ... + + @abstractmethod + def tag(self, resource_id: int, tag: str) -> list[str]: ... + + @abstractmethod + def untag(self, resource_id: int, tag: str) -> list[str]: ... + + def _not_supported(self, *, method: str) -> NoReturn: + version = getattr(self.api_version, "value", "unknown") + resource = getattr(self.resource_type, "value", "unknown") + + raise OpenMLNotSupportedError( + f"{self.__class__.__name__}: " + f"{version} API does not support `{method}` " + f"for resource `{resource}`" + ) diff --git a/openml/_api/resources/base/fallback.py b/openml/_api/resources/base/fallback.py new file mode 100644 index 000000000..3919c36a9 --- /dev/null +++ b/openml/_api/resources/base/fallback.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from openml.exceptions import OpenMLNotSupportedError + + +class FallbackProxy: + def __init__(self, *api_versions: Any): + if not api_versions: + raise ValueError("At least one API version must be provided") + self._apis = api_versions + + def __getattr__(self, name: str) -> Any: + api, attr = self._find_attr(name) + if callable(attr): + return self._wrap_callable(name, api, attr) + return attr + + def _find_attr(self, name: str) -> tuple[Any, Any]: + for api in self._apis: + attr = getattr(api, name, None) + if attr is not None: + return api, attr + raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") + + def _wrap_callable( + self, + name: str, + primary_api: Any, + primary_attr: Callable[..., Any], + ) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return primary_attr(*args, **kwargs) + except OpenMLNotSupportedError: + return self._call_fallbacks(name, primary_api, *args, **kwargs) + + return wrapper + + def _call_fallbacks( + self, + name: str, + skip_api: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + for api in self._apis: + if api is skip_api: + continue + attr = getattr(api, name, None) + if callable(attr): + try: + return attr(*args, **kwargs) + except OpenMLNotSupportedError: + continue + raise OpenMLNotSupportedError(f"Could not fallback to any API for method: {name}") diff --git a/openml/_api/resources/base/resources.py b/openml/_api/resources/base/resources.py new file mode 100644 index 000000000..18c290e9f --- /dev/null +++ b/openml/_api/resources/base/resources.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from openml.enums import ResourceType +from openml.exceptions import OpenMLCacheRequiredError + +from .base import ResourceAPI + +if TYPE_CHECKING: + import pandas as pd + from requests import Response + from traitlets import Any + + from openml.tasks.task import OpenMLTask, TaskType + + +class DatasetAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.DATASET + + +class TaskAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.TASK + + @abstractmethod + def get( + self, + task_id: int, + ) -> OpenMLTask: + """ + API v1: + GET /task/{task_id} + + API v2: + GET /tasks/{task_id} + """ + ... + + # Task listing (V1 only) + @abstractmethod + def list( + self, + limit: int, + offset: int, + task_type: TaskType | int | None = None, + **kwargs: Any, + ) -> pd.DataFrame: + """ + List tasks with filters. + + API v1: + GET /task/list + + API v2: + Not available. + + Returns + ------- + pandas.DataFrame + """ + ... + + def download( + self, + url: str, + handler: Callable[[Response, Path, str], Path] | None = None, + encoding: str = "utf-8", + file_name: str = "response.txt", + md5_checksum: str | None = None, + ) -> Path: + if self._http.cache is None: + raise OpenMLCacheRequiredError( + "A cache object is required for download, but none was provided in the HTTPClient." + ) + base = self._http.cache.path + file_path = base / "downloads" / urlparse(url).path.lstrip("/") / file_name + file_path = file_path.expanduser() + file_path.parent.mkdir(parents=True, exist_ok=True) + if file_path.exists(): + return file_path + + response = self._http.get(url, md5_checksum=md5_checksum) + if handler is not None: + return handler(response, file_path, encoding) + + return self._text_handler(response, file_path, encoding) + + def _text_handler(self, response: Response, path: Path, encoding: str) -> Path: + with path.open("w", encoding=encoding) as f: + f.write(response.text) + return path + + +class EvaluationMeasureAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.EVALUATION_MEASURE + + +class EstimationProcedureAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.ESTIMATION_PROCEDURE + + +class EvaluationAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.EVALUATION + + +class FlowAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.FLOW + + +class StudyAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.STUDY + + +class RunAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.RUN + + +class SetupAPI(ResourceAPI): + resource_type: ResourceType = ResourceType.SETUP diff --git a/openml/_api/resources/base/versions.py b/openml/_api/resources/base/versions.py new file mode 100644 index 000000000..b86272377 --- /dev/null +++ b/openml/_api/resources/base/versions.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, cast + +import xmltodict + +from openml.enums import APIVersion, ResourceType +from openml.exceptions import ( + OpenMLNotAuthorizedError, + OpenMLServerError, + OpenMLServerException, +) + +from .base import ResourceAPI + + +class ResourceV1API(ResourceAPI): + api_version: APIVersion = APIVersion.V1 + + def publish(self, path: str, files: Mapping[str, Any] | None) -> int: + response = self._http.post(path, files=files) + parsed_response = xmltodict.parse(response.content) + return self._extract_id_from_upload(parsed_response) + + def delete(self, resource_id: int) -> bool: + resource_type = self._get_endpoint_name() + + legal_resources = {"data", "flow", "task", "run", "study", "user"} + if resource_type not in legal_resources: + raise ValueError(f"Can't delete a {resource_type}") + + path = f"{resource_type}/{resource_id}" + try: + response = self._http.delete(path) + result = xmltodict.parse(response.content) + return f"oml:{resource_type}_delete" in result + except OpenMLServerException as e: + self._handle_delete_exception(resource_type, e) + raise + + def tag(self, resource_id: int, tag: str) -> list[str]: + resource_type = self._get_endpoint_name() + + legal_resources = {"data", "task", "flow", "setup", "run"} + if resource_type not in legal_resources: + raise ValueError(f"Can't tag a {resource_type}") + + path = f"{resource_type}/tag" + data = {f"{resource_type}_id": resource_id, "tag": tag} + response = self._http.post(path, data=data) + + main_tag = f"oml:{resource_type}_tag" + parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"}) + result = parsed_response[main_tag] + tags: list[str] = result.get("oml:tag", []) + + return tags + + def untag(self, resource_id: int, tag: str) -> list[str]: + resource_type = self._get_endpoint_name() + + legal_resources = {"data", "task", "flow", "setup", "run"} + if resource_type not in legal_resources: + raise ValueError(f"Can't tag a {resource_type}") + + path = f"{resource_type}/untag" + data = {f"{resource_type}_id": resource_id, "tag": tag} + response = self._http.post(path, data=data) + + main_tag = f"oml:{resource_type}_untag" + parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"}) + result = parsed_response[main_tag] + tags: list[str] = result.get("oml:tag", []) + + return tags + + def _get_endpoint_name(self) -> str: + if self.resource_type == ResourceType.DATASET: + return "data" + return cast("str", self.resource_type.value) + + def _handle_delete_exception( + self, resource_type: str, exception: OpenMLServerException + ) -> None: + # https://github.com/openml/OpenML/blob/21f6188d08ac24fcd2df06ab94cf421c946971b0/openml_OS/views/pages/api_new/v1/xml/pre.php + # Most exceptions are descriptive enough to be raised as their standard + # OpenMLServerException, however there are two cases where we add information: + # - a generic "failed" message, we direct them to the right issue board + # - when the user successfully authenticates with the server, + # but user is not allowed to take the requested action, + # in which case we specify a OpenMLNotAuthorizedError. + by_other_user = [323, 353, 393, 453, 594] + has_dependent_entities = [324, 326, 327, 328, 354, 454, 464, 595] + unknown_reason = [325, 355, 394, 455, 593] + if exception.code in by_other_user: + raise OpenMLNotAuthorizedError( + message=( + f"The {resource_type} can not be deleted because it was not uploaded by you." + ), + ) from exception + if exception.code in has_dependent_entities: + raise OpenMLNotAuthorizedError( + message=( + f"The {resource_type} can not be deleted because " + f"it still has associated entities: {exception.message}" + ), + ) from exception + if exception.code in unknown_reason: + raise OpenMLServerError( + message=( + f"The {resource_type} can not be deleted for unknown reason," + " please open an issue at: https://github.com/openml/openml/issues/new" + ), + ) from exception + raise exception + + def _extract_id_from_upload(self, parsed: Mapping[str, Any]) -> int: + # reads id from upload response + # actual parsed dict: {"oml:upload_flow": {"@xmlns:oml": "...", "oml:id": "42"}} + + # xmltodict always gives exactly one root key + ((_, root_value),) = parsed.items() + + if not isinstance(root_value, Mapping): + raise ValueError("Unexpected XML structure") + + # Look for oml:id directly in the root value + if "oml:id" in root_value: + id_value = root_value["oml:id"] + if isinstance(id_value, (str, int)): + return int(id_value) + + # Fallback: check all values for numeric/string IDs + for v in root_value.values(): + if isinstance(v, (str, int)): + return int(v) + + raise ValueError("No ID found in upload response") + + +class ResourceV2API(ResourceAPI): + api_version: APIVersion = APIVersion.V2 + + def publish(self, path: str, files: Mapping[str, Any] | None) -> int: # noqa: ARG002 + self._not_supported(method="publish") + + def delete(self, resource_id: int) -> bool: # noqa: ARG002 + self._not_supported(method="delete") + + def tag(self, resource_id: int, tag: str) -> list[str]: # noqa: ARG002 + self._not_supported(method="tag") + + def untag(self, resource_id: int, tag: str) -> list[str]: # noqa: ARG002 + self._not_supported(method="untag") diff --git a/openml/_api/resources/dataset.py b/openml/_api/resources/dataset.py new file mode 100644 index 000000000..51688a2fd --- /dev/null +++ b/openml/_api/resources/dataset.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import DatasetAPI, ResourceV1API, ResourceV2API + + +class DatasetV1API(ResourceV1API, DatasetAPI): + pass + + +class DatasetV2API(ResourceV2API, DatasetAPI): + pass diff --git a/openml/_api/resources/estimation_procedure.py b/openml/_api/resources/estimation_procedure.py new file mode 100644 index 000000000..b8ea7d2c3 --- /dev/null +++ b/openml/_api/resources/estimation_procedure.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import EstimationProcedureAPI, ResourceV1API, ResourceV2API + + +class EstimationProcedureV1API(ResourceV1API, EstimationProcedureAPI): + pass + + +class EstimationProcedureV2API(ResourceV2API, EstimationProcedureAPI): + pass diff --git a/openml/_api/resources/evaluation.py b/openml/_api/resources/evaluation.py new file mode 100644 index 000000000..07877e14e --- /dev/null +++ b/openml/_api/resources/evaluation.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import EvaluationAPI, ResourceV1API, ResourceV2API + + +class EvaluationV1API(ResourceV1API, EvaluationAPI): + pass + + +class EvaluationV2API(ResourceV2API, EvaluationAPI): + pass diff --git a/openml/_api/resources/evaluation_measure.py b/openml/_api/resources/evaluation_measure.py new file mode 100644 index 000000000..63cf16c77 --- /dev/null +++ b/openml/_api/resources/evaluation_measure.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import EvaluationMeasureAPI, ResourceV1API, ResourceV2API + + +class EvaluationMeasureV1API(ResourceV1API, EvaluationMeasureAPI): + pass + + +class EvaluationMeasureV2API(ResourceV2API, EvaluationMeasureAPI): + pass diff --git a/openml/_api/resources/flow.py b/openml/_api/resources/flow.py new file mode 100644 index 000000000..ad2e05bd9 --- /dev/null +++ b/openml/_api/resources/flow.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import FlowAPI, ResourceV1API, ResourceV2API + + +class FlowV1API(ResourceV1API, FlowAPI): + pass + + +class FlowV2API(ResourceV2API, FlowAPI): + pass diff --git a/openml/_api/resources/run.py b/openml/_api/resources/run.py new file mode 100644 index 000000000..151c69e35 --- /dev/null +++ b/openml/_api/resources/run.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import ResourceV1API, ResourceV2API, RunAPI + + +class RunV1API(ResourceV1API, RunAPI): + pass + + +class RunV2API(ResourceV2API, RunAPI): + pass diff --git a/openml/_api/resources/setup.py b/openml/_api/resources/setup.py new file mode 100644 index 000000000..78a36cecc --- /dev/null +++ b/openml/_api/resources/setup.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import ResourceV1API, ResourceV2API, SetupAPI + + +class SetupV1API(ResourceV1API, SetupAPI): + pass + + +class SetupV2API(ResourceV2API, SetupAPI): + pass diff --git a/openml/_api/resources/study.py b/openml/_api/resources/study.py new file mode 100644 index 000000000..cefd55004 --- /dev/null +++ b/openml/_api/resources/study.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import ResourceV1API, ResourceV2API, StudyAPI + + +class StudyV1API(ResourceV1API, StudyAPI): + pass + + +class StudyV2API(ResourceV2API, StudyAPI): + pass diff --git a/openml/_api/resources/task.py b/openml/_api/resources/task.py new file mode 100644 index 000000000..239dbe2e0 --- /dev/null +++ b/openml/_api/resources/task.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import builtins +import warnings +from typing import Any + +import pandas as pd +import xmltodict + +from openml.tasks.task import ( + OpenMLClassificationTask, + OpenMLClusteringTask, + OpenMLLearningCurveTask, + OpenMLRegressionTask, + OpenMLTask, + TaskType, +) + +from .base import ResourceV1API, ResourceV2API, TaskAPI + + +class TaskV1API(ResourceV1API, TaskAPI): + def get(self, task_id: int) -> OpenMLTask: + """Download OpenML task for a given task ID. + + Downloads the task representation. + + Parameters + ---------- + task_id : int + The OpenML task id of the task to download. + get_dataset_kwargs : + Args and kwargs can be used pass optional parameters to + :meth:`openml.datasets.get_dataset`. + + Returns + ------- + task: OpenMLTask + """ + if not isinstance(task_id, int): + raise TypeError(f"Task id should be integer, is {type(task_id)}") + + response = self._http.get(f"task/{task_id}", use_cache=True) + return self._create_task_from_xml(response.text) + + def _create_task_from_xml(self, xml: str) -> OpenMLTask: + """Create a task given a xml string. + + Parameters + ---------- + xml : string + Task xml representation. + + Returns + ------- + OpenMLTask + """ + dic = xmltodict.parse(xml)["oml:task"] + estimation_parameters = {} + inputs = {} + # Due to the unordered structure we obtain, we first have to extract + # the possible keys of oml:input; dic["oml:input"] is a list of + # OrderedDicts + + # Check if there is a list of inputs + if isinstance(dic["oml:input"], list): + for input_ in dic["oml:input"]: + name = input_["@name"] + inputs[name] = input_ + # Single input case + elif isinstance(dic["oml:input"], dict): + name = dic["oml:input"]["@name"] + inputs[name] = dic["oml:input"] + + evaluation_measures = None + if "evaluation_measures" in inputs: + evaluation_measures = inputs["evaluation_measures"]["oml:evaluation_measures"][ + "oml:evaluation_measure" + ] + + task_type = TaskType(int(dic["oml:task_type_id"])) + common_kwargs = { + "task_id": dic["oml:task_id"], + "task_type": dic["oml:task_type"], + "task_type_id": task_type, + "data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"], + "evaluation_measure": evaluation_measures, + } + # TODO: add OpenMLClusteringTask? + if task_type in ( + TaskType.SUPERVISED_CLASSIFICATION, + TaskType.SUPERVISED_REGRESSION, + TaskType.LEARNING_CURVE, + ): + # Convert some more parameters + for parameter in inputs["estimation_procedure"]["oml:estimation_procedure"][ + "oml:parameter" + ]: + name = parameter["@name"] + text = parameter.get("#text", "") + estimation_parameters[name] = text + + common_kwargs["estimation_procedure_type"] = inputs["estimation_procedure"][ + "oml:estimation_procedure" + ]["oml:type"] + common_kwargs["estimation_procedure_id"] = int( + inputs["estimation_procedure"]["oml:estimation_procedure"]["oml:id"] + ) + + common_kwargs["estimation_parameters"] = estimation_parameters + common_kwargs["target_name"] = inputs["source_data"]["oml:data_set"][ + "oml:target_feature" + ] + common_kwargs["data_splits_url"] = inputs["estimation_procedure"][ + "oml:estimation_procedure" + ]["oml:data_splits_url"] + + cls = { + TaskType.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask, + TaskType.SUPERVISED_REGRESSION: OpenMLRegressionTask, + TaskType.CLUSTERING: OpenMLClusteringTask, + TaskType.LEARNING_CURVE: OpenMLLearningCurveTask, + }.get(task_type) + if cls is None: + raise NotImplementedError(f"Task type {common_kwargs['task_type']} not supported.") + return cls(**common_kwargs) # type: ignore + + def list( + self, + limit: int, + offset: int, + task_type: TaskType | int | None = None, + **kwargs: Any, + ) -> pd.DataFrame: + """ + Perform the api call to return a number of tasks having the given filters. + + Parameters + ---------- + Filter task_type is separated from the other filters because + it is used as task_type in the task description, but it is named + type when used as a filter in list tasks call. + limit: int + offset: int + task_type : TaskType, optional + Refers to the type of task. + kwargs: dict, optional + Legal filter operators: tag, task_id (list), data_tag, status, limit, + offset, data_id, data_name, number_instances, number_features, + number_classes, number_missing_values. + + Returns + ------- + dataframe + """ + api_call = "task/list" + if limit is not None: + api_call += f"/limit/{limit}" + if offset is not None: + api_call += f"/offset/{offset}" + if task_type is not None: + tvalue = task_type.value if isinstance(task_type, TaskType) else task_type + api_call += f"/type/{tvalue}" + if kwargs is not None: + for operator, value in kwargs.items(): + if value is not None: + if operator == "task_id": + value = ",".join([str(int(i)) for i in value]) # noqa: PLW2901 + api_call += f"/{operator}/{value}" + + return self._fetch_tasks_df(api_call=api_call) + + def _fetch_tasks_df(self, api_call: str) -> pd.DataFrame: # noqa: C901, PLR0912 + """Returns a Pandas DataFrame with information about OpenML tasks. + + Parameters + ---------- + api_call : str + The API call specifying which tasks to return. + + Returns + ------- + A Pandas DataFrame with information about OpenML tasks. + + Raises + ------ + ValueError + If the XML returned by the OpenML API does not contain 'oml:tasks', '@xmlns:oml', + or has an incorrect value for '@xmlns:oml'. + KeyError + If an invalid key is found in the XML for a task. + """ + xml_string = self._http.get(api_call).text + + tasks_dict = xmltodict.parse(xml_string, force_list=("oml:task", "oml:input")) + # Minimalistic check if the XML is useful + if "oml:tasks" not in tasks_dict: + raise ValueError(f'Error in return XML, does not contain "oml:runs": {tasks_dict}') + + if "@xmlns:oml" not in tasks_dict["oml:tasks"]: + raise ValueError( + f'Error in return XML, does not contain "oml:runs"/@xmlns:oml: {tasks_dict}' + ) + + if tasks_dict["oml:tasks"]["@xmlns:oml"] != "http://openml.org/openml": + raise ValueError( + "Error in return XML, value of " + '"oml:runs"/@xmlns:oml is not ' + f'"http://openml.org/openml": {tasks_dict!s}', + ) + + assert isinstance(tasks_dict["oml:tasks"]["oml:task"], list), type(tasks_dict["oml:tasks"]) + + tasks = {} + procs = self._get_estimation_procedure_list() + proc_dict = {x["id"]: x for x in procs} + + for task_ in tasks_dict["oml:tasks"]["oml:task"]: + tid = None + try: + tid = int(task_["oml:task_id"]) + task_type_int = int(task_["oml:task_type_id"]) + try: + task_type_id = TaskType(task_type_int) + except ValueError as e: + warnings.warn( + f"Could not create task type id for {task_type_int} due to error {e}", + RuntimeWarning, + stacklevel=2, + ) + continue + + task = { + "tid": tid, + "ttid": task_type_id, + "did": int(task_["oml:did"]), + "name": task_["oml:name"], + "task_type": task_["oml:task_type"], + "status": task_["oml:status"], + } + + # Other task inputs + for _input in task_.get("oml:input", []): + if _input["@name"] == "estimation_procedure": + task[_input["@name"]] = proc_dict[int(_input["#text"])]["name"] + else: + value = _input.get("#text") + task[_input["@name"]] = value + + # The number of qualities can range from 0 to infinity + for quality in task_.get("oml:quality", []): + if "#text" not in quality: + quality_value = 0.0 + else: + quality["#text"] = float(quality["#text"]) + if abs(int(quality["#text"]) - quality["#text"]) < 0.0000001: + quality["#text"] = int(quality["#text"]) + quality_value = quality["#text"] + task[quality["@name"]] = quality_value + tasks[tid] = task + except KeyError as e: + if tid is not None: + warnings.warn( + f"Invalid xml for task {tid}: {e}\nFrom {task_}", + RuntimeWarning, + stacklevel=2, + ) + else: + warnings.warn( + f"Could not find key {e} in {task_}!", RuntimeWarning, stacklevel=2 + ) + + return pd.DataFrame.from_dict(tasks, orient="index") + + def _get_estimation_procedure_list(self) -> builtins.list[dict[str, Any]]: + """Return a list of all estimation procedures which are on OpenML. + + Returns + ------- + procedures : list + A list of all estimation procedures. Every procedure is represented by + a dictionary containing the following information: id, task type id, + name, type, repeats, folds, stratified. + """ + url_suffix = "estimationprocedure/list" + xml_string = self._http.get(url_suffix).text + + procs_dict = xmltodict.parse(xml_string) + # Minimalistic check if the XML is useful + if "oml:estimationprocedures" not in procs_dict: + raise ValueError("Error in return XML, does not contain tag oml:estimationprocedures.") + + if "@xmlns:oml" not in procs_dict["oml:estimationprocedures"]: + raise ValueError( + "Error in return XML, does not contain tag " + "@xmlns:oml as a child of oml:estimationprocedures.", + ) + + if procs_dict["oml:estimationprocedures"]["@xmlns:oml"] != "http://openml.org/openml": + raise ValueError( + "Error in return XML, value of " + "oml:estimationprocedures/@xmlns:oml is not " + "http://openml.org/openml, but {}".format( + str(procs_dict["oml:estimationprocedures"]["@xmlns:oml"]) + ), + ) + + procs: list[dict[str, Any]] = [] + for proc_ in procs_dict["oml:estimationprocedures"]["oml:estimationprocedure"]: + task_type_int = int(proc_["oml:ttid"]) + try: + task_type_id = TaskType(task_type_int) + procs.append( + { + "id": int(proc_["oml:id"]), + "task_type_id": task_type_id, + "name": proc_["oml:name"], + "type": proc_["oml:type"], + }, + ) + except ValueError as e: + warnings.warn( + f"Could not create task type id for {task_type_int} due to error {e}", + RuntimeWarning, + stacklevel=2, + ) + + return procs + + +class TaskV2API(ResourceV2API, TaskAPI): + def get(self, task_id: int) -> OpenMLTask: + response = self._http.get(f"tasks/{task_id}", use_cache=True) + return self._create_task_from_json(response.json()) + + def _create_task_from_json(self, task_json: dict) -> OpenMLTask: + task_type_id = TaskType(int(task_json["task_type_id"])) + + inputs = {i["name"]: i for i in task_json.get("input", [])} + + source = inputs["source_data"]["data_set"] + + common_kwargs = { + "task_id": int(task_json["id"]), + "task_type": task_json["task_type"], + "task_type_id": task_type_id, + "data_set_id": int(source["data_set_id"]), + "evaluation_measure": None, + } + + if task_type_id in ( + TaskType.SUPERVISED_CLASSIFICATION, + TaskType.SUPERVISED_REGRESSION, + TaskType.LEARNING_CURVE, + ): + est = inputs.get("estimation_procedure", {}).get("estimation_procedure") + + if est: + common_kwargs["estimation_procedure_id"] = int(est["id"]) + common_kwargs["estimation_procedure_type"] = est["type"] + common_kwargs["estimation_parameters"] = { + p["name"]: p.get("value") for p in est.get("parameter", []) + } + + common_kwargs["target_name"] = source.get("target_feature") + + cls = { + TaskType.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask, + TaskType.SUPERVISED_REGRESSION: OpenMLRegressionTask, + TaskType.CLUSTERING: OpenMLClusteringTask, + TaskType.LEARNING_CURVE: OpenMLLearningCurveTask, + }[task_type_id] + + return cls(**common_kwargs) # type: ignore + + def list( + self, + limit: int, # noqa: ARG002 + offset: int, # noqa: ARG002 + task_type: TaskType | int | None = None, # noqa: ARG002 + **kwargs: Any, # noqa: ARG002 + ) -> pd.DataFrame: + raise self._not_supported(method="list") diff --git a/openml/_api/setup/__init__.py b/openml/_api/setup/__init__.py new file mode 100644 index 000000000..1c28cfa9e --- /dev/null +++ b/openml/_api/setup/__init__.py @@ -0,0 +1,14 @@ +from ._instance import _backend +from .backend import APIBackend +from .builder import APIBackendBuilder +from .config import APIConfig, CacheConfig, Config, ConnectionConfig + +__all__ = [ + "APIBackend", + "APIBackendBuilder", + "APIConfig", + "CacheConfig", + "Config", + "ConnectionConfig", + "_backend", +] diff --git a/openml/_api/setup/_instance.py b/openml/_api/setup/_instance.py new file mode 100644 index 000000000..c98ccaf57 --- /dev/null +++ b/openml/_api/setup/_instance.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .backend import APIBackend + +_backend = APIBackend.get_instance() diff --git a/openml/_api/setup/_utils.py b/openml/_api/setup/_utils.py new file mode 100644 index 000000000..ddcf5b41c --- /dev/null +++ b/openml/_api/setup/_utils.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import logging +import os +import platform +from pathlib import Path + +openml_logger = logging.getLogger("openml") + +# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards) +_user_path = Path("~").expanduser().absolute() + + +def _resolve_default_cache_dir() -> Path: + user_defined_cache_dir = os.environ.get("OPENML_CACHE_DIR") + if user_defined_cache_dir is not None: + return Path(user_defined_cache_dir) + + if platform.system().lower() != "linux": + return _user_path / ".openml" + + xdg_cache_home = os.environ.get("XDG_CACHE_HOME") + if xdg_cache_home is None: + return Path("~", ".cache", "openml") + + # This is the proper XDG_CACHE_HOME directory, but + # we unfortunately had a problem where we used XDG_CACHE_HOME/org, + # we check heuristically if this old directory still exists and issue + # a warning if it does. There's too much data to move to do this for the user. + + # The new cache directory exists + cache_dir = Path(xdg_cache_home) / "openml" + if cache_dir.exists(): + return cache_dir + + # The old cache directory *does not* exist + heuristic_dir_for_backwards_compat = Path(xdg_cache_home) / "org" / "openml" + if not heuristic_dir_for_backwards_compat.exists(): + return cache_dir + + root_dir_to_delete = Path(xdg_cache_home) / "org" + openml_logger.warning( + "An old cache directory was found at '%s'. This directory is no longer used by " + "OpenML-Python. To silence this warning you would need to delete the old cache " + "directory. The cached files will then be located in '%s'.", + root_dir_to_delete, + cache_dir, + ) + return Path(xdg_cache_home) diff --git a/openml/_api/setup/backend.py b/openml/_api/setup/backend.py new file mode 100644 index 000000000..c29d1dbad --- /dev/null +++ b/openml/_api/setup/backend.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING, Any, cast + +from .builder import APIBackendBuilder +from .config import Config + +if TYPE_CHECKING: + from openml._api.resources import ( + DatasetAPI, + EstimationProcedureAPI, + EvaluationAPI, + EvaluationMeasureAPI, + FlowAPI, + RunAPI, + SetupAPI, + StudyAPI, + TaskAPI, + ) + + +class APIBackend: + _instance: APIBackend | None = None + + def __init__(self, config: Config | None = None): + self._config: Config = config or Config() + self._backend = APIBackendBuilder.build(self._config) + + @property + def dataset(self) -> DatasetAPI: + return cast("DatasetAPI", self._backend.dataset) + + @property + def task(self) -> TaskAPI: + return cast("TaskAPI", self._backend.task) + + @property + def evaluation_measure(self) -> EvaluationMeasureAPI: + return cast("EvaluationMeasureAPI", self._backend.evaluation_measure) + + @property + def estimation_procedure(self) -> EstimationProcedureAPI: + return cast("EstimationProcedureAPI", self._backend.estimation_procedure) + + @property + def evaluation(self) -> EvaluationAPI: + return cast("EvaluationAPI", self._backend.evaluation) + + @property + def flow(self) -> FlowAPI: + return cast("FlowAPI", self._backend.flow) + + @property + def study(self) -> StudyAPI: + return cast("StudyAPI", self._backend.study) + + @property + def run(self) -> RunAPI: + return cast("RunAPI", self._backend.run) + + @property + def setup(self) -> SetupAPI: + return cast("SetupAPI", self._backend.setup) + + @classmethod + def get_instance(cls) -> APIBackend: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def get_config(cls) -> Config: + return deepcopy(cls.get_instance()._config) + + @classmethod + def set_config(cls, config: Config) -> None: + instance = cls.get_instance() + instance._config = config + instance._backend = APIBackendBuilder.build(config) + + @classmethod + def get_config_value(cls, key: str) -> Any: + keys = key.split(".") + config_value = cls.get_instance()._config + for k in keys: + if isinstance(config_value, dict): + config_value = config_value[k] + else: + config_value = getattr(config_value, k) + return deepcopy(config_value) + + @classmethod + def set_config_value(cls, key: str, value: Any) -> None: + keys = key.split(".") + config = cls.get_instance()._config + parent = config + for k in keys[:-1]: + parent = parent[k] if isinstance(parent, dict) else getattr(parent, k) + if isinstance(parent, dict): + parent[keys[-1]] = value + else: + setattr(parent, keys[-1], value) + cls.set_config(config) + + @classmethod + def get_config_values(cls, keys: list[str]) -> list[Any]: + values = [] + for key in keys: + value = cls.get_config_value(key) + values.append(value) + return values + + @classmethod + def set_config_values(cls, config_dict: dict[str, Any]) -> None: + config = cls.get_instance()._config + + for key, value in config_dict.items(): + keys = key.split(".") + parent = config + for k in keys[:-1]: + parent = parent[k] if isinstance(parent, dict) else getattr(parent, k) + if isinstance(parent, dict): + parent[keys[-1]] = value + else: + setattr(parent, keys[-1], value) + + cls.set_config(config) diff --git a/openml/_api/setup/builder.py b/openml/_api/setup/builder.py new file mode 100644 index 000000000..f801fe525 --- /dev/null +++ b/openml/_api/setup/builder.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING + +from openml._api.clients import HTTPCache, HTTPClient, MinIOClient +from openml._api.resources import API_REGISTRY, FallbackProxy, ResourceAPI +from openml.enums import ResourceType + +if TYPE_CHECKING: + from .config import Config + + +class APIBackendBuilder: + def __init__( + self, + resource_apis: Mapping[ResourceType, ResourceAPI | FallbackProxy], + ): + self.dataset = resource_apis[ResourceType.DATASET] + self.task = resource_apis[ResourceType.TASK] + self.evaluation_measure = resource_apis[ResourceType.EVALUATION_MEASURE] + self.estimation_procedure = resource_apis[ResourceType.ESTIMATION_PROCEDURE] + self.evaluation = resource_apis[ResourceType.EVALUATION] + self.flow = resource_apis[ResourceType.FLOW] + self.study = resource_apis[ResourceType.STUDY] + self.run = resource_apis[ResourceType.RUN] + self.setup = resource_apis[ResourceType.SETUP] + + @classmethod + def build(cls, config: Config) -> APIBackendBuilder: + cache_dir = Path(config.cache.dir).expanduser() + + http_cache = HTTPCache(path=cache_dir, ttl=config.cache.ttl) + minio_client = MinIOClient(path=cache_dir) + + primary_api_config = config.api_configs[config.api_version] + primary_http_client = HTTPClient( + server=primary_api_config.server, + base_url=primary_api_config.base_url, + api_key=primary_api_config.api_key, + retries=config.connection.retries, + retry_policy=config.connection.retry_policy, + cache=http_cache, + ) + + resource_apis: dict[ResourceType, ResourceAPI] = {} + for resource_type, resource_api_cls in API_REGISTRY[config.api_version].items(): + resource_apis[resource_type] = resource_api_cls(primary_http_client, minio_client) + + if config.fallback_api_version is None: + return cls(resource_apis) + + fallback_api_config = config.api_configs[config.fallback_api_version] + fallback_http_client = HTTPClient( + server=fallback_api_config.server, + base_url=fallback_api_config.base_url, + api_key=fallback_api_config.api_key, + retries=config.connection.retries, + retry_policy=config.connection.retry_policy, + cache=http_cache, + ) + + fallback_resource_apis: dict[ResourceType, ResourceAPI] = {} + for resource_type, resource_api_cls in API_REGISTRY[config.fallback_api_version].items(): + fallback_resource_apis[resource_type] = resource_api_cls( + fallback_http_client, minio_client + ) + + merged: dict[ResourceType, FallbackProxy] = { + name: FallbackProxy(resource_apis[name], fallback_resource_apis[name]) + for name in resource_apis + } + + return cls(merged) diff --git a/openml/_api/setup/config.py b/openml/_api/setup/config.py new file mode 100644 index 000000000..4108227aa --- /dev/null +++ b/openml/_api/setup/config.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import timedelta + +from openml.enums import APIVersion, RetryPolicy + +from ._utils import _resolve_default_cache_dir + + +@dataclass +class APIConfig: + server: str + base_url: str + api_key: str + + +@dataclass +class ConnectionConfig: + retries: int + retry_policy: RetryPolicy + + +@dataclass +class CacheConfig: + dir: str + ttl: int + + +@dataclass +class Config: + api_version: APIVersion = APIVersion.V1 + fallback_api_version: APIVersion | None = None + + api_configs: dict[APIVersion, APIConfig] = field( + default_factory=lambda: { + APIVersion.V1: APIConfig( + server="https://www.openml.org/", + base_url="api/v1/xml/", + api_key="", + ), + APIVersion.V2: APIConfig( + server="http://localhost:8002/", + base_url="", + api_key="", + ), + } + ) + + connection: ConnectionConfig = field( + default_factory=lambda: ConnectionConfig( + retries=5, + retry_policy=RetryPolicy.HUMAN, + ) + ) + + cache: CacheConfig = field( + default_factory=lambda: CacheConfig( + dir=str(_resolve_default_cache_dir()), + ttl=int(timedelta(weeks=1).total_seconds()), + ) + ) diff --git a/openml/config.py b/openml/config.py index e6104fd7f..c266ae9d9 100644 --- a/openml/config.py +++ b/openml/config.py @@ -18,6 +18,8 @@ from typing_extensions import TypedDict from urllib.parse import urlparse +from openml.enums import RetryPolicy + logger = logging.getLogger(__name__) openml_logger = logging.getLogger("openml") console_handler: logging.StreamHandler | None = None @@ -206,6 +208,8 @@ def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = N retry_policy = value connection_n_retries = default_retries_by_policy[value] if n_retries is None else n_retries + _sync_api_config() + class ConfigurationForExamples: """Allows easy switching to and from a test configuration, used for examples.""" @@ -244,6 +248,8 @@ def start_using_configuration_for_example(cls) -> None: stacklevel=2, ) + _sync_api_config() + @classmethod def stop_using_configuration_for_example(cls) -> None: """Return to configuration as it was before `start_use_example_configuration`.""" @@ -262,6 +268,8 @@ def stop_using_configuration_for_example(cls) -> None: apikey = cast("str", cls._last_used_key) cls._start_last_called = False + _sync_api_config() + def _handle_xdg_config_home_backwards_compatibility( xdg_home: str, @@ -374,6 +382,8 @@ def _setup(config: _Config | None = None) -> None: short_cache_dir = Path(config["cachedir"]) _root_cache_directory = short_cache_dir.expanduser().resolve() + _sync_api_config() + try: cache_exists = _root_cache_directory.exists() # create the cache subdirectory @@ -408,6 +418,8 @@ def set_field_in_config_file(field: str, value: Any) -> None: if value is not None: fh.write(f"{f} = {value}\n") + _sync_api_config() + def _parse_config(config_file: str | Path) -> _Config: """Parse the config file, set up defaults.""" @@ -495,6 +507,8 @@ def set_root_cache_directory(root_cache_directory: str | Path) -> None: global _root_cache_directory # noqa: PLW0603 _root_cache_directory = Path(root_cache_directory) + _sync_api_config() + start_using_configuration_for_example = ( ConfigurationForExamples.start_using_configuration_for_example @@ -514,6 +528,28 @@ def overwrite_config_context(config: dict[str, Any]) -> Iterator[_Config]: _setup(existing_config) +def _sync_api_config() -> None: + """Sync the new API config with the legacy config in this file.""" + from ._api import APIBackend + + p = urlparse(server) + v1_server = f"{p.scheme}://{p.netloc}/" + v1_base_url = p.path.lstrip("/") + connection_retry_policy = RetryPolicy.HUMAN if retry_policy == "human" else RetryPolicy.ROBOT + cache_dir = str(_root_cache_directory) + + APIBackend.set_config_values( + { + "api_configs.v1.server": v1_server, + "api_configs.v1.base_url": v1_base_url, + "api_configs.v1.api_key": apikey, + "cache.dir": cache_dir, + "connection.retry_policy": connection_retry_policy, + "connection.retries": connection_n_retries, + } + ) + + __all__ = [ "get_cache_directory", "get_config_as_dict", diff --git a/openml/enums.py b/openml/enums.py new file mode 100644 index 000000000..f5a4381b7 --- /dev/null +++ b/openml/enums.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from enum import Enum + + +class APIVersion(str, Enum): + """Supported OpenML API versions.""" + + V1 = "v1" + V2 = "v2" + + +class ResourceType(str, Enum): + """Canonical resource types exposed by the OpenML API.""" + + DATASET = "dataset" + TASK = "task" + TASK_TYPE = "task_type" + EVALUATION_MEASURE = "evaluation_measure" + ESTIMATION_PROCEDURE = "estimation_procedure" + EVALUATION = "evaluation" + FLOW = "flow" + STUDY = "study" + RUN = "run" + SETUP = "setup" + USER = "user" + + +class RetryPolicy(str, Enum): + """Retry behavior for failed API requests.""" + + HUMAN = "human" + ROBOT = "robot" diff --git a/openml/exceptions.py b/openml/exceptions.py index fe63b8a58..10f693648 100644 --- a/openml/exceptions.py +++ b/openml/exceptions.py @@ -65,3 +65,11 @@ class OpenMLNotAuthorizedError(OpenMLServerError): class ObjectNotPublishedError(PyOpenMLError): """Indicates an object has not been published yet.""" + + +class OpenMLNotSupportedError(PyOpenMLError): + """Raised when an API operation is not supported for a resource/version.""" + + +class OpenMLCacheRequiredError(PyOpenMLError): + """Raised when a cache object is required but not provided.""" diff --git a/openml/tasks/functions.py b/openml/tasks/functions.py index 3df2861c0..ee0dd00c4 100644 --- a/openml/tasks/functions.py +++ b/openml/tasks/functions.py @@ -1,19 +1,14 @@ # License: BSD 3-Clause from __future__ import annotations -import os -import re import warnings from functools import partial -from typing import Any +from typing import TYPE_CHECKING, Any import pandas as pd -import xmltodict -import openml._api_calls import openml.utils from openml.datasets import get_dataset -from openml.exceptions import OpenMLCacheException from .task import ( OpenMLClassificationTask, @@ -21,109 +16,13 @@ OpenMLLearningCurveTask, OpenMLRegressionTask, OpenMLSupervisedTask, - OpenMLTask, TaskType, ) -TASKS_CACHE_DIR_NAME = "tasks" - - -def _get_cached_tasks() -> dict[int, OpenMLTask]: - """Return a dict of all the tasks which are cached locally. - - Returns - ------- - tasks : OrderedDict - A dict of all the cached tasks. Each task is an instance of - OpenMLTask. - """ - task_cache_dir = openml.utils._create_cache_directory(TASKS_CACHE_DIR_NAME) - directory_content = os.listdir(task_cache_dir) # noqa: PTH208 - directory_content.sort() - - # Find all dataset ids for which we have downloaded the dataset - # description - tids = (int(did) for did in directory_content if re.match(r"[0-9]*", did)) - return {tid: _get_cached_task(tid) for tid in tids} - - -def _get_cached_task(tid: int) -> OpenMLTask: - """Return a cached task based on the given id. - - Parameters - ---------- - tid : int - Id of the task. - - Returns - ------- - OpenMLTask - """ - tid_cache_dir = openml.utils._create_cache_directory_for_id(TASKS_CACHE_DIR_NAME, tid) - - task_xml_path = tid_cache_dir / "task.xml" - try: - with task_xml_path.open(encoding="utf8") as fh: - return _create_task_from_xml(fh.read()) - except OSError as e: - openml.utils._remove_cache_dir_for_id(TASKS_CACHE_DIR_NAME, tid_cache_dir) - raise OpenMLCacheException(f"Task file for tid {tid} not cached") from e - - -def _get_estimation_procedure_list() -> list[dict[str, Any]]: - """Return a list of all estimation procedures which are on OpenML. - - Returns - ------- - procedures : list - A list of all estimation procedures. Every procedure is represented by - a dictionary containing the following information: id, task type id, - name, type, repeats, folds, stratified. - """ - url_suffix = "estimationprocedure/list" - xml_string = openml._api_calls._perform_api_call(url_suffix, "get") - - procs_dict = xmltodict.parse(xml_string) - # Minimalistic check if the XML is useful - if "oml:estimationprocedures" not in procs_dict: - raise ValueError("Error in return XML, does not contain tag oml:estimationprocedures.") - - if "@xmlns:oml" not in procs_dict["oml:estimationprocedures"]: - raise ValueError( - "Error in return XML, does not contain tag " - "@xmlns:oml as a child of oml:estimationprocedures.", - ) - - if procs_dict["oml:estimationprocedures"]["@xmlns:oml"] != "http://openml.org/openml": - raise ValueError( - "Error in return XML, value of " - "oml:estimationprocedures/@xmlns:oml is not " - "http://openml.org/openml, but {}".format( - str(procs_dict["oml:estimationprocedures"]["@xmlns:oml"]) - ), - ) - - procs: list[dict[str, Any]] = [] - for proc_ in procs_dict["oml:estimationprocedures"]["oml:estimationprocedure"]: - task_type_int = int(proc_["oml:ttid"]) - try: - task_type_id = TaskType(task_type_int) - procs.append( - { - "id": int(proc_["oml:id"]), - "task_type_id": task_type_id, - "name": proc_["oml:name"], - "type": proc_["oml:type"], - }, - ) - except ValueError as e: - warnings.warn( - f"Could not create task type id for {task_type_int} due to error {e}", - RuntimeWarning, - stacklevel=2, - ) - - return procs +if TYPE_CHECKING: + from .task import ( + OpenMLTask, + ) def list_tasks( # noqa: PLR0913 @@ -175,7 +74,7 @@ def list_tasks( # noqa: PLR0913 calculated for the associated dataset, some of these are also returned. """ listing_call = partial( - _list_tasks, + openml._backend.task.list, task_type=task_type, tag=tag, data_tag=data_tag, @@ -194,151 +93,6 @@ def list_tasks( # noqa: PLR0913 return pd.concat(batches) -def _list_tasks( - limit: int, - offset: int, - task_type: TaskType | int | None = None, - **kwargs: Any, -) -> pd.DataFrame: - """ - Perform the api call to return a number of tasks having the given filters. - - Parameters - ---------- - Filter task_type is separated from the other filters because - it is used as task_type in the task description, but it is named - type when used as a filter in list tasks call. - limit: int - offset: int - task_type : TaskType, optional - Refers to the type of task. - kwargs: dict, optional - Legal filter operators: tag, task_id (list), data_tag, status, limit, - offset, data_id, data_name, number_instances, number_features, - number_classes, number_missing_values. - - Returns - ------- - dataframe - """ - api_call = "task/list" - if limit is not None: - api_call += f"/limit/{limit}" - if offset is not None: - api_call += f"/offset/{offset}" - if task_type is not None: - tvalue = task_type.value if isinstance(task_type, TaskType) else task_type - api_call += f"/type/{tvalue}" - if kwargs is not None: - for operator, value in kwargs.items(): - if value is not None: - if operator == "task_id": - value = ",".join([str(int(i)) for i in value]) # noqa: PLW2901 - api_call += f"/{operator}/{value}" - - return __list_tasks(api_call=api_call) - - -def __list_tasks(api_call: str) -> pd.DataFrame: # noqa: C901, PLR0912 - """Returns a Pandas DataFrame with information about OpenML tasks. - - Parameters - ---------- - api_call : str - The API call specifying which tasks to return. - - Returns - ------- - A Pandas DataFrame with information about OpenML tasks. - - Raises - ------ - ValueError - If the XML returned by the OpenML API does not contain 'oml:tasks', '@xmlns:oml', - or has an incorrect value for '@xmlns:oml'. - KeyError - If an invalid key is found in the XML for a task. - """ - xml_string = openml._api_calls._perform_api_call(api_call, "get") - tasks_dict = xmltodict.parse(xml_string, force_list=("oml:task", "oml:input")) - # Minimalistic check if the XML is useful - if "oml:tasks" not in tasks_dict: - raise ValueError(f'Error in return XML, does not contain "oml:runs": {tasks_dict}') - - if "@xmlns:oml" not in tasks_dict["oml:tasks"]: - raise ValueError( - f'Error in return XML, does not contain "oml:runs"/@xmlns:oml: {tasks_dict}' - ) - - if tasks_dict["oml:tasks"]["@xmlns:oml"] != "http://openml.org/openml": - raise ValueError( - "Error in return XML, value of " - '"oml:runs"/@xmlns:oml is not ' - f'"http://openml.org/openml": {tasks_dict!s}', - ) - - assert isinstance(tasks_dict["oml:tasks"]["oml:task"], list), type(tasks_dict["oml:tasks"]) - - tasks = {} - procs = _get_estimation_procedure_list() - proc_dict = {x["id"]: x for x in procs} - - for task_ in tasks_dict["oml:tasks"]["oml:task"]: - tid = None - try: - tid = int(task_["oml:task_id"]) - task_type_int = int(task_["oml:task_type_id"]) - try: - task_type_id = TaskType(task_type_int) - except ValueError as e: - warnings.warn( - f"Could not create task type id for {task_type_int} due to error {e}", - RuntimeWarning, - stacklevel=2, - ) - continue - - task = { - "tid": tid, - "ttid": task_type_id, - "did": int(task_["oml:did"]), - "name": task_["oml:name"], - "task_type": task_["oml:task_type"], - "status": task_["oml:status"], - } - - # Other task inputs - for _input in task_.get("oml:input", []): - if _input["@name"] == "estimation_procedure": - task[_input["@name"]] = proc_dict[int(_input["#text"])]["name"] - else: - value = _input.get("#text") - task[_input["@name"]] = value - - # The number of qualities can range from 0 to infinity - for quality in task_.get("oml:quality", []): - if "#text" not in quality: - quality_value = 0.0 - else: - quality["#text"] = float(quality["#text"]) - if abs(int(quality["#text"]) - quality["#text"]) < 0.0000001: - quality["#text"] = int(quality["#text"]) - quality_value = quality["#text"] - task[quality["@name"]] = quality_value - tasks[tid] = task - except KeyError as e: - if tid is not None: - warnings.warn( - f"Invalid xml for task {tid}: {e}\nFrom {task_}", - RuntimeWarning, - stacklevel=2, - ) - else: - warnings.warn(f"Could not find key {e} in {task_}!", RuntimeWarning, stacklevel=2) - - return pd.DataFrame.from_dict(tasks, orient="index") - - def get_tasks( task_ids: list[int], download_data: bool | None = None, @@ -346,7 +100,7 @@ def get_tasks( ) -> list[OpenMLTask]: """Download tasks. - This function iterates :meth:`openml.tasks.get_task`. + This function iterates :meth:`openml.task.get`. Parameters ---------- @@ -412,132 +166,32 @@ def get_task( ------- task: OpenMLTask """ + from openml._api.resources.task import TaskV1API, TaskV2API + if not isinstance(task_id, int): raise TypeError(f"Task id should be integer, is {type(task_id)}") - cache_key_dir = openml.utils._create_cache_directory_for_id(TASKS_CACHE_DIR_NAME, task_id) - tid_cache_dir = cache_key_dir / str(task_id) - tid_cache_dir_existed = tid_cache_dir.exists() - try: - task = _get_task_description(task_id) - dataset = get_dataset(task.dataset_id, **get_dataset_kwargs) - # List of class labels available in dataset description - # Including class labels as part of task meta data handles - # the case where data download was initially disabled - if isinstance(task, (OpenMLClassificationTask, OpenMLLearningCurveTask)): - task.class_labels = dataset.retrieve_class_labels(task.target_name) - # Clustering tasks do not have class labels - # and do not offer download_split - if download_splits and isinstance(task, OpenMLSupervisedTask): - task.download_split() - except Exception as e: - if not tid_cache_dir_existed: - openml.utils._remove_cache_dir_for_id(TASKS_CACHE_DIR_NAME, tid_cache_dir) - raise e - - return task - - -def _get_task_description(task_id: int) -> OpenMLTask: - try: - return _get_cached_task(task_id) - except OpenMLCacheException: - _cache_dir = openml.utils._create_cache_directory_for_id(TASKS_CACHE_DIR_NAME, task_id) - xml_file = _cache_dir / "task.xml" - task_xml = openml._api_calls._perform_api_call(f"task/{task_id}", "get") - - with xml_file.open("w", encoding="utf8") as fh: - fh.write(task_xml) - return _create_task_from_xml(task_xml) - - -def _create_task_from_xml(xml: str) -> OpenMLTask: - """Create a task given a xml string. + task = openml._backend.task.get(task_id) + dataset = get_dataset(task.dataset_id, **get_dataset_kwargs) - Parameters - ---------- - xml : string - Task xml representation. + if isinstance(task, (OpenMLClassificationTask, OpenMLLearningCurveTask)): + task.class_labels = dataset.retrieve_class_labels(task.target_name) - Returns - ------- - OpenMLTask - """ - dic = xmltodict.parse(xml)["oml:task"] - estimation_parameters = {} - inputs = {} - # Due to the unordered structure we obtain, we first have to extract - # the possible keys of oml:input; dic["oml:input"] is a list of - # OrderedDicts - - # Check if there is a list of inputs - if isinstance(dic["oml:input"], list): - for input_ in dic["oml:input"]: - name = input_["@name"] - inputs[name] = input_ - # Single input case - elif isinstance(dic["oml:input"], dict): - name = dic["oml:input"]["@name"] - inputs[name] = dic["oml:input"] - - evaluation_measures = None - if "evaluation_measures" in inputs: - evaluation_measures = inputs["evaluation_measures"]["oml:evaluation_measures"][ - "oml:evaluation_measure" - ] - - task_type = TaskType(int(dic["oml:task_type_id"])) - common_kwargs = { - "task_id": dic["oml:task_id"], - "task_type": dic["oml:task_type"], - "task_type_id": task_type, - "data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"], - "evaluation_measure": evaluation_measures, - } - # TODO: add OpenMLClusteringTask? - if task_type in ( - TaskType.SUPERVISED_CLASSIFICATION, - TaskType.SUPERVISED_REGRESSION, - TaskType.LEARNING_CURVE, + if ( + download_splits + and isinstance(task, OpenMLSupervisedTask) + and isinstance(openml._backend.task, TaskV1API) ): - # Convert some more parameters - for parameter in inputs["estimation_procedure"]["oml:estimation_procedure"][ - "oml:parameter" - ]: - name = parameter["@name"] - text = parameter.get("#text", "") - estimation_parameters[name] = text - - common_kwargs["estimation_procedure_type"] = inputs["estimation_procedure"][ - "oml:estimation_procedure" - ]["oml:type"] - common_kwargs["estimation_procedure_id"] = int( - inputs["estimation_procedure"]["oml:estimation_procedure"]["oml:id"] + task.download_split() + elif download_splits and isinstance(openml._backend.task, TaskV2API): + warnings.warn( + "`download_splits` is not yet supported in the v2 API and will be ignored.", + stacklevel=2, ) - common_kwargs["estimation_parameters"] = estimation_parameters - common_kwargs["target_name"] = inputs["source_data"]["oml:data_set"]["oml:target_feature"] - common_kwargs["data_splits_url"] = inputs["estimation_procedure"][ - "oml:estimation_procedure" - ]["oml:data_splits_url"] - - cls = { - TaskType.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask, - TaskType.SUPERVISED_REGRESSION: OpenMLRegressionTask, - TaskType.CLUSTERING: OpenMLClusteringTask, - TaskType.LEARNING_CURVE: OpenMLLearningCurveTask, - }.get(task_type) - if cls is None: - raise NotImplementedError( - f"Task type '{common_kwargs['task_type']}' is not supported. " - f"Supported task types: SUPERVISED_CLASSIFICATION," - f"SUPERVISED_REGRESSION, CLUSTERING, LEARNING_CURVE." - f"Please check the OpenML documentation for available task types." - ) - return cls(**common_kwargs) # type: ignore + return task -# TODO(eddiebergman): overload on `task_type` def create_task( task_type: TaskType, dataset_id: int, @@ -624,4 +278,4 @@ def delete_task(task_id: int) -> bool: bool True if the deletion was successful. False otherwise. """ - return openml.utils._delete_entity("task", task_id) + return openml._backend.task.delete(task_id) diff --git a/openml/tasks/task.py b/openml/tasks/task.py index b297a105c..1dbbe7595 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any from typing_extensions import TypedDict -import openml._api_calls import openml.config from openml import datasets from openml.base import OpenMLBase @@ -172,10 +171,7 @@ def _download_split(self, cache_file: Path) -> None: pass except OSError: split_url = self.estimation_procedure["data_splits_url"] - openml._api_calls._download_text_file( - source=str(split_url), - output_path=str(cache_file), - ) + openml._backend.task.download(url=str(split_url), file_name="datasplits.arff") def download_split(self) -> OpenMLSplit: """Download the OpenML split for a given task.""" @@ -222,6 +218,46 @@ def _parse_publish_response(self, xml_response: dict) -> None: """Parse the id from the xml_response and assign it to self.""" self.task_id = int(xml_response["oml:upload_task"]["oml:id"]) + def publish(self) -> OpenMLTask: + """Publish this task to OpenML server. + + Returns + ------- + self : OpenMLTask + """ + file_elements = self._get_file_elements() + if "description" not in file_elements: + file_elements["description"] = self._to_xml() + task_id = openml._backend.task.publish(path="task", files=file_elements) + self.task_id = task_id + return self + + def push_tag(self, tag: str) -> None: + """Annotates this task with a tag on the server. + + Parameters + ---------- + tag : str + Tag to attach to the task. + """ + if self.task_id is None: + raise ValueError("Task does not have an ID. Please publish the task before tagging.") + openml._backend.task.tag(self.task_id, tag) + + def remove_tag(self, tag: str) -> None: + """Removes a tag from this task on the server. + + Parameters + ---------- + tag : str + Tag to remove from the task. + """ + if self.task_id is None: + raise ValueError( + "Dataset does not have an ID. Please publish the dataset before untagging." + ) + openml._backend.task.untag(self.task_id, tag) + class OpenMLSupervisedTask(OpenMLTask, ABC): """OpenML Supervised Classification object. diff --git a/openml/testing.py b/openml/testing.py index 8d3bbbd5b..d73e15a2d 100644 --- a/openml/testing.py +++ b/openml/testing.py @@ -15,6 +15,8 @@ import requests import openml +from openml._api import HTTPCache, HTTPClient, MinIOClient +from openml.enums import APIVersion, RetryPolicy from openml.exceptions import OpenMLServerException from openml.tasks import TaskType @@ -107,6 +109,7 @@ def setUp(self, n_levels: int = 1, tmpdir_suffix: str = "") -> None: self.retry_policy = openml.config.retry_policy self.connection_n_retries = openml.config.connection_n_retries openml.config.set_retry_policy("robot", n_retries=20) + openml.config._sync_api_config() def use_production_server(self) -> None: """ @@ -116,6 +119,7 @@ def use_production_server(self) -> None: """ openml.config.server = self.production_server openml.config.apikey = "" + openml.config._sync_api_config() def tearDown(self) -> None: """Tear down the test""" @@ -129,6 +133,7 @@ def tearDown(self) -> None: openml.config.connection_n_retries = self.connection_n_retries openml.config.retry_policy = self.retry_policy + openml.config._sync_api_config() @classmethod def _mark_entity_for_removal( @@ -276,6 +281,56 @@ def _check_fold_timing_evaluations( # noqa: PLR0913 assert evaluation <= max_val +class TestAPIBase(unittest.TestCase): + retries: int + retry_policy: RetryPolicy + ttl: int + cache_dir: Path + cache: HTTPCache + http_clients: dict[APIVersion, HTTPClient] + minio_client: MinIOClient + current_api_version: APIVersion | None + + def setUp(self) -> None: + config = openml._backend.get_config() + + self.retries = config.connection.retries + self.retry_policy = config.connection.retry_policy + self.ttl = config.cache.ttl + self.current_api_version = None + + abspath_this_file = Path(inspect.getfile(self.__class__)).absolute() + self.cache_dir = abspath_this_file.parent.parent / "files" + if not self.cache_dir.is_dir(): + raise ValueError( + f"Cannot find test cache dir, expected it to be {self.cache_dir}!", + ) + + self.cache = HTTPCache( + path=self.cache_dir, + ttl=self.ttl, + ) + self.http_clients = { + APIVersion.V1: HTTPClient( + server="https://test.openml.org/", + base_url="api/v1/xml/", + api_key="normaluser", + retries=self.retries, + retry_policy=self.retry_policy, + cache=self.cache, + ), + APIVersion.V2: HTTPClient( + server="http://localhost:8002/", + base_url="", + api_key="", + retries=self.retries, + retry_policy=self.retry_policy, + cache=self.cache, + ), + } + self.minio_client = MinIOClient(path=self.cache_dir) + + def check_task_existence( task_type: TaskType, dataset_id: int, diff --git a/tests/conftest.py b/tests/conftest.py index bd974f3f3..bcf93bd72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,6 +99,7 @@ def delete_remote_files(tracker, flow_names) -> None: """ openml.config.server = TestBase.test_server openml.config.apikey = TestBase.user_key + openml.config._sync_api_config() # reordering to delete sub flows at the end of flows # sub-flows have shorter names, hence, sorting by descending order of flow name length @@ -275,10 +276,12 @@ def with_server(request): if "production" in request.keywords: openml.config.server = "https://www.openml.org/api/v1/xml" openml.config.apikey = None + openml.config._sync_api_config() yield return openml.config.server = "https://test.openml.org/api/v1/xml" openml.config.apikey = TestBase.user_key + openml.config._sync_api_config() yield diff --git a/tests/test_api/__init__.py b/tests/test_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_api/test_http.py b/tests/test_api/test_http.py new file mode 100644 index 000000000..3c35ea5e1 --- /dev/null +++ b/tests/test_api/test_http.py @@ -0,0 +1,173 @@ +from requests import Response, Request +import time +import xmltodict +import pytest +from openml.testing import TestAPIBase +import os +from urllib.parse import urljoin +from openml.enums import APIVersion + + +class TestHTTPClient(TestAPIBase): + def setUp(self): + super().setUp() + self.http_client = self.http_clients[APIVersion.V1] + + def _prepare_url(self, path: str | None = None) -> str: + server = self.http_client.server + base_url = self.http_client.base_url + return urljoin(server, urljoin(base_url, path)) + + def test_cache(self): + url = self._prepare_url(path="task/31") + params = {"param1": "value1", "param2": "value2"} + + key = self.cache.get_key(url, params) + expected_key = os.path.join( + "org", + "openml", + "test", + "api", + "v1", + "xml", + "task", + "31", + "param1=value1¶m2=value2", + ) + + # validate key + self.assertEqual(key, expected_key) + + # create fake response + req = Request("GET", url).prepare() + response = Response() + response.status_code = 200 + response.url = url + response.reason = "OK" + response._content = b"test" + response.headers = {"Content-Type": "text/xml"} + response.encoding = "utf-8" + response.request = req + response.elapsed = type("Elapsed", (), {"total_seconds": lambda self: 0.1})() + + # save to cache + self.cache.save(key, response) + + # load from cache + cached_response = self.cache.load(key) + + # validate loaded response + self.assertEqual(cached_response.status_code, 200) + self.assertEqual(cached_response.url, url) + self.assertEqual(cached_response.content, b"test") + self.assertEqual( + cached_response.headers["Content-Type"], "text/xml" + ) + + @pytest.mark.uses_test_server() + def test_get(self): + response = self.http_client.get("task/1") + + self.assertEqual(response.status_code, 200) + self.assertIn(b" new request + self.assertNotEqual(response1_cache_time_stamp, response2_cache_time_stamp) + self.assertEqual(response2.status_code, 200) + self.assertEqual(response1.content, response2.content) + + @pytest.mark.uses_test_server() + def test_get_reset_cache(self): + path = "task/1" + + url = self._prepare_url(path=path) + key = self.cache.get_key(url, {}) + cache_path = self.cache._key_to_path(key) / "meta.json" + + response1 = self.http_client.get(path, use_cache=True) + response1_cache_time_stamp = cache_path.stat().st_ctime + + response2 = self.http_client.get(path, use_cache=True, reset_cache=True) + response2_cache_time_stamp = cache_path.stat().st_ctime + + self.assertNotEqual(response1_cache_time_stamp, response2_cache_time_stamp) + self.assertEqual(response2.status_code, 200) + self.assertEqual(response1.content, response2.content) + + @pytest.mark.uses_test_server() + def test_post_and_delete(self): + task_xml = """ + + 5 + 193 + 17 + + """ + + task_id = None + try: + # POST the task + post_response = self.http_client.post( + "task", + files={"description": task_xml}, + ) + self.assertEqual(post_response.status_code, 200) + xml_resp = xmltodict.parse(post_response.content) + task_id = int(xml_resp["oml:upload_task"]["oml:id"]) + + # GET the task to verify it exists + get_response = self.http_client.get(f"task/{task_id}") + self.assertEqual(get_response.status_code, 200) + + finally: + # DELETE the task if it was created + if task_id is not None: + del_response = self.http_client.delete(f"task/{task_id}") + self.assertEqual(del_response.status_code, 200) diff --git a/tests/test_api/test_tasks.py b/tests/test_api/test_tasks.py new file mode 100644 index 000000000..aad4644da --- /dev/null +++ b/tests/test_api/test_tasks.py @@ -0,0 +1,53 @@ +# License: BSD 3-Clause +from __future__ import annotations + +import pytest +import pandas as pd +from openml._api.resources.task import TaskV1API, TaskV2API +from openml.testing import TestAPIBase +from openml.tasks.task import TaskType +from openml.enums import APIVersion + +class TestTasksV1(TestAPIBase): + def setUp(self): + super().setUp() + self.resource = TaskV1API(self.http_client) + + @pytest.mark.uses_test_server() + def test_list_tasks(self): + """Verify V1 list endpoint returns a populated DataFrame.""" + tasks_df = self.resource.list(limit=5, offset=0) + assert isinstance(tasks_df, pd.DataFrame) + assert not tasks_df.empty + assert "tid" in tasks_df.columns + + @pytest.mark.uses_test_server() + def test_estimation_procedure_list(self): + """Verify that estimation procedure list endpoint works.""" + procs = self.resource._get_estimation_procedure_list() + assert isinstance(procs, list) + assert len(procs) > 0 + assert "id" in procs[0] + + +class TestTasksCombined(TestAPIBase): + def setUp(self): + super().setUp() + self.v1_resource = TaskV1API(self.http_clients[APIVersion.V1]) + self.v2_resource = TaskV2API(self.http_clients[APIVersion.V2]) + + def _get_first_tid(self, task_type: TaskType) -> int: + """Helper to find an existing task ID for a given type using the V1 resource.""" + tasks = self.v1_resource.list(limit=1, offset=0, task_type=task_type) + if tasks.empty: + pytest.skip(f"No tasks of type {task_type} found on test server.") + return int(tasks.iloc[0]["tid"]) + + @pytest.mark.uses_test_server() + def test_v2_get_task(self): + """Verify that we can get a task from V2 API using a task ID found via V1.""" + tid = self._get_first_tid(TaskType.SUPERVISED_CLASSIFICATION) + task_v1 = self.v1_resource.get(tid) + task_v2 = self.v2_resource.get(tid) + assert int(task_v1.task_id) == tid + assert int(task_v2.task_id) == tid \ No newline at end of file diff --git a/tests/test_api/test_versions.py b/tests/test_api/test_versions.py new file mode 100644 index 000000000..5fa9d624d --- /dev/null +++ b/tests/test_api/test_versions.py @@ -0,0 +1,88 @@ +from time import time +import pytest +from openml.testing import TestAPIBase +from openml._api import ResourceV1API, ResourceV2API, FallbackProxy +from openml.enums import ResourceType, APIVersion +from openml.exceptions import OpenMLNotSupportedError + + +@pytest.mark.uses_test_server() +class TestResourceAPIBase(TestAPIBase): + def _publish_and_delete(self): + task_xml = """ + + 5 + 193 + 17 + + """ + + task_id = self.resource.publish( + "task", + files={"description": task_xml}, + ) + self.assertIsNotNone(task_id) + + success = self.resource.delete(task_id) + self.assertTrue(success) + + def _tag_and_untag(self): + resource_id = 1 + unique_indicator = str(time()).replace(".", "") + tag = f"{self.__class__.__name__}_test_tag_and_untag_{unique_indicator}" + + tags = self.resource.tag(resource_id, tag) + self.assertIn(tag, tags) + + tags = self.resource.untag(resource_id, tag) + self.assertNotIn(tag, tags) + + +class TestResourceV1API(TestResourceAPIBase): + def setUp(self): + super().setUp() + http_client = self.http_clients[APIVersion.V1] + self.resource = ResourceV1API(http_client) + self.resource.resource_type = ResourceType.TASK + + def test_publish_and_delete(self): + self._publish_and_delete() + + def test_tag_and_untag(self): + self._tag_and_untag() + + +class TestResourceV2API(TestResourceAPIBase): + def setUp(self): + super().setUp() + http_client = self.http_clients[APIVersion.V2] + self.resource = ResourceV2API(http_client) + self.resource.resource_type = ResourceType.TASK + + def test_publish_and_delete(self): + with pytest.raises(OpenMLNotSupportedError): + self._tag_and_untag() + + def test_tag_and_untag(self): + with pytest.raises(OpenMLNotSupportedError): + self._tag_and_untag() + + +class TestResourceFallbackAPI(TestResourceAPIBase): + def setUp(self): + super().setUp() + http_client_v1 = self.http_clients[APIVersion.V1] + resource_v1 = ResourceV1API(http_client_v1) + resource_v1.resource_type = ResourceType.TASK + + http_client_v2 = self.http_clients[APIVersion.V2] + resource_v2 = ResourceV2API(http_client_v2) + resource_v2.resource_type = ResourceType.TASK + + self.resource = FallbackProxy(resource_v2, resource_v1) + + def test_publish_and_delete(self): + self._publish_and_delete() + + def test_tag_and_untag(self): + self._tag_and_untag() diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index c41664ba7..39a6c9cae 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -158,6 +158,7 @@ def test_check_datasets_active(self): [79], ) openml.config.server = self.test_server + openml.config._sync_api_config() @pytest.mark.uses_test_server() def test_illegal_character_tag(self): @@ -186,6 +187,7 @@ def test__name_to_id_with_deactivated(self): # /d/1 was deactivated assert openml.datasets.functions._name_to_id("anneal") == 2 openml.config.server = self.test_server + openml.config._sync_api_config() @pytest.mark.production() def test__name_to_id_with_multiple_active(self): @@ -438,6 +440,7 @@ def test__getarff_md5_issue(self): } n = openml.config.connection_n_retries openml.config.connection_n_retries = 1 + openml.config._sync_api_config() self.assertRaisesRegex( OpenMLHashException, @@ -448,6 +451,7 @@ def test__getarff_md5_issue(self): ) openml.config.connection_n_retries = n + openml.config._sync_api_config() @pytest.mark.uses_test_server() def test__get_dataset_features(self): @@ -617,6 +621,7 @@ def test_data_status(self): # admin key for test server (only admins can activate datasets. # all users can deactivate their own datasets) openml.config.apikey = TestBase.admin_key + openml.config._sync_api_config() openml.datasets.status_update(did, "active") self._assert_status_of_dataset(did=did, status="active") @@ -1555,6 +1560,7 @@ def test_list_datasets_with_high_size_parameter(self): # Reverting to test server openml.config.server = self.test_server + openml.config._sync_api_config() assert len(datasets_a) == len(datasets_b) diff --git a/tests/test_tasks/test_task_functions.py b/tests/test_tasks/test_task_functions.py index d44717177..b9ecb7310 100644 --- a/tests/test_tasks/test_task_functions.py +++ b/tests/test_tasks/test_task_functions.py @@ -3,16 +3,13 @@ import os import unittest -from typing import cast from unittest import mock -import pandas as pd import pytest -import requests import openml -from openml import OpenMLSplit, OpenMLTask -from openml.exceptions import OpenMLCacheException, OpenMLNotAuthorizedError, OpenMLServerException +from openml import OpenMLTask +from openml.exceptions import OpenMLNotAuthorizedError, OpenMLServerException from openml.tasks import TaskType from openml.testing import TestBase, create_request_response @@ -26,36 +23,6 @@ def setUp(self): def tearDown(self): super().tearDown() - @pytest.mark.uses_test_server() - def test__get_cached_tasks(self): - openml.config.set_root_cache_directory(self.static_cache_dir) - tasks = openml.tasks.functions._get_cached_tasks() - assert isinstance(tasks, dict) - assert len(tasks) == 3 - assert isinstance(next(iter(tasks.values())), OpenMLTask) - - @pytest.mark.uses_test_server() - def test__get_cached_task(self): - openml.config.set_root_cache_directory(self.static_cache_dir) - task = openml.tasks.functions._get_cached_task(1) - assert isinstance(task, OpenMLTask) - - def test__get_cached_task_not_cached(self): - openml.config.set_root_cache_directory(self.static_cache_dir) - self.assertRaisesRegex( - OpenMLCacheException, - "Task file for tid 2 not cached", - openml.tasks.functions._get_cached_task, - 2, - ) - - @pytest.mark.uses_test_server() - def test__get_estimation_procedure_list(self): - estimation_procedures = openml.tasks.functions._get_estimation_procedure_list() - assert isinstance(estimation_procedures, list) - assert isinstance(estimation_procedures[0], dict) - assert estimation_procedures[0]["task_type_id"] == TaskType.SUPERVISED_CLASSIFICATION - @pytest.mark.production() @pytest.mark.xfail(reason="failures_issue_1544", strict=False) def test_list_clustering_task(self): @@ -136,11 +103,6 @@ def test_list_tasks_per_type_paginate(self): assert j == task["ttid"] self._check_task(task) - @pytest.mark.uses_test_server() - def test__get_task(self): - openml.config.set_root_cache_directory(self.static_cache_dir) - openml.tasks.get_task(1882) - @unittest.skip( "Please await outcome of discussion: https://github.com/openml/OpenML/issues/776", ) @@ -151,20 +113,6 @@ def test__get_task_live(self): # https://github.com/openml/openml-python/issues/378 openml.tasks.get_task(34536) - @pytest.mark.uses_test_server() - def test_get_task(self): - task = openml.tasks.get_task(1, download_data=True) # anneal; crossvalidation - assert isinstance(task, OpenMLTask) - assert os.path.exists( - os.path.join(self.workdir, "org", "openml", "test", "tasks", "1", "task.xml") - ) - assert not os.path.exists( - os.path.join(self.workdir, "org", "openml", "test", "tasks", "1", "datasplits.arff") - ) - assert os.path.exists( - os.path.join(self.workdir, "org", "openml", "test", "datasets", "1", "dataset.arff") - ) - @pytest.mark.uses_test_server() def test_get_task_lazy(self): task = openml.tasks.get_task(2, download_data=False) # anneal; crossvalidation @@ -187,104 +135,37 @@ def test_get_task_lazy(self): os.path.join(self.workdir, "org", "openml", "test", "tasks", "2", "datasplits.arff") ) - @mock.patch("openml.tasks.functions.get_dataset") - @pytest.mark.uses_test_server() - def test_removal_upon_download_failure(self, get_dataset): - class WeirdException(Exception): - pass - - def assert_and_raise(*args, **kwargs): - # Make sure that the file was created! - assert os.path.join(os.getcwd(), "tasks", "1", "tasks.xml") - raise WeirdException() - - get_dataset.side_effect = assert_and_raise - try: - openml.tasks.get_task(1) # anneal; crossvalidation - except WeirdException: - pass - # Now the file should no longer exist - assert not os.path.exists(os.path.join(os.getcwd(), "tasks", "1", "tasks.xml")) - - @pytest.mark.uses_test_server() - def test_get_task_with_cache(self): - openml.config.set_root_cache_directory(self.static_cache_dir) - task = openml.tasks.get_task(1) - assert isinstance(task, OpenMLTask) - - @pytest.mark.production() - def test_get_task_different_types(self): - self.use_production_server() - # Regression task - openml.tasks.functions.get_task(5001) - # Learning curve - openml.tasks.functions.get_task(64) - # Issue 538, get_task failing with clustering task. - openml.tasks.functions.get_task(126033) - - @pytest.mark.uses_test_server() - def test_download_split(self): - task = openml.tasks.get_task(1) # anneal; crossvalidation - split = task.download_split() - assert type(split) == OpenMLSplit - assert os.path.exists( - os.path.join(self.workdir, "org", "openml", "test", "tasks", "1", "datasplits.arff") - ) - - def test_deletion_of_cache_dir(self): - # Simple removal - tid_cache_dir = openml.utils._create_cache_directory_for_id( - "tasks", - 1, - ) - assert os.path.exists(tid_cache_dir) - openml.utils._remove_cache_dir_for_id("tasks", tid_cache_dir) - assert not os.path.exists(tid_cache_dir) - - -@mock.patch.object(requests.Session, "delete") -def test_delete_task_not_owned(mock_delete, test_files_directory, test_api_key): +@mock.patch("openml._api.clients.http.HTTPClient.delete") +def test_delete_task_not_owned(mock_delete): openml.config.start_using_configuration_for_example() - content_file = test_files_directory / "mock_responses" / "tasks" / "task_delete_not_owned.xml" - mock_delete.return_value = create_request_response( - status_code=412, - content_filepath=content_file, + mock_delete.side_effect = OpenMLNotAuthorizedError( + "The task can not be deleted because it was not uploaded by you." ) - with pytest.raises( OpenMLNotAuthorizedError, match="The task can not be deleted because it was not uploaded by you.", ): openml.tasks.delete_task(1) - task_url = "https://test.openml.org/api/v1/xml/task/1" + task_url = "task/1" assert task_url == mock_delete.call_args.args[0] - assert test_api_key == mock_delete.call_args.kwargs.get("params", {}).get("api_key") - -@mock.patch.object(requests.Session, "delete") -def test_delete_task_with_run(mock_delete, test_files_directory, test_api_key): +@mock.patch("openml._api.clients.http.HTTPClient.delete") +def test_delete_task_with_run(mock_delete): openml.config.start_using_configuration_for_example() - content_file = test_files_directory / "mock_responses" / "tasks" / "task_delete_has_runs.xml" - mock_delete.return_value = create_request_response( - status_code=412, - content_filepath=content_file, - ) + mock_delete.side_effect = OpenMLServerException("Task does not exist") with pytest.raises( - OpenMLNotAuthorizedError, - match="The task can not be deleted because it still has associated entities:", + OpenMLServerException, + match="Task does not exist", ): openml.tasks.delete_task(3496) - task_url = "https://test.openml.org/api/v1/xml/task/3496" + task_url = "task/3496" assert task_url == mock_delete.call_args.args[0] - assert test_api_key == mock_delete.call_args.kwargs.get("params", {}).get("api_key") - -@mock.patch.object(requests.Session, "delete") -def test_delete_success(mock_delete, test_files_directory, test_api_key): - openml.config.start_using_configuration_for_example() +@mock.patch("openml._api.clients.http.HTTPClient.delete") +def test_delete_success(mock_delete, test_files_directory): content_file = test_files_directory / "mock_responses" / "tasks" / "task_delete_successful.xml" mock_delete.return_value = create_request_response( status_code=200, @@ -294,26 +175,14 @@ def test_delete_success(mock_delete, test_files_directory, test_api_key): success = openml.tasks.delete_task(361323) assert success - task_url = "https://test.openml.org/api/v1/xml/task/361323" + task_url = "task/361323" assert task_url == mock_delete.call_args.args[0] - assert test_api_key == mock_delete.call_args.kwargs.get("params", {}).get("api_key") - -@mock.patch.object(requests.Session, "delete") -def test_delete_unknown_task(mock_delete, test_files_directory, test_api_key): - openml.config.start_using_configuration_for_example() - content_file = test_files_directory / "mock_responses" / "tasks" / "task_delete_not_exist.xml" - mock_delete.return_value = create_request_response( - status_code=412, - content_filepath=content_file, - ) - - with pytest.raises( - OpenMLServerException, - match="Task does not exist", - ): +@mock.patch("openml._api.clients.http.HTTPClient.delete") +def test_delete_unknown_task(mock_delete): + mock_delete.side_effect = OpenMLServerException("Task does not exist") + with pytest.raises(OpenMLServerException, match="Task does not exist"): openml.tasks.delete_task(9_999_999) - task_url = "https://test.openml.org/api/v1/xml/task/9999999" - assert task_url == mock_delete.call_args.args[0] - assert test_api_key == mock_delete.call_args.kwargs.get("params", {}).get("api_key") + task_url = "task/9999999" + assert task_url == mock_delete.call_args.args[0] \ No newline at end of file