Merge branch ofa/task_finetune into master

Title: [to #42322933]add finetune & merge master 

新增ofa其它任务的finetune能力
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10656541
This commit is contained in:
yingda.chen
2022-11-08 13:01:22 +08:00
12 changed files with 395 additions and 43 deletions

View File

@@ -402,6 +402,7 @@ class Metrics(object):
# accuracy
accuracy = 'accuracy'
multi_average_precision = 'mAP'
audio_noise_metric = 'audio-noise-metric'
# text gen

View File

@@ -24,6 +24,7 @@ class MetricKeys(object):
ROUGE_1 = 'rouge-1'
ROUGE_L = 'rouge-l'
NED = 'ned' # ocr metric
mAP = 'mAP'
BatchAcc = 'inbatch_t2i_recall_at_1'

View File

@@ -0,0 +1,67 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
import numpy as np
from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(
group_key=default_group, module_name=Metrics.multi_average_precision)
class AveragePrecisionMetric(Metric):
"""The metric computation class for multi avarage precision classes.
This metric class calculates multi avarage precision for the whole input batches.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.preds = []
self.labels = []
self.thresh = kwargs.get('threshold', 0.5)
def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
for truth in ground_truths:
self.labels.append(truth)
for result in eval_results:
if isinstance(truth, str):
self.preds.append(result.strip().replace(' ', ''))
else:
self.preds.append(result)
def evaluate(self):
assert len(self.preds) == len(self.labels)
scores = self._calculate_ap_score(self.preds, self.labels, self.thresh)
return {MetricKeys.mAP: scores.mean().item()}
def _calculate_ap_score(self, preds, labels, thresh=0.5):
hyps = np.array(preds)
refs = np.array(labels)
a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2])
b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])
interacts = np.concatenate([a, b], axis=1)
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (
hyps[:, 3] - hyps[:, 1])
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
interacts_w = interacts[:, 2] - interacts[:, 0]
interacts_h = interacts[:, 3] - interacts[:, 1]
area_interacts = interacts_w * interacts_h
ious = area_interacts / (
area_predictions + area_targets - area_interacts + 1e-6)
return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)

View File

@@ -43,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = sample['label']
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])

View File

@@ -1,13 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import functools
from typing import Any, Dict
import torch
from PIL import Image
from PIL import Image, ImageFile
from timm.data import create_transform
from torchvision import transforms
from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
from .utils.vision_helper import RandomAugment
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
@@ -28,18 +35,77 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
super(OfaImageClassificationPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
if self.mode != ModeKeys.TRAIN:
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
else:
self.patch_resize_transform = create_transform(
input_size=self.patch_image_size,
is_training=True,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=self.mean,
std=self.std)
self.patch_resize_transform = transforms.Compose(
functools.reduce(lambda x, y: x + y, [
[
lambda image: image.convert('RGB'),
],
self.patch_resize_transform.transforms[:2],
[self.patch_resize_transform.transforms[2]],
[
RandomAugment(
2,
7,
isPIL=True,
augs=[
'Identity', 'AutoContrast', 'Equalize',
'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Rotate'
]),
],
self.patch_resize_transform.transforms[3:],
]))
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = ' {}'.format(sample['label'])
sample['ref_dict'] = {sample['label']: 1.0}
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target'][:-1]])
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(sample['prev_output_tokens']),
len(self.tgt_dict))).bool()
for i in range(len(sample['prev_output_tokens'])):
constraint_prefix_token = sample[
'prev_output_tokens'][:i + 1].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
inputs = self.tokenize_text(prompt)
@@ -48,4 +114,6 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
return sample

View File

@@ -11,9 +11,6 @@ from zhconv import convert
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def ocr_resize(img, patch_image_size, is_document=False):
img = img.convert('RGB')
@@ -112,6 +109,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
}
if 'text' in self.column_map and self.column_map['text'] in data:
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
sample['label'] = unicodedata2.normalize(
'NFKC', convert(target, 'zh-hans'))
return sample

View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict
import torch
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
@@ -24,9 +26,26 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
self).__init__(cfg, model_dir, mode, *args, **kwargs)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target_str = sample['label'].lower()
target = super().pre_caption(target_str, max_words=self.max_tgt_length)
target = target.replace('[unk]', 'unk').replace('<unk>', 'unk')
sample['target'] = self.tokenize_text(target, add_bos=False)
noise_target_item = self.add_noise_to_tgt(
sample['target'][:-1].clone())
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, noise_target_item])
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = super().pre_caption(
data['text'], max_words=self.max_src_length)
source = source.strip()[:self.max_src_length]
data[self.column_map['text']], max_words=self.max_src_length)
source = source.replace('[unk]', 'unk').replace('<unk>', 'unk')
prompt = self.cfg.model.get(
'prompt', ' " {} " Summarize the article with a title: ')
@@ -42,4 +61,17 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
'source': inputs,
'decoder_prompt': decoder_prompt,
}
if 'summary' in self.column_map and self.column_map['summary'] in data:
sample['label'] = data[self.column_map['summary']]
return sample
def add_noise_to_tgt(self, target):
noise_indices = torch.FloatTensor(
target.size(0)).uniform_() < self.cfg.model.get(
'noise_ratio', 0.0)
target[noise_indices] = torch.randint(
4,
len(self.src_dict) - self.cfg.model.get('num_codes', 8192)
- self.cfg.model.get('num_bins', 1000),
size=(noise_indices.sum(), ))
return target

View File

@@ -38,18 +38,64 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = ' {}'.format(sample['label'])
sample['ref_dict'] = {sample['label']: 1.0}
tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False)
if self.prompt_type == 'none':
prev_output_item = torch.cat([self.bos_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'src':
prev_output_item = torch.cat([sample['source'], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'prev_output':
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
else:
raise NotImplementedError
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
sample['target'] = target_item
sample['prev_output_tokens'] = prev_output_item
if self.constraint_trie is not None:
constraint_mask = torch.zeros(
(len(target_item), len(self.tgt_dict))).bool()
start_idx = len(target_item) - len(tgt_item) - 1
for i in range(
len(target_item) - len(tgt_item) - 1, len(target_item)):
constraint_prefix_token = [
self.tgt_dict.bos()
] + target_item[start_idx:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
if 'text2' not in data:
hypothesis = self.pre_caption(data['text'], self.max_src_length)
hypothesis = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get('prompt',
' does the image describe " {} "?')
text = prompt.format(hypothesis)
else:
assert 'text' in data, f'text must be in the input {data.keys()}'
caption = self.pre_caption(data['text2'], self.max_src_length)
hypothesis = self.pre_caption(data['text'], self.max_src_length)
caption = self.pre_caption(data[self.column_map['text2']],
self.max_src_length)
hypothesis = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' can image and text1 " {} " imply text2 " {} "?')
text = prompt.format(caption, hypothesis)
@@ -68,4 +114,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
if 'relation' in self.column_map and self.column_map[
'relation'] in data:
sample['label'] = data[self.column_map['relation']]
return sample

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
@@ -8,6 +9,7 @@ from torchvision import transforms
from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
from .utils import transforms as T
class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
@@ -27,24 +29,98 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
"""
super(OfaVisualGroundingPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
self.num_bins = self.cfg.model.get('num_bins', 1000)
if self.mode == ModeKeys.TRAIN:
# for positioning
self.positioning_transform = T.Compose([
T.RandomResize([self.patch_image_size],
max_size=self.patch_image_size),
T.ToTensor(),
T.Normalize(
mean=self.mean,
std=self.std,
max_image_size=self.max_image_size)
])
else:
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
w, h = image.size
boxes_target = {
'boxes': [],
'labels': [],
'area': [],
'size': torch.tensor([h, w])
}
x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split(
',')
region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
boxes_target['boxes'] = torch.tensor(
[[float(x0), float(y0), float(x1),
float(y1)]])
boxes_target['labels'] = np.array([0])
area = [(float(x1) - float(x0)) * (float(y1) - float(y0))]
boxes_target['area'] = torch.tensor(area)
patch_image, patch_boxes = self.positioning_transform(
image, boxes_target)
resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1]
quant_x0 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round()))
quant_y0 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round()))
quant_x1 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round()))
quant_y1 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round()))
region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1,
quant_y1)
src_caption = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)
src_item = self.tokenize_text(text)
target_item = self.tokenize_text(
region_coord, add_bos=False) # !!! use_bpe=False
prev_output_item = torch.cat([self.bos_item, target_item[:-1]])
sample = {
'source': src_item,
'patch_image': patch_image,
'patch_mask': torch.tensor([True]),
'target': target_item,
'prev_output_tokens': prev_output_item,
'w_resize_ratio': resize_w / w,
'h_resize_ratio': resize_h / h,
'region_coord': region
}
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
w, h = image.size
patch_image = self.patch_resize_transform(image)
w_resize_ratio = torch.tensor(self.patch_image_size / w)
h_resize_ratio = torch.tensor(self.patch_image_size / h)
src_caption = self.pre_caption(data['text'], self.max_src_length)
src_caption = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)
@@ -56,4 +132,10 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
'w_resize_ratio': w_resize_ratio,
'h_resize_ratio': h_resize_ratio,
}
if 'region_coord' in self.column_map and self.column_map[
'region_coord'] in data:
x0, y0, x1, y1 = data[
self.column_map['region_coord']].strip().split(',')
sample['label'] = [float(x0), float(y0), float(x1), float(y1)]
return sample

View File

@@ -38,10 +38,52 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
tgt_item = self.tokenize_text(
' {}'.format(sample['label']), add_bos=False, add_eos=False)
if self.prompt_type == 'none':
prev_output_item = torch.cat([self.bos_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'src':
prev_output_item = torch.cat([sample['source'], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'prev_output':
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
else:
raise NotImplementedError
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
sample['prev_output_tokens'] = prev_output_item
sample['target'] = target_item
if self.constraint_trie is not None:
constraint_mask = torch.zeros(
(len(target_item), len(self.tgt_dict))).bool()
start_idx = len(target_item) - len(tgt_item) - 1
for i in range(
len(target_item) - len(tgt_item) - 1, len(target_item)):
constraint_prefix_token = [
self.tgt_dict.bos()
] + target_item[start_idx:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
text = ' {}'.format(data['text'])
text = ' {}'.format(data[self.column_map['text']])
inputs = self.tokenize_text(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
@@ -57,4 +99,6 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
if 'answer' in self.column_map and self.column_map['answer'] in data:
sample['label'] = data[self.column_map['answer']]
return sample

View File

@@ -34,6 +34,7 @@ class OFATrainer(EpochBasedTrainer):
self,
model: Optional[Union[TorchModel, nn.Module, str]] = None,
cfg_file: Optional[str] = None,
cfg_modify_fn: Optional[Callable] = None,
arg_parse_fn: Optional[Callable] = None,
data_collator: Optional[Union[Callable, Dict[str,
Callable]]] = None,
@@ -49,7 +50,8 @@ class OFATrainer(EpochBasedTrainer):
**kwargs):
model = Model.from_pretrained(model, revision=model_revision)
model_dir = model.model_dir
cfg = Config.from_file(cfg_file)
self.cfg_modify_fn = cfg_modify_fn
cfg = self.rebuild_config(Config.from_file(cfg_file))
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0:
work_dir = cfg.train.work_dir
else:
@@ -57,10 +59,12 @@ class OFATrainer(EpochBasedTrainer):
tokenizer_files = {
'zh': [
'tokenizer.json', 'tokenizer_config.json', 'vocab.txt',
'config.json'
'config.json', 'ans2label.json'
],
'en': [
'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json',
'ans2label.json'
],
'en':
['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'],
}
for filename in tokenizer_files[cfg.model.get('language', 'en')]:
finetune_file = os.path.join(work_dir, filename)
@@ -127,6 +131,11 @@ class OFATrainer(EpochBasedTrainer):
**kwargs,
)
def rebuild_config(self, cfg: Config):
if self.cfg_modify_fn is not None:
cfg = self.cfg_modify_fn(cfg)
return cfg
def train_step(self, model, inputs):
model.train()
loss, sample_size, logging_output = self.criterion(model, inputs)

View File

@@ -9,6 +9,7 @@ from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.hub import read_config
from modelscope.utils.test_utils import test_level
@@ -78,6 +79,7 @@ class TestOfaTrainer(unittest.TestCase):
json.dump(self.finetune_cfg, writer)
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
args = dict(
model=pretrained_model,
work_dir=WORKSPACE,