mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
fix cuda bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11968695
This commit is contained in:
@@ -41,7 +41,7 @@ class ViTSTR(VisionTransformer):
|
||||
x = self.forward_features(x)
|
||||
ap = x.view(x.shape[0] // 3, 3, 75, x.shape[2])
|
||||
features_1d_concat = torch.ones(x.shape[0] // 3, 201,
|
||||
x.shape[2]).cuda()
|
||||
x.shape[2]).type_as(x)
|
||||
features_1d_concat[:, :69, :] = ap[:, 0, :69, :]
|
||||
features_1d_concat[:, 69:69 + 63, :] = ap[:, 1, 6:-6, :]
|
||||
features_1d_concat[:, 69 + 63:, :] = ap[:, 2, 6:, :]
|
||||
|
||||
Reference in New Issue
Block a user