diff --git a/README.md b/README.md index bfa38a41..faaca28c 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,6 @@ [![DOI](https://zenodo.org/badge/703686617.svg)](https://zenodo.org/doi/10.5281/zenodo.11406462) ![Docker pulls](https://img.shields.io/docker/pulls/michaelf34/infinity) - Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip. Infinity is developed under [MIT License](https://github.com/michaelfeil/infinity/blob/main/LICENSE). ## Why Infinity diff --git a/docs/docs/cli_v2.md b/docs/docs/cli_v2.md index 934c8791..865e3b31 100644 --- a/docs/docs/cli_v2.md +++ b/docs/docs/cli_v2.md @@ -75,6 +75,9 @@ $ infinity_emb v2 --help │ `INFINITY_LENGTHS_VIA_TOKENIZ… │ │ [default: │ │ lengths-via-tokenize] │ +│ --dtype [float32|float16|int8|fp8|aut dtype for the model weights. │ +│ oquant|auto] [env var: `INFINITY_DTYPE`] │ + │ --dtype [float32|float16|bfloat16|int dtype for the model weights. │ │ 8|fp8|auto] [env var: `INFINITY_DTYPE`] │ │ [default: auto] │ diff --git a/libs/infinity_emb/Makefile b/libs/infinity_emb/Makefile index b9e6faa4..0ad55e54 100644 --- a/libs/infinity_emb/Makefile +++ b/libs/infinity_emb/Makefile @@ -22,7 +22,7 @@ test tests: poetry run pytest openapi: - ./../../docs/assets/create_openapi_with_server_hook.sh + poetry run ./../../docs/assets/create_openapi_with_server_hook.sh ###################### # LINTING AND FORMATTING @@ -60,7 +60,7 @@ benchmark_embed: tests/data/benchmark/benchmark_embed.json # Generate CLI v2 documentation cli_v2_docs: - ./../../docs/assets/create_cli_v2_docs.sh + poetry run ./../../docs/assets/create_cli_v2_docs.sh ###################### # HELP diff --git a/libs/infinity_emb/infinity_emb/primitives.py b/libs/infinity_emb/infinity_emb/primitives.py index f19ab48c..93ad6936 100644 --- a/libs/infinity_emb/infinity_emb/primitives.py +++ b/libs/infinity_emb/infinity_emb/primitives.py @@ -128,6 +128,7 @@ class Dtype(EnumType): bfloat16: str = "bfloat16" int8: str = "int8" fp8: str = "fp8" + autoquant: str = "autoquant" auto: str = "auto" @staticmethod diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py index b86ae5a1..03d03a97 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py @@ -25,7 +25,6 @@ if TYPE_CHECKING: from torch import Tensor - if CHECK_SENTENCE_TRANSFORMERS.is_available: from sentence_transformers import SentenceTransformer, util # type: ignore else: @@ -88,7 +87,7 @@ def __init__(self, *, engine_args=EngineArgs): ]: fm.auto_model.to(torch.bfloat16) - if engine_args.dtype in (Dtype.int8, Dtype.fp8): + if engine_args.dtype in (Dtype.int8, Dtype.fp8, Dtype.autoquant): fm.auto_model = quant_interface( fm.auto_model, engine_args.dtype, device=Device[self.device.type] ) diff --git a/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py b/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py index 7b257c50..d38a90fe 100644 --- a/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py +++ b/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py @@ -7,6 +7,7 @@ import numpy as np import requests # type: ignore +import torch.ao.quantization from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH from infinity_emb.env import MANAGER @@ -34,7 +35,12 @@ def quant_interface(model: Any, dtype: Dtype = Dtype.int8, device: Device = Devi Defaults to Device.cpu. """ device_orig = model.device - if device == Device.cpu and dtype in [Dtype.int8, Dtype.auto]: + if dtype == Dtype.autoquant: + import torchao # type: ignore + + model = torchao.autoquant(model) + logger.info("using dtype=autoquant") + elif device == Device.cpu and dtype in [Dtype.int8, Dtype.auto]: logger.info("using torch.quantization.quantize_dynamic()") # TODO: verify if cpu requires quantization with torch.quantization.quantize_dynamic() model = torch.quantization.quantize_dynamic( @@ -42,6 +48,9 @@ def quant_interface(model: Any, dtype: Dtype = Dtype.int8, device: Device = Devi {torch.nn.Linear}, # a set of layers to dynamically quantize dtype=torch.qint8, ) + model = torch.ao.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) elif device == Device.cuda and dtype in [Dtype.int8, Dtype.auto]: logger.info(f"using quantize() for {dtype.value}") quant_handler, state_dict = quantize(model, mode=dtype.value) diff --git a/libs/infinity_emb/infinity_emb/transformer/quantization/quant.py b/libs/infinity_emb/infinity_emb/transformer/quantization/quant.py index adee2f07..a440b8ae 100644 --- a/libs/infinity_emb/infinity_emb/transformer/quantization/quant.py +++ b/libs/infinity_emb/infinity_emb/transformer/quantization/quant.py @@ -506,7 +506,10 @@ def create_quantized_state_dict(self): cur_state_dict = self.mod.state_dict() for fqn, mod in self.mod.named_modules(): if isinstance(mod, torch.nn.Linear): - assert not mod.bias + if mod.bias is not None: + raise ValueError( + "int4 quantization requires all layers to have bias=False. This model is not compatible." + ) out_features = mod.out_features in_features = mod.in_features assert out_features % 8 == 0, "require out_features % 8 == 0" @@ -710,7 +713,10 @@ def quantize( quantized_state_dict = quant_handler.create_quantized_state_dict() new_base_name = base_name.replace(".pth", f"{label}int8.pth") + elif mode == "autoquant": + import torchao + model = torchao.autoquant(torch.compile(model)) elif mode == "int4": logger.info( "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" diff --git a/libs/infinity_emb/poetry.lock b/libs/infinity_emb/poetry.lock index 69f67c16..cbab7892 100644 --- a/libs/infinity_emb/poetry.lock +++ b/libs/infinity_emb/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -3354,6 +3354,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -3863,6 +3868,26 @@ typing-extensions = ">=4.8.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.11.0)"] +[[package]] +name = "torchao" +version = "0.5.0" +description = "Package for applying ao techniques to GPU models" +optional = true +python-versions = "*" +files = [ + {file = "torchao-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6daff53790532d48e6b023bdd34030e8f87075f75f4206f3dd6577e6d99d7132"}, + {file = "torchao-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30a4c5c6ef7e3f5fa9a8dae3e2b9bb82c34d7c61a55f008e120303e22dd82cb6"}, + {file = "torchao-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3b8be5c3dcb641688397501d55cab4804688c3d27bdb9f8e0abcba1f2810678e"}, + {file = "torchao-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2aede6d89481ccda6bfd81f1666707765244de97697cb42b57c1001d9f928492"}, + {file = "torchao-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d1dacb8b899d76ea97f166d421c16140016cebed2090f04b17eee7e15b69969a"}, + {file = "torchao-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a14c6d5c6e8a1b03eded529bec9306f271ce59de43fa2e4699fd83f464bb5cd"}, + {file = "torchao-0.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3ed3c5a55609c0051d0e1b12a609f416d896afd2d6a5a1ac73ee14c0230801c"}, + {file = "torchao-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:575974f5aea245cdb9ecc30b6d7385de043c3ac874f5b7701ceb5362e521b7f1"}, +] + +[package.extras] +dev = ["bitsandbytes", "expecttest", "fire", "hypothesis", "matplotlib", "ninja", "packaging", "pandas", "parameterized", "pre-commit", "pytest (==7.4.0)", "ruff", "sentencepiece", "tabulate", "transformers", "unittest-xml-reporting"] + [[package]] name = "torchvision" version = "0.19.1" @@ -4023,11 +4048,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -4800,17 +4820,17 @@ type = ["pytest-mypy"] all = ["ctranslate2", "diskcache", "einops", "fastapi", "optimum", "orjson", "pillow", "posthog", "prometheus-fastapi-instrumentator", "pydantic", "rich", "sentence-transformers", "soundfile", "timm", "torch", "typer", "uvicorn"] audio = ["soundfile"] cache = ["diskcache"] -ct2 = ["ctranslate2", "sentence-transformers", "torch", "transformers"] +ct2 = ["ctranslate2", "sentence-transformers", "torch", "torchao", "transformers"] einops = ["einops"] logging = ["rich"] onnxruntime-gpu = ["onnxruntime-gpu"] optimum = ["optimum"] server = ["fastapi", "orjson", "posthog", "prometheus-fastapi-instrumentator", "pydantic", "rich", "typer", "uvicorn"] tensorrt = ["tensorrt"] -torch = ["sentence-transformers", "torch"] +torch = ["sentence-transformers", "torch", "torchao"] vision = ["pillow", "timm"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "0c763ca18d0acb2628b5f25516cf7702214d22425fc8c6761ff7054c6285e9dd" +content-hash = "08a9ba12f7feeb6a5a1a3fbdaba5a8ddfd95e1c61dc22cf74aa1ecc26ab0cc29" diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index 4e83ebca..1abae92b 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -42,6 +42,7 @@ diskcache = {version = "*", optional=true} onnxruntime-gpu = {version = "*", optional=true} tensorrt = {version = "^8.6.1", optional=true} soundfile = {version="^0.12.1", optional=true} +torchao = {version="^0.5.0", optional=true} [tool.poetry.scripts] infinity_emb = "infinity_emb.infinity_server:cli" @@ -82,9 +83,9 @@ types-chardet = "^5.0.4.6" mypy-protobuf = "^3.0.0" [tool.poetry.extras] -ct2=["ctranslate2","sentence-transformers","torch","transformers"] +ct2=["ctranslate2","sentence-transformers","torch","torchao","transformers"] optimum=["optimum"] -torch=["sentence-transformers","torch"] +torch=["sentence-transformers","torch","torchao"] einops=["einops"] logging=["rich"] cache=["diskcache"] diff --git a/libs/infinity_emb/tests/unit_test/transformer/quantization/test_interface.py b/libs/infinity_emb/tests/unit_test/transformer/quantization/test_interface.py index 4321ed0d..a5cb67ab 100644 --- a/libs/infinity_emb/tests/unit_test/transformer/quantization/test_interface.py +++ b/libs/infinity_emb/tests/unit_test/transformer/quantization/test_interface.py @@ -4,7 +4,11 @@ import torch from transformers import AutoTokenizer, BertModel # type: ignore +from infinity_emb.args import EngineArgs from infinity_emb.primitives import Device, Dtype +from infinity_emb.transformer.embedder.sentence_transformer import ( + SentenceTransformerPatched, +) from infinity_emb.transformer.quantization.interface import quant_interface devices = [Device.cpu] @@ -49,3 +53,45 @@ def test_quantize_bert(device: Device, dtype: Dtype): out_quant = model.forward(**tokens_encoded)["last_hidden_state"].mean(dim=1) assert torch.cosine_similarity(out_default, out_quant) > 0.95 + + +def test_autoquant_quantization(): + model_st = SentenceTransformerPatched( + engine_args=EngineArgs( + model_name_or_path="michaelfeil/bge-small-en-v1.5", + dtype="autoquant", + engine="torch", + bettertransformer=False, + ) + ) + model_default = SentenceTransformerPatched( + engine_args=EngineArgs( + model_name_or_path="michaelfeil/bge-small-en-v1.5", + dtype="float32", + engine="torch", + bettertransformer=False, + ) + ) + sentence = "This is a test sentence." + for sentence in [ + "This is a test sentence.", + "This is another sentence, that should be embedded. " * 10, + "1", + ]: + embedding_st = model_st.encode_post( + model_st.encode_core(model_st.encode_pre([sentence])) + ) + embedding_default = model_default.encode_post( + model_default.encode_core(model_default.encode_pre([sentence])) + ) + assert embedding_st.shape == embedding_default.shape + + # cosine similarity + sim = torch.nn.functional.cosine_similarity( + torch.tensor(embedding_st), torch.tensor(embedding_default) + ) + assert sim > 0.95 + + +if __name__ == "__main__": + test_autoquant_quantization()