[to #42322933]修复shop segmentation CPU Inference错误

修复CPU Inference错误,支持CPU inference
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10177721
This commit is contained in:
xingguang.zxg
2022-09-20 15:53:38 +08:00
committed by yingda.chen
parent 12b8f5d04b
commit 1eedbd65bc
2 changed files with 9 additions and 9 deletions

View File

@@ -552,7 +552,7 @@ class CLIPVisionTransformer(nn.Module):
nn.GroupNorm(1, embed_dim),
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),

View File

@@ -33,18 +33,18 @@ class ShopSegmentation(TorchModel):
model_dir=model_dir, device_id=device_id, *args, **kwargs)
self.model = SHOPSEG(model_dir=model_dir)
pretrained_params = torch.load('{}/{}'.format(
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
pretrained_params = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location='cpu')
self.model.load_state_dict(pretrained_params)
self.model.eval()
self.device_id = device_id
if self.device_id >= 0 and torch.cuda.is_available():
self.model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
if device_id >= 0 and torch.cuda.is_available():
self.model.to('cuda:{}'.format(device_id))
logger.info('Use GPU: {}'.format(device_id))
else:
self.device_id = -1
device_id = -1
logger.info('Use CPU for inference')
self.device_id = device_id
def preprocess(self, img, size=1024):
mean = [0.48145466, 0.4578275, 0.40821073]