diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5fb66178..5cc60592 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -15,6 +15,7 @@ from modelscope.utils.logger import get_logger from modelscope.utils.plugins import (register_modelhub_repo, register_plugins_repo) from modelscope.utils.registry import Registry, build_from_cfg +from modelscope.utils.task_utils import is_embedding_task from .base import Pipeline from .util import is_official_hub_path @@ -187,6 +188,16 @@ def pipeline(task: str = None, else: pipeline_props = {'type': pipeline_name} + if not pipeline_props and is_embedding_task(task): + try: + from modelscope.utils.hf_util import sentence_transformers_pipeline + return sentence_transformers_pipeline(model=model, **kwargs) + except Exception as e: + logger.error( + 'We could not find a suitable pipeline from modelscope, so we tried to load it using the ' + 'sentence_transformers, but that also failed.') + raise e + if not pipeline_props and is_transformers_available(): try: from modelscope.utils.hf_util import hf_pipeline diff --git a/modelscope/utils/hf_util/__init__.py b/modelscope/utils/hf_util/__init__.py index ac8349c9..e2b00b93 100644 --- a/modelscope/utils/hf_util/__init__.py +++ b/modelscope/utils/hf_util/__init__.py @@ -1,3 +1,3 @@ from .auto_class import * from .patcher import patch_context, patch_hub, unpatch_hub -from .pipeline_builder import hf_pipeline +from .pipeline_builder import hf_pipeline, sentence_transformers_pipeline diff --git a/modelscope/utils/hf_util/pipeline_builder.py b/modelscope/utils/hf_util/pipeline_builder.py index 734ab09d..a2a2208f 100644 --- a/modelscope/utils/hf_util/pipeline_builder.py +++ b/modelscope/utils/hf_util/pipeline_builder.py @@ -3,6 +3,9 @@ from typing import Optional, Union from modelscope.hub import snapshot_download from modelscope.utils.hf_util.patcher import _patch_pretrained_class +from modelscope.utils.logger import get_logger + +logger = get_logger() def _get_hf_device(device): @@ -52,3 +55,29 @@ def hf_pipeline( device=device, pipeline_class=pipeline_class, **kwargs) + + +def sentence_transformers_pipeline(model: str, **kwargs): + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + 'Could not import sentence_transformers, please upgrade to the latest version of sentence_transformers ' + "with: 'pip install -U sentence_transformers'") from None + if isinstance(model, str): + if not os.path.exists(model): + model = snapshot_download(model) + + def __call__(self, + sentences: str | list[str] | None = None, + prompt_name: str | None = None, + **kwargs): + input = kwargs.pop('input', None) + if input is not None: + sentences = input['source_sentence'] + res = self.encode(sentences, **kwargs) + return {'text_embedding': res} + return self.encode(sentences, prompt_name, **kwargs) + + SentenceTransformer.__call__ = __call__ + return SentenceTransformer(model, **kwargs) diff --git a/modelscope/utils/task_utils.py b/modelscope/utils/task_utils.py index 07d3838e..1040f670 100644 --- a/modelscope/utils/task_utils.py +++ b/modelscope/utils/task_utils.py @@ -82,6 +82,14 @@ def _inverted_index(forward_index): INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL) +def is_embedding_task(task: str): + if task is None: + return False + if task in (Tasks.sentence_embedding, ): + return True + return False + + def get_task_by_subtask_name(group_key): if group_key in INVERTED_TASKS_LEVEL: return INVERTED_TASKS_LEVEL[group_key][ diff --git a/tests/utils/test_sentence_embedding_utils.py b/tests/utils/test_sentence_embedding_utils.py new file mode 100644 index 00000000..fe428064 --- /dev/null +++ b/tests/utils/test_sentence_embedding_utils.py @@ -0,0 +1,43 @@ +import unittest + +from modelscope.utils.test_utils import test_level + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +class LLMPipelineTest(unittest.TestCase): + def setUp(self) -> None: + self.model_id = 'Qwen/Qwen3-Embedding-0.6B' + self.querys = [ + "What is the capital of China?", + "Explain gravity", + ] + self.documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", + ] + + def test_ori_pipeline(self): + ppl = pipeline( + Tasks.sentence_embedding, + model=self.model_id, + model_revision='master', + ) + inputs = {"source_sentence": self.documents} + embeddings = ppl(input=inputs)["text_embedding"] + self.assertEqual(embeddings.shape[0], len(self.documents)) + assert((embeddings[0][0]+0.0471825)<0.01) # check value + + def test_sentence_embedding_input(self): + ppl = pipeline( + Tasks.sentence_embedding, + model=self.model_id, + model_revision='master', + ) + embeddings = ppl(self.documents, prompt_name="query") + self.assertEqual(embeddings.shape[0], len(self.documents)) + assert ((embeddings[0][0] + 0.0471825) < 0.01) # check value + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file