-
Notifications
You must be signed in to change notification settings - Fork 479
feat(llmobs): add async task support #16311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
271c9c9
b4867ac
befe628
960811d
5248919
fd0ce01
7dcbfda
6e00b73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||
| from abc import ABC | ||||||||||
| from abc import abstractmethod | ||||||||||
| import asyncio | ||||||||||
| from concurrent.futures import ThreadPoolExecutor | ||||||||||
| from copy import deepcopy | ||||||||||
| from dataclasses import dataclass | ||||||||||
|
|
@@ -350,10 +351,10 @@ class EvaluationResult(TypedDict): | |||||||||
|
|
||||||||||
|
|
||||||||||
| class _ExperimentRunInfo: | ||||||||||
| def __init__(self, run_interation: int): | ||||||||||
| def __init__(self, run_iteration: int): | ||||||||||
| self._id = uuid.uuid4() | ||||||||||
| # always increment the representation of iteration by 1 for readability | ||||||||||
| self._run_iteration = run_interation + 1 | ||||||||||
| self._run_iteration = run_iteration + 1 | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ExperimentRowResult(TypedDict): | ||||||||||
|
|
@@ -675,12 +676,54 @@ def run( | |||||||||
| ) -> ExperimentResult: | ||||||||||
| """Run the experiment by executing the task on all dataset records and evaluating the results. | ||||||||||
|
|
||||||||||
| For async tasks, use arun() instead if you are already in an async context. | ||||||||||
|
|
||||||||||
| :param jobs: Maximum number of concurrent task and evaluator executions (default: 1) | ||||||||||
| :param raise_errors: Whether to raise exceptions on task or evaluator errors (default: False) | ||||||||||
| :param sample_size: Optional number of dataset records to sample for testing | ||||||||||
| (default: None, uses full dataset) | ||||||||||
| :return: ExperimentResult containing evaluation results and metadata | ||||||||||
| """ | ||||||||||
| self._setup_experiment(jobs) | ||||||||||
| run_results = [] | ||||||||||
| for run_iteration in range(self._runs): | ||||||||||
| run = _ExperimentRunInfo(run_iteration) | ||||||||||
| task_results = self._run_task(jobs, run, raise_errors, sample_size) | ||||||||||
| run_result = self._process_run_results(task_results, run, raise_errors, jobs) | ||||||||||
| run_results.append(run_result) | ||||||||||
| return self._build_experiment_result(run_results) | ||||||||||
|
|
||||||||||
| async def arun( | ||||||||||
| self, | ||||||||||
| jobs: int = 1, | ||||||||||
| raise_errors: bool = False, | ||||||||||
| sample_size: Optional[int] = None, | ||||||||||
| ) -> ExperimentResult: | ||||||||||
| """Async version of run() for use with async task functions. | ||||||||||
|
|
||||||||||
| Use this method when calling from an async context. | ||||||||||
|
|
||||||||||
| :param jobs: Maximum number of concurrent task executions (default: 1) | ||||||||||
| :param raise_errors: Whether to raise exceptions on task or evaluator errors (default: False) | ||||||||||
| :param sample_size: Optional number of dataset records to sample for testing | ||||||||||
| (default: None, uses full dataset) | ||||||||||
| :return: ExperimentResult containing evaluation results and metadata | ||||||||||
| """ | ||||||||||
| self._setup_experiment(jobs) | ||||||||||
|
|
||||||||||
| semaphore = asyncio.Semaphore(jobs) | ||||||||||
|
|
||||||||||
| async def _run_single_iteration(run_iteration: int) -> ExperimentRun: | ||||||||||
| run = _ExperimentRunInfo(run_iteration) | ||||||||||
| task_results = await self._run_task_async(semaphore, run, raise_errors, sample_size) | ||||||||||
| return self._process_run_results(task_results, run, raise_errors, jobs) | ||||||||||
|
|
||||||||||
| run_results = await asyncio.gather(*[_run_single_iteration(i) for i in range(self._runs)]) | ||||||||||
|
|
||||||||||
| return self._build_experiment_result(list(run_results)) | ||||||||||
|
|
||||||||||
| def _setup_experiment(self, jobs: int) -> None: | ||||||||||
| """Validate inputs and set up experiment state.""" | ||||||||||
| if jobs < 1: | ||||||||||
| raise ValueError("jobs must be at least 1") | ||||||||||
|
|
||||||||||
|
|
@@ -710,29 +753,41 @@ def run( | |||||||||
| self._id = experiment_id | ||||||||||
| self._tags["experiment_id"] = str(experiment_id) | ||||||||||
| self._run_name = experiment_run_name | ||||||||||
| run_results = [] | ||||||||||
| # for backwards compatibility | ||||||||||
| for run_iteration in range(self._runs): | ||||||||||
| run = _ExperimentRunInfo(run_iteration) | ||||||||||
| self._tags["run_id"] = str(run._id) | ||||||||||
| self._tags["run_iteration"] = str(run._run_iteration) | ||||||||||
| task_results = self._run_task(jobs, run, raise_errors, sample_size) | ||||||||||
| evaluations = self._run_evaluators(task_results, raise_errors=raise_errors, jobs=jobs) | ||||||||||
| summary_evals = self._run_summary_evaluators(task_results, evaluations, raise_errors, jobs=jobs) | ||||||||||
| run_result = self._merge_results(run, task_results, evaluations, summary_evals) | ||||||||||
| experiment_evals = self._generate_metrics_from_exp_results(run_result) | ||||||||||
|
|
||||||||||
| def _process_run_results( | ||||||||||
| self, | ||||||||||
| task_results: List[TaskResult], | ||||||||||
| run: _ExperimentRunInfo, | ||||||||||
| raise_errors: bool, | ||||||||||
| jobs: int, | ||||||||||
| ) -> ExperimentRun: | ||||||||||
| """Run evaluators, merge results, and post metrics.""" | ||||||||||
| evaluations = self._run_evaluators(task_results, raise_errors=raise_errors, jobs=jobs) | ||||||||||
| summary_evals = self._run_summary_evaluators(task_results, evaluations, raise_errors, jobs=jobs) | ||||||||||
| run_result = self._merge_results(run, task_results, evaluations, summary_evals) | ||||||||||
| experiment_evals = self._generate_metrics_from_exp_results(run_result) | ||||||||||
| if self._llmobs_instance and self._id is not None: | ||||||||||
| self._llmobs_instance._dne_client.experiment_eval_post( | ||||||||||
| self._id, experiment_evals, convert_tags_dict_to_list(self._tags) | ||||||||||
| self._id, experiment_evals, convert_tags_dict_to_list(self._get_run_tags(run)) | ||||||||||
| ) | ||||||||||
| run_results.append(run_result) | ||||||||||
| return run_result | ||||||||||
|
|
||||||||||
| experiment_result: ExperimentResult = { | ||||||||||
| def _build_experiment_result(self, run_results: List[ExperimentRun]) -> ExperimentResult: | ||||||||||
| """Build the final experiment result from run results.""" | ||||||||||
| return { | ||||||||||
| # for backwards compatibility, the first result fills the old fields of rows and summary evals | ||||||||||
| "summary_evaluations": run_results[0].summary_evaluations if len(run_results) > 0 else {}, | ||||||||||
| "rows": run_results[0].rows if len(run_results) > 0 else [], | ||||||||||
| "runs": run_results, | ||||||||||
| } | ||||||||||
| return experiment_result | ||||||||||
|
|
||||||||||
| def _get_run_tags(self, run: _ExperimentRunInfo) -> Dict[str, str]: | ||||||||||
| """Get tags for a specific run, merging experiment-level tags with run-specific values.""" | ||||||||||
| return { | ||||||||||
| **self._tags, | ||||||||||
| "run_id": str(run._id), | ||||||||||
| "run_iteration": str(run._run_iteration), | ||||||||||
| } | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def url(self) -> str: | ||||||||||
|
|
@@ -762,7 +817,7 @@ def _process_record(self, idx_record: Tuple[int, DatasetRecord], run: _Experimen | |||||||||
| input_data = record["input_data"] | ||||||||||
| record_id = record.get("record_id", "") | ||||||||||
| tags = { | ||||||||||
| **self._tags, | ||||||||||
| **self._get_run_tags(run), | ||||||||||
| "dataset_id": str(self._dataset._id), | ||||||||||
| "dataset_record_id": str(record_id), | ||||||||||
| "experiment_id": str(self._id), | ||||||||||
|
|
@@ -796,19 +851,72 @@ def _process_record(self, idx_record: Tuple[int, DatasetRecord], run: _Experimen | |||||||||
| }, | ||||||||||
| } | ||||||||||
|
|
||||||||||
| def _run_task( | ||||||||||
| self, | ||||||||||
| jobs: int, | ||||||||||
| run: _ExperimentRunInfo, | ||||||||||
| raise_errors: bool = False, | ||||||||||
| sample_size: Optional[int] = None, | ||||||||||
| ) -> List[TaskResult]: | ||||||||||
| async def _process_record_async( | ||||||||||
| self, idx_record: Tuple[int, DatasetRecord], run: _ExperimentRunInfo | ||||||||||
| ) -> Optional[TaskResult]: | ||||||||||
| """Async version of _process_record that awaits async tasks.""" | ||||||||||
| if not self._llmobs_instance or not self._llmobs_instance.enabled: | ||||||||||
| return [] | ||||||||||
| return None | ||||||||||
| idx, record = idx_record | ||||||||||
| with self._llmobs_instance._experiment( | ||||||||||
| name=self._task.__name__, | ||||||||||
| experiment_id=self._id, | ||||||||||
| run_id=str(run._id), | ||||||||||
| run_iteration=run._run_iteration, | ||||||||||
| dataset_name=self._dataset.name, | ||||||||||
| project_name=self._project_name, | ||||||||||
| project_id=self._project_id, | ||||||||||
| experiment_name=self.name, | ||||||||||
| ) as span: | ||||||||||
| span_context = self._llmobs_instance.export_span(span=span) | ||||||||||
| if span_context: | ||||||||||
| span_id = span_context.get("span_id", "") | ||||||||||
| trace_id = span_context.get("trace_id", "") | ||||||||||
| else: | ||||||||||
| span_id, trace_id = "", "" | ||||||||||
| input_data = record["input_data"] | ||||||||||
| record_id = record.get("record_id", "") | ||||||||||
| tags = { | ||||||||||
| **self._get_run_tags(run), | ||||||||||
| "dataset_id": str(self._dataset._id), | ||||||||||
| "dataset_record_id": str(record_id), | ||||||||||
| "experiment_id": str(self._id), | ||||||||||
| } | ||||||||||
| output_data = None | ||||||||||
| try: | ||||||||||
| output_data = await self._task(input_data, self._config) # type: ignore[misc] | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't look correct to me – I was wondering how somehow this wasn't showing in tests's type checking but the fact they receive
I think this absolutely needs to be changed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I very much agree, I just realized this last night. The type hint for task does not fit the async case, and we have a directive for mypy to ignore that line in our code, but I think this should be handled differently. I was trying to get mypy to show errors in a test script I have where I try using local editable install of ddtrace -- it seems mypy is actually surprising permissive on allowing async function with similar signature where the type is Callable, but I'd rather have the correct type
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking of either having separate async_task argument to Experiment init, so the idea is the user will either only one of task and async_task. Another thought is that task arg could accept either sync or async task, use inspect.iscoroutinefunction() to determine if it's async. In either case, the user gets a friendly error if they call run() / arun() methods when it doesn't match the task type.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know what other APIs (e.g. Braintrust) do? Maybe there's a good pattern to re-use there. To be honest, I don't have a strong preference and it's an issue that I've encountered in the past (without a clear-cut solution).
I remember I think at some point seeing someone use that and although it works, it gets somewhat hard to properly type the code when relying on this pattern.
I think this is OK but again, typing this properly is somewhat complicated; I guess you'd have to use |
||||||||||
| except Exception: | ||||||||||
| span.set_exc_info(*sys.exc_info()) | ||||||||||
| self._llmobs_instance.annotate(span, input_data=input_data, output_data=output_data, tags=tags) | ||||||||||
|
|
||||||||||
| span._set_ctx_item(EXPERIMENT_EXPECTED_OUTPUT, record["expected_output"]) | ||||||||||
| if "metadata" in record: | ||||||||||
| span._set_ctx_item(EXPERIMENT_RECORD_METADATA, record["metadata"]) | ||||||||||
|
|
||||||||||
| return { | ||||||||||
| "idx": idx, | ||||||||||
| "span_id": span_id, | ||||||||||
| "trace_id": trace_id, | ||||||||||
| "timestamp": span.start_ns, | ||||||||||
| "output": output_data, | ||||||||||
| "metadata": { | ||||||||||
| "dataset_record_index": idx, | ||||||||||
| "experiment_name": self.name, | ||||||||||
| "dataset_name": self._dataset.name, | ||||||||||
| }, | ||||||||||
| "error": { | ||||||||||
| "message": span.get_tag(ERROR_MSG), | ||||||||||
| "stack": span.get_tag(ERROR_STACK), | ||||||||||
| "type": span.get_tag(ERROR_TYPE), | ||||||||||
| }, | ||||||||||
| } | ||||||||||
|
|
||||||||||
| def _get_subset_dataset(self, sample_size: Optional[int]) -> Dataset: | ||||||||||
| """Get dataset or a subset for sampling.""" | ||||||||||
| if sample_size is not None and sample_size < len(self._dataset): | ||||||||||
| subset_records = [deepcopy(record) for record in self._dataset._records[:sample_size]] | ||||||||||
| subset_name = "[Test subset of {} records] {}".format(sample_size, self._dataset.name) | ||||||||||
| subset_dataset = Dataset( | ||||||||||
| return Dataset( | ||||||||||
| name=subset_name, | ||||||||||
| project=self._dataset.project, | ||||||||||
| dataset_id=self._dataset._id, | ||||||||||
|
|
@@ -818,8 +926,28 @@ def _run_task( | |||||||||
| version=self._dataset._version, | ||||||||||
| _dne_client=self._dataset._dne_client, | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| subset_dataset = self._dataset | ||||||||||
| return self._dataset | ||||||||||
|
|
||||||||||
| def _check_task_error(self, result: TaskResult, raise_errors: bool) -> None: | ||||||||||
| """Check for task errors and raise if configured.""" | ||||||||||
| err_dict = result.get("error") or {} | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend a type hint here |
||||||||||
| if isinstance(err_dict, dict): | ||||||||||
| err_msg = err_dict.get("message") | ||||||||||
| err_stack = err_dict.get("stack") | ||||||||||
| err_type = err_dict.get("type") | ||||||||||
| if raise_errors and err_msg: | ||||||||||
| raise RuntimeError("Error on record {}: {}\n{}\n{}".format(result["idx"], err_msg, err_type, err_stack)) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason not to use an f-string here? |
||||||||||
|
|
||||||||||
| def _run_task( | ||||||||||
| self, | ||||||||||
| jobs: int, | ||||||||||
| run: _ExperimentRunInfo, | ||||||||||
| raise_errors: bool = False, | ||||||||||
| sample_size: Optional[int] = None, | ||||||||||
| ) -> List[TaskResult]: | ||||||||||
| if not self._llmobs_instance or not self._llmobs_instance.enabled: | ||||||||||
| return [] | ||||||||||
| subset_dataset = self._get_subset_dataset(sample_size) | ||||||||||
| task_results = [] | ||||||||||
| with ThreadPoolExecutor(max_workers=jobs) as executor: | ||||||||||
| for result in executor.map( | ||||||||||
|
|
@@ -830,18 +958,40 @@ def _run_task( | |||||||||
| if not result: | ||||||||||
| continue | ||||||||||
| task_results.append(result) | ||||||||||
| err_dict = result.get("error") or {} | ||||||||||
| if isinstance(err_dict, dict): | ||||||||||
| err_msg = err_dict.get("message") | ||||||||||
| err_stack = err_dict.get("stack") | ||||||||||
| err_type = err_dict.get("type") | ||||||||||
| if raise_errors and err_msg: | ||||||||||
| raise RuntimeError( | ||||||||||
| "Error on record {}: {}\n{}\n{}".format(result["idx"], err_msg, err_type, err_stack) | ||||||||||
| ) | ||||||||||
| self._check_task_error(result, raise_errors) | ||||||||||
| self._llmobs_instance.flush() # Ensure spans get submitted in serverless environments | ||||||||||
| return task_results | ||||||||||
|
|
||||||||||
| async def _run_task_async( | ||||||||||
| self, | ||||||||||
| semaphore: asyncio.Semaphore, | ||||||||||
| run: _ExperimentRunInfo, | ||||||||||
| raise_errors: bool = False, | ||||||||||
| sample_size: Optional[int] = None, | ||||||||||
| ) -> List[TaskResult]: | ||||||||||
| """Async version of _run_task for async task functions.""" | ||||||||||
| if not self._llmobs_instance or not self._llmobs_instance.enabled: | ||||||||||
| return [] | ||||||||||
| subset_dataset = self._get_subset_dataset(sample_size) | ||||||||||
|
|
||||||||||
| async def process_with_limit(idx: int, record: DatasetRecord) -> Optional[TaskResult]: | ||||||||||
| async with semaphore: | ||||||||||
| return await self._process_record_async((idx, record), run) | ||||||||||
|
|
||||||||||
| tasks = [process_with_limit(idx, record) for idx, record in enumerate(subset_dataset)] | ||||||||||
| results = await asyncio.gather(*tasks) | ||||||||||
|
Comment on lines
+981
to
+982
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Are those really Tasks? I think they're just coroutines.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call, Claude used the name task but I would usually in the past have used the name coros. I think I've seen that in a lot of examples of async code in the past. |
||||||||||
|
|
||||||||||
| task_results: List[TaskResult] = [] | ||||||||||
| for result in results: | ||||||||||
| if not result: | ||||||||||
| continue | ||||||||||
| task_results.append(result) | ||||||||||
| self._check_task_error(result, raise_errors) | ||||||||||
|
|
||||||||||
| if self._llmobs_instance: | ||||||||||
| self._llmobs_instance.flush() | ||||||||||
| return task_results | ||||||||||
|
|
||||||||||
| def _run_evaluators( | ||||||||||
| self, task_results: List[TaskResult], raise_errors: bool = False, jobs: int = 1 | ||||||||||
| ) -> List[EvaluationResult]: | ||||||||||
|
|
@@ -1025,7 +1175,9 @@ def _merge_results( | |||||||||
| experiment_results = [] | ||||||||||
| for idx, task_result in enumerate(task_results): | ||||||||||
| output_data = task_result["output"] | ||||||||||
| metadata: Dict[str, JSONType] = {"tags": cast(List[JSONType], convert_tags_dict_to_list(self._tags))} | ||||||||||
| metadata: Dict[str, JSONType] = { | ||||||||||
| "tags": cast(List[JSONType], convert_tags_dict_to_list(self._get_run_tags(run))) | ||||||||||
| } | ||||||||||
| metadata.update(task_result.get("metadata") or {}) | ||||||||||
| record: DatasetRecord = self._dataset[idx] | ||||||||||
| evals = evaluations[idx]["evaluations"] | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| --- | ||
| features: | ||
| - | | ||
| LLM Observability: Adds async task support for experiments. Use the ``arun()`` method to run experiments | ||
| with async task functions in an async context. |

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the type hint maybe?