diff --git a/evals/__init__.py b/evals/__init__.py index dfd580f825..24fd9ce390 100644 --- a/evals/__init__.py +++ b/evals/__init__.py @@ -2,6 +2,7 @@ from .api import CompletionResult as CompletionResult from .api import DummyCompletionFn as DummyCompletionFn from .api import record_and_check_match as record_and_check_match +from .completion_fns.openai import UserChatCompletionFn as UserChatCompletionFn from .completion_fns.openai import OpenAIChatCompletionFn as OpenAIChatCompletionFn from .completion_fns.openai import OpenAICompletionFn as OpenAICompletionFn from .completion_fns.openai import OpenAICompletionResult as OpenAICompletionResult diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index e48a09ac19..8aac80b395 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -5,6 +5,7 @@ import logging import shlex import sys +import ast from typing import Any, Mapping, Optional, Union, cast import evals @@ -37,6 +38,26 @@ def get_parser() -> argparse.ArgumentParser: default="", help="Specify additional parameters to modify the behavior of the completion_fn during its creation. Parameters should be passed as a comma-separated list of key-value pairs (e.g., 'key1=value1,key2=value2'). This option allows for the dynamic modification of completion_fn settings, including the ability to override default arguments where necessary.", ) + parser.add_argument( + "--extra_eval_cls", + type=str, + default="", + help="实际制定的评估类, 对应yaml文件中的cls, 优先级高于yaml文件配置" + ) + parser.add_argument( + "--extra_eval_metrics", + type=list[str], + default=["accuracy", "f1_score"], + help="在评估类中需要使用的评估指标,如果不指定则使用评估类中的默认指标" + ) + parser.add_argument("--api_base", type=str, default="", help="直接通过ip:port发送请求时,该参数表示发送请求的url") + parser.add_argument( + "--payload", + type=str, + default="{'max_tokens': 100, 'temperature': 0.0, 'stream': False, 'ignore_eos': False}", + help="和api_base结合使用, 表示采用ip:port发请求时的payload, 需要注意的是, 仅在api_base非空时, 该参数才会起作用" + ) + parser.add_argument("--enable_pc_offload", type=bool, default=True, help="是否启用pc offload,仅在api_base非空时才生效") parser.add_argument("--max_samples", type=int, default=None) parser.add_argument("--cache", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--visible", action=argparse.BooleanOptionalAction, default=None) @@ -97,6 +118,12 @@ class OaiEvalArguments(argparse.Namespace): completion_fn: str eval: str extra_eval_params: str + completion_args: str + extra_eval_cls: str + extra_eval_metrics: list[str] + api_base: str + payload: str + enable_pc_offload: bool max_samples: Optional[int] cache: bool visible: Optional[bool] @@ -166,10 +193,20 @@ def to_number(x: str) -> Union[int, float, str]: additional_completion_args = {k: v for k, v in (kv.split("=") for kv in completion_args if kv)} completion_fns = args.completion_fn.split(",") + # 当api_base非空时,首先通过api_base这个url去访问对应引擎 + if args.api_base is not None: + additional_completion_args["api_base"] = args.api_base + additional_completion_args["payload"] = ast.literal_eval(args.payload) + additional_completion_args["enable_pc_offload"] = args.enable_pc_offload + completion_fn_instances = [ registry.make_completion_fn(url, **additional_completion_args) for url in completion_fns ] + # 如果extra_eval_cls为空,使用yaml文件中指定的类作为评估类,否则使用extra_eval_cls对其更新,这里需要在run_config前, 因为record中会记录所采用的评估参数 + if args.extra_eval_cls: + eval_spec.cls = args.extra_eval_cls + run_config = { "completion_fns": completion_fns, "eval_spec": eval_spec, @@ -194,8 +231,9 @@ def to_number(x: str) -> Union[int, float, str]: created_by=args.user, ) + # 日志路径 record_path = ( - f"/tmp/evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl" + f"./evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl" if args.record_path is None else args.record_path ) @@ -214,6 +252,7 @@ def to_number(x: str) -> Union[int, float, str]: run_url = f"{run_spec.run_id}" logger.info(_purple(f"Run started: {run_url}")) + # 实例化具体匹配方式(比如match)等,这里会将samples_jsonl的路径传入匹配类 eval_class = registry.get_class(eval_spec) eval: Eval = eval_class( completion_fns=completion_fn_instances, @@ -223,6 +262,9 @@ def to_number(x: str) -> Union[int, float, str]: registry=registry, **extra_eval_params, ) + + # TODO:添加自选评估指标,这里为了不要对所有的评估类进行修改,不应该在run函数中新增参数,看看能把不能在recorder中添加参数 + # 调用match等匹配方法的run函数,进行数据集的处理和实际的请求发送 result = eval.run(recorder) try: add_token_usage_to_result(result, recorder) diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index 21524bfc1a..967de71bbe 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -1,7 +1,11 @@ import logging from typing import Any, Optional, Union - +import requests +import json +import copy +import asyncio import openai +import aiohttp from openai import OpenAI from evals.api import CompletionFn, CompletionResult @@ -23,6 +27,12 @@ openai.InternalServerError, ) +logger = logging.getLogger(__name__) + +# 通过ip:port发请求时设置的超时时间 +TIMEOUT = 3 * 3600 +# 发送请求的HEADERS +HEADERS = {"User-Agent": "Benchmark Client", "Content-Type": "application/json"} def openai_completion_create_retrying(client: OpenAI, *args, **kwargs): """ @@ -33,7 +43,7 @@ def openai_completion_create_retrying(client: OpenAI, *args, **kwargs): client.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs ) if "error" in result: - logging.warning(result) + logger.warning(result) raise openai.APIError(result["error"]) return result @@ -47,7 +57,7 @@ def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs): client.chat.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs ) if "error" in result: - logging.warning(result) + logger.warning(result) raise openai.APIError(result["error"]) return result @@ -179,3 +189,130 @@ def __call__( usage=result.raw_data.usage, ) return result + +class UserChatCompletionResult(OpenAIBaseCompletionResult): + def get_completions(self) -> list[str]: + completions = [] + if self.raw_data: + for choice in self.raw_data.get("choices", []): + if choice.get("message", {}).get("content", '') is not None: + content = choice.get("message", {}).get("content", '') + elif choice.get("message", {}).get("reasoning_content", '') is not None: + reasoning_content = choice.get("message", {}).get("reasoning_content", '') + output = content if content != "" else reasoning_content + logger.info(f"output: {output}") + request_id = choice.get("message", {}).get("id", "request_id not found") + if output == "": + logger.error(f"Request_id为{request_id}的请求返回信息为空,请检查具体原因!!!") + raise Exception(f"Request_id为{request_id}的请求返回信息为空,请检查具体原因!!!") + completions.append(output) + return completions + +class UserChatCompletionFn(CompletionFnSpec): + def __init__( + self, + model: Optional[str] = None, + api_base: str = "", + payload: dict = {}, + enable_pc_offload: bool = False, + extra_options: Optional[dict] = {}, + ): + self.model = model + self.api_base = api_base + assert self.api_base, "api_base is required" + self.url = f"http://{self.api_base}/v1/chat/completions" + self.payload = payload + # 如果payload为str类型,将其转为dict + if isinstance(self.payload, str): + self.payload = json.loads(self.payload) + assert self.payload, "payload is required" + self.enable_pc_offload = enable_pc_offload + self.extra_options = extra_options + + def __call__( + self, + prompt: Union[str, OpenAICreateChatPrompt], + **kwargs, + ) -> UserChatCompletionResult: + if not isinstance(prompt, Prompt): + assert ( + isinstance(prompt, str) + or (isinstance(prompt, list) and all(isinstance(token, int) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(token, str) for token in prompt)) + or (isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt)) + ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" + + prompt = ChatCompletionPrompt( + raw_prompt=prompt, + ) + + # 这里多并发情况下每个每个请求需要有自己独立的payload + per_request_payload = copy.deepcopy(self.payload) + user_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt() + per_request_payload = self._update_payload(per_request_payload, user_create_prompt, self.model, self.enable_pc_offload) + # result = asyncio.run(self._async_do_request(per_request_payload)) + # 这里由于eval.py中的eval_all_samples通过线性池进行调用,这里采用同步发送请求的方式 + result = self._do_request(per_request_payload) + + result = UserChatCompletionResult(raw_data=result, prompt=user_create_prompt) + record_sampling( + prompt=result.prompt, + sampled=result.get_completions(), + model=result.raw_data.get("model"), + usage=result.raw_data.get("usage"), + ) + return result + + def _update_payload(self, payload: dict, prompt: str, model: str, enable_pc_offload: bool) -> None: + """ + 对payload进行更新, 主要需要对prompt、model名, 以及enable_pc_ooffload进行更新 + """ + payload.update({"messages": prompt}) + payload.update({"model": model}) + payload.update({"enable_pc_offload": enable_pc_offload}) + return payload + + def _do_request(self, payload: dict): + """ + 同步发请求, eval.py中的eval_all_samples通过线性池调用时使用该函数 + """ + try: + response = requests.post(self.url, headers=HEADERS, json=payload) + response.raise_for_status() + return json.loads(response.text) + except Exception as err: + self._handle_request_error(err) + + async def _async_do_request(self, payload: dict): + """ + 异步发请求, 当采用eval中的async_eval_all_samples运行实际请求时使用该函数 + """ + try: + timeout = aiohttp.ClientTimeout(total=TIMEOUT) + async with aiohttp.ClientSession(timeout=timeout, connector=aiohttp.TCPConnector(ssl=False)) as session: + async with session.post(self.url, headers=HEADERS, json=payload) as response: + response.raise_for_status() + return await response.json + except Exception as err: + self._handle_request_error(err) + + def _handle_request_error(self, err: Exception) -> None: + """ + 用于对_do_request及_async_do_request的异常进行处理, 使用户能够更清楚出现异常的原因 + """ + if isinstance(err, (requests.exceptions.ConnectionError, aiohttp.ClientConnectionError)): + logger.error(f"无法连接到服务器{self.api_base}, 检查网络是否可达") + raise ConnectionError(f"无法连接到服务器{self.api_base}, 检查网络是否可达") + elif isinstance(err, (requests.exceptions.Timeout, aiohttp.ServerTimeoutError)): + logger.error("请求超时,检查服务端状态") + raise TimeoutError("请求超时,检查服务端状态") + elif isinstance(err, (requests.exceptions.HTTPError, aiohttp.ClientResponseError)): + status_code = err.response.status_code if hasattr(err, "response") else err.status + if status_code == 404: + logger.error(f"请求资源不存在或是model名称错误") + else: + logger.error(f"HTTP错误, 状态码: {status_code}") + raise Exception(f"HTTP错误, 状态码: {status_code}") + else: + logger.error(f"其他未知错误: {err}") + raise Exception(f"其他未知错误: {err}") \ No newline at end of file diff --git a/evals/registry.py b/evals/registry.py index 2d1c0fee1d..7755255f99 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -17,13 +17,14 @@ import yaml from openai import OpenAI -from evals import OpenAIChatCompletionFn, OpenAICompletionFn +from evals import OpenAIChatCompletionFn, OpenAICompletionFn, UserChatCompletionFn from evals.api import CompletionFn, DummyCompletionFn from evals.base import BaseEvalSpec, CompletionFnSpec, EvalSetSpec, EvalSpec from evals.elsuite.modelgraded.base import ModelGradedSpec from evals.utils.misc import make_object -client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +# client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +client = OpenAI(api_key="***************") logger = logging.getLogger(__name__) @@ -129,6 +130,9 @@ def make_completion_fn( """ if name == "dummy": return DummyCompletionFn() + + if kwargs.get("api_base") is not None: + return UserChatCompletionFn(model=name, **kwargs) n_ctx = n_ctx_from_model_name(name)