Skip to content

Commit 9be5e30

Browse files
elronbandeldafnapension
authored andcommitted
Added test and fix
Signed-off-by: elronbandel <elronbandel@gmail.com>
1 parent 0e79158 commit 9be5e30

6 files changed

Lines changed: 156 additions & 26 deletions

File tree

src/unitxt/artifact.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,14 @@ class Artifact(Dataclass):
446446
default=None, required=False, also_positional=False
447447
)
448448

449+
def __init_subclass__(cls, **kwargs):
450+
super().__init_subclass__(**kwargs)
451+
module = inspect.getmodule(cls)
452+
# standardize module name
453+
module_name = getattr(module, "__name__", None)
454+
if not is_library_module(module_name):
455+
cls.register_class()
456+
449457
@classmethod
450458
def is_possible_identifier(cls, obj):
451459
return isinstance(obj, str) or is_artifact_dict(obj)
@@ -458,18 +466,15 @@ def get_artifact_type(cls):
458466
if not is_library_module(module_name):
459467
non_library_module_warning = f"module named {module_name} is not importable. Class {cls} is thus registered into Artifact.class_register, indexed by {cls.__name__}, accessible there as long as this class_register lives."
460468
warnings.warn(non_library_module_warning, ImportWarning, stacklevel=2)
461-
cls.register_class(cls)
469+
cls.register_class()
462470
return {"module": "class_register", "name": cls.__name__}
463471
if hasattr(cls, "__qualname__") and "." in cls.__qualname__:
464472
return {"module": module_name, "name": cls.__qualname__}
465473
return {"module": module_name, "name": cls.__name__}
466474

467475
@classmethod
468-
def register_class(cls, artifact_class):
469-
Artifact._class_register[artifact_class.__name__] = artifact_class
470-
471-
def __init_subclass__(cls, **kwargs):
472-
super().__init_subclass__(**kwargs)
476+
def register_class(cls):
477+
Artifact._class_register[cls.__name__] = cls
473478

474479
@classmethod
475480
def is_artifact_file(cls, path):
@@ -603,7 +608,7 @@ def maybe_fix_type_to_ensure_instantiation_ability(self):
603608
not is_library_module(self.__type__["module"])
604609
or "<locals>" in self.__type__["name"]
605610
):
606-
self.__class__.register_class(self.__class__)
611+
self.__class__.register_class()
607612
self.__type__ = {
608613
"module": "class_register",
609614
"name": self.__class__.__name__,

src/unitxt/catalog/cards/tot/arithmetic.json

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,47 @@
11
{
2-
"__type__": "task_card",
2+
"__type__": {
3+
"module": "unitxt.card",
4+
"name": "TaskCard"
5+
},
36
"loader": {
4-
"__type__": "load_hf",
7+
"__type__": {
8+
"module": "unitxt.loaders",
9+
"name": "LoadHF"
10+
},
511
"path": "baharef/ToT",
612
"name": "tot_arithmetic"
713
},
814
"preprocess_steps": [
915
{
10-
"__type__": "replace",
16+
"__type__": {
17+
"module": "unitxt.string_operators",
18+
"name": "Replace"
19+
},
1120
"field": "label",
1221
"old": "'",
1322
"new": "\""
1423
},
1524
{
16-
"__type__": "load_json",
25+
"__type__": {
26+
"module": "unitxt.struct_data_operators",
27+
"name": "LoadJson"
28+
},
1729
"field": "label"
1830
},
1931
{
20-
"__type__": "copy",
32+
"__type__": {
33+
"module": "unitxt.operators",
34+
"name": "Copy"
35+
},
2136
"field": "label/answer",
2237
"to_field": "label"
2338
}
2439
],
2540
"task": {
26-
"__type__": "task",
41+
"__type__": {
42+
"module": "unitxt.task",
43+
"name": "Task"
44+
},
2745
"input_fields": {
2846
"question": "str"
2947
},
@@ -37,14 +55,23 @@
3755
},
3856
"templates": [
3957
{
40-
"__type__": "input_output_template",
58+
"__type__": {
59+
"module": "unitxt.templates",
60+
"name": "InputOutputTemplate"
61+
},
4162
"input_format": "{question}",
4263
"output_format": "{{\"answer\": \"{label}\"}}",
4364
"postprocessors": [
4465
{
45-
"__type__": "post_process",
66+
"__type__": {
67+
"module": "unitxt.processors",
68+
"name": "PostProcess"
69+
},
4670
"operator": {
47-
"__type__": "extract_with_regex",
71+
"__type__": {
72+
"module": "unitxt.processors",
73+
"name": "ExtractWithRegex"
74+
},
4875
"regex": "\"answer\"\\s*:\\s*\"((?:[^\"\\\\]|\\\\.)*)\""
4976
}
5077
}

src/unitxt/catalog/cards/tot/semantic.json

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
{
2-
"__type__": "task_card",
2+
"__type__": {
3+
"module": "unitxt.card",
4+
"name": "TaskCard"
5+
},
36
"loader": {
4-
"__type__": "load_hf",
7+
"__type__": {
8+
"module": "unitxt.loaders",
9+
"name": "LoadHF"
10+
},
511
"path": "baharef/ToT",
612
"name": "tot_semantic"
713
},
814
"task": {
9-
"__type__": "task",
15+
"__type__": {
16+
"module": "unitxt.task",
17+
"name": "Task"
18+
},
1019
"input_fields": {
1120
"prompt": "str",
1221
"question": "str"
@@ -21,14 +30,23 @@
2130
},
2231
"templates": [
2332
{
24-
"__type__": "input_output_template",
33+
"__type__": {
34+
"module": "unitxt.templates",
35+
"name": "InputOutputTemplate"
36+
},
2537
"input_format": "{prompt}",
2638
"output_format": "{{\"answer\": \"{label}\"}}",
2739
"postprocessors": [
2840
{
29-
"__type__": "post_process",
41+
"__type__": {
42+
"module": "unitxt.processors",
43+
"name": "PostProcess"
44+
},
3045
"operator": {
31-
"__type__": "extract_with_regex",
46+
"__type__": {
47+
"module": "unitxt.processors",
48+
"name": "ExtractWithRegex"
49+
},
3250
"regex": "\"answer\"\\s*:\\s*\"((?:[^\"\\\\]|\\\\.)*)\""
3351
}
3452
}

src/unitxt/deprecation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def wrapper(*args, **kwargs):
8080
and issubclass(obj, Artifact)
8181
and obj is not Artifact
8282
):
83-
obj.register_class(obj)
83+
obj.register_class()
8484
elif constants.version >= version:
8585
raise DeprecationError(f"{obj.__name__} is no longer supported.{alt_text}")
8686
return obj(*args, **kwargs)

tests/library/test_artifact_recovery.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import subprocess
2+
import sys
3+
import tempfile
4+
import textwrap
5+
from pathlib import Path
6+
17
from unitxt.artifact import (
28
Artifact,
39
MissingArtifactTypeError,
@@ -12,6 +18,80 @@
1218

1319

1420
class TestArtifactRecovery(UnitxtTestCase):
21+
def test_custom_catalog_and_project(self):
22+
with tempfile.TemporaryDirectory() as tmpdirname:
23+
project_dir = Path(tmpdirname)
24+
operator_dir = project_dir / "operators"
25+
catalog_dir = project_dir / "catalog"
26+
operator_dir.mkdir()
27+
28+
# Write the operator class
29+
operator_code = textwrap.dedent(
30+
"""
31+
from unitxt.operators import InstanceOperator
32+
33+
class MyTempOperator(InstanceOperator):
34+
def process(self, instance, stream_name=None):
35+
return instance
36+
"""
37+
)
38+
(operator_dir / "my_operator.py").write_text(operator_code)
39+
(operator_dir / "__init__.py").write_text("")
40+
41+
# Write the saving script
42+
saving_code = textwrap.dedent(
43+
f"""
44+
from operators.my_operator import MyTempOperator
45+
from unitxt import add_to_catalog, settings
46+
47+
add_to_catalog(MyTempOperator(), "operators.my_temp_operator", catalog_path="{catalog_dir}")
48+
"""
49+
)
50+
saving_script = project_dir / "save_operator.py"
51+
saving_script.write_text(saving_code)
52+
53+
# Write the loading script
54+
loading_code = textwrap.dedent(
55+
"""
56+
from unitxt import get_from_catalog
57+
from operators.my_operator import MyTempOperator
58+
59+
get_from_catalog("operators.my_temp_operator")
60+
"""
61+
)
62+
loading_script = project_dir / "load_operator.py"
63+
loading_script.write_text(loading_code)
64+
65+
# Run the saving script
66+
result_save = subprocess.run(
67+
[sys.executable, str(saving_script)],
68+
env={
69+
"UNITXT_CATALOGS": str(catalog_dir),
70+
"PYTHONPATH": str(project_dir),
71+
},
72+
capture_output=True,
73+
text=True,
74+
)
75+
if result_save.returncode != 0:
76+
logger.info(f"Saving script STDOUT:\n{result_save.stdout}")
77+
logger.info(f"Saving script STDERR:\n{result_save.stderr}")
78+
self.assertEqual(result_save.returncode, 0, "Saving script failed")
79+
80+
# Run the loading script
81+
result_load = subprocess.run(
82+
[sys.executable, str(loading_script)],
83+
env={
84+
"UNITXT_CATALOGS": str(catalog_dir),
85+
"PYTHONPATH": str(project_dir),
86+
},
87+
capture_output=True,
88+
text=True,
89+
)
90+
if result_load.returncode != 0:
91+
logger.info(f"Loading script STDOUT:\n{result_load.stdout}")
92+
logger.info(f"Loading script STDERR:\n{result_load.stderr}")
93+
self.assertEqual(result_load.returncode, 0, "Loading script failed")
94+
1595
def test_correct_artifact_recovery(self):
1696
args = {
1797
"__type__": {"module": "unitxt.standard", "name": "DatasetRecipe"},

utils/prepare_all_artifacts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def main():
8585
)
8686
if final_number_of_catalog_entries <= initial_number_of_catalog_entries:
8787
error_msg = f"all the following {len(prepare_files)} prepare files fail forever: {prepare_files}. "
88-
"One potential reason is that at least one of them contains add_link_to_catalog of an ArtifactLink "
89-
"that links to an artifact that is added to the catalog only down that prepare_file. "
90-
"To fix this: swap the order: first add_to_catalog the artifact linked to, and then add_link_to_catalog."
88+
"One potential reason is a circular dependency among them, another is that at least one of them contains add_link_to_catalog "
89+
"of an ArtifactLink that links to an artifact that is added to the catalog only down that prepare_file. "
90+
"To fix: resolve dependency, or swap the order: first add_to_catalog the artifact linked to, and then add_link_to_catalog."
9191
raise RuntimeError(error_msg)
9292
prepare_files = failing_prepare_files
9393
failing_prepare_files = []

0 commit comments

Comments
 (0)