From 352f670bec39eb6d518511a5e8b28532b3e0d2b6 Mon Sep 17 00:00:00 2001 From: "yuanzhi.zyz" Date: Mon, 13 Mar 2023 17:49:45 +0800 Subject: [PATCH] fix cuda bug Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11968695 --- modelscope/models/cv/ocr_recognition/modules/vitstr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelscope/models/cv/ocr_recognition/modules/vitstr.py b/modelscope/models/cv/ocr_recognition/modules/vitstr.py index 56eaeb94..c9fa0693 100644 --- a/modelscope/models/cv/ocr_recognition/modules/vitstr.py +++ b/modelscope/models/cv/ocr_recognition/modules/vitstr.py @@ -41,7 +41,7 @@ class ViTSTR(VisionTransformer): x = self.forward_features(x) ap = x.view(x.shape[0] // 3, 3, 75, x.shape[2]) features_1d_concat = torch.ones(x.shape[0] // 3, 201, - x.shape[2]).cuda() + x.shape[2]).type_as(x) features_1d_concat[:, :69, :] = ap[:, 0, :69, :] features_1d_concat[:, 69:69 + 63, :] = ap[:, 1, 6:-6, :] features_1d_concat[:, 69 + 63:, :] = ap[:, 2, 6:, :]