-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfilters.py
More file actions
92 lines (86 loc) · 2.91 KB
/
filters.py
File metadata and controls
92 lines (86 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
This package contains methods to filter the dataset(s) for valid
intervention indices. Typically the filters will take in a
pandas row and operate on the values in the row.
filter: python_function(df_row, info)
a python function that takes a dataframe row and an info
dict and returns filtered indices (indices that we would
like to sample from)
"""
default_info = {
"pad_token": "<PAD>",
"bos_token": "<BOS>",
"eos_token": "<EOS>",
"pad_token_id": 0,
"bos_token_id": 1,
"eos_token_id": 2,
}
def default_filter(df_row, info=None, excl_varb_vals=None):
"""
Default filter that returns True if the input token id is a keeper.
This occurs when the token is not
one of the special tokens (pad, bos, eos) or if it is not in
the exclusion variable values.
Args:
df_row: A row from a pandas DataFrame containing the input token id.
info: A dictionary containing token information (optional).
excl_varb_vals: A dictionary of variable values to exclude (optional).
keys: variable names
values: set of values to exclude for that variable
example: {"count": {2, 4, 9, 14, 17}}
"""
if info is None: info = default_info
bad_token_ids = {
info.get("pad_token_id", 0),
info.get("bos_token_id", 1),
info.get("eos_token_id", 2),
*info.get("trig_token_ids", [7]),
}
do_remove = df_row.inpt_token_id in bad_token_ids
#try:
# do_remove = do_remove or (int(df_row["count"]) in {-1,0,20})
#except KeyError:
# # If "count" is not in df_row, we do not exclude based on it.
# pass
if excl_varb_vals is not None:
for key in excl_varb_vals:
if key in df_row:
do_remove = do_remove or (int(df_row[key]) in excl_varb_vals[key])
return not do_remove
default_excl_vals = {
"count": { 2,4,9,14,17 },
}
def excl_varb_vals_filter(
df_row,
excl_varb_vals=default_excl_vals,
info=None
):
if info is None: info = default_info
bad_token_ids = {
info.get("pad_token_id", 0),
info.get("bos_token_id", 1),
info.get("eos_token_id", 2),
*info.get("trig_token_ids", [7]),
}
dfx = df_row.inpt_token_id not in bad_token_ids
for key in excl_varb_vals:
if key in df_row:
dfx = dfx and not( int(df_row[key]) in excl_varb_vals[key] )
return dfx
def keep_varb_vals_filter(
df_row,
keep_varb_vals=default_excl_vals,
info=None
):
if info is None: info = default_info
bad_token_ids = {
info.get("pad_token_id", 0),
info.get("bos_token_id", 1),
info.get("eos_token_id", 2),
*info.get("trig_token_ids", [7]),
}
dfx = df_row.inpt_token_id not in bad_token_ids
for key in keep_varb_vals:
if key in df_row:
dfx = dfx and ( int( df_row[key] ) in keep_varb_vals[key])
return dfx