diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py index 38e4d720..ba78ab74 100644 --- a/modelscope/models/audio/ans/frcrn.py +++ b/modelscope/models/audio/ans/frcrn.py @@ -71,6 +71,7 @@ class FRCRNModel(TorchModel): model_dir (str): the model path. """ super().__init__(model_dir, *args, **kwargs) + kwargs.pop('device') self.model = FRCRN(*args, **kwargs) model_bin_file = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py index 81e47350..428ec367 100644 --- a/modelscope/models/audio/kws/farfield/model.py +++ b/modelscope/models/audio/kws/farfield/model.py @@ -33,6 +33,7 @@ class FSMNSeleNetV2Decorator(TorchModel): ModelFile.TORCH_MODEL_BIN_FILE) self._model = None if os.path.exists(model_bin_file): + kwargs.pop('device') self._model = FSMNSeleNetV2(*args, **kwargs) checkpoint = torch.load(model_bin_file) self._model.load_state_dict(checkpoint, strict=False) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 3b596769..279dbba2 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -10,6 +10,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.device import device_placement, verify_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.hub import parse_label_mapping from modelscope.utils.logger import get_logger @@ -24,8 +25,7 @@ class Model(ABC): def __init__(self, model_dir, *args, **kwargs): self.model_dir = model_dir device_name = kwargs.get('device', 'gpu') - assert device_name in ['gpu', - 'cpu'], 'device should be either cpu or gpu.' + verify_device(device_name) self._device_name = device_name def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: @@ -72,6 +72,7 @@ class Model(ABC): model_name_or_path: str, revision: Optional[str] = DEFAULT_MODEL_REVISION, cfg_dict: Config = None, + device: str = None, *model_args, **kwargs): """ Instantiate a model from local directory or remote model repo. Note @@ -97,7 +98,7 @@ class Model(ABC): osp.join(local_model_dir, ModelFile.CONFIGURATION)) task_name = cfg.task model_cfg = cfg.model - # TODO @wenmeng.zwm may should manually initialize model after model building + framework = cfg.framework if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type @@ -105,8 +106,14 @@ class Model(ABC): model_cfg.model_dir = local_model_dir for k, v in kwargs.items(): model_cfg[k] = v - model = build_model( - model_cfg, task_name=task_name, default_args=kwargs) + if device is not None: + model_cfg.device = device + with device_placement(framework, device): + model = build_model( + model_cfg, task_name=task_name, default_args=kwargs) + else: + model = build_model( + model_cfg, task_name=task_name, default_args=kwargs) # dynamically add pipeline info to model for pipeline inference if hasattr(cfg, 'pipeline'): diff --git a/modelscope/models/cv/crowd_counting/cc_model.py b/modelscope/models/cv/crowd_counting/cc_model.py index 4e3d0e9f..582b26f4 100644 --- a/modelscope/models/cv/crowd_counting/cc_model.py +++ b/modelscope/models/cv/crowd_counting/cc_model.py @@ -13,8 +13,8 @@ from modelscope.utils.constant import Tasks Tasks.crowd_counting, module_name=Models.crowd_counting) class HRNetCrowdCounting(TorchModel): - def __init__(self, model_dir: str): - super().__init__(model_dir) + def __init__(self, model_dir: str, **kwargs): + super().__init__(model_dir, **kwargs) from .hrnet_aspp_relu import HighResolutionNet as HRNet_aspp_relu diff --git a/modelscope/models/cv/image_classification/mmcls_model.py b/modelscope/models/cv/image_classification/mmcls_model.py index 6a65656e..a6789d0b 100644 --- a/modelscope/models/cv/image_classification/mmcls_model.py +++ b/modelscope/models/cv/image_classification/mmcls_model.py @@ -10,7 +10,7 @@ from modelscope.utils.constant import Tasks Tasks.image_classification, module_name=Models.classification_model) class ClassificationModel(TorchModel): - def __init__(self, model_dir: str): + def __init__(self, model_dir: str, **kwargs): import mmcv from mmcls.models import build_classifier diff --git a/modelscope/models/cv/product_retrieval_embedding/item_model.py b/modelscope/models/cv/product_retrieval_embedding/item_model.py index 2a893669..85a636c0 100644 --- a/modelscope/models/cv/product_retrieval_embedding/item_model.py +++ b/modelscope/models/cv/product_retrieval_embedding/item_model.py @@ -13,8 +13,8 @@ from modelscope.models.cv.product_retrieval_embedding.item_embedding import ( preprocess, resnet50_embed) from modelscope.outputs import OutputKeys from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import create_device from modelscope.utils.logger import get_logger -from modelscope.utils.torch_utils import create_device logger = get_logger() @@ -48,9 +48,8 @@ class ProductRetrievalEmbedding(TorchModel): filter_param(src_params, own_state) model.load_state_dict(own_state) - cpu_flag = device == 'cpu' self.device = create_device( - cpu_flag) # device.type == "cpu" or device.type == "cuda" + device) # device.type == "cpu" or device.type == "cuda" self.use_gpu = self.device.type == 'cuda' # config the model path diff --git a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py index 88a4ddda..4e959a17 100644 --- a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py +++ b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py @@ -24,8 +24,8 @@ logger = get_logger() Tasks.video_multi_modal_embedding, module_name=Models.video_clip) class VideoCLIPForMultiModalEmbedding(TorchModel): - def __init__(self, model_dir, device_id=-1): - super().__init__(model_dir=model_dir, device_id=device_id) + def __init__(self, model_dir, **kwargs): + super().__init__(model_dir=model_dir, **kwargs) # model config parameters with open(f'{model_dir}/{ModelFile.CONFIGURATION}', 'r') as json_file: model_config = json.load(json_file) diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index e9cb8db3..410a7cb5 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks -from modelscope.utils.torch_utils import create_device def audio_norm(x): diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 041dfb34..180ad757 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -14,9 +14,10 @@ from modelscope.outputs import TASK_OUTPUTS from modelscope.preprocessors import Preprocessor from modelscope.utils.config import Config from modelscope.utils.constant import Frameworks, ModelFile +from modelscope.utils.device import (create_device, device_placement, + verify_device) from modelscope.utils.import_utils import is_tf_available, is_torch_available from modelscope.utils.logger import get_logger -from modelscope.utils.torch_utils import create_device from .util import is_model, is_official_hub_path if is_torch_available(): @@ -41,7 +42,8 @@ class Pipeline(ABC): logger.info(f'initiate model from location {model}.') # expecting model has been prefetched to local cache beforehand return Model.from_pretrained( - model, model_prefetched=True) if is_model(model) else model + model, model_prefetched=True, + device=self.device_name) if is_model(model) else model elif isinstance(model, Model): return model else: @@ -74,11 +76,15 @@ class Pipeline(ABC): config_file(str, optional): Filepath to configuration file. model: (list of) Model name or model object preprocessor: (list of) Preprocessor object - device (str): gpu device or cpu device to use + device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X auto_collate (bool): automatically to convert data to tensor or not. """ if config_file is not None: self.cfg = Config.from_file(config_file) + + verify_device(device) + self.device_name = device + if not isinstance(model, List): self.model = self.initiate_single_model(model) self.models = [self.model] @@ -94,15 +100,15 @@ class Pipeline(ABC): else: self.framework = None - assert device in ['gpu', 'cpu'], 'device should be either cpu or gpu.' - self.device_name = device if self.framework == Frameworks.torch: - self.device = create_device(self.device_name == 'cpu') + self.device = create_device(self.device_name) self._model_prepare = False self._model_prepare_lock = Lock() self._auto_collate = auto_collate def prepare_model(self): + """ Place model on certain device for pytorch models before first inference + """ self._model_prepare_lock.acquire(timeout=600) def _prepare_single(model): @@ -125,39 +131,6 @@ class Pipeline(ABC): self._model_prepare = True self._model_prepare_lock.release() - @contextmanager - def place_device(self): - """ device placement function, allow user to specify which device to place pipeline - - Returns: - Context manager - - Examples: - - ```python - # Requests for using pipeline on cuda:0 for gpu - pipeline = pipeline(..., device='gpu') - with pipeline.device(): - output = pipe(...) - ``` - """ - if self.framework == Frameworks.tf: - if self.device_name == 'cpu': - with tf.device('/CPU:0'): - yield - else: - with tf.device('/device:GPU:0'): - yield - - elif self.framework == Frameworks.torch: - if self.device_name == 'gpu': - device = create_device() - if device.type == 'gpu': - torch.cuda.set_device(device) - yield - else: - yield - def _get_framework(self) -> str: frameworks = [] for m in self.models: @@ -272,10 +245,11 @@ class Pipeline(ABC): postprocess_params = kwargs.get('postprocess_params') out = self.preprocess(input, **preprocess_params) - with self.place_device(): - if self.framework == Frameworks.torch and self._auto_collate: + with device_placement(self.framework, self.device_name): + if self.framework == Frameworks.torch: with torch.no_grad(): - out = self._collate_fn(out) + if self._auto_collate: + out = self._collate_fn(out) out = self.forward(out, **forward_params) else: out = self.forward(out, **forward_params) diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 9c3c418e..eb669354 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -16,6 +16,7 @@ from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger +from ...utils.device import device_placement if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -36,11 +37,14 @@ class ImageCartoonPipeline(Pipeline): model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) - self.facer = FaceAna(self.model) - self.sess_anime_head = self.load_sess( - os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head') - self.sess_anime_bg = self.load_sess( - os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg') + with device_placement(self.framework, self.device_name): + self.facer = FaceAna(self.model) + self.sess_anime_head = self.load_sess( + os.path.join(self.model, 'cartoon_anime_h.pb'), + 'model_anime_head') + self.sess_anime_bg = self.load_sess( + os.path.join(self.model, 'cartoon_anime_bg.pb'), + 'model_anime_bg') self.box_width = 288 global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index d9e81959..d7b7fc3c 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -10,6 +10,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger logger = get_logger() @@ -31,19 +32,20 @@ class ImageMattingPipeline(Pipeline): tf = tf.compat.v1 model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - with self._session.as_default(): - logger.info(f'loading model from {model_path}') - with tf.gfile.FastGFile(model_path, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - tf.import_graph_def(graph_def, name='') - self.output = self._session.graph.get_tensor_by_name( - 'output_png:0') - self.input_name = 'input_image:0' - logger.info('load model done') + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + self.output = self._session.graph.get_tensor_by_name( + 'output_png:0') + self.input_name = 'input_image:0' + logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) diff --git a/modelscope/pipelines/cv/image_style_transfer_pipeline.py b/modelscope/pipelines/cv/image_style_transfer_pipeline.py index a67aaec2..827a0d44 100644 --- a/modelscope/pipelines/cv/image_style_transfer_pipeline.py +++ b/modelscope/pipelines/cv/image_style_transfer_pipeline.py @@ -10,6 +10,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger logger = get_logger() @@ -31,30 +32,31 @@ class ImageStyleTransferPipeline(Pipeline): tf = tf.compat.v1 model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - self.max_length = 800 - with self._session.as_default(): - logger.info(f'loading model from {model_path}') - with tf.gfile.FastGFile(model_path, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - tf.import_graph_def(graph_def, name='') + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + self.max_length = 800 + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') - self.content = tf.get_default_graph().get_tensor_by_name( - 'content:0') - self.style = tf.get_default_graph().get_tensor_by_name( - 'style:0') - self.output = tf.get_default_graph().get_tensor_by_name( - 'stylized_output:0') - self.attention = tf.get_default_graph().get_tensor_by_name( - 'attention_map:0') - self.inter_weight = tf.get_default_graph().get_tensor_by_name( - 'inter_weight:0') - self.centroids = tf.get_default_graph().get_tensor_by_name( - 'centroids:0') - logger.info('load model done') + self.content = tf.get_default_graph().get_tensor_by_name( + 'content:0') + self.style = tf.get_default_graph().get_tensor_by_name( + 'style:0') + self.output = tf.get_default_graph().get_tensor_by_name( + 'stylized_output:0') + self.attention = tf.get_default_graph().get_tensor_by_name( + 'attention_map:0') + self.inter_weight = tf.get_default_graph( + ).get_tensor_by_name('inter_weight:0') + self.centroids = tf.get_default_graph().get_tensor_by_name( + 'centroids:0') + logger.info('load model done') def _sanitize_parameters(self, **pipeline_parameters): return pipeline_parameters, {}, {} diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 32209c1e..b54ad96d 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -11,6 +11,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger from .ocr_utils import (SegLinkDetector, cal_width, combine_segments_python, decode_segments_links_python, nms_python, @@ -51,66 +52,67 @@ class OCRDetectionPipeline(Pipeline): osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), 'checkpoint-80000') - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.allow_growth = True - self._session = tf.Session(config=config) - self.input_images = tf.placeholder( - tf.float32, shape=[1, 1024, 1024, 3], name='input_images') - self.output = {} + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + self.input_images = tf.placeholder( + tf.float32, shape=[1, 1024, 1024, 3], name='input_images') + self.output = {} - with tf.variable_scope('', reuse=tf.AUTO_REUSE): - global_step = tf.get_variable( - 'global_step', [], - initializer=tf.constant_initializer(0), - dtype=tf.int64, - trainable=False) - variable_averages = tf.train.ExponentialMovingAverage( - 0.997, global_step) + with tf.variable_scope('', reuse=tf.AUTO_REUSE): + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) - # detector - detector = SegLinkDetector() - all_maps = detector.build_model( - self.input_images, is_training=False) + # detector + detector = SegLinkDetector() + all_maps = detector.build_model( + self.input_images, is_training=False) - # decode local predictions - all_nodes, all_links, all_reg = [], [], [] - for i, maps in enumerate(all_maps): - cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] - reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) - cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) - lnk_prob_pos = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, :2]) - lnk_prob_mut = tf.nn.softmax( - tf.reshape(lnk_maps, [-1, 4])[:, 2:]) - lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) + lnk_prob_pos = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) - all_nodes.append(cls_prob) - all_links.append(lnk_prob) - all_reg.append(reg_maps) + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) - # decode segments and links - image_size = tf.shape(self.input_images)[1:3] - segments, group_indices, segment_counts, _ = decode_segments_links_python( - image_size, - all_nodes, - all_links, - all_reg, - anchor_sizes=list(detector.anchor_sizes)) + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) - # combine segments - combined_rboxes, combined_counts = combine_segments_python( - segments, group_indices, segment_counts) - self.output['combined_rboxes'] = combined_rboxes - self.output['combined_counts'] = combined_counts + # combine segments + combined_rboxes, combined_counts = combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts - with self._session.as_default() as sess: - logger.info(f'loading model from {model_path}') - # load model - model_loader = tf.train.Saver( - variable_averages.variables_to_restore()) - model_loader.restore(sess, model_path) + with self._session.as_default() as sess: + logger.info(f'loading model from {model_path}') + # load model + model_loader = tf.train.Saver( + variable_averages.variables_to_restore()) + model_loader.restore(sess, model_path) def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) diff --git a/modelscope/pipelines/cv/skin_retouching_pipeline.py b/modelscope/pipelines/cv/skin_retouching_pipeline.py index d9b49ff3..f8c9de60 100644 --- a/modelscope/pipelines/cv/skin_retouching_pipeline.py +++ b/modelscope/pipelines/cv/skin_retouching_pipeline.py @@ -23,6 +23,7 @@ from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import create_device, device_placement from modelscope.utils.logger import get_logger if tf.__version__ >= '2.0': @@ -42,12 +43,9 @@ class SkinRetouchingPipeline(Pipeline): Args: model: model id on modelscope hub. """ - super().__init__(model=model) + super().__init__(model=model, device=device) - if torch.cuda.is_available() and device == 'gpu': - device = 'cuda' - else: - device = 'cpu' + device = create_device(self.device_name) model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE) detector_model_path = os.path.join( self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth') @@ -81,16 +79,17 @@ class SkinRetouchingPipeline(Pipeline): self.skin_model_path = skin_model_path if self.skin_model_path is not None: - config = tf.ConfigProto(allow_soft_placement=True) - config.gpu_options.per_process_gpu_memory_fraction = 0.3 - config.gpu_options.allow_growth = True - self.sess = tf.Session(config=config) - with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - self.sess.graph.as_default() - tf.import_graph_def(graph_def, name='') - self.sess.run(tf.global_variables_initializer()) + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + self.sess.graph.as_default() + tf.import_graph_def(graph_def, name='') + self.sess.run(tf.global_variables_initializer()) self.image_files_transforms = transforms.Compose([ transforms.ToTensor(), diff --git a/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py index 166d3f06..bc697b05 100644 --- a/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py +++ b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py @@ -4,6 +4,7 @@ from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks +from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger logger = get_logger() @@ -26,7 +27,7 @@ class VideoMultiModalEmbeddingPipeline(Pipeline): return input def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: - with self.place_device(): + with device_placement(self.framework, self.device_name): out = self.forward(input) self._check_output(out) diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index 909e3c6c..b9b74ce4 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -31,7 +31,7 @@ class TranslationPipeline(Pipeline): @param model: A Model instance. """ - super().__init__(model=model) + super().__init__(model=model, **kwargs) model = self.model.model_dir tf.reset_default_graph() diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 0916495c..c48ab2cd 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -36,11 +36,11 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, ConfigKeys, Hubs, ModeKeys, ModelFile, Tasks, TrainerStages) from modelscope.utils.data_utils import to_device +from modelscope.utils.device import create_device, verify_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.logger import get_logger from modelscope.utils.registry import build_from_cfg -from modelscope.utils.torch_utils import (create_device, get_dist_info, - init_dist) +from modelscope.utils.torch_utils import get_dist_info, init_dist from .base import BaseTrainer from .builder import TRAINERS from .default_config import DEFAULT_CONFIG @@ -150,9 +150,8 @@ class EpochBasedTrainer(BaseTrainer): self.eval_preprocessor.mode = ModeKeys.EVAL device_name = kwargs.get('device', 'gpu') - assert device_name in ['gpu', - 'cpu'], 'device should be either cpu or gpu.' - self.device = create_device(device_name == 'cpu') + verify_device(device_name) + self.device = create_device(device_name) self.train_dataset = self.to_task_dataset( train_dataset, diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 1a3fb7c3..993a3e42 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -290,3 +290,9 @@ class ColorCodes: GREEN = '\033[92m' RED = '\033[91m' END = '\033[0m' + + +class Devices: + """device used for training and inference""" + cpu = 'cpu' + gpu = 'gpu' diff --git a/modelscope/utils/device.py b/modelscope/utils/device.py new file mode 100644 index 00000000..aa8fda66 --- /dev/null +++ b/modelscope/utils/device.py @@ -0,0 +1,110 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from contextlib import contextmanager + +from modelscope.utils.constant import Devices, Frameworks +from modelscope.utils.import_utils import is_tf_available, is_torch_available +from modelscope.utils.logger import get_logger + +logger = get_logger() + +if is_tf_available(): + import tensorflow as tf + +if is_torch_available(): + import torch + + +def verify_device(device_name): + """ Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X. + + Args: + device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X + where X is the ordinal for gpu device. + + Return: + device info (tuple): device_type and device_id, if device_id is not set, will use 0 as default. + """ + device_name = device_name.lower() + eles = device_name.split(':') + err_msg = 'device should be either cpu, cuda, gpu, gpu:X or cuda:X where X is the ordinal for gpu device.' + assert len(eles) <= 2, err_msg + assert eles[0] in ['cpu', 'cuda', 'gpu'], err_msg + device_type = eles[0] + device_id = None + if len(eles) > 1: + device_id = int(eles[1]) + if device_type == 'cuda': + device_type = Devices.gpu + if device_type == Devices.gpu and device_id is None: + device_id = 0 + return device_type, device_id + + +@contextmanager +def device_placement(framework, device_name='gpu:0'): + """ Device placement function, allow user to specify which device to place model or tensor + Args: + framework (str): tensorflow or pytorch. + device (str): gpu or cpu to use, if you want to specify certain gpu, + use gpu:$gpu_id or cuda:$gpu_id. + + Returns: + Context manager + + Examples: + + ```python + # Requests for using model on cuda:0 for gpu + with device_placement('pytorch', device='gpu:0'): + model = Model.from_pretrained(...) + ``` + """ + device_type, device_id = verify_device(device_name) + + if framework == Frameworks.tf: + if device_type == Devices.gpu and not tf.test.is_gpu_available(): + logger.warning( + 'tensorflow cuda is not available, using cpu instead.') + device_type = Devices.cpu + if device_type == Devices.cpu: + with tf.device('/CPU:0'): + yield + else: + if device_type == Devices.gpu: + with tf.device(f'/device:gpu:{device_id}'): + yield + + elif framework == Frameworks.torch: + if device_type == Devices.gpu: + if torch.cuda.is_available(): + torch.cuda.set_device(f'cuda:{device_id}') + else: + logger.warning('cuda is not available, using cpu instead.') + yield + else: + yield + + +def create_device(device_name) -> torch.DeviceObjType: + """ create torch device + + Args: + device_name (str): cpu, gpu, gpu:0, cuda:0 etc. + """ + device_type, device_id = verify_device(device_name) + use_cuda = False + if device_type == Devices.gpu: + use_cuda = True + if not torch.cuda.is_available(): + logger.warning( + 'cuda is not available, create gpu device failed, using cpu instead.' + ) + use_cuda = False + + if use_cuda: + device = torch.device(f'cuda:{device_id}') + else: + device = torch.device('cpu') + + return device diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py index 1f157f9a..45e33c3e 100644 --- a/modelscope/utils/torch_utils.py +++ b/modelscope/utils/torch_utils.py @@ -132,17 +132,6 @@ def master_only(func: Callable) -> Callable: return wrapper -def create_device(cpu: bool = False) -> torch.DeviceObjType: - use_cuda = torch.cuda.is_available() and not cpu - if use_cuda: - local_rank = os.environ.get('LOCAL_RANK', 0) - device = torch.device(f'cuda:{local_rank}') - else: - device = torch.device('cpu') - - return device - - def make_tmp_dir(): """Make sure each rank has the same temporary directory on the distributed mode. """ diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py index e7967edc..0b64831a 100644 --- a/tests/pipelines/test_key_word_spotting_farfield.py +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -41,3 +41,7 @@ class KWSFarfieldTest(unittest.TestCase): result = kws(data) self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py new file mode 100644 index 00000000..3135b214 --- /dev/null +++ b/tests/utils/test_device.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import tempfile +import time +import unittest + +import torch + +from modelscope.utils.constant import Frameworks +from modelscope.utils.device import (create_device, device_placement, + verify_device) + +# import tensorflow must be imported after torch is imported when using tf1.15 +import tensorflow as tf # isort:skip + + +class DeviceTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def tearDown(self): + super().tearDown() + + def test_verify(self): + device_name, device_id = verify_device('cpu') + self.assertEqual(device_name, 'cpu') + self.assertTrue(device_id is None) + device_name, device_id = verify_device('CPU') + self.assertEqual(device_name, 'cpu') + + device_name, device_id = verify_device('gpu') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('cuda') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('cuda:0') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('gpu:1') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 1) + + with self.assertRaises(AssertionError): + verify_device('xgu') + + def test_create_device_torch(self): + if torch.cuda.is_available(): + target_device_type = 'cuda' + target_device_index = 0 + else: + target_device_type = 'cpu' + target_device_index = None + device = create_device('gpu') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + device = create_device('gpu:0') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + device = create_device('cuda') + self.assertTrue(device.type == target_device_type) + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.index == target_device_index) + + device = create_device('cuda:0') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + def test_device_placement_cpu(self): + with device_placement(Frameworks.torch, 'cpu'): + pass + + def test_device_placement_tf_gpu(self): + tf.debugging.set_log_device_placement(True) + with device_placement(Frameworks.tf, 'gpu:0'): + a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + c = tf.matmul(a, b) + s = tf.Session() + s.run(c) + tf.debugging.set_log_device_placement(False) + + def test_device_placement_torch_gpu(self): + with device_placement(Frameworks.torch, 'gpu:0'): + if torch.cuda.is_available(): + self.assertEqual(torch.cuda.current_device(), 0) + + +if __name__ == '__main__': + unittest.main()