[to #47017903] add file io with lock

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11402082
This commit is contained in:
zhangzhicheng.zzc
2023-01-12 13:22:47 +08:00
committed by wenmeng.zwm
parent 06296c1819
commit 075eb1eddb
3 changed files with 192 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import fcntl
import os
import tempfile
from abc import ABCMeta, abstractmethod
@@ -9,6 +10,10 @@ from typing import Generator, Union
import requests
from modelscope.utils.logger import get_logger
logger = get_logger()
class Storage(metaclass=ABCMeta):
"""Abstract class of storage.
@@ -119,6 +124,153 @@ class LocalStorage(Storage):
yield filepath
class LocalStorageWithLock(Storage):
"""Local hard disk storage with file lock"""
handle = None
def acquire(self):
fcntl.flock(self.handle, fcntl.LOCK_EX)
def release(self):
fcntl.flock(self.handle, fcntl.LOCK_UN)
self.handle.close()
def __del__(self):
self.handle.close()
def read(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
self.handle = open(filepath, 'rb')
try:
self.acquire()
logger.debug(f'acquire the lock for read function on {filepath}')
content = self.handle.read()
except Exception as err:
raise err
finally:
self.release()
logger.debug(f'release the lock for read function on {filepath}')
return content
def read_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
self.handle = open(filepath, 'r', encoding=encoding)
try:
self.acquire()
logger.debug(
f'acquire the lock for read_text function on {filepath}')
content = self.handle.read()
except Exception as err:
raise err
finally:
self.release()
logger.debug(
f'release the lock for read_text function on {filepath}')
return content
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
try:
os.makedirs(dirname)
except FileExistsError as err:
logger.warning(
f'File created by other thread during creation with err: {err}'
)
self.handle = open(filepath, 'wb')
try:
self.acquire()
logger.debug(f'acquire the lock for write function on {filepath}')
self.handle.write(obj)
except Exception as err:
raise err
finally:
self.release()
logger.debug(f'release the lock for write function on {filepath}')
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
try:
os.makedirs(dirname)
except FileExistsError as err:
logger.warning(
f'File created by other thread during creation with err: {err}'
)
self.handle = open(filepath, 'w', encoding=encoding)
try:
self.acquire()
logger.debug(
f'acquire the lock for write_text function on {filepath}')
self.handle.write(obj)
except Exception as err:
raise err
finally:
self.release()
logger.debug(
f'release the lock for write_text function on {filepath}')
@contextlib.contextmanager
def as_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
class HTTPStorage(Storage):
"""HTTP and HTTPS storage."""

View File

@@ -16,7 +16,7 @@ import gast
import json
from modelscope import __version__
from modelscope.fileio.file import LocalStorage
from modelscope.fileio.file import LocalStorage, LocalStorageWithLock
from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers,
Metrics, Models, Optimizers, Pipelines,
Preprocessors, TaskModels, Trainers)
@@ -26,7 +26,7 @@ from modelscope.utils.logger import get_logger
from modelscope.utils.registry import default_group
logger = get_logger()
storage = LocalStorage()
storage = LocalStorageWithLock()
p = Path(__file__)
# get the path of package 'modelscope'
@@ -624,6 +624,7 @@ def _save_index(index, file_path, file_list=None, with_template=False):
if with_template:
json_index = json_index.replace(MODELSCOPE_PATH.as_posix(),
TEMPLATE_PATH)
storage.write(json_index.encode(), file_path)
index[INDEX_KEY] = {
ast.literal_eval(k): v

View File

@@ -5,7 +5,8 @@ import unittest
from requests import HTTPError
from modelscope.fileio.file import File, HTTPStorage, LocalStorage
from modelscope.fileio.file import (File, HTTPStorage, LocalStorage,
LocalStorageWithLock)
class FileTest(unittest.TestCase):
@@ -24,6 +25,41 @@ class FileTest(unittest.TestCase):
os.remove(temp_name)
def test_local_storage_with_lock_in_single_case(self):
storage = LocalStorageWithLock()
temp_name = tempfile.gettempdir() + '/' + next(
tempfile._get_candidate_names())
binary_content = b'12345'
storage.write(binary_content, temp_name)
self.assertEqual(binary_content, storage.read(temp_name))
content = '12345'
storage.write_text(content, temp_name)
self.assertEqual(content, storage.read_text(temp_name))
os.remove(temp_name)
def test_local_storage_with_lock_in_multi_case(self):
import threading
def local_test():
storage = LocalStorageWithLock()
temp_name = tempfile.gettempdir() + '/' + next(
tempfile._get_candidate_names()) + '/test'
binary_content = b'12345'
storage.write(binary_content, temp_name)
self.assertEqual(binary_content, storage.read(temp_name))
content = '12345'
storage.write_text(content, temp_name)
self.assertEqual(content, storage.read_text(temp_name))
os.remove(temp_name)
for i in range(5):
local_thread = threading.Thread(target=local_test)
local_thread.start()
def test_http_storage(self):
storage = HTTPStorage()
url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/texts/data.txt'