This commit is contained in:
suluyan
2025-08-04 14:16:15 +08:00
parent 58b73b761f
commit bfb2d5a34c
3 changed files with 22 additions and 21 deletions

View File

@@ -192,11 +192,11 @@ def pipeline(task: str = None,
try:
from modelscope.utils.hf_util import sentence_transformers_pipeline
return sentence_transformers_pipeline(model=model, **kwargs)
except Exception as e:
logger.error(
except Exception:
logger.exception(
'We could not find a suitable pipeline from modelscope, so we tried to load it using the '
'sentence_transformers, but that also failed.')
raise e
raise
if not pipeline_props and is_transformers_available():
try:

View File

@@ -69,25 +69,24 @@ def sentence_transformers_pipeline(model: str, **kwargs):
model = snapshot_download(model)
from modelscope.pipelines import Pipeline
class SentenceTransformerPipeline(Pipeline):
"""A wrapper for sentence_transformers.SentenceTransformer to make it compatible
with the modelscope pipeline conventions."""
def __init__(self, model_path: str, **kwargs):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_path, **kwargs)
def __call__(
self,
sentences: str | list[str] | None = None,
prompt_name: str | None = None,
**kwargs
):
def __call__(self,
sentences: str | list[str] | None = None,
prompt_name: str | None = None,
**kwargs):
input_data = kwargs.pop('input', None)
if input_data is not None:
sentences = input_data['source_sentence']
res = self.model.encode(sentences, **kwargs)
return {'text_embedding': res}
return self.model.encode(sentences, prompt_name=prompt_name, **kwargs)
return self.model.encode(
sentences, prompt_name=prompt_name, **kwargs)
return SentenceTransformerPipeline(model, **kwargs)

View File

@@ -1,20 +1,22 @@
import unittest
from modelscope.utils.test_utils import test_level
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class SentenceEmbeddingPipelineTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'Qwen/Qwen3-Embedding-0.6B'
self.queries = [
"What is the capital of China?",
"Explain gravity",
'What is the capital of China?',
'Explain gravity',
]
self.documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
'The capital of China is Beijing.',
'Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and '
'is responsible for the movement of planets around the sun.',
]
def test_ori_pipeline(self):
@@ -23,8 +25,8 @@ class SentenceEmbeddingPipelineTest(unittest.TestCase):
model=self.model_id,
model_revision='master',
)
inputs = {"source_sentence": self.documents}
embeddings = ppl(input=inputs)["text_embedding"]
inputs = {'source_sentence': self.documents}
embeddings = ppl(input=inputs)['text_embedding']
self.assertEqual(embeddings.shape[0], len(self.documents))
self.assertLess((embeddings[0][0] + 0.0471825), 0.01) # check value
@@ -34,10 +36,10 @@ class SentenceEmbeddingPipelineTest(unittest.TestCase):
model=self.model_id,
model_revision='master',
)
embeddings = ppl(self.queries, prompt_name="query")
embeddings = ppl(self.queries, prompt_name='query')
self.assertEqual(embeddings.shape[0], len(self.documents))
self.assertLess((embeddings[0][0] + 0.050865322), 0.01) # check value
if __name__ == '__main__':
unittest.main()
unittest.main()