From 07296a837a3174cf30148c55d2b4e0e06c91c960 Mon Sep 17 00:00:00 2001 From: XDUWQ <1300964705@qq.com> Date: Wed, 12 Jul 2023 15:19:09 +0800 Subject: [PATCH] fix bugs --- .../custom_diffusion_trainer.py | 97 +++---------------- 1 file changed, 14 insertions(+), 83 deletions(-) diff --git a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py index 63bdf338..ad51e23c 100644 --- a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py +++ b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py @@ -68,11 +68,6 @@ class CustomCheckpointProcessor(CheckpointProcessor): class CustomDiffusionDataset(Dataset): - """ - A dataset to prepare the instance and class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. - """ - def __init__( self, concepts_list, @@ -85,6 +80,20 @@ class CustomDiffusionDataset(Dataset): hflip=False, aug=True, ): + """A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + + Args: + concepts_list: contain multiple concepts, instance_prompt, class_prompt, etc. + tokenizer: pretrained tokenizer. + size: the size of images. + mask_size: the mask size of images. + center_crop: execute center crop or not. + with_prior_preservation: flag to add prior preservation loss. + hflip: whether to flip horizontally. + aug: perform data augmentation. + + """ self.size = size self.mask_size = mask_size self.center_crop = center_crop @@ -219,84 +228,6 @@ class CustomDiffusionDataset(Dataset): return example -class ClassDataset(Dataset): - - def __init__( - self, - tokenizer, - class_data_root=None, - class_prompt=None, - class_num_images=None, - size=512, - center_crop=False, - ): - """A dataset to prepare class images with the prompts for fine-tuning the model. - It pre-processes the images and the tokenizes prompts. - - Args: - tokenizer: The tokenizer to use for tokenization. - class_data_root: The saved class data path. - class_prompt: The prompt to use for class images. - class_num_images: The number of class images to use. - size: The size to resize the images. - center_crop: Whether to do center crop or random crop. - - """ - self.size = size - self.center_crop = center_crop - self.tokenizer = tokenizer - - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - if class_num_images is not None: - self.num_class_images = min( - len(self.class_images_path), class_num_images) - else: - self.num_class_images = len(self.class_images_path) - self.class_prompt = class_prompt - else: - raise ValueError( - f"Class {self.class_data_root} class data root doesn't exists." - ) - - self.image_transforms = transforms.Compose([ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) - if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) - - def __len__(self): - return self.num_class_images - - def __getitem__(self, index): - example = {} - - if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images]) - class_image = exif_transpose(class_image) - - if not class_image.mode == 'RGB': - class_image = class_image.convert('RGB') - example['pixel_values'] = self.image_transforms(class_image) - - class_text_inputs = self.tokenizer( - self.class_prompt, - max_length=self.tokenizer.model_max_length, - truncation=True, - padding='max_length', - return_tensors='pt') - input_ids = torch.squeeze(class_text_inputs.input_ids) - example['input_ids'] = input_ids - - return example - - class PromptDataset(Dataset): def __init__(self, prompt, num_samples):