This commit is contained in:
XDUWQ
2023-07-26 16:08:24 +08:00
parent 70da8b7809
commit 25d67a0b83

View File

@@ -301,7 +301,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
instance_data_dir = instance_data_name
else:
ds = MsDataset.load(instance_data_name, split='train')
instance_data_dir = os.path.dirname(next(iter(ds))['Target:FILE'])
instance_data_dir = os.path.dirname(
next(iter(ds))['Target:FILE'])
# construct concept list
if self.concepts_list is None:
@@ -320,7 +321,9 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
if not os.path.exists(concept['class_data_dir']):
os.makedirs(concept['class_data_dir'])
if not os.path.exists(concept['instance_data_dir']):
raise Exception(f"instance dataset {concept['instance_data_dir']} does not exist.")
raise Exception(
f"instance dataset {concept['instance_data_dir']} does not exist."
)
# Adding a modifier token which is optimized
self.modifier_token_id = []
@@ -524,7 +527,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
"""
for i, concept in enumerate(self.concepts_list):
class_images_dir = Path(concept['class_data_dir'])
print("-------class_images_dir: ", class_images_dir)
print('-------class_images_dir: ', class_images_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True)