mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
test finetune
This commit is contained in:
@@ -164,6 +164,7 @@ class Trainers(object):
|
||||
|
||||
# multi-modal trainers
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
ofa_tasks = 'ofa-tasks-trainer'
|
||||
|
||||
# cv trainers
|
||||
image_instance_segmentation = 'image-instance-segmentation'
|
||||
|
||||
@@ -398,10 +398,27 @@ class SequenceGenerator(nn.Module):
|
||||
if self.should_set_src_lengths:
|
||||
self.search.set_src_lengths(src_lengths)
|
||||
|
||||
if self.repeat_ngram_blocker is not None and step > prefix_tokens.size(
|
||||
1):
|
||||
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz,
|
||||
beam_size, step)
|
||||
if self.repeat_ngram_blocker is not None:
|
||||
# process prefix_tokens
|
||||
p_toks_len = prefix_tokens.ne(self.pad).sum(
|
||||
dim=1) if prefix_tokens is not None else None
|
||||
if p_toks_len is not None:
|
||||
p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat(
|
||||
1, beam_size).view(-1)
|
||||
no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size
|
||||
out_prefix = p_toks_len_beam < (
|
||||
step + no_repeat_ngram_size - 1)
|
||||
else:
|
||||
out_prefix = [True] * bsz * beam_size
|
||||
ngram_blocker_tokens = tokens[out_prefix]
|
||||
ngram_blocker_lprobs = lprobs[out_prefix]
|
||||
ngram_blocker_bsz = out_prefix.sum() // beam_size
|
||||
lprobs[out_prefix] = self.repeat_ngram_blocker(
|
||||
tokens=ngram_blocker_tokens,
|
||||
lprobs=ngram_blocker_lprobs,
|
||||
bsz=ngram_blocker_bsz,
|
||||
beam_size=beam_size,
|
||||
step=step)
|
||||
|
||||
# Shape: (batch, cand_size)
|
||||
cand_scores, cand_indices, cand_beams = self.search.step(
|
||||
|
||||
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.activations import ACT2FN
|
||||
@@ -40,6 +41,8 @@ logger = logging.get_logger(__name__)
|
||||
_CHECKPOINT_FOR_DOC = 'ofa-base'
|
||||
_CONFIG_FOR_DOC = 'OFAConfig'
|
||||
_TOKENIZER_FOR_DOC = 'OFATokenizer'
|
||||
TORCH_VERSION = version.parse(torch.__version__)
|
||||
TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1')
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
@@ -114,8 +117,11 @@ def make_image_bucket_position(bucket_size, num_relative_distance):
|
||||
"""
|
||||
coords_h = torch.arange(bucket_size)
|
||||
coords_w = torch.arange(bucket_size)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w],
|
||||
indexing='ij')) # 2, Wh, Ww
|
||||
if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION:
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
||||
else:
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - \
|
||||
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
|
||||
@@ -11,7 +11,7 @@ from modelscope.metainfo import Preprocessors
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModelFile, Tasks
|
||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
from .ofa import * # noqa
|
||||
@@ -27,11 +27,16 @@ __all__ = [
|
||||
Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor)
|
||||
class OfaPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
preprocess_mapping = {
|
||||
@@ -59,8 +64,8 @@ class OfaPreprocessor(Preprocessor):
|
||||
model_dir)
|
||||
self.cfg = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
self.preprocess = preprocess_mapping[self.cfg.task](self.cfg,
|
||||
model_dir)
|
||||
self.preprocess = preprocess_mapping[self.cfg.task](
|
||||
cfg=self.cfg, model_dir=model_dir, mode=mode)
|
||||
self.keys = input_key_mapping[self.cfg.task]
|
||||
self.tokenizer = self.preprocess.tokenizer
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from .utils.random_help import set_torch_seed
|
||||
|
||||
class OfaBasePreprocessor:
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self, cfg, model_dir, mode, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
@@ -21,6 +21,7 @@ class OfaBasePreprocessor:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
self.cfg = cfg
|
||||
self.mode = mode
|
||||
self.language = self.cfg.model.get('language', 'en')
|
||||
if self.language == 'en':
|
||||
tokenizer = OFATokenizer.from_pretrained(model_dir)
|
||||
|
||||
@@ -12,16 +12,21 @@ from .base import OfaBasePreprocessor
|
||||
|
||||
class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaImageCaptioningPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
|
||||
@@ -6,21 +6,27 @@ from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaImageClassificationPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaSummarizationPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaSummarizationPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
source = super().pre_caption(
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaTextClassificationPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaTextClassificationPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
text1 = ' '.join(
|
||||
|
||||
@@ -3,21 +3,27 @@ from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaTextToImageSynthesisPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
self.max_src_length = 64
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -6,21 +6,27 @@ from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaVisualEntailmentPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
|
||||
@@ -6,21 +6,27 @@ from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaVisualGroundingPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
|
||||
@@ -6,21 +6,27 @@ from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
|
||||
class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir, split, *args, **kwargs):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path,
|
||||
split: data phase
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaVisualQuestionAnsweringPreprocessor,
|
||||
self).__init__(cfg, model_dir, split, *args, **kwargs)
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .ofa_trainer import OFATrainer
|
||||
|
||||
@@ -78,6 +78,8 @@ class OFAFileDataset:
|
||||
self.lineid_to_offset.append(offset)
|
||||
self.total_row_count += 1
|
||||
offset += len(line.encode('utf-8'))
|
||||
pickle.dump(self.lineid_to_offset,
|
||||
open('{}.index'.format(self.file_path), 'rb'))
|
||||
self._compute_start_pos_and_row_count()
|
||||
print(
|
||||
'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping'
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
from os import path as osp
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.preprocessors.multi_modal import OfaPreprocessor
|
||||
from modelscope.preprocessors.ofa.utils.collate import collate_fn
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.constant import ModeKeys, ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import init_dist
|
||||
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
|
||||
OFADataset, get_schedule)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
|
||||
class OFATrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, model: str, *args, **kwargs):
|
||||
model = Model.from_pretrained(model)
|
||||
super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION))
|
||||
self.model_dir = model.model_dir
|
||||
self.model = model.model
|
||||
self.device_id = 0
|
||||
self.total_epoch = self.cfg.train.epoch
|
||||
self.train_batch_size = self.cfg.train.batch_size
|
||||
self.val_batch_size = self.cfg.evaluation.batch_size
|
||||
self.save_dir = self.cfg.train.save_dir
|
||||
init_dist(launcher='pytorch')
|
||||
self.train_dataset = OFADataset(
|
||||
file_path=self.cfg.dataset.train_set,
|
||||
selected_id_keys=self.cfg.dataset.selected_id_keys,
|
||||
preprocessor=OfaPreprocessor(
|
||||
model_dir=self.model_dir, split=ModeKeys.TRAIN),
|
||||
)
|
||||
self.val_dataset = OFADataset(
|
||||
file_path=self.cfg.dataset.valid_set,
|
||||
selected_id_keys=self.cfg.dataset.selected_id_keys,
|
||||
preprocessor=OfaPreprocessor(
|
||||
model_dir=self.model_dir, split=ModeKeys.EVAL),
|
||||
)
|
||||
epoch_steps = len(
|
||||
self.train_dataset) // self.cfg.train.gradient_accumulation_steps
|
||||
self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch
|
||||
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion(
|
||||
self.cfg.train.criterion)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
assert dist.is_initialized()
|
||||
|
||||
self.model.train()
|
||||
self.model.to(self.device_id)
|
||||
ddp_model = torch.nn.parallel.DistributedDataParallel(
|
||||
self.model, device_ids=[
|
||||
self.device_id,
|
||||
])
|
||||
|
||||
optimizer = transformers.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.cfg.train.lr,
|
||||
weight_decay=self.cfg.train.weight_decay,
|
||||
correct_bias=False,
|
||||
)
|
||||
scheduler_class, scheduler_args = get_schedule(self.cfg.train)
|
||||
if scheduler_class is not None:
|
||||
lr_scheduler = scheduler_class(**{'optimizer': optimizer},
|
||||
**scheduler_args)
|
||||
else:
|
||||
lr_scheduler = None
|
||||
for epoch in range(self.total_epoch):
|
||||
train_sampler = DistributedSampler(
|
||||
dataset=self.train_dataset, shuffle=True)
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
train_params = {
|
||||
'pin_memory': True,
|
||||
'collate_fn': collate_fn,
|
||||
'batch_size': self.train_batch_size,
|
||||
'shuffle': False,
|
||||
'drop_last': True,
|
||||
'sampler': train_sampler,
|
||||
'num_workers': 2,
|
||||
}
|
||||
|
||||
train_loader = DataLoader(self.train_dataset, **train_params)
|
||||
|
||||
for idx, batch in enumerate(train_loader, start=1):
|
||||
model_outputs = ddp_model(**batch)
|
||||
loss, sample_size, logging_output = self.criterion(
|
||||
model_outputs, batch)
|
||||
loss.backward()
|
||||
optimizer.zero_grad()
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if idx % 10 == 0:
|
||||
logger.info(
|
||||
'epoch: {}, train batch {}/{}, loss={:.5f}'.format(
|
||||
epoch, idx, len(train_loader), loss.item()))
|
||||
if dist.get_rank() == 0:
|
||||
os.makedirs(self.ckpt_dir, exist_ok=True)
|
||||
torch.save(ddp_model.module.state_dict(),
|
||||
f'{self.ckpt_dir}/epoch{epoch}.bin')
|
||||
|
||||
def evaluate(self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
pass
|
||||
|
||||
@@ -2,36 +2,36 @@
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the Apache 2.0 license
|
||||
# found in the LICENSE file in the root directory.
|
||||
from os import path as osp
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.preprocessors.multi_modal import OfaPreprocessor
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks
|
||||
from .ofa_file_dataset import OFAFileDataset
|
||||
|
||||
|
||||
class OFADataset(Dataset):
|
||||
|
||||
def __init__(self,
|
||||
model_dir,
|
||||
file_path,
|
||||
file_path: str,
|
||||
preprocessor: OfaPreprocessor,
|
||||
selected_id_keys: str,
|
||||
dtypes=None,
|
||||
separator='\t',
|
||||
cached_index=False,
|
||||
split=ModeKeys.TRAIN,
|
||||
**kwargs):
|
||||
self.cfg = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
selected_col_ids = self.cfg.dataset.selected_col_ids
|
||||
selected_col_keys = self.cfg.dataset.selected_col_keys
|
||||
|
||||
assert selected_col_ids is not None
|
||||
assert selected_col_keys is not None
|
||||
self.selected_col_key_l = selected_col_keys.split(',')
|
||||
assert len(self.selected_col_key_l) == len(selected_col_ids.split(','))
|
||||
assert selected_id_keys is not None
|
||||
selected_col_ids = list()
|
||||
selected_col_keys = list()
|
||||
for id_key in selected_id_keys.split(','):
|
||||
id, key = id_key.split(':')
|
||||
selected_col_ids.append(id)
|
||||
selected_col_keys.append(key)
|
||||
|
||||
self.dataset = OFAFileDataset(
|
||||
file_path=file_path,
|
||||
@@ -39,14 +39,278 @@ class OFADataset(Dataset):
|
||||
dtypes=dtypes,
|
||||
separator=separator,
|
||||
cached_index=cached_index)
|
||||
self.preprocessor = OfaPreprocessor(model_dir, split)
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
value_l = self.dataset[index]
|
||||
values = self.dataset[index]
|
||||
data = dict()
|
||||
for key, value in zip(self.selected_col_key_l, value_l):
|
||||
for key, value in zip(self.selected_col_keys, values):
|
||||
data[key] = value
|
||||
return self.preprocessor(data)
|
||||
|
||||
|
||||
def construct_rdrop_sample(x):
|
||||
if isinstance(x, dict):
|
||||
for key in x:
|
||||
x[key] = construct_rdrop_sample(x[key])
|
||||
return x
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x.repeat(2, *([1] * (x.dim() - 1)))
|
||||
elif isinstance(x, int):
|
||||
return x * 2
|
||||
elif isinstance(x, np.ndarray):
|
||||
return x.repeat(2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def kl_loss(p, q):
|
||||
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
||||
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
||||
loss = (p_loss + q_loss) / 2
|
||||
return loss
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs,
|
||||
target,
|
||||
epsilon,
|
||||
update_num,
|
||||
reduce=True,
|
||||
drop_worst_ratio=0.0,
|
||||
drop_worst_after=0,
|
||||
use_rdrop=False,
|
||||
reg_alpha=1.0,
|
||||
constraint_masks=None,
|
||||
constraint_start=None,
|
||||
constraint_end=None):
|
||||
if target.dim() == lprobs.dim() - 1:
|
||||
target = target.unsqueeze(-1)
|
||||
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
||||
if constraint_masks is not None:
|
||||
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(
|
||||
dim=-1, keepdim=True).squeeze(-1)
|
||||
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
||||
elif constraint_start is not None and constraint_end is not None:
|
||||
constraint_range = [0, 1, 2, 3] + list(
|
||||
range(constraint_start, constraint_end))
|
||||
smooth_loss = -lprobs[:, constraint_range].sum(
|
||||
dim=-1, keepdim=True).squeeze(-1)
|
||||
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
||||
else:
|
||||
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
||||
eps_i = epsilon / (lprobs.size(-1) - 1)
|
||||
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
||||
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
||||
if use_rdrop:
|
||||
true_batch_size = loss.size(0) // 2
|
||||
_, indices = torch.topk(
|
||||
loss[:true_batch_size],
|
||||
k=int(true_batch_size * (1 - drop_worst_ratio)),
|
||||
largest=False)
|
||||
loss = torch.cat([loss[indices], loss[indices + true_batch_size]])
|
||||
nll_loss = torch.cat(
|
||||
[nll_loss[indices], nll_loss[indices + true_batch_size]])
|
||||
lprobs = torch.cat(
|
||||
[lprobs[indices], lprobs[indices + true_batch_size]])
|
||||
else:
|
||||
loss, indices = torch.topk(
|
||||
loss,
|
||||
k=int(loss.shape[0] * (1 - drop_worst_ratio)),
|
||||
largest=False)
|
||||
nll_loss = nll_loss[indices]
|
||||
lprobs = lprobs[indices]
|
||||
|
||||
ntokens = loss.numel()
|
||||
nll_loss = nll_loss.sum()
|
||||
loss = loss.sum()
|
||||
if use_rdrop:
|
||||
true_batch_size = lprobs.size(0) // 2
|
||||
p = lprobs[:true_batch_size]
|
||||
q = lprobs[true_batch_size:]
|
||||
if constraint_start is not None and constraint_end is not None:
|
||||
constraint_range = [0, 1, 2, 3] + list(
|
||||
range(constraint_start, constraint_end))
|
||||
p = p[:, constraint_range]
|
||||
q = q[:, constraint_range]
|
||||
loss += kl_loss(p, q) * reg_alpha
|
||||
|
||||
return loss, nll_loss, ntokens
|
||||
|
||||
|
||||
class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.sentence_avg = args.sentence_avg
|
||||
self.eps = args.label_smoothing
|
||||
self.ignore_prefix_size = args.ignore_prefix_size
|
||||
self.ignore_eos = args.ignore_eos
|
||||
self.report_accuracy = args.report_accuracy
|
||||
self.drop_worst_ratio = args.drop_worst_ratio
|
||||
self.drop_worst_after = args.drop_worst_after
|
||||
self.use_rdrop = args.use_rdrop
|
||||
self.reg_alpha = args.reg_alpha
|
||||
self.sample_patch_num = args.sample_patch_num
|
||||
|
||||
self.constraint_start = None
|
||||
self.constraint_end = None
|
||||
if args.constraint_range is not None:
|
||||
constraint_start, constraint_end = args.constraint_range.split(',')
|
||||
self.constraint_start = int(constraint_start)
|
||||
self.constraint_end = int(constraint_end)
|
||||
self.padding_idx = args.tokenizer.pad_token_id
|
||||
self.args = args
|
||||
|
||||
def forward(self, output, sample, update_num=0, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
if isinstance(sample, list):
|
||||
if self.sample_patch_num > 0:
|
||||
sample[0]['net_input'][
|
||||
'sample_patch_num'] = self.sample_patch_num
|
||||
loss_v1, sample_size_v1, logging_output_v1 = self.forward(
|
||||
output[0], sample[0], update_num, reduce)
|
||||
loss_v2, sample_size_v2, logging_output_v2 = self.forward(
|
||||
output[1], sample[1], update_num, reduce)
|
||||
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
||||
sample_size = 1
|
||||
logging_output = {
|
||||
'loss':
|
||||
loss.data,
|
||||
'loss_v1':
|
||||
loss_v1.data,
|
||||
'loss_v2':
|
||||
loss_v2.data,
|
||||
'nll_loss':
|
||||
logging_output_v1['nll_loss'].data / sample_size_v1
|
||||
+ logging_output_v2['nll_loss'].data / sample_size_v2,
|
||||
'ntokens':
|
||||
logging_output_v1['ntokens'] + logging_output_v2['ntokens'],
|
||||
'nsentences':
|
||||
logging_output_v1['nsentences']
|
||||
+ logging_output_v2['nsentences'],
|
||||
'sample_size':
|
||||
1,
|
||||
'sample_size_v1':
|
||||
sample_size_v1,
|
||||
'sample_size_v2':
|
||||
sample_size_v2,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
if self.use_rdrop:
|
||||
construct_rdrop_sample(sample)
|
||||
|
||||
net_output = output
|
||||
# model(**sample["net_input"])
|
||||
loss, nll_loss, ntokens = self.compute_loss(
|
||||
net_output, sample, update_num, reduce=reduce)
|
||||
sample_size = (
|
||||
sample['target'].size(0) if self.sentence_avg else ntokens)
|
||||
logging_output = {
|
||||
'loss': loss.data,
|
||||
'nll_loss': nll_loss.data,
|
||||
'ntokens': sample['ntokens'],
|
||||
'nsentences': sample['nsentences'],
|
||||
'sample_size': sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def get_lprobs_and_target(self, net_output, sample):
|
||||
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[
|
||||
'conf'] is not None else 1
|
||||
constraint_masks = None
|
||||
if 'constraint_masks' in sample and sample[
|
||||
'constraint_masks'] is not None:
|
||||
constraint_masks = sample['constraint_masks']
|
||||
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
||||
if self.constraint_start is not None and self.constraint_end is not None:
|
||||
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
||||
net_output[0][:, :, self.constraint_end:] = -math.inf
|
||||
lprobs = F.log_softmax(
|
||||
net_output[0], dim=-1, dtype=torch.float32) * conf
|
||||
target = sample['target']
|
||||
if self.ignore_prefix_size > 0:
|
||||
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
|
||||
target = target[:, self.ignore_prefix_size:].contiguous()
|
||||
if constraint_masks is not None:
|
||||
constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable
|
||||
if self.ignore_eos:
|
||||
bsz, seq_len, embed_dim = lprobs.size()
|
||||
eos_indices = target.eq(self.task.tgt_dict.eos())
|
||||
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim)
|
||||
target = target[~eos_indices].reshape(bsz, seq_len - 1)
|
||||
if constraint_masks is not None:
|
||||
constraint_masks = constraint_masks[~eos_indices].reshape(
|
||||
bsz, seq_len - 1, embed_dim)
|
||||
if constraint_masks is not None:
|
||||
constraint_masks = constraint_masks.view(-1,
|
||||
constraint_masks.size(-1))
|
||||
return lprobs.view(-1,
|
||||
lprobs.size(-1)), target.view(-1), constraint_masks
|
||||
|
||||
def compute_loss(self, net_output, sample, update_num, reduce=True):
|
||||
lprobs, target, constraint_masks = self.get_lprobs_and_target(
|
||||
net_output, sample)
|
||||
if constraint_masks is not None:
|
||||
constraint_masks = constraint_masks[target != self.padding_idx]
|
||||
lprobs = lprobs[target != self.padding_idx]
|
||||
target = target[target != self.padding_idx]
|
||||
loss, nll_loss, ntokens = label_smoothed_nll_loss(
|
||||
lprobs,
|
||||
target,
|
||||
self.eps,
|
||||
update_num,
|
||||
reduce=reduce,
|
||||
drop_worst_ratio=self.drop_worst_ratio,
|
||||
drop_worst_after=self.drop_worst_after,
|
||||
use_rdrop=self.use_rdrop,
|
||||
reg_alpha=self.reg_alpha,
|
||||
constraint_masks=constraint_masks,
|
||||
constraint_start=self.constraint_start,
|
||||
constraint_end=self.constraint_end)
|
||||
return loss, nll_loss, ntokens
|
||||
|
||||
|
||||
def get_schedule(args):
|
||||
|
||||
if args.schedule == 'const':
|
||||
scheduler_class = transformers.get_constant_schedule_with_warmup
|
||||
scheduler_args = {
|
||||
'num_warmup_steps':
|
||||
int(args.warmup_proportion * args.num_train_steps)
|
||||
}
|
||||
elif args.schedule == 'linear':
|
||||
scheduler_class = transformers.get_linear_schedule_with_warmup
|
||||
scheduler_args = {
|
||||
'num_warmup_steps':
|
||||
int(args.warmup_proportion * args.num_train_steps),
|
||||
'num_training_steps': args.num_train_steps
|
||||
}
|
||||
elif args.schedule == 'cosine':
|
||||
scheduler_class = transformers.get_cosine_schedule_with_warmup
|
||||
scheduler_args = {
|
||||
'num_warmup_steps':
|
||||
int(args.warmup_proportion * args.num_train_steps),
|
||||
'num_training_steps': args.num_train_steps
|
||||
}
|
||||
elif args.schedule == 'polynomial_decay':
|
||||
scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup
|
||||
scheduler_args = {
|
||||
'num_warmup_steps':
|
||||
int(args.warmup_proportion * args.num_train_steps),
|
||||
'num_training_steps': args.num_train_steps,
|
||||
'lr_end': args.lr_end
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return scheduler_class, scheduler_args
|
||||
|
||||
14
modelscope/utils/multi_modal/fp16/__init__.py
Normal file
14
modelscope/utils/multi_modal/fp16/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .fp16 import FP16_Module, FP16_Optimizer
|
||||
655
modelscope/utils/multi_modal/fp16/fp16.py
Executable file
655
modelscope/utils/multi_modal/fp16/fp16.py
Executable file
@@ -0,0 +1,655 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stable version of apex FP16 Optimizer"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .fp16util import (master_params_to_model_params,
|
||||
model_grads_to_master_grads)
|
||||
from .loss_scaler import DynamicLossScaler, LossScaler
|
||||
|
||||
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
|
||||
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
|
||||
|
||||
|
||||
def conversion_helper(val, conversion):
|
||||
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
|
||||
if not isinstance(val, (tuple, list)):
|
||||
return conversion(val)
|
||||
rtn = [conversion_helper(v, conversion) for v in val]
|
||||
if isinstance(val, tuple):
|
||||
rtn = tuple(rtn)
|
||||
return rtn
|
||||
|
||||
|
||||
def fp32_to_fp16(val):
|
||||
"""Convert fp32 `val` to fp16"""
|
||||
|
||||
def half_conversion(val):
|
||||
val_typecheck = val
|
||||
if isinstance(val_typecheck, (Parameter, Variable)):
|
||||
val_typecheck = val.data
|
||||
if isinstance(val_typecheck, FLOAT_TYPES):
|
||||
val = val.half()
|
||||
return val
|
||||
|
||||
return conversion_helper(val, half_conversion)
|
||||
|
||||
|
||||
def fp16_to_fp32(val):
|
||||
"""Convert fp16 `val` to fp32"""
|
||||
|
||||
def float_conversion(val):
|
||||
val_typecheck = val
|
||||
if isinstance(val_typecheck, (Parameter, Variable)):
|
||||
val_typecheck = val.data
|
||||
if isinstance(val_typecheck, HALF_TYPES):
|
||||
val = val.float()
|
||||
return val
|
||||
|
||||
return conversion_helper(val, float_conversion)
|
||||
|
||||
|
||||
class FP16_Module(nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
super(FP16_Module, self).__init__()
|
||||
self.add_module('module', module.half())
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
return self.module.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.module.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
class FP16_Optimizer(object):
|
||||
"""
|
||||
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
|
||||
and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
|
||||
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
|
||||
and changing the call to ``backward``.
|
||||
|
||||
Example::
|
||||
|
||||
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
# Name the FP16_Optimizer instance to replace the existing optimizer
|
||||
# (recommended but not required):
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
||||
...
|
||||
# loss.backward() becomes:
|
||||
optimizer.backward(loss)
|
||||
...
|
||||
|
||||
Example with dynamic loss scaling::
|
||||
|
||||
...
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
# optional arg to control dynamic loss scaling behavior
|
||||
# dynamic_loss_args={'scale_window' : 500})
|
||||
# Usually, dynamic_loss_args is not necessary.
|
||||
|
||||
Args:
|
||||
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa
|
||||
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa
|
||||
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa
|
||||
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa
|
||||
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa
|
||||
|
||||
``init_optimizer`` is expected to have been constructed in the ordinary way.
|
||||
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
|
||||
named to replace ``init_optimizer``, for two reasons:
|
||||
First, it means that references to the same name
|
||||
later in the file will not have to change.
|
||||
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
|
||||
modify ``init_optimizer``. If you do choose a unique name for the new
|
||||
:class:`FP16_Optimizer` instance, you should only work with this new instance,
|
||||
because the preexisting optimizer might no longer behave as expected.
|
||||
|
||||
``init_optimizer`` may be any Pytorch optimizer.
|
||||
It may contain a mixture of fp16 and fp32 parameters organized into any number of
|
||||
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
|
||||
ingest these ``param_groups`` and remember them.
|
||||
|
||||
Calls to ::
|
||||
|
||||
loss.backward()
|
||||
|
||||
must be replaced with ::
|
||||
|
||||
optimizer.backward(loss)
|
||||
|
||||
because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
|
||||
loss scaling and copies to master gradients.
|
||||
|
||||
.. note::
|
||||
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
|
||||
are downscaled before being applied. This means that adjusting the loss scale, or using
|
||||
dynamic loss scaling, should not require retuning the learning rate or any other
|
||||
hyperparameters.
|
||||
|
||||
|
||||
**Advanced options**
|
||||
|
||||
**Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
|
||||
See docstring for :attr:`step`.
|
||||
|
||||
**Gradient clipping**: Use :attr:`clip_master_grads`.
|
||||
|
||||
**Multiple losses**: If your model accumulates gradients from multiple losses,
|
||||
this can be made more efficient by supplying ``update_master_grads=False``
|
||||
to :attr:`backward`. See docstring for :attr:`backward`.
|
||||
|
||||
**Manually adjusting loss scale**: The current loss scale can be retrieved or set via ::
|
||||
|
||||
print(optimizer.loss_scale)
|
||||
optimizer.loss_scale = new_loss_scale
|
||||
|
||||
For static loss scaling, manually adjusting the loss scale over time is a reasonable
|
||||
thing to do. During later epochs, gradients may become smaller, and a
|
||||
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
|
||||
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
|
||||
the loss scale is not recommended.
|
||||
|
||||
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
|
||||
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
|
||||
should still work as intended.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_optimizer,
|
||||
static_loss_scale=1.0,
|
||||
dynamic_loss_scale=False,
|
||||
dynamic_loss_args=None,
|
||||
verbose=False):
|
||||
if not torch.cuda.is_available:
|
||||
raise SystemError('Cannot use fp16 without CUDA.')
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
self.optimizer = init_optimizer
|
||||
# init_state_dict sets up an alternative way to cast per-param state tensors.
|
||||
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
|
||||
# init_state_dict = init_optimizer.state_dict()
|
||||
|
||||
self.fp16_groups = []
|
||||
self.fp32_from_fp16_groups = []
|
||||
self.fp32_from_fp32_groups = []
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
self.maybe_print(
|
||||
'FP16_Optimizer processing param group {}:'.format(i))
|
||||
fp16_params_this_group = []
|
||||
fp32_params_this_group = []
|
||||
fp32_from_fp16_params_this_group = []
|
||||
for i, param in enumerate(param_group['params']):
|
||||
if param.requires_grad:
|
||||
if param.type() == 'torch.cuda.HalfTensor':
|
||||
self.maybe_print(
|
||||
'FP16_Optimizer received torch.cuda.HalfTensor with {}'
|
||||
.format(param.size()))
|
||||
fp16_params_this_group.append(param)
|
||||
master_param = param.detach().clone().float()
|
||||
master_param.requires_grad = True
|
||||
# Copythe model parallel flag.
|
||||
master_param.model_parallel = param.model_parallel
|
||||
param_group['params'][i] = master_param
|
||||
fp32_from_fp16_params_this_group.append(master_param)
|
||||
# Reset existing state dict key to the new master param.
|
||||
# We still need to recast per-param state tensors, if any, to FP32.
|
||||
if param in self.optimizer.state:
|
||||
self.optimizer.state[
|
||||
master_param] = self.optimizer.state.pop(param)
|
||||
elif param.type() == 'torch.cuda.FloatTensor':
|
||||
self.maybe_print(
|
||||
'FP16_Optimizer received torch.cuda.FloatTensor with {}'
|
||||
.format(param.size()))
|
||||
fp32_params_this_group.append(param)
|
||||
param_group['params'][i] = param
|
||||
else:
|
||||
raise TypeError(
|
||||
'Wrapped parameters must be either '
|
||||
'torch.cuda.FloatTensor or torch.cuda.HalfTensor. '
|
||||
'Received {}'.format(param.type()))
|
||||
|
||||
self.fp16_groups.append(fp16_params_this_group)
|
||||
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
|
||||
self.fp32_from_fp32_groups.append(fp32_params_this_group)
|
||||
|
||||
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
|
||||
self.optimizer.load_state_dict(self.optimizer.state_dict())
|
||||
# alternative way to cast per-param state tensors:
|
||||
# self.optimizer.load_state_dict(init_state_dict)
|
||||
|
||||
if dynamic_loss_scale:
|
||||
self.dynamic_loss_scale = True
|
||||
if dynamic_loss_args is not None:
|
||||
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
|
||||
else:
|
||||
self.loss_scaler = DynamicLossScaler()
|
||||
else:
|
||||
self.dynamic_loss_scale = False
|
||||
self.loss_scaler = LossScaler(static_loss_scale)
|
||||
|
||||
self.overflow = False
|
||||
self.first_closure_call_this_step = True
|
||||
|
||||
self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_
|
||||
|
||||
def maybe_print(self, msg):
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
|
||||
def __getstate__(self):
|
||||
raise RuntimeError(
|
||||
'FP16_Optimizer should be serialized using state_dict().')
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise RuntimeError(
|
||||
'FP16_Optimizer should be deserialized using load_state_dict().')
|
||||
|
||||
def zero_grad(self, set_grads_to_None=False):
|
||||
"""
|
||||
Zero fp32 and fp16 parameter grads.
|
||||
"""
|
||||
# In principle, only the .grad attributes of the model params need to be zeroed,
|
||||
# because gradients are copied into the FP32 master params. However, we zero
|
||||
# all gradients owned by the optimizer, just to be safe:
|
||||
for group in self.optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
if set_grads_to_None:
|
||||
p.grad = None
|
||||
else:
|
||||
if p.grad is not None:
|
||||
p.grad.detach_()
|
||||
p.grad.zero_()
|
||||
|
||||
# Zero fp16 gradients owned by the model:
|
||||
for fp16_group in self.fp16_groups:
|
||||
for param in fp16_group:
|
||||
if set_grads_to_None:
|
||||
param.grad = None
|
||||
else:
|
||||
if param.grad is not None:
|
||||
param.grad.detach_(
|
||||
) # as in torch.optim.optimizer.zero_grad()
|
||||
param.grad.zero_()
|
||||
|
||||
def _check_overflow(self):
|
||||
params = []
|
||||
for group in self.fp16_groups:
|
||||
for param in group:
|
||||
params.append(param)
|
||||
for group in self.fp32_from_fp32_groups:
|
||||
for param in group:
|
||||
params.append(param)
|
||||
self.overflow = self.loss_scaler.has_overflow(params)
|
||||
|
||||
def _update_scale(self, has_overflow=False):
|
||||
self.loss_scaler.update_scale(has_overflow)
|
||||
|
||||
def _master_params_to_model_params(self):
|
||||
for fp16_group, fp32_from_fp16_group in zip(
|
||||
self.fp16_groups, self.fp32_from_fp16_groups):
|
||||
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
|
||||
|
||||
def _model_params_to_master_params(self):
|
||||
for fp16_group, fp32_from_fp16_group in zip(
|
||||
self.fp16_groups, self.fp32_from_fp16_groups):
|
||||
master_params_to_model_params(fp32_from_fp16_group, fp16_group)
|
||||
|
||||
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
|
||||
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
|
||||
def _model_grads_to_master_grads(self):
|
||||
for fp16_group, fp32_from_fp16_group in zip(
|
||||
self.fp16_groups, self.fp32_from_fp16_groups):
|
||||
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
|
||||
|
||||
def _downscale_master(self):
|
||||
if self.loss_scale != 1.0:
|
||||
for group in self.optimizer.param_groups:
|
||||
for param in group['params']:
|
||||
if param.grad is not None:
|
||||
param.grad.data.mul_(1. / self.loss_scale)
|
||||
|
||||
def clip_master_grads(self, max_norm, norm_type=2):
|
||||
"""
|
||||
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
|
||||
|
||||
Args:
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the current fp32 gradients (viewed as a single vector).
|
||||
|
||||
.. warning::
|
||||
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa
|
||||
"""
|
||||
if not self.overflow:
|
||||
fp32_params = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
for param in param_group['params']:
|
||||
fp32_params.append(param)
|
||||
return self.clip_grad_norm(fp32_params, max_norm, norm_type)
|
||||
else:
|
||||
return -1
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
|
||||
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
|
||||
of the contained Pytorch optimizer.
|
||||
Example::
|
||||
|
||||
checkpoint = {}
|
||||
checkpoint['model'] = model.state_dict()
|
||||
checkpoint['optimizer'] = optimizer.state_dict()
|
||||
torch.save(checkpoint, "saved.pth")
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['loss_scaler'] = self.loss_scaler
|
||||
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
|
||||
state_dict['overflow'] = self.overflow
|
||||
state_dict[
|
||||
'first_closure_call_this_step'] = self.first_closure_call_this_step
|
||||
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
|
||||
state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Loads a state_dict created by an earlier call to state_dict().
|
||||
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
|
||||
whose parameters in turn came from ``model``, it is expected that the user
|
||||
will call ``model.load_state_dict()`` before
|
||||
``fp16_optimizer_instance.load_state_dict()`` is called.
|
||||
|
||||
Example::
|
||||
|
||||
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
||||
...
|
||||
checkpoint = torch.load("saved.pth")
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
"""
|
||||
# I think it should actually be ok to reload the optimizer before the model.
|
||||
self.loss_scaler = state_dict['loss_scaler']
|
||||
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
|
||||
self.overflow = state_dict['overflow']
|
||||
self.first_closure_call_this_step = state_dict[
|
||||
'first_closure_call_this_step']
|
||||
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
||||
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
|
||||
# The optimizer's hyperparameters and internal buffers are also up to date.
|
||||
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
|
||||
# out of date. There are two options.
|
||||
# 1: Refresh the master params from the model's fp16 params.
|
||||
# This requires less storage but incurs precision loss.
|
||||
# 2: Save and restore the fp32 master copies separately.
|
||||
# We choose option 2.
|
||||
#
|
||||
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
|
||||
# of their associated parameters, because it's possible those buffers might not exist yet in
|
||||
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
|
||||
# constructed in the same way as the one whose state_dict we are loading, the same master params
|
||||
# are guaranteed to exist, so we can just copy_() from the saved master params.
|
||||
for current_group, saved_group in zip(self.fp32_from_fp16_groups,
|
||||
state_dict['fp32_from_fp16']):
|
||||
for current, saved in zip(current_group, saved_group):
|
||||
current.data.copy_(saved.data)
|
||||
|
||||
def step(self, closure=None): # could add clip option.
|
||||
"""
|
||||
If no closure is supplied, :attr:`step` should be called after
|
||||
``fp16_optimizer_obj.backward(loss)``.
|
||||
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
|
||||
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
|
||||
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
|
||||
another forward pass using their model.
|
||||
|
||||
If a closure is supplied, :attr:`step` may be called without a prior call to
|
||||
:attr:`backward(loss)`.
|
||||
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
|
||||
However, the user should take care that any ``loss.backward()`` call within the closure
|
||||
has been replaced by ``fp16_optimizer_obj.backward(loss)``.
|
||||
|
||||
Args:
|
||||
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa
|
||||
|
||||
Example with closure::
|
||||
|
||||
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
|
||||
# existing pytorch optimizer.
|
||||
for input, target in dataset:
|
||||
def closure():
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
# loss.backward() becomes:
|
||||
optimizer.backward(loss)
|
||||
return loss
|
||||
optimizer.step(closure)
|
||||
|
||||
.. warning::
|
||||
Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.
|
||||
|
||||
.. _`ordinary Pytorch optimizer use`:
|
||||
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
|
||||
"""
|
||||
|
||||
scale = self.loss_scaler.loss_scale
|
||||
self._update_scale(self.overflow)
|
||||
|
||||
if self.overflow:
|
||||
self.maybe_print(
|
||||
'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}'
|
||||
.format(scale, self.loss_scale))
|
||||
return
|
||||
|
||||
if closure is not None:
|
||||
retval = self._step_with_closure(closure)
|
||||
else:
|
||||
retval = self.optimizer.step()
|
||||
|
||||
self._master_params_to_model_params()
|
||||
|
||||
return retval
|
||||
|
||||
def _step_with_closure(self, closure):
|
||||
|
||||
def wrapped_closure():
|
||||
# helpful for debugging
|
||||
# print("Calling wrapped_closure, first_closure_call_this_step = {}"
|
||||
# .format(self.first_closure_call_this_step))
|
||||
if self.first_closure_call_this_step:
|
||||
# We expect that the fp16 params are initially fresh on entering self.step(),
|
||||
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
|
||||
# is called within self.optimizer.step().
|
||||
self.first_closure_call_this_step = False
|
||||
else:
|
||||
# If self.optimizer.step() internally calls wrapped_closure more than once,
|
||||
# it may update the fp32 params after each call. However, self.optimizer
|
||||
# doesn't know about the fp16 params at all. If the fp32 params get updated,
|
||||
# we can't rely on self.optimizer to refresh the fp16 params. We need
|
||||
# to handle that manually:
|
||||
self._master_params_to_model_params()
|
||||
# Our API expects the user to give us ownership of the backward() call by
|
||||
# replacing all calls to loss.backward() with optimizer.backward(loss).
|
||||
# This requirement holds whether or not the call to backward() is made within a closure.
|
||||
# If the user is properly calling optimizer.backward(loss) within "closure,"
|
||||
# calling closure() here will give the fp32 master params fresh gradients
|
||||
# for the optimizer to play with, so all wrapped_closure needs to do is call
|
||||
# closure() and return the loss.
|
||||
temp_loss = closure()
|
||||
while (self.overflow):
|
||||
scale = self.loss_scaler.loss_scale
|
||||
self._update_scale(self.overflow)
|
||||
self.maybe_print(
|
||||
'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, '
|
||||
'reducing to {}'.format(scale, self.loss_scale))
|
||||
temp_loss = closure()
|
||||
return temp_loss
|
||||
|
||||
retval = self.optimizer.step(wrapped_closure)
|
||||
|
||||
self.first_closure_call_this_step = True
|
||||
|
||||
return retval
|
||||
|
||||
def backward(self, loss, update_master_grads=True, retain_graph=False):
|
||||
"""
|
||||
:attr:`backward` performs the following conceptual steps:
|
||||
|
||||
1. fp32_loss = loss.float() (see first Note below)
|
||||
2. scaled_loss = fp32_loss*loss_scale
|
||||
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa
|
||||
4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa
|
||||
5. Finally, master grads are divided by loss_scale.
|
||||
|
||||
In this way, after :attr:`backward`, the master params have fresh gradients,
|
||||
and :attr:`step` may be called.
|
||||
|
||||
.. note::
|
||||
:attr:`backward` internally converts the loss to fp32 before applying the loss scale.
|
||||
This provides some additional safety against overflow if the user has supplied an
|
||||
fp16 loss value.
|
||||
However, for maximum overflow safety, the user should
|
||||
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
|
||||
:attr:`backward`.
|
||||
|
||||
.. warning::
|
||||
The gradients found in a model's leaves after the call to
|
||||
:attr:`backward` should not be regarded as valid in general,
|
||||
because it's possible
|
||||
they have been scaled (and in the case of dynamic loss scaling,
|
||||
the scale factor may change over time).
|
||||
If the user wants to inspect gradients after a call to :attr:`backward`,
|
||||
only the master gradients should be regarded as valid. These can be retrieved via
|
||||
:attr:`inspect_master_grad_data()`.
|
||||
|
||||
Args:
|
||||
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
|
||||
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa
|
||||
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa
|
||||
|
||||
Example::
|
||||
|
||||
# Ordinary operation:
|
||||
optimizer.backward(loss)
|
||||
|
||||
# Naive operation with multiple losses (technically valid, but less efficient):
|
||||
# fp32 grads will be correct after the second call, but
|
||||
# the first call incurs an unnecessary fp16->fp32 grad copy.
|
||||
optimizer.backward(loss1)
|
||||
optimizer.backward(loss2)
|
||||
|
||||
# More efficient way to handle multiple losses:
|
||||
# The fp16->fp32 grad copy is delayed until fp16 grads from all
|
||||
# losses have been accumulated.
|
||||
optimizer.backward(loss1, update_master_grads=False)
|
||||
optimizer.backward(loss2, update_master_grads=False)
|
||||
optimizer.update_master_grads()
|
||||
"""
|
||||
# To consider: try multiple backward passes using retain_grad=True to find
|
||||
# a loss scale that works. After you find a loss scale that works, do a final dummy
|
||||
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
|
||||
# discarding the iteration, but probably wouldn't improve overall efficiency.
|
||||
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
|
||||
if update_master_grads:
|
||||
self.update_master_grads()
|
||||
|
||||
def update_master_grads(self):
|
||||
"""
|
||||
Copy the ``.grad`` attribute from stored references to fp16 parameters to
|
||||
the ``.grad`` attribute of the fp32 master parameters that are directly
|
||||
updated by the optimizer. :attr:`update_master_grads` only needs to be called if
|
||||
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
|
||||
"""
|
||||
if self.dynamic_loss_scale:
|
||||
self._check_overflow()
|
||||
if self.overflow: return # noqa
|
||||
self._model_grads_to_master_grads()
|
||||
self._downscale_master()
|
||||
|
||||
def inspect_master_grad_data(self):
|
||||
"""
|
||||
When running with :class:`FP16_Optimizer`,
|
||||
``.grad`` attributes of a model's fp16 leaves should not be
|
||||
regarded as truthful, because they might be scaled.
|
||||
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
|
||||
the fp32 master params' ``.grad``
|
||||
attributes will contain valid gradients properly divided by the loss scale. However,
|
||||
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
|
||||
nonintuitive. :attr:`inspect_master_grad_data`
|
||||
allows those gradients to be viewed with shapes corresponding to their associated model leaves.
|
||||
|
||||
Returns:
|
||||
List of lists (one list for each parameter group). The list for each parameter group
|
||||
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
|
||||
"""
|
||||
if self.overflow:
|
||||
print(
|
||||
'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. '
|
||||
'Gradients are currently invalid (may be inf, nan, or stale). Returning None.'
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# The optimizer owns only references to master params.
|
||||
master_grads_data = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
master_grads_this_group = []
|
||||
for param in param_group['params']:
|
||||
if param.grad is not None:
|
||||
master_grads_this_group.append(param.grad.data)
|
||||
else:
|
||||
master_grads_this_group.append(None)
|
||||
master_grads_data.append(master_grads_this_group)
|
||||
return master_grads_data
|
||||
|
||||
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
|
||||
def _get_loss_scale(self):
|
||||
return self.loss_scaler.loss_scale
|
||||
|
||||
def _set_loss_scale(self, value):
|
||||
self.loss_scaler.cur_scale = value
|
||||
|
||||
loss_scale = property(_get_loss_scale, _set_loss_scale)
|
||||
|
||||
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
|
||||
def _get_state(self):
|
||||
return self.optimizer.state
|
||||
|
||||
def _set_state(self, value):
|
||||
self.optimizer.state = value
|
||||
|
||||
state = property(_get_state, _set_state)
|
||||
|
||||
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
|
||||
# (for example, to adjust the learning rate)
|
||||
def _get_param_groups(self):
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def _set_param_groups(self, value):
|
||||
self.optimizer.param_groups = value
|
||||
|
||||
param_groups = property(_get_param_groups, _set_param_groups)
|
||||
216
modelscope/utils/multi_modal/fp16/fp16util.py
Normal file
216
modelscope/utils/multi_modal/fp16/fp16util.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class tofp16(nn.Module):
|
||||
"""
|
||||
Utility module that implements::
|
||||
|
||||
def forward(self, input):
|
||||
return input.half()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(tofp16, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input.half()
|
||||
|
||||
|
||||
def BN_convert_float(module):
|
||||
"""
|
||||
Utility function for network_to_half().
|
||||
|
||||
Retained for legacy purposes.
|
||||
"""
|
||||
if isinstance(
|
||||
module,
|
||||
torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
|
||||
module.float()
|
||||
for child in module.children():
|
||||
BN_convert_float(child)
|
||||
return module
|
||||
|
||||
|
||||
def network_to_half(network):
|
||||
"""
|
||||
Convert model to half precision in a batchnorm-safe way.
|
||||
|
||||
Retained for legacy purposes. It is recommended to use FP16Model.
|
||||
"""
|
||||
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
|
||||
|
||||
|
||||
def convert_module(module, dtype):
|
||||
"""
|
||||
Converts a module's immediate parameters and buffers to dtype.
|
||||
"""
|
||||
for param in module.parameters(recurse=False):
|
||||
if param is not None:
|
||||
if param.data.dtype.is_floating_point:
|
||||
param.data = param.data.to(dtype=dtype)
|
||||
if param._grad is not None and param._grad.data.dtype.is_floating_point:
|
||||
param._grad.data = param._grad.data.to(dtype=dtype)
|
||||
|
||||
for buf in module.buffers(recurse=False):
|
||||
if buf is not None and buf.data.dtype.is_floating_point:
|
||||
buf.data = buf.data.to(dtype=dtype)
|
||||
|
||||
|
||||
def convert_network(network, dtype):
|
||||
"""
|
||||
Converts a network's parameters and buffers to dtype.
|
||||
"""
|
||||
for module in network.modules():
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm
|
||||
) and module.affine is True:
|
||||
continue
|
||||
convert_module(module, dtype)
|
||||
return network
|
||||
|
||||
|
||||
class FP16Model(nn.Module):
|
||||
"""
|
||||
Convert model to half precision in a batchnorm-safe way.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(FP16Model, self).__init__()
|
||||
self.network = convert_network(network, dtype=torch.half)
|
||||
|
||||
def forward(self, *inputs):
|
||||
inputs = tuple(t.half() for t in inputs)
|
||||
return self.network(*inputs)
|
||||
|
||||
|
||||
def backwards_debug_hook(grad):
|
||||
raise RuntimeError(
|
||||
'master_params recieved a gradient in the backward pass!')
|
||||
|
||||
|
||||
def prep_param_lists(model, flat_master=False):
|
||||
"""
|
||||
Creates a list of FP32 master parameters for a given model, as in
|
||||
`Training Neural Networks with Mixed Precision: Real Examples`_.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Existing Pytorch model
|
||||
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa
|
||||
Returns:
|
||||
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa
|
||||
|
||||
Example::
|
||||
|
||||
model_params, master_params = prep_param_lists(model)
|
||||
|
||||
.. warning::
|
||||
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa
|
||||
|
||||
.. _`Training Neural Networks with Mixed Precision: Real Examples`:
|
||||
http://on-demand.gputechconf.com/gtc/2018/video/S81012/
|
||||
"""
|
||||
model_params = [
|
||||
param for param in model.parameters() if param.requires_grad
|
||||
]
|
||||
|
||||
if flat_master:
|
||||
# Give the user some more useful error messages
|
||||
try:
|
||||
# flatten_dense_tensors returns a contiguous flat array.
|
||||
# http://pytorch.org/docs/master/_modules/torch/_utils.html
|
||||
master_params = _flatten_dense_tensors(
|
||||
[param.data for param in model_params]).float()
|
||||
except: # noqa
|
||||
print(
|
||||
'Error in prep_param_lists: model may contain a mixture of parameters '
|
||||
'of different types. Use flat_master=False, or use F16_Optimizer.'
|
||||
)
|
||||
raise
|
||||
master_params = torch.nn.Parameter(master_params)
|
||||
master_params.requires_grad = True
|
||||
# master_params.register_hook(backwards_debug_hook)
|
||||
if master_params.grad is None:
|
||||
master_params.grad = master_params.new(*master_params.size())
|
||||
return model_params, [master_params]
|
||||
else:
|
||||
master_params = [
|
||||
param.clone().float().detach() for param in model_params
|
||||
]
|
||||
for param in master_params:
|
||||
param.requires_grad = True
|
||||
return model_params, master_params
|
||||
|
||||
|
||||
def model_grads_to_master_grads(model_params,
|
||||
master_params,
|
||||
flat_master=False):
|
||||
"""
|
||||
Copy model gradients to master gradients.
|
||||
|
||||
Args:
|
||||
model_params: List of model parameters created by :func:`prep_param_lists`.
|
||||
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa
|
||||
"""
|
||||
if flat_master:
|
||||
# The flattening may incur one more deep copy than is necessary.
|
||||
master_params[0].grad.data.copy_(
|
||||
_flatten_dense_tensors([p.grad.data for p in model_params]))
|
||||
else:
|
||||
for model, master in zip(model_params, master_params):
|
||||
if model.grad is not None:
|
||||
if master.grad is None:
|
||||
master.grad = Variable(
|
||||
master.data.new(*master.data.size()))
|
||||
master.grad.data.copy_(model.grad.data)
|
||||
else:
|
||||
master.grad = None
|
||||
|
||||
|
||||
def master_params_to_model_params(model_params,
|
||||
master_params,
|
||||
flat_master=False):
|
||||
"""
|
||||
Copy master parameters to model parameters.
|
||||
|
||||
Args:
|
||||
model_params: List of model parameters created by :func:`prep_param_lists`.
|
||||
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa
|
||||
"""
|
||||
if flat_master:
|
||||
for model, master in zip(
|
||||
model_params,
|
||||
_unflatten_dense_tensors(master_params[0].data, model_params)):
|
||||
model.data.copy_(master)
|
||||
else:
|
||||
for model, master in zip(model_params, master_params):
|
||||
model.data.copy_(master.data)
|
||||
|
||||
|
||||
# Backward compatibility fixes
|
||||
|
||||
|
||||
def to_python_float(t):
|
||||
if hasattr(t, 'item'):
|
||||
return t.item()
|
||||
else:
|
||||
return t[0]
|
||||
|
||||
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
237
modelscope/utils/multi_modal/fp16/loss_scaler.py
Executable file
237
modelscope/utils/multi_modal/fp16/loss_scaler.py
Executable file
@@ -0,0 +1,237 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# item() is a recent addition, so this helps with backward compatibility.
|
||||
def to_python_float(t):
|
||||
if hasattr(t, 'item'):
|
||||
return t.item()
|
||||
else:
|
||||
return t[0]
|
||||
|
||||
|
||||
class LossScaler:
|
||||
"""
|
||||
Class that manages a static loss scale. This class is intended to interact with
|
||||
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
|
||||
|
||||
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
|
||||
:class:`FP16_Optimizer`'s constructor.
|
||||
|
||||
Args:
|
||||
scale (float, optional, default=1.0): The loss scale.
|
||||
"""
|
||||
|
||||
def __init__(self, scale=1):
|
||||
self.cur_scale = scale
|
||||
|
||||
# `params` is a list / generator of torch.Variable
|
||||
def has_overflow(self, params):
|
||||
return False
|
||||
|
||||
# `x` is a torch.Tensor
|
||||
def _has_inf_or_nan(x):
|
||||
return False
|
||||
|
||||
def update_scale(self, overflow):
|
||||
pass
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.cur_scale
|
||||
|
||||
def scale_gradient(self, module, grad_in, grad_out):
|
||||
return tuple(self.loss_scale * g for g in grad_in)
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
scaled_loss = loss * self.loss_scale
|
||||
scaled_loss.backward(retain_graph=retain_graph)
|
||||
|
||||
|
||||
class DynamicLossScaler:
|
||||
"""
|
||||
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
|
||||
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
|
||||
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
|
||||
operates, because the default options can be changed using the
|
||||
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
|
||||
|
||||
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
|
||||
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
|
||||
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
|
||||
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
|
||||
occurred.
|
||||
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
|
||||
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
|
||||
If a certain number of iterations occur without overflowing gradients detected,
|
||||
:class:`DynamicLossScaler` increases the loss scale once more.
|
||||
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
|
||||
always using the highest loss scale possible without incurring overflow.
|
||||
|
||||
Args:
|
||||
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
|
||||
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa
|
||||
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_scale=2**32,
|
||||
scale_factor=2.,
|
||||
scale_window=1000,
|
||||
min_scale=1,
|
||||
delayed_shift=1,
|
||||
consecutive_hysteresis=False):
|
||||
self.cur_scale = init_scale
|
||||
self.cur_iter = 0
|
||||
self.last_overflow_iter = -1
|
||||
self.scale_factor = scale_factor
|
||||
self.scale_window = scale_window
|
||||
self.min_scale = min_scale
|
||||
self.delayed_shift = delayed_shift
|
||||
self.cur_hysteresis = delayed_shift
|
||||
self.consecutive_hysteresis = consecutive_hysteresis
|
||||
|
||||
# `params` is a list / generator of torch.Variable
|
||||
def has_overflow_serial(self, params):
|
||||
for p in params:
|
||||
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(
|
||||
p.grad.data):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def has_overflow(self, params):
|
||||
overflow = self.has_overflow_serial(params)
|
||||
overflow_gpu = torch.cuda.ByteTensor([overflow])
|
||||
overflow = overflow_gpu[0].item()
|
||||
return bool(overflow)
|
||||
|
||||
# `x` is a torch.Tensor
|
||||
def _has_inf_or_nan(x):
|
||||
try:
|
||||
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||
# Pytorch's .sum() creates a one-element tensor of the same type as x
|
||||
# (which is true for some recent version of pytorch).
|
||||
cpu_sum = float(x.float().sum())
|
||||
# More efficient version that can be used if .sum() returns a Python scalar
|
||||
# cpu_sum = float(x.sum())
|
||||
except RuntimeError as instance:
|
||||
# We want to check if inst is actually an overflow exception.
|
||||
# RuntimeError could come from a different error.
|
||||
# If so, we still want the exception to propagate.
|
||||
if 'value cannot be converted' not in instance.args[0]:
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if cpu_sum == float(
|
||||
'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
||||
return True
|
||||
return False
|
||||
|
||||
# `overflow` is boolean indicating whether the gradient overflowed
|
||||
def update_scale(self, overflow):
|
||||
|
||||
if not hasattr(self, 'min_scale'):
|
||||
self.min_scale = 1
|
||||
if not hasattr(self, 'delayed_shift'):
|
||||
self.delayed_shift = 1
|
||||
if not hasattr(self, 'cur_hysteresis'):
|
||||
self.cur_hysteresis = 1
|
||||
if not hasattr(self, 'consecutive_hysteresis'):
|
||||
self.consecutive_hysteresis = True
|
||||
if overflow:
|
||||
# self.cur_scale /= self.scale_factor
|
||||
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
|
||||
self.cur_scale = max(self.cur_scale / self.scale_factor,
|
||||
self.min_scale)
|
||||
else:
|
||||
self.cur_hysteresis -= 1
|
||||
self.last_overflow_iter = self.cur_iter
|
||||
else:
|
||||
if self.consecutive_hysteresis:
|
||||
self.cur_hysteresis = self.delayed_shift
|
||||
if (self.cur_iter
|
||||
- self.last_overflow_iter) % self.scale_window == 0:
|
||||
if not self.consecutive_hysteresis:
|
||||
self.cur_hysteresis = self.delayed_shift
|
||||
self.cur_scale *= self.scale_factor
|
||||
self.cur_iter += 1
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.cur_scale
|
||||
|
||||
def scale_gradient(self, module, grad_in, grad_out):
|
||||
return tuple(self.loss_scale * g for g in grad_in)
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
scaled_loss = loss * self.loss_scale
|
||||
scaled_loss.backward(retain_graph=retain_graph)
|
||||
|
||||
|
||||
##############################################################
|
||||
# Example usage below here -- assuming it's in a separate file
|
||||
##############################################################
|
||||
"""
|
||||
TO-DO separate out into an example.
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from dynamic_loss_scaler import DynamicLossScaler
|
||||
|
||||
# N is batch size; D_in is input dimension;
|
||||
# H is hidden dimension; D_out is output dimension.
|
||||
N, D_in, H, D_out = 64, 1000, 100, 10
|
||||
|
||||
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
|
||||
x = Variable(torch.randn(N, D_in), requires_grad=False)
|
||||
y = Variable(torch.randn(N, D_out), requires_grad=False)
|
||||
|
||||
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
|
||||
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
|
||||
parameters = [w1, w2]
|
||||
|
||||
learning_rate = 1e-6
|
||||
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
|
||||
loss_scaler = DynamicLossScaler()
|
||||
|
||||
for t in range(500):
|
||||
y_pred = x.mm(w1).clamp(min=0).mm(w2)
|
||||
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
|
||||
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
|
||||
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
|
||||
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
|
||||
|
||||
# Run backprop
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Check for overflow
|
||||
has_overflow = DynamicLossScaler.has_overflow(parameters)
|
||||
|
||||
# If no overflow, unscale grad and update as usual
|
||||
if not has_overflow:
|
||||
for param in parameters:
|
||||
param.grad.data.mul_(1. / loss_scaler.loss_scale)
|
||||
optimizer.step()
|
||||
# Otherwise, don't do anything -- ie, skip iteration
|
||||
else:
|
||||
print('OVERFLOW!')
|
||||
|
||||
# Update loss scale for next iteration
|
||||
loss_scaler.update_scale(has_overflow)
|
||||
|
||||
"""
|
||||
@@ -172,6 +172,7 @@ class OfaTasksTest(unittest.TestCase):
|
||||
ofa_pipe = pipeline(Tasks.visual_grounding, model=model)
|
||||
image = 'data/test/images/visual_grounding.png'
|
||||
text = '一个圆头的蓝色宝可梦'
|
||||
text = '火'
|
||||
input = {'image': image, 'text': text}
|
||||
result = ofa_pipe(input)
|
||||
print(result)
|
||||
|
||||
20
tests/trainers/test_ofa_trainer.py
Normal file
20
tests/trainers/test_ofa_trainer.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from modelscope.trainers.multi_modal.ofa import OFATrainer
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestOfaTrainer(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en'
|
||||
self.trainer = OFATrainer(model_id)
|
||||
self.trainer.train()
|
||||
shutil.rmtree(self.trainer.save_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user