diff --git a/inference_img.py b/inference_img.py index 9ca6fac..d9b78cb 100644 --- a/inference_img.py +++ b/inference_img.py @@ -42,4 +42,4 @@ for i in range(args.times): if not os.path.exists('output'): os.mkdir('output') for i in range(len(img_list)): - cv2.imwrite('output/img{}.png'.format(i), img_list[i][0].numpy().transpose(1, 2, 0)[:h, :w] * 255) + cv2.imwrite('output/img{}.png'.format(i), img_list[i][0].cpu().numpy().transpose(1, 2, 0)[:h, :w] * 255)