mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #43878347] device placement support certain gpu
1. add device util to verify, create and place device
2. pipeline and trainer support update
3. fix pipeline which use tf models does not place model to the right device
usage
```python
pipe = pipeline('damo/xxx', device='cpu')
pipe = pipeline('damo/xxx', device='gpu')
pipe = pipeline('damo/xxx', device='gpu:0')
pipe = pipeline('damo/xxx', device='gpu:2')
pipe = pipeline('damo/xxx', device='cuda')
pipe = pipeline('damo/xxx', device='cuda:1')
```
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9800672
This commit is contained in:
@@ -71,6 +71,7 @@ class FRCRNModel(TorchModel):
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
kwargs.pop('device')
|
||||
self.model = FRCRN(*args, **kwargs)
|
||||
model_bin_file = os.path.join(model_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
|
||||
@@ -33,6 +33,7 @@ class FSMNSeleNetV2Decorator(TorchModel):
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
self._model = None
|
||||
if os.path.exists(model_bin_file):
|
||||
kwargs.pop('device')
|
||||
self._model = FSMNSeleNetV2(*args, **kwargs)
|
||||
checkpoint = torch.load(model_bin_file)
|
||||
self._model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
@@ -10,6 +10,7 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.device import device_placement, verify_device
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.hub import parse_label_mapping
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -24,8 +25,7 @@ class Model(ABC):
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
self.model_dir = model_dir
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
assert device_name in ['gpu',
|
||||
'cpu'], 'device should be either cpu or gpu.'
|
||||
verify_device(device_name)
|
||||
self._device_name = device_name
|
||||
|
||||
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
@@ -72,6 +72,7 @@ class Model(ABC):
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
cfg_dict: Config = None,
|
||||
device: str = None,
|
||||
*model_args,
|
||||
**kwargs):
|
||||
""" Instantiate a model from local directory or remote model repo. Note
|
||||
@@ -97,7 +98,7 @@ class Model(ABC):
|
||||
osp.join(local_model_dir, ModelFile.CONFIGURATION))
|
||||
task_name = cfg.task
|
||||
model_cfg = cfg.model
|
||||
# TODO @wenmeng.zwm may should manually initialize model after model building
|
||||
framework = cfg.framework
|
||||
|
||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
@@ -105,8 +106,14 @@ class Model(ABC):
|
||||
model_cfg.model_dir = local_model_dir
|
||||
for k, v in kwargs.items():
|
||||
model_cfg[k] = v
|
||||
model = build_model(
|
||||
model_cfg, task_name=task_name, default_args=kwargs)
|
||||
if device is not None:
|
||||
model_cfg.device = device
|
||||
with device_placement(framework, device):
|
||||
model = build_model(
|
||||
model_cfg, task_name=task_name, default_args=kwargs)
|
||||
else:
|
||||
model = build_model(
|
||||
model_cfg, task_name=task_name, default_args=kwargs)
|
||||
|
||||
# dynamically add pipeline info to model for pipeline inference
|
||||
if hasattr(cfg, 'pipeline'):
|
||||
|
||||
@@ -13,8 +13,8 @@ from modelscope.utils.constant import Tasks
|
||||
Tasks.crowd_counting, module_name=Models.crowd_counting)
|
||||
class HRNetCrowdCounting(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str):
|
||||
super().__init__(model_dir)
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
from .hrnet_aspp_relu import HighResolutionNet as HRNet_aspp_relu
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from modelscope.utils.constant import Tasks
|
||||
Tasks.image_classification, module_name=Models.classification_model)
|
||||
class ClassificationModel(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str):
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
import mmcv
|
||||
from mmcls.models import build_classifier
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ from modelscope.models.cv.product_retrieval_embedding.item_embedding import (
|
||||
preprocess, resnet50_embed)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import create_device
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -48,9 +48,8 @@ class ProductRetrievalEmbedding(TorchModel):
|
||||
filter_param(src_params, own_state)
|
||||
model.load_state_dict(own_state)
|
||||
|
||||
cpu_flag = device == 'cpu'
|
||||
self.device = create_device(
|
||||
cpu_flag) # device.type == "cpu" or device.type == "cuda"
|
||||
device) # device.type == "cpu" or device.type == "cuda"
|
||||
self.use_gpu = self.device.type == 'cuda'
|
||||
|
||||
# config the model path
|
||||
|
||||
@@ -24,8 +24,8 @@ logger = get_logger()
|
||||
Tasks.video_multi_modal_embedding, module_name=Models.video_clip)
|
||||
class VideoCLIPForMultiModalEmbedding(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, device_id=-1):
|
||||
super().__init__(model_dir=model_dir, device_id=device_id)
|
||||
def __init__(self, model_dir, **kwargs):
|
||||
super().__init__(model_dir=model_dir, **kwargs)
|
||||
# model config parameters
|
||||
with open(f'{model_dir}/{ModelFile.CONFIGURATION}', 'r') as json_file:
|
||||
model_config = json.load(json_file)
|
||||
|
||||
@@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.torch_utils import create_device
|
||||
|
||||
|
||||
def audio_norm(x):
|
||||
|
||||
@@ -14,9 +14,10 @@ from modelscope.outputs import TASK_OUTPUTS
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Frameworks, ModelFile
|
||||
from modelscope.utils.device import (create_device, device_placement,
|
||||
verify_device)
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import create_device
|
||||
from .util import is_model, is_official_hub_path
|
||||
|
||||
if is_torch_available():
|
||||
@@ -41,7 +42,8 @@ class Pipeline(ABC):
|
||||
logger.info(f'initiate model from location {model}.')
|
||||
# expecting model has been prefetched to local cache beforehand
|
||||
return Model.from_pretrained(
|
||||
model, model_prefetched=True) if is_model(model) else model
|
||||
model, model_prefetched=True,
|
||||
device=self.device_name) if is_model(model) else model
|
||||
elif isinstance(model, Model):
|
||||
return model
|
||||
else:
|
||||
@@ -74,11 +76,15 @@ class Pipeline(ABC):
|
||||
config_file(str, optional): Filepath to configuration file.
|
||||
model: (list of) Model name or model object
|
||||
preprocessor: (list of) Preprocessor object
|
||||
device (str): gpu device or cpu device to use
|
||||
device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X
|
||||
auto_collate (bool): automatically to convert data to tensor or not.
|
||||
"""
|
||||
if config_file is not None:
|
||||
self.cfg = Config.from_file(config_file)
|
||||
|
||||
verify_device(device)
|
||||
self.device_name = device
|
||||
|
||||
if not isinstance(model, List):
|
||||
self.model = self.initiate_single_model(model)
|
||||
self.models = [self.model]
|
||||
@@ -94,15 +100,15 @@ class Pipeline(ABC):
|
||||
else:
|
||||
self.framework = None
|
||||
|
||||
assert device in ['gpu', 'cpu'], 'device should be either cpu or gpu.'
|
||||
self.device_name = device
|
||||
if self.framework == Frameworks.torch:
|
||||
self.device = create_device(self.device_name == 'cpu')
|
||||
self.device = create_device(self.device_name)
|
||||
self._model_prepare = False
|
||||
self._model_prepare_lock = Lock()
|
||||
self._auto_collate = auto_collate
|
||||
|
||||
def prepare_model(self):
|
||||
""" Place model on certain device for pytorch models before first inference
|
||||
"""
|
||||
self._model_prepare_lock.acquire(timeout=600)
|
||||
|
||||
def _prepare_single(model):
|
||||
@@ -125,39 +131,6 @@ class Pipeline(ABC):
|
||||
self._model_prepare = True
|
||||
self._model_prepare_lock.release()
|
||||
|
||||
@contextmanager
|
||||
def place_device(self):
|
||||
""" device placement function, allow user to specify which device to place pipeline
|
||||
|
||||
Returns:
|
||||
Context manager
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Requests for using pipeline on cuda:0 for gpu
|
||||
pipeline = pipeline(..., device='gpu')
|
||||
with pipeline.device():
|
||||
output = pipe(...)
|
||||
```
|
||||
"""
|
||||
if self.framework == Frameworks.tf:
|
||||
if self.device_name == 'cpu':
|
||||
with tf.device('/CPU:0'):
|
||||
yield
|
||||
else:
|
||||
with tf.device('/device:GPU:0'):
|
||||
yield
|
||||
|
||||
elif self.framework == Frameworks.torch:
|
||||
if self.device_name == 'gpu':
|
||||
device = create_device()
|
||||
if device.type == 'gpu':
|
||||
torch.cuda.set_device(device)
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def _get_framework(self) -> str:
|
||||
frameworks = []
|
||||
for m in self.models:
|
||||
@@ -272,10 +245,11 @@ class Pipeline(ABC):
|
||||
postprocess_params = kwargs.get('postprocess_params')
|
||||
|
||||
out = self.preprocess(input, **preprocess_params)
|
||||
with self.place_device():
|
||||
if self.framework == Frameworks.torch and self._auto_collate:
|
||||
with device_placement(self.framework, self.device_name):
|
||||
if self.framework == Frameworks.torch:
|
||||
with torch.no_grad():
|
||||
out = self._collate_fn(out)
|
||||
if self._auto_collate:
|
||||
out = self._collate_fn(out)
|
||||
out = self.forward(out, **forward_params)
|
||||
else:
|
||||
out = self.forward(out, **forward_params)
|
||||
|
||||
@@ -16,6 +16,7 @@ from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ...utils.device import device_placement
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
@@ -36,11 +37,14 @@ class ImageCartoonPipeline(Pipeline):
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.facer = FaceAna(self.model)
|
||||
self.sess_anime_head = self.load_sess(
|
||||
os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head')
|
||||
self.sess_anime_bg = self.load_sess(
|
||||
os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg')
|
||||
with device_placement(self.framework, self.device_name):
|
||||
self.facer = FaceAna(self.model)
|
||||
self.sess_anime_head = self.load_sess(
|
||||
os.path.join(self.model, 'cartoon_anime_h.pb'),
|
||||
'model_anime_head')
|
||||
self.sess_anime_bg = self.load_sess(
|
||||
os.path.join(self.model, 'cartoon_anime_bg.pb'),
|
||||
'model_anime_bg')
|
||||
|
||||
self.box_width = 288
|
||||
global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg'))
|
||||
|
||||
@@ -10,6 +10,7 @@ from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -31,19 +32,20 @@ class ImageMattingPipeline(Pipeline):
|
||||
tf = tf.compat.v1
|
||||
model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE)
|
||||
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
self._session = tf.Session(config=config)
|
||||
with self._session.as_default():
|
||||
logger.info(f'loading model from {model_path}')
|
||||
with tf.gfile.FastGFile(model_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
self.output = self._session.graph.get_tensor_by_name(
|
||||
'output_png:0')
|
||||
self.input_name = 'input_image:0'
|
||||
logger.info('load model done')
|
||||
with device_placement(self.framework, self.device_name):
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
self._session = tf.Session(config=config)
|
||||
with self._session.as_default():
|
||||
logger.info(f'loading model from {model_path}')
|
||||
with tf.gfile.FastGFile(model_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
self.output = self._session.graph.get_tensor_by_name(
|
||||
'output_png:0')
|
||||
self.input_name = 'input_image:0'
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
|
||||
@@ -10,6 +10,7 @@ from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -31,30 +32,31 @@ class ImageStyleTransferPipeline(Pipeline):
|
||||
tf = tf.compat.v1
|
||||
model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE)
|
||||
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
self._session = tf.Session(config=config)
|
||||
self.max_length = 800
|
||||
with self._session.as_default():
|
||||
logger.info(f'loading model from {model_path}')
|
||||
with tf.gfile.FastGFile(model_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
with device_placement(self.framework, self.device_name):
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
self._session = tf.Session(config=config)
|
||||
self.max_length = 800
|
||||
with self._session.as_default():
|
||||
logger.info(f'loading model from {model_path}')
|
||||
with tf.gfile.FastGFile(model_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
self.content = tf.get_default_graph().get_tensor_by_name(
|
||||
'content:0')
|
||||
self.style = tf.get_default_graph().get_tensor_by_name(
|
||||
'style:0')
|
||||
self.output = tf.get_default_graph().get_tensor_by_name(
|
||||
'stylized_output:0')
|
||||
self.attention = tf.get_default_graph().get_tensor_by_name(
|
||||
'attention_map:0')
|
||||
self.inter_weight = tf.get_default_graph().get_tensor_by_name(
|
||||
'inter_weight:0')
|
||||
self.centroids = tf.get_default_graph().get_tensor_by_name(
|
||||
'centroids:0')
|
||||
logger.info('load model done')
|
||||
self.content = tf.get_default_graph().get_tensor_by_name(
|
||||
'content:0')
|
||||
self.style = tf.get_default_graph().get_tensor_by_name(
|
||||
'style:0')
|
||||
self.output = tf.get_default_graph().get_tensor_by_name(
|
||||
'stylized_output:0')
|
||||
self.attention = tf.get_default_graph().get_tensor_by_name(
|
||||
'attention_map:0')
|
||||
self.inter_weight = tf.get_default_graph(
|
||||
).get_tensor_by_name('inter_weight:0')
|
||||
self.centroids = tf.get_default_graph().get_tensor_by_name(
|
||||
'centroids:0')
|
||||
logger.info('load model done')
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return pipeline_parameters, {}, {}
|
||||
|
||||
@@ -11,6 +11,7 @@ from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .ocr_utils import (SegLinkDetector, cal_width, combine_segments_python,
|
||||
decode_segments_links_python, nms_python,
|
||||
@@ -51,66 +52,67 @@ class OCRDetectionPipeline(Pipeline):
|
||||
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
|
||||
'checkpoint-80000')
|
||||
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
self._session = tf.Session(config=config)
|
||||
self.input_images = tf.placeholder(
|
||||
tf.float32, shape=[1, 1024, 1024, 3], name='input_images')
|
||||
self.output = {}
|
||||
with device_placement(self.framework, self.device_name):
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
self._session = tf.Session(config=config)
|
||||
self.input_images = tf.placeholder(
|
||||
tf.float32, shape=[1, 1024, 1024, 3], name='input_images')
|
||||
self.output = {}
|
||||
|
||||
with tf.variable_scope('', reuse=tf.AUTO_REUSE):
|
||||
global_step = tf.get_variable(
|
||||
'global_step', [],
|
||||
initializer=tf.constant_initializer(0),
|
||||
dtype=tf.int64,
|
||||
trainable=False)
|
||||
variable_averages = tf.train.ExponentialMovingAverage(
|
||||
0.997, global_step)
|
||||
with tf.variable_scope('', reuse=tf.AUTO_REUSE):
|
||||
global_step = tf.get_variable(
|
||||
'global_step', [],
|
||||
initializer=tf.constant_initializer(0),
|
||||
dtype=tf.int64,
|
||||
trainable=False)
|
||||
variable_averages = tf.train.ExponentialMovingAverage(
|
||||
0.997, global_step)
|
||||
|
||||
# detector
|
||||
detector = SegLinkDetector()
|
||||
all_maps = detector.build_model(
|
||||
self.input_images, is_training=False)
|
||||
# detector
|
||||
detector = SegLinkDetector()
|
||||
all_maps = detector.build_model(
|
||||
self.input_images, is_training=False)
|
||||
|
||||
# decode local predictions
|
||||
all_nodes, all_links, all_reg = [], [], []
|
||||
for i, maps in enumerate(all_maps):
|
||||
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
|
||||
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)
|
||||
# decode local predictions
|
||||
all_nodes, all_links, all_reg = [], [], []
|
||||
for i, maps in enumerate(all_maps):
|
||||
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
|
||||
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)
|
||||
|
||||
cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))
|
||||
cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))
|
||||
|
||||
lnk_prob_pos = tf.nn.softmax(
|
||||
tf.reshape(lnk_maps, [-1, 4])[:, :2])
|
||||
lnk_prob_mut = tf.nn.softmax(
|
||||
tf.reshape(lnk_maps, [-1, 4])[:, 2:])
|
||||
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)
|
||||
lnk_prob_pos = tf.nn.softmax(
|
||||
tf.reshape(lnk_maps, [-1, 4])[:, :2])
|
||||
lnk_prob_mut = tf.nn.softmax(
|
||||
tf.reshape(lnk_maps, [-1, 4])[:, 2:])
|
||||
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)
|
||||
|
||||
all_nodes.append(cls_prob)
|
||||
all_links.append(lnk_prob)
|
||||
all_reg.append(reg_maps)
|
||||
all_nodes.append(cls_prob)
|
||||
all_links.append(lnk_prob)
|
||||
all_reg.append(reg_maps)
|
||||
|
||||
# decode segments and links
|
||||
image_size = tf.shape(self.input_images)[1:3]
|
||||
segments, group_indices, segment_counts, _ = decode_segments_links_python(
|
||||
image_size,
|
||||
all_nodes,
|
||||
all_links,
|
||||
all_reg,
|
||||
anchor_sizes=list(detector.anchor_sizes))
|
||||
# decode segments and links
|
||||
image_size = tf.shape(self.input_images)[1:3]
|
||||
segments, group_indices, segment_counts, _ = decode_segments_links_python(
|
||||
image_size,
|
||||
all_nodes,
|
||||
all_links,
|
||||
all_reg,
|
||||
anchor_sizes=list(detector.anchor_sizes))
|
||||
|
||||
# combine segments
|
||||
combined_rboxes, combined_counts = combine_segments_python(
|
||||
segments, group_indices, segment_counts)
|
||||
self.output['combined_rboxes'] = combined_rboxes
|
||||
self.output['combined_counts'] = combined_counts
|
||||
# combine segments
|
||||
combined_rboxes, combined_counts = combine_segments_python(
|
||||
segments, group_indices, segment_counts)
|
||||
self.output['combined_rboxes'] = combined_rboxes
|
||||
self.output['combined_counts'] = combined_counts
|
||||
|
||||
with self._session.as_default() as sess:
|
||||
logger.info(f'loading model from {model_path}')
|
||||
# load model
|
||||
model_loader = tf.train.Saver(
|
||||
variable_averages.variables_to_restore())
|
||||
model_loader.restore(sess, model_path)
|
||||
with self._session.as_default() as sess:
|
||||
logger.info(f'loading model from {model_path}')
|
||||
# load model
|
||||
model_loader = tf.train.Saver(
|
||||
variable_averages.variables_to_restore())
|
||||
model_loader.restore(sess, model_path)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
|
||||
@@ -23,6 +23,7 @@ from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import create_device, device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
@@ -42,12 +43,9 @@ class SkinRetouchingPipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
super().__init__(model=model, device=device)
|
||||
|
||||
if torch.cuda.is_available() and device == 'gpu':
|
||||
device = 'cuda'
|
||||
else:
|
||||
device = 'cpu'
|
||||
device = create_device(self.device_name)
|
||||
model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
detector_model_path = os.path.join(
|
||||
self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth')
|
||||
@@ -81,16 +79,17 @@ class SkinRetouchingPipeline(Pipeline):
|
||||
|
||||
self.skin_model_path = skin_model_path
|
||||
if self.skin_model_path is not None:
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.3
|
||||
config.gpu_options.allow_growth = True
|
||||
self.sess = tf.Session(config=config)
|
||||
with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
self.sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
with device_placement(self.framework, self.device_name):
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.3
|
||||
config.gpu_options.allow_growth = True
|
||||
self.sess = tf.Session(config=config)
|
||||
with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
self.sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.image_files_transforms = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
|
||||
@@ -4,6 +4,7 @@ from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -26,7 +27,7 @@ class VideoMultiModalEmbeddingPipeline(Pipeline):
|
||||
return input
|
||||
|
||||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
|
||||
with self.place_device():
|
||||
with device_placement(self.framework, self.device_name):
|
||||
out = self.forward(input)
|
||||
|
||||
self._check_output(out)
|
||||
|
||||
@@ -31,7 +31,7 @@ class TranslationPipeline(Pipeline):
|
||||
|
||||
@param model: A Model instance.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
super().__init__(model=model, **kwargs)
|
||||
model = self.model.model_dir
|
||||
tf.reset_default_graph()
|
||||
|
||||
|
||||
@@ -36,11 +36,11 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
|
||||
ConfigKeys, Hubs, ModeKeys, ModelFile,
|
||||
Tasks, TrainerStages)
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.device import create_device, verify_device
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.registry import build_from_cfg
|
||||
from modelscope.utils.torch_utils import (create_device, get_dist_info,
|
||||
init_dist)
|
||||
from modelscope.utils.torch_utils import get_dist_info, init_dist
|
||||
from .base import BaseTrainer
|
||||
from .builder import TRAINERS
|
||||
from .default_config import DEFAULT_CONFIG
|
||||
@@ -150,9 +150,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.eval_preprocessor.mode = ModeKeys.EVAL
|
||||
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
assert device_name in ['gpu',
|
||||
'cpu'], 'device should be either cpu or gpu.'
|
||||
self.device = create_device(device_name == 'cpu')
|
||||
verify_device(device_name)
|
||||
self.device = create_device(device_name)
|
||||
|
||||
self.train_dataset = self.to_task_dataset(
|
||||
train_dataset,
|
||||
|
||||
@@ -290,3 +290,9 @@ class ColorCodes:
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
END = '\033[0m'
|
||||
|
||||
|
||||
class Devices:
|
||||
"""device used for training and inference"""
|
||||
cpu = 'cpu'
|
||||
gpu = 'gpu'
|
||||
|
||||
110
modelscope/utils/device.py
Normal file
110
modelscope/utils/device.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from modelscope.utils.constant import Devices, Frameworks
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def verify_device(device_name):
|
||||
""" Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X.
|
||||
|
||||
Args:
|
||||
device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X
|
||||
where X is the ordinal for gpu device.
|
||||
|
||||
Return:
|
||||
device info (tuple): device_type and device_id, if device_id is not set, will use 0 as default.
|
||||
"""
|
||||
device_name = device_name.lower()
|
||||
eles = device_name.split(':')
|
||||
err_msg = 'device should be either cpu, cuda, gpu, gpu:X or cuda:X where X is the ordinal for gpu device.'
|
||||
assert len(eles) <= 2, err_msg
|
||||
assert eles[0] in ['cpu', 'cuda', 'gpu'], err_msg
|
||||
device_type = eles[0]
|
||||
device_id = None
|
||||
if len(eles) > 1:
|
||||
device_id = int(eles[1])
|
||||
if device_type == 'cuda':
|
||||
device_type = Devices.gpu
|
||||
if device_type == Devices.gpu and device_id is None:
|
||||
device_id = 0
|
||||
return device_type, device_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_placement(framework, device_name='gpu:0'):
|
||||
""" Device placement function, allow user to specify which device to place model or tensor
|
||||
Args:
|
||||
framework (str): tensorflow or pytorch.
|
||||
device (str): gpu or cpu to use, if you want to specify certain gpu,
|
||||
use gpu:$gpu_id or cuda:$gpu_id.
|
||||
|
||||
Returns:
|
||||
Context manager
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Requests for using model on cuda:0 for gpu
|
||||
with device_placement('pytorch', device='gpu:0'):
|
||||
model = Model.from_pretrained(...)
|
||||
```
|
||||
"""
|
||||
device_type, device_id = verify_device(device_name)
|
||||
|
||||
if framework == Frameworks.tf:
|
||||
if device_type == Devices.gpu and not tf.test.is_gpu_available():
|
||||
logger.warning(
|
||||
'tensorflow cuda is not available, using cpu instead.')
|
||||
device_type = Devices.cpu
|
||||
if device_type == Devices.cpu:
|
||||
with tf.device('/CPU:0'):
|
||||
yield
|
||||
else:
|
||||
if device_type == Devices.gpu:
|
||||
with tf.device(f'/device:gpu:{device_id}'):
|
||||
yield
|
||||
|
||||
elif framework == Frameworks.torch:
|
||||
if device_type == Devices.gpu:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(f'cuda:{device_id}')
|
||||
else:
|
||||
logger.warning('cuda is not available, using cpu instead.')
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def create_device(device_name) -> torch.DeviceObjType:
|
||||
""" create torch device
|
||||
|
||||
Args:
|
||||
device_name (str): cpu, gpu, gpu:0, cuda:0 etc.
|
||||
"""
|
||||
device_type, device_id = verify_device(device_name)
|
||||
use_cuda = False
|
||||
if device_type == Devices.gpu:
|
||||
use_cuda = True
|
||||
if not torch.cuda.is_available():
|
||||
logger.warning(
|
||||
'cuda is not available, create gpu device failed, using cpu instead.'
|
||||
)
|
||||
use_cuda = False
|
||||
|
||||
if use_cuda:
|
||||
device = torch.device(f'cuda:{device_id}')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
return device
|
||||
@@ -132,17 +132,6 @@ def master_only(func: Callable) -> Callable:
|
||||
return wrapper
|
||||
|
||||
|
||||
def create_device(cpu: bool = False) -> torch.DeviceObjType:
|
||||
use_cuda = torch.cuda.is_available() and not cpu
|
||||
if use_cuda:
|
||||
local_rank = os.environ.get('LOCAL_RANK', 0)
|
||||
device = torch.device(f'cuda:{local_rank}')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def make_tmp_dir():
|
||||
"""Make sure each rank has the same temporary directory on the distributed mode.
|
||||
"""
|
||||
|
||||
@@ -41,3 +41,7 @@ class KWSFarfieldTest(unittest.TestCase):
|
||||
result = kws(data)
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
print(result['kws_list'][-1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
101
tests/utils/test_device.py
Normal file
101
tests/utils/test_device.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.utils.constant import Frameworks
|
||||
from modelscope.utils.device import (create_device, device_placement,
|
||||
verify_device)
|
||||
|
||||
# import tensorflow must be imported after torch is imported when using tf1.15
|
||||
import tensorflow as tf # isort:skip
|
||||
|
||||
|
||||
class DeviceTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
def test_verify(self):
|
||||
device_name, device_id = verify_device('cpu')
|
||||
self.assertEqual(device_name, 'cpu')
|
||||
self.assertTrue(device_id is None)
|
||||
device_name, device_id = verify_device('CPU')
|
||||
self.assertEqual(device_name, 'cpu')
|
||||
|
||||
device_name, device_id = verify_device('gpu')
|
||||
self.assertEqual(device_name, 'gpu')
|
||||
self.assertTrue(device_id == 0)
|
||||
|
||||
device_name, device_id = verify_device('cuda')
|
||||
self.assertEqual(device_name, 'gpu')
|
||||
self.assertTrue(device_id == 0)
|
||||
|
||||
device_name, device_id = verify_device('cuda:0')
|
||||
self.assertEqual(device_name, 'gpu')
|
||||
self.assertTrue(device_id == 0)
|
||||
|
||||
device_name, device_id = verify_device('gpu:1')
|
||||
self.assertEqual(device_name, 'gpu')
|
||||
self.assertTrue(device_id == 1)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
verify_device('xgu')
|
||||
|
||||
def test_create_device_torch(self):
|
||||
if torch.cuda.is_available():
|
||||
target_device_type = 'cuda'
|
||||
target_device_index = 0
|
||||
else:
|
||||
target_device_type = 'cpu'
|
||||
target_device_index = None
|
||||
device = create_device('gpu')
|
||||
self.assertTrue(isinstance(device, torch.device))
|
||||
self.assertTrue(device.type == target_device_type)
|
||||
self.assertTrue(device.index == target_device_index)
|
||||
|
||||
device = create_device('gpu:0')
|
||||
self.assertTrue(isinstance(device, torch.device))
|
||||
self.assertTrue(device.type == target_device_type)
|
||||
self.assertTrue(device.index == target_device_index)
|
||||
|
||||
device = create_device('cuda')
|
||||
self.assertTrue(device.type == target_device_type)
|
||||
self.assertTrue(isinstance(device, torch.device))
|
||||
self.assertTrue(device.index == target_device_index)
|
||||
|
||||
device = create_device('cuda:0')
|
||||
self.assertTrue(isinstance(device, torch.device))
|
||||
self.assertTrue(device.type == target_device_type)
|
||||
self.assertTrue(device.index == target_device_index)
|
||||
|
||||
def test_device_placement_cpu(self):
|
||||
with device_placement(Frameworks.torch, 'cpu'):
|
||||
pass
|
||||
|
||||
def test_device_placement_tf_gpu(self):
|
||||
tf.debugging.set_log_device_placement(True)
|
||||
with device_placement(Frameworks.tf, 'gpu:0'):
|
||||
a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
|
||||
c = tf.matmul(a, b)
|
||||
s = tf.Session()
|
||||
s.run(c)
|
||||
tf.debugging.set_log_device_placement(False)
|
||||
|
||||
def test_device_placement_torch_gpu(self):
|
||||
with device_placement(Frameworks.torch, 'gpu:0'):
|
||||
if torch.cuda.is_available():
|
||||
self.assertEqual(torch.cuda.current_device(), 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user