[to #42322933] solve memory error for translation finetune

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10713843

    * [to #42322933] solve memory error for translation finetune
This commit is contained in:
xiangpeng.wxp
2022-11-14 20:31:29 +08:00
committed by yingda.chen
parent 4e4faa9a30
commit d6ea41fb70
2 changed files with 47 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import time
from typing import Dict, Optional
import tensorflow as tf
@@ -122,8 +123,7 @@ class CsanmtTranslationTrainer(BaseTrainer):
self.params['scale_l1'] = self.cfg['train']['scale_l1']
self.params['scale_l2'] = self.cfg['train']['scale_l2']
self.params['train_max_len'] = self.cfg['train']['train_max_len']
self.params['max_training_steps'] = self.cfg['train'][
'max_training_steps']
self.params['num_of_epochs'] = self.cfg['train']['num_of_epochs']
self.params['save_checkpoints_steps'] = self.cfg['train'][
'save_checkpoints_steps']
self.params['num_of_samples'] = self.cfg['train']['num_of_samples']
@@ -144,14 +144,15 @@ class CsanmtTranslationTrainer(BaseTrainer):
vocab_src = osp.join(self.model_dir, self.params['vocab_src'])
vocab_trg = osp.join(self.model_dir, self.params['vocab_trg'])
epoch = 0
iteration = 0
with self._session.as_default() as tf_session:
while True:
iteration += 1
if iteration >= self.params['max_training_steps']:
epoch += 1
if epoch >= self.params['num_of_epochs']:
break
tf.logging.info('%s: Epoch %i' % (__name__, epoch))
train_input_fn = input_fn(
train_src,
train_trg,
@@ -160,36 +161,44 @@ class CsanmtTranslationTrainer(BaseTrainer):
batch_size_words=self.params['train_batch_size_words'],
max_len=self.params['train_max_len'],
num_gpus=self.params['num_gpus']
if self.params['num_gpus'] > 0 else 1,
if self.params['num_gpus'] > 1 else 1,
is_train=True,
session=tf_session,
iteration=iteration)
epoch=epoch)
features, labels = train_input_fn
features_batch, labels_batch = tf_session.run(
[features, labels])
try:
while True:
features_batch, labels_batch = tf_session.run(
[features, labels])
iteration += 1
feed_dict = {
self.source_wids: features_batch,
self.target_wids: labels_batch
}
sess_outputs = self._session.run(
self.output, feed_dict=feed_dict)
loss_step = sess_outputs['loss']
logger.info('Iteration: {}, step loss: {:.6f}'.format(
iteration, loss_step))
feed_dict = {
self.source_wids: features_batch,
self.target_wids: labels_batch
}
sess_outputs = self._session.run(
self.output, feed_dict=feed_dict)
loss_step = sess_outputs['loss']
logger.info('Iteration: {}, step loss: {:.6f}'.format(
iteration, loss_step))
if iteration % self.params[
'save_checkpoints_steps'] == 0:
tf.logging.info('%s: Saving model on step: %d.' %
(__name__, iteration))
ck_path = self.model_dir + 'model.ckpt'
self.model_saver.save(
tf_session,
ck_path,
global_step=tf.train.get_global_step())
if iteration % self.params['save_checkpoints_steps'] == 0:
tf.logging.info('%s: Saving model on step: %d.' %
(__name__, iteration))
ck_path = self.model_dir + 'model.ckpt'
self.model_saver.save(
tf_session,
ck_path,
global_step=tf.train.get_global_step())
except tf.errors.OutOfRangeError:
tf.logging.info('epoch %d end!' % (epoch))
tf.logging.info('%s: NMT training completed at time: %s.')
tf.logging.info(
'%s: NMT training completed at time: %s.' %
(__name__, time.asctime(time.localtime(time.time()))))
def evaluate(self,
checkpoint_path: Optional[str] = None,
@@ -222,7 +231,7 @@ def input_fn(src_file,
num_gpus=1,
is_train=True,
session=None,
iteration=None):
epoch=None):
src_vocab = tf.lookup.StaticVocabularyTable(
tf.lookup.TextFileInitializer(
src_vocab_file,
@@ -291,7 +300,7 @@ def input_fn(src_file,
if is_train:
session.run(iterator.initializer)
if iteration == 1:
if epoch == 1:
session.run(tf.tables_initializer())
return features, labels

View File

@@ -6,11 +6,17 @@ from modelscope.utils.test_utils import test_level
class TranslationTest(unittest.TestCase):
model_id = 'damo/nlp_csanmt_translation_zh2en'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
trainer = CsanmtTranslationTrainer(model=self.model_id)
def test_run_with_model_name_for_en2zh(self):
model_id = 'damo/nlp_csanmt_translation_en2zh'
trainer = CsanmtTranslationTrainer(model=model_id)
trainer.train()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_for_en2fr(self):
model_id = 'damo/nlp_csanmt_translation_en2fr'
trainer = CsanmtTranslationTrainer(model=model_id)
trainer.train()