From aaa604cb16ff6e5d9fdfe6b7c978d0e5eb064d01 Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Mon, 22 Aug 2022 15:32:00 +0800 Subject: [PATCH] [to #43878347] device placement support certain gpu 1. add device util to verify, create and place device 2. pipeline and trainer support update 3. fix pipeline which use tf models does not place model to the right device usage ```python pipe = pipeline('damo/xxx', device='cpu') pipe = pipeline('damo/xxx', device='gpu') pipe = pipeline('damo/xxx', device='gpu:0') pipe = pipeline('damo/xxx', device='gpu:2') pipe = pipeline('damo/xxx', device='cuda') pipe = pipeline('damo/xxx', device='cuda:1') ``` Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9800672 --- modelscope/models/audio/ans/frcrn.py | 1 + modelscope/models/audio/kws/farfield/model.py | 1 + modelscope/models/base/base_model.py | 17 ++- .../models/cv/crowd_counting/cc_model.py | 4 +- .../cv/image_classification/mmcls_model.py | 2 +- .../product_retrieval_embedding/item_model.py | 5 +- .../mmr/models/clip_for_mm_video_embedding.py | 4 +- modelscope/pipelines/audio/ans_pipeline.py | 1 - modelscope/pipelines/base.py | 58 +++------ .../pipelines/cv/image_cartoon_pipeline.py | 14 ++- .../pipelines/cv/image_matting_pipeline.py | 28 ++--- .../cv/image_style_transfer_pipeline.py | 48 ++++---- .../pipelines/cv/ocr_detection_pipeline.py | 104 +++++++++-------- .../pipelines/cv/skin_retouching_pipeline.py | 29 +++-- .../video_multi_modal_embedding_pipeline.py | 3 +- .../pipelines/nlp/translation_pipeline.py | 2 +- modelscope/trainers/trainer.py | 9 +- modelscope/utils/constant.py | 6 + modelscope/utils/device.py | 110 ++++++++++++++++++ modelscope/utils/torch_utils.py | 11 -- .../test_key_word_spotting_farfield.py | 4 + tests/utils/test_device.py | 101 ++++++++++++++++ 22 files changed, 381 insertions(+), 181 deletions(-) create mode 100644 modelscope/utils/device.py create mode 100644 tests/utils/test_device.py diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py index 38e4d720..ba78ab74 100644 --- a/modelscope/models/audio/ans/frcrn.py +++ b/modelscope/models/audio/ans/frcrn.py @@ -71,6 +71,7 @@ class FRCRNModel(TorchModel): model_dir (str): the model path. """ super().__init__(model_dir, *args, **kwargs) + kwargs.pop('device') self.model = FRCRN(*args, **kwargs) model_bin_file = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py index 81e47350..428ec367 100644 --- a/modelscope/models/audio/kws/farfield/model.py +++ b/modelscope/models/audio/kws/farfield/model.py @@ -33,6 +33,7 @@ class FSMNSeleNetV2Decorator(TorchModel): ModelFile.TORCH_MODEL_BIN_FILE) self._model = None if os.path.exists(model_bin_file): + kwargs.pop('device') self._model = FSMNSeleNetV2(*args, **kwargs) checkpoint = torch.load(model_bin_file) self._model.load_state_dict(checkpoint, strict=False) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 3b596769..279dbba2 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -10,6 +10,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.device import device_placement, verify_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.hub import parse_label_mapping from modelscope.utils.logger import get_logger @@ -24,8 +25,7 @@ class Model(ABC): def __init__(self, model_dir, *args, **kwargs): self.model_dir = model_dir device_name = kwargs.get('device', 'gpu') - assert device_name in ['gpu', - 'cpu'], 'device should be either cpu or gpu.' + verify_device(device_name) self._device_name = device_name def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: @@ -72,6 +72,7 @@ class Model(ABC): model_name_or_path: str, revision: Optional[str] = DEFAULT_MODEL_REVISION, cfg_dict: Config = None, + device: str = None, *model_args, **kwargs): """ Instantiate a model from local directory or remote model repo. Note @@ -97,7 +98,7 @@ class Model(ABC): osp.join(local_model_dir, ModelFile.CONFIGURATION)) task_name = cfg.task model_cfg = cfg.model - # TODO @wenmeng.zwm may should manually initialize model after model building + framework = cfg.framework if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type @@ -105,8 +106,14 @@ class Model(ABC): model_cfg.model_dir = local_model_dir for k, v in kwargs.items(): model_cfg[k] = v - model = build_model( - model_cfg, task_name=task_name, default_args=kwargs) + if device is not None: + model_cfg.device = device + with device_placement(framework, device): + model = build_model( + model_cfg, task_name=task_name, default_args=kwargs) + else: + model = build_model( + model_cfg, task_name=task_name, default_args=kwargs) # dynamically add pipeline info to model for pipeline inference if hasattr(cfg, 'pipeline'): diff --git a/modelscope/models/cv/crowd_counting/cc_model.py b/modelscope/models/cv/crowd_counting/cc_model.py index 4e3d0e9f..582b26f4 100644 --- a/modelscope/models/cv/crowd_counting/cc_model.py +++ b/modelscope/models/cv/crowd_counting/cc_model.py @@ -13,8 +13,8 @@ from modelscope.utils.constant import Tasks Tasks.crowd_counting, module_name=Models.crowd_counting) class HRNetCrowdCounting(TorchModel): - def __init__(self, model_dir: str): - super().__init__(model_dir) + def __init__(self, model_dir: str, **kwargs): + super().__init__(model_dir, **kwargs) from .hrnet_aspp_relu import HighResolutionNet as HRNet_aspp_relu diff --git a/modelscope/models/cv/image_classification/mmcls_model.py b/modelscope/models/cv/image_classification/mmcls_model.py index 6a65656e..a6789d0b 100644 --- a/modelscope/models/cv/image_classification/mmcls_model.py +++ b/modelscope/models/cv/image_classification/mmcls_model.py @@ -10,7 +10,7 @@ from modelscope.utils.constant import Tasks Tasks.image_classification, module_name=Models.classification_model) class ClassificationModel(TorchModel): - def __init__(self, model_dir: str): + def __init__(self, model_dir: str, **kwargs): import mmcv from mmcls.models import build_classifier diff --git a/modelscope/models/cv/product_retrieval_embedding/item_model.py b/modelscope/models/cv/product_retrieval_embedding/item_model.py index 2a893669..85a636c0 100644 --- a/modelscope/models/cv/product_retrieval_embedding/item_model.py +++ b/modelscope/models/cv/product_retrieval_embedding/item_model.py @@ -13,8 +13,8 @@ from modelscope.models.cv.product_retrieval_embedding.item_embedding import ( preprocess, resnet50_embed) from modelscope.outputs import OutputKeys from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import create_device from modelscope.utils.logger import get_logger -from modelscope.utils.torch_utils import create_device logger = get_logger() @@ -48,9 +48,8 @@ class ProductRetrievalEmbedding(TorchModel): filter_param(src_params, own_state) model.load_state_dict(own_state) - cpu_flag = device == 'cpu' self.device = create_device( - cpu_flag) # device.type == "cpu" or device.type == "cuda" + device) # device.type == "cpu" or device.type == "cuda" self.use_gpu = self.device.type == 'cuda' # config the model path diff --git a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py index 88a4ddda..4e959a17 100644 --- a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py +++ b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py @@ -24,8 +24,8 @@ logger = get_logger() Tasks.video_multi_modal_embedding, module_name=Models.video_clip) class VideoCLIPForMultiModalEmbedding(TorchModel): - def __init__(self, model_dir, device_id=-1): - super().__init__(model_dir=model_dir, device_id=device_id) + def __init__(self, model_dir, **kwargs): + super().__init__(model_dir=model_dir, **kwargs) # model config parameters with open(f'{model_dir}/{ModelFile.CONFIGURATION}', 'r') as json_file: model_config = json.load(json_file) diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index e9cb8db3..410a7cb5 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks -from modelscope.utils.torch_utils import create_device def audio_norm(x): diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 041dfb34..180ad757 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -14,9 +14,10 @@ from modelscope.outputs import TASK_OUTPUTS from modelscope.preprocessors import Preprocessor from modelscope.utils.config import Config from modelscope.utils.constant import Frameworks, ModelFile +from modelscope.utils.device import (create_device, device_placement, + verify_device) from modelscope.utils.import_utils import is_tf_available, is_torch_available from modelscope.utils.logger import get_logger -from modelscope.utils.torch_utils import create_device from .util import is_model, is_official_hub_path if is_torch_available(): @@ -41,7 +42,8 @@ class Pipeline(ABC): logger.info(f'initiate model from location {model}.') # expecting model has been prefetched to local cache beforehand return Model.from_pretrained( - model, model_prefetched=True) if is_model(model) else model + model, model_prefetched=True, + device=self.device_name) if is_model(model) else model elif isinstance(model, Model): return model else: @@ -74,11 +76,15 @@ class Pipeline(ABC): config_file(str, optional): Filepath to configuration file. model: (list of) Model name or model object preprocessor: (list of) Preprocessor object - device (str): gpu device or cpu device to use + device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X auto_collate (bool): automatically to convert data to tensor or not. """ if config_file is not None: self.cfg = Config.from_file(config_file) + + verify_device(device) + self.device_name = device + if not isinstance(model, List): self.model = self.initiate_single_model(model) self.models = [self.model] @@ -94,15 +100,15 @@ class Pipeline(ABC): else: self.framework = None - assert device in ['gpu', 'cpu'], 'device should be either cpu or gpu.' - self.device_name = device if self.framework == Frameworks.torch: - self.device = create_device(self.device_name == 'cpu') + self.device = create_device(self.device_name) self._model_prepare = False self._model_prepare_lock = Lock() self._auto_collate = auto_collate def prepare_model(self): + """ Place model on certain device for pytorch models before first inference + """ self._model_prepare_lock.acquire(timeout=600) def _prepare_single(model): @@ -125,39 +131,6 @@ class Pipeline(ABC): self._model_prepare = True self._model_prepare_lock.release() - @contextmanager - def place_device(self): - """ device placement function, allow user to specify which device to place pipeline - - Returns: - Context manager - - Examples: - - ```python - # Requests for using pipeline on cuda:0 for gpu - pipeline = pipeline(..., device='gpu') - with pipeline.device(): - output = pipe(...) - ``` - """ - if self.framework == Frameworks.tf: - if self.device_name == 'cpu': - with tf.device('/CPU:0'): - yield - else: - with tf.device('/device:GPU:0'): - yield - - elif self.framework == Frameworks.torch: - if self.device_name == 'gpu': - device = create_device() - if device.type == 'gpu': - torch.cuda.set_device(device) - yield - else: - yield - def _get_framework(self) -> str: frameworks = [] for m in self.models: @@ -272,10 +245,11 @@ class Pipeline(ABC): postprocess_params = kwargs.get('postprocess_params') out = self.preprocess(input, **preprocess_params) - with self.place_device(): - if self.framework == Frameworks.torch and self._auto_collate: + with device_placement(self.framework, self.device_name): + if self.framework == Frameworks.torch: with torch.no_grad(): - out = self._collate_fn(out) + if self._auto_collate: + out = self._collate_fn(out) out = self.forward(out, **forward_params) else: out = self.forward(out, **forward_params) diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 9c3c418e..eb669354 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -16,6 +16,7 @@ from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger +from ...utils.device import device_placement if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -36,11 +37,14 @@ class ImageCartoonPipeline(Pipeline): model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) - self.facer = FaceAna(self.model) - self.sess_anime_head = self.load_sess( - os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head') - self.sess_anime_bg = self.load_sess( - os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg') + with device_placement(self.framework, self.device_name): + self.facer = FaceAna(self.model) + self.sess_anime_head = self.load_sess( + os.path.join(self.model, 'cartoon_anime_h.pb'), + 'model_anime_head') + self.sess_anime_bg = self.load_sess( + os.path.join(self.model, 'cartoon_anime_bg.pb'), + 'model_anime_bg') self.box_width = 288 global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index d9e81959..d7b7fc3c 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -10,6 +10,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger logger = get_logger() @@ -31,19 +32,20 @@ class ImageMattingPipeline(Pipeline): tf = tf.compat.v1 model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - with self._session.as_default(): - logger.info(f'loading model from {model_path}') - with tf.gfile.FastGFile(model_path, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - tf.import_graph_def(graph_def, name='') - self.output = self._session.graph.get_tensor_by_name( - 'output_png:0') - self.input_name = 'input_image:0' - logger.info('load model done') + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + self.output = self._session.graph.get_tensor_by_name( + 'output_png:0') + self.input_name = 'input_image:0' + logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) diff --git a/modelscope/pipelines/cv/image_style_transfer_pipeline.py b/modelscope/pipelines/cv/image_style_transfer_pipeline.py index a67aaec2..827a0d44 100644 --- a/modelscope/pipelines/cv/image_style_transfer_pipeline.py +++ b/modelscope/pipelines/cv/image_style_transfer_pipeline.py @@ -10,6 +10,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger logger = get_logger() @@ -31,30 +32,31 @@ class ImageStyleTransferPipeline(Pipeline): tf = tf.compat.v1 model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - self.max_length = 800 - with self._session.as_default(): - logger.info(f'loading model from {model_path}') - with tf.gfile.FastGFile(model_path, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - tf.import_graph_def(graph_def, name='') + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + self.max_length = 800 + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') - self.content = tf.get_default_graph().get_tensor_by_name( - 'content:0') - self.style = tf.get_default_graph().get_tensor_by_name( - 'style:0') - self.output = tf.get_default_graph().get_tensor_by_name( - 'stylized_output:0') - self.attention = tf.get_default_graph().get_tensor_by_name( - 'attention_map:0') - self.inter_weight = tf.get_default_graph().get_tensor_by_name( - 'inter_weight:0') - self.centroids = tf.get_default_graph().get_tensor_by_name( - 'centroids:0') - logger.info('load model done') + self.content = tf.get_default_graph().get_tensor_by_name( + 'content:0') + self.style = tf.get_default_graph().get_tensor_by_name( + 'style:0') + self.output = tf.get_default_graph().get_tensor_by_name( + 'stylized_output:0') + self.attention = tf.get_default_graph().get_tensor_by_name( + 'attention_map:0') + self.inter_weight = tf.get_default_graph( + ).get_tensor_by_name('inter_weight:0') + self.centroids = tf.get_default_graph().get_tensor_by_name( + 'centroids:0') + logger.info('load model done') def _sanitize_parameters(self, **pipeline_parameters): return pipeline_parameters, {}, {} diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 32209c1e..b54ad96d 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -11,6 +11,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger from .ocr_utils import (SegLinkDetector, cal_width, combine_segments_python, decode_segments_links_python, nms_python, @@ -51,66 +52,67 @@ class OCRDetectionPipeline(Pipeline): osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), 'checkpoint-80000') - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - self.input_images = tf.placeholder( - tf.float32, shape=[1, 1024, 1024, 3], name='input_images') - self.output = {} + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + self.input_images = tf.placeholder( + tf.float32, shape=[1, 1024, 1024, 3], name='input_images') + self.output = {} - with tf.variable_scope('', reuse=tf.AUTO_REUSE): - global_step = tf.get_variable( - 'global_step', [], - initializer=tf.constant_initializer(0), - dtype=tf.int64, - trainable=False) - variable_averages = tf.train.ExponentialMovingAverage( - 0.997, global_step) + with tf.variable_scope('', reuse=tf.AUTO_REUSE): + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) - # detector - detector = SegLinkDetector() - all_maps = detector.build_model( - self.input_images, is_training=False) + # detector + detector = SegLinkDetector() + all_maps = detector.build_model( + self.input_images, is_training=False) - # decode local predictions - all_nodes, all_links, all_reg = [], [], [] - for i, maps in enumerate(all_maps): - cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] - reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) - cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) - lnk_prob_pos = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, :2]) - lnk_prob_mut = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, 2:]) - lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) + lnk_prob_pos = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) - all_nodes.append(cls_prob) - all_links.append(lnk_prob) - all_reg.append(reg_maps) + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) - # decode segments and links - image_size = tf.shape(self.input_images)[1:3] - segments, group_indices, segment_counts, _ = decode_segments_links_python( - image_size, - all_nodes, - all_links, - all_reg, - anchor_sizes=list(detector.anchor_sizes)) + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) - # combine segments - combined_rboxes, combined_counts = combine_segments_python( - segments, group_indices, segment_counts) - self.output['combined_rboxes'] = combined_rboxes - self.output['combined_counts'] = combined_counts + # combine segments + combined_rboxes, combined_counts = combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts - with self._session.as_default() as sess: - logger.info(f'loading model from {model_path}') - # load model - model_loader = tf.train.Saver( - variable_averages.variables_to_restore()) - model_loader.restore(sess, model_path) + with self._session.as_default() as sess: + logger.info(f'loading model from {model_path}') + # load model + model_loader = tf.train.Saver( + variable_averages.variables_to_restore()) + model_loader.restore(sess, model_path) def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) diff --git a/modelscope/pipelines/cv/skin_retouching_pipeline.py b/modelscope/pipelines/cv/skin_retouching_pipeline.py index d9b49ff3..f8c9de60 100644 --- a/modelscope/pipelines/cv/skin_retouching_pipeline.py +++ b/modelscope/pipelines/cv/skin_retouching_pipeline.py @@ -23,6 +23,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import create_device, device_placement from modelscope.utils.logger import get_logger if tf.__version__ >= '2.0': @@ -42,12 +43,9 @@ class SkinRetouchingPipeline(Pipeline): Args: model: model id on modelscope hub. """ - super().__init__(model=model) + super().__init__(model=model, device=device) - if torch.cuda.is_available() and device == 'gpu': - device = 'cuda' - else: - device = 'cpu' + device = create_device(self.device_name) model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE) detector_model_path = os.path.join( self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth') @@ -81,16 +79,17 @@ class SkinRetouchingPipeline(Pipeline): self.skin_model_path = skin_model_path if self.skin_model_path is not None: - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.per_process_gpu_memory_fraction = 0.3 - config.gpu_options.allow_growth = True - self.sess = tf.Session(config=config) - with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - self.sess.graph.as_default() - tf.import_graph_def(graph_def, name='') - self.sess.run(tf.global_variables_initializer()) + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + self.sess.graph.as_default() + tf.import_graph_def(graph_def, name='') + self.sess.run(tf.global_variables_initializer()) self.image_files_transforms = transforms.Compose([ transforms.ToTensor(), diff --git a/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py index 166d3f06..bc697b05 100644 --- a/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py +++ b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py @@ -4,6 +4,7 @@ from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger logger = get_logger() @@ -26,7 +27,7 @@ class VideoMultiModalEmbeddingPipeline(Pipeline): return input def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: - with self.place_device(): + with device_placement(self.framework, self.device_name): out = self.forward(input) self._check_output(out) diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index 909e3c6c..b9b74ce4 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -31,7 +31,7 @@ class TranslationPipeline(Pipeline): @param model: A Model instance. """ - super().__init__(model=model) + super().__init__(model=model, **kwargs) model = self.model.model_dir tf.reset_default_graph() diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 0916495c..c48ab2cd 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -36,11 +36,11 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, ConfigKeys, Hubs, ModeKeys, ModelFile, Tasks, TrainerStages) from modelscope.utils.data_utils import to_device +from modelscope.utils.device import create_device, verify_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.logger import get_logger from modelscope.utils.registry import build_from_cfg -from modelscope.utils.torch_utils import (create_device, get_dist_info, - init_dist) +from modelscope.utils.torch_utils import get_dist_info, init_dist from .base import BaseTrainer from .builder import TRAINERS from .default_config import DEFAULT_CONFIG @@ -150,9 +150,8 @@ class EpochBasedTrainer(BaseTrainer): self.eval_preprocessor.mode = ModeKeys.EVAL device_name = kwargs.get('device', 'gpu') - assert device_name in ['gpu', - 'cpu'], 'device should be either cpu or gpu.' - self.device = create_device(device_name == 'cpu') + verify_device(device_name) + self.device = create_device(device_name) self.train_dataset = self.to_task_dataset( train_dataset, diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 1a3fb7c3..993a3e42 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -290,3 +290,9 @@ class ColorCodes: GREEN = '\033[92m' RED = '\033[91m' END = '\033[0m' + + +class Devices: + """device used for training and inference""" + cpu = 'cpu' + gpu = 'gpu' diff --git a/modelscope/utils/device.py b/modelscope/utils/device.py new file mode 100644 index 00000000..aa8fda66 --- /dev/null +++ b/modelscope/utils/device.py @@ -0,0 +1,110 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from contextlib import contextmanager + +from modelscope.utils.constant import Devices, Frameworks +from modelscope.utils.import_utils import is_tf_available, is_torch_available +from modelscope.utils.logger import get_logger + +logger = get_logger() + +if is_tf_available(): + import tensorflow as tf + +if is_torch_available(): + import torch + + +def verify_device(device_name): + """ Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X. + + Args: + device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X + where X is the ordinal for gpu device. + + Return: + device info (tuple): device_type and device_id, if device_id is not set, will use 0 as default. + """ + device_name = device_name.lower() + eles = device_name.split(':') + err_msg = 'device should be either cpu, cuda, gpu, gpu:X or cuda:X where X is the ordinal for gpu device.' + assert len(eles) <= 2, err_msg + assert eles[0] in ['cpu', 'cuda', 'gpu'], err_msg + device_type = eles[0] + device_id = None + if len(eles) > 1: + device_id = int(eles[1]) + if device_type == 'cuda': + device_type = Devices.gpu + if device_type == Devices.gpu and device_id is None: + device_id = 0 + return device_type, device_id + + +@contextmanager +def device_placement(framework, device_name='gpu:0'): + """ Device placement function, allow user to specify which device to place model or tensor + Args: + framework (str): tensorflow or pytorch. + device (str): gpu or cpu to use, if you want to specify certain gpu, + use gpu:$gpu_id or cuda:$gpu_id. + + Returns: + Context manager + + Examples: + + ```python + # Requests for using model on cuda:0 for gpu + with device_placement('pytorch', device='gpu:0'): + model = Model.from_pretrained(...) + ``` + """ + device_type, device_id = verify_device(device_name) + + if framework == Frameworks.tf: + if device_type == Devices.gpu and not tf.test.is_gpu_available(): + logger.warning( + 'tensorflow cuda is not available, using cpu instead.') + device_type = Devices.cpu + if device_type == Devices.cpu: + with tf.device('/CPU:0'): + yield + else: + if device_type == Devices.gpu: + with tf.device(f'/device:gpu:{device_id}'): + yield + + elif framework == Frameworks.torch: + if device_type == Devices.gpu: + if torch.cuda.is_available(): + torch.cuda.set_device(f'cuda:{device_id}') + else: + logger.warning('cuda is not available, using cpu instead.') + yield + else: + yield + + +def create_device(device_name) -> torch.DeviceObjType: + """ create torch device + + Args: + device_name (str): cpu, gpu, gpu:0, cuda:0 etc. + """ + device_type, device_id = verify_device(device_name) + use_cuda = False + if device_type == Devices.gpu: + use_cuda = True + if not torch.cuda.is_available(): + logger.warning( + 'cuda is not available, create gpu device failed, using cpu instead.' + ) + use_cuda = False + + if use_cuda: + device = torch.device(f'cuda:{device_id}') + else: + device = torch.device('cpu') + + return device diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py index 1f157f9a..45e33c3e 100644 --- a/modelscope/utils/torch_utils.py +++ b/modelscope/utils/torch_utils.py @@ -132,17 +132,6 @@ def master_only(func: Callable) -> Callable: return wrapper -def create_device(cpu: bool = False) -> torch.DeviceObjType: - use_cuda = torch.cuda.is_available() and not cpu - if use_cuda: - local_rank = os.environ.get('LOCAL_RANK', 0) - device = torch.device(f'cuda:{local_rank}') - else: - device = torch.device('cpu') - - return device - - def make_tmp_dir(): """Make sure each rank has the same temporary directory on the distributed mode. """ diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py index e7967edc..0b64831a 100644 --- a/tests/pipelines/test_key_word_spotting_farfield.py +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -41,3 +41,7 @@ class KWSFarfieldTest(unittest.TestCase): result = kws(data) self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py new file mode 100644 index 00000000..3135b214 --- /dev/null +++ b/tests/utils/test_device.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import tempfile +import time +import unittest + +import torch + +from modelscope.utils.constant import Frameworks +from modelscope.utils.device import (create_device, device_placement, + verify_device) + +# import tensorflow must be imported after torch is imported when using tf1.15 +import tensorflow as tf # isort:skip + + +class DeviceTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def tearDown(self): + super().tearDown() + + def test_verify(self): + device_name, device_id = verify_device('cpu') + self.assertEqual(device_name, 'cpu') + self.assertTrue(device_id is None) + device_name, device_id = verify_device('CPU') + self.assertEqual(device_name, 'cpu') + + device_name, device_id = verify_device('gpu') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('cuda') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('cuda:0') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('gpu:1') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 1) + + with self.assertRaises(AssertionError): + verify_device('xgu') + + def test_create_device_torch(self): + if torch.cuda.is_available(): + target_device_type = 'cuda' + target_device_index = 0 + else: + target_device_type = 'cpu' + target_device_index = None + device = create_device('gpu') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + device = create_device('gpu:0') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + device = create_device('cuda') + self.assertTrue(device.type == target_device_type) + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.index == target_device_index) + + device = create_device('cuda:0') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + def test_device_placement_cpu(self): + with device_placement(Frameworks.torch, 'cpu'): + pass + + def test_device_placement_tf_gpu(self): + tf.debugging.set_log_device_placement(True) + with device_placement(Frameworks.tf, 'gpu:0'): + a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + c = tf.matmul(a, b) + s = tf.Session() + s.run(c) + tf.debugging.set_log_device_placement(False) + + def test_device_placement_torch_gpu(self): + with device_placement(Frameworks.torch, 'gpu:0'): + if torch.cuda.is_available(): + self.assertEqual(torch.cuda.current_device(), 0) + + +if __name__ == '__main__': + unittest.main()