-
Notifications
You must be signed in to change notification settings - Fork 54
Expand file tree
/
Copy pathhydra_utils.py
More file actions
44 lines (36 loc) · 1.21 KB
/
hydra_utils.py
File metadata and controls
44 lines (36 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import datasets
import trl
import torch
import numpy
import transformers
import subprocess
def spawn_bash_command(bash_command):
print(f"Executing {bash_command}")
try:
subprocess.run(bash_command, shell=True, check=True)
except subprocess.CalledProcessError as err:
print("error during execution:", err)
except Exception as err:
print("unexpected error:", err)
def clean_after_training_completion(output_dir):
bash_command = f"./scripts/clean_past_steps.sh {output_dir} all"
spawn_bash_command(bash_command)
def fix_pad_token(tokenizer, model_name, unsafe=False):
if tokenizer.pad_token is None:
if "Llama" in model_name:
tokenizer.pad_token = "<|reserved_special_token_5|>"
elif "Qwen" in model_name or 'Bespoke' in model_name:
tokenizer.pad_token = "<|fim_pad|>"
else:
raise NotImplementedError
else:
if not unsafe:
assert tokenizer.pad_token_id != tokenizer.eos_token_id
return tokenizer
def wrap_as_list(*args, **kwargs):
to_return = []
for element in args:
to_return.append(element)
for element in kwargs.values():
to_return.append(element)
return to_return