diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c844799..085ec6b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * A new `Turn.to_inspect_messages()` method for converting turns to Inspect's message format. * Comprehensive documentation in the [Evals guide](https://posit-dev.github.io/chatlas/misc/evals.html). +### Changes + +* `Provider` implementations now require an additional `.value_tokens()` method. Previously, it was assumed that token info was logged and attached to the `Turn` as part of the `.value_turn()` method. The logging and attaching is now handled automatically. (#194) + ## [0.13.2] - 2025-10-02 diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 2d853e7a..2fab6531 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -47,7 +47,7 @@ from ._logging import log_tool_error from ._mcp_manager import MCPSessionManager from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT -from ._tokens import compute_cost, get_token_pricing +from ._tokens import compute_cost, get_token_pricing, tokens_log from ._tools import Tool, ToolRejectError from ._turn import Turn, user_turn from ._typing_extensions import TypedDict, TypeGuard @@ -958,7 +958,9 @@ async def solve(state: InspectTaskState, generate): if user_prompt is None: raise ValueError("No user prompt found in InspectAI state messages") - input_content = [inspect_content_as_chatlas(x) for x in user_prompt.content] + input_content = [ + inspect_content_as_chatlas(x) for x in user_prompt.content + ] await chat_instance.chat_async(*input_content, echo="none") last_turn = chat_instance.get_last_turn(role="assistant") @@ -2560,6 +2562,10 @@ def emit(text: str | Content): if echo == "all": emit_other_contents(turn, emit) + if turn.tokens is None and turn.completion: + turn.tokens = self.provider.value_tokens(turn.completion) + if turn.tokens is not None: + tokens_log(self.provider, turn.tokens) self._turns.extend([user_turn, turn]) async def _submit_turns_async( @@ -2622,6 +2628,10 @@ def emit(text: str | Content): if echo == "all": emit_other_contents(turn, emit) + if turn.tokens is None and turn.completion: + turn.tokens = self.provider.value_tokens(turn.completion) + if turn.tokens is not None: + tokens_log(self.provider, turn.tokens) self._turns.extend([user_turn, turn]) def _invoke_tool(self, request: ContentToolRequest): diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 7e216ec2..98347eb4 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -249,6 +249,12 @@ def value_turn( has_data_model: bool, ) -> Turn: ... + @abstractmethod + def value_tokens( + self, + completion: ChatCompletionT, + ) -> tuple[int, int, int] | None: ... + @abstractmethod def token_count( self, diff --git a/chatlas/_provider_anthropic.py b/chatlas/_provider_anthropic.py index 70f01666..4f7cbb09 100644 --- a/chatlas/_provider_anthropic.py +++ b/chatlas/_provider_anthropic.py @@ -30,7 +30,7 @@ StandardModelParamNames, StandardModelParams, ) -from ._tokens import get_token_pricing, tokens_log +from ._tokens import get_token_pricing from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, user_turn from ._utils import split_http_client_kwargs @@ -411,6 +411,17 @@ def stream_turn(self, completion, has_data_model) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + usage = completion.usage + # N.B. Currently, Anthropic doesn't cache by default and we currently do not support + # manual caching in chatlas. Note also that this only tracks reads, NOT writes, which + # have their own cost. To track that properly, we would need another caching category and per-token cost. + return ( + completion.usage.input_tokens, + completion.usage.output_tokens, + usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0, + ) + def token_count( self, *args: Content | str, @@ -619,23 +630,9 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn: ) ) - usage = completion.usage - # N.B. Currently, Anthropic doesn't cache by default and we currently do not support - # manual caching in chatlas. Note also that this only tracks reads, NOT writes, which - # have their own cost. To track that properly, we would need another caching category and per-token cost. - - tokens = ( - completion.usage.input_tokens, - completion.usage.output_tokens, - usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0, - ) - - tokens_log(self, tokens) - return Turn( "assistant", contents, - tokens=tokens, finish_reason=completion.stop_reason, completion=completion, ) diff --git a/chatlas/_provider_google.py b/chatlas/_provider_google.py index 9cf7f752..7217de0b 100644 --- a/chatlas/_provider_google.py +++ b/chatlas/_provider_google.py @@ -22,7 +22,7 @@ from ._logging import log_model_default from ._merge import merge_dicts from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams -from ._tokens import get_token_pricing, tokens_log +from ._tokens import get_token_pricing from ._tools import Tool from ._turn import Turn, user_turn @@ -228,6 +228,7 @@ def chat_perform( def chat_perform( self, + *, stream: bool, turns: list[Turn], tools: dict[str, Tool], @@ -264,6 +265,7 @@ async def chat_perform_async( async def chat_perform_async( self, + *, stream: bool, turns: list[Turn], tools: dict[str, Tool], @@ -349,6 +351,17 @@ def value_turn(self, completion, has_data_model) -> Turn: completion = cast("GenerateContentResponseDict", completion.model_dump()) return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + usage = completion.usage_metadata + if usage is None: + return None + cached = usage.cached_content_token_count or 0 + return ( + (usage.prompt_token_count or 0) - cached, + usage.candidates_token_count or 0, + usage.cached_content_token_count or 0, + ) + def token_count( self, *args: Content | str, @@ -528,25 +541,12 @@ def _as_turn( ) ) - usage = message.get("usage_metadata") - tokens = (0, 0, 0) - if usage: - cached = usage.get("cached_content_token_count") or 0 - tokens = ( - (usage.get("prompt_token_count") or 0) - cached, - usage.get("candidates_token_count") or 0, - usage.get("cached_content_token_count") or 0, - ) - - tokens_log(self, tokens) - if isinstance(finish_reason, FinishReason): finish_reason = finish_reason.name return Turn( "assistant", contents, - tokens=tokens, finish_reason=finish_reason, completion=message, ) diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 87415175..6271d5e8 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -38,7 +38,7 @@ StandardModelParamNames, StandardModelParams, ) -from ._tokens import get_token_pricing, tokens_log +from ._tokens import get_token_pricing from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, user_turn from ._utils import MISSING, MISSING_TYPE, is_testing, split_http_client_kwargs @@ -381,6 +381,32 @@ def stream_turn(self, completion, has_data_model) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + usage = completion.usage + if usage is None: + # For some reason ChatGroq() includes tokens under completion.x_groq + # Groq does not support caching, so we set cached_tokens to 0 + if hasattr(completion, "x_groq"): + usage = completion.x_groq["usage"] # type: ignore + return usage["prompt_tokens"], usage["completion_tokens"], 0 + else: + return None + + if usage.prompt_tokens_details is not None: + cached_tokens = ( + usage.prompt_tokens_details.cached_tokens + if usage.prompt_tokens_details.cached_tokens + else 0 + ) + else: + cached_tokens = 0 + + return ( + usage.prompt_tokens - cached_tokens, + usage.completion_tokens, + cached_tokens, + ) + def token_count( self, *args: Content | str, @@ -606,36 +632,9 @@ def _as_turn( ) ) - usage = completion.usage - if usage is None: - tokens = (0, 0, 0) - else: - if usage.prompt_tokens_details is not None: - cached_tokens = ( - usage.prompt_tokens_details.cached_tokens - if usage.prompt_tokens_details.cached_tokens - else 0 - ) - else: - cached_tokens = 0 - tokens = ( - usage.prompt_tokens - cached_tokens, - usage.completion_tokens, - cached_tokens, - ) - - # For some reason ChatGroq() includes tokens under completion.x_groq - # Groq does not support caching, so we set cached_tokens to 0 - if usage is None and hasattr(completion, "x_groq"): - usage = completion.x_groq["usage"] # type: ignore - tokens = usage["prompt_tokens"], usage["completion_tokens"], 0 - - tokens_log(self, tokens) - return Turn( "assistant", contents, - tokens=tokens, finish_reason=completion.choices[0].finish_reason, completion=completion, ) diff --git a/chatlas/_provider_snowflake.py b/chatlas/_provider_snowflake.py index 1e430555..3eb98fac 100644 --- a/chatlas/_provider_snowflake.py +++ b/chatlas/_provider_snowflake.py @@ -21,7 +21,6 @@ ) from ._logging import log_model_default from ._provider import Provider, StandardModelParamNames, StandardModelParams -from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn from ._utils import drop_none @@ -441,6 +440,18 @@ def stream_turn(self, completion, has_data_model) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def value_tokens(self, completion): + # Snowflake does not currently appear to support caching, so we set cached tokens to 0 + usage = completion.usage + if usage is None: + return None + + return ( + usage.prompt_tokens or 0, + usage.completion_tokens or 0, + 0, + ) + def token_count( self, *args: "Content | str", @@ -543,19 +554,10 @@ def _as_turn(self, completion: "Completion", has_data_model: bool) -> Turn: arguments=params, ) ) - # Snowflake does not currently appear to support caching, so we set cached tokens to 0 - usage = completion.usage - if usage is None: - tokens = (0, 0, 0) - else: - tokens = (usage.prompt_tokens or 0, usage.completion_tokens or 0, 0) - - tokens_log(self, tokens) return Turn( "assistant", contents, - tokens=tokens, # TODO: no finish_reason in Snowflake? # finish_reason=completion.choices[0].finish_reason, completion=completion, diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 24650f78..e8212011 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -342,8 +342,9 @@ def test_multiple_samples_state_management(self): chat = chat_func( system_prompt=""" - A user is going to simply add ingredients to a shopping list. - Your task is to simply report all known ingredients when one is added. + The user is building a shopping list. + Your task is to simply report all known ingredients on every response. + Be very terse, no punctuation. """ ) @@ -352,16 +353,12 @@ def test_multiple_samples_state_management(self): dataset=[ Sample( input="Add apples", - target="The shopping list should contain only apples", + target="The response should contain only apples, no other fruit", ), Sample( input="Add bananas", - target="The shopping list should contain only bananas", - ), - Sample( - input="Add oranges", - target="The shopping list should contain only oranges", - ), + target="The response should contain bananas, no other fruit", + ) ], )