feat: sentence_embedding pipeline (#1435)

This commit is contained in:
suluyana
2025-08-06 15:43:36 +08:00
committed by GitHub
parent ee05e12d75
commit 53ceca4df4
5 changed files with 104 additions and 1 deletions

View File

@@ -15,6 +15,7 @@ from modelscope.utils.logger import get_logger
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
from modelscope.utils.task_utils import is_embedding_task
from .base import Pipeline
from .util import is_official_hub_path
@@ -187,6 +188,16 @@ def pipeline(task: str = None,
else:
pipeline_props = {'type': pipeline_name}
if not pipeline_props and is_embedding_task(task):
try:
from modelscope.utils.hf_util import sentence_transformers_pipeline
return sentence_transformers_pipeline(model=model, **kwargs)
except Exception:
logger.exception(
'We could not find a suitable pipeline from modelscope, so we tried to load it using the '
'sentence_transformers, but that also failed.')
raise
if not pipeline_props and is_transformers_available():
try:
from modelscope.utils.hf_util import hf_pipeline

View File

@@ -1,3 +1,3 @@
from .auto_class import *
from .patcher import patch_context, patch_hub, unpatch_hub
from .pipeline_builder import hf_pipeline
from .pipeline_builder import hf_pipeline, sentence_transformers_pipeline

View File

@@ -3,6 +3,9 @@ from typing import Optional, Union
from modelscope.hub import snapshot_download
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
from modelscope.utils.logger import get_logger
logger = get_logger()
def _get_hf_device(device):
@@ -52,3 +55,38 @@ def hf_pipeline(
device=device,
pipeline_class=pipeline_class,
**kwargs)
def sentence_transformers_pipeline(model: str, **kwargs):
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
'Could not import sentence_transformers, please upgrade to the latest version of sentence_transformers '
"with: 'pip install -U sentence_transformers'") from None
if isinstance(model, str):
if not os.path.exists(model):
model = snapshot_download(model)
from modelscope.pipelines import Pipeline
class SentenceTransformerPipeline(Pipeline):
"""A wrapper for sentence_transformers.SentenceTransformer to make it compatible
with the modelscope pipeline conventions."""
def __init__(self, model_path: str, **kwargs):
self.model = SentenceTransformer(model_path, **kwargs)
def __call__(self,
sentences: str | list[str] | None = None,
prompt_name: str | None = None,
**kwargs):
input_data = kwargs.pop('input', None)
if input_data is not None:
sentences = input_data['source_sentence']
res = self.model.encode(sentences, **kwargs)
return {'text_embedding': res}
return self.model.encode(
sentences, prompt_name=prompt_name, **kwargs)
return SentenceTransformerPipeline(model, **kwargs)

View File

@@ -82,6 +82,10 @@ def _inverted_index(forward_index):
INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL)
def is_embedding_task(task: str):
return task == Tasks.sentence_embedding
def get_task_by_subtask_name(group_key):
if group_key in INVERTED_TASKS_LEVEL:
return INVERTED_TASKS_LEVEL[group_key][

View File

@@ -0,0 +1,50 @@
import subprocess
import sys
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class SentenceEmbeddingPipelineTest(unittest.TestCase):
def setUp(self) -> None:
subprocess.check_call(
[sys.executable, '-m', 'pip', 'install', 'transformers>=4.51.3'])
self.model_id = 'Qwen/Qwen3-Embedding-0.6B'
self.queries = [
'What is the capital of China?',
'Explain gravity',
]
self.documents = [
'The capital of China is Beijing.',
'Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and '
'is responsible for the movement of planets around the sun.',
]
def test_ori_pipeline(self):
ppl = pipeline(
Tasks.sentence_embedding,
model=self.model_id,
model_revision='master',
)
inputs = {'source_sentence': self.documents}
embeddings = ppl(input=inputs)['text_embedding']
self.assertEqual(embeddings.shape[0], len(self.documents))
self.assertLess((embeddings[0][0] + 0.0471825), 0.01) # check value
def test_sentence_embedding_input(self):
ppl = pipeline(
Tasks.sentence_embedding,
model=self.model_id,
model_revision='master',
)
embeddings = ppl(self.queries, prompt_name='query')
self.assertEqual(embeddings.shape[0], len(self.queries))
self.assertLess((embeddings[0][0] + 0.050865322), 0.01) # check value
if __name__ == '__main__':
unittest.main()