|
| 1 | +import subprocess |
| 2 | +import sys |
| 3 | +import tempfile |
| 4 | +import textwrap |
| 5 | +from pathlib import Path |
| 6 | + |
1 | 7 | from unitxt.artifact import ( |
2 | 8 | Artifact, |
3 | 9 | MissingArtifactTypeError, |
|
12 | 18 |
|
13 | 19 |
|
14 | 20 | class TestArtifactRecovery(UnitxtTestCase): |
| 21 | + def test_custom_catalog_and_project(self): |
| 22 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 23 | + project_dir = Path(tmpdirname) |
| 24 | + operator_dir = project_dir / "operators" |
| 25 | + catalog_dir = project_dir / "catalog" |
| 26 | + operator_dir.mkdir() |
| 27 | + |
| 28 | + # Write the operator class |
| 29 | + operator_code = textwrap.dedent( |
| 30 | + """ |
| 31 | + from unitxt.operators import InstanceOperator |
| 32 | +
|
| 33 | + class MyTempOperator(InstanceOperator): |
| 34 | + def process(self, instance, stream_name=None): |
| 35 | + return instance |
| 36 | + """ |
| 37 | + ) |
| 38 | + (operator_dir / "my_operator.py").write_text(operator_code) |
| 39 | + (operator_dir / "__init__.py").write_text("") |
| 40 | + |
| 41 | + # Write the saving script |
| 42 | + saving_code = textwrap.dedent( |
| 43 | + f""" |
| 44 | + from operators.my_operator import MyTempOperator |
| 45 | + from unitxt import add_to_catalog, settings |
| 46 | +
|
| 47 | + add_to_catalog(MyTempOperator(), "operators.my_temp_operator", catalog_path="{catalog_dir}") |
| 48 | + """ |
| 49 | + ) |
| 50 | + saving_script = project_dir / "save_operator.py" |
| 51 | + saving_script.write_text(saving_code) |
| 52 | + |
| 53 | + # Write the loading script |
| 54 | + loading_code = textwrap.dedent( |
| 55 | + """ |
| 56 | + from unitxt import get_from_catalog |
| 57 | + from operators.my_operator import MyTempOperator |
| 58 | +
|
| 59 | + get_from_catalog("operators.my_temp_operator") |
| 60 | + """ |
| 61 | + ) |
| 62 | + loading_script = project_dir / "load_operator.py" |
| 63 | + loading_script.write_text(loading_code) |
| 64 | + |
| 65 | + # Run the saving script |
| 66 | + result_save = subprocess.run( |
| 67 | + [sys.executable, str(saving_script)], |
| 68 | + env={ |
| 69 | + "UNITXT_CATALOGS": str(catalog_dir), |
| 70 | + "PYTHONPATH": str(project_dir), |
| 71 | + }, |
| 72 | + capture_output=True, |
| 73 | + text=True, |
| 74 | + ) |
| 75 | + if result_save.returncode != 0: |
| 76 | + logger.info(f"Saving script STDOUT:\n{result_save.stdout}") |
| 77 | + logger.info(f"Saving script STDERR:\n{result_save.stderr}") |
| 78 | + self.assertEqual(result_save.returncode, 0, "Saving script failed") |
| 79 | + |
| 80 | + # Run the loading script |
| 81 | + result_load = subprocess.run( |
| 82 | + [sys.executable, str(loading_script)], |
| 83 | + env={ |
| 84 | + "UNITXT_CATALOGS": str(catalog_dir), |
| 85 | + "PYTHONPATH": str(project_dir), |
| 86 | + }, |
| 87 | + capture_output=True, |
| 88 | + text=True, |
| 89 | + ) |
| 90 | + if result_load.returncode != 0: |
| 91 | + logger.info(f"Loading script STDOUT:\n{result_load.stdout}") |
| 92 | + logger.info(f"Loading script STDERR:\n{result_load.stderr}") |
| 93 | + self.assertEqual(result_load.returncode, 0, "Loading script failed") |
| 94 | + |
15 | 95 | def test_correct_artifact_recovery(self): |
16 | 96 | args = { |
17 | 97 | "__type__": {"module": "unitxt.standard", "name": "DatasetRecipe"}, |
|
0 commit comments