fix image colorization model load issue

之前的评审提交(https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13669766?tab=changes)可能导致 modelscope hub 的模型无法读取,对于不同 modelscope 版本使用者产生兼容性问题,因此在上次评审提交的基础上做一定修改确保兼容性。
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13716448
* fix image colorization model load issue
This commit is contained in:
kangxiaoyang.kxy
2023-08-21 09:47:32 +08:00
committed by wenmeng.zwm
parent 7f7573f15d
commit 6c3973b9a3
2 changed files with 3 additions and 3 deletions

View File

@@ -158,7 +158,7 @@ class Encoder(nn.Module):
return hooks
def forward(self, img):
return self.arch(img)
return self.arch.forward_features(img)
class MultiScaleColorDecoder(nn.Module):

View File

@@ -116,7 +116,7 @@ class ConvNeXt(nn.Module):
self.add_module(layer_name, layer)
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
# self.head_cls = nn.Linear(dims[-1], 4)
self.head_cls = nn.Linear(dims[-1], 4)
self.apply(self._init_weights)
# self.head_cls.weight.data.mul_(head_init_scale)
@@ -141,7 +141,7 @@ class ConvNeXt(nn.Module):
def forward(self, x):
x = self.forward_features(x)
# x = self.head_cls(x)
x = self.head_cls(x)
return x