Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* A new `Turn.to_inspect_messages()` method for converting turns to Inspect's message format.
* Comprehensive documentation in the [Evals guide](https://posit-dev.github.io/chatlas/misc/evals.html).

### Changes

* `Provider` implementations now require an additional `.value_tokens()` method. Previously, it was assumed that token info was logged and attached to the `Turn` as part of the `.value_turn()` method. The logging and attaching is now handled automatically. (#194)



## [0.13.2] - 2025-10-02
Expand Down
14 changes: 12 additions & 2 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ._logging import log_tool_error
from ._mcp_manager import MCPSessionManager
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
from ._tokens import compute_cost, get_token_pricing
from ._tokens import compute_cost, get_token_pricing, tokens_log
from ._tools import Tool, ToolRejectError
from ._turn import Turn, user_turn
from ._typing_extensions import TypedDict, TypeGuard
Expand Down Expand Up @@ -958,7 +958,9 @@ async def solve(state: InspectTaskState, generate):
if user_prompt is None:
raise ValueError("No user prompt found in InspectAI state messages")

input_content = [inspect_content_as_chatlas(x) for x in user_prompt.content]
input_content = [
inspect_content_as_chatlas(x) for x in user_prompt.content
]

await chat_instance.chat_async(*input_content, echo="none")
last_turn = chat_instance.get_last_turn(role="assistant")
Expand Down Expand Up @@ -2560,6 +2562,10 @@ def emit(text: str | Content):
if echo == "all":
emit_other_contents(turn, emit)

if turn.tokens is None and turn.completion:
turn.tokens = self.provider.value_tokens(turn.completion)
if turn.tokens is not None:
tokens_log(self.provider, turn.tokens)
self._turns.extend([user_turn, turn])

async def _submit_turns_async(
Expand Down Expand Up @@ -2622,6 +2628,10 @@ def emit(text: str | Content):
if echo == "all":
emit_other_contents(turn, emit)

if turn.tokens is None and turn.completion:
turn.tokens = self.provider.value_tokens(turn.completion)
if turn.tokens is not None:
tokens_log(self.provider, turn.tokens)
self._turns.extend([user_turn, turn])

def _invoke_tool(self, request: ContentToolRequest):
Expand Down
6 changes: 6 additions & 0 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ def value_turn(
has_data_model: bool,
) -> Turn: ...

@abstractmethod
def value_tokens(
self,
completion: ChatCompletionT,
) -> tuple[int, int, int] | None: ...

@abstractmethod
def token_count(
self,
Expand Down
27 changes: 12 additions & 15 deletions chatlas/_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
StandardModelParamNames,
StandardModelParams,
)
from ._tokens import get_token_pricing, tokens_log
from ._tokens import get_token_pricing
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, user_turn
from ._utils import split_http_client_kwargs
Expand Down Expand Up @@ -411,6 +411,17 @@ def stream_turn(self, completion, has_data_model) -> Turn:
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def value_tokens(self, completion):
usage = completion.usage
# N.B. Currently, Anthropic doesn't cache by default and we currently do not support
# manual caching in chatlas. Note also that this only tracks reads, NOT writes, which
# have their own cost. To track that properly, we would need another caching category and per-token cost.
return (
completion.usage.input_tokens,
completion.usage.output_tokens,
usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0,
)

def token_count(
self,
*args: Content | str,
Expand Down Expand Up @@ -619,23 +630,9 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn:
)
)

usage = completion.usage
# N.B. Currently, Anthropic doesn't cache by default and we currently do not support
# manual caching in chatlas. Note also that this only tracks reads, NOT writes, which
# have their own cost. To track that properly, we would need another caching category and per-token cost.

tokens = (
completion.usage.input_tokens,
completion.usage.output_tokens,
usage.cache_read_input_tokens if usage.cache_read_input_tokens else 0,
)

tokens_log(self, tokens)

return Turn(
"assistant",
contents,
tokens=tokens,
finish_reason=completion.stop_reason,
completion=completion,
)
Expand Down
28 changes: 14 additions & 14 deletions chatlas/_provider_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ._logging import log_model_default
from ._merge import merge_dicts
from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams
from ._tokens import get_token_pricing, tokens_log
from ._tokens import get_token_pricing
from ._tools import Tool
from ._turn import Turn, user_turn

Expand Down Expand Up @@ -228,6 +228,7 @@ def chat_perform(

def chat_perform(
self,
*,
Comment thread
cpsievert marked this conversation as resolved.
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
Expand Down Expand Up @@ -264,6 +265,7 @@ async def chat_perform_async(

async def chat_perform_async(
self,
*,
Comment thread
cpsievert marked this conversation as resolved.
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
Expand Down Expand Up @@ -349,6 +351,17 @@ def value_turn(self, completion, has_data_model) -> Turn:
completion = cast("GenerateContentResponseDict", completion.model_dump())
return self._as_turn(completion, has_data_model)

def value_tokens(self, completion):
usage = completion.usage_metadata
if usage is None:
return None
cached = usage.cached_content_token_count or 0
return (
(usage.prompt_token_count or 0) - cached,
usage.candidates_token_count or 0,
usage.cached_content_token_count or 0,
)

def token_count(
self,
*args: Content | str,
Expand Down Expand Up @@ -528,25 +541,12 @@ def _as_turn(
)
)

usage = message.get("usage_metadata")
tokens = (0, 0, 0)
if usage:
cached = usage.get("cached_content_token_count") or 0
tokens = (
(usage.get("prompt_token_count") or 0) - cached,
usage.get("candidates_token_count") or 0,
usage.get("cached_content_token_count") or 0,
)

tokens_log(self, tokens)

if isinstance(finish_reason, FinishReason):
finish_reason = finish_reason.name

return Turn(
"assistant",
contents,
tokens=tokens,
finish_reason=finish_reason,
completion=message,
)
Expand Down
55 changes: 27 additions & 28 deletions chatlas/_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
StandardModelParamNames,
StandardModelParams,
)
from ._tokens import get_token_pricing, tokens_log
from ._tokens import get_token_pricing
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, user_turn
from ._utils import MISSING, MISSING_TYPE, is_testing, split_http_client_kwargs
Expand Down Expand Up @@ -381,6 +381,32 @@ def stream_turn(self, completion, has_data_model) -> Turn:
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def value_tokens(self, completion):
usage = completion.usage
if usage is None:
# For some reason ChatGroq() includes tokens under completion.x_groq
# Groq does not support caching, so we set cached_tokens to 0
if hasattr(completion, "x_groq"):
usage = completion.x_groq["usage"] # type: ignore
return usage["prompt_tokens"], usage["completion_tokens"], 0
else:
return None

if usage.prompt_tokens_details is not None:
cached_tokens = (
usage.prompt_tokens_details.cached_tokens
if usage.prompt_tokens_details.cached_tokens
else 0
)
else:
cached_tokens = 0

return (
usage.prompt_tokens - cached_tokens,
usage.completion_tokens,
cached_tokens,
)

def token_count(
self,
*args: Content | str,
Expand Down Expand Up @@ -606,36 +632,9 @@ def _as_turn(
)
)

usage = completion.usage
if usage is None:
tokens = (0, 0, 0)
else:
if usage.prompt_tokens_details is not None:
cached_tokens = (
usage.prompt_tokens_details.cached_tokens
if usage.prompt_tokens_details.cached_tokens
else 0
)
else:
cached_tokens = 0
tokens = (
usage.prompt_tokens - cached_tokens,
usage.completion_tokens,
cached_tokens,
)

# For some reason ChatGroq() includes tokens under completion.x_groq
# Groq does not support caching, so we set cached_tokens to 0
if usage is None and hasattr(completion, "x_groq"):
usage = completion.x_groq["usage"] # type: ignore
tokens = usage["prompt_tokens"], usage["completion_tokens"], 0

tokens_log(self, tokens)

return Turn(
"assistant",
contents,
tokens=tokens,
finish_reason=completion.choices[0].finish_reason,
completion=completion,
)
Expand Down
22 changes: 12 additions & 10 deletions chatlas/_provider_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
)
from ._logging import log_model_default
from ._provider import Provider, StandardModelParamNames, StandardModelParams
from ._tokens import tokens_log
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn
from ._utils import drop_none
Expand Down Expand Up @@ -441,6 +440,18 @@ def stream_turn(self, completion, has_data_model) -> Turn:
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def value_tokens(self, completion):
# Snowflake does not currently appear to support caching, so we set cached tokens to 0
usage = completion.usage
if usage is None:
return None

return (
usage.prompt_tokens or 0,
usage.completion_tokens or 0,
0,
)

def token_count(
self,
*args: "Content | str",
Expand Down Expand Up @@ -543,19 +554,10 @@ def _as_turn(self, completion: "Completion", has_data_model: bool) -> Turn:
arguments=params,
)
)
# Snowflake does not currently appear to support caching, so we set cached tokens to 0
usage = completion.usage
if usage is None:
tokens = (0, 0, 0)
else:
tokens = (usage.prompt_tokens or 0, usage.completion_tokens or 0, 0)

tokens_log(self, tokens)

return Turn(
"assistant",
contents,
tokens=tokens,
# TODO: no finish_reason in Snowflake?
# finish_reason=completion.choices[0].finish_reason,
completion=completion,
Expand Down
15 changes: 6 additions & 9 deletions tests/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,9 @@ def test_multiple_samples_state_management(self):

chat = chat_func(
system_prompt="""
A user is going to simply add ingredients to a shopping list.
Your task is to simply report all known ingredients when one is added.
The user is building a shopping list.
Your task is to simply report all known ingredients on every response.
Be very terse, no punctuation.
"""
)

Expand All @@ -352,16 +353,12 @@ def test_multiple_samples_state_management(self):
dataset=[
Sample(
input="Add apples",
target="The shopping list should contain only apples",
target="The response should contain only apples, no other fruit",
),
Sample(
input="Add bananas",
target="The shopping list should contain only bananas",
),
Sample(
input="Add oranges",
target="The shopping list should contain only oranges",
),
target="The response should contain bananas, no other fruit",
)
],
)

Expand Down
Loading