revise format

This commit is contained in:
rujiao.lrj
2022-11-17 15:57:55 +08:00
parent 1cb215422d
commit b5a576b860

View File

@@ -643,8 +643,8 @@ class DLAUp(nn.Module):
assert len(layers) > 1
for i in range(len(layers) - 1):
ida = getattr(self, "ida_{}".format(i))
x, y = ida(layers[-i - 2 :])
layers[-i - 1 :] = y
x, y = ida(layers[-i - 2:])
layers[-i - 1:] = y
return x
@@ -668,8 +668,8 @@ class DLASeg(nn.Module):
self.first_level = int(np.log2(down_ratio))
self.base = globals()[base_name](pretrained=pretrained, return_levels=True)
channels = self.base.channels
scales = [2**i for i in range(len(channels[self.first_level :]))]
self.dla_up = DLAUp(channels[self.first_level :], scales=scales)
scales = [2**i for i in range(len(channels[self.first_level:]))]
self.dla_up = DLAUp(channels[self.first_level:], scales=scales)
for head in self.heads:
classes = self.heads[head]
@@ -713,7 +713,7 @@ class DLASeg(nn.Module):
def forward(self, x):
x = self.base(x)
x = self.dla_up(x[self.first_level :])
x = self.dla_up(x[self.first_level:])
ret = {}
for head in self.heads:
ret[head] = self.__getattr__(head)(x)