diff --git a/modelscope/models/cv/ocr_recognition/modules/vitstr.py b/modelscope/models/cv/ocr_recognition/modules/vitstr.py index 56eaeb94..c9fa0693 100644 --- a/modelscope/models/cv/ocr_recognition/modules/vitstr.py +++ b/modelscope/models/cv/ocr_recognition/modules/vitstr.py @@ -41,7 +41,7 @@ class ViTSTR(VisionTransformer): x = self.forward_features(x) ap = x.view(x.shape[0] // 3, 3, 75, x.shape[2]) features_1d_concat = torch.ones(x.shape[0] // 3, 201, - x.shape[2]).cuda() + x.shape[2]).type_as(x) features_1d_concat[:, :69, :] = ap[:, 0, :69, :] features_1d_concat[:, 69:69 + 63, :] = ap[:, 1, 6:-6, :] features_1d_concat[:, 69 + 63:, :] = ap[:, 2, 6:, :]