diff --git a/modelscope/trainers/cv/image_classifition_trainer.py b/modelscope/trainers/cv/image_classifition_trainer.py index f15fd5e3..0ecfe173 100644 --- a/modelscope/trainers/cv/image_classifition_trainer.py +++ b/modelscope/trainers/cv/image_classifition_trainer.py @@ -498,6 +498,11 @@ class ImageClassifitionTrainer(BaseTrainer): metric_options = self.cfg.evaluation.get('metric_options', {}) if 'topk' in metric_options.keys(): metric_options['topk'] = tuple(metric_options['topk']) + # mmcls will set the default value of topk to (1, 5) which + # will cause error when number of classes less then 5. + # set topk as (1,) if len(CLASSES) < 5: + elif len(CLASSES) < 5: + metric_options['topk'] = (1, ) if self.cfg.evaluation.metrics: eval_results = dataset.evaluate( results=outputs,