mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
feat: sentence_embedding pipeline
This commit is contained in:
@@ -15,6 +15,7 @@ from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
from modelscope.utils.task_utils import is_embedding_task
|
||||
from .base import Pipeline
|
||||
from .util import is_official_hub_path
|
||||
|
||||
@@ -187,6 +188,16 @@ def pipeline(task: str = None,
|
||||
else:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
|
||||
if not pipeline_props and is_embedding_task(task):
|
||||
try:
|
||||
from modelscope.utils.hf_util import sentence_transformers_pipeline
|
||||
return sentence_transformers_pipeline(model=model, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'We could not find a suitable pipeline from modelscope, so we tried to load it using the '
|
||||
'sentence_transformers, but that also failed.')
|
||||
raise e
|
||||
|
||||
if not pipeline_props and is_transformers_available():
|
||||
try:
|
||||
from modelscope.utils.hf_util import hf_pipeline
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .auto_class import *
|
||||
from .patcher import patch_context, patch_hub, unpatch_hub
|
||||
from .pipeline_builder import hf_pipeline
|
||||
from .pipeline_builder import hf_pipeline, sentence_transformers_pipeline
|
||||
|
||||
@@ -3,6 +3,9 @@ from typing import Optional, Union
|
||||
|
||||
from modelscope.hub import snapshot_download
|
||||
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _get_hf_device(device):
|
||||
@@ -52,3 +55,29 @@ def hf_pipeline(
|
||||
device=device,
|
||||
pipeline_class=pipeline_class,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def sentence_transformers_pipeline(model: str, **kwargs):
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Could not import sentence_transformers, please upgrade to the latest version of sentence_transformers '
|
||||
"with: 'pip install -U sentence_transformers'") from None
|
||||
if isinstance(model, str):
|
||||
if not os.path.exists(model):
|
||||
model = snapshot_download(model)
|
||||
|
||||
def __call__(self,
|
||||
sentences: str | list[str] | None = None,
|
||||
prompt_name: str | None = None,
|
||||
**kwargs):
|
||||
input = kwargs.pop('input', None)
|
||||
if input is not None:
|
||||
sentences = input['source_sentence']
|
||||
res = self.encode(sentences, **kwargs)
|
||||
return {'text_embedding': res}
|
||||
return self.encode(sentences, prompt_name, **kwargs)
|
||||
|
||||
SentenceTransformer.__call__ = __call__
|
||||
return SentenceTransformer(model, **kwargs)
|
||||
|
||||
@@ -82,6 +82,14 @@ def _inverted_index(forward_index):
|
||||
INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL)
|
||||
|
||||
|
||||
def is_embedding_task(task: str):
|
||||
if task is None:
|
||||
return False
|
||||
if task in (Tasks.sentence_embedding, ):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_task_by_subtask_name(group_key):
|
||||
if group_key in INVERTED_TASKS_LEVEL:
|
||||
return INVERTED_TASKS_LEVEL[group_key][
|
||||
|
||||
43
tests/utils/test_sentence_embedding_utils.py
Normal file
43
tests/utils/test_sentence_embedding_utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import unittest
|
||||
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
class LLMPipelineTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'Qwen/Qwen3-Embedding-0.6B'
|
||||
self.querys = [
|
||||
"What is the capital of China?",
|
||||
"Explain gravity",
|
||||
]
|
||||
self.documents = [
|
||||
"The capital of China is Beijing.",
|
||||
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
|
||||
]
|
||||
|
||||
def test_ori_pipeline(self):
|
||||
ppl = pipeline(
|
||||
Tasks.sentence_embedding,
|
||||
model=self.model_id,
|
||||
model_revision='master',
|
||||
)
|
||||
inputs = {"source_sentence": self.documents}
|
||||
embeddings = ppl(input=inputs)["text_embedding"]
|
||||
self.assertEqual(embeddings.shape[0], len(self.documents))
|
||||
assert((embeddings[0][0]+0.0471825)<0.01) # check value
|
||||
|
||||
def test_sentence_embedding_input(self):
|
||||
ppl = pipeline(
|
||||
Tasks.sentence_embedding,
|
||||
model=self.model_id,
|
||||
model_revision='master',
|
||||
)
|
||||
embeddings = ppl(self.documents, prompt_name="query")
|
||||
self.assertEqual(embeddings.shape[0], len(self.documents))
|
||||
assert ((embeddings[0][0] + 0.0471825) < 0.01) # check value
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user