mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
fix bugs
This commit is contained in:
@@ -68,11 +68,6 @@ class CustomCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
|
||||
class CustomDiffusionDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
concepts_list,
|
||||
@@ -85,6 +80,20 @@ class CustomDiffusionDataset(Dataset):
|
||||
hflip=False,
|
||||
aug=True,
|
||||
):
|
||||
"""A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
|
||||
Args:
|
||||
concepts_list: contain multiple concepts, instance_prompt, class_prompt, etc.
|
||||
tokenizer: pretrained tokenizer.
|
||||
size: the size of images.
|
||||
mask_size: the mask size of images.
|
||||
center_crop: execute center crop or not.
|
||||
with_prior_preservation: flag to add prior preservation loss.
|
||||
hflip: whether to flip horizontally.
|
||||
aug: perform data augmentation.
|
||||
|
||||
"""
|
||||
self.size = size
|
||||
self.mask_size = mask_size
|
||||
self.center_crop = center_crop
|
||||
@@ -219,84 +228,6 @@ class CustomDiffusionDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
class ClassDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
class_data_root=None,
|
||||
class_prompt=None,
|
||||
class_num_images=None,
|
||||
size=512,
|
||||
center_crop=False,
|
||||
):
|
||||
"""A dataset to prepare class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use for tokenization.
|
||||
class_data_root: The saved class data path.
|
||||
class_prompt: The prompt to use for class images.
|
||||
class_num_images: The number of class images to use.
|
||||
size: The size to resize the images.
|
||||
center_crop: Whether to do center crop or random crop.
|
||||
|
||||
"""
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if class_data_root is not None:
|
||||
self.class_data_root = Path(class_data_root)
|
||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||
self.class_images_path = list(self.class_data_root.iterdir())
|
||||
if class_num_images is not None:
|
||||
self.num_class_images = min(
|
||||
len(self.class_images_path), class_num_images)
|
||||
else:
|
||||
self.num_class_images = len(self.class_images_path)
|
||||
self.class_prompt = class_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Class {self.class_data_root} class data root doesn't exists."
|
||||
)
|
||||
|
||||
self.image_transforms = transforms.Compose([
|
||||
transforms.Resize(
|
||||
size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size)
|
||||
if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_class_images
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(
|
||||
self.class_images_path[index % self.num_class_images])
|
||||
class_image = exif_transpose(class_image)
|
||||
|
||||
if not class_image.mode == 'RGB':
|
||||
class_image = class_image.convert('RGB')
|
||||
example['pixel_values'] = self.image_transforms(class_image)
|
||||
|
||||
class_text_inputs = self.tokenizer(
|
||||
self.class_prompt,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
input_ids = torch.squeeze(class_text_inputs.input_ids)
|
||||
example['input_ids'] = input_ids
|
||||
|
||||
return example
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
|
||||
Reference in New Issue
Block a user