From 58b73b761f9c8e1b2f43d253f1baa2a46711a48b Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 4 Aug 2025 14:11:33 +0800 Subject: [PATCH] minor fix --- modelscope/utils/hf_util/pipeline_builder.py | 34 +++++++++++++------- modelscope/utils/task_utils.py | 6 +--- tests/utils/test_sentence_embedding_utils.py | 10 +++--- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/modelscope/utils/hf_util/pipeline_builder.py b/modelscope/utils/hf_util/pipeline_builder.py index a2a2208f..65edf3af 100644 --- a/modelscope/utils/hf_util/pipeline_builder.py +++ b/modelscope/utils/hf_util/pipeline_builder.py @@ -68,16 +68,26 @@ def sentence_transformers_pipeline(model: str, **kwargs): if not os.path.exists(model): model = snapshot_download(model) - def __call__(self, - sentences: str | list[str] | None = None, - prompt_name: str | None = None, - **kwargs): - input = kwargs.pop('input', None) - if input is not None: - sentences = input['source_sentence'] - res = self.encode(sentences, **kwargs) - return {'text_embedding': res} - return self.encode(sentences, prompt_name, **kwargs) + from modelscope.pipelines import Pipeline + class SentenceTransformerPipeline(Pipeline): + """A wrapper for sentence_transformers.SentenceTransformer to make it compatible + with the modelscope pipeline conventions.""" - SentenceTransformer.__call__ = __call__ - return SentenceTransformer(model, **kwargs) + def __init__(self, model_path: str, **kwargs): + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer(model_path, **kwargs) + + def __call__( + self, + sentences: str | list[str] | None = None, + prompt_name: str | None = None, + **kwargs + ): + input_data = kwargs.pop('input', None) + if input_data is not None: + sentences = input_data['source_sentence'] + res = self.model.encode(sentences, **kwargs) + return {'text_embedding': res} + return self.model.encode(sentences, prompt_name=prompt_name, **kwargs) + + return SentenceTransformerPipeline(model, **kwargs) diff --git a/modelscope/utils/task_utils.py b/modelscope/utils/task_utils.py index 1040f670..2c99fcf3 100644 --- a/modelscope/utils/task_utils.py +++ b/modelscope/utils/task_utils.py @@ -83,11 +83,7 @@ INVERTED_TASKS_LEVEL = _inverted_index(DEFAULT_TASKS_LEVEL) def is_embedding_task(task: str): - if task is None: - return False - if task in (Tasks.sentence_embedding, ): - return True - return False + return task == Tasks.sentence_embedding def get_task_by_subtask_name(group_key): diff --git a/tests/utils/test_sentence_embedding_utils.py b/tests/utils/test_sentence_embedding_utils.py index fe428064..1318dfb5 100644 --- a/tests/utils/test_sentence_embedding_utils.py +++ b/tests/utils/test_sentence_embedding_utils.py @@ -5,10 +5,10 @@ from modelscope.utils.test_utils import test_level from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks -class LLMPipelineTest(unittest.TestCase): +class SentenceEmbeddingPipelineTest(unittest.TestCase): def setUp(self) -> None: self.model_id = 'Qwen/Qwen3-Embedding-0.6B' - self.querys = [ + self.queries = [ "What is the capital of China?", "Explain gravity", ] @@ -26,7 +26,7 @@ class LLMPipelineTest(unittest.TestCase): inputs = {"source_sentence": self.documents} embeddings = ppl(input=inputs)["text_embedding"] self.assertEqual(embeddings.shape[0], len(self.documents)) - assert((embeddings[0][0]+0.0471825)<0.01) # check value + self.assertLess((embeddings[0][0] + 0.0471825), 0.01) # check value def test_sentence_embedding_input(self): ppl = pipeline( @@ -34,9 +34,9 @@ class LLMPipelineTest(unittest.TestCase): model=self.model_id, model_revision='master', ) - embeddings = ppl(self.documents, prompt_name="query") + embeddings = ppl(self.queries, prompt_name="query") self.assertEqual(embeddings.shape[0], len(self.documents)) - assert ((embeddings[0][0] + 0.0471825) < 0.01) # check value + self.assertLess((embeddings[0][0] + 0.050865322), 0.01) # check value if __name__ == '__main__':