mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
fix device mis match
This commit is contained in:
@@ -187,13 +187,14 @@ class OfaForAllTasks(TorchModel):
|
||||
valid_size = len(val_ans)
|
||||
valid_tgt_items = [
|
||||
torch.cat([
|
||||
torch.tensor(decoder_prompt[1:]), valid_answer,
|
||||
torch.tensor(decoder_prompt[1:]).to('cpu'), valid_answer,
|
||||
self.eos_item
|
||||
]) for decoder_prompt in input['decoder_prompts']
|
||||
for valid_answer in val_ans
|
||||
]
|
||||
valid_prev_items = [
|
||||
torch.cat([torch.tensor(decoder_prompt), valid_answer])
|
||||
torch.cat(
|
||||
[torch.tensor(decoder_prompt).to('cpu'), valid_answer])
|
||||
for decoder_prompt in input['decoder_prompts']
|
||||
for valid_answer in val_ans
|
||||
]
|
||||
|
||||
@@ -37,19 +37,6 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
result = img_captioning({'image': image})
|
||||
print(result[OutputKeys.CAPTION])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_image_captioning_zh_with_model(self):
|
||||
model = Model.from_pretrained(
|
||||
'/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_image-caption_coco_base_zh'
|
||||
)
|
||||
img_captioning = pipeline(
|
||||
task=Tasks.image_captioning,
|
||||
model=model,
|
||||
)
|
||||
image = 'data/test/images/image_captioning.png'
|
||||
result = img_captioning({'image': image})
|
||||
print(result[OutputKeys.CAPTION])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_image_captioning_with_name(self):
|
||||
img_captioning = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user