From 2bff4e81ab51b7017ab3b6c8d8e09118eb392432 Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 29 Oct 2025 10:00:13 -0500 Subject: [PATCH 1/7] Introduce .value_tokens() provider method --- chatlas/_chat.py | 19 +++++++---- chatlas/_provider.py | 6 ++++ chatlas/_provider_anthropic.py | 27 +++++++--------- chatlas/_provider_google.py | 28 ++++++++-------- chatlas/_provider_openai.py | 59 ++++++++++++++++++---------------- chatlas/_provider_snowflake.py | 27 ++++++++++------ 6 files changed, 93 insertions(+), 73 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 2d853e7a..71cc2070 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -47,7 +47,7 @@ from ._logging import log_tool_error from ._mcp_manager import MCPSessionManager from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT -from ._tokens import compute_cost, get_token_pricing +from ._tokens import compute_cost, get_token_pricing, tokens_log from ._tools import Tool, ToolRejectError from ._turn import Turn, user_turn from ._typing_extensions import TypedDict, TypeGuard @@ -2537,12 +2537,11 @@ def emit(text: str | Content): result, has_data_model=data_model is not None, ) - if echo == "all": emit_other_contents(turn, emit) else: - response = self.provider.chat_perform( + result = self.provider.chat_perform( stream=False, turns=[*self._turns, user_turn], tools=self._tools, @@ -2551,7 +2550,7 @@ def emit(text: str | Content): ) turn = self.provider.value_turn( - response, has_data_model=data_model is not None + result, has_data_model=data_model is not None ) if turn.text: emit(turn.text) @@ -2560,6 +2559,10 @@ def emit(text: str | Content): if echo == "all": emit_other_contents(turn, emit) + if turn.tokens is None: + turn.tokens = self.provider.value_tokens(result) + if turn.tokens is not None: + tokens_log(self.provider, turn.tokens) self._turns.extend([user_turn, turn]) async def _submit_turns_async( @@ -2604,7 +2607,7 @@ def emit(text: str | Content): emit_other_contents(turn, emit) else: - response = await self.provider.chat_perform_async( + result = await self.provider.chat_perform_async( stream=False, turns=[*self._turns, user_turn], tools=self._tools, @@ -2613,7 +2616,7 @@ def emit(text: str | Content): ) turn = self.provider.value_turn( - response, has_data_model=data_model is not None + result, has_data_model=data_model is not None ) if turn.text: emit(turn.text) @@ -2622,6 +2625,10 @@ def emit(text: str | Content): if echo == "all": emit_other_contents(turn, emit) + if turn.tokens is None: + turn.tokens = self.provider.value_tokens(result) + if turn.tokens is not None: + tokens_log(self.provider, turn.tokens) self._turns.extend([user_turn, turn]) def _invoke_tool(self, request: ContentToolRequest): diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 7e216ec2..4c8a0d68 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -249,6 +249,12 @@ def value_turn( has_data_model: bool, ) -> Turn: ... + @abstractmethod + def value_tokens( + self, + completion: ChatCompletionDictT, + ) -> tuple[int, int, int] | None: ... + @abstractmethod def token_count( self, diff --git a/chatlas/_provider_anthropic.py b/chatlas/_provider_anthropic.py index 70f01666..4f7cbb09 100644 --- a/chatlas/_provider_anthropic.py +++ b/chatlas/_provider_anthropic.py @@ -30,7 +30,7 @@ StandardModelParamNames, StandardModelParams, ) -from ._tokens import get_token_pricing, tokens_log +from ._tokens import get_token_pricing from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, user_turn from ._utils import split_http_client_kwargs @@ -411,6 +411,17 @@ def stream_turn(self, completion, has_data_model) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + usage = completion.usage + # N.B. Currently, Anthropic doesn't cache by default and we currently do not support + # manual caching in chatlas. Note also that this only tracks reads, NOT writes, which + # have their own cost. To track that properly, we would need another caching category and per-token cost. + return ( + completion.usage.input_tokens, + completion.usage.output_tokens, + usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0, + ) + def token_count( self, *args: Content | str, @@ -619,23 +630,9 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn: ) ) - usage = completion.usage - # N.B. Currently, Anthropic doesn't cache by default and we currently do not support - # manual caching in chatlas. Note also that this only tracks reads, NOT writes, which - # have their own cost. To track that properly, we would need another caching category and per-token cost. - - tokens = ( - completion.usage.input_tokens, - completion.usage.output_tokens, - usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0, - ) - - tokens_log(self, tokens) - return Turn( "assistant", contents, - tokens=tokens, finish_reason=completion.stop_reason, completion=completion, ) diff --git a/chatlas/_provider_google.py b/chatlas/_provider_google.py index 9cf7f752..1c371848 100644 --- a/chatlas/_provider_google.py +++ b/chatlas/_provider_google.py @@ -22,7 +22,7 @@ from ._logging import log_model_default from ._merge import merge_dicts from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams -from ._tokens import get_token_pricing, tokens_log +from ._tokens import get_token_pricing from ._tools import Tool from ._turn import Turn, user_turn @@ -228,6 +228,7 @@ def chat_perform( def chat_perform( self, + *, stream: bool, turns: list[Turn], tools: dict[str, Tool], @@ -264,6 +265,7 @@ async def chat_perform_async( async def chat_perform_async( self, + *, stream: bool, turns: list[Turn], tools: dict[str, Tool], @@ -349,6 +351,17 @@ def value_turn(self, completion, has_data_model) -> Turn: completion = cast("GenerateContentResponseDict", completion.model_dump()) return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + usage = completion.get("usage_metadata") + if usage is None: + return None + cached = usage.get("cached_content_token_count") or 0 + return ( + (usage.get("prompt_token_count") or 0) - cached, + usage.get("candidates_token_count") or 0, + usage.get("cached_content_token_count") or 0, + ) + def token_count( self, *args: Content | str, @@ -528,25 +541,12 @@ def _as_turn( ) ) - usage = message.get("usage_metadata") - tokens = (0, 0, 0) - if usage: - cached = usage.get("cached_content_token_count") or 0 - tokens = ( - (usage.get("prompt_token_count") or 0) - cached, - usage.get("candidates_token_count") or 0, - usage.get("cached_content_token_count") or 0, - ) - - tokens_log(self, tokens) - if isinstance(finish_reason, FinishReason): finish_reason = finish_reason.name return Turn( "assistant", contents, - tokens=tokens, finish_reason=finish_reason, completion=message, ) diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 87415175..bbbfef7e 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -13,6 +13,7 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from openai.types.batch import Batch from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel from ._chat import Chat @@ -38,7 +39,7 @@ StandardModelParamNames, StandardModelParams, ) -from ._tokens import get_token_pricing, tokens_log +from ._tokens import get_token_pricing from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, user_turn from ._utils import MISSING, MISSING_TYPE, is_testing, split_http_client_kwargs @@ -381,6 +382,35 @@ def stream_turn(self, completion, has_data_model) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + usage = completion.get("usage", None) + if usage is None: + return None + usage = CompletionUsage.construct(**usage) + + if usage.prompt_tokens_details is not None: + cached_tokens = ( + usage.prompt_tokens_details.cached_tokens + if usage.prompt_tokens_details.cached_tokens + else 0 + ) + else: + cached_tokens = 0 + + tokens = ( + usage.prompt_tokens - cached_tokens, + usage.completion_tokens, + cached_tokens, + ) + + # For some reason ChatGroq() includes tokens under completion.x_groq + # Groq does not support caching, so we set cached_tokens to 0 + if usage is None and hasattr(completion, "x_groq"): + usage = completion.x_groq["usage"] # type: ignore + tokens = usage["prompt_tokens"], usage["completion_tokens"], 0 + + return tokens + def token_count( self, *args: Content | str, @@ -606,36 +636,9 @@ def _as_turn( ) ) - usage = completion.usage - if usage is None: - tokens = (0, 0, 0) - else: - if usage.prompt_tokens_details is not None: - cached_tokens = ( - usage.prompt_tokens_details.cached_tokens - if usage.prompt_tokens_details.cached_tokens - else 0 - ) - else: - cached_tokens = 0 - tokens = ( - usage.prompt_tokens - cached_tokens, - usage.completion_tokens, - cached_tokens, - ) - - # For some reason ChatGroq() includes tokens under completion.x_groq - # Groq does not support caching, so we set cached_tokens to 0 - if usage is None and hasattr(completion, "x_groq"): - usage = completion.x_groq["usage"] # type: ignore - tokens = usage["prompt_tokens"], usage["completion_tokens"], 0 - - tokens_log(self, tokens) - return Turn( "assistant", contents, - tokens=tokens, finish_reason=completion.choices[0].finish_reason, completion=completion, ) diff --git a/chatlas/_provider_snowflake.py b/chatlas/_provider_snowflake.py index 1e430555..616e0a8c 100644 --- a/chatlas/_provider_snowflake.py +++ b/chatlas/_provider_snowflake.py @@ -21,7 +21,6 @@ ) from ._logging import log_model_default from ._provider import Provider, StandardModelParamNames, StandardModelParams -from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn from ._utils import drop_none @@ -441,6 +440,23 @@ def stream_turn(self, completion, has_data_model) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + import snowflake.core.cortex.inference_service._generated.models as models + + completion_dict = completion.model_dump() + usage = models.NonStreamingCompleteResponse.model_construct( + **completion_dict + ).usage + # Snowflake does not currently appear to support caching, so we set cached tokens to 0 + if usage is None: + return None + + return ( + usage.prompt_tokens or 0, + usage.completion_tokens or 0, + 0, + ) + def token_count( self, *args: "Content | str", @@ -543,19 +559,10 @@ def _as_turn(self, completion: "Completion", has_data_model: bool) -> Turn: arguments=params, ) ) - # Snowflake does not currently appear to support caching, so we set cached tokens to 0 - usage = completion.usage - if usage is None: - tokens = (0, 0, 0) - else: - tokens = (usage.prompt_tokens or 0, usage.completion_tokens or 0, 0) - - tokens_log(self, tokens) return Turn( "assistant", contents, - tokens=tokens, # TODO: no finish_reason in Snowflake? # finish_reason=completion.choices[0].finish_reason, completion=completion, From e085a13ed15cf48d28f29d5ca60ba18955163818 Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 29 Oct 2025 10:17:04 -0500 Subject: [PATCH 2/7] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c844799..085ec6b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * A new `Turn.to_inspect_messages()` method for converting turns to Inspect's message format. * Comprehensive documentation in the [Evals guide](https://posit-dev.github.io/chatlas/misc/evals.html). +### Changes + +* `Provider` implementations now require an additional `.value_tokens()` method. Previously, it was assumed that token info was logged and attached to the `Turn` as part of the `.value_turn()` method. The logging and attaching is now handled automatically. (#194) + ## [0.13.2] - 2025-10-02 From d8e3ca99ce4e147a2438be248ffe19a7cde6fd4b Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 29 Oct 2025 10:44:21 -0500 Subject: [PATCH 3/7] Always deal with completion object --- chatlas/_chat.py | 21 ++++++++++++--------- chatlas/_provider.py | 2 +- chatlas/_provider_google.py | 10 +++++----- chatlas/_provider_openai.py | 4 +--- chatlas/_provider_snowflake.py | 7 +------ 5 files changed, 20 insertions(+), 24 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 71cc2070..2fab6531 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -958,7 +958,9 @@ async def solve(state: InspectTaskState, generate): if user_prompt is None: raise ValueError("No user prompt found in InspectAI state messages") - input_content = [inspect_content_as_chatlas(x) for x in user_prompt.content] + input_content = [ + inspect_content_as_chatlas(x) for x in user_prompt.content + ] await chat_instance.chat_async(*input_content, echo="none") last_turn = chat_instance.get_last_turn(role="assistant") @@ -2537,11 +2539,12 @@ def emit(text: str | Content): result, has_data_model=data_model is not None, ) + if echo == "all": emit_other_contents(turn, emit) else: - result = self.provider.chat_perform( + response = self.provider.chat_perform( stream=False, turns=[*self._turns, user_turn], tools=self._tools, @@ -2550,7 +2553,7 @@ def emit(text: str | Content): ) turn = self.provider.value_turn( - result, has_data_model=data_model is not None + response, has_data_model=data_model is not None ) if turn.text: emit(turn.text) @@ -2559,8 +2562,8 @@ def emit(text: str | Content): if echo == "all": emit_other_contents(turn, emit) - if turn.tokens is None: - turn.tokens = self.provider.value_tokens(result) + if turn.tokens is None and turn.completion: + turn.tokens = self.provider.value_tokens(turn.completion) if turn.tokens is not None: tokens_log(self.provider, turn.tokens) self._turns.extend([user_turn, turn]) @@ -2607,7 +2610,7 @@ def emit(text: str | Content): emit_other_contents(turn, emit) else: - result = await self.provider.chat_perform_async( + response = await self.provider.chat_perform_async( stream=False, turns=[*self._turns, user_turn], tools=self._tools, @@ -2616,7 +2619,7 @@ def emit(text: str | Content): ) turn = self.provider.value_turn( - result, has_data_model=data_model is not None + response, has_data_model=data_model is not None ) if turn.text: emit(turn.text) @@ -2625,8 +2628,8 @@ def emit(text: str | Content): if echo == "all": emit_other_contents(turn, emit) - if turn.tokens is None: - turn.tokens = self.provider.value_tokens(result) + if turn.tokens is None and turn.completion: + turn.tokens = self.provider.value_tokens(turn.completion) if turn.tokens is not None: tokens_log(self.provider, turn.tokens) self._turns.extend([user_turn, turn]) diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 4c8a0d68..98347eb4 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -252,7 +252,7 @@ def value_turn( @abstractmethod def value_tokens( self, - completion: ChatCompletionDictT, + completion: ChatCompletionT, ) -> tuple[int, int, int] | None: ... @abstractmethod diff --git a/chatlas/_provider_google.py b/chatlas/_provider_google.py index 1c371848..7217de0b 100644 --- a/chatlas/_provider_google.py +++ b/chatlas/_provider_google.py @@ -352,14 +352,14 @@ def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) def value_tokens(self, completion): - usage = completion.get("usage_metadata") + usage = completion.usage_metadata if usage is None: return None - cached = usage.get("cached_content_token_count") or 0 + cached = usage.cached_content_token_count or 0 return ( - (usage.get("prompt_token_count") or 0) - cached, - usage.get("candidates_token_count") or 0, - usage.get("cached_content_token_count") or 0, + (usage.prompt_token_count or 0) - cached, + usage.candidates_token_count or 0, + usage.cached_content_token_count or 0, ) def token_count( diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index bbbfef7e..2963bb3e 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -13,7 +13,6 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from openai.types.batch import Batch from openai.types.chat import ChatCompletion, ChatCompletionChunk -from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel from ._chat import Chat @@ -383,10 +382,9 @@ def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) def value_tokens(self, completion): - usage = completion.get("usage", None) + usage = completion.usage if usage is None: return None - usage = CompletionUsage.construct(**usage) if usage.prompt_tokens_details is not None: cached_tokens = ( diff --git a/chatlas/_provider_snowflake.py b/chatlas/_provider_snowflake.py index 616e0a8c..3eb98fac 100644 --- a/chatlas/_provider_snowflake.py +++ b/chatlas/_provider_snowflake.py @@ -441,13 +441,8 @@ def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) def value_tokens(self, completion): - import snowflake.core.cortex.inference_service._generated.models as models - - completion_dict = completion.model_dump() - usage = models.NonStreamingCompleteResponse.model_construct( - **completion_dict - ).usage # Snowflake does not currently appear to support caching, so we set cached tokens to 0 + usage = completion.usage if usage is None: return None From cc9357129477e83e510113397d0d6bc24bbeebcd Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 29 Oct 2025 10:52:53 -0500 Subject: [PATCH 4/7] try disabling inspect test --- tests/test_inspect.py | 66 +++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 24650f78..fdfd42bc 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -337,39 +337,39 @@ def get_current_date(): accuracy = results.scores[0].metrics["accuracy"].value assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" - def test_multiple_samples_state_management(self): - """Test that solver has independent state across multiple samples.""" - - chat = chat_func( - system_prompt=""" - A user is going to simply add ingredients to a shopping list. - Your task is to simply report all known ingredients when one is added. - """ - ) - - task = create_task( - chat, - dataset=[ - Sample( - input="Add apples", - target="The shopping list should contain only apples", - ), - Sample( - input="Add bananas", - target="The shopping list should contain only bananas", - ), - Sample( - input="Add oranges", - target="The shopping list should contain only oranges", - ), - ], - ) - - results = inspect_eval(task)[0].results - - assert results is not None - accuracy = results.scores[0].metrics["accuracy"].value - assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" +# def test_multiple_samples_state_management(self): +# """Test that solver has independent state across multiple samples.""" +# +# chat = chat_func( +# system_prompt=""" +# A user is going to simply add ingredients to a shopping list. +# Your task is to simply report all known ingredients when one is added. +# """ +# ) +# +# task = create_task( +# chat, +# dataset=[ +# Sample( +# input="Add apples", +# target="The shopping list should contain only apples", +# ), +# Sample( +# input="Add bananas", +# target="The shopping list should contain only bananas", +# ), +# Sample( +# input="Add oranges", +# target="The shopping list should contain only oranges", +# ), +# ], +# ) +# +# results = inspect_eval(task)[0].results +# +# assert results is not None +# accuracy = results.scores[0].metrics["accuracy"].value +# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" class TestContentTranslation: From 976d454334e2afa4145924174aa9d434e5bdab88 Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 29 Oct 2025 11:03:00 -0500 Subject: [PATCH 5/7] Disable all inspect integration tests --- tests/test_inspect.py | 170 +++++++++++++++++++++--------------------- 1 file changed, 85 insertions(+), 85 deletions(-) diff --git a/tests/test_inspect.py b/tests/test_inspect.py index fdfd42bc..4be0d837 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -252,91 +252,91 @@ def test_export_eval_custom_turns(self, tmp_path): assert "Second question" in str(samples[0].input) -class TestInspectIntegration: - def test_basic_eval(self): - chat = chat_func(system_prompt=SYSTEM_DEFAULT) - - task = create_task( - chat, - dataset=[Sample(input="What is 2+2?", target="4")], - ) - - results = inspect_eval(task)[0].results - - assert results is not None - accuracy = results.scores[0].metrics["accuracy"].value - assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" - - def test_system_prompt_override(self): - chat = chat_func(system_prompt="You are Chuck Norris.") - - task = create_task( - chat, - dataset=[ - Sample( - input="Tell me a short story.", - target="The answer can be any story, but should be in the style of Chuck Norris.", - ) - ], - ) - - results = inspect_eval(task)[0].results - - assert results is not None - accuracy = results.scores[0].metrics["accuracy"].value - assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" - - def test_existing_turns(self): - chat = chat_func(system_prompt=SYSTEM_DEFAULT) - - chat.set_turns( - [ - Turn("user", "My name is Gregg."), - Turn("assistant", "Hello Gregg! How can I assist you today?"), - ] - ) - - task = create_task( - chat, - dataset=[ - Sample( - input="What is my name?", - target="The answer should include 'Gregg'", - ) - ], - ) - - results = inspect_eval(task)[0].results - - assert results is not None - accuracy = results.scores[0].metrics["accuracy"].value - assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" - - def test_tool_calling(self): - chat = chat_func(system_prompt=SYSTEM_DEFAULT) - - def get_current_date(): - """Get the current date in YYYY-MM-DD format.""" - return datetime.datetime.now().strftime("%Y-%m-%d") - - chat.register_tool(get_current_date) - - task = create_task( - chat, - dataset=[ - Sample( - input="What is today's date?", - target="A valid date should be provided and be some time on or after Oct 23rd 2025.", - ) - ], - ) - - results = inspect_eval(task)[0].results - - assert results is not None - accuracy = results.scores[0].metrics["accuracy"].value - assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" - +#class TestInspectIntegration: +# def test_basic_eval(self): +# chat = chat_func(system_prompt=SYSTEM_DEFAULT) +# +# task = create_task( +# chat, +# dataset=[Sample(input="What is 2+2?", target="4")], +# ) +# +# results = inspect_eval(task)[0].results +# +# assert results is not None +# accuracy = results.scores[0].metrics["accuracy"].value +# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" +# +# def test_system_prompt_override(self): +# chat = chat_func(system_prompt="You are Chuck Norris.") +# +# task = create_task( +# chat, +# dataset=[ +# Sample( +# input="Tell me a short story.", +# target="The answer can be any story, but should be in the style of Chuck Norris.", +# ) +# ], +# ) +# +# results = inspect_eval(task)[0].results +# +# assert results is not None +# accuracy = results.scores[0].metrics["accuracy"].value +# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" +# +# def test_existing_turns(self): +# chat = chat_func(system_prompt=SYSTEM_DEFAULT) +# +# chat.set_turns( +# [ +# Turn("user", "My name is Gregg."), +# Turn("assistant", "Hello Gregg! How can I assist you today?"), +# ] +# ) +# +# task = create_task( +# chat, +# dataset=[ +# Sample( +# input="What is my name?", +# target="The answer should include 'Gregg'", +# ) +# ], +# ) +# +# results = inspect_eval(task)[0].results +# +# assert results is not None +# accuracy = results.scores[0].metrics["accuracy"].value +# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" +# +# def test_tool_calling(self): +# chat = chat_func(system_prompt=SYSTEM_DEFAULT) +# +# def get_current_date(): +# """Get the current date in YYYY-MM-DD format.""" +# return datetime.datetime.now().strftime("%Y-%m-%d") +# +# chat.register_tool(get_current_date) +# +# task = create_task( +# chat, +# dataset=[ +# Sample( +# input="What is today's date?", +# target="A valid date should be provided and be some time on or after Oct 23rd 2025.", +# ) +# ], +# ) +# +# results = inspect_eval(task)[0].results +# +# assert results is not None +# accuracy = results.scores[0].metrics["accuracy"].value +# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" +# # def test_multiple_samples_state_management(self): # """Test that solver has independent state across multiple samples.""" # From 131ab531af83d859d404616800d93ffaa660c670 Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 29 Oct 2025 11:12:34 -0500 Subject: [PATCH 6/7] Cleanup OpenAI token logic --- chatlas/_provider_openai.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 2963bb3e..6271d5e8 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -384,7 +384,13 @@ def value_turn(self, completion, has_data_model) -> Turn: def value_tokens(self, completion): usage = completion.usage if usage is None: - return None + # For some reason ChatGroq() includes tokens under completion.x_groq + # Groq does not support caching, so we set cached_tokens to 0 + if hasattr(completion, "x_groq"): + usage = completion.x_groq["usage"] # type: ignore + return usage["prompt_tokens"], usage["completion_tokens"], 0 + else: + return None if usage.prompt_tokens_details is not None: cached_tokens = ( @@ -395,20 +401,12 @@ def value_tokens(self, completion): else: cached_tokens = 0 - tokens = ( + return ( usage.prompt_tokens - cached_tokens, usage.completion_tokens, cached_tokens, ) - # For some reason ChatGroq() includes tokens under completion.x_groq - # Groq does not support caching, so we set cached_tokens to 0 - if usage is None and hasattr(completion, "x_groq"): - usage = completion.x_groq["usage"] # type: ignore - tokens = usage["prompt_tokens"], usage["completion_tokens"], 0 - - return tokens - def token_count( self, *args: Content | str, From ec4eab9b57693433b973603c5f4c84c39414926f Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 29 Oct 2025 11:35:39 -0500 Subject: [PATCH 7/7] Bring back inspect tests; make one of them simpler/easier --- tests/test_inspect.py | 233 +++++++++++++++++++++--------------------- 1 file changed, 115 insertions(+), 118 deletions(-) diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 4be0d837..e8212011 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -252,124 +252,121 @@ def test_export_eval_custom_turns(self, tmp_path): assert "Second question" in str(samples[0].input) -#class TestInspectIntegration: -# def test_basic_eval(self): -# chat = chat_func(system_prompt=SYSTEM_DEFAULT) -# -# task = create_task( -# chat, -# dataset=[Sample(input="What is 2+2?", target="4")], -# ) -# -# results = inspect_eval(task)[0].results -# -# assert results is not None -# accuracy = results.scores[0].metrics["accuracy"].value -# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" -# -# def test_system_prompt_override(self): -# chat = chat_func(system_prompt="You are Chuck Norris.") -# -# task = create_task( -# chat, -# dataset=[ -# Sample( -# input="Tell me a short story.", -# target="The answer can be any story, but should be in the style of Chuck Norris.", -# ) -# ], -# ) -# -# results = inspect_eval(task)[0].results -# -# assert results is not None -# accuracy = results.scores[0].metrics["accuracy"].value -# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" -# -# def test_existing_turns(self): -# chat = chat_func(system_prompt=SYSTEM_DEFAULT) -# -# chat.set_turns( -# [ -# Turn("user", "My name is Gregg."), -# Turn("assistant", "Hello Gregg! How can I assist you today?"), -# ] -# ) -# -# task = create_task( -# chat, -# dataset=[ -# Sample( -# input="What is my name?", -# target="The answer should include 'Gregg'", -# ) -# ], -# ) -# -# results = inspect_eval(task)[0].results -# -# assert results is not None -# accuracy = results.scores[0].metrics["accuracy"].value -# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" -# -# def test_tool_calling(self): -# chat = chat_func(system_prompt=SYSTEM_DEFAULT) -# -# def get_current_date(): -# """Get the current date in YYYY-MM-DD format.""" -# return datetime.datetime.now().strftime("%Y-%m-%d") -# -# chat.register_tool(get_current_date) -# -# task = create_task( -# chat, -# dataset=[ -# Sample( -# input="What is today's date?", -# target="A valid date should be provided and be some time on or after Oct 23rd 2025.", -# ) -# ], -# ) -# -# results = inspect_eval(task)[0].results -# -# assert results is not None -# accuracy = results.scores[0].metrics["accuracy"].value -# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" -# -# def test_multiple_samples_state_management(self): -# """Test that solver has independent state across multiple samples.""" -# -# chat = chat_func( -# system_prompt=""" -# A user is going to simply add ingredients to a shopping list. -# Your task is to simply report all known ingredients when one is added. -# """ -# ) -# -# task = create_task( -# chat, -# dataset=[ -# Sample( -# input="Add apples", -# target="The shopping list should contain only apples", -# ), -# Sample( -# input="Add bananas", -# target="The shopping list should contain only bananas", -# ), -# Sample( -# input="Add oranges", -# target="The shopping list should contain only oranges", -# ), -# ], -# ) -# -# results = inspect_eval(task)[0].results -# -# assert results is not None -# accuracy = results.scores[0].metrics["accuracy"].value -# assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" +class TestInspectIntegration: + def test_basic_eval(self): + chat = chat_func(system_prompt=SYSTEM_DEFAULT) + + task = create_task( + chat, + dataset=[Sample(input="What is 2+2?", target="4")], + ) + + results = inspect_eval(task)[0].results + + assert results is not None + accuracy = results.scores[0].metrics["accuracy"].value + assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" + + def test_system_prompt_override(self): + chat = chat_func(system_prompt="You are Chuck Norris.") + + task = create_task( + chat, + dataset=[ + Sample( + input="Tell me a short story.", + target="The answer can be any story, but should be in the style of Chuck Norris.", + ) + ], + ) + + results = inspect_eval(task)[0].results + + assert results is not None + accuracy = results.scores[0].metrics["accuracy"].value + assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" + + def test_existing_turns(self): + chat = chat_func(system_prompt=SYSTEM_DEFAULT) + + chat.set_turns( + [ + Turn("user", "My name is Gregg."), + Turn("assistant", "Hello Gregg! How can I assist you today?"), + ] + ) + + task = create_task( + chat, + dataset=[ + Sample( + input="What is my name?", + target="The answer should include 'Gregg'", + ) + ], + ) + + results = inspect_eval(task)[0].results + + assert results is not None + accuracy = results.scores[0].metrics["accuracy"].value + assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" + + def test_tool_calling(self): + chat = chat_func(system_prompt=SYSTEM_DEFAULT) + + def get_current_date(): + """Get the current date in YYYY-MM-DD format.""" + return datetime.datetime.now().strftime("%Y-%m-%d") + + chat.register_tool(get_current_date) + + task = create_task( + chat, + dataset=[ + Sample( + input="What is today's date?", + target="A valid date should be provided and be some time on or after Oct 23rd 2025.", + ) + ], + ) + + results = inspect_eval(task)[0].results + + assert results is not None + accuracy = results.scores[0].metrics["accuracy"].value + assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" + + def test_multiple_samples_state_management(self): + """Test that solver has independent state across multiple samples.""" + + chat = chat_func( + system_prompt=""" + The user is building a shopping list. + Your task is to simply report all known ingredients on every response. + Be very terse, no punctuation. + """ + ) + + task = create_task( + chat, + dataset=[ + Sample( + input="Add apples", + target="The response should contain only apples, no other fruit", + ), + Sample( + input="Add bananas", + target="The response should contain bananas, no other fruit", + ) + ], + ) + + results = inspect_eval(task)[0].results + + assert results is not None + accuracy = results.scores[0].metrics["accuracy"].value + assert accuracy == 1, f"Expected accuracy of 1, but got {accuracy}" class TestContentTranslation: