training script

This commit is contained in:
Yuwei Guo
2023-08-20 17:02:57 +08:00
parent e559802fef
commit e816747d66
8 changed files with 744 additions and 1 deletions

View File

@@ -63,7 +63,7 @@ Contributions are always welcome!! The <code>dev</code> branch is for community
</details>
## Setup for Inference
## Setups for Inference
### Prepare Environment
@@ -139,6 +139,35 @@ Then run the following commands:
python -m scripts.animate --config [path to the config file]
```
## Steps for Training
### Dataset
Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine.
Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`.
### Configuration
After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder:
```
train_data:
csv_path: [Replace with .csv Annotation File Path]
video_folder: [Replace with Video Folder Path]
sample_size: 256
```
Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files.
### Training
To train motion modules
```
torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml
```
To finetune the unet's image layers
```
torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml
```
## Gradio Demo
We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands:
```