feat: sentence_embedding pipeline

This commit is contained in:
suluyan
2025-08-04 13:10:49 +08:00
parent 49d50b2126
commit 2610cbf4fd
5 changed files with 92 additions and 1 deletions

View File

@@ -15,6 +15,7 @@ from modelscope.utils.logger import get_logger
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
from modelscope.utils.task_utils import is_embedding_task
from .base import Pipeline
from .util import is_official_hub_path
@@ -187,6 +188,16 @@ def pipeline(task: str = None,
else:
pipeline_props = {'type': pipeline_name}
if not pipeline_props and is_embedding_task(task):
try:
from modelscope.utils.hf_util import sentence_transformers_pipeline
return sentence_transformers_pipeline(model=model, **kwargs)
except Exception as e:
logger.error(
'We could not find a suitable pipeline from modelscope, so we tried to load it using the '
'sentence_transformers, but that also failed.')
raise e
if not pipeline_props and is_transformers_available():
try:
from modelscope.utils.hf_util import hf_pipeline

View File

@@ -1,3 +1,3 @@
from .auto_class import *
from .patcher import patch_context, patch_hub, unpatch_hub
from .pipeline_builder import hf_pipeline
from .pipeline_builder import hf_pipeline, sentence_transformers_pipeline

View File

@@ -3,6 +3,9 @@ from typing import Optional, Union
from modelscope.hub import snapshot_download
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
from modelscope.utils.logger import get_logger
logger = get_logger()
def _get_hf_device(device):
@@ -52,3 +55,29 @@ def hf_pipeline(
device=device,
pipeline_class=pipeline_class,
**kwargs)
def sentence_transformers_pipeline(model: str, **kwargs):
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
'Could not import sentence_transformers, please upgrade to the latest version of sentence_transformers '
"with: 'pip install -U sentence_transformers'") from None
if isinstance(model, str):
if not os.path.exists(model):
model = snapshot_download(model)
def __call__(self,
sentences: str | list[str] | None = None,
prompt_name: str | None = None,
**kwargs):
input = kwargs.pop('input', None)
if input is not None:
sentences = input['source_sentence']
res = self.encode(sentences, **kwargs)
return {'text_embedding': res}
return self.encode(sentences, prompt_name, **kwargs)
SentenceTransformer.__call__ = __call__
return SentenceTransformer(model, **kwargs)

View File

@@ -82,6 +82,14 @@ def _inverted_index(forward_index):
INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL)
def is_embedding_task(task: str):
if task is None:
return False
if task in (Tasks.sentence_embedding, ):
return True
return False
def get_task_by_subtask_name(group_key):
if group_key in INVERTED_TASKS_LEVEL:
return INVERTED_TASKS_LEVEL[group_key][

View File

@@ -0,0 +1,43 @@
import unittest
from modelscope.utils.test_utils import test_level
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
class LLMPipelineTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'Qwen/Qwen3-Embedding-0.6B'
self.querys = [
"What is the capital of China?",
"Explain gravity",
]
self.documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
def test_ori_pipeline(self):
ppl = pipeline(
Tasks.sentence_embedding,
model=self.model_id,
model_revision='master',
)
inputs = {"source_sentence": self.documents}
embeddings = ppl(input=inputs)["text_embedding"]
self.assertEqual(embeddings.shape[0], len(self.documents))
assert((embeddings[0][0]+0.0471825)<0.01) # check value
def test_sentence_embedding_input(self):
ppl = pipeline(
Tasks.sentence_embedding,
model=self.model_id,
model_revision='master',
)
embeddings = ppl(self.documents, prompt_name="query")
self.assertEqual(embeddings.shape[0], len(self.documents))
assert ((embeddings[0][0] + 0.0471825) < 0.01) # check value
if __name__ == '__main__':
unittest.main()