mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
fix typo
This commit is contained in:
@@ -192,11 +192,11 @@ def pipeline(task: str = None,
|
||||
try:
|
||||
from modelscope.utils.hf_util import sentence_transformers_pipeline
|
||||
return sentence_transformers_pipeline(model=model, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
except Exception:
|
||||
logger.exception(
|
||||
'We could not find a suitable pipeline from modelscope, so we tried to load it using the '
|
||||
'sentence_transformers, but that also failed.')
|
||||
raise e
|
||||
raise
|
||||
|
||||
if not pipeline_props and is_transformers_available():
|
||||
try:
|
||||
|
||||
@@ -69,25 +69,24 @@ def sentence_transformers_pipeline(model: str, **kwargs):
|
||||
model = snapshot_download(model)
|
||||
|
||||
from modelscope.pipelines import Pipeline
|
||||
|
||||
class SentenceTransformerPipeline(Pipeline):
|
||||
"""A wrapper for sentence_transformers.SentenceTransformer to make it compatible
|
||||
with the modelscope pipeline conventions."""
|
||||
|
||||
def __init__(self, model_path: str, **kwargs):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model_path, **kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sentences: str | list[str] | None = None,
|
||||
prompt_name: str | None = None,
|
||||
**kwargs
|
||||
):
|
||||
def __call__(self,
|
||||
sentences: str | list[str] | None = None,
|
||||
prompt_name: str | None = None,
|
||||
**kwargs):
|
||||
input_data = kwargs.pop('input', None)
|
||||
if input_data is not None:
|
||||
sentences = input_data['source_sentence']
|
||||
res = self.model.encode(sentences, **kwargs)
|
||||
return {'text_embedding': res}
|
||||
return self.model.encode(sentences, prompt_name=prompt_name, **kwargs)
|
||||
return self.model.encode(
|
||||
sentences, prompt_name=prompt_name, **kwargs)
|
||||
|
||||
return SentenceTransformerPipeline(model, **kwargs)
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
import unittest
|
||||
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class SentenceEmbeddingPipelineTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'Qwen/Qwen3-Embedding-0.6B'
|
||||
self.queries = [
|
||||
"What is the capital of China?",
|
||||
"Explain gravity",
|
||||
'What is the capital of China?',
|
||||
'Explain gravity',
|
||||
]
|
||||
self.documents = [
|
||||
"The capital of China is Beijing.",
|
||||
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
|
||||
'The capital of China is Beijing.',
|
||||
'Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and '
|
||||
'is responsible for the movement of planets around the sun.',
|
||||
]
|
||||
|
||||
def test_ori_pipeline(self):
|
||||
@@ -23,8 +25,8 @@ class SentenceEmbeddingPipelineTest(unittest.TestCase):
|
||||
model=self.model_id,
|
||||
model_revision='master',
|
||||
)
|
||||
inputs = {"source_sentence": self.documents}
|
||||
embeddings = ppl(input=inputs)["text_embedding"]
|
||||
inputs = {'source_sentence': self.documents}
|
||||
embeddings = ppl(input=inputs)['text_embedding']
|
||||
self.assertEqual(embeddings.shape[0], len(self.documents))
|
||||
self.assertLess((embeddings[0][0] + 0.0471825), 0.01) # check value
|
||||
|
||||
@@ -34,10 +36,10 @@ class SentenceEmbeddingPipelineTest(unittest.TestCase):
|
||||
model=self.model_id,
|
||||
model_revision='master',
|
||||
)
|
||||
embeddings = ppl(self.queries, prompt_name="query")
|
||||
embeddings = ppl(self.queries, prompt_name='query')
|
||||
self.assertEqual(embeddings.shape[0], len(self.documents))
|
||||
self.assertLess((embeddings[0][0] + 0.050865322), 0.01) # check value
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user