Files
modelscope/modelscope/utils/data_collators.py
Jintao ba4b9fc43f Added full parameter sft to llm (#402)
* Optimized code

* update parse_args

* fix get_logger bug

* update parse_args

* Added full parameter fine-tuning

* Add support_bf16 warning

* Modify the code format and fix bugs
2023-07-24 15:52:09 +08:00

77 lines
3.1 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
from collections import OrderedDict
from collections.abc import Mapping
from typing import Any, List, Optional, Tuple
from .logger import get_logger
logger = get_logger()
class RemoveColumnsCollator:
"""Remove specified columns from the input mini-batch, and convert them to attributes.
For example: if columns_to_remove = ['id'], then user should call batch.id instead of batch['id'].
Args:
data_collator: An inner data collator to collate the mini-batch
columns_to_remove(`List[str]`): The redundant columns to be removed from the mini-batch
model_name(`Optional[str]`): An optional model name to print into log
description(`Optional[str]`): An optional description to print into log
"""
def __init__(
self,
data_collator,
columns_to_remove: List[str],
model_name: Optional[str] = None,
description: Optional[str] = None,
):
self.data_collator = data_collator
self.columns_to_remove = columns_to_remove
self.description = description
self.model_name = model_name
self.message_logged = False
def _remove_columns(self, feature: Mapping) -> Tuple[Mapping, Any]:
if not isinstance(feature, Mapping):
return feature, None
if not self.message_logged and self.model_name:
ignored_columns = list(
set(feature.keys()) - set(self.columns_to_remove))
if len(ignored_columns) > 0:
dset_description = '' if self.description is None else f'in the {self.description} set'
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
f"Legal columns: {', '.join(self.columns_to_remove)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
' you can safely ignore this message.')
self.message_logged = True
feature_clean = {
k: v
for k, v in feature.items() if k in self.columns_to_remove
}
feature_unused = {
k: v
for k, v in feature.items() if k not in self.columns_to_remove
}
return feature_clean, feature_unused
def __call__(self, features: List[Mapping]):
features_clean = []
features_unused = []
for feature in features:
feature, feature_unused = self._remove_columns(feature)
features_clean.append(feature)
features_unused.append(feature_unused)
data = OrderedDict(self.data_collator(features_clean))
if features_unused[0] is not None:
for key in features_unused[0].keys():
setattr(data, key, [
feature_unused[key] for feature_unused in features_unused
])
return data