diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 747cf242..3c26dff5 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -742,13 +742,14 @@ class HubApi: Args: repo_id (`str`): The repo id to use - filename (`str`): The queried filename + filename (`str`): The queried filename, if the file exists in a sub folder, + please pass / revision (`Optional[str]`): The repo revision Returns: The query result in bool value """ - files = self.get_model_files(repo_id, revision=revision) - files = [file['Name'] for file in files] + files = self.get_model_files(repo_id, recursive=True, revision=revision) + files = [file['Path'] for file in files] return filename in files def create_dataset(self, diff --git a/tests/hub/test_file_exists.py b/tests/hub/test_file_exists.py new file mode 100644 index 00000000..6f21a02a --- /dev/null +++ b/tests/hub/test_file_exists.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import uuid +from os.path import expanduser + +from requests import delete + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.git import GitCommandWrapper +from modelscope.hub.repository import Repository +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1, + TEST_MODEL_CHINESE_NAME, + TEST_MODEL_ORG, delete_credential) + +logger = get_logger() +logger.setLevel('DEBUG') +DEFAULT_GIT_PATH = 'git' +download_model_file_name = 'test.bin' + + +class FileExistsTest(unittest.TestCase): + + def test_file_exsists(self): + api = HubApi() + self.assertTrue( + api.file_exists('iic/gte_Qwen2-7B-instruct', 'added_tokens.json')) + self.assertTrue( + api.file_exists('iic/gte_Qwen2-7B-instruct', + '1_Pooling/config.json')) + + +if __name__ == '__main__': + unittest.main()