diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 715d9946..3995fcfa 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -975,6 +975,8 @@ class EpochBasedTrainer(BaseTrainer): dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) else: sampler = None + if not isinstance(dataset, torch.utils.data.IterableDataset): + kwargs['shuffle'] = shuffle batch_sampler = None