mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
Fix a bug that the logging file cannot save the correct lr, which is zero instead
This bug is a result of float rounding when saving key-value pairs to log files, which is reported by a user. Now the solution is to remove the rounding operation of all values, instead of only the lr value, which I think may be too specific. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10684029
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.logger.base import LoggerHook
|
||||
from modelscope.utils.constant import LogKeys, ModeKeys
|
||||
@@ -30,6 +31,8 @@ class TextLoggerHook(LoggerHook):
|
||||
reset_flag (bool, optional): Whether to clear the output buffer after
|
||||
logging. Default: False.
|
||||
out_dir (str): The directory to save log. If is None, use `trainer.work_dir`
|
||||
ignore_rounding_keys (`Union[str, List]`): The keys to ignore float rounding, default 'lr'
|
||||
rounding_digits (`int`): The digits of rounding, exceeding parts will be ignored.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -37,13 +40,20 @@ class TextLoggerHook(LoggerHook):
|
||||
interval=10,
|
||||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
out_dir=None):
|
||||
out_dir=None,
|
||||
ignore_rounding_keys='lr',
|
||||
rounding_digits=5):
|
||||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
||||
by_epoch)
|
||||
self.by_epoch = by_epoch
|
||||
self.time_sec_tot = 0
|
||||
self.out_dir = out_dir
|
||||
self._logged_keys = [] # store the key has been logged
|
||||
if isinstance(ignore_rounding_keys,
|
||||
str) or ignore_rounding_keys is None:
|
||||
ignore_rounding_keys = [ignore_rounding_keys]
|
||||
self.ignore_rounding_keys = ignore_rounding_keys
|
||||
self.rounding_digits = rounding_digits
|
||||
|
||||
def before_run(self, trainer):
|
||||
super(TextLoggerHook, self).before_run(trainer)
|
||||
@@ -139,7 +149,9 @@ class TextLoggerHook(LoggerHook):
|
||||
# dump log in json format
|
||||
json_log = OrderedDict()
|
||||
for k, v in log_dict.items():
|
||||
json_log[k] = self._round_float(v)
|
||||
json_log[
|
||||
k] = v if k in self.ignore_rounding_keys else self._round_float(
|
||||
v, self.rounding_digits)
|
||||
|
||||
if is_master():
|
||||
with open(self.json_log_path, 'a+') as f:
|
||||
@@ -148,7 +160,7 @@ class TextLoggerHook(LoggerHook):
|
||||
|
||||
def _round_float(self, items, ndigits=5):
|
||||
if isinstance(items, list):
|
||||
return [self._round_float(item) for item in items]
|
||||
return [self._round_float(item, ndigits) for item in items]
|
||||
elif isinstance(items, float):
|
||||
return round(items, ndigits)
|
||||
else:
|
||||
|
||||
@@ -70,6 +70,7 @@ def train_func(work_dir, dist=False, log_interval=3, imgs_per_gpu=4):
|
||||
},
|
||||
{
|
||||
'type': 'TextLoggerHook',
|
||||
'ignore_rounding_keys': None,
|
||||
'interval': log_interval
|
||||
},
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user