mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #44847108] add sparsity hook (pst algorithm)
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10198228 * [to #44847108] add sparsity hook (pst algorithm)
This commit is contained in:
@@ -404,6 +404,9 @@ class Hooks(object):
|
||||
IterTimerHook = 'IterTimerHook'
|
||||
EvaluationHook = 'EvaluationHook'
|
||||
|
||||
# Compression
|
||||
SparsityHook = 'SparsityHook'
|
||||
|
||||
|
||||
class LR_Schedulers(object):
|
||||
"""learning rate scheduler is defined here
|
||||
|
||||
@@ -6,10 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
if TYPE_CHECKING:
|
||||
from .builder import HOOKS, build_hook
|
||||
from .checkpoint_hook import BestCkptSaverHook, CheckpointHook
|
||||
from .compression import SparsityHook
|
||||
from .evaluation_hook import EvaluationHook
|
||||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
from .logger import TextLoggerHook, TensorboardHook
|
||||
from .logger import TensorboardHook, TextLoggerHook
|
||||
from .lr_scheduler_hook import LrSchedulerHook
|
||||
from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook,
|
||||
OptimizerHook, TorchAMPOptimizerHook)
|
||||
@@ -19,6 +20,7 @@ else:
|
||||
_import_structure = {
|
||||
'builder': ['HOOKS', 'build_hook'],
|
||||
'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'],
|
||||
'compression': ['SparsityHook'],
|
||||
'evaluation_hook': ['EvaluationHook'],
|
||||
'hook': ['Hook'],
|
||||
'iter_timer_hook': ['IterTimerHook'],
|
||||
|
||||
24
modelscope/trainers/hooks/compression/__init__.py
Normal file
24
modelscope/trainers/hooks/compression/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sparsity_hook import SparsityHook
|
||||
from .utils import SparseLinear, convert_sparse_network
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'sparsity_hook': ['SparsityHook'],
|
||||
'utils': ['convert_sparse_network', 'SparseLinear'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
131
modelscope/trainers/hooks/compression/sparsity_hook.py
Normal file
131
modelscope/trainers/hooks/compression/sparsity_hook.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.hook import Hook
|
||||
from modelscope.trainers.hooks.priority import Priority
|
||||
from modelscope.utils.checkpoint import save_checkpoint
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.SparsityHook)
|
||||
class SparsityHook(Hook):
|
||||
|
||||
PRIORITY = Priority.HIGHEST
|
||||
|
||||
def __init__(self, pruning_method, config={}, save_dir=None):
|
||||
self.pruning_method = pruning_method
|
||||
self.save_dir = save_dir
|
||||
|
||||
self.compress_module = config.get('compress_module', [])
|
||||
self.weight_rank = config.get('weight_rank', 8)
|
||||
self.weight_beta = config.get('weight_beta', 1)
|
||||
self.mask_rank = config.get('mask_rank', 8)
|
||||
self.mask_alpha1 = config.get('mask_alpha1', 1)
|
||||
self.mask_alpha2 = config.get('mask_alpha2', 1)
|
||||
|
||||
self.step = 0
|
||||
self.total_step = 0
|
||||
self.frequency = config.get('frequency', 1)
|
||||
self.initial_warmup = config.get('initial_warmup', 0.1)
|
||||
self.final_warmup = config.get('final_warmup', 0.3)
|
||||
self.initial_sparsity = config.get('initial_sparsity', 0.0)
|
||||
self.final_sparsity = config.get('final_sparsity', 0.0)
|
||||
|
||||
def before_run(self, trainer):
|
||||
import torch
|
||||
|
||||
from .utils import SparseLinear, convert_sparse_network
|
||||
|
||||
if self.save_dir is None:
|
||||
self.save_dir = trainer.work_dir
|
||||
|
||||
if len(self.compress_module) == 0:
|
||||
convert_sparse_network(
|
||||
trainer.model,
|
||||
pruning_method=self.pruning_method,
|
||||
weight_rank=self.weight_rank,
|
||||
weight_beta=self.weight_beta,
|
||||
mask_rank=self.mask_rank,
|
||||
mask_alpha1=self.mask_alpha1,
|
||||
mask_alpha2=self.mask_alpha2,
|
||||
logger=trainer.logger,
|
||||
)
|
||||
else:
|
||||
for cm in self.compress_module:
|
||||
for name, module in trainer.model.named_modules():
|
||||
if name != cm:
|
||||
continue
|
||||
convert_sparse_network(
|
||||
module,
|
||||
pruning_method=self.pruning_method,
|
||||
weight_rank=self.weight_rank,
|
||||
weight_beta=self.weight_beta,
|
||||
mask_rank=self.mask_rank,
|
||||
mask_alpha1=self.mask_alpha1,
|
||||
mask_alpha2=self.mask_alpha2,
|
||||
logger=trainer.logger,
|
||||
)
|
||||
|
||||
for i in range(len(trainer.optimizer.param_groups)):
|
||||
new_train_params = []
|
||||
for param in trainer.optimizer.param_groups[i]['params']:
|
||||
is_find = False
|
||||
for name, module in trainer.model.named_modules():
|
||||
if isinstance(module, SparseLinear):
|
||||
if torch.equal(param.half(),
|
||||
module.weight.data.half()):
|
||||
is_find = True
|
||||
break
|
||||
|
||||
if not is_find:
|
||||
new_train_params.append(param)
|
||||
|
||||
trainer.optimizer.param_groups[i]['params'] = new_train_params
|
||||
|
||||
new_params = []
|
||||
for name, module in trainer.model.named_modules():
|
||||
if isinstance(module, SparseLinear):
|
||||
new_params.extend(
|
||||
[p for p in module.parameters() if p.requires_grad])
|
||||
|
||||
trainer.optimizer.add_param_group({'params': new_params})
|
||||
|
||||
self.total_step = trainer.iters_per_epoch * trainer._max_epochs
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
from .utils import schedule_sparsity_ratio, update_network_sparsity
|
||||
|
||||
cur_sparsity = schedule_sparsity_ratio(
|
||||
self.step,
|
||||
self.total_step,
|
||||
self.frequency,
|
||||
self.initial_warmup,
|
||||
self.final_warmup,
|
||||
self.initial_sparsity,
|
||||
self.final_sparsity,
|
||||
)
|
||||
|
||||
update_network_sparsity(trainer.model, cur_sparsity)
|
||||
|
||||
if is_master():
|
||||
trainer.logger.info(
|
||||
f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}'
|
||||
)
|
||||
|
||||
self.step += 1
|
||||
|
||||
def after_run(self, trainer):
|
||||
from .utils import generate_sparse_model
|
||||
|
||||
generate_sparse_model(trainer.model, logger=trainer.logger)
|
||||
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
def _save_checkpoint(self, trainer):
|
||||
if is_master():
|
||||
trainer.logger.info('Saving checkpoint at final compress')
|
||||
cur_save_name = os.path.join(self.save_dir, 'compress_model.pth')
|
||||
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
|
||||
208
modelscope/trainers/hooks/compression/utils.py
Normal file
208
modelscope/trainers/hooks/compression/utils.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
|
||||
|
||||
class SparseBinarizer(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, mask_scores, sparsity):
|
||||
num_prune = int(mask_scores.numel() * sparsity)
|
||||
prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune]
|
||||
mask = mask_scores.clone().fill_(1)
|
||||
mask.reshape(-1)[prune_indices] = 0.0
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradOutput):
|
||||
return gradOutput, None
|
||||
|
||||
|
||||
class SparseLinear(nn.Module):
|
||||
"""
|
||||
Fully Connected layer with on the fly adaptive mask.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
pruning_method='pst',
|
||||
weight_rank=8,
|
||||
weight_beta=1.0,
|
||||
mask_rank=8,
|
||||
mask_alpha1=1.0,
|
||||
mask_alpha2=1.0,
|
||||
):
|
||||
super(SparseLinear, self).__init__()
|
||||
self.module = module
|
||||
out_features = self.module.weight.shape[0]
|
||||
in_features = self.module.weight.shape[1]
|
||||
|
||||
self.weight = self.module.weight
|
||||
self.module.weight = None
|
||||
self.module._parameters.pop('weight')
|
||||
|
||||
self.pruning_method = pruning_method
|
||||
|
||||
self.cur_sparsity = 0.0
|
||||
|
||||
if self.pruning_method == 'pst':
|
||||
self.weight_rank = weight_rank
|
||||
self.weight_beta = weight_beta
|
||||
self.mask_rank = mask_rank
|
||||
self.mask_alpha1 = mask_alpha1
|
||||
self.mask_alpha2 = mask_alpha2
|
||||
|
||||
# create trainable params
|
||||
self.weight_U = nn.Parameter(
|
||||
torch.randn(out_features, self.weight_rank).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype))
|
||||
self.weight_V = nn.Parameter(
|
||||
torch.zeros(self.weight_rank, in_features).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype))
|
||||
|
||||
self.mask_scores_A = nn.Parameter(
|
||||
torch.randn(out_features, self.mask_rank).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype))
|
||||
self.mask_scores_B = nn.Parameter(
|
||||
torch.zeros(self.mask_rank, in_features).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype))
|
||||
self.mask_scores_R = nn.Parameter(
|
||||
torch.zeros(out_features).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype))
|
||||
self.mask_scores_C = nn.Parameter(
|
||||
torch.zeros(in_features).to(
|
||||
device=self.weight.device, dtype=self.weight.dtype))
|
||||
|
||||
self.weight.requires_grad = False
|
||||
if self.module.bias is not None:
|
||||
self.module.bias.requires_grad = False
|
||||
|
||||
def forward(self, *inputs):
|
||||
if self.pruning_method == 'pst':
|
||||
weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
|
||||
mask_scores = (
|
||||
weight.abs()
|
||||
+ self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
|
||||
+ self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
|
||||
+ self.mask_scores_C.unsqueeze(0)))
|
||||
|
||||
mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)
|
||||
masked_weight = mask * weight
|
||||
|
||||
self.module.weight = masked_weight
|
||||
return self.module(*inputs)
|
||||
else:
|
||||
return self.module(*inputs)
|
||||
|
||||
def convert(self):
|
||||
if self.pruning_method == 'pst':
|
||||
weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
|
||||
mask_scores = (
|
||||
weight.abs()
|
||||
+ self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
|
||||
+ self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
|
||||
+ self.mask_scores_C.unsqueeze(0)))
|
||||
|
||||
mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)
|
||||
|
||||
masked_weight = mask * weight
|
||||
self.module.weight = nn.Parameter(masked_weight.data)
|
||||
|
||||
|
||||
def _setattr(model, name, module):
|
||||
name_list = name.split('.')
|
||||
for name in name_list[:-1]:
|
||||
model = getattr(model, name)
|
||||
setattr(model, name_list[-1], module)
|
||||
|
||||
|
||||
def convert_sparse_network(
|
||||
model,
|
||||
pruning_method,
|
||||
weight_rank,
|
||||
weight_beta,
|
||||
mask_rank,
|
||||
mask_alpha1,
|
||||
mask_alpha2,
|
||||
logger=None,
|
||||
):
|
||||
compress_module = [nn.Linear]
|
||||
try:
|
||||
from megatron import mpu
|
||||
compress_module.extend(
|
||||
[mpu.RowParallelLinear, mpu.ColumnParallelLinear])
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if type(module) in compress_module:
|
||||
new_module = SparseLinear(
|
||||
module,
|
||||
pruning_method,
|
||||
weight_rank,
|
||||
weight_beta,
|
||||
mask_rank,
|
||||
mask_alpha1,
|
||||
mask_alpha2,
|
||||
)
|
||||
|
||||
# replace original module by new sparse module
|
||||
_setattr(model, name, new_module)
|
||||
|
||||
if is_master():
|
||||
if logger:
|
||||
logger.info(f'convert {name} to sparse module.')
|
||||
else:
|
||||
print(f'convert {name} to sparse module.')
|
||||
|
||||
|
||||
def update_network_sparsity(model, sparsity):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, SparseLinear):
|
||||
module.cur_sparsity = sparsity
|
||||
|
||||
|
||||
def schedule_sparsity_ratio(
|
||||
step,
|
||||
total_step,
|
||||
frequency,
|
||||
initial_warmup,
|
||||
final_warmup,
|
||||
initial_sparsity,
|
||||
final_sparsity,
|
||||
):
|
||||
if step <= initial_warmup * total_step:
|
||||
sparsity = initial_sparsity
|
||||
elif step > (total_step - final_warmup * total_step):
|
||||
sparsity = final_sparsity
|
||||
else:
|
||||
spars_warmup_steps = initial_warmup * total_step
|
||||
spars_schedu_steps = (final_warmup + initial_warmup) * total_step
|
||||
step = (step - spars_warmup_steps) // frequency * frequency
|
||||
mul_coeff = 1 - step / (total_step - spars_schedu_steps)
|
||||
sparsity = final_sparsity + (initial_sparsity - final_sparsity) * (
|
||||
mul_coeff**3)
|
||||
return sparsity
|
||||
|
||||
|
||||
def generate_sparse_model(model, logger=None):
|
||||
# generate sparse weight for saving
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, SparseLinear):
|
||||
module.convert()
|
||||
|
||||
_setattr(model, name, module.module)
|
||||
|
||||
if is_master():
|
||||
if logger:
|
||||
logger.info(f'convert {name} weight to sparse weight, \
|
||||
sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
|
||||
)
|
||||
else:
|
||||
print(f'convert {name} weight to sparse, \
|
||||
sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
|
||||
)
|
||||
0
tests/trainers/hooks/compression/__init__.py
Normal file
0
tests/trainers/hooks/compression/__init__.py
Normal file
113
tests/trainers/hooks/compression/test_sparsity_hook.py
Normal file
113
tests/trainers/hooks/compression/test_sparsity_hook.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import ModelFile, TrainerStages
|
||||
from modelscope.utils.test_utils import create_dummy_test_dataset
|
||||
|
||||
dummy_dataset = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(5, 10)
|
||||
self.bn = nn.BatchNorm1d(10)
|
||||
|
||||
def forward(self, feat, labels):
|
||||
x = self.linear(feat)
|
||||
|
||||
x = self.bn(x)
|
||||
loss = torch.sum(x)
|
||||
return dict(logits=x, loss=loss)
|
||||
|
||||
|
||||
class SparsityHookTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
def test_sparsity_hook(self):
|
||||
json_cfg = {
|
||||
'task': 'image_classification',
|
||||
'train': {
|
||||
'work_dir':
|
||||
self.tmp_dir,
|
||||
'dataloader': {
|
||||
'batch_size_per_gpu': 2,
|
||||
'workers_per_gpu': 1
|
||||
},
|
||||
'hooks': [{
|
||||
'type': 'SparsityHook',
|
||||
'pruning_method': 'pst',
|
||||
'config': {
|
||||
'weight_rank': 1,
|
||||
'mask_rank': 1,
|
||||
'final_sparsity': 0.9,
|
||||
'frequency': 1,
|
||||
},
|
||||
}],
|
||||
},
|
||||
}
|
||||
|
||||
config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(json_cfg, f)
|
||||
|
||||
model = DummyModel()
|
||||
optimizer = SGD(model.parameters(), lr=0.01)
|
||||
lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4])
|
||||
trainer_name = Trainers.default
|
||||
kwargs = dict(
|
||||
cfg_file=config_path,
|
||||
model=model,
|
||||
train_dataset=dummy_dataset,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
max_epochs=5,
|
||||
device='cpu',
|
||||
)
|
||||
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
train_dataloader = trainer._build_dataloader_with_dataset(
|
||||
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
|
||||
trainer.register_optimizers_hook()
|
||||
trainer.register_hook_from_cfg(trainer.cfg.train.hooks)
|
||||
trainer.train_dataloader = train_dataloader
|
||||
trainer.data_loader = train_dataloader
|
||||
trainer.invoke_hook(TrainerStages.before_run)
|
||||
for i in range(trainer._epoch, trainer._max_epochs):
|
||||
trainer.invoke_hook(TrainerStages.before_train_epoch)
|
||||
for _, data_batch in enumerate(train_dataloader):
|
||||
trainer.invoke_hook(TrainerStages.before_train_iter)
|
||||
trainer.train_step(trainer.model, data_batch)
|
||||
trainer.invoke_hook(TrainerStages.after_train_iter)
|
||||
trainer.invoke_hook(TrainerStages.after_train_epoch)
|
||||
trainer.invoke_hook(TrainerStages.after_run)
|
||||
|
||||
self.assertEqual(
|
||||
torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user