Set default topk when num_classes < 5 (#269)

* set default topk when num_classes < 6

* add comments

* format changes with yapf

---------

Co-authored-by: Xu Wenqing <xwq391974@alibaba-inc.com>
This commit is contained in:
Xu Wenqing
2023-04-18 20:30:13 +08:00
committed by GitHub
parent e841261557
commit e8bca5a11e

View File

@@ -498,6 +498,11 @@ class ImageClassifitionTrainer(BaseTrainer):
metric_options = self.cfg.evaluation.get('metric_options', {})
if 'topk' in metric_options.keys():
metric_options['topk'] = tuple(metric_options['topk'])
# mmcls will set the default value of topk to (1, 5) which
# will cause error when number of classes less then 5.
# set topk as (1,) if len(CLASSES) < 5:
elif len(CLASSES) < 5:
metric_options['topk'] = (1, )
if self.cfg.evaluation.metrics:
eval_results = dataset.evaluate(
results=outputs,