mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
minor fix
This commit is contained in:
@@ -68,16 +68,26 @@ def sentence_transformers_pipeline(model: str, **kwargs):
|
||||
if not os.path.exists(model):
|
||||
model = snapshot_download(model)
|
||||
|
||||
def __call__(self,
|
||||
sentences: str | list[str] | None = None,
|
||||
prompt_name: str | None = None,
|
||||
**kwargs):
|
||||
input = kwargs.pop('input', None)
|
||||
if input is not None:
|
||||
sentences = input['source_sentence']
|
||||
res = self.encode(sentences, **kwargs)
|
||||
return {'text_embedding': res}
|
||||
return self.encode(sentences, prompt_name, **kwargs)
|
||||
from modelscope.pipelines import Pipeline
|
||||
class SentenceTransformerPipeline(Pipeline):
|
||||
"""A wrapper for sentence_transformers.SentenceTransformer to make it compatible
|
||||
with the modelscope pipeline conventions."""
|
||||
|
||||
SentenceTransformer.__call__ = __call__
|
||||
return SentenceTransformer(model, **kwargs)
|
||||
def __init__(self, model_path: str, **kwargs):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model_path, **kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sentences: str | list[str] | None = None,
|
||||
prompt_name: str | None = None,
|
||||
**kwargs
|
||||
):
|
||||
input_data = kwargs.pop('input', None)
|
||||
if input_data is not None:
|
||||
sentences = input_data['source_sentence']
|
||||
res = self.model.encode(sentences, **kwargs)
|
||||
return {'text_embedding': res}
|
||||
return self.model.encode(sentences, prompt_name=prompt_name, **kwargs)
|
||||
|
||||
return SentenceTransformerPipeline(model, **kwargs)
|
||||
|
||||
@@ -83,11 +83,7 @@ INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL)
|
||||
|
||||
|
||||
def is_embedding_task(task: str):
|
||||
if task is None:
|
||||
return False
|
||||
if task in (Tasks.sentence_embedding, ):
|
||||
return True
|
||||
return False
|
||||
return task == Tasks.sentence_embedding
|
||||
|
||||
|
||||
def get_task_by_subtask_name(group_key):
|
||||
|
||||
@@ -5,10 +5,10 @@ from modelscope.utils.test_utils import test_level
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
class LLMPipelineTest(unittest.TestCase):
|
||||
class SentenceEmbeddingPipelineTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'Qwen/Qwen3-Embedding-0.6B'
|
||||
self.querys = [
|
||||
self.queries = [
|
||||
"What is the capital of China?",
|
||||
"Explain gravity",
|
||||
]
|
||||
@@ -26,7 +26,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
inputs = {"source_sentence": self.documents}
|
||||
embeddings = ppl(input=inputs)["text_embedding"]
|
||||
self.assertEqual(embeddings.shape[0], len(self.documents))
|
||||
assert((embeddings[0][0]+0.0471825)<0.01) # check value
|
||||
self.assertLess((embeddings[0][0] + 0.0471825), 0.01) # check value
|
||||
|
||||
def test_sentence_embedding_input(self):
|
||||
ppl = pipeline(
|
||||
@@ -34,9 +34,9 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model=self.model_id,
|
||||
model_revision='master',
|
||||
)
|
||||
embeddings = ppl(self.documents, prompt_name="query")
|
||||
embeddings = ppl(self.queries, prompt_name="query")
|
||||
self.assertEqual(embeddings.shape[0], len(self.documents))
|
||||
assert ((embeddings[0][0] + 0.0471825) < 0.01) # check value
|
||||
self.assertLess((embeddings[0][0] + 0.050865322), 0.01) # check value
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user