From 53ceca4df4ddffe8606f45a94411fc80606785ed Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Wed, 6 Aug 2025 15:43:36 +0800 Subject: [PATCH] feat: sentence_embedding pipeline (#1435) --- modelscope/pipelines/builder.py | 11 +++++ modelscope/utils/hf_util/__init__.py | 2 +- modelscope/utils/hf_util/pipeline_builder.py | 38 +++++++++++++++ modelscope/utils/task_utils.py | 4 ++ tests/utils/test_sentence_embedding_utils.py | 50 ++++++++++++++++++++ 5 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_sentence_embedding_utils.py diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5fb66178..0ac7d2ec 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -15,6 +15,7 @@ from modelscope.utils.logger import get_logger from modelscope.utils.plugins import (register_modelhub_repo, register_plugins_repo) from modelscope.utils.registry import Registry, build_from_cfg +from modelscope.utils.task_utils import is_embedding_task from .base import Pipeline from .util import is_official_hub_path @@ -187,6 +188,16 @@ def pipeline(task: str = None, else: pipeline_props = {'type': pipeline_name} + if not pipeline_props and is_embedding_task(task): + try: + from modelscope.utils.hf_util import sentence_transformers_pipeline + return sentence_transformers_pipeline(model=model, **kwargs) + except Exception: + logger.exception( + 'We could not find a suitable pipeline from modelscope, so we tried to load it using the ' + 'sentence_transformers, but that also failed.') + raise + if not pipeline_props and is_transformers_available(): try: from modelscope.utils.hf_util import hf_pipeline diff --git a/modelscope/utils/hf_util/__init__.py b/modelscope/utils/hf_util/__init__.py index ac8349c9..e2b00b93 100644 --- a/modelscope/utils/hf_util/__init__.py +++ b/modelscope/utils/hf_util/__init__.py @@ -1,3 +1,3 @@ from .auto_class import * from .patcher import patch_context, patch_hub, unpatch_hub -from .pipeline_builder import hf_pipeline +from .pipeline_builder import hf_pipeline, sentence_transformers_pipeline diff --git a/modelscope/utils/hf_util/pipeline_builder.py b/modelscope/utils/hf_util/pipeline_builder.py index 734ab09d..8bf6fa02 100644 --- a/modelscope/utils/hf_util/pipeline_builder.py +++ b/modelscope/utils/hf_util/pipeline_builder.py @@ -3,6 +3,9 @@ from typing import Optional, Union from modelscope.hub import snapshot_download from modelscope.utils.hf_util.patcher import _patch_pretrained_class +from modelscope.utils.logger import get_logger + +logger = get_logger() def _get_hf_device(device): @@ -52,3 +55,38 @@ def hf_pipeline( device=device, pipeline_class=pipeline_class, **kwargs) + + +def sentence_transformers_pipeline(model: str, **kwargs): + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + 'Could not import sentence_transformers, please upgrade to the latest version of sentence_transformers ' + "with: 'pip install -U sentence_transformers'") from None + if isinstance(model, str): + if not os.path.exists(model): + model = snapshot_download(model) + + from modelscope.pipelines import Pipeline + + class SentenceTransformerPipeline(Pipeline): + """A wrapper for sentence_transformers.SentenceTransformer to make it compatible + with the modelscope pipeline conventions.""" + + def __init__(self, model_path: str, **kwargs): + self.model = SentenceTransformer(model_path, **kwargs) + + def __call__(self, + sentences: str | list[str] | None = None, + prompt_name: str | None = None, + **kwargs): + input_data = kwargs.pop('input', None) + if input_data is not None: + sentences = input_data['source_sentence'] + res = self.model.encode(sentences, **kwargs) + return {'text_embedding': res} + return self.model.encode( + sentences, prompt_name=prompt_name, **kwargs) + + return SentenceTransformerPipeline(model, **kwargs) diff --git a/modelscope/utils/task_utils.py b/modelscope/utils/task_utils.py index 07d3838e..2c99fcf3 100644 --- a/modelscope/utils/task_utils.py +++ b/modelscope/utils/task_utils.py @@ -82,6 +82,10 @@ def _inverted_index(forward_index): INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL) +def is_embedding_task(task: str): + return task == Tasks.sentence_embedding + + def get_task_by_subtask_name(group_key): if group_key in INVERTED_TASKS_LEVEL: return INVERTED_TASKS_LEVEL[group_key][ diff --git a/tests/utils/test_sentence_embedding_utils.py b/tests/utils/test_sentence_embedding_utils.py new file mode 100644 index 00000000..5e4cf76e --- /dev/null +++ b/tests/utils/test_sentence_embedding_utils.py @@ -0,0 +1,50 @@ +import subprocess +import sys +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class SentenceEmbeddingPipelineTest(unittest.TestCase): + + def setUp(self) -> None: + + subprocess.check_call( + [sys.executable, '-m', 'pip', 'install', 'transformers>=4.51.3']) + self.model_id = 'Qwen/Qwen3-Embedding-0.6B' + self.queries = [ + 'What is the capital of China?', + 'Explain gravity', + ] + self.documents = [ + 'The capital of China is Beijing.', + 'Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and ' + 'is responsible for the movement of planets around the sun.', + ] + + def test_ori_pipeline(self): + ppl = pipeline( + Tasks.sentence_embedding, + model=self.model_id, + model_revision='master', + ) + inputs = {'source_sentence': self.documents} + embeddings = ppl(input=inputs)['text_embedding'] + self.assertEqual(embeddings.shape[0], len(self.documents)) + self.assertLess((embeddings[0][0] + 0.0471825), 0.01) # check value + + def test_sentence_embedding_input(self): + ppl = pipeline( + Tasks.sentence_embedding, + model=self.model_id, + model_revision='master', + ) + embeddings = ppl(self.queries, prompt_name='query') + self.assertEqual(embeddings.shape[0], len(self.queries)) + self.assertLess((embeddings[0][0] + 0.050865322), 0.01) # check value + + +if __name__ == '__main__': + unittest.main()