[to #44847108] add sparsity hook (pst algorithm)

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10198228

    * [to #44847108] add sparsity hook (pst algorithm)
This commit is contained in:
laiyin.lyc
2022-10-11 16:05:20 +08:00
parent 333c11c0a6
commit 09d2296f36
7 changed files with 482 additions and 1 deletions

View File

@@ -404,6 +404,9 @@ class Hooks(object):
IterTimerHook = 'IterTimerHook'
EvaluationHook = 'EvaluationHook'
# Compression
SparsityHook = 'SparsityHook'
class LR_Schedulers(object):
"""learning rate scheduler is defined here

View File

@@ -6,10 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .builder import HOOKS, build_hook
from .checkpoint_hook import BestCkptSaverHook, CheckpointHook
from .compression import SparsityHook
from .evaluation_hook import EvaluationHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .logger import TextLoggerHook, TensorboardHook
from .logger import TensorboardHook, TextLoggerHook
from .lr_scheduler_hook import LrSchedulerHook
from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook,
OptimizerHook, TorchAMPOptimizerHook)
@@ -19,6 +20,7 @@ else:
_import_structure = {
'builder': ['HOOKS', 'build_hook'],
'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'],
'compression': ['SparsityHook'],
'evaluation_hook': ['EvaluationHook'],
'hook': ['Hook'],
'iter_timer_hook': ['IterTimerHook'],

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .sparsity_hook import SparsityHook
from .utils import SparseLinear, convert_sparse_network
else:
_import_structure = {
'sparsity_hook': ['SparsityHook'],
'utils': ['convert_sparse_network', 'SparseLinear'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,131 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from modelscope import __version__
from modelscope.metainfo import Hooks
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.hook import Hook
from modelscope.trainers.hooks.priority import Priority
from modelscope.utils.checkpoint import save_checkpoint
from modelscope.utils.torch_utils import is_master
@HOOKS.register_module(module_name=Hooks.SparsityHook)
class SparsityHook(Hook):
PRIORITY = Priority.HIGHEST
def __init__(self, pruning_method, config={}, save_dir=None):
self.pruning_method = pruning_method
self.save_dir = save_dir
self.compress_module = config.get('compress_module', [])
self.weight_rank = config.get('weight_rank', 8)
self.weight_beta = config.get('weight_beta', 1)
self.mask_rank = config.get('mask_rank', 8)
self.mask_alpha1 = config.get('mask_alpha1', 1)
self.mask_alpha2 = config.get('mask_alpha2', 1)
self.step = 0
self.total_step = 0
self.frequency = config.get('frequency', 1)
self.initial_warmup = config.get('initial_warmup', 0.1)
self.final_warmup = config.get('final_warmup', 0.3)
self.initial_sparsity = config.get('initial_sparsity', 0.0)
self.final_sparsity = config.get('final_sparsity', 0.0)
def before_run(self, trainer):
import torch
from .utils import SparseLinear, convert_sparse_network
if self.save_dir is None:
self.save_dir = trainer.work_dir
if len(self.compress_module) == 0:
convert_sparse_network(
trainer.model,
pruning_method=self.pruning_method,
weight_rank=self.weight_rank,
weight_beta=self.weight_beta,
mask_rank=self.mask_rank,
mask_alpha1=self.mask_alpha1,
mask_alpha2=self.mask_alpha2,
logger=trainer.logger,
)
else:
for cm in self.compress_module:
for name, module in trainer.model.named_modules():
if name != cm:
continue
convert_sparse_network(
module,
pruning_method=self.pruning_method,
weight_rank=self.weight_rank,
weight_beta=self.weight_beta,
mask_rank=self.mask_rank,
mask_alpha1=self.mask_alpha1,
mask_alpha2=self.mask_alpha2,
logger=trainer.logger,
)
for i in range(len(trainer.optimizer.param_groups)):
new_train_params = []
for param in trainer.optimizer.param_groups[i]['params']:
is_find = False
for name, module in trainer.model.named_modules():
if isinstance(module, SparseLinear):
if torch.equal(param.half(),
module.weight.data.half()):
is_find = True
break
if not is_find:
new_train_params.append(param)
trainer.optimizer.param_groups[i]['params'] = new_train_params
new_params = []
for name, module in trainer.model.named_modules():
if isinstance(module, SparseLinear):
new_params.extend(
[p for p in module.parameters() if p.requires_grad])
trainer.optimizer.add_param_group({'params': new_params})
self.total_step = trainer.iters_per_epoch * trainer._max_epochs
def before_train_iter(self, trainer):
from .utils import schedule_sparsity_ratio, update_network_sparsity
cur_sparsity = schedule_sparsity_ratio(
self.step,
self.total_step,
self.frequency,
self.initial_warmup,
self.final_warmup,
self.initial_sparsity,
self.final_sparsity,
)
update_network_sparsity(trainer.model, cur_sparsity)
if is_master():
trainer.logger.info(
f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}'
)
self.step += 1
def after_run(self, trainer):
from .utils import generate_sparse_model
generate_sparse_model(trainer.model, logger=trainer.logger)
self._save_checkpoint(trainer)
def _save_checkpoint(self, trainer):
if is_master():
trainer.logger.info('Saving checkpoint at final compress')
cur_save_name = os.path.join(self.save_dir, 'compress_model.pth')
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)

View File

@@ -0,0 +1,208 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from modelscope.utils.torch_utils import is_master
class SparseBinarizer(torch.autograd.Function):
@staticmethod
def forward(ctx, mask_scores, sparsity):
num_prune = int(mask_scores.numel() * sparsity)
prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune]
mask = mask_scores.clone().fill_(1)
mask.reshape(-1)[prune_indices] = 0.0
return mask
@staticmethod
def backward(ctx, gradOutput):
return gradOutput, None
class SparseLinear(nn.Module):
"""
Fully Connected layer with on the fly adaptive mask.
"""
def __init__(
self,
module,
pruning_method='pst',
weight_rank=8,
weight_beta=1.0,
mask_rank=8,
mask_alpha1=1.0,
mask_alpha2=1.0,
):
super(SparseLinear, self).__init__()
self.module = module
out_features = self.module.weight.shape[0]
in_features = self.module.weight.shape[1]
self.weight = self.module.weight
self.module.weight = None
self.module._parameters.pop('weight')
self.pruning_method = pruning_method
self.cur_sparsity = 0.0
if self.pruning_method == 'pst':
self.weight_rank = weight_rank
self.weight_beta = weight_beta
self.mask_rank = mask_rank
self.mask_alpha1 = mask_alpha1
self.mask_alpha2 = mask_alpha2
# create trainable params
self.weight_U = nn.Parameter(
torch.randn(out_features, self.weight_rank).to(
device=self.weight.device, dtype=self.weight.dtype))
self.weight_V = nn.Parameter(
torch.zeros(self.weight_rank, in_features).to(
device=self.weight.device, dtype=self.weight.dtype))
self.mask_scores_A = nn.Parameter(
torch.randn(out_features, self.mask_rank).to(
device=self.weight.device, dtype=self.weight.dtype))
self.mask_scores_B = nn.Parameter(
torch.zeros(self.mask_rank, in_features).to(
device=self.weight.device, dtype=self.weight.dtype))
self.mask_scores_R = nn.Parameter(
torch.zeros(out_features).to(
device=self.weight.device, dtype=self.weight.dtype))
self.mask_scores_C = nn.Parameter(
torch.zeros(in_features).to(
device=self.weight.device, dtype=self.weight.dtype))
self.weight.requires_grad = False
if self.module.bias is not None:
self.module.bias.requires_grad = False
def forward(self, *inputs):
if self.pruning_method == 'pst':
weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
mask_scores = (
weight.abs()
+ self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
+ self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
+ self.mask_scores_C.unsqueeze(0)))
mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)
masked_weight = mask * weight
self.module.weight = masked_weight
return self.module(*inputs)
else:
return self.module(*inputs)
def convert(self):
if self.pruning_method == 'pst':
weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
mask_scores = (
weight.abs()
+ self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
+ self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
+ self.mask_scores_C.unsqueeze(0)))
mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)
masked_weight = mask * weight
self.module.weight = nn.Parameter(masked_weight.data)
def _setattr(model, name, module):
name_list = name.split('.')
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)
def convert_sparse_network(
model,
pruning_method,
weight_rank,
weight_beta,
mask_rank,
mask_alpha1,
mask_alpha2,
logger=None,
):
compress_module = [nn.Linear]
try:
from megatron import mpu
compress_module.extend(
[mpu.RowParallelLinear, mpu.ColumnParallelLinear])
except ImportError:
pass
for name, module in model.named_modules():
if type(module) in compress_module:
new_module = SparseLinear(
module,
pruning_method,
weight_rank,
weight_beta,
mask_rank,
mask_alpha1,
mask_alpha2,
)
# replace original module by new sparse module
_setattr(model, name, new_module)
if is_master():
if logger:
logger.info(f'convert {name} to sparse module.')
else:
print(f'convert {name} to sparse module.')
def update_network_sparsity(model, sparsity):
for name, module in model.named_modules():
if isinstance(module, SparseLinear):
module.cur_sparsity = sparsity
def schedule_sparsity_ratio(
step,
total_step,
frequency,
initial_warmup,
final_warmup,
initial_sparsity,
final_sparsity,
):
if step <= initial_warmup * total_step:
sparsity = initial_sparsity
elif step > (total_step - final_warmup * total_step):
sparsity = final_sparsity
else:
spars_warmup_steps = initial_warmup * total_step
spars_schedu_steps = (final_warmup + initial_warmup) * total_step
step = (step - spars_warmup_steps) // frequency * frequency
mul_coeff = 1 - step / (total_step - spars_schedu_steps)
sparsity = final_sparsity + (initial_sparsity - final_sparsity) * (
mul_coeff**3)
return sparsity
def generate_sparse_model(model, logger=None):
# generate sparse weight for saving
for name, module in model.named_modules():
if isinstance(module, SparseLinear):
module.convert()
_setattr(model, name, module.module)
if is_master():
if logger:
logger.info(f'convert {name} weight to sparse weight, \
sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
)
else:
print(f'convert {name} weight to sparse, \
sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
)

View File

@@ -0,0 +1,113 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import json
import numpy as np
import torch
from torch import nn
from torch.optim import SGD
from torch.optim.lr_scheduler import MultiStepLR
from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile, TrainerStages
from modelscope.utils.test_utils import create_dummy_test_dataset
dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10)
class DummyModel(nn.Module, Model):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 10)
self.bn = nn.BatchNorm1d(10)
def forward(self, feat, labels):
x = self.linear(feat)
x = self.bn(x)
loss = torch.sum(x)
return dict(logits=x, loss=loss)
class SparsityHookTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)
def test_sparsity_hook(self):
json_cfg = {
'task': 'image_classification',
'train': {
'work_dir':
self.tmp_dir,
'dataloader': {
'batch_size_per_gpu': 2,
'workers_per_gpu': 1
},
'hooks': [{
'type': 'SparsityHook',
'pruning_method': 'pst',
'config': {
'weight_rank': 1,
'mask_rank': 1,
'final_sparsity': 0.9,
'frequency': 1,
},
}],
},
}
config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
with open(config_path, 'w') as f:
json.dump(json_cfg, f)
model = DummyModel()
optimizer = SGD(model.parameters(), lr=0.01)
lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4])
trainer_name = Trainers.default
kwargs = dict(
cfg_file=config_path,
model=model,
train_dataset=dummy_dataset,
optimizers=(optimizer, lr_scheduler),
max_epochs=5,
device='cpu',
)
trainer = build_trainer(trainer_name, kwargs)
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
trainer.register_hook_from_cfg(trainer.cfg.train.hooks)
trainer.train_dataloader = train_dataloader
trainer.data_loader = train_dataloader
trainer.invoke_hook(TrainerStages.before_run)
for i in range(trainer._epoch, trainer._max_epochs):
trainer.invoke_hook(TrainerStages.before_train_epoch)
for _, data_batch in enumerate(train_dataloader):
trainer.invoke_hook(TrainerStages.before_train_iter)
trainer.train_step(trainer.model, data_batch)
trainer.invoke_hook(TrainerStages.after_train_iter)
trainer.invoke_hook(TrainerStages.after_train_epoch)
trainer.invoke_hook(TrainerStages.after_run)
self.assertEqual(
torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9)
if __name__ == '__main__':
unittest.main()