mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
[to #42322933] solve memory error for translation finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10713843 * [to #42322933] solve memory error for translation finetune
This commit is contained in:
committed by
yingda.chen
parent
4e4faa9a30
commit
d6ea41fb70
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -122,8 +123,7 @@ class CsanmtTranslationTrainer(BaseTrainer):
|
||||
self.params['scale_l1'] = self.cfg['train']['scale_l1']
|
||||
self.params['scale_l2'] = self.cfg['train']['scale_l2']
|
||||
self.params['train_max_len'] = self.cfg['train']['train_max_len']
|
||||
self.params['max_training_steps'] = self.cfg['train'][
|
||||
'max_training_steps']
|
||||
self.params['num_of_epochs'] = self.cfg['train']['num_of_epochs']
|
||||
self.params['save_checkpoints_steps'] = self.cfg['train'][
|
||||
'save_checkpoints_steps']
|
||||
self.params['num_of_samples'] = self.cfg['train']['num_of_samples']
|
||||
@@ -144,14 +144,15 @@ class CsanmtTranslationTrainer(BaseTrainer):
|
||||
vocab_src = osp.join(self.model_dir, self.params['vocab_src'])
|
||||
vocab_trg = osp.join(self.model_dir, self.params['vocab_trg'])
|
||||
|
||||
epoch = 0
|
||||
iteration = 0
|
||||
|
||||
with self._session.as_default() as tf_session:
|
||||
while True:
|
||||
iteration += 1
|
||||
if iteration >= self.params['max_training_steps']:
|
||||
epoch += 1
|
||||
if epoch >= self.params['num_of_epochs']:
|
||||
break
|
||||
|
||||
tf.logging.info('%s: Epoch %i' % (__name__, epoch))
|
||||
train_input_fn = input_fn(
|
||||
train_src,
|
||||
train_trg,
|
||||
@@ -160,36 +161,44 @@ class CsanmtTranslationTrainer(BaseTrainer):
|
||||
batch_size_words=self.params['train_batch_size_words'],
|
||||
max_len=self.params['train_max_len'],
|
||||
num_gpus=self.params['num_gpus']
|
||||
if self.params['num_gpus'] > 0 else 1,
|
||||
if self.params['num_gpus'] > 1 else 1,
|
||||
is_train=True,
|
||||
session=tf_session,
|
||||
iteration=iteration)
|
||||
epoch=epoch)
|
||||
|
||||
features, labels = train_input_fn
|
||||
|
||||
features_batch, labels_batch = tf_session.run(
|
||||
[features, labels])
|
||||
try:
|
||||
while True:
|
||||
features_batch, labels_batch = tf_session.run(
|
||||
[features, labels])
|
||||
iteration += 1
|
||||
feed_dict = {
|
||||
self.source_wids: features_batch,
|
||||
self.target_wids: labels_batch
|
||||
}
|
||||
sess_outputs = self._session.run(
|
||||
self.output, feed_dict=feed_dict)
|
||||
loss_step = sess_outputs['loss']
|
||||
logger.info('Iteration: {}, step loss: {:.6f}'.format(
|
||||
iteration, loss_step))
|
||||
|
||||
feed_dict = {
|
||||
self.source_wids: features_batch,
|
||||
self.target_wids: labels_batch
|
||||
}
|
||||
sess_outputs = self._session.run(
|
||||
self.output, feed_dict=feed_dict)
|
||||
loss_step = sess_outputs['loss']
|
||||
logger.info('Iteration: {}, step loss: {:.6f}'.format(
|
||||
iteration, loss_step))
|
||||
if iteration % self.params[
|
||||
'save_checkpoints_steps'] == 0:
|
||||
tf.logging.info('%s: Saving model on step: %d.' %
|
||||
(__name__, iteration))
|
||||
ck_path = self.model_dir + 'model.ckpt'
|
||||
self.model_saver.save(
|
||||
tf_session,
|
||||
ck_path,
|
||||
global_step=tf.train.get_global_step())
|
||||
|
||||
if iteration % self.params['save_checkpoints_steps'] == 0:
|
||||
tf.logging.info('%s: Saving model on step: %d.' %
|
||||
(__name__, iteration))
|
||||
ck_path = self.model_dir + 'model.ckpt'
|
||||
self.model_saver.save(
|
||||
tf_session,
|
||||
ck_path,
|
||||
global_step=tf.train.get_global_step())
|
||||
except tf.errors.OutOfRangeError:
|
||||
tf.logging.info('epoch %d end!' % (epoch))
|
||||
|
||||
tf.logging.info('%s: NMT training completed at time: %s.')
|
||||
tf.logging.info(
|
||||
'%s: NMT training completed at time: %s.' %
|
||||
(__name__, time.asctime(time.localtime(time.time()))))
|
||||
|
||||
def evaluate(self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
@@ -222,7 +231,7 @@ def input_fn(src_file,
|
||||
num_gpus=1,
|
||||
is_train=True,
|
||||
session=None,
|
||||
iteration=None):
|
||||
epoch=None):
|
||||
src_vocab = tf.lookup.StaticVocabularyTable(
|
||||
tf.lookup.TextFileInitializer(
|
||||
src_vocab_file,
|
||||
@@ -291,7 +300,7 @@ def input_fn(src_file,
|
||||
|
||||
if is_train:
|
||||
session.run(iterator.initializer)
|
||||
if iteration == 1:
|
||||
if epoch == 1:
|
||||
session.run(tf.tables_initializer())
|
||||
return features, labels
|
||||
|
||||
|
||||
@@ -6,11 +6,17 @@ from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TranslationTest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_csanmt_translation_zh2en'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
trainer = CsanmtTranslationTrainer(model=self.model_id)
|
||||
def test_run_with_model_name_for_en2zh(self):
|
||||
model_id = 'damo/nlp_csanmt_translation_en2zh'
|
||||
trainer = CsanmtTranslationTrainer(model=model_id)
|
||||
trainer.train()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name_for_en2fr(self):
|
||||
model_id = 'damo/nlp_csanmt_translation_en2fr'
|
||||
trainer = CsanmtTranslationTrainer(model=model_id)
|
||||
trainer.train()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user