yuanzhi.zyz
2023-03-13 17:49:45 +08:00
committed by yuze.zyz
parent 7659b64cdc
commit 352f670bec

View File

@@ -41,7 +41,7 @@ class ViTSTR(VisionTransformer):
x = self.forward_features(x)
ap = x.view(x.shape[0] // 3, 3, 75, x.shape[2])
features_1d_concat = torch.ones(x.shape[0] // 3, 201,
x.shape[2]).cuda()
x.shape[2]).type_as(x)
features_1d_concat[:, :69, :] = ap[:, 0, :69, :]
features_1d_concat[:, 69:69 + 63, :] = ap[:, 1, 6:-6, :]
features_1d_concat[:, 69 + 63:, :] = ap[:, 2, 6:, :]