Files
modelscope/examples/pytorch/DiT_ImageNet_Demo.ipynb

290 lines
1.4 MiB
Plaintext
Raw Normal View History

2024-02-21 14:17:07 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "355UKMUQJxFd"
},
"source": [
"# Scalable Diffusion Models with Transformer (DiT)\n",
"\n",
"This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.\n",
"\n",
"[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zJlgLkSaKn7u"
},
"source": [
"# 1. Setup\n",
"\n",
"We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/facebookresearch/DiT.git\n",
"import DiT, os\n",
"os.chdir('DiT')\n",
"os.environ['PYTHONPATH'] = '/env/python:/content/DiT'\n",
"!pip install diffusers timm --upgrade"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2024-02-21T02:55:56.417045Z",
"iopub.status.busy": "2024-02-21T02:55:56.416754Z",
"iopub.status.idle": "2024-02-21T02:56:06.911052Z",
"shell.execute_reply": "2024-02-21T02:56:06.910591Z",
"shell.execute_reply.started": "2024-02-21T02:55:56.417025Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"正克隆到 'DiT'...\n",
"remote: Enumerating objects: 102, done.\u001b[K\n",
"remote: Counting objects: 100% (78/78), done.\u001b[K\n",
"remote: Compressing objects: 100% (43/43), done.\u001b[K\n",
"remote: Total 102 (delta 55), reused 35 (delta 35), pack-reused 24\u001b[K\n",
"接收对象中: 100% (102/102), 6.37 MiB | 4.06 MiB/s, 完成.\n",
"处理 delta 中: 100% (56/56), 完成.\n",
"Looking in indexes: https://mirrors.aliyun.com/pypi/simple\n",
"Requirement already satisfied: diffusers in /opt/conda/lib/python3.10/site-packages (0.26.3)\n",
"Requirement already satisfied: timm in /opt/conda/lib/python3.10/site-packages (0.9.16)\n",
"Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.10/site-packages (from diffusers) (7.0.1)\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from diffusers) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub>=0.20.2 in /opt/conda/lib/python3.10/site-packages (from diffusers) (0.20.3)\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from diffusers) (1.26.3)\n",
"Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from diffusers) (2023.12.25)\n",
"Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from diffusers) (2.31.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from diffusers) (0.4.1)\n",
"Requirement already satisfied: Pillow in /opt/conda/lib/python3.10/site-packages (from diffusers) (10.2.0)\n",
"Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from timm) (2.1.2+cu121)\n",
"Requirement already satisfied: torchvision in /opt/conda/lib/python3.10/site-packages (from timm) (0.16.2+cu121)\n",
"Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from timm) (6.0.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.20.2->diffusers) (2023.10.0)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.20.2->diffusers) (4.65.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.20.2->diffusers) (4.9.0)\n",
"Requirement already satisfied: packaging>=20.9 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.20.2->diffusers) (23.1)\n",
"Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.17.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers) (2.0.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers) (2023.11.17)\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->timm) (1.12)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->timm) (2.8.4)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->timm) (3.1.2)\n",
"Requirement already satisfied: triton==2.1.0 in /opt/conda/lib/python3.10/site-packages (from torch->timm) (2.1.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->timm) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->timm) (1.3.0)\n",
"\u001b[33mDEPRECATION: pytorch-lightning 1.7.7 has a non-standard dependency specifier torch>=1.9.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"2024-02-21 10:56:06,878 - modelscope - INFO - PyTorch version 2.1.2+cu121 Found.\n",
"2024-02-21 10:56:06,880 - modelscope - INFO - TensorFlow version 2.14.0 Found.\n",
"2024-02-21 10:56:06,881 - modelscope - INFO - Loading ast index from /mnt/workspace/.cache/modelscope/ast_indexer\n",
"2024-02-21 10:56:06,907 - modelscope - INFO - Loading done! Current index file version is 1.12.0, with md5 509123dba36c5e70a95f6780df348471 and a total number of 964 components indexed\n"
]
}
],
"source": [
"# DiT imports:\n",
"import torch\n",
"from torchvision.utils import save_image\n",
"from diffusion import create_diffusion\n",
"from diffusers.models import AutoencoderKL\n",
"from download import find_model\n",
"from models import DiT_XL_2\n",
"from PIL import Image\n",
"from IPython.display import display\n",
"from modelscope import snapshot_download\n",
"torch.set_grad_enabled(False)\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"if device == \"cpu\":\n",
" print(\"GPU not found. Using CPU instead.\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AXpziRkoOvV9"
},
"source": [
"# Download DiT-XL/2 Models\n",
"\n",
"You can choose between a 512x512 model and a 256x256 model. You can swap-out the LDM VAE, too."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2024-02-21T02:58:20.338677Z",
"iopub.status.busy": "2024-02-21T02:58:20.338356Z",
"iopub.status.idle": "2024-02-21T02:58:31.246188Z",
"shell.execute_reply": "2024-02-21T02:58:31.245600Z",
"shell.execute_reply.started": "2024-02-21T02:58:20.338656Z"
},
"id": "EWG-WNimO59K",
"tags": []
},
"outputs": [],
"source": [
"image_size = 256 #@param [256, 512]\n",
"vae_model = snapshot_download(\"AI-ModelScope/sd-vae-ft-ema\") #@param [\"stabilityai/sd-vae-ft-mse\", \"stabilityai/sd-vae-ft-ema\"]\n",
"latent_size = int(image_size) // 8\n",
"# Load model:\n",
"model = DiT_XL_2(input_size=latent_size).to(device)\n",
"DiT_model = snapshot_download(f\"AI-ModelScope/DiT-XL-2-{image_size}x{image_size}\")\n",
"state_dict = find_model(f\"{DiT_model}/DiT-XL-2-{image_size}x{image_size}.pt\")\n",
"model.load_state_dict(state_dict)\n",
"model.eval() # important!\n",
"vae = AutoencoderKL.from_pretrained(vae_model).to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5JTNyzNZKb9E"
},
"source": [
"# 2. Sample from Pre-trained DiT Models\n",
"\n",
"You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-21T02:58:36.546161Z",
"iopub.status.busy": "2024-02-21T02:58:36.545823Z",
"iopub.status.idle": "2024-02-21T03:00:26.517853Z",
"shell.execute_reply": "2024-02-21T03:00:26.517365Z",
"shell.execute_reply.started": "2024-02-21T02:58:36.546137Z"
},
"id": "-Hw7B5h4Kk4p",
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 250/250 [01:49<00:00, 2.29it/s]\n"
]
},
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAIGBAoDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDgPC/hnSNR0O2ubu18yV925vMcZw5HY46CuhbwX4ZRf+QcT6Hz5P8A4qq3gvP/AAidngcHf06/6xq30DMzBR+dZNu+hySnJSaOfXwXobOf+Jd8uOMTyf8AxVSSeBtCB3JZgjP3TM4/rXQmTylGCMdOKikUbQ2cr+tXbTcXtJHNt4O0PLBbAA/9dZOP/HqRvBuipGGNhuPp5zj/ANmropGZlHl8Y5OR1oG9mHGCRjntUK5SnIwIfB+hso83S9vPJ86T/wCKqZfA+gsTtss98ebJx/49W4zMGVVyR0p6YSQknO3ptqtR873MNPAmguoxY/N0OZZP/iqd/wAIV4YU4axJI7+dJg/+PV0MMEsxLQROw74Bq3ZeHtRvcmGxlYdPmUimoyHzS2RyB8EeH3y6WAVR2M0n/wAVVRvB2iqwAsBtz182Tn/x6vZdG+H80kOb9vs6H+BTlv8A61bsfw/0NNpeOWRh3L0WZVpdDwpfA3h5otxsgp9TNJz/AOPVWbwXoqS/8g/cp6fvpB/7NXvl18P9Kmi2wNLA3Zg279DWRc/DyRY2ZLhZCo+UbcE/WnYLTPJk8DeHDH81hg+vnyf/ABVJL4E8PHJSw2jPTzpDj/x6uyvdHmsGljmUoB/eH+eKpI9uN+6XBA2496l3QterOTHgnQMEGxBI9JZM/wDoVOHgbQZEytjt+ssmf/Qq6QSbiTGyk4wGNIEmdAfMRT6g9KV2TeXc5yTwP4fCfu7QFuh/fSf/ABVRHwLo3QWfQdfNf/4quqEQjjLn5mB6jrT0eJ2yQQ23v/hS5mO8u5yS+BdHPSzyen+sf/4qmnwPpCAl7EcH/nq//wAVXTGfypiCwwRwRTllzGAMYzxk01PTUXNI5UeDdCI2izyT38x//iqH8C6SUJS15A/56P8A4103nL5hVwgwOAOmagluN3KON3p61XMLnkYI8EaMvLWeR3/ev/8AFUh8F6NuOLDK9v3z/wCNdBkupYtgA8ipZJCQrxsG/vZ4xTuNOT6nPDwXoYfD6eRxn/XSf/FU5fB/hsnH2HJPbzpOP/Hq3BdRu77uXxjINOh8u4cuNoAHJ6E0rheXcwz4M8PKvOmnI/6bSc/+PUieDvDbc/2f+HnSf/FVrT3IMhUcBeq560faoxGNoyQe45oFzS6GU3g3w6SQumlcf9NpP/iqhbwboWT/AKAB9ZpP/iq25LyVzgIFOOBTd+4bZD+NJsXNLuY3hPwb4f1PxBr1teWe+C0+z+SvnONu5WLdDk8gda7A/DbwUo/5Bf8A5My//F1yWiyFfFev+WSFP2fof9g11kEybvnMgP1q0dMfhQ0/DrwV20r/AMmpf/iqenwy8ISD5dJB/wC3mX/4qte3S2lH+sJ+prRjto41BWT/AMeoKOcX4WeEP4tK/wDJmX/4qmy/DDweo+TSAT73Uw/9mrqROQcYVh9abNdhFywRfxpgckPhr4WzzoKfhdzf/F08fDTwieuh4/7epv8A4uujTUoX4WZS3pUoaZzlWXFAHKSfDHwptOzSAP8At5l/+KqhJ8M/Du7C2GP+20n/AMVXdmNz9+YL+NMaO3QZe5BP1pAcMPhj4eJ5tMD/AK6yf/FVMnwy8L9DZEn/AK7yf/FV0kskRbCXu2nxwQsMtf5/KgDlpfhj4awQlnt9zNIf/ZqZH8MPDuPmtc/9tZB/7NXapbWarlroH/gVRs1jnAmBH+8KYHHt8NvCycG0z/23k/8AiqX/AIVz4TA5sT/3/k/+KrrjDaSDKupHqWqRdNideCCPUGgDjf8AhXng/p9j5/67yf8AxVNf4deFCPksv/I8n/xVdmdHiPTcKibRoxn53H1oA4hvhz4fydtmuPeeT/4qmL8OtCP/AC5qT/13k/8Aiq7VtIXGA5P41H/ZRU53n86AOPPw20b+GwB/7byf/FU5fhrop+9p/wCU0n/xVdslmwH+sIH1oNmxPFwQPrQByafDHw8w5scf9tpP/iqm/wCFXeGFX5rUZ/67yf8AxVdN9jx966bHtTjbwBcmeR/xoA5GX4Z+G1GUsc/9t5P/AIqqzfD3w2Dg2GPfzpP/AIquwlmtYRja59cGqLalbDJS1kJ9SaAOfHgHwqpG/TmP/beQD/0KrafDvwvIPk0Mt7/apf8A4qr/APaUqtmO2B+tOOr6qwPlrHGO1ICknw28Nc79BA/7e5f/AIqnN8OvCCjLaK/4XEv/AMVWhDqmpZAm2N9K0RqhdPugEeuRmmBzqfDzwa440h8/9fEv/wAVSt8N/CeMppOfrcyj/wBmrYl1K9/5Z2wAPQk0Ce/kHzhUHtQBwereBtHj8W+HtOs9LIjvftPmxJNIxfZGGH8WRjk8VtH4W6SrHOhTgAd3l/8Aiq1NMMg+LvgsyPu/4/sZH/TA17mCD0rOUbvexcZWWx84v8O/D8eQ+lOpXggyy/8AxVRf8IB4dHXTf/I0n/xVfSR2MCGAI9CKrT6XYXP+utIXPqUGfzrPkfSQ+ddj53HgDw5303/yPJ/8VSHwD4cxxpv/AJHk/wDiq93ufB+k3GSsTRH1Rv8AGub1LwVeWyl7VxcIOcAYb8qzlTqrZlKUWeU/8IJ4cB503/yPJ/8AFUf8IJ4c/wCgd/5Hk/8Aiq6uaFonZHUqwOCCORUJXFcznUXVlcqOY/4Qbw5/0Dv/ACPJ/wDFUf8ACDeHP+gd/wCR5P8A4quo8h2G5RkdqYYzkAgg0+efdj5PI5hvA/h0HjT/APyNJ/8AFVA/grQAeLD/AMjSf/FV1s0LodrDnrUDodp4wafPPuX7NdjlP+EN0H/nw/8AIz//ABVJ/wAIboX/AD44/wC2z/8AxVdIY8MeKbs5queXcOSPY53/AIQ3Qv8Anx/8jP8A/FU4eC9CP/Lgf+/r/wDxVdAFB471rWVkLhVwnzgc+jUOpJdSo0lJ2sc1Y+APDsswEthlT/02k/8Aiqvz/DnwvAy50wY7/wCkS/8AxVdc0dvbQgMCrbcEVlXN4zO0bHcpGFzXJOvO+kmenQw1O2sV9xzs/gXwmhBXTMKR/wA/EvH/AI9WXd+EPDkLfLp/ftPJ/wDFVrSXuGaMnpn8Kqvc+dCcnkZrJ1qv8z+8644Wivsr7kYz+GfD6xE/YBu/67Sd/wDgVQS+G9DEW9LIAk/89X/xq4ZfkD9+RUfmhrfbke1Cq1f5n95r9Ww6+wvuRRfw7oyhf9EAPVv3r8D86j/sHRmyRZbR/wBdX/xq3PKAcDucfWmuCIBk8tycn9
"image/png": "iVBORw0KGgoAAAANSUhEUgAABAoAAAIGCAIAAABS1Po6AAEAAElEQVR4Aez917JlTZIn9m2tj0j1idLV1dV6tOgREDM2IOaCRpjRaASN4C0fgcZn4AVfgLc0GPEWoJEGKgMIYjAYYS2mq0t+MjPPOVtL/v4e+2RVD4dGMbe1Tubea68V4eHh4eHh4eHh0en8+vo1BX5NgV9T4NcU+DUFfk2BX1Pg1xT4NQV+TYFfU+DXFPg1BX5NgV9T4NcU+DUFfk2BX1Pg1xT4NQV+TYFfU+BXKdBtP/4n/8v/+al/nExH3Uuve+rMxqPZZHLYHObj8WAx2R9Og/7g3D1tVqvdbtft90/n836z7Xe6/V63P+ifTpdT19VbL1eX46nf7fR63d6wN55NR+PRcXPYbQ953usM+8PL6bxarc4SjEZAdbudy/Fy6fYnyuidL6dL53KZ3i5Gs+npeNpv1qfzqdfv7TvnzXpzOpyG3f5oMOoPfA8undN6e+hcOrP5/Lw7dE+H0RAQaF0uvd6l00+yTrdzBqDnqz8aDaZj1dlt1rvVcn88jccTVTgcDpvt/nA8D4cD6SFxOB/AHwDQHUzGs/liAe2nzep8Oe7PByh2Lufj8bjfHM6ny6jXmy2mL1++WK/Xh/1pOMjfZDIeTyco9n75NJnOX378evW03G933ePhdDyut7vdqXPqdPrd3nTUm6lNvwsBMC+X8/lwOF+6i/vXu+P5dDycj/vTYXu+XPbHY6/Xn01uBoPR8Xi4nI79cX93OqLnHhrjab/f2e/2Q416Pu/Us9Ppno6zwVApwPSHyD/o9nqn8+VwOhwPBz9HI804UFzPpSn2kNir+Hg8Rj8pJUT54+E4GU+lOu72p/2hO+jsU/vdZHYzGqJh53iSYHDunCeTSb/f3+3xyUAarYoml8tpNp0jf683OO73OOTUkVabXCQeTcbdQe8AxOXcv3SG3W6/q81Auyy3+/F0NhzDsrPd7C/d7mjYG2rjzmW72lw6yH9RCbgPBkMXzoPkbrPpXpAOpZHyqE37w/7xhH3OnisVw6JJ59LFA53D5T/7X/9v0Ow/+0//08PhiCUUhwIaGVDoXZ6v8Hh6TD5gAn4XW+FXjHg61RvvfOdV+NidB0mTx/5jxAbCk2Rz5eMCeRAGikuJOBiaZ2/O53Ne+4J5P2zsZ/HyZa9mea4vnpP+gmswyGm/26n1duf+AKaf58t5vz/IKGll7wXh49FzkNFZibosPEc6wEAb9VJ9zDkYAK53qICSEFD6zWa72WwwKiYfDUfyyiJZ+hqSFjngiW4+qo4XtQK8nqA5upzDahov3RSd0TKX+uYlsve6ylJoEExHUw+8rNnqggnyns9gJnuRR8MBW2AalS4dr6u1CiMv85c2UHcIuEutNVAuhRa+cL5e7Q3hlFYgg8DrdlFHLbBZaO5Vmki90NP9GZqe+Qw1XGnMAYhohWQek015Wv8rB2nbnc2m48kUTIkV4a1Kyhq5pVtCt7EO4qpZ4CqQ2Ow2YvgMimm/VAUayZXGrAfBJ+RrxMCDkVwndAhrobN8UFeEz//Zf/KfgPC//V/9Ly7HHeICBW6IfTlPg2E/8mm313aw1AQyQerS1fdPIBV3oSPUdIcUoSzv0Q9eafFij/yERz6LfkoPBv5L001SfxpX4vzvaSpkhp/kXoYr6tNb/JNXyvbYz5CrqJDkBScfLQtYQUha7SR16HYKBcDwUwMnuxtiCUXJvUBLT8kPpQSmW4RJcweNYmDlqSFCaRxCUuvCEv7IrAnypqGL0On7kAhu6banwzj8FCqQdJN+ZzEevLiZvBgObnDI4WxA1EtJY5Qdj4coEnYaDsPqBPWW4N8j9GDQWT88kmbgPL19vxiPusfz+mmt3Saz6Xq7vZz2JfiHvUH36f3D7f3i88++fHF/+/WXb6eTWQQLXI/nw+5kEMRgZHtYjYzud7fLjXaD+Wg2O+wOmGD5sFYNEnJnADEQTCbvH5907v3hPJ+O+uPBbn+G2mQ+fXjaPD6sx+PRZDLqQGK7v72druBzpFeMsNJiYiTvGkFfvrgzID69W6eCRqNuZ3ozMRzooZhsvT0S9cPB6eM3t8MuUlzuXt1uD/v1et/vDddP22E3xNnutov5aK+X9S698+n+/v7nP/98MhrOJ6ObxfRpucVwH3/zzZefvz0dLpPpYL4Ydy+nze44VbWLip+3T+s/frj8x//5UqMtfuN/fDwMScEQp/gVxxZvhUXT3/LvemnT3OE9nFGJwxxhDP+xSjFIJajbMGnjmZa/eDbMFdbC2y1vUvh/LaiAJOFzxoBNP9d/AyWs663UwbYVXtkLTHq95Om5qgH+FUxSYn/Pr098PRcQHOHiTT5SweubdCK/rleKyqv0iNy071DqUtzvlfv8rPf5atgREh/gKDQPdeN2UzDbR8sX4EGjYZfKXpEp0Kl3o0x7mtJaISmzYHoRGF7UV7uv33mRbL/Es4BU8gCqf9cCq/hr3nyltknTchQQvwvRwHalPh+Kzs/wSL35ZbktQUk+eZ/fSnwtqYrIfRV1Zb8SmeGHolhxRuPPFNjSFI0arfTx5I7MDQukmKo1ZoZvnhgccUfv3D0YXDr77Y/+dwocNEw3vct+d+mPOvQrqhkFYToczLrDIX4lVPTXCx2aPksn6K6p1+vtfneezcbj2fBMT+0MJosb/bl3ugxNL6h+0ZwyUB0k3KzJju65X0MOZeO4NMfodWeGw/Am1U0BRuXLbrddL5fD8fTY700iYbt0mNFkMBiPld4dnXuZpJBWVJGBt8vtdr15enlD1t3uNtvjftujMufq9MYdcnMyIeoAJ2fTMy5uR+Pd7v3muN2TUkiyI2Ki1B7QzVDR7W72O2rWSS1gOJnNFhOyPhqSSg5GO9pkCH2G7/nY2e0O4xG1fHLqDN6+ezSijAaDmbmERutSXBTaH81uDK4ZWqU+7mBI5IFBBb30ATk+Pa1Wl910vhiNh11zrWOPjjwkgEeT9WFr4mRaYyyTadSjCg/mk7vhoL8/bfdHIp0ueNrCSh+nOg/6GcsOex1vNJqG6qcNzaPXoaZQvA1sITkZelDnjAoRSZ0ocmYInZO2OVBZRiQJRb03GEbQ4sDBxOg/GM6xookgbbYzuBicDUmQpExqvPFg2ukcgTAMnrqD4WS4PxsSdmraG8ymk6GxZLM5zCa9/Z4mR1T15jcz7KUdYGsc3B4OFP9JlE68o0ZRQfBqVHUDVKczHA0imc6n7e6Iv3vq1e0MMTaq9Xphz8tp0OsOpwPqi6p0hz2jqSlVlJVOd2v6dclUE2d5AmkUMW24jK49Mg2nu7gM7M9KWCRwqh36A5KXlBicA0oQM6k7euGudaU0vYZvP6VO99Mb0wk99DuZKLWSV7E0jqZW7g+7p9WTyVhSSuEPGtcLx4c7KakuH2gSUAbTQR8CJkE+otZo3UIR+ukGNUVOq496OobZ4RUxNRqO6HPSmgdKbGB2+Ukx1R9hqhyPG240TPNZeuVqtTZOd6epCNSi1CaxhChtqoCimii6f0MdmehBhUYIipgIIeN+v8vvVsPnwdI7FUjXqRvIILJL5U74A19nBnnc46TDAZ9LoB4KwMmomrl2iFxUDr2vDedttdaZIJMYBQa6fFgiVYCRD5RMxtD8ekHPpTyQojjWlCkNksetglHsNG5apKZpEpuMjXtjRPAE4DTuMRBiv/AkM66aM/QxOYIN1RaeKD80dTYxU9bloq3xJ8xIHl3Ak5YmyNELUaGjy2QiAWsCAQ5BKaiEP4JVGDIXfEyfw//Yp+gQtbRzrGao+a0ZXGwe+mSNKYFCAvXgNOx3qG7HiOrhmdhCf6LL2OGB8oAcmBiYtabJtY9C85/g8JbUwq/p3uFZpIhs9Qk8RFLLkE51IkiRN89jzQjnBLnn2vjhAQOP+lWm5AggfQEEYEm0dL3UP7JBoZmUSp/ZoEEixCBFouiFkn7Kdy0jiAAVnT+PKgfqRVAWDkG8Y3oUOKm5r9SwviUwNEYSaJZChjgHJBKWcA1OmlKS3gWzBpB8qVeooN9iS+n1YJ8do9uBpKeZXwbMH5q1250ORhpqudu
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGB size=1034x518>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set user inputs:\n",
"seed = 0 #@param {type:\"number\"}\n",
"torch.manual_seed(seed)\n",
"num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
"cfg_scale = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
"class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:\"raw\"}\n",
"samples_per_row = 4 #@param {type:\"number\"}\n",
"\n",
"# Create diffusion object:\n",
"diffusion = create_diffusion(str(num_sampling_steps))\n",
"\n",
"# Create sampling noise:\n",
"n = len(class_labels)\n",
"z = torch.randn(n, 4, latent_size, latent_size, device=device)\n",
"y = torch.tensor(class_labels, device=device)\n",
"\n",
"# Setup classifier-free guidance:\n",
"z = torch.cat([z, z], 0)\n",
"y_null = torch.tensor([1000] * n, device=device)\n",
"y = torch.cat([y, y_null], 0)\n",
"model_kwargs = dict(y=y, cfg_scale=cfg_scale)\n",
"\n",
"# Sample images:\n",
"samples = diffusion.p_sample_loop(\n",
" model.forward_with_cfg, z.shape, z, clip_denoised=False, \n",
" model_kwargs=model_kwargs, progress=True, device=device\n",
")\n",
"samples, _ = samples.chunk(2, dim=0) # Remove null class samples\n",
"samples = vae.decode(samples / 0.18215).sample\n",
"\n",
"# Save and display images:\n",
"save_image(samples, \"sample.png\", nrow=int(samples_per_row), \n",
" normalize=True, value_range=(-1, 1))\n",
"samples = Image.open(\"sample.png\")\n",
"display(samples)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}