Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions evals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .api import CompletionResult as CompletionResult
from .api import DummyCompletionFn as DummyCompletionFn
from .api import record_and_check_match as record_and_check_match
from .completion_fns.openai import UserChatCompletionFn as UserChatCompletionFn
from .completion_fns.openai import OpenAIChatCompletionFn as OpenAIChatCompletionFn
from .completion_fns.openai import OpenAICompletionFn as OpenAICompletionFn
from .completion_fns.openai import OpenAICompletionResult as OpenAICompletionResult
Expand Down
44 changes: 43 additions & 1 deletion evals/cli/oaieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import shlex
import sys
import ast
from typing import Any, Mapping, Optional, Union, cast

import evals
Expand Down Expand Up @@ -37,6 +38,26 @@ def get_parser() -> argparse.ArgumentParser:
default="",
help="Specify additional parameters to modify the behavior of the completion_fn during its creation. Parameters should be passed as a comma-separated list of key-value pairs (e.g., 'key1=value1,key2=value2'). This option allows for the dynamic modification of completion_fn settings, including the ability to override default arguments where necessary.",
)
parser.add_argument(
"--extra_eval_cls",
type=str,
default="",
help="实际制定的评估类, 对应yaml文件中的cls, 优先级高于yaml文件配置"
)
parser.add_argument(
"--extra_eval_metrics",
type=list[str],
default=["accuracy", "f1_score"],
help="在评估类中需要使用的评估指标,如果不指定则使用评估类中的默认指标"
)
parser.add_argument("--api_base", type=str, default="", help="直接通过ip:port发送请求时,该参数表示发送请求的url")
parser.add_argument(
"--payload",
type=str,
default="{'max_tokens': 100, 'temperature': 0.0, 'stream': False, 'ignore_eos': False}",
help="和api_base结合使用, 表示采用ip:port发请求时的payload, 需要注意的是, 仅在api_base非空时, 该参数才会起作用"
)
parser.add_argument("--enable_pc_offload", type=bool, default=True, help="是否启用pc offload,仅在api_base非空时才生效")
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument("--cache", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--visible", action=argparse.BooleanOptionalAction, default=None)
Expand Down Expand Up @@ -97,6 +118,12 @@ class OaiEvalArguments(argparse.Namespace):
completion_fn: str
eval: str
extra_eval_params: str
completion_args: str
extra_eval_cls: str
extra_eval_metrics: list[str]
api_base: str
payload: str
enable_pc_offload: bool
max_samples: Optional[int]
cache: bool
visible: Optional[bool]
Expand Down Expand Up @@ -166,10 +193,20 @@ def to_number(x: str) -> Union[int, float, str]:
additional_completion_args = {k: v for k, v in (kv.split("=") for kv in completion_args if kv)}

completion_fns = args.completion_fn.split(",")
# 当api_base非空时,首先通过api_base这个url去访问对应引擎
if args.api_base is not None:
additional_completion_args["api_base"] = args.api_base
additional_completion_args["payload"] = ast.literal_eval(args.payload)
additional_completion_args["enable_pc_offload"] = args.enable_pc_offload

completion_fn_instances = [
registry.make_completion_fn(url, **additional_completion_args) for url in completion_fns
]

# 如果extra_eval_cls为空,使用yaml文件中指定的类作为评估类,否则使用extra_eval_cls对其更新,这里需要在run_config前, 因为record中会记录所采用的评估参数
if args.extra_eval_cls:
eval_spec.cls = args.extra_eval_cls

run_config = {
"completion_fns": completion_fns,
"eval_spec": eval_spec,
Expand All @@ -194,8 +231,9 @@ def to_number(x: str) -> Union[int, float, str]:
created_by=args.user,
)

# 日志路径
record_path = (
f"/tmp/evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl"
f"./evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl"
if args.record_path is None
else args.record_path
)
Expand All @@ -214,6 +252,7 @@ def to_number(x: str) -> Union[int, float, str]:
run_url = f"{run_spec.run_id}"
logger.info(_purple(f"Run started: {run_url}"))

# 实例化具体匹配方式(比如match)等,这里会将samples_jsonl的路径传入匹配类
eval_class = registry.get_class(eval_spec)
eval: Eval = eval_class(
completion_fns=completion_fn_instances,
Expand All @@ -223,6 +262,9 @@ def to_number(x: str) -> Union[int, float, str]:
registry=registry,
**extra_eval_params,
)

# TODO:添加自选评估指标,这里为了不要对所有的评估类进行修改,不应该在run函数中新增参数,看看能把不能在recorder中添加参数
# 调用match等匹配方法的run函数,进行数据集的处理和实际的请求发送
result = eval.run(recorder)
try:
add_token_usage_to_result(result, recorder)
Expand Down
143 changes: 140 additions & 3 deletions evals/completion_fns/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
from typing import Any, Optional, Union

import requests
import json
import copy
import asyncio
import openai
import aiohttp
from openai import OpenAI

from evals.api import CompletionFn, CompletionResult
Expand All @@ -23,6 +27,12 @@
openai.InternalServerError,
)

logger = logging.getLogger(__name__)

# 通过ip:port发请求时设置的超时时间
TIMEOUT = 3 * 3600
# 发送请求的HEADERS
HEADERS = {"User-Agent": "Benchmark Client", "Content-Type": "application/json"}

def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
"""
Expand All @@ -33,7 +43,7 @@ def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
client.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
logging.warning(result)
logger.warning(result)
raise openai.APIError(result["error"])
return result

Expand All @@ -47,7 +57,7 @@ def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs):
client.chat.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
)
if "error" in result:
logging.warning(result)
logger.warning(result)
raise openai.APIError(result["error"])
return result

Expand Down Expand Up @@ -179,3 +189,130 @@ def __call__(
usage=result.raw_data.usage,
)
return result

class UserChatCompletionResult(OpenAIBaseCompletionResult):
def get_completions(self) -> list[str]:
completions = []
if self.raw_data:
for choice in self.raw_data.get("choices", []):
if choice.get("message", {}).get("content", '') is not None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated calls to .get("message", {}) can be extracted into a local variable.

content = choice.get("message", {}).get("content", '')
elif choice.get("message", {}).get("reasoning_content", '') is not None:
reasoning_content = choice.get("message", {}).get("reasoning_content", '')
output = content if content != "" else reasoning_content
logger.info(f"output: {output}")
request_id = choice.get("message", {}).get("id", "request_id not found")
if output == "":
logger.error(f"Request_id为{request_id}的请求返回信息为空,请检查具体原因!!!")
raise Exception(f"Request_id为{request_id}的请求返回信息为空,请检查具体原因!!!")
completions.append(output)
return completions

class UserChatCompletionFn(CompletionFnSpec):
def __init__(
self,
model: Optional[str] = None,
api_base: str = "",
payload: dict = {},
enable_pc_offload: bool = False,
extra_options: Optional[dict] = {},
):
self.model = model
self.api_base = api_base
assert self.api_base, "api_base is required"
self.url = f"http://{self.api_base}/v1/chat/completions"
self.payload = payload
# 如果payload为str类型,将其转为dict
if isinstance(self.payload, str):
self.payload = json.loads(self.payload)
assert self.payload, "payload is required"
self.enable_pc_offload = enable_pc_offload
self.extra_options = extra_options

def __call__(
self,
prompt: Union[str, OpenAICreateChatPrompt],
**kwargs,
) -> UserChatCompletionResult:
if not isinstance(prompt, Prompt):
assert (
isinstance(prompt, str)
or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt))
or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt))
or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt))
), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]"

prompt = ChatCompletionPrompt(
raw_prompt=prompt,
)

# 这里多并发情况下每个每个请求需要有自己独立的payload
per_request_payload = copy.deepcopy(self.payload)
user_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt()
per_request_payload = self._update_payload(per_request_payload, user_create_prompt, self.model, self.enable_pc_offload)
# result = asyncio.run(self._async_do_request(per_request_payload))
# 这里由于eval.py中的eval_all_samples通过线性池进行调用,这里采用同步发送请求的方式
result = self._do_request(per_request_payload)

result = UserChatCompletionResult(raw_data=result, prompt=user_create_prompt)
record_sampling(
prompt=result.prompt,
sampled=result.get_completions(),
model=result.raw_data.get("model"),
usage=result.raw_data.get("usage"),
)
return result

def _update_payload(self, payload: dict, prompt: str, model: str, enable_pc_offload: bool) -> None:
"""
对payload进行更新, 主要需要对prompt、model名, 以及enable_pc_ooffload进行更新
"""
payload.update({"messages": prompt})
payload.update({"model": model})
payload.update({"enable_pc_offload": enable_pc_offload})
return payload

def _do_request(self, payload: dict):
"""
同步发请求, eval.py中的eval_all_samples通过线性池调用时使用该函数
"""
try:
response = requests.post(self.url, headers=HEADERS, json=payload)
response.raise_for_status()
return json.loads(response.text)
except Exception as err:
self._handle_request_error(err)

async def _async_do_request(self, payload: dict):
"""
异步发请求, 当采用eval中的async_eval_all_samples运行实际请求时使用该函数
"""
try:
timeout = aiohttp.ClientTimeout(total=TIMEOUT)
async with aiohttp.ClientSession(timeout=timeout, connector=aiohttp.TCPConnector(ssl=False)) as session:
async with session.post(self.url, headers=HEADERS, json=payload) as response:
response.raise_for_status()
return await response.json
except Exception as err:
self._handle_request_error(err)

def _handle_request_error(self, err: Exception) -> None:
"""
用于对_do_request及_async_do_request的异常进行处理, 使用户能够更清楚出现异常的原因
"""
if isinstance(err, (requests.exceptions.ConnectionError, aiohttp.ClientConnectionError)):
logger.error(f"无法连接到服务器{self.api_base}, 检查网络是否可达")
raise ConnectionError(f"无法连接到服务器{self.api_base}, 检查网络是否可达")
elif isinstance(err, (requests.exceptions.Timeout, aiohttp.ServerTimeoutError)):
logger.error("请求超时,检查服务端状态")
raise TimeoutError("请求超时,检查服务端状态")
elif isinstance(err, (requests.exceptions.HTTPError, aiohttp.ClientResponseError)):
status_code = err.response.status_code if hasattr(err, "response") else err.status
if status_code == 404:
logger.error(f"请求资源不存在或是model名称错误")
else:
logger.error(f"HTTP错误, 状态码: {status_code}")
raise Exception(f"HTTP错误, 状态码: {status_code}")
else:
logger.error(f"其他未知错误: {err}")
raise Exception(f"其他未知错误: {err}")
8 changes: 6 additions & 2 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import yaml
from openai import OpenAI

from evals import OpenAIChatCompletionFn, OpenAICompletionFn
from evals import OpenAIChatCompletionFn, OpenAICompletionFn, UserChatCompletionFn
from evals.api import CompletionFn, DummyCompletionFn
from evals.base import BaseEvalSpec, CompletionFnSpec, EvalSetSpec, EvalSpec
from evals.elsuite.modelgraded.base import ModelGradedSpec
from evals.utils.misc import make_object

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
client = OpenAI(api_key="***************")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,6 +130,9 @@ def make_completion_fn(
"""
if name == "dummy":
return DummyCompletionFn()

if kwargs.get("api_base") is not None:
return UserChatCompletionFn(model=name, **kwargs)

n_ctx = n_ctx_from_model_name(name)

Expand Down