From 74d97ea7e09636b3860be7067e3a4ae8a01bd803 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:12:28 +0800 Subject: [PATCH] Refactor zero sized file downloading (#991) --- modelscope/hub/file_download.py | 10 ++++++++-- tests/hub/test_hub_empty_file.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 tests/hub/test_hub_empty_file.py diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 542c42af..f1cbce6f 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -461,17 +461,23 @@ def http_get_model_file( unit='B', unit_scale=True, unit_divisor=1024, - total=file_size, + total=file_size if file_size > 0 else 1, initial=0, desc='Downloading [' + file_name + ']', ) + if file_size == 0: + # Avoid empty file server request + with open(temp_file_path, 'w+'): + progress.update(1) + progress.close() + break partial_length = 0 if os.path.exists( temp_file_path): # download partial, continue download with open(temp_file_path, 'rb') as f: partial_length = f.seek(0, io.SEEK_END) progress.update(partial_length) - if partial_length >= file_size > 0: + if partial_length >= file_size: break # closed range[], from 0. get_headers['Range'] = 'bytes=%s-%s' % (partial_length, diff --git a/tests/hub/test_hub_empty_file.py b/tests/hub/test_hub_empty_file.py new file mode 100644 index 00000000..b73b1a66 --- /dev/null +++ b/tests/hub/test_hub_empty_file.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path +import shutil +import tempfile +import unittest + +from modelscope import snapshot_download + + +class HubEmptyFile(unittest.TestCase): + + def setUp(self): + temporary_dir = tempfile.mkdtemp() + self.work_dir = temporary_dir + + def tearDown(self): + shutil.rmtree(self.work_dir, ignore_errors=True) + + def test_download_empty_file(self): + model_dir = snapshot_download( + 'tastelikefeet/test_empty_download', cache_dir=self.work_dir) + self.assertTrue(model_dir is not None) + self.assertTrue(os.path.exists(os.path.join(model_dir, '1.txt'))) + self.assertTrue( + os.path.exists(os.path.join(model_dir, 'configuration.json'))) + self.assertTrue(os.path.exists(os.path.join(model_dir, 'init.py'))) + self.assertTrue(os.path.exists(os.path.join(model_dir, 'README.md'))) + + +if __name__ == '__main__': + unittest.main()