-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_analyse_responses.py
More file actions
116 lines (97 loc) · 3.88 KB
/
batch_analyse_responses.py
File metadata and controls
116 lines (97 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import json
import polars as pl
from dotenv import load_dotenv
from openai import OpenAI
import multicultural_alignment.data
import multicultural_alignment.fileio as fileio
import multicultural_alignment.openai_batch as openai_batch
from multicultural_alignment.constants import OUTPUT_DIR
from multicultural_alignment.models import MODEL_FAMILIES
from multicultural_alignment.schemas import ProcessOpinions
from multicultural_alignment.structured import OPINION_SYSTEM_MSG
assert OUTPUT_DIR.exists(), f"Data directory {OUTPUT_DIR} does not exist"
load_dotenv(override=True)
MODEL_FAMILY_NAMES = list(MODEL_FAMILIES)
MODEL_FAMILY_NAMES.remove("Baseline")
def create_messages(response_row: dict, system_msg: str = OPINION_SYSTEM_MSG) -> list[dict]:
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": json.dumps(response_row, ensure_ascii=False)},
]
def create_batch(families: list[str]):
combined_data = multicultural_alignment.data.get_family_data(families=families)
client = OpenAI()
family_str = multicultural_alignment.data.get_family_string(families)
run_batch(
client,
combined_data,
output_name=f"{family_str}_responses_batch.jsonl",
additional_metadata={"model_families": family_str},
)
def run_batch(
client,
combined_data: pl.DataFrame,
output_name: str = "all_responses_batch.jsonl",
additional_metadata: dict | None = None,
):
stances = multicultural_alignment.data.get_stance_labels()
# responses = combined_data[["response", "topic"]].to_dict(orient="records")
responses = (
combined_data.select(["question_key", "response"]).join(stances, on="question_key").drop("question_key").to_dicts()
)
all_messages = [create_messages(response) for response in responses]
all_requests = openai_batch.create_requests_format(
all_messages, tool=ProcessOpinions, id_prefix="analyze-opinions", model="gpt-4.1"
)
metadata = {"description": "Analyze opinions in responses", "task": "analyze"}
if additional_metadata is not None:
metadata.update(additional_metadata)
batch = openai_batch.batch_from_messages(
client,
all_requests,
output_path=OUTPUT_DIR / output_name,
metadata=metadata,
)
return batch
def download_batch(model_families: list[str]):
model_families_str = multicultural_alignment.data.get_family_string(model_families)
client = OpenAI()
batches = client.batches.list()
analysis_batch = next(
batch
for batch in batches.data
if (batch.metadata.get("task") == "analyze") and (batch.metadata.get("model_families") == model_families_str)
)
if not analysis_batch.status == "completed":
raise ValueError(
f"Batch is not completed: Status is {analysis_batch.status}. Current progress is: {analysis_batch.request_counts}" # noqa: E501
)
downloaded = openai_batch.download_batch(client, batch=analysis_batch)
fileio.write_jsonl(downloaded, OUTPUT_DIR / f"{model_families_str}-raw-responses.jsonl")
def main(args: argparse.Namespace):
if args.mode == "run":
create_batch(families=args.model_families)
elif args.mode == "download":
download_batch(model_families=args.model_families)
else:
raise ValueError("Invalid mode")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Batch process responses for analysis")
parser.add_argument(
"--model-families",
"-m",
type=str,
nargs="+",
help="Models families for which to run batch processing",
choices=MODEL_FAMILY_NAMES,
)
parser.add_argument(
"--mode",
type=str,
default="run",
choices=["run", "download"],
help="Whether to run the batch or download the results",
)
args = parser.parse_args()
main(args=args)