remove redundant code

This commit is contained in:
gaomingqi
2023-04-12 13:21:43 +08:00
parent 9f30e59c45
commit caf539d3ca
150 changed files with 25 additions and 8484 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
__pycache__/
.vscode/
docs/

11
README.md Normal file
View File

@@ -0,0 +1,11 @@
# Track-Anything
## Demo
***
This is Demo
## Get Started.
***
This is Get Started.
## Acknowledgement
***
The project is based on [XMem](https://github.com/facebookresearch/segment-anything) and [Segment Anything](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts.

View File

@@ -1,58 +0,0 @@
# Interactive GUI for Demo
First, set up the required packages following [GETTING STARTED.md](./GETTING_STARTED.md). You can ignore the dataset part as you wouldn't be needing them for this demo. Download the pretrained models following [INFERENCE.md](./INFERENCE.md).
You will need some additional packages and pretrained models for the GUI. For the packages,
```bash
pip install -r requirements_demo.txt
```
The interactive GUI is modified from [MiVOS](https://github.com/hkchengrex/MiVOS). Specifically, we keep the "interaction-to-mask" module and the propagation module is replaced with XMem. The fusion module is discarded because I don't want to train it.
For interactions, we use [f-BRS](https://github.com/saic-vul/fbrs_interactive_segmentation) and [S2M](https://github.com/hkchengrex/Scribble-to-Mask). You will need their pretrained models. Use `./scripts/download_models_demo.sh` or download them manually into `./saves`.
The entry point is `interactive_demo.py`. The command line arguments should be self-explanatory.
![gui](https://imgur.com/uAImD80.jpg)
## Try it for yourself
https://user-images.githubusercontent.com/7107196/177661140-f690156b-1775-4cd7-acd7-1738a5c92f30.mp4
Right-click download this video (source: https://www.youtube.com/watch?v=FTcjzaqL0pE). Then run
```bash
python interactive_demo.py --video [path to the video] --num_objects 4
```
## Features
* Low CPU memory cost. Unlike the implementation in MiVOS, we do not load all the images as the program starts up. We load them on-the-fly with an LRU buffer.
* Low GPU memory cost. This is provided by XMem. See the paper.
* Faster than MiVOS-STCN, especially for long videos. ^
* You can continue from interrupted runs. We save the resultant masks on-the-fly in the workspace directory from which annotation can be resumed. The memory bank is not saved and cannot be resumed.
## Controls
* Use the slider to change the current frame. "Play Video" automatically progresses the video.
* Select interaction type: "scribble", "click", or "free". Both scribble and "free" (free-hand drawing) modify an existing mask. Using "click" on an existing object mask (i.e., a mask from propagation or other interaction methods) will reset the mask. This is because f-BRS does not take an existing mask as input.
* Select the target object using the number keys. "1" corresponds to the first object, etc. You need to specify the maximum number of objects when you start the program through the command line.
* Use propagate forward/backward to let XMem do the job. Pause when correction is needed. It will only automatically stops when it hits the end of the video.
* Make sure all objects are correctly labeled before propagating. The program doesn't care which object you have interacted with -- it treats everything as user-provided inputs. Not labelling an object implicitly means that it is part of the background.
* The memory bank might be "polluted" by bad memory frames. Feel free to hit clear memory to erase that. Propagation runs faster with a small memory bank.
* All output masks are automatically saved in the workspace directory, which is printed when the program starts.
* You can load an external mask for the current frame using "Import mask".
* For "layered insertion" (e.g., the breakdance demo), use the "layered" overlay mode. You can load a custom layer using "Import layer". The layer should be an RGBA png file. RGB image files are also accepted -- the alpha channel will be filled with ones.
* The "save overlay during propagation" checkbox does exactly that. It does not save the overlay when the user is just scrubbing the timeline.
* For "popup" and "layered", the visualizations during propagation (and the saved overlays) have higher quality then when the user is scrubbing the timeline. This is because we have access to the soft probability mask during propagation.
* Both "popup" and "layered" need a binary mask. By default, the first object mask is used. You can change the target (or make the target a union of objects) using the middle mouse key.
## FAQ
1. Why cannot I label object 2 after pressing the number '2'?
- Make sure you specified `--num_objects`. We ignore object IDs that exceed `num_objects`.
2. The GUI feels slow!
- The GUI needs to read/write images and masks on-the-go. Ideally this can be implemented with multiple threads with look-ahead but I didn't. The overheads will be smaller if you place the `workspace` on a SSD. You can also use a ram disk. `eval.py` will almost certainly be faster.
- It takes more time to process more objects. This depends on `num_objects`, but not the actual number of objects that the user has annotated. *This does not mean that running time is directly proportional to the number of objects. There is significant shared computation.*
3. Can I run this on a remote server?
- X11 forwarding should be possible. I have not tried this and would love to know if it works for you.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -1,25 +0,0 @@
# Failure Cases
Like all methods, XMem can fail. Here, we try to show some illustrative and frankly consistent failure modes that we noticed. We slowed down all videos for visualization.
## Fast motion, similar objects
The first one is fast motion with similarly-looking objects that do not provide sufficient appearance clues for XMem to track. Below is an example from the YouTubeVOS validation set (0e8a6b63bb):
https://user-images.githubusercontent.com/7107196/179459162-80b65a6c-439d-4239-819f-68804d9412e9.mp4
And the source video:
https://user-images.githubusercontent.com/7107196/181700094-356284bc-e8a4-4757-ab84-1e9009fddd4b.mp4
Technically it can be solved by using more positional and motion clues. XMem is not sufficiently proficient at those.
## Shot changes; saliency shift
Ever wondered why I did not include the final scene of Chika Dance when the roach flies off? Because it failed there.
XMem seems to be attracted to any new salient object in the scene when the (true) target object is missing. By new I mean an object that did not appear (or had a different appearance) earlier in the video -- as XMem could not have a memory representation for that object. This happens a lot if the camera shot changes.
https://user-images.githubusercontent.com/7107196/179459190-d736937a-6925-4472-b46e-dcf94e1cafc0.mp4
Note that the first shot change is not as problematic.

View File

@@ -1,64 +0,0 @@
# Getting Started
Our code is tested on Ubuntu. I have briefly tested the GUI on Windows (with a PyQt5 fix in the heading of interactive_demo.py).
## Requirements
* Python 3.8+
* PyTorch 1.11+ (See [PyTorch](https://pytorch.org/) for installation instructions)
* `torchvision` corresponding to the PyTorch version
* OpenCV (try `pip install opencv-python`)
* Others: `pip install -r requirements.txt`
## Dataset
I recommend either softlinking (`ln -s`) existing data or use the provided `scripts/download_datasets.py` to structure the datasets as our format.
`python -m scripts.download_dataset`
The structure is the same as the one in STCN -- you can place XMem in the same folder as STCN and it will work.
The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case.
It does not download BL30K because it is huge and we don't want to crash your harddisks.
```bash
├── XMem
├── BL30K
├── DAVIS
│ ├── 2016
│ │ ├── Annotations
│ │ └── ...
│ └── 2017
│ ├── test-dev
│ │ ├── Annotations
│ │ └── ...
│ └── trainval
│ ├── Annotations
│ └── ...
├── static
│ ├── BIG_small
│ └── ...
├── long_video_set
│ ├── long_video
│ ├── long_video_x3
│ ├── long_video_davis
│ └── ...
├── YouTube
│ ├── all_frames
│ │ └── valid_all_frames
│ ├── train
│ ├── train_480p
│ └── valid
└── YouTube2018
├── all_frames
│ └── valid_all_frames
└── valid
```
## Long-Time Video
It comes from [AFB-URR](https://github.com/xmlyqing00/AFB-URR). Please following their license when using this data. We release our extended version (X3) and corresponding `_davis` versions such that the DAVIS evaluation can be used directly. They can be downloaded [[here]](TODO). The script above would also attempt to download it.
### BL30K
You can either use the automatic script `download_bl30k.py` or download it manually from [MiVOS](https://github.com/hkchengrex/MiVOS/#bl30k). Note that each segment is about 115GB in size -- 700GB in total. You are going to need ~1TB of free disk space to run the script (including extraction buffer).
The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case.

View File

@@ -1,110 +0,0 @@
# Inference
What is palette? Why is the output a "colored image"? How do I make those input masks that look like color images? See [PALETTE.md](./PALETTE.md).
1. Set up the datasets following [GETTING_STARTED.md](./GETTING_STARTED.md).
2. Download the pretrained models either using `./scripts/download_models.sh`, or manually and put them in `./saves` (create the folder if it doesn't exist). You can download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing).
All command-line inference are accessed with `eval.py`. See [RESULTS.md](./RESULTS.md) for an explanation of FPS and the differences between different models.
## Usage
```
python eval.py --model [path to model file] --output [where to save the output] --dataset [which dataset to evaluate on] --split [val for validation or test for test-dev]
```
See the code for a complete list of available command-line arguments.
Examples:
(``--model`` defaults to `./saves/XMem.pth`)
DAVIS 2017 validation:
```
python eval.py --output ../output/d17 --dataset D17
```
DAVIS 2016 validation:
```
python eval.py --output ../output/d16 --dataset D16
```
DAVIS 2017 test-dev:
```
python eval.py --output ../output/d17-td --dataset D17 --split test
```
YouTubeVOS 2018 validation:
```
python eval.py --output ../output/y18 --dataset Y18
```
Long-Time Video (3X) (note that `mem_every`, aka `r`, is set differently):
```
python eval.py --output ../output/lv3 --dataset LV3 --mem_every 10
```
## Getting quantitative results
We do not provide any tools for getting quantitative results here. We used the followings to get the results reported in the paper:
- DAVIS 2017 validation: [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation)
- DAVIS 2016 validation: [davis2016-evaluation](https://github.com/hkchengrex/davis2016-evaluation) (Unofficial)
- DAVIS 2017 test-dev: [CodaLab](https://competitions.codalab.org/competitions/20516#participate)
- YouTubeVOS 2018 validation: [CodaLab](https://competitions.codalab.org/competitions/19544#results)
- YouTubeVOS 2019 validation: [CodaLab](https://competitions.codalab.org/competitions/20127#participate-submit_results)
- Long-Time Video: [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation)
(For the Long-Time Video dataset, point `--davis_path` to either `long_video_davis` or `long_video_davis_x3`)
## On custom data
Structure your custom data like this:
```bash
├── custom_data_root
│ ├── JPEGImages
│ │ ├── video1
│ │ │ ├── 00001.jpg
│ │ │ ├── 00002.jpg
│ │ │ ├── ...
│ │ └── ...
│ ├── Annotations
│ │ ├── video1
│ │ │ ├── 00001.png
│ │ │ ├── ...
│ │ └── ...
```
We use `sort` to determine frame order. The annotations do not have have to be complete (e.g., first-frame only is fine). We use PIL to read the annotations and `np.unique` to determine objects. PNG palette will be used automatically if exists.
Then, point `--generic_path` to `custom_data_root` and specify `--dataset` as `G` (for generic).
## Multi-scale evaluation
Multi-scale evaluation is done in two steps. We first compute and save the object probabilities maps for different settings independently on hard-disks as `hkl` (hickle) files. Then, these maps are merged together with `merge_multi_score.py`.
Example for DAVIS 2017 validation MS:
Step 1 (can be done in parallel with multiple GPUs):
```bash
python eval.py --output ../output/d17_ms/720p --mem_every 3 --dataset D17 --save_scores --size 720
python eval.py --output ../output/d17_ms/720p_flip --mem_every 3 --dataset D17 --save_scores --size 720 --flip
```
Step 2:
```bash
python merge_multi_scale.py --dataset D --list ../output/d17_ms/720p ../output/d17_ms/720p_flip --output ../output/d17_ms_merged
```
Instead of `--list`, you can also use `--pattern` to specify a glob pattern. It also depends on your shell (e.g., `zsh` or `bash`).
## Advanced usage
To develop your own evaluation interface, see `./inference/` -- most importantly, `inference_core.py`.

View File

@@ -1,13 +0,0 @@
# Palette
> Some image formats, such as GIF or PNG, can use a palette, which is a table of (usually) 256 colors to allow for better compression. Basically, instead of representing each pixel with its full color triplet, which takes 24bits (plus eventual 8 more for transparency), they use a 8 bit index that represent the position inside the palette, and thus the color.
-- https://docs.geoserver.org/2.22.x/en/user/tutorials/palettedimage/palettedimage.html
So those mask files that look like color images are single-channel, `uint8` arrays under the hood. When `PIL` reads them, it (correctly) gives you a two-dimensional array (`opencv` does not work AFAIK). If what you get is instead of three-dimensional, `H*W*3` array, then your mask is not actually a paletted mask, but just a colored image. Reading and saving a paletted mask through `opencv` or MS Paint would destroy the palette.
Our code, when asked to generate multi-object segmentation (e.g., DAVIS 2017/YouTubeVOS), always reads and writes single-channel mask. If there is a palette in the input, we will use it in the output. The code does not care whether a palette is actually used -- we can read grayscale images just fine.
Importantly, we use `np.unique` to determine the number of objects in the mask. This would fail if:
1. Colored images, instead of paletted masks are used.
2. The masks have "smooth" edges, produced by feathering/downsizing/compression. For example, when you draw the mask in a painting software, make sure you set the brush hardness to maximum.

View File

@@ -1,104 +0,0 @@
# Results
## Preamble
Our code, by default, uses automatic mixed precision (AMP). Its effect on the output is negligible.
All speeds reported in the paper are recorded with AMP turned off (`--benchmark`).
Due to refactoring, there might be slight differences between the outputs produced by this code base with the precomputed results/results reported in the paper. This difference rarely leads to a change of the least significant figure (i.e., 0.1).
**For most complete results, please see the paper (and the appendix)!**
All available precomputed results can be found [[here]](https://drive.google.com/drive/folders/1UxHPXJbQLHjF5zYVn3XZCXfi_NYL81Bf?usp=sharing).
## Pretrained models
We provide four pretrained models for download:
1. XMem.pth (Default)
2. XMem-s012.pth (Trained with BL30K)
3. XMem-s2.pth (No pretraining on static images)
4. XMem-no-sensory (No sensory memory)
The model without pretraining is for reference. The model without sensory memory might be more suitable for tasks without spatial continuity, like mask tracking in a multi-camera 3D reconstruction setting, though I would encourage you to try the base model as well.
Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing).
## Long-Time Video
[[Precomputed Results]](https://drive.google.com/drive/folders/1NADcetigH6d83mUvyb2rH4VVjwFA76Lh?usp=sharing)
### Long-Time Video (1X)
| Model | J&F | J | F |
| --- | :--:|:--:|:---:|
| XMem | 89.8±0.2 | 88.0±0.2 | 91.6±0.2 |
### Long-Time Video (3X)
| Model | J&F | J | F |
| --- | :--:|:--:|:---:|
| XMem | 90.0±0.4 | 88.2±0.3 | 91.8±0.4 |
## DAVIS
[[Precomputed Results]](https://drive.google.com/drive/folders/1XTOGevTedRSjHnFVsZyTdxJG-iHjO0Re?usp=sharing)
### DAVIS 2016
| Model | J&F | J | F | FPS | FPS (AMP) |
| --- | :--:|:--:|:---:|:---:|:---:|
| XMem | 91.5 | 90.4 | 92.7 | 29.6 | 40.3 |
| XMem-s012 | 92.0 | 90.7 | 93.2 | 29.6 | 40.3 |
| XMem-s2 | 90.8 | 89.6 | 91.9 | 29.6 | 40.3 |
### DAVIS 2017 validation
| Model | J&F | J | F | FPS | FPS (AMP) |
| --- | :--:|:--:|:---:|:---:|:---:|
| XMem | 86.2 | 82.9 | 89.5 | 22.6 | 33.9 |
| XMem-s012 | 87.7 | 84.0 | 91.4 | 22.6 | 33.9 |
| XMem-s2 | 84.5 | 81.4 | 87.6 | 22.6 | 33.9 |
| XMem-no-sensory | 85.1 | - | - | 23.1 | - |
### DAVIS 2017 test-dev
| Model | J&F | J | F |
| --- | :--:|:--:|:---:|
| XMem | 81.0 | 77.4 | 84.5 |
| XMem-s012 | 81.2 | 77.6 | 84.7 |
| XMem-s2 | 79.8 | 61.4 | 68.1 |
| XMem-s012 (600p) | 82.5 | 79.1 | 85.8 |
## YouTubeVOS
We use all available frames in YouTubeVOS by default.
See [INFERENCE.md](./INFERENCE.md) if you want to evaluate with sparse frames for some reason.
[[Precomputed Results]](https://drive.google.com/drive/folders/1P_BmOdcG6OP5mWGqWzCZrhQJ7AaLME4E?usp=sharing)
[[Precomputed Results (sparse)]](https://drive.google.com/drive/folders/1IRV1fHepufUXM45EEbtl9D4pkoh9POSZ?usp=sharing)
### YouTubeVOS 2018 validation
| Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen | FPS | FPS (AMP) |
| --- | :--:|:--:|:---:|:---:|:---:|:---:|:---:|
| XMem | 85.7 | 84.6 | 89.3 | 80.2 | 88.7 | 22.6 | 31.7 |
| XMem-s012 | 86.1 | 85.1 | 89.8 | 80.3 | 89.2 | 22.6 | 31.7 |
| XMem-s2 | 84.3 | 83.9 | 88.8 | 77.7 | 86.7 | 22.6 | 31.7 |
| XMem-no-sensory | 84.4 | - | - | - | - | 23.1 | - |
### YouTubeVOS 2019 validation
| Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen |
| --- | :--:|:--:|:---:|:---:|:---:|
| XMem | 85.5 | 84.3 | 88.6 | 80.3 | 88.6 |
| XMem-s012 | 85.8 | 84.8 | 89.2 | 80.3 | 88.8 |
| XMem-s2 | 84.2 | 83.8 | 88.3 | 78.1 | 86.7 |
## Multi-scale evaluation
Please see the appendix for quantitative results.
[[DAVIS-MS Precomputed Results]](https://drive.google.com/drive/folders/1H3VHKDO09izp6KR3sE-LzWbjyM-jpftn?usp=sharing)
[[YouTubeVOS-MS Precomputed Results]](https://drive.google.com/drive/folders/1ww5HVRbMKXraLd2dy1rtk6kLjEawW9Kn?usp=sharing)

View File

@@ -1,50 +0,0 @@
# Training
First, set up the datasets following [GETTING STARTED.md](./GETTING_STARTED.md).
The model is trained progressively with different stages (0: static images; 1: BL30K; 2: longer main training; 3: shorter main training). After each stage finishes, we start the next stage by loading the latest trained weight.
For example, the base model is pretrained with static images followed by the shorter main training (s03).
To train the base model on two GPUs, you can use:
```bash
python -m torch.distributed.run --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain --stage 03
```
(**NOTE**: Unexplained accuracy decrease might occur if you are not using two GPUs to train. See https://github.com/hkchengrex/XMem/issues/71.)
`master_port` needs to point to an unused port.
`nproc_per_node` refers to the number of GPUs to be used (specify `CUDA_VISIBLE_DEVICES` to select which GPUs to use).
`exp_id` is an identifier you give to this training job.
See other available command line arguments in `util/configuration.py`.
**Unlike the training code of STCN, batch sizes are effective. You don't have to adjust the batch size when you use more/fewer GPUs.**
We implemented automatic staging in this code base. You don't have to train different stages by yourself like in STCN (but that is still supported).
`stage` is a string that we split to determine the training stages. Examples include `0` (static images only), `03` (base training), `012` (with BL30K), `2` (main training only).
You can use `tensorboard` to visualize the training process.
## Outputs
The model files and checkpoints will be saved in `./saves/[name containing datetime and exp_id]`.
`.pth` files with `_checkpoint` store the network weights, optimizer states, etc. and can be used to resume training (with `--load_checkpoint`).
Other `.pth` files store the network weights only and can be used for inference. We note that there are variations in performance across different training runs and across the last few saved models. For the base model, we most often note that main training at 107K iterations leads to the best result (full training is 110K).
We measure the median and std scores across five training runs of the base model:
| Dataset | median | std |
| --- | :--:|:--:|
| DAVIS J&F | 86.2 | 0.23 |
| YouTubeVOS 2018 G | 85.6 | 0.21
## Pretrained models
You can start training from scratch, or use any of our pretrained models for fine-tuning. For example, you can load our stage 0 model to skip main training:
```bash
python -m torch.distributed.launch --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain_stage3_only --stage 3 --load_network saves/XMem-s0.pth
```
Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 815 B

View File

@@ -1,174 +0,0 @@
<!DOCTYPE HTML>
<html>
<head>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-E4PHBZXG5S"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-E4PHBZXG5S');
</script>
<link rel="preconnect" href="https://fonts.gstatic.com">
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@100;300;400&display=swap" rel="stylesheet">
<title>XMem</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- CSS only -->
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<link href="style.css" type="text/css" rel="stylesheet" media="screen,projection"/>
</head>
<body>
<br><br><br><br>
<div class="container">
<div class="row text-center" style="font-size:38px">
<div class="col">
XMem: Long-Term Video Object Segmentation with an Atkinson-Shiffrin Memory Model
</div>
</div>
<br>
<div class="row text-center" style="font-size:28px">
<div class="col">
ECCV 2022
</div>
</div>
<br>
<div class="h-100 row text-center heavy justify-content-md-center" style="font-size:24px;">
<div class="col-sm-3">
<a href="https://hkchengrex.github.io/">Ho Kei Cheng</a>
</div>
<div class="col-sm-3">
<a href="https://www.alexander-schwing.de/">Alexander Schwing</a>
</div>
</div>
<br>
<div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
<div class="col-sm-2">
<a href="https://arxiv.org/abs/2207.07115">[arXiv]</a>
</div>
<div class="col-sm-2">
<a href="https://arxiv.org/pdf/2207.07115.pdf">[Paper]</a>
</div>
<div class="col-sm-2">
<a href="https://github.com/hkchengrex/XMem">[Code]</a>
</div>
</div>
<br>
<div class="h-100 row text-center justify-content-md-center">
<i>Interactive GUI demo available <a href="https://github.com/hkchengrex/XMem/blob/main/docs/DEMO.md">[here]</a>! </i>
<div class="col">
<a href="https://github.com/hkchengrex/XMem/blob/main/docs/DEMO.md">
<img width="60%" src="https://imgur.com/uAImD80.jpg" alt="framework">
</a>
</div>
</div>
<hr>
<div class="row" style="font-size:32px">
<div class="col">
Abstract
</div>
</div>
<br>
<div class="row">
<div class="col">
<p style="text-align: justify;">
We present XMem, a video object segmentation architecture for long videos with unified feature memory stores inspired by the Atkinson-Shiffrin memory model.
Prior work on video object segmentation typically only uses one type of feature memory. For videos longer than a minute, a single feature memory model tightly links memory consumption and accuracy.
In contrast, following the Atkinson-Shiffrin model, we develop an architecture that incorporates multiple independent yet deeply-connected feature memory stores: a rapidly updated sensory memory, a high-resolution working memory, and a compact thus sustained long-term memory.
Crucially, we develop a memory potentiation algorithm that routinely consolidates actively used working memory elements into the long-term memory, which avoids memory explosion and minimizes performance decay for long-term prediction.
Combined with a new memory reading mechanism, XMem greatly exceeds state-of-the-art performance on long-video datasets while being on par with state-of-the-art methods (that do not work on long videos) on short-video datasets.
</p>
</div>
</div>
<br>
<div class="h-100 row text-center justify-content-md-center">
<div class="col">
<img width="80%" src="https://imgur.com/ToE2frx.jpg" alt="framework">
</div>
</div>
<br>
<hr>
<br>
<div class="row" style="font-size:32px">
<div class="col">
Handling long-term occlusion
</div>
</div>
<br>
<center>
<iframe style="width:100%; aspect-ratio: 1.78;"
src="https://www.youtube.com/embed/mwOP8l3zVNw"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen>
</iframe>
</center>
<br>
<hr>
<br>
<div class="row" style="font-size:32px">
<div class="col">
Very-long video; masked layer insertion
</div>
</div>
<br>
<center>
<iframe style="width:100%; aspect-ratio: 1.78;"
src="https://www.youtube.com/embed/9OtFvF8FiEg"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen>
</iframe>
Source: https://www.youtube.com/watch?v=q5Xr0F4a0iU
</center>
<br>
<hr>
<br>
<div class="row" style="font-size:32px">
<div class="col">
Out-of-domain case
</div>
</div>
<br>
<center>
<video style="width: 100%" controls>
<source src="https://user-images.githubusercontent.com/7107196/177920383-161f1da1-33f9-48b3-b8b2-09e450432e2b.mp4" type="video/mp4">
Your browser does not support the video tag.
</video>
Source: かぐや様は告らせたい ~天才たちの恋愛頭脳戦~ Ep.3; A1 Pictures
</center>
<br><br>
<div style="font-size: 14px;">
Contact: Ho Kei (Rex) Cheng hkchengrex@gmail.com
<br>
</div>
<br><br>
</div>
</body>
</html>

View File

@@ -1,59 +0,0 @@
body {
font-family: 'Roboto', sans-serif;
font-size:18px;
margin-left: auto;
margin-right: auto;
font-weight: 300;
height: 100%;
max-width: 1000px;
}
.light {
font-weight: 100;
}
.heavy {
font-weight: 400;
}
.column {
float: left;
}
.metric_table {
border-collapse: collapse;
margin-left: 15px;
margin-right: auto;
}
.metric_table th{
border-bottom: 1px solid #555;
padding-left: 15px;
padding-right: 15px;
}
.metric_table td{
padding-left: 15px;
padding-right: 15px;
}
.metric_table .left_align{
text-align: left;
}
a:link,a:visited
{
color: #05538f;
text-decoration: none;
}
a:hover {
color: #63cbdd;
}
hr
{
border: 0;
height: 1px;
background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
}

View File

@@ -1,373 +0,0 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

View File

@@ -1,103 +0,0 @@
import torch
from ..fbrs.inference import clicker
from ..fbrs.inference.predictors import get_predictor
class InteractiveController:
def __init__(self, net, device, predictor_params, prob_thresh=0.5):
self.net = net.to(device)
self.prob_thresh = prob_thresh
self.clicker = clicker.Clicker()
self.states = []
self.probs_history = []
self.object_count = 0
self._result_mask = None
self.image = None
self.predictor = None
self.device = device
self.predictor_params = predictor_params
self.reset_predictor()
def set_image(self, image):
self.image = image
self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8)
self.object_count = 0
self.reset_last_object()
def add_click(self, x, y, is_positive):
self.states.append({
'clicker': self.clicker.get_state(),
'predictor': self.predictor.get_states()
})
click = clicker.Click(is_positive=is_positive, coords=(y, x))
self.clicker.add_click(click)
pred = self.predictor.get_prediction(self.clicker)
torch.cuda.empty_cache()
if self.probs_history:
self.probs_history.append((self.probs_history[-1][0], pred))
else:
self.probs_history.append((torch.zeros_like(pred), pred))
def undo_click(self):
if not self.states:
return
prev_state = self.states.pop()
self.clicker.set_state(prev_state['clicker'])
self.predictor.set_states(prev_state['predictor'])
self.probs_history.pop()
def partially_finish_object(self):
object_prob = self.current_object_prob
if object_prob is None:
return
self.probs_history.append((object_prob, torch.zeros_like(object_prob)))
self.states.append(self.states[-1])
self.clicker.reset_clicks()
self.reset_predictor()
def finish_object(self):
object_prob = self.current_object_prob
if object_prob is None:
return
self.object_count += 1
object_mask = object_prob > self.prob_thresh
self._result_mask[object_mask] = self.object_count
self.reset_last_object()
def reset_last_object(self):
self.states = []
self.probs_history = []
self.clicker.reset_clicks()
self.reset_predictor()
def reset_predictor(self, predictor_params=None):
if predictor_params is not None:
self.predictor_params = predictor_params
self.predictor = get_predictor(self.net, device=self.device,
**self.predictor_params)
if self.image is not None:
self.predictor.set_input_image(self.image)
@property
def current_object_prob(self):
if self.probs_history:
current_prob_total, current_prob_additive = self.probs_history[-1]
return torch.maximum(current_prob_total, current_prob_additive)
else:
return None
@property
def is_incomplete_mask(self):
return len(self.probs_history) > 0
@property
def result_mask(self):
return self._result_mask.clone()

View File

@@ -1,103 +0,0 @@
from collections import namedtuple
import numpy as np
from copy import deepcopy
from scipy.ndimage import distance_transform_edt
Click = namedtuple('Click', ['is_positive', 'coords'])
class Clicker(object):
def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1):
if gt_mask is not None:
self.gt_mask = gt_mask == 1
self.not_ignore_mask = gt_mask != ignore_label
else:
self.gt_mask = None
self.reset_clicks()
if init_clicks is not None:
for click in init_clicks:
self.add_click(click)
def make_next_click(self, pred_mask):
assert self.gt_mask is not None
click = self._get_click(pred_mask)
self.add_click(click)
def get_clicks(self, clicks_limit=None):
return self.clicks_list[:clicks_limit]
def _get_click(self, pred_mask, padding=True):
fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
fn_mask_dt = distance_transform_edt(fn_mask)
fp_mask_dt = distance_transform_edt(fp_mask)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
fn_mask_dt = fn_mask_dt * self.not_clicked_map
fp_mask_dt = fp_mask_dt * self.not_clicked_map
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
if is_positive:
coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
else:
coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
def add_click(self, click):
coords = click.coords
if click.is_positive:
self.num_pos_clicks += 1
else:
self.num_neg_clicks += 1
self.clicks_list.append(click)
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = False
def _remove_last_click(self):
click = self.clicks_list.pop()
coords = click.coords
if click.is_positive:
self.num_pos_clicks -= 1
else:
self.num_neg_clicks -= 1
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = True
def reset_clicks(self):
if self.gt_mask is not None:
self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
self.num_pos_clicks = 0
self.num_neg_clicks = 0
self.clicks_list = []
def get_state(self):
return deepcopy(self.clicks_list)
def set_state(self, state):
self.reset_clicks()
for click in state:
self.add_click(click)
def __len__(self):
return len(self.clicks_list)

View File

@@ -1,56 +0,0 @@
from time import time
import numpy as np
import torch
from ..inference import utils
from ..inference.clicker import Clicker
try:
get_ipython()
from tqdm import tqdm_notebook as tqdm
except NameError:
from tqdm import tqdm
def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs):
all_ious = []
start_time = time()
for index in tqdm(range(len(dataset)), leave=False):
sample = dataset.get_sample(index)
item = dataset[index]
if oracle_eval:
gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32)
gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
predictor.opt_functor.mask_loss.set_gt_mask(gt_mask)
_, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs)
all_ious.append(sample_ious)
end_time = time()
elapsed_time = end_time - start_time
return all_ious, elapsed_time
def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr,
pred_thr=0.49, max_clicks=20):
clicker = Clicker(gt_mask=instances_mask)
pred_mask = np.zeros_like(instances_mask)
ious_list = []
with torch.no_grad():
predictor.set_input_image(image_nd)
for click_number in range(max_clicks):
clicker.make_next_click(pred_mask)
pred_probs = predictor.get_prediction(clicker)
pred_mask = pred_probs > pred_thr
iou = utils.get_iou(instances_mask, pred_mask)
ious_list.append(iou)
if iou >= max_iou_thr:
break
return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs

View File

@@ -1,95 +0,0 @@
from .base import BasePredictor
from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
from .brs_functors import InputOptimizer, ScaleBiasOptimizer
from ..transforms import ZoomIn
from ...model.is_hrnet_model import DistMapsHRNetModel
def get_predictor(net, brs_mode, device,
prob_thresh=0.49,
with_flip=True,
zoom_in_params=dict(),
predictor_params=None,
brs_opt_func_params=None,
lbfgs_params=None):
lbfgs_params_ = {
'm': 20,
'factr': 0,
'pgtol': 1e-8,
'maxfun': 20,
}
predictor_params_ = {
'optimize_after_n_clicks': 1
}
if zoom_in_params is not None:
zoom_in = ZoomIn(**zoom_in_params)
else:
zoom_in = None
if lbfgs_params is not None:
lbfgs_params_.update(lbfgs_params)
lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
if brs_opt_func_params is None:
brs_opt_func_params = dict()
if brs_mode == 'NoBRS':
if predictor_params is not None:
predictor_params_.update(predictor_params)
predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
elif brs_mode.startswith('f-BRS'):
predictor_params_.update({
'net_clicks_limit': 8,
})
if predictor_params is not None:
predictor_params_.update(predictor_params)
insertion_mode = {
'f-BRS-A': 'after_c4',
'f-BRS-B': 'after_aspp',
'f-BRS-C': 'after_deeplab'
}[brs_mode]
opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
with_flip=with_flip,
optimizer_params=lbfgs_params_,
**brs_opt_func_params)
if isinstance(net, DistMapsHRNetModel):
FeaturePredictor = HRNetFeatureBRSPredictor
insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
else:
FeaturePredictor = FeatureBRSPredictor
predictor = FeaturePredictor(net, device,
opt_functor=opt_functor,
with_flip=with_flip,
insertion_mode=insertion_mode,
zoom_in=zoom_in,
**predictor_params_)
elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
use_dmaps = brs_mode == 'DistMap-BRS'
predictor_params_.update({
'net_clicks_limit': 5,
})
if predictor_params is not None:
predictor_params_.update(predictor_params)
opt_functor = InputOptimizer(prob_thresh=prob_thresh,
with_flip=with_flip,
optimizer_params=lbfgs_params_,
**brs_opt_func_params)
predictor = InputBRSPredictor(net, device,
optimize_target='dmaps' if use_dmaps else 'rgb',
opt_functor=opt_functor,
with_flip=with_flip,
zoom_in=zoom_in,
**predictor_params_)
else:
raise NotImplementedError
return predictor

View File

@@ -1,100 +0,0 @@
import torch
import torch.nn.functional as F
from ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
class BasePredictor(object):
def __init__(self, net, device,
net_clicks_limit=None,
with_flip=False,
zoom_in=None,
max_size=None,
**kwargs):
self.net = net
self.with_flip = with_flip
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.device = device
self.zoom_in = zoom_in
self.transforms = [zoom_in] if zoom_in is not None else []
if max_size is not None:
self.transforms.append(LimitLongestSide(max_size=max_size))
self.transforms.append(SigmoidForPred())
if with_flip:
self.transforms.append(AddHorizontalFlip())
def set_input_image(self, image_nd):
for transform in self.transforms:
transform.reset()
self.original_image = image_nd.to(self.device)
if len(self.original_image.shape) == 3:
self.original_image = self.original_image.unsqueeze(0)
def get_prediction(self, clicker):
clicks_list = clicker.get_clicks()
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
self.original_image, [clicks_list]
)
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
size=image_nd.size()[2:])
for t in reversed(self.transforms):
prediction = t.inv_transform(prediction)
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
print('zooming')
return self.get_prediction(clicker)
# return prediction.cpu().numpy()[0, 0]
return prediction
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
return self.net(image_nd, points_nd)['instances']
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit]
pos_clicks = [click.coords for click in clicks_list if click.is_positive]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)]
neg_clicks = [click.coords for click in clicks_list if not click.is_positive]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)]
total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device)
def get_states(self):
return {'transform_states': self._get_transform_states()}
def set_states(self, states):
self._set_transform_states(states['transform_states'])

View File

@@ -1,280 +0,0 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
from .base import BasePredictor
from ...model.is_hrnet_model import DistMapsHRNetModel
class BRSBasePredictor(BasePredictor):
def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
super().__init__(model, device, **kwargs)
self.optimize_after_n_clicks = optimize_after_n_clicks
self.opt_functor = opt_functor
self.opt_data = None
self.input_data = None
def set_input_image(self, image_nd):
super().set_input_image(image_nd)
self.opt_data = None
self.input_data = None
def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
for list_indx, clicks_list in enumerate(clicks_lists):
for click in clicks_list:
y, x = click.coords
y, x = int(round(y)), int(round(x))
y1, x1 = y - radius, x - radius
y2, x2 = y + radius + 1, x + radius + 1
if click.is_positive:
pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
else:
neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
with torch.no_grad():
pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
return pos_clicks_map, neg_clicks_map
def get_states(self):
return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
def set_states(self, states):
self._set_transform_states(states['transform_states'])
self.opt_data = states['opt_data']
class FeatureBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.insertion_mode = insertion_mode
self._c1_features = None
if self.insertion_mode == 'after_deeplab':
self.num_channels = model.feature_extractor.ch
elif self.insertion_mode == 'after_c4':
self.num_channels = model.feature_extractor.aspp_in_channels
elif self.insertion_mode == 'after_aspp':
self.num_channels = model.feature_extractor.ch + 32
else:
raise NotImplementedError
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
self.input_data = self._get_head_input(image_nd, points_nd)
def get_prediction_logits(scale, bias):
scale = scale.view(bs, -1, 1, 1)
bias = bias.view(bs, -1, 1, 1)
if self.with_flip:
scale = scale.repeat(2, 1, 1, 1)
bias = bias.repeat(2, 1, 1, 1)
scaled_backbone_features = self.input_data * scale
scaled_backbone_features = scaled_backbone_features + bias
if self.insertion_mode == 'after_c4':
x = self.net.feature_extractor.aspp(scaled_backbone_features)
x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
align_corners=True)
x = torch.cat((x, self._c1_features), dim=1)
scaled_backbone_features = self.net.feature_extractor.head(x)
elif self.insertion_mode == 'after_aspp':
scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
pred_logits = self.net.head(scaled_backbone_features)
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
**self.opt_functor.optimizer_params)
self.opt_data = opt_result[0]
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits
def _get_head_input(self, image_nd, points):
with torch.no_grad():
coord_features = self.net.dist_maps(image_nd, points)
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
c1, _, c3, c4 = self.net.feature_extractor.backbone(x)
c1 = self.net.feature_extractor.skip_project(c1)
if self.insertion_mode == 'after_aspp':
x = self.net.feature_extractor.aspp(c4)
x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, c1), dim=1)
backbone_features = x
else:
backbone_features = c4
self._c1_features = c1
else:
backbone_features = self.net.feature_extractor(x)[0]
return backbone_features
class HRNetFeatureBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.insertion_mode = insertion_mode
self._c1_features = None
if self.insertion_mode == 'A':
self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
elif self.insertion_mode == 'C':
self.num_channels = 2 * model.feature_extractor.ocr_width
else:
raise NotImplementedError
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
self.input_data = self._get_head_input(image_nd, points_nd)
def get_prediction_logits(scale, bias):
scale = scale.view(bs, -1, 1, 1)
bias = bias.view(bs, -1, 1, 1)
if self.with_flip:
scale = scale.repeat(2, 1, 1, 1)
bias = bias.repeat(2, 1, 1, 1)
scaled_backbone_features = self.input_data * scale
scaled_backbone_features = scaled_backbone_features + bias
if self.insertion_mode == 'A':
out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
feats = self.net.feature_extractor.ocr_distri_head(feats, context)
pred_logits = self.net.feature_extractor.cls_head(feats)
elif self.insertion_mode == 'C':
pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
else:
raise NotImplementedError
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
**self.opt_functor.optimizer_params)
self.opt_data = opt_result[0]
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits
def _get_head_input(self, image_nd, points):
with torch.no_grad():
coord_features = self.net.dist_maps(image_nd, points)
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
feats = self.net.feature_extractor.compute_hrnet_feats(x)
if self.insertion_mode == 'A':
backbone_features = feats
elif self.insertion_mode == 'C':
out_aux = self.net.feature_extractor.aux_head(feats)
feats = self.net.feature_extractor.conv3x3_ocr(feats)
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
else:
raise NotImplementedError
return backbone_features
class InputBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.optimize_target = optimize_target
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
if self.opt_data is None or is_image_changed:
opt_channels = 2 if self.optimize_target == 'dmaps' else 3
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
device=self.device, dtype=torch.float32)
def get_prediction_logits(opt_bias):
input_image = image_nd
if self.optimize_target == 'rgb':
input_image = input_image + opt_bias
dmaps = self.net.dist_maps(input_image, points_nd)
if self.optimize_target == 'dmaps':
dmaps = dmaps + opt_bias
x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
if self.optimize_target == 'all':
x = x + opt_bias
if isinstance(self.net, DistMapsHRNetModel):
pred_logits = self.net.feature_extractor(x)[0]
else:
backbone_features = self.net.feature_extractor(x)
pred_logits = self.net.head(backbone_features[0])
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
shape=self.opt_data.shape)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
**self.opt_functor.optimizer_params)
self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits

View File

@@ -1,109 +0,0 @@
import torch
import numpy as np
from ...model.metrics import _compute_iou
from .brs_losses import BRSMaskLoss
class BaseOptimizer:
def __init__(self, optimizer_params,
prob_thresh=0.49,
reg_weight=1e-3,
min_iou_diff=0.01,
brs_loss=BRSMaskLoss(),
with_flip=False,
flip_average=False,
**kwargs):
self.brs_loss = brs_loss
self.optimizer_params = optimizer_params
self.prob_thresh = prob_thresh
self.reg_weight = reg_weight
self.min_iou_diff = min_iou_diff
self.with_flip = with_flip
self.flip_average = flip_average
self.best_prediction = None
self._get_prediction_logits = None
self._opt_shape = None
self._best_loss = None
self._click_masks = None
self._last_mask = None
self.device = None
def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None):
self.best_prediction = None
self._get_prediction_logits = get_prediction_logits
self._click_masks = (pos_mask, neg_mask)
self._opt_shape = shape
self._last_mask = None
self.device = device
def __call__(self, x):
opt_params = torch.from_numpy(x).float().to(self.device)
opt_params.requires_grad_(True)
with torch.enable_grad():
opt_vars, reg_loss = self.unpack_opt_params(opt_params)
result_before_sigmoid = self._get_prediction_logits(*opt_vars)
result = torch.sigmoid(result_before_sigmoid)
pos_mask, neg_mask = self._click_masks
if self.with_flip and self.flip_average:
result, result_flipped = torch.chunk(result, 2, dim=0)
result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
loss = loss + reg_loss
f_val = loss.detach().cpu().numpy()
if self.best_prediction is None or f_val < self._best_loss:
self.best_prediction = result_before_sigmoid.detach()
self._best_loss = f_val
if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh:
return [f_val, np.zeros_like(x)]
current_mask = result > self.prob_thresh
if self._last_mask is not None and self.min_iou_diff > 0:
diff_iou = _compute_iou(current_mask, self._last_mask)
if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff:
return [f_val, np.zeros_like(x)]
self._last_mask = current_mask
loss.backward()
f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float32)
return [f_val, f_grad]
def unpack_opt_params(self, opt_params):
raise NotImplementedError
class InputOptimizer(BaseOptimizer):
def unpack_opt_params(self, opt_params):
opt_params = opt_params.view(self._opt_shape)
if self.with_flip:
opt_params_flipped = torch.flip(opt_params, dims=[3])
opt_params = torch.cat([opt_params, opt_params_flipped], dim=0)
reg_loss = self.reg_weight * torch.sum(opt_params**2)
return (opt_params,), reg_loss
class ScaleBiasOptimizer(BaseOptimizer):
def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs):
super().__init__(*args, **kwargs)
self.scale_act = scale_act
self.reg_bias_weight = reg_bias_weight
def unpack_opt_params(self, opt_params):
scale, bias = torch.chunk(opt_params, 2, dim=0)
reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
if self.scale_act == 'tanh':
scale = torch.tanh(scale)
elif self.scale_act == 'sin':
scale = torch.sin(scale)
return (1 + scale, bias), reg_loss

View File

@@ -1,58 +0,0 @@
import torch
from ...model.losses import SigmoidBinaryCrossEntropyLoss
class BRSMaskLoss(torch.nn.Module):
def __init__(self, eps=1e-5):
super().__init__()
self._eps = eps
def forward(self, result, pos_mask, neg_mask):
pos_diff = (1 - result) * pos_mask
pos_target = torch.sum(pos_diff ** 2)
pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
neg_diff = result * neg_mask
neg_target = torch.sum(neg_diff ** 2)
neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
loss = pos_target + neg_target
with torch.no_grad():
f_max_pos = torch.max(torch.abs(pos_diff)).item()
f_max_neg = torch.max(torch.abs(neg_diff)).item()
return loss, f_max_pos, f_max_neg
class OracleMaskLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.gt_mask = None
self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
self.predictor = None
self.history = []
def set_gt_mask(self, gt_mask):
self.gt_mask = gt_mask
self.history = []
def forward(self, result, pos_mask, neg_mask):
gt_mask = self.gt_mask.to(result.device)
if self.predictor.object_roi is not None:
r1, r2, c1, c2 = self.predictor.object_roi[:4]
gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
if result.shape[0] == 2:
gt_mask_flipped = torch.flip(gt_mask, dims=[3])
gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0)
loss = self.loss(result, gt_mask)
self.history.append(loss.detach().cpu().numpy()[0])
if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5:
return 0, 0, 0
return loss, 1.0, 1.0

View File

@@ -1,5 +0,0 @@
from .base import SigmoidForPred
from .flip import AddHorizontalFlip
from .zoom_in import ZoomIn
from .limit_longest_side import LimitLongestSide
from .crops import Crops

View File

@@ -1,38 +0,0 @@
import torch
class BaseTransform(object):
def __init__(self):
self.image_changed = False
def transform(self, image_nd, clicks_lists):
raise NotImplementedError
def inv_transform(self, prob_map):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def set_state(self, state):
raise NotImplementedError
class SigmoidForPred(BaseTransform):
def transform(self, image_nd, clicks_lists):
return image_nd, clicks_lists
def inv_transform(self, prob_map):
return torch.sigmoid(prob_map)
def reset(self):
pass
def get_state(self):
return None
def set_state(self, state):
pass

View File

@@ -1,97 +0,0 @@
import math
import torch
import numpy as np
from ...inference.clicker import Click
from .base import BaseTransform
class Crops(BaseTransform):
def __init__(self, crop_size=(320, 480), min_overlap=0.2):
super().__init__()
self.crop_height, self.crop_width = crop_size
self.min_overlap = min_overlap
self.x_offsets = None
self.y_offsets = None
self._counts = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_height, image_width = image_nd.shape[2:4]
self._counts = None
if image_height < self.crop_height or image_width < self.crop_width:
return image_nd, clicks_lists
self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
self._counts = np.zeros((image_height, image_width))
image_crops = []
for dy in self.y_offsets:
for dx in self.x_offsets:
self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
image_crops.append(image_crop)
image_crops = torch.cat(image_crops, dim=0)
self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
clicks_list = clicks_lists[0]
clicks_lists = []
for dy in self.y_offsets:
for dx in self.x_offsets:
crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx))
for x in clicks_list]
clicks_lists.append(crop_clicks)
return image_crops, clicks_lists
def inv_transform(self, prob_map):
if self._counts is None:
return prob_map
new_prob_map = torch.zeros((1, 1, *self._counts.shape),
dtype=prob_map.dtype, device=prob_map.device)
crop_indx = 0
for dy in self.y_offsets:
for dx in self.x_offsets:
new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
crop_indx += 1
new_prob_map = torch.div(new_prob_map, self._counts)
return new_prob_map
def get_state(self):
return self.x_offsets, self.y_offsets, self._counts
def set_state(self, state):
self.x_offsets, self.y_offsets, self._counts = state
def reset(self):
self.x_offsets = None
self.y_offsets = None
self._counts = None
def get_offsets(length, crop_size, min_overlap_ratio=0.2):
if length == crop_size:
return [0]
N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
N = math.ceil(N)
overlap_ratio = (N - length / crop_size) / (N - 1)
overlap_width = int(crop_size * overlap_ratio)
offsets = [0]
for i in range(1, N):
new_offset = offsets[-1] + crop_size - overlap_width
if new_offset + crop_size > length:
new_offset = length - crop_size
offsets.append(new_offset)
return offsets

View File

@@ -1,37 +0,0 @@
import torch
from ..clicker import Click
from .base import BaseTransform
class AddHorizontalFlip(BaseTransform):
def transform(self, image_nd, clicks_lists):
assert len(image_nd.shape) == 4
image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
image_width = image_nd.shape[3]
clicks_lists_flipped = []
for clicks_list in clicks_lists:
clicks_list_flipped = [Click(is_positive=click.is_positive,
coords=(click.coords[0], image_width - click.coords[1] - 1))
for click in clicks_list]
clicks_lists_flipped.append(clicks_list_flipped)
clicks_lists = clicks_lists + clicks_lists_flipped
return image_nd, clicks_lists
def inv_transform(self, prob_map):
assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
num_maps = prob_map.shape[0] // 2
prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
def get_state(self):
return None
def set_state(self, state):
pass
def reset(self):
pass

View File

@@ -1,22 +0,0 @@
from .zoom_in import ZoomIn, get_roi_image_nd
class LimitLongestSide(ZoomIn):
def __init__(self, max_size=800):
super().__init__(target_size=max_size, skip_clicks=0)
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_max_size = max(image_nd.shape[2:4])
self.image_changed = False
if image_max_size <= self.target_size:
return image_nd, clicks_lists
self._input_image = image_nd
self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
self.image_changed = True
tclicks_lists = [self._transform_clicks(clicks_lists[0])]
return self._roi_image, tclicks_lists

View File

@@ -1,171 +0,0 @@
import torch
from ..clicker import Click
from ...utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
from .base import BaseTransform
class ZoomIn(BaseTransform):
def __init__(self,
target_size=400,
skip_clicks=1,
expansion_ratio=1.4,
min_crop_size=200,
recompute_thresh_iou=0.5,
prob_thresh=0.50):
super().__init__()
self.target_size = target_size
self.min_crop_size = min_crop_size
self.skip_clicks = skip_clicks
self.expansion_ratio = expansion_ratio
self.recompute_thresh_iou = recompute_thresh_iou
self.prob_thresh = prob_thresh
self._input_image_shape = None
self._prev_probs = None
self._object_roi = None
self._roi_image = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
self.image_changed = False
clicks_list = clicks_lists[0]
if len(clicks_list) <= self.skip_clicks:
return image_nd, clicks_lists
self._input_image_shape = image_nd.shape
current_object_roi = None
if self._prev_probs is not None:
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if current_pred_mask.sum() > 0:
current_object_roi = get_object_roi(current_pred_mask, clicks_list,
self.expansion_ratio, self.min_crop_size)
if current_object_roi is None:
return image_nd, clicks_lists
update_object_roi = False
if self._object_roi is None:
update_object_roi = True
elif not check_object_roi(self._object_roi, clicks_list):
update_object_roi = True
elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
update_object_roi = True
if update_object_roi:
self._object_roi = current_object_roi
self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
self.image_changed = True
tclicks_lists = [self._transform_clicks(clicks_list)]
return self._roi_image.to(image_nd.device), tclicks_lists
def inv_transform(self, prob_map):
if self._object_roi is None:
self._prev_probs = prob_map.cpu().numpy()
return prob_map
assert prob_map.shape[0] == 1
rmin, rmax, cmin, cmax = self._object_roi
prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
mode='bilinear', align_corners=True)
if self._prev_probs is not None:
new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
else:
new_prob_map = prob_map
self._prev_probs = new_prob_map.cpu().numpy()
return new_prob_map
def check_possible_recalculation(self):
if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
return False
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if pred_mask.sum() > 0:
possible_object_roi = get_object_roi(pred_mask, [],
self.expansion_ratio, self.min_crop_size)
image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
return True
return False
def get_state(self):
roi_image = self._roi_image.cpu() if self._roi_image is not None else None
return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
def set_state(self, state):
self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
def reset(self):
self._input_image_shape = None
self._object_roi = None
self._prev_probs = None
self._roi_image = None
self.image_changed = False
def _transform_clicks(self, clicks_list):
if self._object_roi is None:
return clicks_list
rmin, rmax, cmin, cmax = self._object_roi
crop_height, crop_width = self._roi_image.shape[2:]
transformed_clicks = []
for click in clicks_list:
new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
transformed_clicks.append(Click(is_positive=click.is_positive, coords=(new_r, new_c)))
return transformed_clicks
def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
pred_mask = pred_mask.copy()
for click in clicks_list:
if click.is_positive:
pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
bbox = get_bbox_from_mask(pred_mask)
bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
h, w = pred_mask.shape[0], pred_mask.shape[1]
bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
return bbox
def get_roi_image_nd(image_nd, object_roi, target_size):
rmin, rmax, cmin, cmax = object_roi
height = rmax - rmin + 1
width = cmax - cmin + 1
if isinstance(target_size, tuple):
new_height, new_width = target_size
else:
scale = target_size / max(height, width)
new_height = int(round(height * scale))
new_width = int(round(width * scale))
with torch.no_grad():
roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
mode='bilinear', align_corners=True)
return roi_image_nd
def check_object_roi(object_roi, clicks_list):
for click in clicks_list:
if click.is_positive:
if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
return False
if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
return False
return True

View File

@@ -1,177 +0,0 @@
from datetime import timedelta
from pathlib import Path
import torch
import numpy as np
from ..model.is_deeplab_model import get_deeplab_model
from ..model.is_hrnet_model import get_hrnet_model
def get_time_metrics(all_ious, elapsed_time):
n_images = len(all_ious)
n_clicks = sum(map(len, all_ious))
mean_spc = elapsed_time / n_clicks
mean_spi = elapsed_time / n_images
return mean_spc, mean_spi
def load_is_model(checkpoint, device, backbone='auto', **kwargs):
if isinstance(checkpoint, (str, Path)):
state_dict = torch.load(checkpoint, map_location='cpu')
else:
state_dict = checkpoint
if backbone == 'auto':
for k in state_dict.keys():
if 'feature_extractor.stage2.0.branches' in k:
return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
elif 'resnet' in backbone:
return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
elif 'hrnet' in backbone:
return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
else:
raise NotImplementedError('Unknown backbone')
def load_hrnet_is_model(state_dict, device, backbone='auto', width=48, ocr_width=256,
small=False, cpu_dist_maps=False, norm_radius=260):
if backbone == 'auto':
num_fe_weights = len([x for x in state_dict.keys() if 'feature_extractor.' in x])
small = num_fe_weights < 1800
ocr_f_down = [v for k, v in state_dict.items() if 'object_context_block.f_down.1.0.bias' in k]
assert len(ocr_f_down) == 1
ocr_width = ocr_f_down[0].shape[0]
s2_conv1_w = [v for k, v in state_dict.items() if 'stage2.0.branches.0.0.conv1.weight' in k]
assert len(s2_conv1_w) == 1
width = s2_conv1_w[0].shape[0]
model = get_hrnet_model(width=width, ocr_width=ocr_width, small=small,
with_aux_output=False, cpu_dist_maps=cpu_dist_maps,
norm_radius=norm_radius)
model.load_state_dict(state_dict, strict=False)
for param in model.parameters():
param.requires_grad = False
model.to(device)
model.eval()
return model
def load_deeplab_is_model(state_dict, device, backbone='auto', deeplab_ch=128, aspp_dropout=0.2,
cpu_dist_maps=False, norm_radius=260):
if backbone == 'auto':
num_backbone_params = len([x for x in state_dict.keys()
if 'feature_extractor.backbone' in x and not('num_batches_tracked' in x)])
if num_backbone_params <= 181:
backbone = 'resnet34'
elif num_backbone_params <= 276:
backbone = 'resnet50'
elif num_backbone_params <= 531:
backbone = 'resnet101'
else:
raise NotImplementedError('Unknown backbone')
if 'aspp_dropout' in state_dict:
aspp_dropout = float(state_dict['aspp_dropout'].cpu().numpy())
else:
aspp_project_weight = [v for k, v in state_dict.items() if 'aspp.project.0.weight' in k][0]
deeplab_ch = aspp_project_weight.size(0)
if deeplab_ch == 256:
aspp_dropout = 0.5
model = get_deeplab_model(backbone=backbone, deeplab_ch=deeplab_ch,
aspp_dropout=aspp_dropout, cpu_dist_maps=cpu_dist_maps,
norm_radius=norm_radius)
model.load_state_dict(state_dict, strict=False)
for param in model.parameters():
param.requires_grad = False
model.to(device)
model.eval()
return model
def get_iou(gt_mask, pred_mask, ignore_label=-1):
ignore_gt_mask_inv = gt_mask != ignore_label
obj_gt_mask = gt_mask == 1
intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
return intersection / union
def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
def _get_noc(iou_arr, iou_thr):
vals = iou_arr >= iou_thr
return np.argmax(vals) + 1 if np.any(vals) else max_clicks
noc_list = []
over_max_list = []
for iou_thr in iou_thrs:
scores_arr = np.array([_get_noc(iou_arr, iou_thr)
for iou_arr in all_ious], dtype=np.int32)
score = scores_arr.mean()
over_max = (scores_arr == max_clicks).sum()
noc_list.append(score)
over_max_list.append(over_max)
return noc_list, over_max_list
def find_checkpoint(weights_folder, checkpoint_name):
weights_folder = Path(weights_folder)
if ':' in checkpoint_name:
model_name, checkpoint_name = checkpoint_name.split(':')
models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
assert len(models_candidates) == 1
model_folder = models_candidates[0]
else:
model_folder = weights_folder
if checkpoint_name.endswith('.pth'):
if Path(checkpoint_name).exists():
checkpoint_path = checkpoint_name
else:
checkpoint_path = weights_folder / checkpoint_name
else:
model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
assert len(model_checkpoints) == 1
checkpoint_path = model_checkpoints[0]
return str(checkpoint_path)
def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
n_clicks=20, model_name=None):
table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
f'{"SPC,s":^7}|{"Time":^9}|')
row_width = len(table_header)
header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
header += '-' * row_width + '\n'
header += table_header + '\n' + '-' * row_width
eval_time = str(timedelta(seconds=int(elapsed_time)))
table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
table_row += f'{noc_list[0]:^9.2f}|'
table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
return header, table_row

View File

@@ -1,105 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
class Initializer(object):
def __init__(self, local_init=True, gamma=None):
self.local_init = local_init
self.gamma = gamma
def __call__(self, m):
if getattr(m, '__initialized', False):
return
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
if m.weight is not None:
self._init_gamma(m.weight.data)
if m.bias is not None:
self._init_beta(m.bias.data)
else:
if getattr(m, 'weight', None) is not None:
self._init_weight(m.weight.data)
if getattr(m, 'bias', None) is not None:
self._init_bias(m.bias.data)
if self.local_init:
object.__setattr__(m, '__initialized', True)
def _init_weight(self, data):
nn.init.uniform_(data, -0.07, 0.07)
def _init_bias(self, data):
nn.init.constant_(data, 0)
def _init_gamma(self, data):
if self.gamma is None:
nn.init.constant_(data, 1.0)
else:
nn.init.normal_(data, 1.0, self.gamma)
def _init_beta(self, data):
nn.init.constant_(data, 0)
class Bilinear(Initializer):
def __init__(self, scale, groups, in_channels, **kwargs):
super().__init__(**kwargs)
self.scale = scale
self.groups = groups
self.in_channels = in_channels
def _init_weight(self, data):
"""Reset the weight and bias."""
bilinear_kernel = self.get_bilinear_kernel(self.scale)
weight = torch.zeros_like(data)
for i in range(self.in_channels):
if self.groups == 1:
j = i
else:
j = 0
weight[i, j] = bilinear_kernel
data[:] = weight
@staticmethod
def get_bilinear_kernel(scale):
"""Generate a bilinear upsampling kernel."""
kernel_size = 2 * scale - scale % 2
scale = (kernel_size + 1) // 2
center = scale - 0.5 * (1 + kernel_size % 2)
og = np.ogrid[:kernel_size, :kernel_size]
kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
return torch.tensor(kernel, dtype=torch.float32)
class XavierGluon(Initializer):
def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
super().__init__(**kwargs)
self.rnd_type = rnd_type
self.factor_type = factor_type
self.magnitude = float(magnitude)
def _init_weight(self, arr):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
if self.factor_type == 'avg':
factor = (fan_in + fan_out) / 2.0
elif self.factor_type == 'in':
factor = fan_in
elif self.factor_type == 'out':
factor = fan_out
else:
raise ValueError('Incorrect factor type')
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == 'uniform':
nn.init.uniform_(arr, -scale, scale)
elif self.rnd_type == 'gaussian':
nn.init.normal_(arr, 0, scale)
else:
raise ValueError('Unknown random type')

View File

@@ -1,86 +0,0 @@
import torch
import torch.nn as nn
from .ops import DistMaps
from .modeling.deeplab_v3 import DeepLabV3Plus
from .modeling.basic_blocks import SepConvHead
def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
norm_layer=nn.BatchNorm2d, backbone_norm_layer=None,
use_rgb_conv=True, cpu_dist_maps=False,
norm_radius=260):
model = DistMapsModel(
feature_extractor=DeepLabV3Plus(backbone=backbone,
ch=deeplab_ch,
project_dropout=aspp_dropout,
norm_layer=norm_layer,
backbone_norm_layer=backbone_norm_layer),
head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
num_layers=2, norm_layer=norm_layer),
use_rgb_conv=use_rgb_conv,
norm_layer=norm_layer,
norm_radius=norm_radius,
cpu_dist_maps=cpu_dist_maps
)
return model
class DistMapsModel(nn.Module):
def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True,
cpu_dist_maps=False, norm_radius=260):
super(DistMapsModel, self).__init__()
if use_rgb_conv:
self.rgb_conv = nn.Sequential(
nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
nn.LeakyReLU(negative_slope=0.2),
norm_layer(8),
nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
)
else:
self.rgb_conv = None
self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
cpu_mode=cpu_dist_maps)
self.feature_extractor = feature_extractor
self.head = head
def forward(self, image, points):
coord_features = self.dist_maps(image, points)
if self.rgb_conv is not None:
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
else:
c1, c2 = torch.chunk(coord_features, 2, dim=1)
c3 = torch.ones_like(c1)
coord_features = torch.cat((c1, c2, c3), dim=1)
x = 0.8 * image * coord_features + 0.2 * image
backbone_features = self.feature_extractor(x)
instance_out = self.head(backbone_features[0])
instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
mode='bilinear', align_corners=True)
return {'instances': instance_out}
def load_weights(self, path_to_weights):
current_state_dict = self.state_dict()
new_state_dict = torch.load(path_to_weights, map_location='cpu')
current_state_dict.update(new_state_dict)
self.load_state_dict(current_state_dict)
def get_trainable_params(self):
backbone_params = nn.ParameterList()
other_params = nn.ParameterList()
for name, param in self.named_parameters():
if param.requires_grad:
if 'backbone' in name:
backbone_params.append(param)
else:
other_params.append(param)
return backbone_params, other_params

View File

@@ -1,87 +0,0 @@
import torch
import torch.nn as nn
from .ops import DistMaps
from .modeling.hrnet_ocr import HighResolutionNet
def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260,
use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False,
norm_layer=nn.BatchNorm2d):
model = DistMapsHRNetModel(
feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
num_classes=1, norm_layer=norm_layer),
use_rgb_conv=use_rgb_conv,
with_aux_output=with_aux_output,
norm_layer=norm_layer,
norm_radius=norm_radius,
cpu_dist_maps=cpu_dist_maps
)
return model
class DistMapsHRNetModel(nn.Module):
def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False,
norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False):
super(DistMapsHRNetModel, self).__init__()
self.with_aux_output = with_aux_output
if use_rgb_conv:
self.rgb_conv = nn.Sequential(
nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
nn.LeakyReLU(negative_slope=0.2),
norm_layer(8),
nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
)
else:
self.rgb_conv = None
self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps)
self.feature_extractor = feature_extractor
def forward(self, image, points):
coord_features = self.dist_maps(image, points)
if self.rgb_conv is not None:
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
else:
c1, c2 = torch.chunk(coord_features, 2, dim=1)
c3 = torch.ones_like(c1)
coord_features = torch.cat((c1, c2, c3), dim=1)
x = 0.8 * image * coord_features + 0.2 * image
feature_extractor_out = self.feature_extractor(x)
instance_out = feature_extractor_out[0]
instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
mode='bilinear', align_corners=True)
outputs = {'instances': instance_out}
if self.with_aux_output:
instance_aux_out = feature_extractor_out[1]
instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:],
mode='bilinear', align_corners=True)
outputs['instances_aux'] = instance_aux_out
return outputs
def load_weights(self, path_to_weights):
current_state_dict = self.state_dict()
new_state_dict = torch.load(path_to_weights)
current_state_dict.update(new_state_dict)
self.load_state_dict(current_state_dict)
def get_trainable_params(self):
backbone_params = nn.ParameterList()
other_params = nn.ParameterList()
other_params_keys = []
nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head']
for name, param in self.named_parameters():
if param.requires_grad:
if any(x in name for x in nonbackbone_keywords):
other_params.append(param)
other_params_keys.append(name)
else:
backbone_params.append(param)
print('Nonbackbone params:', sorted(other_params_keys))
return backbone_params, other_params

View File

@@ -1,134 +0,0 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import misc
class NormalizedFocalLossSigmoid(nn.Module):
def __init__(self, axis=-1, alpha=0.25, gamma=2,
from_logits=False, batch_axis=0,
weight=None, size_average=True, detach_delimeter=True,
eps=1e-12, scale=1.0,
ignore_label=-1):
super(NormalizedFocalLossSigmoid, self).__init__()
self._axis = axis
self._alpha = alpha
self._gamma = gamma
self._ignore_label = ignore_label
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
self._scale = scale
self._from_logits = from_logits
self._eps = eps
self._size_average = size_average
self._detach_delimeter = detach_delimeter
self._k_sum = 0
def forward(self, pred, label, sample_weight=None):
one_hot = label > 0
sample_weight = label != self._ignore_label
if not self._from_logits:
pred = torch.sigmoid(pred)
alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
pt = torch.where(one_hot, pred, 1 - pred)
pt = torch.where(sample_weight, pt, torch.ones_like(pt))
beta = (1 - pt) ** self._gamma
sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
mult = sw_sum / (beta_sum + self._eps)
if self._detach_delimeter:
mult = mult.detach()
beta = beta * mult
ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
if np.any(ignore_area == 0):
self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
loss = self._weight * (loss * sample_weight)
if self._size_average:
bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
else:
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
return self._scale * loss
def log_states(self, sw, name, global_step):
sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
class FocalLoss(nn.Module):
def __init__(self, axis=-1, alpha=0.25, gamma=2,
from_logits=False, batch_axis=0,
weight=None, num_class=None,
eps=1e-9, size_average=True, scale=1.0):
super(FocalLoss, self).__init__()
self._axis = axis
self._alpha = alpha
self._gamma = gamma
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
self._scale = scale
self._num_class = num_class
self._from_logits = from_logits
self._eps = eps
self._size_average = size_average
def forward(self, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.sigmoid(pred)
one_hot = label > 0
pt = torch.where(one_hot, pred, 1 - pred)
t = label != -1
alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t)
beta = (1 - pt) ** self._gamma
loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
sample_weight = label != -1
loss = self._weight * (loss * sample_weight)
if self._size_average:
tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
else:
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
return self._scale * loss
class SigmoidBinaryCrossEntropyLoss(nn.Module):
def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
super(SigmoidBinaryCrossEntropyLoss, self).__init__()
self._from_sigmoid = from_sigmoid
self._ignore_label = ignore_label
self._weight = weight if weight is not None else 1.0
self._batch_axis = batch_axis
def forward(self, pred, label):
label = label.view(pred.size())
sample_weight = label != self._ignore_label
label = torch.where(sample_weight, label, torch.zeros_like(label))
if not self._from_sigmoid:
loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
else:
eps = 1e-12
loss = -(torch.log(pred + eps) * label
+ torch.log(1. - pred + eps) * (1. - label))
loss = self._weight * (loss * sample_weight)
return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))

View File

@@ -1,101 +0,0 @@
import torch
import numpy as np
from ..utils import misc
class TrainMetric(object):
def __init__(self, pred_outputs, gt_outputs):
self.pred_outputs = pred_outputs
self.gt_outputs = gt_outputs
def update(self, *args, **kwargs):
raise NotImplementedError
def get_epoch_value(self):
raise NotImplementedError
def reset_epoch_stats(self):
raise NotImplementedError
def log_states(self, sw, tag_prefix, global_step):
pass
@property
def name(self):
return type(self).__name__
class AdaptiveIoU(TrainMetric):
def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
ignore_label=-1, from_logits=True,
pred_output='instances', gt_output='instances'):
super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
self._ignore_label = ignore_label
self._from_logits = from_logits
self._iou_thresh = init_thresh
self._thresh_step = thresh_step
self._thresh_beta = thresh_beta
self._iou_beta = iou_beta
self._ema_iou = 0.0
self._epoch_iou_sum = 0.0
self._epoch_batch_count = 0
def update(self, pred, gt):
gt_mask = gt > 0
if self._from_logits:
pred = torch.sigmoid(pred)
gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
if np.all(gt_mask_area == 0):
return
ignore_mask = gt == self._ignore_label
max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
best_thresh = self._iou_thresh
for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
if temp_iou > max_iou:
max_iou = temp_iou
best_thresh = t
self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
self._epoch_iou_sum += max_iou
self._epoch_batch_count += 1
def get_epoch_value(self):
if self._epoch_batch_count > 0:
return self._epoch_iou_sum / self._epoch_batch_count
else:
return 0.0
def reset_epoch_stats(self):
self._epoch_iou_sum = 0.0
self._epoch_batch_count = 0
def log_states(self, sw, tag_prefix, global_step):
sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
@property
def iou_thresh(self):
return self._iou_thresh
def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
if ignore_mask is not None:
pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
nonzero = union > 0
iou = intersection[nonzero] / union[nonzero]
if not keep_ignore:
return iou
else:
result = np.full_like(intersection, -1)
result[nonzero] = iou
return result

View File

@@ -1,71 +0,0 @@
import torch.nn as nn
from ...model import ops
class ConvHead(nn.Module):
def __init__(self, out_channels, in_channels=32, num_layers=1,
kernel_size=3, padding=1,
norm_layer=nn.BatchNorm2d):
super(ConvHead, self).__init__()
convhead = []
for i in range(num_layers):
convhead.extend([
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
nn.ReLU(),
norm_layer(in_channels) if norm_layer is not None else nn.Identity()
])
convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
self.convhead = nn.Sequential(*convhead)
def forward(self, *inputs):
return self.convhead(inputs[0])
class SepConvHead(nn.Module):
def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
norm_layer=nn.BatchNorm2d):
super(SepConvHead, self).__init__()
sepconvhead = []
for i in range(num_layers):
sepconvhead.append(
SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
out_channels=mid_channels,
dw_kernel=kernel_size, dw_padding=padding,
norm_layer=norm_layer, activation='relu')
)
if dropout_ratio > 0 and dropout_indx == i:
sepconvhead.append(nn.Dropout(dropout_ratio))
sepconvhead.append(
nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
)
self.layers = nn.Sequential(*sepconvhead)
def forward(self, *inputs):
x = inputs[0]
return self.layers(x)
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
activation=None, use_bias=False, norm_layer=None):
super(SeparableConv2d, self).__init__()
_activation = ops.select_activation_function(activation)
self.body = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
padding=dw_padding, bias=use_bias, groups=in_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
_activation()
)
def forward(self, x):
return self.body(x)

View File

@@ -1,176 +0,0 @@
from contextlib import ExitStack
import torch
from torch import nn
import torch.nn.functional as F
from .basic_blocks import SeparableConv2d
from .resnet import ResNetBackbone
from ...model import ops
class DeepLabV3Plus(nn.Module):
def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
backbone_norm_layer=None,
ch=256,
project_dropout=0.5,
inference_mode=False,
**kwargs):
super(DeepLabV3Plus, self).__init__()
if backbone_norm_layer is None:
backbone_norm_layer = norm_layer
self.backbone_name = backbone
self.norm_layer = norm_layer
self.backbone_norm_layer = backbone_norm_layer
self.inference_mode = False
self.ch = ch
self.aspp_in_channels = 2048
self.skip_project_in_channels = 256 # layer 1 out_channels
self._kwargs = kwargs
if backbone == 'resnet34':
self.aspp_in_channels = 512
self.skip_project_in_channels = 64
self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
norm_layer=self.backbone_norm_layer, **kwargs)
self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
norm_layer=self.norm_layer)
self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
self.aspp = _ASPP(in_channels=self.aspp_in_channels,
atrous_rates=[12, 24, 36],
out_channels=ch,
project_dropout=project_dropout,
norm_layer=self.norm_layer)
if inference_mode:
self.set_prediction_mode()
def load_pretrained_weights(self):
pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
norm_layer=self.backbone_norm_layer, **self._kwargs)
backbone_state_dict = self.backbone.state_dict()
pretrained_state_dict = pretrained.state_dict()
backbone_state_dict.update(pretrained_state_dict)
self.backbone.load_state_dict(backbone_state_dict)
if self.inference_mode:
for param in self.backbone.parameters():
param.requires_grad = False
def set_prediction_mode(self):
self.inference_mode = True
self.eval()
def forward(self, x):
with ExitStack() as stack:
if self.inference_mode:
stack.enter_context(torch.no_grad())
c1, _, c3, c4 = self.backbone(x)
c1 = self.skip_project(c1)
x = self.aspp(c4)
x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, c1), dim=1)
x = self.head(x)
return x,
class _SkipProject(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
super(_SkipProject, self).__init__()
_activation = ops.select_activation_function("relu")
self.skip_project = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
norm_layer(out_channels),
_activation()
)
def forward(self, x):
return self.skip_project(x)
class _DeepLabHead(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
super(_DeepLabHead, self).__init__()
self.block = nn.Sequential(
SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
dw_padding=1, activation='relu', norm_layer=norm_layer),
SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
dw_padding=1, activation='relu', norm_layer=norm_layer),
nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
)
def forward(self, x):
return self.block(x)
class _ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256,
project_dropout=0.5, norm_layer=nn.BatchNorm2d):
super(_ASPP, self).__init__()
b0 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
norm_layer(out_channels),
nn.ReLU()
)
rate1, rate2, rate3 = tuple(atrous_rates)
b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
project = [
nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
kernel_size=1, bias=False),
norm_layer(out_channels),
nn.ReLU()
]
if project_dropout > 0:
project.append(nn.Dropout(project_dropout))
self.project = nn.Sequential(*project)
def forward(self, x):
x = torch.cat([block(x) for block in self.concurent], dim=1)
return self.project(x)
class _AsppPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer):
super(_AsppPooling, self).__init__()
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, bias=False),
norm_layer(out_channels),
nn.ReLU()
)
def forward(self, x):
pool = self.gap(x)
return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=3, padding=atrous_rate,
dilation=atrous_rate, bias=False),
norm_layer(out_channels),
nn.ReLU()
)
return block

View File

@@ -1,399 +0,0 @@
import os
import numpy as np
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
from .ocr import SpatialOCR_Module, SpatialGather_Module
from .resnetv1b import BasicBlockV1b, BottleneckV1b
relu_inplace = True
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels, fuse_method,multi_scale_output=True,
norm_layer=nn.BatchNorm2d, align_corners=True):
super(HighResolutionModule, self).__init__()
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.norm_layer = norm_layer
self.align_corners = align_corners
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(
num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=relu_inplace)
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False),
self.norm_layer(num_channels[branch_index] * block.expansion),
)
layers = []
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index], stride,
downsample=downsample, norm_layer=self.norm_layer))
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index],
norm_layer=self.norm_layer))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(in_channels=num_inchannels[j],
out_channels=num_inchannels[i],
kernel_size=1,
bias=False),
self.norm_layer(num_inchannels[i])))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
kernel_size=3, stride=2, padding=1, bias=False),
self.norm_layer(num_outchannels_conv3x3)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
kernel_size=3, stride=2, padding=1, bias=False),
self.norm_layer(num_outchannels_conv3x3),
nn.ReLU(inplace=relu_inplace)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
elif j > i:
width_output = x[i].shape[-1]
height_output = x[i].shape[-2]
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=[height_output, width_output],
mode='bilinear', align_corners=self.align_corners)
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
class HighResolutionNet(nn.Module):
def __init__(self, width, num_classes, ocr_width=256, small=False,
norm_layer=nn.BatchNorm2d, align_corners=True):
super(HighResolutionNet, self).__init__()
self.norm_layer = norm_layer
self.width = width
self.ocr_width = ocr_width
self.align_corners = align_corners
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = norm_layer(64)
self.relu = nn.ReLU(inplace=relu_inplace)
num_blocks = 2 if small else 4
stage1_num_channels = 64
self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
self.stage2_num_branches = 2
num_channels = [width, 2 * width]
num_inchannels = [
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer(
[stage1_out_channel], num_inchannels)
self.stage2, pre_stage_channels = self._make_stage(
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
num_blocks=2 * [num_blocks], num_channels=num_channels)
self.stage3_num_branches = 3
num_channels = [width, 2 * width, 4 * width]
num_inchannels = [
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
self.transition2 = self._make_transition_layer(
pre_stage_channels, num_inchannels)
self.stage3, pre_stage_channels = self._make_stage(
BasicBlockV1b, num_inchannels=num_inchannels,
num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
num_blocks=3 * [num_blocks], num_channels=num_channels)
self.stage4_num_branches = 4
num_channels = [width, 2 * width, 4 * width, 8 * width]
num_inchannels = [
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
self.transition3 = self._make_transition_layer(
pre_stage_channels, num_inchannels)
self.stage4, pre_stage_channels = self._make_stage(
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
num_branches=self.stage4_num_branches,
num_blocks=4 * [num_blocks], num_channels=num_channels)
last_inp_channels = np.int32(np.sum(pre_stage_channels))
ocr_mid_channels = 2 * ocr_width
ocr_key_channels = ocr_width
self.conv3x3_ocr = nn.Sequential(
nn.Conv2d(last_inp_channels, ocr_mid_channels,
kernel_size=3, stride=1, padding=1),
norm_layer(ocr_mid_channels),
nn.ReLU(inplace=relu_inplace),
)
self.ocr_gather_head = SpatialGather_Module(num_classes)
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
key_channels=ocr_key_channels,
out_channels=ocr_mid_channels,
scale=1,
dropout=0.05,
norm_layer=norm_layer,
align_corners=align_corners)
self.cls_head = nn.Conv2d(
ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
self.aux_head = nn.Sequential(
nn.Conv2d(last_inp_channels, last_inp_channels,
kernel_size=1, stride=1, padding=0),
norm_layer(last_inp_channels),
nn.ReLU(inplace=relu_inplace),
nn.Conv2d(last_inp_channels, num_classes,
kernel_size=1, stride=1, padding=0, bias=True)
)
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
self.norm_layer(num_channels_cur_layer[i]),
nn.ReLU(inplace=relu_inplace)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(inchannels, outchannels,
kernel_size=3, stride=2, padding=1, bias=False),
self.norm_layer(outchannels),
nn.ReLU(inplace=relu_inplace)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
self.norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(inplanes, planes, stride,
downsample=downsample, norm_layer=self.norm_layer))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
return nn.Sequential(*layers)
def _make_stage(self, block, num_inchannels,
num_modules, num_branches, num_blocks, num_channels,
fuse_method='SUM',
multi_scale_output=True):
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
reset_multi_scale_output,
norm_layer=self.norm_layer,
align_corners=self.align_corners)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def forward(self, x):
feats = self.compute_hrnet_feats(x)
out_aux = self.aux_head(feats)
feats = self.conv3x3_ocr(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
return [out, out_aux]
def compute_hrnet_feats(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_num_branches):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_num_branches):
if self.transition2[i] is not None:
if i < self.stage2_num_branches:
x_list.append(self.transition2[i](y_list[i]))
else:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_num_branches):
if self.transition3[i] is not None:
if i < self.stage3_num_branches:
x_list.append(self.transition3[i](y_list[i]))
else:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
x = self.stage4(x_list)
# Upsampling
x0_h, x0_w = x[0].size(2), x[0].size(3)
x1 = F.interpolate(x[1], size=(x0_h, x0_w),
mode='bilinear', align_corners=self.align_corners)
x2 = F.interpolate(x[2], size=(x0_h, x0_w),
mode='bilinear', align_corners=self.align_corners)
x3 = F.interpolate(x[3], size=(x0_h, x0_w),
mode='bilinear', align_corners=self.align_corners)
return torch.cat([x[0], x1, x2, x3], 1)
def load_pretrained_weights(self, pretrained_path=''):
model_dict = self.state_dict()
if not os.path.exists(pretrained_path):
print(f'\nFile "{pretrained_path}" does not exist.')
print('You need to specify the correct path to the pre-trained weights.\n'
'You can download the weights for HRNet from the repository:\n'
'https://github.com/HRNet/HRNet-Image-Classification')
exit(1)
pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
pretrained_dict.items()}
print('model_dict-pretrained_dict:', sorted(list(set(model_dict) - set(pretrained_dict))))
print('pretrained_dict-model_dict:', sorted(list(set(pretrained_dict) - set(model_dict))))
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)

View File

@@ -1,141 +0,0 @@
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
class SpatialGather_Module(nn.Module):
"""
Aggregate the context features according to the initial
predicted probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, cls_num=0, scale=1):
super(SpatialGather_Module, self).__init__()
self.cls_num = cls_num
self.scale = scale
def forward(self, feats, probs):
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
probs = probs.view(batch_size, c, -1)
feats = feats.view(batch_size, feats.size(1), -1)
feats = feats.permute(0, 2, 1) # batch x hw x c
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
ocr_context = torch.matmul(probs, feats) \
.permute(0, 2, 1).unsqueeze(3) # batch x k x c
return ocr_context
class SpatialOCR_Module(nn.Module):
"""
Implementation of the OCR module:
We aggregate the global object representation to update the representation for each pixel.
"""
def __init__(self,
in_channels,
key_channels,
out_channels,
scale=1,
dropout=0.1,
norm_layer=nn.BatchNorm2d,
align_corners=True):
super(SpatialOCR_Module, self).__init__()
self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
norm_layer, align_corners)
_in_channels = 2 * in_channels
self.conv_bn_dropout = nn.Sequential(
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
nn.Dropout2d(dropout)
)
def forward(self, feats, proxy_feats):
context = self.object_context_block(feats, proxy_feats)
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
return output
class ObjectAttentionBlock2D(nn.Module):
'''
The basic implementation for object context block
Input:
N X C X H X W
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
scale : choose the scale to downsample the input feature maps (save memory cost)
bn_type : specify the bn type
Return:
N X C X H X W
'''
def __init__(self,
in_channels,
key_channels,
scale=1,
norm_layer=nn.BatchNorm2d,
align_corners=True):
super(ObjectAttentionBlock2D, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.key_channels = key_channels
self.align_corners = align_corners
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
self.f_pixel = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
)
self.f_object = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
)
self.f_down = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
)
self.f_up = nn.Sequential(
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0, bias=False),
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
)
def forward(self, x, proxy):
batch_size, h, w = x.size(0), x.size(2), x.size(3)
if self.scale > 1:
x = self.pool(x)
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
value = value.permute(0, 2, 1)
sim_map = torch.matmul(query, key)
sim_map = (self.key_channels ** -.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
# add bg context ...
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.key_channels, *x.size()[2:])
context = self.f_up(context)
if self.scale > 1:
context = F.interpolate(input=context, size=(h, w),
mode='bilinear', align_corners=self.align_corners)
return context

View File

@@ -1,39 +0,0 @@
import torch
from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
class ResNetBackbone(torch.nn.Module):
def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
super(ResNetBackbone, self).__init__()
if backbone == 'resnet34':
pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
elif backbone == 'resnet50':
pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
elif backbone == 'resnet101':
pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
elif backbone == 'resnet152':
pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
else:
raise RuntimeError(f'unknown backbone: {backbone}')
self.conv1 = pretrained.conv1
self.bn1 = pretrained.bn1
self.relu = pretrained.relu
self.maxpool = pretrained.maxpool
self.layer1 = pretrained.layer1
self.layer2 = pretrained.layer2
self.layer3 = pretrained.layer3
self.layer4 = pretrained.layer4
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
c1 = self.layer1(x)
c2 = self.layer2(c1)
c3 = self.layer3(c2)
c4 = self.layer4(c3)
return c1, c2, c3, c4

View File

@@ -1,276 +0,0 @@
import torch
import torch.nn as nn
GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
class BasicBlockV1b(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BasicBlockV1b, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=previous_dilation, dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
class BottleneckV1b(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BottleneckV1b, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn2 = norm_layer(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
class ResNetV1b(nn.Module):
""" Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
classes : int, default 1000
Number of classification classes.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
norm_layer : object
Normalization layer used (default: :class:`nn.BatchNorm2d`)
deep_stem : bool, default False
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
final_drop : float, default 0.0
Dropout ratio before the final classification layer.
Reference:
- He, Kaiming, et al. "Deep residual learning for image recognition."
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
self.inplanes = stem_width*2 if deep_stem else 64
super(ResNetV1b, self).__init__()
if not deep_stem:
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(True),
nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(True),
nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(True)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
norm_layer=norm_layer)
if dilated:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
avg_down=avg_down, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
avg_down=avg_down, norm_layer=norm_layer)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
avg_down=avg_down, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
avg_down=avg_down, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.drop = None
if final_drop > 0.0:
self.drop = nn.Dropout(final_drop)
self.fc = nn.Linear(512 * block.expansion, classes)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
avg_down=False, norm_layer=nn.BatchNorm2d):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = []
if avg_down:
if dilation == 1:
downsample.append(
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
)
else:
downsample.append(
nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
)
downsample.extend([
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
kernel_size=1, stride=1, bias=False),
norm_layer(planes * block.expansion)
])
downsample = nn.Sequential(*downsample)
else:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
norm_layer(planes * block.expansion)
)
layers = []
if dilation in (1, 2):
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
previous_dilation=dilation, norm_layer=norm_layer))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
previous_dilation=dilation, norm_layer=norm_layer))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation,
previous_dilation=dilation, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
if self.drop is not None:
x = self.drop(x)
x = self.fc(x)
return x
def _safe_state_dict_filtering(orig_dict, model_dict_keys):
filtered_orig_dict = {}
for k, v in orig_dict.items():
if k in model_dict_keys:
filtered_orig_dict[k] = v
else:
print(f"[ERROR] Failed to load <{k}> in backbone")
return filtered_orig_dict
def resnet34_v1b(pretrained=False, **kwargs):
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model
def resnet50_v1s(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model
def resnet101_v1s(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model
def resnet152_v1s(pretrained=False, **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
if pretrained:
model_dict = model.state_dict()
filtered_orig_dict = _safe_state_dict_filtering(
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
model_dict.keys()
)
model_dict.update(filtered_orig_dict)
model.load_state_dict(model_dict)
return model

View File

@@ -1,83 +0,0 @@
import torch
from torch import nn as nn
import numpy as np
from . import initializer as initializer
from ..utils.cython import get_dist_maps
def select_activation_function(activation):
if isinstance(activation, str):
if activation.lower() == 'relu':
return nn.ReLU
elif activation.lower() == 'softplus':
return nn.Softplus
else:
raise ValueError(f"Unknown activation type {activation}")
elif isinstance(activation, nn.Module):
return activation
else:
raise ValueError(f"Unknown activation type {activation}")
class BilinearConvTranspose2d(nn.ConvTranspose2d):
def __init__(self, in_channels, out_channels, scale, groups=1):
kernel_size = 2 * scale - scale % 2
self.scale = scale
super().__init__(
in_channels, out_channels,
kernel_size=kernel_size,
stride=scale,
padding=1,
groups=groups,
bias=False)
self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
class DistMaps(nn.Module):
def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
super(DistMaps, self).__init__()
self.spatial_scale = spatial_scale
self.norm_radius = norm_radius
self.cpu_mode = cpu_mode
def get_coord_features(self, points, batchsize, rows, cols):
if self.cpu_mode:
coords = []
for i in range(batchsize):
norm_delimeter = self.spatial_scale * self.norm_radius
coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
norm_delimeter))
coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
else:
num_points = points.shape[1] // 2
points = points.view(-1, 2)
invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
coords.add_(-add_xy)
coords.div_(self.norm_radius * self.spatial_scale)
coords.mul_(coords)
coords[:, 0] += coords[:, 1]
coords = coords[:, :1]
coords[invalid_points, :, :, :] = 1e6
coords = coords.view(-1, num_points, 1, rows, cols)
coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w
coords = coords.view(-1, 2, rows, cols)
coords.sqrt_().mul_(2).tanh_()
return coords
def forward(self, x, coords):
return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3])

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2018 Tamaki Kojima
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,127 +0,0 @@
# pytorch-syncbn
Tamaki Kojima(tamakoji@gmail.com)
## Announcement
**Pytorch 1.0 support**
## Overview
This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn)
## Remarks
- Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel`
- Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation
- You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
- Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d`
## Requirements
For PyTorch, please refer to https://pytorch.org/
NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04
It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use.
```
sudo apt-get install ninja-build
```
Also install all dependencies for python. For pip, run:
```
pip install -U -r requirements.txt
```
## Build
There is no need to build. just run and JIT will take care.
JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes.
## Usage
Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d`
```
import torch
from modules import nn as NN
num_gpu = torch.cuda.device_count()
model = nn.Sequential(
nn.Conv2d(3, 3, 1, 1, bias=False),
NN.BatchNorm2d(3),
nn.ReLU(inplace=True),
nn.Conv2d(3, 3, 1, 1, bias=False),
NN.BatchNorm2d(3),
).cuda()
model = nn.DataParallel(model, device_ids=range(num_gpu))
x = torch.rand(num_gpu, 3, 2, 2).cuda()
z = model(x)
```
## Math
### Forward
1. compute <img src="https://latex.codecogs.com/gif.latex?\sum{x_i},\sum{x_i^2}"/> in each gpu
2. gather all <img src="https://latex.codecogs.com/gif.latex?\sum{x_i},\sum{x_i^2}"/> from workers to master and compute <img src="https://latex.codecogs.com/gif.latex?\mu,\sigma"/> where
<img src="https://latex.codecogs.com/gif.latex?\mu=\frac{\sum{x_i}}{N}"/>
and
<img src="https://latex.codecogs.com/gif.latex?\sigma^2=\frac{\sum{x_i^2}-\mu\sum{x_i}}{N}"/></a>
and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
3. forward batchnorm using global stats by
<img src="https://latex.codecogs.com/gif.latex?\hat{x_i}=\frac{x_i-\mu}{\sqrt{\sigma^2&plus;\epsilon}}"/>
and then
<img src="https://latex.codecogs.com/gif.latex?y_i=\gamma\cdot\hat{x_i}&plus;\beta"/>
where <img src="https://latex.codecogs.com/gif.latex?\gamma"/> is weight parameter and <img src="https://latex.codecogs.com/gif.latex?\beta"/> is bias parameter.
4. save <img src="https://latex.codecogs.com/gif.latex?x,&space;\gamma\&space;\beta,&space;\mu,&space;\sigma^2"/> for backward
### Backward
1. Restore saved <img src="https://latex.codecogs.com/gif.latex?x,&space;\gamma\&space;\beta,&space;\mu,&space;\sigma^2"/>
2. Compute below sums on each gpu
<img src="https://latex.codecogs.com/gif.latex?\sum_{i=1}^{N_j}(\frac{dJ}{dy_i})"/>
and
<img src="https://latex.codecogs.com/gif.latex?\sum_{i=1}^{N_j}(\frac{dJ}{dy_i}\cdot\hat{x_i})"/>
where <img src="https://latex.codecogs.com/gif.latex?j\in[0,1,....,num\_gpu]"/>
then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
3. compute gradients using global stats
<img src="https://latex.codecogs.com/gif.latex?\frac{dJ}{dx_i},&space;\frac{dJ}{d\gamma},&space;\frac{dJ}{d\beta}&space;"/>
where
<img src="https://latex.codecogs.com/gif.latex?\frac{dJ}{d\gamma}=\sum_{i=1}^{N}(\frac{dJ}{dy_i}\cdot\hat{x_i})"/>
and
<img src="https://latex.codecogs.com/gif.latex?\frac{dJ}{d\beta}=\sum_{i=1}^{N}(\frac{dJ}{dy_i})"/>
and finally,
<img src="https://latex.codecogs.com/gif.latex?\frac{dJ}{dx_i}=\frac{dJ}{d\hat{x_i}}\frac{d\hat{x_i}}{dx_i}+\frac{dJ}{d\mu_i}\frac{d\mu_i}{dx_i}+\frac{dJ}{d\sigma^2_i}\frac{d\sigma^2_i}{dx_i}"/>
<img src="https://latex.codecogs.com/gif.latex?=\frac{1}{N\sqrt{(\sigma^2+\epsilon)}}(N\frac{dJ}{d\hat{x_i}}-\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}})-\hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j}))"/>
<img src="https://latex.codecogs.com/gif.latex?=\frac{\gamma}{N\sqrt{(\sigma^2+\epsilon)}}(N\frac{dJ}{dy_i}-\sum_{j=1}^{N}(\frac{dJ}{dy_j})-\hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{dy_j}\hat{x_j}))"/>
Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.
You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/)

View File

@@ -1 +0,0 @@
from .syncbn import batchnorm2d_sync

View File

@@ -1,54 +0,0 @@
"""
/*****************************************************************************/
Extension module loader
code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark
/*****************************************************************************/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os.path
import torch
try:
from torch.utils.cpp_extension import load
from torch.utils.cpp_extension import CUDA_HOME
except ImportError:
raise ImportError(
"The cpp layer extensions requires PyTorch 0.4 or higher")
def _load_C_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
this_dir = os.path.join(this_dir, "csrc")
main_file = glob.glob(os.path.join(this_dir, "*.cpp"))
sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp"))
sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu"))
sources = main_file + sources_cpu
extra_cflags = []
extra_cuda_cflags = []
if torch.cuda.is_available() and CUDA_HOME is not None:
sources.extend(sources_cuda)
extra_cflags = ["-O3", "-DWITH_CUDA"]
extra_cuda_cflags = ["--expt-extended-lambda"]
sources = [os.path.join(this_dir, s) for s in sources]
extra_include_paths = [this_dir]
return load(
name="ext_lib",
sources=sources,
extra_cflags=extra_cflags,
extra_include_paths=extra_include_paths,
extra_cuda_cflags=extra_cuda_cflags,
)
_backend = _load_C_extensions()

View File

@@ -1,70 +0,0 @@
/*****************************************************************************
SyncBN
*****************************************************************************/
#pragma once
#ifdef WITH_CUDA
#include "cuda/ext_lib.h"
#endif
/// SyncBN
std::vector<at::Tensor> syncbn_sum_sqsum(const at::Tensor& x) {
if (x.is_cuda()) {
#ifdef WITH_CUDA
return syncbn_sum_sqsum_cuda(x);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
AT_ERROR("CPU implementation not supported");
}
}
at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& mean,
const at::Tensor& var, bool affine, float eps) {
if (x.is_cuda()) {
#ifdef WITH_CUDA
return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
AT_ERROR("CPU implementation not supported");
}
}
std::vector<at::Tensor> syncbn_backward_xhat(const at::Tensor& dz,
const at::Tensor& x,
const at::Tensor& mean,
const at::Tensor& var, float eps) {
if (dz.is_cuda()) {
#ifdef WITH_CUDA
return syncbn_backward_xhat_cuda(dz, x, mean, var, eps);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
AT_ERROR("CPU implementation not supported");
}
}
std::vector<at::Tensor> syncbn_backward(
const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
float eps) {
if (dz.is_cuda()) {
#ifdef WITH_CUDA
return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz,
sum_dz_xhat, affine, eps);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
AT_ERROR("CPU implementation not supported");
}
}

View File

@@ -1,280 +0,0 @@
/*****************************************************************************
CUDA SyncBN code
code referenced from : https://github.com/mapillary/inplace_abn
*****************************************************************************/
#include <ATen/ATen.h>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <vector>
#include "cuda/common.h"
// Utilities
void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) {
num = x.size(0);
chn = x.size(1);
sp = 1;
for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i);
}
/// SyncBN
template <typename T>
struct SqSumOp {
__device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {}
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
T x = tensor[(batch * chn + plane) * sp + n];
return Pair<T>(x, x * x); // x, x^2
}
const T *tensor;
const int chn;
const int sp;
};
template <typename T>
__global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum,
int num, int chn, int sp) {
int plane = blockIdx.x;
Pair<T> res =
reduce<Pair<T>, SqSumOp<T>>(SqSumOp<T>(x, chn, sp), plane, num, chn, sp);
__syncthreads();
if (threadIdx.x == 0) {
sum[plane] = res.v1;
sqsum[plane] = res.v2;
}
}
std::vector<at::Tensor> syncbn_sum_sqsum_cuda(const at::Tensor &x) {
CHECK_INPUT(x);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
// Prepare output tensors
auto sum = at::empty({chn}, x.options());
auto sqsum = at::empty({chn}, x.options());
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
AT_DISPATCH_FLOATING_TYPES(
x.type(), "syncbn_sum_sqsum_cuda", ([&] {
syncbn_sum_sqsum_kernel<scalar_t><<<blocks, threads>>>(
x.data<scalar_t>(), sum.data<scalar_t>(),
sqsum.data<scalar_t>(), num, chn, sp);
}));
return {sum, sqsum};
}
template <typename T>
__global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight,
const T *bias, const T *mean,
const T *var, bool affine, float eps,
int num, int chn, int sp) {
int plane = blockIdx.x;
T _mean = mean[plane];
T _var = var[plane];
T _weight = affine ? weight[plane] : T(1);
T _bias = affine ? bias[plane] : T(0);
float _invstd = T(0);
if (_var || eps) {
_invstd = rsqrt(_var + eps);
}
for (int batch = 0; batch < num; ++batch) {
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
T _x = x[(batch * chn + plane) * sp + n];
T _xhat = (_x - _mean) * _invstd;
T _z = _xhat * _weight + _bias;
z[(batch * chn + plane) * sp + n] = _z;
}
}
}
at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &mean,
const at::Tensor &var, bool affine, float eps) {
CHECK_INPUT(x);
CHECK_INPUT(weight);
CHECK_INPUT(bias);
CHECK_INPUT(mean);
CHECK_INPUT(var);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
auto z = at::zeros_like(x);
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
AT_DISPATCH_FLOATING_TYPES(
x.type(), "syncbn_forward_cuda", ([&] {
syncbn_forward_kernel<scalar_t><<<blocks, threads>>>(
z.data<scalar_t>(), x.data<scalar_t>(),
weight.data<scalar_t>(), bias.data<scalar_t>(),
mean.data<scalar_t>(), var.data<scalar_t>(),
affine, eps, num, chn, sp);
}));
return z;
}
template <typename T>
struct XHatOp {
__device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s)
: weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {}
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
// xhat = (x - bias) * weight
T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight;
// dxhat * x_hat
T _dz = dz[(batch * chn + plane) * sp + n];
return Pair<T>(_dz, _dz * _xhat);
}
const T weight;
const T bias;
const T *dz;
const T *x;
const int chn;
const int sp;
};
template <typename T>
__global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x,
const T *mean, const T *var,
T *sum_dz, T *sum_dz_xhat,
float eps, int num, int chn,
int sp) {
int plane = blockIdx.x;
T _mean = mean[plane];
T _var = var[plane];
T _invstd = T(0);
if (_var || eps) {
_invstd = rsqrt(_var + eps);
}
Pair<T> res = reduce<Pair<T>, XHatOp<T>>(
XHatOp<T>(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp);
__syncthreads();
if (threadIdx.x == 0) {
// \sum(\frac{dJ}{dy_i})
sum_dz[plane] = res.v1;
// \sum(\frac{dJ}{dy_i}*\hat{x_i})
sum_dz_xhat[plane] = res.v2;
}
}
std::vector<at::Tensor> syncbn_backward_xhat_cuda(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &mean,
const at::Tensor &var,
float eps) {
CHECK_INPUT(dz);
CHECK_INPUT(x);
CHECK_INPUT(mean);
CHECK_INPUT(var);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
// Prepare output tensors
auto sum_dz = at::empty({chn}, x.options());
auto sum_dz_xhat = at::empty({chn}, x.options());
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
AT_DISPATCH_FLOATING_TYPES(
x.type(), "syncbn_backward_xhat_cuda", ([&] {
syncbn_backward_xhat_kernel<scalar_t><<<blocks, threads>>>(
dz.data<scalar_t>(), x.data<scalar_t>(), mean.data<scalar_t>(),
var.data<scalar_t>(), sum_dz.data<scalar_t>(),
sum_dz_xhat.data<scalar_t>(), eps, num, chn, sp);
}));
return {sum_dz, sum_dz_xhat};
}
template <typename T>
__global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight,
const T *bias, const T *mean,
const T *var, const T *sum_dz,
const T *sum_dz_xhat, T *dx, T *dweight,
T *dbias, bool affine, float eps,
int num, int chn, int sp) {
int plane = blockIdx.x;
T _mean = mean[plane];
T _var = var[plane];
T _weight = affine ? weight[plane] : T(1);
T _sum_dz = sum_dz[plane];
T _sum_dz_xhat = sum_dz_xhat[plane];
T _invstd = T(0);
if (_var || eps) {
_invstd = rsqrt(_var + eps);
}
/*
\frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} (
N\frac{dJ}{d\hat{x_i}} -
\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) -
\hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j})
)
Note : N is omitted here since it will be accumulated and
_sum_dz and _sum_dz_xhat expected to be already normalized
before the call.
*/
if (dx) {
T _mul = _weight * _invstd;
for (int batch = 0; batch < num; ++batch) {
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
T _dz = dz[(batch * chn + plane) * sp + n];
T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd;
T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul;
dx[(batch * chn + plane) * sp + n] = _dx;
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
if (affine) {
T _norm = num * sp;
dweight[plane] += _sum_dz_xhat * _norm;
dbias[plane] += _sum_dz * _norm;
}
}
}
std::vector<at::Tensor> syncbn_backward_cuda(
const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var,
const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine,
float eps) {
CHECK_INPUT(dz);
CHECK_INPUT(x);
CHECK_INPUT(weight);
CHECK_INPUT(bias);
CHECK_INPUT(mean);
CHECK_INPUT(var);
CHECK_INPUT(sum_dz);
CHECK_INPUT(sum_dz_xhat);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
// Prepare output tensors
auto dx = at::zeros_like(dz);
auto dweight = at::zeros_like(weight);
auto dbias = at::zeros_like(bias);
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
AT_DISPATCH_FLOATING_TYPES(
x.type(), "syncbn_backward_cuda", ([&] {
syncbn_backward_kernel<scalar_t><<<blocks, threads>>>(
dz.data<scalar_t>(), x.data<scalar_t>(), weight.data<scalar_t>(),
bias.data<scalar_t>(), mean.data<scalar_t>(), var.data<scalar_t>(),
sum_dz.data<scalar_t>(), sum_dz_xhat.data<scalar_t>(),
dx.data<scalar_t>(), dweight.data<scalar_t>(),
dbias.data<scalar_t>(), affine, eps, num, chn, sp);
}));
return {dx, dweight, dbias};
}

View File

@@ -1,124 +0,0 @@
/*****************************************************************************
CUDA utility funcs
code referenced from : https://github.com/mapillary/inplace_abn
*****************************************************************************/
#pragma once
#include <cuda_runtime_api.h>
// Checks
#ifndef AT_CHECK
#define AT_CHECK AT_ASSERT
#endif
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/*
* General settings
*/
const int WARP_SIZE = 32;
const int MAX_BLOCK_SIZE = 512;
template <typename T>
struct Pair {
T v1, v2;
__device__ Pair() {}
__device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
__device__ Pair(T v) : v1(v), v2(v) {}
__device__ Pair(int v) : v1(v), v2(v) {}
__device__ Pair &operator+=(const Pair<T> &a) {
v1 += a.v1;
v2 += a.v2;
return *this;
}
};
/*
* Utility functions
*/
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask,
int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDART_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
static int getNumThreads(int nElem) {
int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
for (int i = 0; i != 5; ++i) {
if (nElem <= threadSizes[i]) {
return threadSizes[i];
}
}
return MAX_BLOCK_SIZE;
}
template <typename T>
static __device__ __forceinline__ T warpSum(T val) {
#if __CUDA_ARCH__ >= 300
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
}
#else
__shared__ T values[MAX_BLOCK_SIZE];
values[threadIdx.x] = val;
__threadfence_block();
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
for (int i = 1; i < WARP_SIZE; i++) {
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
}
#endif
return val;
}
template <typename T>
static __device__ __forceinline__ Pair<T> warpSum(Pair<T> value) {
value.v1 = warpSum(value.v1);
value.v2 = warpSum(value.v2);
return value;
}
template <typename T, typename Op>
__device__ T reduce(Op op, int plane, int N, int C, int S) {
T sum = (T)0;
for (int batch = 0; batch < N; ++batch) {
for (int x = threadIdx.x; x < S; x += blockDim.x) {
sum += op(batch, plane, x);
}
}
// sum over NumThreads within a warp
sum = warpSum(sum);
// 'transpose', and reduce within warp again
__shared__ T shared[32];
__syncthreads();
if (threadIdx.x % WARP_SIZE == 0) {
shared[threadIdx.x / WARP_SIZE] = sum;
}
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
// zero out the other entries in shared
shared[threadIdx.x] = (T)0;
}
__syncthreads();
if (threadIdx.x / WARP_SIZE == 0) {
sum = warpSum(shared[threadIdx.x]);
if (threadIdx.x == 0) {
shared[0] = sum;
}
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}

View File

@@ -1,24 +0,0 @@
/*****************************************************************************
CUDA SyncBN code
*****************************************************************************/
#pragma once
#include <torch/extension.h>
#include <vector>
/// Sync-BN
std::vector<at::Tensor> syncbn_sum_sqsum_cuda(const at::Tensor& x);
at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& mean,
const at::Tensor& var, bool affine, float eps);
std::vector<at::Tensor> syncbn_backward_xhat_cuda(const at::Tensor& dz,
const at::Tensor& x,
const at::Tensor& mean,
const at::Tensor& var,
float eps);
std::vector<at::Tensor> syncbn_backward_cuda(
const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
float eps);

View File

@@ -1,10 +0,0 @@
#include "bn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation");
m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation");
m.def("syncbn_backward_xhat", &syncbn_backward_xhat,
"First part of SyncBN backward computation");
m.def("syncbn_backward", &syncbn_backward,
"Second part of SyncBN backward computation");
}

View File

@@ -1,137 +0,0 @@
"""
/*****************************************************************************/
BatchNorm2dSync with multi-gpu
code referenced from : https://github.com/mapillary/inplace_abn
/*****************************************************************************/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.cuda.comm as comm
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ._csrc import _backend
def _count_samples(x):
count = 1
for i, s in enumerate(x.size()):
if i != 1:
count *= s
return count
class BatchNorm2dSyncFunc(Function):
@staticmethod
def forward(ctx, x, weight, bias, running_mean, running_var,
extra, compute_stats=True, momentum=0.1, eps=1e-05):
def _parse_extra(ctx, extra):
ctx.is_master = extra["is_master"]
if ctx.is_master:
ctx.master_queue = extra["master_queue"]
ctx.worker_queues = extra["worker_queues"]
ctx.worker_ids = extra["worker_ids"]
else:
ctx.master_queue = extra["master_queue"]
ctx.worker_queue = extra["worker_queue"]
# Save context
if extra is not None:
_parse_extra(ctx, extra)
ctx.compute_stats = compute_stats
ctx.momentum = momentum
ctx.eps = eps
ctx.affine = weight is not None and bias is not None
if ctx.compute_stats:
N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
assert N > 1
# 1. compute sum(x) and sum(x^2)
xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
if ctx.is_master:
xsums, xsqsums = [xsum], [xsqsum]
# master : gatther all sum(x) and sum(x^2) from slaves
for _ in range(ctx.master_queue.maxsize):
xsum_w, xsqsum_w = ctx.master_queue.get()
ctx.master_queue.task_done()
xsums.append(xsum_w)
xsqsums.append(xsqsum_w)
xsum = comm.reduce_add(xsums)
xsqsum = comm.reduce_add(xsqsums)
mean = xsum / N
sumvar = xsqsum - xsum * mean
var = sumvar / N
uvar = sumvar / (N - 1)
# master : broadcast global mean, variance to all slaves
tensors = comm.broadcast_coalesced(
(mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
for ts, queue in zip(tensors[1:], ctx.worker_queues):
queue.put(ts)
else:
# slave : send sum(x) and sum(x^2) to master
ctx.master_queue.put((xsum, xsqsum))
# slave : get global mean and variance
mean, uvar, var = ctx.worker_queue.get()
ctx.worker_queue.task_done()
# Update running stats
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
ctx.N = N
ctx.save_for_backward(x, weight, bias, mean, var)
else:
mean, var = running_mean, running_var
# do batch norm forward
z = _backend.syncbn_forward(x, weight, bias, mean, var,
ctx.affine, ctx.eps)
return z
@staticmethod
@once_differentiable
def backward(ctx, dz):
x, weight, bias, mean, var = ctx.saved_tensors
dz = dz.contiguous()
# 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat(
dz, x, mean, var, ctx.eps)
if ctx.is_master:
sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
# master : gatther from slaves
for _ in range(ctx.master_queue.maxsize):
sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
ctx.master_queue.task_done()
sum_dzs.append(sum_dz_w)
sum_dz_xhats.append(sum_dz_xhat_w)
# master : compute global stats
sum_dz = comm.reduce_add(sum_dzs)
sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
sum_dz /= ctx.N
sum_dz_xhat /= ctx.N
# master : broadcast global stats
tensors = comm.broadcast_coalesced(
(sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
for ts, queue in zip(tensors[1:], ctx.worker_queues):
queue.put(ts)
else:
# slave : send to master
ctx.master_queue.put((sum_dz, sum_dz_xhat))
# slave : get global stats
sum_dz, sum_dz_xhat = ctx.worker_queue.get()
ctx.worker_queue.task_done()
# do batch norm backward
dx, dweight, dbias = _backend.syncbn_backward(
dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat,
ctx.affine, ctx.eps)
return dx, dweight, dbias, \
None, None, None, None, None, None
batchnorm2d_sync = BatchNorm2dSyncFunc.apply
__all__ = ["batchnorm2d_sync"]

View File

@@ -1 +0,0 @@
from .syncbn import *

View File

@@ -1,148 +0,0 @@
"""
/*****************************************************************************/
BatchNorm2dSync with multi-gpu
/*****************************************************************************/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
# python 3
from queue import Queue
except ImportError:
# python 2
from Queue import Queue
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from isegm.model.syncbn.modules.functional import batchnorm2d_sync
class _BatchNorm(nn.Module):
"""
Customized BatchNorm from nn.BatchNorm
>> added freeze attribute to enable bn freeze.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(_BatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.freezed = False
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.reset_parameters()
def reset_parameters(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def _check_input_dim(self, input):
return NotImplemented
def forward(self, input):
self._check_input_dim(input)
compute_stats = not self.freezed and \
self.training and self.track_running_stats
ret = F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias, compute_stats,
self.momentum, self.eps)
return ret
def extra_repr(self):
return '{num_features}, eps={eps}, momentum={momentum}, '\
'affine={affine}, ' \
'track_running_stats={track_running_stats}'.format(
**self.__dict__)
class BatchNorm2dNoSync(_BatchNorm):
"""
Equivalent to nn.BatchNorm2d
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class BatchNorm2dSync(BatchNorm2dNoSync):
"""
BatchNorm2d with automatic multi-GPU Sync
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(BatchNorm2dSync, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine,
track_running_stats=track_running_stats)
self.sync_enabled = True
self.devices = list(range(torch.cuda.device_count()))
if len(self.devices) > 1:
# Initialize queues
self.worker_ids = self.devices[1:]
self.master_queue = Queue(len(self.worker_ids))
self.worker_queues = [Queue(1) for _ in self.worker_ids]
def forward(self, x):
compute_stats = not self.freezed and \
self.training and self.track_running_stats
if self.sync_enabled and compute_stats and len(self.devices) > 1:
if x.get_device() == self.devices[0]:
# Master mode
extra = {
"is_master": True,
"master_queue": self.master_queue,
"worker_queues": self.worker_queues,
"worker_ids": self.worker_ids
}
else:
# Worker mode
extra = {
"is_master": False,
"master_queue": self.master_queue,
"worker_queue": self.worker_queues[
self.worker_ids.index(x.get_device())]
}
return batchnorm2d_sync(x, self.weight, self.bias,
self.running_mean, self.running_var,
extra, compute_stats, self.momentum,
self.eps)
return super(BatchNorm2dSync, self).forward(x)
def __repr__(self):
"""repr"""
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
'affine={affine}, ' \
'track_running_stats={track_running_stats},' \
'devices={devices})'
return rep.format(name=self.__class__.__name__, **self.__dict__)
#BatchNorm2d = BatchNorm2dNoSync
BatchNorm2d = BatchNorm2dSync

View File

@@ -1,2 +0,0 @@
# noinspection PyUnresolvedReferences
from .dist_maps import get_dist_maps

View File

@@ -1,63 +0,0 @@
import numpy as np
cimport cython
cimport numpy as np
from libc.stdlib cimport malloc, free
ctypedef struct qnode:
int row
int col
int layer
int orig_row
int orig_col
@cython.infer_types(True)
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points,
int height, int width, float norm_delimeter):
cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \
np.full((2, height, width), 1e6, dtype=np.float32, order="C")
cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0]
cdef int i, j, x, y, dx, dy
cdef qnode v
cdef qnode *q = <qnode *> malloc((4 * height * width + 1) * sizeof(qnode))
cdef int qhead = 0, qtail = -1
cdef float ndist
for i in range(points.shape[0]):
x, y = round(points[i, 0]), round(points[i, 1])
if x >= 0:
qtail += 1
q[qtail].row = x
q[qtail].col = y
q[qtail].orig_row = x
q[qtail].orig_col = y
if i >= points.shape[0] / 2:
q[qtail].layer = 1
else:
q[qtail].layer = 0
dist_maps[q[qtail].layer, x, y] = 0
while qtail - qhead + 1 > 0:
v = q[qhead]
qhead += 1
for k in range(4):
x = v.row + dxy[2 * k]
y = v.col + dxy[2 * k + 1]
ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2
if (x >= 0 and y >= 0 and x < height and y < width and
dist_maps[v.layer, x, y] > ndist):
qtail += 1
q[qtail].orig_col = v.orig_col
q[qtail].orig_row = v.orig_row
q[qtail].layer = v.layer
q[qtail].row = x
q[qtail].col = y
dist_maps[v.layer, x, y] = ndist
free(q)
return dist_maps

View File

@@ -1,7 +0,0 @@
import numpy
def make_ext(modname, pyxfilename):
from distutils.extension import Extension
return Extension(modname, [pyxfilename],
include_dirs=[numpy.get_include()],
extra_compile_args=['-O3'], language='c++')

View File

@@ -1,3 +0,0 @@
import pyximport; pyximport.install(pyximport=True, language_level=3)
# noinspection PyUnresolvedReferences
from ._get_dist_maps import get_dist_maps

View File

@@ -1,62 +0,0 @@
from functools import partial
import torch
import numpy as np
def get_dims_with_exclusion(dim, exclude=None):
dims = list(range(dim))
if exclude is not None:
dims.remove(exclude)
return dims
def get_unique_labels(mask):
return np.nonzero(np.bincount(mask.flatten() + 1))[0] - 1
def get_bbox_from_mask(mask):
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
return rmin, rmax, cmin, cmax
def expand_bbox(bbox, expand_ratio, min_crop_size=None):
rmin, rmax, cmin, cmax = bbox
rcenter = 0.5 * (rmin + rmax)
ccenter = 0.5 * (cmin + cmax)
height = expand_ratio * (rmax - rmin + 1)
width = expand_ratio * (cmax - cmin + 1)
if min_crop_size is not None:
height = max(height, min_crop_size)
width = max(width, min_crop_size)
rmin = int(round(rcenter - 0.5 * height))
rmax = int(round(rcenter + 0.5 * height))
cmin = int(round(ccenter - 0.5 * width))
cmax = int(round(ccenter + 0.5 * width))
return rmin, rmax, cmin, cmax
def clamp_bbox(bbox, rmin, rmax, cmin, cmax):
return (max(rmin, bbox[0]), min(rmax, bbox[1]),
max(cmin, bbox[2]), min(cmax, bbox[3]))
def get_bbox_iou(b1, b2):
h_iou = get_segments_iou(b1[:2], b2[:2])
w_iou = get_segments_iou(b1[2:4], b2[2:4])
return h_iou * w_iou
def get_segments_iou(s1, s2):
a, b = s1
c, d = s2
intersection = max(0, min(b, d) - max(a, c) + 1)
union = max(1e-6, max(b, d) - min(a, c) + 1)
return intersection / union

View File

@@ -1,129 +0,0 @@
from functools import lru_cache
import cv2
import numpy as np
def visualize_instances(imask, bg_color=255,
boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8):
num_objects = imask.max() + 1
palette = get_palette(num_objects)
if bg_color is not None:
palette[0] = bg_color
result = palette[imask].astype(np.uint8)
if boundaries_color is not None:
boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width)
tresult = result.astype(np.float32)
tresult[boundaries_mask] = boundaries_color
tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result
result = tresult.astype(np.uint8)
return result
@lru_cache(maxsize=16)
def get_palette(num_cls):
palette = np.zeros(3 * num_cls, dtype=np.int32)
for j in range(0, num_cls):
lab = j
i = 0
while lab > 0:
palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i))
palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i))
palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i))
i = i + 1
lab >>= 3
return palette.reshape((-1, 3))
def visualize_mask(mask, num_cls):
palette = get_palette(num_cls)
mask[mask == -1] = 0
return palette[mask].astype(np.uint8)
def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1):
proposal_map, colors, candidates = proposals_info
proposal_map = draw_probmap(proposal_map)
for x, y in candidates:
proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1)
return proposal_map
def draw_probmap(x):
return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT)
def draw_points(image, points, color, radius=3):
image = image.copy()
for p in points:
image = cv2.circle(image, (int(p[1]), int(p[0])), radius, color, -1)
return image
def draw_instance_map(x, palette=None):
num_colors = x.max() + 1
if palette is None:
palette = get_palette(num_colors)
return palette[x].astype(np.uint8)
def blend_mask(image, mask, alpha=0.6):
if mask.min() == -1:
mask = mask.copy() + 1
imap = draw_instance_map(mask)
result = (image * (1 - alpha) + alpha * imap).astype(np.uint8)
return result
def get_boundaries(instances_masks, boundaries_width=1):
boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool)
for obj_id in np.unique(instances_masks.flatten()):
if obj_id == 0:
continue
obj_mask = instances_masks == obj_id
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool)
obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask))
boundaries = np.logical_or(boundaries, obj_boundary)
return boundaries
def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0),
neg_color=(255, 0, 0), radius=4):
result = img.copy()
if mask is not None:
palette = get_palette(np.max(mask) + 1)
rgb_mask = palette[mask.astype(np.uint8)]
mask_region = (mask > 0).astype(np.uint8)
result = result * (1 - mask_region[:, :, np.newaxis]) + \
(1 - alpha) * mask_region[:, :, np.newaxis] * result + \
alpha * rgb_mask
result = result.astype(np.uint8)
# result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8)
if clicks_list is not None and len(clicks_list) > 0:
pos_points = [click.coords for click in clicks_list if click.is_positive]
neg_points = [click.coords for click in clicks_list if not click.is_positive]
result = draw_points(result, pos_points, pos_color, radius=radius)
result = draw_points(result, neg_points, neg_color, radius=radius)
return result

View File

@@ -1,53 +0,0 @@
import torch
from .fbrs.controller import InteractiveController
from .fbrs.inference import utils
class FBRSController:
def __init__(self, checkpoint_path, device='cuda:0', max_size=800):
model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True, norm_radius=260)
# Predictor params
zoomin_params = {
'skip_clicks': 1,
'target_size': 480,
'expansion_ratio': 1.4,
}
predictor_params = {
'brs_mode': 'f-BRS-B',
'prob_thresh': 0.5,
'zoom_in_params': zoomin_params,
'predictor_params': {
'net_clicks_limit': 8,
'max_size': 800,
},
'brs_opt_func_params': {'min_iou_diff': 1e-3},
'lbfgs_params': {'maxfun': 20}
}
self.controller = InteractiveController(model, device, predictor_params)
self.anchored = False
self.device = device
def unanchor(self):
self.anchored = False
def interact(self, image, x, y, is_positive):
image = image.to(self.device, non_blocking=True)
if not self.anchored:
self.controller.set_image(image)
self.controller.reset_predictor()
self.anchored = True
self.controller.add_click(x, y, is_positive)
# return self.controller.result_mask
# return self.controller.probs_history[-1][1]
return (self.controller.probs_history[-1][1]>0.5).float()
def undo(self):
self.controller.undo_click()
if len(self.controller.probs_history) == 0:
return None
else:
return (self.controller.probs_history[-1][1]>0.5).float()

View File

@@ -1,933 +0,0 @@
"""
Based on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN
(which is based on https://github.com/seoungwugoh/ivs-demo)
This version is much simplified.
In this repo, we don't have
- local control
- fusion module
- undo
- timers
but with XMem as the backbone and is more memory (for both CPU and GPU) friendly
"""
import functools
import os
import cv2
# fix conflicts between qt5 and cv2
os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH")
import numpy as np
import torch
from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox,
QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog,
QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton)
from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon
from PyQt5.QtCore import Qt, QTimer
from model.network import XMem
from inference.inference_core import InferenceCore
from .s2m_controller import S2MController
from .fbrs_controller import FBRSController
from .interactive_utils import *
from .interaction import *
from .resource_manager import ResourceManager
from .gui_utils import *
class App(QWidget):
def __init__(self, net: XMem,
resource_manager: ResourceManager,
s2m_ctrl:S2MController,
fbrs_ctrl:FBRSController, config):
super().__init__()
self.initialized = False
self.num_objects = config['num_objects']
self.s2m_controller = s2m_ctrl
self.fbrs_controller = fbrs_ctrl
self.config = config
self.processor = InferenceCore(net, config)
self.processor.set_all_labels(list(range(1, self.num_objects+1)))
self.res_man = resource_manager
self.num_frames = len(self.res_man)
self.height, self.width = self.res_man.h, self.res_man.w
# set window
self.setWindowTitle('XMem Demo')
self.setGeometry(100, 100, self.width, self.height+100)
self.setWindowIcon(QIcon('docs/icon.png'))
# some buttons
self.play_button = QPushButton('Play Video')
self.play_button.clicked.connect(self.on_play_video)
self.commit_button = QPushButton('Commit')
self.commit_button.clicked.connect(self.on_commit)
self.forward_run_button = QPushButton('Forward Propagate')
self.forward_run_button.clicked.connect(self.on_forward_propagation)
self.forward_run_button.setMinimumWidth(200)
self.backward_run_button = QPushButton('Backward Propagate')
self.backward_run_button.clicked.connect(self.on_backward_propagation)
self.backward_run_button.setMinimumWidth(200)
self.reset_button = QPushButton('Reset Frame')
self.reset_button.clicked.connect(self.on_reset_mask)
# LCD
self.lcd = QTextEdit()
self.lcd.setReadOnly(True)
self.lcd.setMaximumHeight(28)
self.lcd.setMaximumWidth(120)
self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1))
# timeline slider
self.tl_slider = QSlider(Qt.Horizontal)
self.tl_slider.valueChanged.connect(self.tl_slide)
self.tl_slider.setMinimum(0)
self.tl_slider.setMaximum(self.num_frames-1)
self.tl_slider.setValue(0)
self.tl_slider.setTickPosition(QSlider.TicksBelow)
self.tl_slider.setTickInterval(1)
# brush size slider
self.brush_label = QLabel()
self.brush_label.setAlignment(Qt.AlignCenter)
self.brush_label.setMinimumWidth(100)
self.brush_slider = QSlider(Qt.Horizontal)
self.brush_slider.valueChanged.connect(self.brush_slide)
self.brush_slider.setMinimum(1)
self.brush_slider.setMaximum(100)
self.brush_slider.setValue(3)
self.brush_slider.setTickPosition(QSlider.TicksBelow)
self.brush_slider.setTickInterval(2)
self.brush_slider.setMinimumWidth(300)
# combobox
self.combo = QComboBox(self)
self.combo.addItem("davis")
self.combo.addItem("fade")
self.combo.addItem("light")
self.combo.addItem("popup")
self.combo.addItem("layered")
self.combo.currentTextChanged.connect(self.set_viz_mode)
self.save_visualization_checkbox = QCheckBox(self)
self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle)
self.save_visualization_checkbox.setChecked(False)
self.save_visualization = False
# Radio buttons for type of interactions
self.curr_interaction = 'Click'
self.interaction_group = QButtonGroup()
self.radio_fbrs = QRadioButton('Click')
self.radio_s2m = QRadioButton('Scribble')
self.radio_free = QRadioButton('Free')
self.interaction_group.addButton(self.radio_fbrs)
self.interaction_group.addButton(self.radio_s2m)
self.interaction_group.addButton(self.radio_free)
self.radio_fbrs.toggled.connect(self.interaction_radio_clicked)
self.radio_s2m.toggled.connect(self.interaction_radio_clicked)
self.radio_free.toggled.connect(self.interaction_radio_clicked)
self.radio_fbrs.toggle()
# Main canvas -> QLabel
self.main_canvas = QLabel()
self.main_canvas.setSizePolicy(QSizePolicy.Expanding,
QSizePolicy.Expanding)
self.main_canvas.setAlignment(Qt.AlignCenter)
self.main_canvas.setMinimumSize(100, 100)
self.main_canvas.mousePressEvent = self.on_mouse_press
self.main_canvas.mouseMoveEvent = self.on_mouse_motion
self.main_canvas.setMouseTracking(True) # Required for all-time tracking
self.main_canvas.mouseReleaseEvent = self.on_mouse_release
# Minimap -> Also a QLbal
self.minimap = QLabel()
self.minimap.setSizePolicy(QSizePolicy.Expanding,
QSizePolicy.Expanding)
self.minimap.setAlignment(Qt.AlignTop)
self.minimap.setMinimumSize(100, 100)
# Zoom-in buttons
self.zoom_p_button = QPushButton('Zoom +')
self.zoom_p_button.clicked.connect(self.on_zoom_plus)
self.zoom_m_button = QPushButton('Zoom -')
self.zoom_m_button.clicked.connect(self.on_zoom_minus)
# Parameters setting
self.clear_mem_button = QPushButton('Clear memory')
self.clear_mem_button.clicked.connect(self.on_clear_memory)
self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size')
self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size')
self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)')
self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)')
self.update_memory_size()
self.update_gpu_usage()
self.work_mem_min, self.work_mem_min_layout = create_parameter_box(1, 100, 'Min. working memory frames',
callback=self.on_work_min_change)
self.work_mem_max, self.work_mem_max_layout = create_parameter_box(2, 100, 'Max. working memory frames',
callback=self.on_work_max_change)
self.long_mem_max, self.long_mem_max_layout = create_parameter_box(1000, 100000,
'Max. long-term memory size', step=1000, callback=self.update_config)
self.num_prototypes_box, self.num_prototypes_box_layout = create_parameter_box(32, 1280,
'Number of prototypes', step=32, callback=self.update_config)
self.mem_every_box, self.mem_every_box_layout = create_parameter_box(1, 100, 'Memory frame every (r)',
callback=self.update_config)
self.work_mem_min.setValue(self.processor.memory.min_mt_frames)
self.work_mem_max.setValue(self.processor.memory.max_mt_frames)
self.long_mem_max.setValue(self.processor.memory.max_long_elements)
self.num_prototypes_box.setValue(self.processor.memory.num_prototypes)
self.mem_every_box.setValue(self.processor.mem_every)
# import mask/layer
self.import_mask_button = QPushButton('Import mask')
self.import_mask_button.clicked.connect(self.on_import_mask)
self.import_layer_button = QPushButton('Import layer')
self.import_layer_button.clicked.connect(self.on_import_layer)
# Console on the GUI
self.console = QPlainTextEdit()
self.console.setReadOnly(True)
self.console.setMinimumHeight(100)
self.console.setMaximumHeight(100)
# navigator
navi = QHBoxLayout()
navi.addWidget(self.lcd)
navi.addWidget(self.play_button)
interact_subbox = QVBoxLayout()
interact_topbox = QHBoxLayout()
interact_botbox = QHBoxLayout()
interact_topbox.setAlignment(Qt.AlignCenter)
interact_topbox.addWidget(self.radio_s2m)
interact_topbox.addWidget(self.radio_fbrs)
interact_topbox.addWidget(self.radio_free)
interact_topbox.addWidget(self.brush_label)
interact_botbox.addWidget(self.brush_slider)
interact_subbox.addLayout(interact_topbox)
interact_subbox.addLayout(interact_botbox)
navi.addLayout(interact_subbox)
navi.addStretch(1)
navi.addWidget(self.reset_button)
navi.addStretch(1)
navi.addWidget(QLabel('Overlay Mode'))
navi.addWidget(self.combo)
navi.addWidget(QLabel('Save overlay during propagation'))
navi.addWidget(self.save_visualization_checkbox)
navi.addStretch(1)
navi.addWidget(self.commit_button)
navi.addWidget(self.forward_run_button)
navi.addWidget(self.backward_run_button)
# Drawing area, main canvas and minimap
draw_area = QHBoxLayout()
draw_area.addWidget(self.main_canvas, 4)
# Minimap area
minimap_area = QVBoxLayout()
minimap_area.setAlignment(Qt.AlignTop)
mini_label = QLabel('Minimap')
mini_label.setAlignment(Qt.AlignTop)
minimap_area.addWidget(mini_label)
# Minimap zooming
minimap_ctrl = QHBoxLayout()
minimap_ctrl.setAlignment(Qt.AlignTop)
minimap_ctrl.addWidget(self.zoom_p_button)
minimap_ctrl.addWidget(self.zoom_m_button)
minimap_area.addLayout(minimap_ctrl)
minimap_area.addWidget(self.minimap)
# Parameters
minimap_area.addLayout(self.work_mem_gauge_layout)
minimap_area.addLayout(self.long_mem_gauge_layout)
minimap_area.addLayout(self.gpu_mem_gauge_layout)
minimap_area.addLayout(self.torch_mem_gauge_layout)
minimap_area.addWidget(self.clear_mem_button)
minimap_area.addLayout(self.work_mem_min_layout)
minimap_area.addLayout(self.work_mem_max_layout)
minimap_area.addLayout(self.long_mem_max_layout)
minimap_area.addLayout(self.num_prototypes_box_layout)
minimap_area.addLayout(self.mem_every_box_layout)
# import mask/layer
import_area = QHBoxLayout()
import_area.setAlignment(Qt.AlignTop)
import_area.addWidget(self.import_mask_button)
import_area.addWidget(self.import_layer_button)
minimap_area.addLayout(import_area)
# console
minimap_area.addWidget(self.console)
draw_area.addLayout(minimap_area, 1)
layout = QVBoxLayout()
layout.addLayout(draw_area)
layout.addWidget(self.tl_slider)
layout.addLayout(navi)
self.setLayout(layout)
# timer to play video
self.timer = QTimer()
self.timer.setSingleShot(False)
# timer to update GPU usage
self.gpu_timer = QTimer()
self.gpu_timer.setSingleShot(False)
self.gpu_timer.timeout.connect(self.on_gpu_timer)
self.gpu_timer.setInterval(2000)
self.gpu_timer.start()
# current frame info
self.curr_frame_dirty = False
self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
self.current_image_torch = None
self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8)
self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).cuda()
# initialize visualization
self.viz_mode = 'davis'
self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
self.vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
self.brush_vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
self.cursur = 0
self.on_showing = None
# Zoom parameters
self.zoom_pixels = 150
# initialize action
self.interaction = None
self.pressed = False
self.right_click = False
self.current_object = 1
self.last_ex = self.last_ey = 0
self.propagating = False
# Objects shortcuts
for i in range(1, self.num_objects+1):
QShortcut(QKeySequence(str(i)), self).activated.connect(functools.partial(self.hit_number_key, i))
# <- and -> shortcuts
QShortcut(QKeySequence(Qt.Key_Left), self).activated.connect(self.on_prev_frame)
QShortcut(QKeySequence(Qt.Key_Right), self).activated.connect(self.on_next_frame)
self.interacted_prob = None
self.overlay_layer = None
self.overlay_layer_torch = None
# the object id used for popup/layered overlay
self.vis_target_objects = [1]
# try to load the default overlay
self._try_load_layer('./docs/ECCV-logo.png')
self.load_current_image_mask()
self.show_current_frame()
self.show()
self.console_push_text('Initialized.')
self.initialized = True
def resizeEvent(self, event):
self.show_current_frame()
def console_push_text(self, text):
self.console.moveCursor(QTextCursor.End)
self.console.insertPlainText(text+'\n')
def interaction_radio_clicked(self, event):
self.last_interaction = self.curr_interaction
if self.radio_s2m.isChecked():
self.curr_interaction = 'Scribble'
self.brush_size = 3
self.brush_slider.setDisabled(True)
elif self.radio_fbrs.isChecked():
self.curr_interaction = 'Click'
self.brush_size = 3
self.brush_slider.setDisabled(True)
elif self.radio_free.isChecked():
self.brush_slider.setDisabled(False)
self.brush_slide()
self.curr_interaction = 'Free'
if self.curr_interaction == 'Scribble':
self.commit_button.setEnabled(True)
else:
self.commit_button.setEnabled(False)
def load_current_image_mask(self, no_mask=False):
self.current_image = self.res_man.get_image(self.cursur)
self.current_image_torch = None
if not no_mask:
loaded_mask = self.res_man.get_mask(self.cursur)
if loaded_mask is None:
self.current_mask.fill(0)
else:
self.current_mask = loaded_mask.copy()
self.current_prob = None
def load_current_torch_image_mask(self, no_mask=False):
if self.current_image_torch is None:
self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image)
if self.current_prob is None and not no_mask:
self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda()
def compose_current_im(self):
self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask,
self.overlay_layer, self.vis_target_objects)
def update_interact_vis(self):
# Update the interactions without re-computing the overlay
height, width, channel = self.viz.shape
bytesPerLine = 3 * width
vis_map = self.vis_map
vis_alpha = self.vis_alpha
brush_vis_map = self.brush_vis_map
brush_vis_alpha = self.brush_vis_alpha
self.viz_with_stroke = self.viz*(1-vis_alpha) + vis_map*vis_alpha
self.viz_with_stroke = self.viz_with_stroke*(1-brush_vis_alpha) + brush_vis_map*brush_vis_alpha
self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8)
qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine, QImage.Format_RGB888)
self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
Qt.KeepAspectRatio, Qt.FastTransformation)))
self.main_canvas_size = self.main_canvas.size()
self.image_size = qImg.size()
def update_minimap(self):
ex, ey = self.last_ex, self.last_ey
r = self.zoom_pixels//2
ex = int(round(max(r, min(self.width-r, ex))))
ey = int(round(max(r, min(self.height-r, ey))))
patch = self.viz_with_stroke[ey-r:ey+r, ex-r:ex+r, :].astype(np.uint8)
height, width, channel = patch.shape
bytesPerLine = 3 * width
qImg = QImage(patch.data, width, height, bytesPerLine, QImage.Format_RGB888)
self.minimap.setPixmap(QPixmap(qImg.scaled(self.minimap.size(),
Qt.KeepAspectRatio, Qt.FastTransformation)))
def update_current_image_fast(self):
# fast path, uses gpu. Changes the image in-place to avoid copying
self.viz = get_visualization_torch(self.viz_mode, self.current_image_torch_no_norm,
self.current_prob, self.overlay_layer_torch, self.vis_target_objects)
if self.save_visualization:
self.res_man.save_visualization(self.cursur, self.viz)
height, width, channel = self.viz.shape
bytesPerLine = 3 * width
qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888)
self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
Qt.KeepAspectRatio, Qt.FastTransformation)))
def show_current_frame(self, fast=False):
# Re-compute overlay and show the image
if fast:
self.update_current_image_fast()
else:
self.compose_current_im()
self.update_interact_vis()
self.update_minimap()
self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1))
self.tl_slider.setValue(self.cursur)
def pixel_pos_to_image_pos(self, x, y):
# Un-scale and un-pad the label coordinates into image coordinates
oh, ow = self.image_size.height(), self.image_size.width()
nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width()
h_ratio = nh/oh
w_ratio = nw/ow
dominate_ratio = min(h_ratio, w_ratio)
# Solve scale
x /= dominate_ratio
y /= dominate_ratio
# Solve padding
fh, fw = nh/dominate_ratio, nw/dominate_ratio
x -= (fw-ow)/2
y -= (fh-oh)/2
return x, y
def is_pos_out_of_bound(self, x, y):
x, y = self.pixel_pos_to_image_pos(x, y)
out_of_bound = (
(x < 0) or
(y < 0) or
(x > self.width-1) or
(y > self.height-1)
)
return out_of_bound
def get_scaled_pos(self, x, y):
x, y = self.pixel_pos_to_image_pos(x, y)
x = max(0, min(self.width-1, x))
y = max(0, min(self.height-1, y))
return x, y
def clear_visualization(self):
self.vis_map.fill(0)
self.vis_alpha.fill(0)
def reset_this_interaction(self):
self.complete_interaction()
self.clear_visualization()
self.interaction = None
if self.fbrs_controller is not None:
self.fbrs_controller.unanchor()
def set_viz_mode(self):
self.viz_mode = self.combo.currentText()
self.show_current_frame()
def save_current_mask(self):
# save mask to hard disk
self.res_man.save_mask(self.cursur, self.current_mask)
def tl_slide(self):
# if we are propagating, the on_run function will take care of everything
# don't do duplicate work here
if not self.propagating:
if self.curr_frame_dirty:
self.save_current_mask()
self.curr_frame_dirty = False
self.reset_this_interaction()
self.cursur = self.tl_slider.value()
self.load_current_image_mask()
self.show_current_frame()
def brush_slide(self):
self.brush_size = self.brush_slider.value()
self.brush_label.setText('Brush size: %d' % self.brush_size)
try:
if type(self.interaction) == FreeInteraction:
self.interaction.set_size(self.brush_size)
except AttributeError:
# Initialization, forget about it
pass
def on_forward_propagation(self):
if self.propagating:
# acts as a pause button
self.propagating = False
else:
self.propagate_fn = self.on_next_frame
self.backward_run_button.setEnabled(False)
self.forward_run_button.setText('Pause Propagation')
self.on_propagation()
def on_backward_propagation(self):
if self.propagating:
# acts as a pause button
self.propagating = False
else:
self.propagate_fn = self.on_prev_frame
self.forward_run_button.setEnabled(False)
self.backward_run_button.setText('Pause Propagation')
self.on_propagation()
def on_pause(self):
self.propagating = False
self.forward_run_button.setEnabled(True)
self.backward_run_button.setEnabled(True)
self.clear_mem_button.setEnabled(True)
self.forward_run_button.setText('Forward Propagate')
self.backward_run_button.setText('Backward Propagate')
self.console_push_text('Propagation stopped.')
def on_propagation(self):
# start to propagate
self.load_current_torch_image_mask()
self.show_current_frame(fast=True)
self.console_push_text('Propagation started.')
self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:])
self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
# clear
self.interacted_prob = None
self.reset_this_interaction()
self.propagating = True
self.clear_mem_button.setEnabled(False)
# propagate till the end
while self.propagating:
self.propagate_fn()
self.load_current_image_mask(no_mask=True)
self.load_current_torch_image_mask(no_mask=True)
self.current_prob = self.processor.step(self.current_image_torch)
self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
self.save_current_mask()
self.show_current_frame(fast=True)
self.update_memory_size()
QApplication.processEvents()
if self.cursur == 0 or self.cursur == self.num_frames-1:
break
self.propagating = False
self.curr_frame_dirty = False
self.on_pause()
self.tl_slide()
QApplication.processEvents()
def pause_propagation(self):
self.propagating = False
def on_commit(self):
self.complete_interaction()
self.update_interacted_mask()
def on_prev_frame(self):
# self.tl_slide will trigger on setValue
self.cursur = max(0, self.cursur-1)
self.tl_slider.setValue(self.cursur)
def on_next_frame(self):
# self.tl_slide will trigger on setValue
self.cursur = min(self.cursur+1, self.num_frames-1)
self.tl_slider.setValue(self.cursur)
def on_play_video_timer(self):
self.cursur += 1
if self.cursur > self.num_frames-1:
self.cursur = 0
self.tl_slider.setValue(self.cursur)
def on_play_video(self):
if self.timer.isActive():
self.timer.stop()
self.play_button.setText('Play Video')
else:
self.timer.start(1000 / 30)
self.play_button.setText('Stop Video')
def on_reset_mask(self):
self.current_mask.fill(0)
if self.current_prob is not None:
self.current_prob.fill_(0)
self.curr_frame_dirty = True
self.save_current_mask()
self.reset_this_interaction()
self.show_current_frame()
def on_zoom_plus(self):
self.zoom_pixels -= 25
self.zoom_pixels = max(50, self.zoom_pixels)
self.update_minimap()
def on_zoom_minus(self):
self.zoom_pixels += 25
self.zoom_pixels = min(self.zoom_pixels, 300)
self.update_minimap()
def set_navi_enable(self, boolean):
self.zoom_p_button.setEnabled(boolean)
self.zoom_m_button.setEnabled(boolean)
self.run_button.setEnabled(boolean)
self.tl_slider.setEnabled(boolean)
self.play_button.setEnabled(boolean)
self.lcd.setEnabled(boolean)
def hit_number_key(self, number):
if number == self.current_object:
return
self.current_object = number
if self.fbrs_controller is not None:
self.fbrs_controller.unanchor()
self.console_push_text(f'Current object changed to {number}.')
self.clear_brush()
self.vis_brush(self.last_ex, self.last_ey)
self.update_interact_vis()
self.show_current_frame()
def clear_brush(self):
self.brush_vis_map.fill(0)
self.brush_vis_alpha.fill(0)
def vis_brush(self, ex, ey):
self.brush_vis_map = cv2.circle(self.brush_vis_map,
(int(round(ex)), int(round(ey))), self.brush_size//2+1, color_map[self.current_object], thickness=-1)
self.brush_vis_alpha = cv2.circle(self.brush_vis_alpha,
(int(round(ex)), int(round(ey))), self.brush_size//2+1, 0.5, thickness=-1)
def on_mouse_press(self, event):
if self.is_pos_out_of_bound(event.x(), event.y()):
return
# mid-click
if (event.button() == Qt.MidButton):
ex, ey = self.get_scaled_pos(event.x(), event.y())
target_object = self.current_mask[int(ey),int(ex)]
if target_object in self.vis_target_objects:
self.vis_target_objects.remove(target_object)
else:
self.vis_target_objects.append(target_object)
self.console_push_text(f'Target objects for visualization changed to {self.vis_target_objects}')
self.show_current_frame()
return
self.right_click = (event.button() == Qt.RightButton)
self.pressed = True
h, w = self.height, self.width
self.load_current_torch_image_mask()
image = self.current_image_torch
last_interaction = self.interaction
new_interaction = None
if self.curr_interaction == 'Scribble':
if last_interaction is None or type(last_interaction) != ScribbleInteraction:
self.complete_interaction()
new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().cuda(),
(h, w), self.s2m_controller, self.num_objects)
elif self.curr_interaction == 'Free':
if last_interaction is None or type(last_interaction) != FreeInteraction:
self.complete_interaction()
new_interaction = FreeInteraction(image, self.current_mask, (h, w),
self.num_objects)
new_interaction.set_size(self.brush_size)
elif self.curr_interaction == 'Click':
if (last_interaction is None or type(last_interaction) != ClickInteraction
or last_interaction.tar_obj != self.current_object):
self.complete_interaction()
self.fbrs_controller.unanchor()
new_interaction = ClickInteraction(image, self.current_prob, (h, w),
self.fbrs_controller, self.current_object)
if new_interaction is not None:
self.interaction = new_interaction
# Just motion it as the first step
self.on_mouse_motion(event)
def on_mouse_motion(self, event):
ex, ey = self.get_scaled_pos(event.x(), event.y())
self.last_ex, self.last_ey = ex, ey
self.clear_brush()
# Visualize
self.vis_brush(ex, ey)
if self.pressed:
if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
obj = 0 if self.right_click else self.current_object
self.vis_map, self.vis_alpha = self.interaction.push_point(
ex, ey, obj, (self.vis_map, self.vis_alpha)
)
self.update_interact_vis()
self.update_minimap()
def update_interacted_mask(self):
self.current_prob = self.interacted_prob
self.current_mask = torch_prob_to_numpy_mask(self.interacted_prob)
self.show_current_frame()
self.save_current_mask()
self.curr_frame_dirty = False
def complete_interaction(self):
if self.interaction is not None:
self.clear_visualization()
self.interaction = None
def on_mouse_release(self, event):
if not self.pressed:
# this can happen when the initial press is out-of-bound
return
ex, ey = self.get_scaled_pos(event.x(), event.y())
self.console_push_text('%s interaction at frame %d.' % (self.curr_interaction, self.cursur))
interaction = self.interaction
if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
self.on_mouse_motion(event)
interaction.end_path()
if self.curr_interaction == 'Free':
self.clear_visualization()
elif self.curr_interaction == 'Click':
ex, ey = self.get_scaled_pos(event.x(), event.y())
self.vis_map, self.vis_alpha = interaction.push_point(ex, ey,
self.right_click, (self.vis_map, self.vis_alpha))
self.interacted_prob = interaction.predict()
self.update_interacted_mask()
self.update_gpu_usage()
self.pressed = self.right_click = False
def wheelEvent(self, event):
ex, ey = self.get_scaled_pos(event.x(), event.y())
if self.curr_interaction == 'Free':
self.brush_slider.setValue(self.brush_slider.value() + event.angleDelta().y()//30)
self.clear_brush()
self.vis_brush(ex, ey)
self.update_interact_vis()
self.update_minimap()
def update_gpu_usage(self):
info = torch.cuda.mem_get_info()
global_free, global_total = info
global_free /= (2**30)
global_total /= (2**30)
global_used = global_total - global_free
self.gpu_mem_gauge.setFormat(f'{global_used:.01f} GB / {global_total:.01f} GB')
self.gpu_mem_gauge.setValue(round(global_used/global_total*100))
used_by_torch = torch.cuda.max_memory_allocated() / (2**20)
self.torch_mem_gauge.setFormat(f'{used_by_torch:.0f} MB / {global_total:.01f} GB')
self.torch_mem_gauge.setValue(round(used_by_torch/global_total*100/1024))
def on_gpu_timer(self):
self.update_gpu_usage()
def update_memory_size(self):
try:
max_work_elements = self.processor.memory.max_work_elements
max_long_elements = self.processor.memory.max_long_elements
curr_work_elements = self.processor.memory.work_mem.size
curr_long_elements = self.processor.memory.long_mem.size
self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}')
self.work_mem_gauge.setValue(round(curr_work_elements/max_work_elements*100))
self.long_mem_gauge.setFormat(f'{curr_long_elements} / {max_long_elements}')
self.long_mem_gauge.setValue(round(curr_long_elements/max_long_elements*100))
except AttributeError:
self.work_mem_gauge.setFormat('Unknown')
self.long_mem_gauge.setFormat('Unknown')
self.work_mem_gauge.setValue(0)
self.long_mem_gauge.setValue(0)
def on_work_min_change(self):
if self.initialized:
self.work_mem_min.setValue(min(self.work_mem_min.value(), self.work_mem_max.value()-1))
self.update_config()
def on_work_max_change(self):
if self.initialized:
self.work_mem_max.setValue(max(self.work_mem_max.value(), self.work_mem_min.value()+1))
self.update_config()
def update_config(self):
if self.initialized:
self.config['min_mid_term_frames'] = self.work_mem_min.value()
self.config['max_mid_term_frames'] = self.work_mem_max.value()
self.config['max_long_term_elements'] = self.long_mem_max.value()
self.config['num_prototypes'] = self.num_prototypes_box.value()
self.config['mem_every'] = self.mem_every_box.value()
self.processor.update_config(self.config)
def on_clear_memory(self):
self.processor.clear_memory()
torch.cuda.empty_cache()
self.update_gpu_usage()
self.update_memory_size()
def _open_file(self, prompt):
options = QFileDialog.Options()
file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options)
return file_name
def on_import_mask(self):
file_name = self._open_file('Mask')
if len(file_name) == 0:
return
mask = self.res_man.read_external_image(file_name, size=(self.height, self.width))
shape_condition = (
(len(mask.shape) == 2) and
(mask.shape[-1] == self.width) and
(mask.shape[-2] == self.height)
)
object_condition = (
mask.max() <= self.num_objects
)
if not shape_condition:
self.console_push_text(f'Expected ({self.height}, {self.width}). Got {mask.shape} instead.')
elif not object_condition:
self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.')
else:
self.console_push_text(f'Mask file {file_name} loaded.')
self.current_image_torch = self.current_prob = None
self.current_mask = mask
self.show_current_frame()
self.save_current_mask()
def on_import_layer(self):
file_name = self._open_file('Layer')
if len(file_name) == 0:
return
self._try_load_layer(file_name)
def _try_load_layer(self, file_name):
try:
layer = self.res_man.read_external_image(file_name, size=(self.height, self.width))
if layer.shape[-1] == 3:
layer = np.concatenate([layer, np.ones_like(layer[:,:,0:1])*255], axis=-1)
condition = (
(len(layer.shape) == 3) and
(layer.shape[-1] == 4) and
(layer.shape[-2] == self.width) and
(layer.shape[-3] == self.height)
)
if not condition:
self.console_push_text(f'Expected ({self.height}, {self.width}, 4). Got {layer.shape}.')
else:
self.console_push_text(f'Layer file {file_name} loaded.')
self.overlay_layer = layer
self.overlay_layer_torch = torch.from_numpy(layer).float().cuda()/255
self.show_current_frame()
except FileNotFoundError:
self.console_push_text(f'{file_name} not found.')
def on_save_visualization_toggle(self):
self.save_visualization = self.save_visualization_checkbox.isChecked()

View File

@@ -1,40 +0,0 @@
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar)
def create_parameter_box(min_val, max_val, text, step=1, callback=None):
layout = QHBoxLayout()
dial = QSpinBox()
dial.setMaximumHeight(28)
dial.setMaximumWidth(150)
dial.setMinimum(min_val)
dial.setMaximum(max_val)
dial.setAlignment(Qt.AlignRight)
dial.setSingleStep(step)
dial.valueChanged.connect(callback)
label = QLabel(text)
label.setAlignment(Qt.AlignRight)
layout.addWidget(label)
layout.addWidget(dial)
return dial, layout
def create_gauge(text):
layout = QHBoxLayout()
gauge = QProgressBar()
gauge.setMaximumHeight(28)
gauge.setMaximumWidth(200)
gauge.setAlignment(Qt.AlignCenter)
label = QLabel(text)
label.setAlignment(Qt.AlignRight)
layout.addWidget(label)
layout.addWidget(gauge)
return gauge, layout

View File

@@ -1,252 +0,0 @@
"""
Contains all the types of interaction related to the GUI
Not related to automatic evaluation in the DAVIS dataset
You can inherit the Interaction class to create new interaction types
undo is (sometimes partially) supported
"""
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import time
from .interactive_utils import color_map, index_numpy_to_one_hot_torch
def aggregate_sbg(prob, keep_bg=False, hard=False):
device = prob.device
k, h, w = prob.shape
ex_prob = torch.zeros((k+1, h, w), device=device)
ex_prob[0] = 0.5
ex_prob[1:] = prob
ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7)
logits = torch.log((ex_prob /(1-ex_prob)))
if hard:
# Very low temperature o((⊙﹏⊙))o 🥶
logits *= 1000
if keep_bg:
return F.softmax(logits, dim=0)
else:
return F.softmax(logits, dim=0)[1:]
def aggregate_wbg(prob, keep_bg=False, hard=False):
k, h, w = prob.shape
new_prob = torch.cat([
torch.prod(1-prob, dim=0, keepdim=True),
prob
], 0).clamp(1e-7, 1-1e-7)
logits = torch.log((new_prob /(1-new_prob)))
if hard:
# Very low temperature o((⊙﹏⊙))o 🥶
logits *= 1000
if keep_bg:
return F.softmax(logits, dim=0)
else:
return F.softmax(logits, dim=0)[1:]
class Interaction:
def __init__(self, image, prev_mask, true_size, controller):
self.image = image
self.prev_mask = prev_mask
self.controller = controller
self.start_time = time.time()
self.h, self.w = true_size
self.out_prob = None
self.out_mask = None
def predict(self):
pass
class FreeInteraction(Interaction):
def __init__(self, image, prev_mask, true_size, num_objects):
"""
prev_mask should be index format numpy array
"""
super().__init__(image, prev_mask, true_size, None)
self.K = num_objects
self.drawn_map = self.prev_mask.copy()
self.curr_path = [[] for _ in range(self.K + 1)]
self.size = None
def set_size(self, size):
self.size = size
"""
k - object id
vis - a tuple (visualization map, pass through alpha). None if not needed.
"""
def push_point(self, x, y, k, vis=None):
if vis is not None:
vis_map, vis_alpha = vis
selected = self.curr_path[k]
selected.append((x, y))
if len(selected) >= 2:
cv2.line(self.drawn_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
k, thickness=self.size)
# Plot visualization
if vis is not None:
# Visualization for drawing
if k == 0:
vis_map = cv2.line(vis_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
color_map[k], thickness=self.size)
else:
vis_map = cv2.line(vis_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
color_map[k], thickness=self.size)
# Visualization on/off boolean filter
vis_alpha = cv2.line(vis_alpha,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
0.75, thickness=self.size)
if vis is not None:
return vis_map, vis_alpha
def end_path(self):
# Complete the drawing
self.curr_path = [[] for _ in range(self.K + 1)]
def predict(self):
self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda()
# self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
# self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
# self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True)
return self.out_prob
class ScribbleInteraction(Interaction):
def __init__(self, image, prev_mask, true_size, controller, num_objects):
"""
prev_mask should be in an indexed form
"""
super().__init__(image, prev_mask, true_size, controller)
self.K = num_objects
self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8)
self.drawn_map.fill(255)
# background + k
self.curr_path = [[] for _ in range(self.K + 1)]
self.size = 3
"""
k - object id
vis - a tuple (visualization map, pass through alpha). None if not needed.
"""
def push_point(self, x, y, k, vis=None):
if vis is not None:
vis_map, vis_alpha = vis
selected = self.curr_path[k]
selected.append((x, y))
if len(selected) >= 2:
self.drawn_map = cv2.line(self.drawn_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
k, thickness=self.size)
# Plot visualization
if vis is not None:
# Visualization for drawing
if k == 0:
vis_map = cv2.line(vis_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
color_map[k], thickness=self.size)
else:
vis_map = cv2.line(vis_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
color_map[k], thickness=self.size)
# Visualization on/off boolean filter
vis_alpha = cv2.line(vis_alpha,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
0.75, thickness=self.size)
# Optional vis return
if vis is not None:
return vis_map, vis_alpha
def end_path(self):
# Complete the drawing
self.curr_path = [[] for _ in range(self.K + 1)]
def predict(self):
self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map)
self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True)
return self.out_prob
class ClickInteraction(Interaction):
def __init__(self, image, prev_mask, true_size, controller, tar_obj):
"""
prev_mask in a prob. form
"""
super().__init__(image, prev_mask, true_size, controller)
self.tar_obj = tar_obj
# negative/positive for each object
self.pos_clicks = []
self.neg_clicks = []
self.out_prob = self.prev_mask.clone()
"""
neg - Negative interaction or not
vis - a tuple (visualization map, pass through alpha). None if not needed.
"""
def push_point(self, x, y, neg, vis=None):
# Clicks
if neg:
self.neg_clicks.append((x, y))
else:
self.pos_clicks.append((x, y))
# Do the prediction
self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg)
# Plot visualization
if vis is not None:
vis_map, vis_alpha = vis
# Visualization for clicks
if neg:
vis_map = cv2.circle(vis_map,
(int(round(x)), int(round(y))),
2, color_map[0], thickness=-1)
else:
vis_map = cv2.circle(vis_map,
(int(round(x)), int(round(y))),
2, color_map[self.tar_obj], thickness=-1)
vis_alpha = cv2.circle(vis_alpha,
(int(round(x)), int(round(y))),
2, 1, thickness=-1)
# Optional vis return
return vis_map, vis_alpha
def predict(self):
self.out_prob = self.prev_mask.clone()
# a small hack to allow the interacting object to overwrite existing masks
# without remembering all the object probabilities
self.out_prob = torch.clamp(self.out_prob, max=0.9)
self.out_prob[self.tar_obj] = self.obj_mask
self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True)
return self.out_prob

View File

@@ -1,175 +0,0 @@
# Modifed from https://github.com/seoungwugoh/ivs-demo
import numpy as np
import torch
import torch.nn.functional as F
from util.palette import davis_palette
from dataset.range_transform import im_normalization
def image_to_torch(frame: np.ndarray, device='cuda'):
# frame: H*W*3 numpy array
frame = frame.transpose(2, 0, 1)
frame = torch.from_numpy(frame).float().to(device)/255
frame_norm = im_normalization(frame)
return frame_norm, frame
def torch_prob_to_numpy_mask(prob):
mask = torch.argmax(prob, dim=0)
mask = mask.cpu().numpy().astype(np.uint8)
return mask
def index_numpy_to_one_hot_torch(mask, num_classes):
mask = torch.from_numpy(mask).long()
return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float()
"""
Some constants fro visualization
"""
color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
# scales for better visualization
color_map_np = (color_map_np.astype(np.float32)*1.5).clip(0, 255).astype(np.uint8)
color_map = color_map_np.tolist()
if torch.cuda.is_available():
color_map_torch = torch.from_numpy(color_map_np).cuda() / 255
grayscale_weights = np.array([[0.3,0.59,0.11]]).astype(np.float32)
if torch.cuda.is_available():
grayscale_weights_torch = torch.from_numpy(grayscale_weights).cuda().unsqueeze(0)
def get_visualization(mode, image, mask, layer, target_object):
if mode == 'fade':
return overlay_davis(image, mask, fade=True)
elif mode == 'davis':
return overlay_davis(image, mask)
elif mode == 'light':
return overlay_davis(image, mask, 0.9)
elif mode == 'popup':
return overlay_popup(image, mask, target_object)
elif mode == 'layered':
if layer is None:
print('Layer file not given. Defaulting to DAVIS.')
return overlay_davis(image, mask)
else:
return overlay_layer(image, mask, layer, target_object)
else:
raise NotImplementedError
def get_visualization_torch(mode, image, prob, layer, target_object):
if mode == 'fade':
return overlay_davis_torch(image, prob, fade=True)
elif mode == 'davis':
return overlay_davis_torch(image, prob)
elif mode == 'light':
return overlay_davis_torch(image, prob, 0.9)
elif mode == 'popup':
return overlay_popup_torch(image, prob, target_object)
elif mode == 'layered':
if layer is None:
print('Layer file not given. Defaulting to DAVIS.')
return overlay_davis_torch(image, prob)
else:
return overlay_layer_torch(image, prob, layer, target_object)
else:
raise NotImplementedError
def overlay_davis(image, mask, alpha=0.5, fade=False):
""" Overlay segmentation on top of RGB image. from davis official"""
im_overlay = image.copy()
colored_mask = color_map_np[mask]
foreground = image*alpha + (1-alpha)*colored_mask
binary_mask = (mask > 0)
# Compose image
im_overlay[binary_mask] = foreground[binary_mask]
if fade:
im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
return im_overlay.astype(image.dtype)
def overlay_popup(image, mask, target_object):
# Keep foreground colored. Convert background to grayscale.
im_overlay = image.copy()
binary_mask = ~(np.isin(mask, target_object))
colored_region = (im_overlay[binary_mask]*grayscale_weights).sum(-1, keepdims=-1)
im_overlay[binary_mask] = colored_region
return im_overlay.astype(image.dtype)
def overlay_layer(image, mask, layer, target_object):
# insert a layer between foreground and background
# The CPU version is less accurate because we are using the hard mask
# The GPU version has softer edges as it uses soft probabilities
obj_mask = (np.isin(mask, target_object)).astype(np.float32)
layer_alpha = layer[:, :, 3].astype(np.float32) / 255
layer_rgb = layer[:, :, :3]
background_alpha = np.maximum(obj_mask, layer_alpha)[:,:,np.newaxis]
obj_mask = obj_mask[:,:,np.newaxis]
im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 255)
return im_overlay.astype(image.dtype)
def overlay_davis_torch(image, mask, alpha=0.5, fade=False):
""" Overlay segmentation on top of RGB image. from davis official"""
# Changes the image in-place to avoid copying
image = image.permute(1, 2, 0)
im_overlay = image
mask = torch.argmax(mask, dim=0)
colored_mask = color_map_torch[mask]
foreground = image*alpha + (1-alpha)*colored_mask
binary_mask = (mask > 0)
# Compose image
im_overlay[binary_mask] = foreground[binary_mask]
if fade:
im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
im_overlay = (im_overlay*255).cpu().numpy()
im_overlay = im_overlay.astype(np.uint8)
return im_overlay
def overlay_popup_torch(image, mask, target_object):
# Keep foreground colored. Convert background to grayscale.
image = image.permute(1, 2, 0)
if len(target_object) == 0:
obj_mask = torch.zeros_like(mask[0]).unsqueeze(2)
else:
# I should not need to convert this to numpy.
# uUsing list works most of the time but consistently fails
# if I include first object -> exclude it -> include it again.
# I check everywhere and it makes absolutely no sense.
# I am blaming this on PyTorch and calling it a day
obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0).unsqueeze(2)
gray_image = (image*grayscale_weights_torch).sum(-1, keepdim=True)
im_overlay = obj_mask*image + (1-obj_mask)*gray_image
im_overlay = (im_overlay*255).cpu().numpy()
im_overlay = im_overlay.astype(np.uint8)
return im_overlay
def overlay_layer_torch(image, mask, layer, target_object):
# insert a layer between foreground and background
# The CPU version is less accurate because we are using the hard mask
# The GPU version has softer edges as it uses soft probabilities
image = image.permute(1, 2, 0)
if len(target_object) == 0:
obj_mask = torch.zeros_like(mask[0])
else:
# I should not need to convert this to numpy.
# uUsing list works most of the time but consistently fails
# if I include first object -> exclude it -> include it again.
# I check everywhere and it makes absolutely no sense.
# I am blaming this on PyTorch and calling it a day
obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0)
layer_alpha = layer[:, :, 3]
layer_rgb = layer[:, :, :3]
background_alpha = torch.maximum(obj_mask, layer_alpha).unsqueeze(2)
obj_mask = obj_mask.unsqueeze(2)
im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 1)
im_overlay = (im_overlay*255).cpu().numpy()
im_overlay = im_overlay.astype(np.uint8)
return im_overlay

View File

@@ -1,206 +0,0 @@
import os
from os import path
import shutil
import collections
import cv2
from PIL import Image
if not hasattr(Image, 'Resampling'): # Pillow<9.0
Image.Resampling = Image
import numpy as np
from util.palette import davis_palette
import progressbar
# https://bugs.python.org/issue28178
# ah python ah why
class LRU:
def __init__(self, func, maxsize=128):
self.cache = collections.OrderedDict()
self.func = func
self.maxsize = maxsize
def __call__(self, *args):
cache = self.cache
if args in cache:
cache.move_to_end(args)
return cache[args]
result = self.func(*args)
cache[args] = result
if len(cache) > self.maxsize:
cache.popitem(last=False)
return result
def invalidate(self, key):
self.cache.pop(key, None)
class ResourceManager:
def __init__(self, config):
# determine inputs
images = config['images']
video = config['video']
self.workspace = config['workspace']
self.size = config['size']
self.palette = davis_palette
# create temporary workspace if not specified
if self.workspace is None:
if images is not None:
basename = path.basename(images)
elif video is not None:
basename = path.basename(video)[:-4]
else:
raise NotImplementedError(
'Either images, video, or workspace has to be specified')
self.workspace = path.join('./workspace', basename)
print(f'Workspace is in: {self.workspace}')
# determine the location of input images
need_decoding = False
need_resizing = False
if path.exists(path.join(self.workspace, 'images')):
pass
elif images is not None:
need_resizing = True
elif video is not None:
# will decode video into frames later
need_decoding = True
# create workspace subdirectories
self.image_dir = path.join(self.workspace, 'images')
self.mask_dir = path.join(self.workspace, 'masks')
os.makedirs(self.image_dir, exist_ok=True)
os.makedirs(self.mask_dir, exist_ok=True)
# convert read functions to be buffered
self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size'])
self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size'])
# extract frames from video
if need_decoding:
self._extract_frames(video)
# copy/resize existing images to the workspace
if need_resizing:
self._copy_resize_frames(images)
# read all frame names
self.names = sorted(os.listdir(self.image_dir))
self.names = [f[:-4] for f in self.names] # remove extensions
self.length = len(self.names)
assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.'
print(f'{self.length} images found.')
self.height, self.width = self.get_image(0).shape[:2]
self.visualization_init = False
def _extract_frames(self, video):
cap = cv2.VideoCapture(video)
frame_index = 0
print(f'Extracting frames from {video} into {self.image_dir}...')
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength)
while(cap.isOpened()):
_, frame = cap.read()
if frame is None:
break
if self.size > 0:
h, w = frame.shape[:2]
new_w = (w*self.size//min(w, h))
new_h = (h*self.size//min(w, h))
if new_w != w or new_h != h:
frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame)
frame_index += 1
bar.update(frame_index)
bar.finish()
print('Done!')
def _copy_resize_frames(self, images):
image_list = os.listdir(images)
print(f'Copying/resizing frames into {self.image_dir}...')
for image_name in progressbar.progressbar(image_list):
if self.size < 0:
# just copy
shutil.copy2(path.join(images, image_name), self.image_dir)
else:
frame = cv2.imread(path.join(images, image_name))
h, w = frame.shape[:2]
new_w = (w*self.size//min(w, h))
new_h = (h*self.size//min(w, h))
if new_w != w or new_h != h:
frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
cv2.imwrite(path.join(self.image_dir, image_name), frame)
print('Done!')
def save_mask(self, ti, mask):
# mask should be uint8 H*W without channels
assert 0 <= ti < self.length
assert isinstance(mask, np.ndarray)
mask = Image.fromarray(mask)
mask.putpalette(self.palette)
mask.save(path.join(self.mask_dir, self.names[ti]+'.png'))
self.invalidate(ti)
def save_visualization(self, ti, image):
# image should be uint8 3*H*W
assert 0 <= ti < self.length
assert isinstance(image, np.ndarray)
if not self.visualization_init:
self.visualization_dir = path.join(self.workspace, 'visualization')
os.makedirs(self.visualization_dir, exist_ok=True)
self.visualization_init = True
image = Image.fromarray(image)
image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg'))
def _get_image_unbuffered(self, ti):
# returns H*W*3 uint8 array
assert 0 <= ti < self.length
image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg'))
image = np.array(image)
return image
def _get_mask_unbuffered(self, ti):
# returns H*W uint8 array
assert 0 <= ti < self.length
mask_path = path.join(self.mask_dir, self.names[ti]+'.png')
if path.exists(mask_path):
mask = Image.open(mask_path)
mask = np.array(mask)
return mask
else:
return None
def read_external_image(self, file_name, size=None):
image = Image.open(file_name)
is_mask = image.mode in ['L', 'P']
if size is not None:
# PIL uses (width, height)
image = image.resize((size[1], size[0]),
resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC)
image = np.array(image)
return image
def invalidate(self, ti):
# the image buffer is never invalidated
self.get_mask.invalidate((ti,))
def __len__(self):
return self.length
@property
def h(self):
return self.height
@property
def w(self):
return self.width

View File

@@ -1,180 +0,0 @@
# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
import torch
from torch import nn
from torch.nn import functional as F
from .utils import _SimpleSegmentationModel
__all__ = ["DeepLabV3"]
class DeepLabV3(_SimpleSegmentationModel):
"""
Implements DeepLabV3 model from
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>`_.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
class DeepLabHeadV3Plus(nn.Module):
def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
super(DeepLabHeadV3Plus, self).__init__()
self.project = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True),
)
self.aspp = ASPP(in_channels, aspp_dilate)
self.classifier = nn.Sequential(
nn.Conv2d(304, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)
self._init_weight()
def forward(self, feature):
low_level_feature = self.project( feature['low_level'] )
output_feature = self.aspp(feature['out'])
output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class DeepLabHead(nn.Module):
def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
super(DeepLabHead, self).__init__()
self.classifier = nn.Sequential(
ASPP(in_channels, aspp_dilate),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)
self._init_weight()
def forward(self, feature):
return self.classifier( feature['out'] )
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class AtrousSeparableConvolution(nn.Module):
""" Atrous Separable Convolution
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=True):
super(AtrousSeparableConvolution, self).__init__()
self.body = nn.Sequential(
# Separable Conv
nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),
# PointWise Conv
nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
)
self._init_weight()
def forward(self, x):
return self.body(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
def forward(self, x):
size = x.shape[-2:]
x = super(ASPPPooling, self).forward(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates):
super(ASPP, self).__init__()
out_channels = 256
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)))
rate1, rate2, rate3 = tuple(atrous_rates)
modules.append(ASPPConv(in_channels, out_channels, rate1))
modules.append(ASPPConv(in_channels, out_channels, rate2))
modules.append(ASPPConv(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.1),)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
def convert_to_separable_conv(module):
new_module = module
if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
new_module = AtrousSeparableConvolution(module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.bias)
for name, child in module.named_children():
new_module.add_module(name, convert_to_separable_conv(child))
return new_module

View File

@@ -1,65 +0,0 @@
# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
from .utils import IntermediateLayerGetter
from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
from . import s2m_resnet
def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
if output_stride==8:
replace_stride_with_dilation=[False, True, True]
aspp_dilate = [12, 24, 36]
else:
replace_stride_with_dilation=[False, False, True]
aspp_dilate = [6, 12, 18]
backbone = s2m_resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=replace_stride_with_dilation)
inplanes = 2048
low_level_planes = 256
if name=='deeplabv3plus':
return_layers = {'layer4': 'out', 'layer1': 'low_level'}
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
elif name=='deeplabv3':
return_layers = {'layer4': 'out'}
classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
model = DeepLabV3(backbone, classifier)
return model
def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
if backbone.startswith('resnet'):
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
else:
raise NotImplementedError
return model
# Deeplab v3
def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
# Deeplab v3+
def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)

View File

@@ -1,182 +0,0 @@
import torch
import torch.nn as nn
try:
from torchvision.models.utils import load_state_dict_from_url
except ModuleNotFoundError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
__all__ = ['ResNet', 'resnet50']
model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)

View File

@@ -1,78 +0,0 @@
# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict
class _SimpleSegmentationModel(nn.Module):
def __init__(self, backbone, classifier):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
def forward(self, x):
input_shape = x.shape[-2:]
features = self.backbone(x)
x = self.classifier(features)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
return x
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model
It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
Examples::
>>> m = torchvision.models.resnet18(pretrained=True)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = new_m(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out

View File

@@ -1,39 +0,0 @@
import torch
import numpy as np
from ..interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M
from util.tensor_util import pad_divide_by, unpad
class S2MController:
"""
A controller for Scribble-to-Mask (for user interaction, not for DAVIS)
Takes the image, previous mask, and scribbles to produce a new mask
ignore_class is usually 255
0 is NOT the ignore class -- it is the label for the background
"""
def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'):
self.s2m_net = s2m_net
self.num_objects = num_objects
self.ignore_class = ignore_class
self.device = device
def interact(self, image, prev_mask, scr_mask):
image = image.to(self.device, non_blocking=True)
prev_mask = prev_mask.unsqueeze(0)
h, w = image.shape[-2:]
unaggre_mask = torch.zeros((self.num_objects, h, w), dtype=torch.float32, device=image.device)
for ki in range(1, self.num_objects+1):
p_srb = (scr_mask==ki).astype(np.uint8)
n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8)
Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device)
inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1)
inputs, pads = pad_divide_by(inputs, 16)
unaggre_mask[ki-1] = unpad(torch.sigmoid(self.s2m_net(inputs)), pads)
return unaggre_mask

View File

@@ -1,33 +0,0 @@
import time
class Timer:
def __init__(self):
self._acc_time = 0
self._paused = True
def start(self):
if self._paused:
self.last_time = time.time()
self._paused = False
return self
def pause(self):
self.count()
self._paused = True
return self
def count(self):
if self._paused:
return self._acc_time
t = time.time()
self._acc_time += t - self.last_time
self.last_time = t
return self._acc_time
def format(self):
# count = int(self.count()*100)
# return '%02d:%02d:%02d' % (count//6000, (count//100)%60, count%100)
return '%03.2f' % self.count()
def __str__(self):
return self.format()

View File

@@ -1,95 +0,0 @@
"""
A simple user interface for XMem
"""
import os
# fix for Windows
if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ:
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = ''
import sys
from argparse import ArgumentParser
import torch
from model.network import XMem
from inference.interact.s2m_controller import S2MController
from inference.interact.fbrs_controller import FBRSController
from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M
from PyQt5.QtWidgets import QApplication
from inference.interact.gui import App
from inference.interact.resource_manager import ResourceManager
torch.set_grad_enabled(False)
if __name__ == '__main__':
# Arguments parsing
parser = ArgumentParser()
parser.add_argument('--model', default='./saves/XMem.pth')
parser.add_argument('--s2m_model', default='saves/s2m.pth')
parser.add_argument('--fbrs_model', default='saves/fbrs.pth')
"""
Priority 1: If a "images" folder exists in the workspace, we will read from that directory
Priority 2: If --images is specified, we will copy/resize those images to the workspace
Priority 3: If --video is specified, we will extract the frames to the workspace (in an "images" folder) and read from there
In any case, if a "masks" folder exists in the workspace, we will use that to initialize the mask
That way, you can continue annotation from an interrupted run as long as the same workspace is used.
"""
parser.add_argument('--images', help='Folders containing input images.', default=None)
parser.add_argument('--video', help='Video file readable by OpenCV.', default=None)
parser.add_argument('--workspace', help='directory for storing buffered images (if needed) and output masks', default=None)
parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100)
parser.add_argument('--num_objects', type=int, default=1)
# Long-memory options
# Defaults. Some can be changed in the GUI.
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
type=int, default=10000)
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
parser.add_argument('--top_k', type=int, default=30)
parser.add_argument('--mem_every', type=int, default=10)
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
parser.add_argument('--no_amp', help='Turn off AMP', action='store_true')
parser.add_argument('--size', default=480, type=int,
help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()
config = vars(args)
config['enable_long_term'] = True
config['enable_long_term_count_usage'] = True
with torch.cuda.amp.autocast(enabled=not args.no_amp):
# Load our checkpoint
network = XMem(config, args.model).cuda().eval()
# Loads the S2M model
if args.s2m_model is not None:
s2m_saved = torch.load(args.s2m_model)
s2m_model = S2M().cuda().eval()
s2m_model.load_state_dict(s2m_saved)
else:
s2m_model = None
s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255)
if args.fbrs_model is not None:
fbrs_controller = FBRSController(args.fbrs_model)
else:
fbrs_controller = None
# Manages most IO
resource_manager = ResourceManager(config)
app = QApplication(sys.argv)
ex = App(network, resource_manager, s2m_controller, fbrs_controller, config)
sys.exit(app.exec_())

View File

@@ -1,135 +0,0 @@
import os
from os import path
from argparse import ArgumentParser
import glob
from collections import defaultdict
import numpy as np
import hickle as hkl
from PIL import Image, ImagePalette
from progressbar import progressbar
from multiprocessing import Pool
from util import palette
from util.palette import davis_palette, youtube_palette
import shutil
def search_options(options, name):
for option in options:
if path.exists(path.join(option, name)):
return path.join(option, name)
else:
return None
def process_vid(vid):
vid_path = search_options(all_options, vid)
if vid_path is not None:
backward_mapping = hkl.load(path.join(vid_path, 'backward.hkl'))
else:
backward_mapping = None
frames = os.listdir(path.join(all_options[0], vid))
frames = [f for f in frames if 'backward' not in f]
print(vid)
if 'Y' in args.dataset:
this_out_path = path.join(out_path, 'Annotations', vid)
else:
this_out_path = path.join(out_path, vid)
os.makedirs(this_out_path, exist_ok=True)
for f in frames:
result_sum = None
for option in all_options:
if not path.exists(path.join(option, vid, f)):
continue
result = hkl.load(path.join(option, vid, f))
if result_sum is None:
result_sum = result.astype(np.float32)
else:
result_sum += result
# argmax and to idx
result_sum = np.argmax(result_sum, axis=0)
# Remap the indices to the original domain
if backward_mapping is not None:
idx_mask = np.zeros_like(result_sum, dtype=np.uint8)
for l, i in backward_mapping.items():
idx_mask[result_sum==i] = l
else:
idx_mask = result_sum.astype(np.uint8)
# Save the results
img_E = Image.fromarray(idx_mask)
img_E.putpalette(palette)
img_E.save(path.join(this_out_path, f[:-4]+'.png'))
if __name__ == '__main__':
"""
Arguments loading
"""
parser = ArgumentParser()
parser.add_argument('--dataset', default='Y', help='D/Y, D for DAVIS; Y for YouTubeVOS')
parser.add_argument('--list', nargs="+")
parser.add_argument('--pattern', default=None, help='Glob patten. Can be used in place of list.')
parser.add_argument('--output')
parser.add_argument('--num_proc', default=4, type=int)
args = parser.parse_args()
out_path = args.output
# Find the input candidates
if args.pattern is None:
all_options = args.list
else:
assert args.list is None, 'cannot specify both list and pattern'
all_options = glob.glob(args.pattern)
# Get the correct palette
if 'D' in args.dataset:
palette = ImagePalette.ImagePalette(mode='P', palette=davis_palette)
elif 'Y' in args.dataset:
palette = ImagePalette.ImagePalette(mode='P', palette=youtube_palette)
else:
raise NotImplementedError
# Count of the number of videos in each candidate
all_options = [path.join(o, 'Scores') for o in all_options]
vid_count = defaultdict(int)
for option in all_options:
vid_in_here = sorted(os.listdir(option))
for vid in vid_in_here:
vid_count[vid] += 1
all_vid = []
count_to_vid = defaultdict(int)
for k, v in vid_count.items():
count_to_vid[v] += 1
all_vid.append(k)
for k, v in count_to_vid.items():
print('Videos with count %d: %d' % (k, v))
all_vid = sorted(all_vid)
print('Total number of videos: ', len(all_vid))
pool = Pool(processes=args.num_proc)
for _ in progressbar(pool.imap_unordered(process_vid, all_vid), max_value=len(all_vid)):
pass
pool.close()
pool.join()
if 'D' in args.dataset:
print('Making zip for DAVIS test-dev...')
shutil.make_archive(args.output, 'zip', args.output)
if 'Y' in args.dataset:
print('Making zip for YouTubeVOS...')
shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')

View File

View File

@@ -5,3 +5,4 @@ git+https://github.com/cheind/py-thin-plate-spline
hickle
tensorboard
numpy
git+https://github.com/facebookresearch/segment-anything.git

View File

@@ -1,3 +0,0 @@
PyQt5
Cython
scipy

View File

Some files were not shown because too many files have changed in this diff Show More