mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
[to #42322933]修复shop segmentation CPU Inference错误
修复CPU Inference错误,支持CPU inference
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10177721
This commit is contained in:
committed by
yingda.chen
parent
12b8f5d04b
commit
1eedbd65bc
@@ -552,7 +552,7 @@ class CLIPVisionTransformer(nn.Module):
|
||||
nn.GroupNorm(1, embed_dim),
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
nn.SyncBatchNorm(embed_dim),
|
||||
nn.BatchNorm2d(embed_dim),
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(
|
||||
embed_dim, embed_dim, kernel_size=2, stride=2),
|
||||
|
||||
@@ -33,18 +33,18 @@ class ShopSegmentation(TorchModel):
|
||||
model_dir=model_dir, device_id=device_id, *args, **kwargs)
|
||||
|
||||
self.model = SHOPSEG(model_dir=model_dir)
|
||||
pretrained_params = torch.load('{}/{}'.format(
|
||||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
|
||||
|
||||
pretrained_params = torch.load(
|
||||
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
|
||||
map_location='cpu')
|
||||
self.model.load_state_dict(pretrained_params)
|
||||
self.model.eval()
|
||||
self.device_id = device_id
|
||||
if self.device_id >= 0 and torch.cuda.is_available():
|
||||
self.model.to('cuda:{}'.format(self.device_id))
|
||||
logger.info('Use GPU: {}'.format(self.device_id))
|
||||
if device_id >= 0 and torch.cuda.is_available():
|
||||
self.model.to('cuda:{}'.format(device_id))
|
||||
logger.info('Use GPU: {}'.format(device_id))
|
||||
else:
|
||||
self.device_id = -1
|
||||
device_id = -1
|
||||
logger.info('Use CPU for inference')
|
||||
self.device_id = device_id
|
||||
|
||||
def preprocess(self, img, size=1024):
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
|
||||
Reference in New Issue
Block a user