minor fix

This commit is contained in:
suluyan
2025-08-04 14:11:33 +08:00
parent 2610cbf4fd
commit 58b73b761f
3 changed files with 28 additions and 22 deletions

View File

@@ -68,16 +68,26 @@ def sentence_transformers_pipeline(model: str, **kwargs):
if not os.path.exists(model):
model = snapshot_download(model)
def __call__(self,
sentences: str | list[str] | None = None,
prompt_name: str | None = None,
**kwargs):
input = kwargs.pop('input', None)
if input is not None:
sentences = input['source_sentence']
res = self.encode(sentences, **kwargs)
return {'text_embedding': res}
return self.encode(sentences, prompt_name, **kwargs)
from modelscope.pipelines import Pipeline
class SentenceTransformerPipeline(Pipeline):
"""A wrapper for sentence_transformers.SentenceTransformer to make it compatible
with the modelscope pipeline conventions."""
SentenceTransformer.__call__ = __call__
return SentenceTransformer(model, **kwargs)
def __init__(self, model_path: str, **kwargs):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_path, **kwargs)
def __call__(
self,
sentences: str | list[str] | None = None,
prompt_name: str | None = None,
**kwargs
):
input_data = kwargs.pop('input', None)
if input_data is not None:
sentences = input_data['source_sentence']
res = self.model.encode(sentences, **kwargs)
return {'text_embedding': res}
return self.model.encode(sentences, prompt_name=prompt_name, **kwargs)
return SentenceTransformerPipeline(model, **kwargs)

View File

@@ -83,11 +83,7 @@ INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL)
def is_embedding_task(task: str):
if task is None:
return False
if task in (Tasks.sentence_embedding, ):
return True
return False
return task == Tasks.sentence_embedding
def get_task_by_subtask_name(group_key):

View File

@@ -5,10 +5,10 @@ from modelscope.utils.test_utils import test_level
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
class LLMPipelineTest(unittest.TestCase):
class SentenceEmbeddingPipelineTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'Qwen/Qwen3-Embedding-0.6B'
self.querys = [
self.queries = [
"What is the capital of China?",
"Explain gravity",
]
@@ -26,7 +26,7 @@ class LLMPipelineTest(unittest.TestCase):
inputs = {"source_sentence": self.documents}
embeddings = ppl(input=inputs)["text_embedding"]
self.assertEqual(embeddings.shape[0], len(self.documents))
assert((embeddings[0][0]+0.0471825)<0.01) # check value
self.assertLess((embeddings[0][0] + 0.0471825), 0.01) # check value
def test_sentence_embedding_input(self):
ppl = pipeline(
@@ -34,9 +34,9 @@ class LLMPipelineTest(unittest.TestCase):
model=self.model_id,
model_revision='master',
)
embeddings = ppl(self.documents, prompt_name="query")
embeddings = ppl(self.queries, prompt_name="query")
self.assertEqual(embeddings.shape[0], len(self.documents))
assert ((embeddings[0][0] + 0.0471825) < 0.01) # check value
self.assertLess((embeddings[0][0] + 0.050865322), 0.01) # check value
if __name__ == '__main__':