mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
update
This commit is contained in:
@@ -129,8 +129,7 @@ class OfaForAllTasks(TorchModel):
|
||||
result_l = list()
|
||||
for cap in caption:
|
||||
result_l.append(cap.translate(self.transtab).strip())
|
||||
input[OutputKeys.CAPTION] = caption
|
||||
|
||||
input[OutputKeys.CAPTION] = result_l
|
||||
return input
|
||||
|
||||
def _text_gen_inference(self, input):
|
||||
@@ -182,6 +181,8 @@ class OfaForAllTasks(TorchModel):
|
||||
encoder_input[key] = input['net_input'][key]
|
||||
encoder_out = self.model.encoder(**encoder_input)
|
||||
valid_result = []
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l):
|
||||
valid_size = len(val_ans)
|
||||
valid_tgt_items = [
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright 2022 The OFA-Sys Team.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the Apache 2.0 license
|
||||
# found in the LICENSE file in the root directory.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class OFAFileDataset:
|
||||
|
||||
def __init__(self,
|
||||
file_path,
|
||||
selected_col_ids=None,
|
||||
dtypes=None,
|
||||
separator='\t',
|
||||
cached_index=False):
|
||||
self.file_path = file_path
|
||||
assert os.path.exists(
|
||||
self.file_path), 'Error: The local datafile {} not exists!'.format(
|
||||
self.file_path)
|
||||
|
||||
self.separator = separator
|
||||
if selected_col_ids is None:
|
||||
# default to all fields
|
||||
self.selected_col_ids = list(
|
||||
range(
|
||||
len(
|
||||
open(self.file_path).readline().rstrip('\n').split(
|
||||
self.separator))))
|
||||
else:
|
||||
self.selected_col_ids = [
|
||||
int(col_id) for col_id in selected_col_ids.split(',')
|
||||
]
|
||||
if dtypes is None:
|
||||
# default to str
|
||||
self.dtypes = [str for col_id in self.selected_col_ids]
|
||||
else:
|
||||
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')]
|
||||
assert len(self.dtypes) == len(self.selected_col_ids)
|
||||
|
||||
self.data_cnt = 0
|
||||
try:
|
||||
self.slice_id = torch.distributed.get_rank()
|
||||
self.slice_count = torch.distributed.get_world_size()
|
||||
except Exception:
|
||||
self.slice_id = 0
|
||||
self.slice_count = 1
|
||||
self.cached_index = cached_index
|
||||
self._init_seek_index()
|
||||
self._reader = self._get_reader()
|
||||
print('file {} slice_id {} row count {} total row count {}'.format(
|
||||
self.file_path, self.slice_id, self.row_count,
|
||||
self.total_row_count))
|
||||
|
||||
def _init_seek_index(self):
|
||||
if self.cached_index:
|
||||
cache_path = '{}.index'.format(self.file_path)
|
||||
assert os.path.exists(
|
||||
cache_path), 'cache file {} not exists!'.format(cache_path)
|
||||
self.total_row_count, self.lineid_to_offset = pickle.load(
|
||||
open(cache_path, 'rb'))
|
||||
print(
|
||||
'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping'
|
||||
.format(self.file_path, self.slice_id))
|
||||
else:
|
||||
# make an iteration over the file to get row_count and line_idx-to-offset mapping
|
||||
fp = open(self.file_path, 'r')
|
||||
print(
|
||||
'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping'
|
||||
.format(self.file_path, self.slice_id))
|
||||
self.total_row_count = 0
|
||||
offset = 0
|
||||
self.lineid_to_offset = []
|
||||
for line in fp:
|
||||
self.lineid_to_offset.append(offset)
|
||||
self.total_row_count += 1
|
||||
offset += len(line.encode('utf-8'))
|
||||
pickle.dump(self.lineid_to_offset,
|
||||
open('{}.index'.format(self.file_path), 'wb'))
|
||||
self._compute_start_pos_and_row_count()
|
||||
print(
|
||||
'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping'
|
||||
.format(self.file_path, self.slice_id))
|
||||
|
||||
def _compute_start_pos_and_row_count(self):
|
||||
self.row_count = self.total_row_count // self.slice_count
|
||||
if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
|
||||
self.row_count += 1
|
||||
self.start_pos = self.row_count * self.slice_id
|
||||
else:
|
||||
self.start_pos = self.row_count * self.slice_id + (
|
||||
self.total_row_count - self.row_count * self.slice_count)
|
||||
|
||||
def _get_reader(self):
|
||||
fp = open(self.file_path, 'r')
|
||||
fp.seek(self.lineid_to_offset[self.start_pos])
|
||||
return fp
|
||||
|
||||
def _seek(self, offset=0):
|
||||
try:
|
||||
print('slice_id {} seek offset {}'.format(self.slice_id,
|
||||
self.start_pos + offset))
|
||||
self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
|
||||
self.data_cnt = offset
|
||||
except Exception:
|
||||
print('slice_id {} seek offset {}'.format(self.slice_id, offset))
|
||||
self._reader.seek(self.lineid_to_offset[offset])
|
||||
self.data_cnt = offset
|
||||
|
||||
def __del__(self):
|
||||
self._reader.close()
|
||||
|
||||
def __len__(self):
|
||||
return self.row_count
|
||||
|
||||
def get_total_row_count(self):
|
||||
return self.total_row_count
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.data_cnt == self.row_count:
|
||||
print('reach the end of datafile, start a new reader')
|
||||
self.data_cnt = 0
|
||||
self._reader = self._get_reader()
|
||||
column_l = self._reader.readline().rstrip('\n').split(self.separator)
|
||||
self.data_cnt += 1
|
||||
column_l = [
|
||||
dtype(column_l[col_id])
|
||||
for col_id, dtype in zip(self.selected_col_ids, self.dtypes)
|
||||
]
|
||||
return column_l
|
||||
@@ -65,7 +65,7 @@ class OFATrainer(EpochBasedTrainer):
|
||||
kwargs['launcher'] = cfg.train.launcher
|
||||
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False):
|
||||
kwargs['use_fp16'] = cfg.train.use_fp16
|
||||
|
||||
kwargs['to_tensor'] = False
|
||||
super().__init__(
|
||||
cfg_file=cfg_file,
|
||||
model=model,
|
||||
|
||||
@@ -167,19 +167,20 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
device_name = f'cuda:{local_rank}'
|
||||
|
||||
self.device = create_device(device_name)
|
||||
|
||||
self.train_dataset = self.to_task_dataset(
|
||||
train_dataset,
|
||||
mode=ModeKeys.TRAIN,
|
||||
task_data_config=self.cfg.dataset.get('train', None) if hasattr(
|
||||
self.cfg, 'dataset') else None,
|
||||
preprocessor=self.train_preprocessor)
|
||||
preprocessor=self.train_preprocessor,
|
||||
**kwargs)
|
||||
self.eval_dataset = self.to_task_dataset(
|
||||
eval_dataset,
|
||||
mode=ModeKeys.EVAL,
|
||||
task_data_config=self.cfg.dataset.get('val', None) if hasattr(
|
||||
self.cfg, 'dataset') else None,
|
||||
preprocessor=self.eval_preprocessor)
|
||||
preprocessor=self.eval_preprocessor,
|
||||
**kwargs)
|
||||
|
||||
self.train_data_collator, self.eval_default_collate = None, None
|
||||
if isinstance(data_collator, Mapping):
|
||||
@@ -305,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
datasets: Union[Dataset, List[Dataset]],
|
||||
mode: str,
|
||||
task_data_config: Config = None,
|
||||
preprocessor: Optional[Preprocessor] = None):
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
**kwargs):
|
||||
"""Build the task specific dataset processor for this trainer.
|
||||
|
||||
Returns: The task dataset processor for the task. If no result for the very model-type and task,
|
||||
the default TaskDataset will be returned.
|
||||
"""
|
||||
try:
|
||||
to_tensor = kwargs.get('to_tensor', True)
|
||||
if not datasets:
|
||||
return datasets
|
||||
if isinstance(datasets, TorchTaskDataset):
|
||||
@@ -327,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
return datasets.to_torch_dataset(
|
||||
task_data_config=task_data_config,
|
||||
task_name=self.cfg.task,
|
||||
preprocessors=preprocessor)
|
||||
preprocessors=preprocessor,
|
||||
to_tensor=to_tensor)
|
||||
elif isinstance(datasets, List) and isinstance(
|
||||
datasets[0], MsDataset):
|
||||
if task_data_config is None:
|
||||
@@ -341,7 +345,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
d.to_torch_dataset(
|
||||
task_data_config=task_data_config,
|
||||
task_name=self.cfg.task,
|
||||
preprocessors=preprocessor) for d in datasets
|
||||
preprocessors=preprocessor,
|
||||
to_tensor=to_tensor) for d in datasets
|
||||
]
|
||||
cfg = ConfigDict(
|
||||
type=self.cfg.task, mode=mode, datasets=datasets)
|
||||
|
||||
@@ -94,8 +94,11 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_text_classification_with_model(self):
|
||||
# model = Model.from_pretrained(
|
||||
# 'damo/ofa_text-classification_mnli_large_en')
|
||||
model = Model.from_pretrained(
|
||||
'damo/ofa_text-classification_mnli_large_en')
|
||||
'/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en'
|
||||
)
|
||||
ofa_pipe = pipeline(Tasks.text_classification, model=model)
|
||||
text = 'One of our number will carry out your instructions minutely.'
|
||||
text2 = 'A member of my team will execute your orders with immense precision.'
|
||||
|
||||
@@ -12,11 +12,10 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt'
|
||||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en'
|
||||
self.trainer = OFATrainer(model_id)
|
||||
self.trainer = OFATrainer(model_id, launcher='pytorch')
|
||||
self.trainer.train()
|
||||
if os.path.exists(self.trainer.work_dir):
|
||||
shutil.rmtree(self.trainer.work_dir)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user