diff --git a/docker/build_image.py b/docker/build_image.py index 3e4fda0e..49ffd16d 100644 --- a/docker/build_image.py +++ b/docker/build_image.py @@ -341,6 +341,10 @@ class LLMImageBuilder(Builder): class SwiftImageBuilder(LLMImageBuilder): def init_args(self, args) -> Any: + if not args.base_image: + args.base_image = 'nvidia/cuda:12.6.3-devel-ubuntu22.04' + if not args.cuda_version: + args.cuda_version = '12.6.3' if not args.torch_version: args.torch_version = '2.7.1' args.torchaudio_version = '2.7.1'