Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
99 changes: 99 additions & 0 deletions tests/configuration/test_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import yaml

sys.path.append(str(Path(__file__).parent.parent.parent))

from sygra.configuration.loader import ConfigLoader


class TestConfigLoaderLoad(unittest.TestCase):
def test_load_returns_dict_unchanged_when_given_dict(self):
loader = ConfigLoader()
config = {"task_name": "test", "nodes": {}}
result = loader.load(config)
self.assertEqual(result, config)

def test_load_reads_yaml_file(self):
config_data = {"task_name": "my_task", "nodes": {"n1": {"node_type": "llm"}}}
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f:
yaml.dump(config_data, f)
tmp_path = f.name

loader = ConfigLoader()
result = loader.load(tmp_path)
self.assertEqual(result["task_name"], "my_task")

def test_load_raises_file_not_found_for_missing_file(self):
loader = ConfigLoader()
with self.assertRaises(FileNotFoundError):
loader.load("/nonexistent/path/config.yaml")

def test_load_raises_type_error_for_non_dict_yaml(self):
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f:
yaml.dump(["item1", "item2"], f)
tmp_path = f.name

loader = ConfigLoader()
with self.assertRaises(TypeError):
loader.load(tmp_path)

def test_load_accepts_path_object(self):
config_data = {"task_name": "path_test"}
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f:
yaml.dump(config_data, f)
tmp_path = Path(f.name)

loader = ConfigLoader()
result = loader.load(tmp_path)
self.assertEqual(result["task_name"], "path_test")


class TestConfigLoaderLoadAndCreate(unittest.TestCase):
def test_load_and_create_returns_workflow_with_correct_flags(self):
config_data = {"task_name": "my_workflow", "nodes": {}}
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f:
yaml.dump(config_data, f)
tmp_path = f.name

loader = ConfigLoader()
workflow = loader.load_and_create(tmp_path)

self.assertTrue(workflow._supports_subgraphs)
self.assertTrue(workflow._supports_multimodal)
self.assertTrue(workflow._supports_resumable)
self.assertTrue(workflow._supports_quality)
self.assertTrue(workflow._supports_oasst)

def test_load_and_create_sets_name_from_parent_directory(self):
config_data = {"task_name": "my_workflow"}
with tempfile.TemporaryDirectory() as tmpdir:
task_dir = Path(tmpdir) / "my_task_name"
task_dir.mkdir()
config_file = task_dir / "graph_config.yaml"
config_file.write_text(yaml.dump(config_data))

loader = ConfigLoader()
workflow = loader.load_and_create(str(config_file))

self.assertEqual(workflow.name, "my_task_name")

def test_load_and_create_with_dict_sets_name_from_task_name(self):
config = {"task_name": "dict_task", "nodes": {}}
loader = ConfigLoader()
workflow = loader.load_and_create(config)
self.assertEqual(workflow.name, "dict_task")

def test_load_and_create_with_dict_without_task_name_uses_default(self):
config = {"nodes": {}}
loader = ConfigLoader()
workflow = loader.load_and_create(config)
self.assertEqual(workflow.name, "loaded_workflow")


if __name__ == "__main__":
unittest.main()
139 changes: 139 additions & 0 deletions tests/core/dataset/test_dataset_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import sys
import unittest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

sys.path.append(str(Path(__file__).parent.parent.parent.parent))


def _make_processor(**kwargs):
"""Helper to create a DatasetProcessor with all heavy deps mocked."""
defaults = dict(
input_dataset=[{"id": "1"}],
graph=MagicMock(),
graph_config=MagicMock(),
output_file="/tmp/tasks/my_task/output.jsonl",
num_records_total=10,
batch_size=10,
checkpoint_interval=10,
)
defaults.update(kwargs)

graph_config = defaults["graph_config"]
graph_config.config = {"task_name": "test_task"}
graph_config.oasst_mapper = None

with patch("sygra.core.dataset.dataset_processor.tqdm") as mock_tqdm, \
patch("sygra.core.dataset.dataset_processor.ResumableExecutionManager"):
mock_tqdm.tqdm.return_value = MagicMock()
from sygra.core.dataset.dataset_processor import DatasetProcessor
processor = DatasetProcessor(**defaults)
return processor


class TestDatasetProcessorInit(unittest.TestCase):
def test_raises_when_checkpoint_not_multiple_of_batch(self):
with patch("sygra.core.dataset.dataset_processor.tqdm"), \
patch("sygra.core.dataset.dataset_processor.ResumableExecutionManager"):
from sygra.core.dataset.dataset_processor import DatasetProcessor
with self.assertRaises(AssertionError):
DatasetProcessor(
input_dataset=[{"id": "1"}],
graph=MagicMock(),
graph_config=MagicMock(),
output_file="/tmp/output.jsonl",
num_records_total=10,
batch_size=30,
checkpoint_interval=100,
)

def test_valid_checkpoint_multiple_of_batch(self):
proc = _make_processor(batch_size=10, checkpoint_interval=100)
self.assertEqual(proc.batch_size, 10)
self.assertEqual(proc.checkpoint_interval, 100)


class TestDetermineDatasetType(unittest.TestCase):
def setUp(self):
from sygra.core.dataset.dataset_processor import DatasetProcessor
self.DatasetProcessor = DatasetProcessor

def test_list_returns_in_memory(self):
result = self.DatasetProcessor._determine_dataset_type([{"a": 1}])
self.assertEqual(result, "in_memory")

def test_streaming_attribute_true_returns_streaming(self):
mock_ds = MagicMock()
mock_ds.is_streaming = True
result = self.DatasetProcessor._determine_dataset_type(mock_ds)
self.assertEqual(result, "streaming")

def test_iterable_dataset_returns_streaming(self):
import datasets
mock_ds = MagicMock(spec=datasets.IterableDataset)
result = self.DatasetProcessor._determine_dataset_type(mock_ds)
self.assertEqual(result, "streaming")

def test_default_returns_in_memory(self):
mock_ds = MagicMock(spec=object)
del mock_ds.is_streaming
result = self.DatasetProcessor._determine_dataset_type(mock_ds)
self.assertEqual(result, "in_memory")


class TestExtractTaskName(unittest.TestCase):
def test_extracts_from_tasks_path(self):
proc = _make_processor(output_file="/data/tasks/my_task/output.jsonl")
result = proc._extract_task_name()
self.assertEqual(result, "my_task")

def test_fallback_when_no_tasks_segment(self):
proc = _make_processor(output_file="/data/output/result.jsonl")
result = proc._extract_task_name()
self.assertIn("task_", result)


class TestIsErrorCodeInOutput(unittest.TestCase):
def setUp(self):
from sygra.core.dataset.dataset_processor import DatasetProcessor
self.DatasetProcessor = DatasetProcessor

def test_returns_true_when_error_prefix_found(self):
output = {"key": "###SERVER_ERROR### something bad happened"}
self.assertTrue(self.DatasetProcessor.is_error_code_in_output(output))

def test_returns_false_when_no_error_prefix(self):
output = {"key": "all good", "count": 5}
self.assertFalse(self.DatasetProcessor.is_error_code_in_output(output))

def test_returns_false_for_non_string_values(self):
output = {"count": 42, "data": [1, 2, 3]}
self.assertFalse(self.DatasetProcessor.is_error_code_in_output(output))

def test_returns_false_for_empty_output(self):
self.assertFalse(self.DatasetProcessor.is_error_code_in_output({}))


class TestGetRecord(unittest.TestCase):
def test_assigns_uuid_when_record_has_no_id(self):
proc = _make_processor(input_dataset=[{"value": "hello"}])
proc.resumable = False
record = proc._get_record()
self.assertIn("id", record)
self.assertTrue(len(record["id"]) > 0)

def test_keeps_existing_id(self):
proc = _make_processor(input_dataset=[{"id": "existing-id", "value": "hello"}])
proc.resumable = False
record = proc._get_record()
self.assertEqual(record["id"], "existing-id")

def test_raises_stop_iteration_when_exhausted(self):
proc = _make_processor(input_dataset=[])
proc.resumable = False
with self.assertRaises(StopIteration):
proc._get_record()


if __name__ == "__main__":
unittest.main()
Loading
Loading