pin optimum 2.0.0 for gpu/llm image

This commit is contained in:
hjh0119
2026-01-13 18:18:15 +08:00
parent 754e630e69
commit e6ef4af1a1

View File

@@ -202,6 +202,11 @@ class CPUImageBuilder(Builder):
class GPUImageBuilder(Builder):
def init_args(self, args) -> Any:
if not args.optimum_version:
args.optimum_version = '2.0.0'
return super().init_args(args)
def generate_dockerfile(self) -> str:
meta_file = './docker/install.sh'
# pushd ~ popd is to solve the tf cannot use gpu problem.
@@ -220,7 +225,7 @@ RUN pushd $(dirname $(python -c 'print(__import__("tensorflow").__file__)')) &&
version_args = (
f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} '
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} '
f'{self.args.flashattn_version}')
f'{self.args.flashattn_version} {self.args.optimum_version}')
base_image = (
f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-{self.args.python_tag}-'
f'torch{self.args.torch_version}-tf{self.args.tf_version}-base')
@@ -288,6 +293,8 @@ class LLMImageBuilder(Builder):
args.autogptq_version = '0.7.1'
if not args.flashattn_version:
args.flashattn_version = '2.7.4.post1'
if not args.optimum_version:
args.optimum_version = '2.0.0'
return args
def generate_dockerfile(self) -> str:
@@ -299,7 +306,7 @@ class LLMImageBuilder(Builder):
version_args = (
f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} '
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} '
f'{self.args.flashattn_version}')
f'{self.args.flashattn_version} {self.args.optimum_version}')
with open('docker/Dockerfile.ubuntu', 'r') as f:
content = f.read()
content = content.replace('{base_image}', self.args.base_image)
@@ -464,6 +471,7 @@ parser.add_argument('--vllm_version', type=str, default=None)
parser.add_argument('--lmdeploy_version', type=str, default=None)
parser.add_argument('--flashattn_version', type=str, default=None)
parser.add_argument('--autogptq_version', type=str, default=None)
parser.add_argument('--optimum_version', type=str, default=None)
parser.add_argument('--modelscope_branch', type=str, default='master')
parser.add_argument('--modelscope_version', type=str, default='9.99.0')
parser.add_argument('--swift_branch', type=str, default='main')