diff --git a/docker/build_image.py b/docker/build_image.py index 22fc21bd..a3455a6d 100644 --- a/docker/build_image.py +++ b/docker/build_image.py @@ -202,6 +202,11 @@ class CPUImageBuilder(Builder): class GPUImageBuilder(Builder): + def init_args(self, args) -> Any: + if not args.optimum_version: + args.optimum_version = '2.0.0' + return super().init_args(args) + def generate_dockerfile(self) -> str: meta_file = './docker/install.sh' # pushd ~ popd is to solve the tf cannot use gpu problem. @@ -220,7 +225,7 @@ RUN pushd $(dirname $(python -c 'print(__import__("tensorflow").__file__)')) && version_args = ( f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} ' f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} ' - f'{self.args.flashattn_version}') + f'{self.args.flashattn_version} {self.args.optimum_version}') base_image = ( f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-{self.args.python_tag}-' f'torch{self.args.torch_version}-tf{self.args.tf_version}-base') @@ -288,6 +293,8 @@ class LLMImageBuilder(Builder): args.autogptq_version = '0.7.1' if not args.flashattn_version: args.flashattn_version = '2.7.4.post1' + if not args.optimum_version: + args.optimum_version = '2.0.0' return args def generate_dockerfile(self) -> str: @@ -299,7 +306,7 @@ class LLMImageBuilder(Builder): version_args = ( f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} ' f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} ' - f'{self.args.flashattn_version}') + f'{self.args.flashattn_version} {self.args.optimum_version}') with open('docker/Dockerfile.ubuntu', 'r') as f: content = f.read() content = content.replace('{base_image}', self.args.base_image) @@ -464,6 +471,7 @@ parser.add_argument('--vllm_version', type=str, default=None) parser.add_argument('--lmdeploy_version', type=str, default=None) parser.add_argument('--flashattn_version', type=str, default=None) parser.add_argument('--autogptq_version', type=str, default=None) +parser.add_argument('--optimum_version', type=str, default=None) parser.add_argument('--modelscope_branch', type=str, default='master') parser.add_argument('--modelscope_version', type=str, default='9.99.0') parser.add_argument('--swift_branch', type=str, default='main')