This commit is contained in:
XDUWQ
2023-07-12 15:19:09 +08:00
parent d6368b2617
commit 07296a837a

View File

@@ -68,11 +68,6 @@ class CustomCheckpointProcessor(CheckpointProcessor):
class CustomDiffusionDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
concepts_list,
@@ -85,6 +80,20 @@ class CustomDiffusionDataset(Dataset):
hflip=False,
aug=True,
):
"""A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
Args:
concepts_list: contain multiple concepts, instance_prompt, class_prompt, etc.
tokenizer: pretrained tokenizer.
size: the size of images.
mask_size: the mask size of images.
center_crop: execute center crop or not.
with_prior_preservation: flag to add prior preservation loss.
hflip: whether to flip horizontally.
aug: perform data augmentation.
"""
self.size = size
self.mask_size = mask_size
self.center_crop = center_crop
@@ -219,84 +228,6 @@ class CustomDiffusionDataset(Dataset):
return example
class ClassDataset(Dataset):
def __init__(
self,
tokenizer,
class_data_root=None,
class_prompt=None,
class_num_images=None,
size=512,
center_crop=False,
):
"""A dataset to prepare class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
Args:
tokenizer: The tokenizer to use for tokenization.
class_data_root: The saved class data path.
class_prompt: The prompt to use for class images.
class_num_images: The number of class images to use.
size: The size to resize the images.
center_crop: Whether to do center crop or random crop.
"""
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
if class_num_images is not None:
self.num_class_images = min(
len(self.class_images_path), class_num_images)
else:
self.num_class_images = len(self.class_images_path)
self.class_prompt = class_prompt
else:
raise ValueError(
f"Class {self.class_data_root} class data root doesn't exists."
)
self.image_transforms = transforms.Compose([
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size)
if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def __len__(self):
return self.num_class_images
def __getitem__(self, index):
example = {}
if self.class_data_root:
class_image = Image.open(
self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
if not class_image.mode == 'RGB':
class_image = class_image.convert('RGB')
example['pixel_values'] = self.image_transforms(class_image)
class_text_inputs = self.tokenizer(
self.class_prompt,
max_length=self.tokenizer.model_max_length,
truncation=True,
padding='max_length',
return_tensors='pt')
input_ids = torch.squeeze(class_text_inputs.input_ids)
example['input_ids'] = input_ids
return example
class PromptDataset(Dataset):
def __init__(self, prompt, num_samples):