-
Notifications
You must be signed in to change notification settings - Fork 44
Multimodal Input: from_webdataset() #560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: 2.0
Are you sure you want to change the base?
Conversation
…nto wds-schema
Co-authored-by: nsaadhvi <[email protected]>
Signed-off-by: Neeral Bhalgat <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... it seems like this PR includes a lot of superfluous/untouched files in the diff that are creating unnecessary conflicts. Maybe the changes can be squashed and rebased on top of the latest from https://github.com/ray-project/deltacat/tree/2.0 then resubmitted to clean this up?
| from transformers import AutoImageProcessor, AutoModelForImageClassification | ||
|
|
||
|
|
||
| #tar_path = "deltacat/tests/test_utils/resources/imagenet1k-train-0000.tar" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove any unused code
| # Create a list of dictionaries combining filename and predicted species | ||
| rows_to_write = [ | ||
| { | ||
| "filename": fname, | ||
| "bird_species": bird_labels[idx] | ||
| } | ||
| for idx, fname in enumerate(filenames) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a weird way to do this, zip() the two would be cleaner.
deltacat/storage/rivulet/dataset.py
Outdated
| with tarfile.open(file_uri, "r") as tar: | ||
| tar_members = tar.getmembers() | ||
| current_batch = None | ||
| reading_frame_size = batch_size # TODO: Use batch size 1 for now. | ||
| total_batches = math.ceil(len(tar_members) / reading_frame_size) | ||
|
|
||
| for i in range(total_batches): | ||
| reading_frame_start = i * reading_frame_size | ||
| reading_frame_end = reading_frame_start + reading_frame_size | ||
| for member in tar_members[reading_frame_start:reading_frame_end]: | ||
| # Ignore hidden files if the imported tar isn't cleaned. | ||
| if member.name.startswith("._"): | ||
| continue | ||
| if member.isfile() and member.name.endswith(".json"): | ||
| f = tar.extractfile(member) | ||
| if f: | ||
| try: | ||
| merge_key = merge_keys | ||
|
|
||
| pyarrow_table = pyarrow.json.read_json(f) | ||
| image_filename = pyarrow_table[merge_key][0].as_py() | ||
|
|
||
| # truncated_filename = normalize_filename(image_filename[image_filename.index('/') + 1:]) | ||
| truncated_filename = normalize_filename(os.path.basename(image_filename)) | ||
| if truncated_filename in [normalize_filename(t.name) for t in tar_members]: | ||
| image_member = next((t for t in tar_members if t.name == truncated_filename), None) | ||
| if image_member: | ||
| fi = tar.extractfile(image_member) | ||
| if fi: | ||
| media_binary = fi.read() | ||
| media_binaries.extend([media_binary]) | ||
|
|
||
| if current_batch is None: | ||
| current_batch = pyarrow_table | ||
| else: | ||
| current_batch = pa.concat_tables([current_batch, pyarrow_table]) | ||
| except Exception as e: | ||
| print(f"Error with {member.name}:", e) | ||
|
|
||
| if current_batch is not None: | ||
| try: | ||
| dataset_schema.merge(Schema.from_pyarrow(current_batch.schema, merge_keys=merge_keys)) | ||
| except Exception as e: | ||
| print(f"Error merging schema: {e}") | ||
|
|
||
| if current_batch is not None and media_binaries: | ||
| if len(media_binaries) == current_batch.num_rows: | ||
| try: | ||
| image_column = pyarrow.array(media_binaries, type=pyarrow.binary()) | ||
| current_batch = current_batch.add_column( | ||
| len(current_batch.schema), | ||
| 'media_binary', | ||
| image_column | ||
| ) | ||
| # Edit dataset_schema to have media_binaries as a field object | ||
| dataset_schema.add_field(Field('media_binary', Datatype.binary(image_filename[image_filename.index('.') + 1:].lower()))) | ||
| except Exception as e: | ||
| print(f"Mismatch between media binaries and batch rows: {e}") | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite heavily nested. Not worth it now, but I'd extract this into a WebDatasetReader sort of class that manges all this in a non-nested way if we do more development.
deltacat/storage/rivulet/dataset.py
Outdated
| pyarrow_table = pyarrow.json.read_json(f) | ||
| image_filename = pyarrow_table[merge_key][0].as_py() | ||
|
|
||
| # truncated_filename = normalize_filename(image_filename[image_filename.index('/') + 1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This blows up if you have multiple merge keys. Need to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have currently handled the following cases:
- If there are multiple merge keys, raise a ValueError
- If the merge key input is a list with one merge key, proceed with the one merge key.
We can adapt this to handle multiple merge keys instead of raising an error if wanted.
| @dataclass(frozen=True) | ||
| class Field: | ||
| name: str | ||
| datatype: Datatype | ||
| is_merge_key: bool = False | ||
|
|
||
| class Schema(MutableMapping[str, Field]): | ||
| def __init__( | ||
| self, | ||
| fields: Iterable[Tuple[str, Datatype] | Field] = None, | ||
| merge_keys: Optional[Iterable[str]] = None, | ||
| ): | ||
| self._fields: Dict[str, Field] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not understanding why a class called Field is in a file called schema_test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file was just for our own purposes of understanding the classes and files, not for the PR, so we have removed it completely.
The Field and Schema classes were already classes defined in schema.py, but just for the sake of ease we copied them into this file just for our own understanding haha.
We've moved the process_tar() function into test_wds.py for now, just because it could be helpful test util, but we are not sure if that is the best place to put it (or if we even want to keep it).
| """Test that from_webdataset correctly identifies all fields in the tar file.""" | ||
| tar_path = "../../../test_utils/resources/test_wds.tar" | ||
| dataset = Dataset.from_webdataset( | ||
| name="test_webdataset", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally don't like static, pre-generated test objects like this. The class should generate the wds file prior to the test using a standard wds function, then test on the created file, then delete the file at the end. This ensures if wds library/standard changes and that change impacts serialization, the tests actually fail properly. Right now, there's no confidence the .tar is actually a real wds file, and no clear understanding of what's in that file or how it was generated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We looked into dynamically generating a webdataset, but it's a bit tricky in this case since typical webdatasets include media files like .jpg, which are not straightforward to create dynamically in a lightweight way. We could use .txt files, but that wouldn't fully test the expected use case.
Also, we noticed that other parts of the codebase (like CSV and Parquet) have some static test files, so we planned to follow that pattern by adding a minimal .tar file for testing. Happy to revisit this if there's a preferred way to generate valid webdataset test data inline.
| def test_metadata_directory_creation(tmp_path): | ||
| """Test that metadata directory is properly initialized.""" | ||
| tar_path = "../../../test_utils/resources/test_wds.tar" | ||
| dataset = Dataset.from_webdataset( | ||
| name="test_meta", | ||
| file_uri=tar_path, | ||
| metadata_uri=tmp_path, | ||
| merge_keys="filename" | ||
| ) | ||
| assert hasattr(dataset, "_metadata_path") | ||
| assert dataset._metadata_path is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't test internal attributes like this. Remove this test and test whatever it is the _metadata_path is attempting to create (i.e. fetch some useful metdata that would require the metadata path dir to exist)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We created a test test_dataset_persistence_and_reloading which successfully creates, saves, and scans the dataset. We take this to mean that the metadata was successfully written and the path successfully exists, but we can also explore more comprehensive tests.
foo.py
Outdated
| import csv | ||
| import pyarrow as pa | ||
| import pyarrow.compute as pc | ||
|
|
||
| animal = pa.array(["sheep", "cows", "horses", "foxes", "sheep"], type=pa.string()) | ||
| count = pa.array([12, 5, 2, 1, 10], type=pa.int8()) | ||
| year = pa.array([2022, 2022, 2022, 2022, 2021], type=pa.int16()) | ||
|
|
||
| # Creating a table from arrays | ||
| table = pa.Table.from_arrays([animal, count, year], names=['animal', 'count', 'year']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File doesn't belong in commit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've removed this file.
foowds.py
Outdated
| import os | ||
| import json | ||
| import tarfile | ||
| import io | ||
| import numpy as np | ||
| from PIL import Image | ||
|
|
||
| # Create mock data directory | ||
| if not os.path.exists('mock_data'): | ||
| os.makedirs('mock_data') | ||
|
|
||
| # Sample IDs for medical papers (similar to the example) | ||
| sample_ids = [ | ||
| "", | ||
| "PMC4129566_00003", | ||
| "PMC4872614_00002", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File doesn't belong in commit? Or this needs to move it to tests/utils or some equivalent so it can be used by the wds test class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've removed this file.
| import itertools | ||
| import pytest | ||
| import pyarrow as pa | ||
| import json | ||
| import tarfile | ||
| from deltacat.storage.rivulet import Dataset, Schema, Field, Datatype | ||
|
|
||
|
|
||
| def test_schema_field_types(): | ||
| """Test that Schema correctly stores Field objects with their types.""" | ||
| fields = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be missing an actual data read/write, all tests are about schema.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All our tests now use Pytest fixtures to create tar files at the beginning of each test. They each run from_webdataset() , so that should include verifying that reading from the tar file worked. Then each test checks the fields, types and some values in the Dataset produced by from_webdataset(), which I believe should handle checking that writing worked.
We put all test cases in a class called TestFromWebDataset, in order to ensure proper set up and teardown, leaving no directories lying around. This meant that we only used 1 temp_dir directory from Pytest's tmp_path, so we named datasets, tar files, and JSON files uniquely to their respective entities in lieu of that.
The one thing leftover is testing with image files (instead of .txt files), which we left a TODO and commented out code in for, and can address once we discuss more on the specifics of the matter!
|
To your comment @pdames (Hmm... it seems like this PR includes a lot of superfluous/untouched files in the diff that are creating unnecessary conflicts. Maybe the changes can be squashed and rebased on top of the latest from https://github.com/ray-project/deltacat/tree/2.0 then resubmitted to clean this up?): It looks like we accidentally changed the executable mode of ~400 files in one commit (most likely by running We have merged the latest 2.0 into the PR (which unfortunately did not resolve the executable mode issue), and changed the executable modes back for all relevant files. |
|
Overall changes:
|
…ogic and fixed media_binary logic in from_webdataset
…d record count check in inconsistent schema test for the WDS reader
…tion handling for WebDatasetReader
Summary
Implemented the function
from_webdataset()in dataset.py that converts a webdataset .tar file into a Rivulet Dataset instance; created a series of unit tests for validation.Rationale
from_webdataset()will allow Rivulet databases to hold WebDataset data, alongside other data formats such as JSON and CSV.Changes
Testing
Test suite deltacat/tests/storage/rivulet/schema/test_wds.py contain unit tests testing the following functionality of from_webdataset(). All test cases pass and from_webdataset() does not break existing Rivulet or deltacat functions.