diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..7f85032
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+__pycache__/
+.vscode/
+docs/
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..922f7df
--- /dev/null
+++ b/README.md
@@ -0,0 +1,11 @@
+# Track-Anything
+
+## Demo
+***
+This is Demo
+## Get Started.
+***
+This is Get Started.
+## Acknowledgement
+***
+The project is based on [XMem](https://github.com/facebookresearch/segment-anything) and [Segment Anything](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts.
diff --git a/dataset/__pycache__/__init__.cpython-39.pyc b/dataset/__pycache__/__init__.cpython-39.pyc
deleted file mode 100644
index 7087dca..0000000
Binary files a/dataset/__pycache__/__init__.cpython-39.pyc and /dev/null differ
diff --git a/dataset/__pycache__/range_transform.cpython-39.pyc b/dataset/__pycache__/range_transform.cpython-39.pyc
deleted file mode 100644
index 2f1270f..0000000
Binary files a/dataset/__pycache__/range_transform.cpython-39.pyc and /dev/null differ
diff --git a/dataset/__pycache__/reseed.cpython-39.pyc b/dataset/__pycache__/reseed.cpython-39.pyc
deleted file mode 100644
index 5d302f2..0000000
Binary files a/dataset/__pycache__/reseed.cpython-39.pyc and /dev/null differ
diff --git a/dataset/__pycache__/static_dataset.cpython-39.pyc b/dataset/__pycache__/static_dataset.cpython-39.pyc
deleted file mode 100644
index d5b5e19..0000000
Binary files a/dataset/__pycache__/static_dataset.cpython-39.pyc and /dev/null differ
diff --git a/dataset/__pycache__/tps.cpython-39.pyc b/dataset/__pycache__/tps.cpython-39.pyc
deleted file mode 100644
index aba1082..0000000
Binary files a/dataset/__pycache__/tps.cpython-39.pyc and /dev/null differ
diff --git a/dataset/__pycache__/vos_dataset.cpython-39.pyc b/dataset/__pycache__/vos_dataset.cpython-39.pyc
deleted file mode 100644
index d7e8557..0000000
Binary files a/dataset/__pycache__/vos_dataset.cpython-39.pyc and /dev/null differ
diff --git a/docs/DEMO.md b/docs/DEMO.md
deleted file mode 100644
index cd1d283..0000000
--- a/docs/DEMO.md
+++ /dev/null
@@ -1,58 +0,0 @@
-# Interactive GUI for Demo
-
-First, set up the required packages following [GETTING STARTED.md](./GETTING_STARTED.md). You can ignore the dataset part as you wouldn't be needing them for this demo. Download the pretrained models following [INFERENCE.md](./INFERENCE.md).
-
-You will need some additional packages and pretrained models for the GUI. For the packages,
-
-```bash
-pip install -r requirements_demo.txt
-```
-
-The interactive GUI is modified from [MiVOS](https://github.com/hkchengrex/MiVOS). Specifically, we keep the "interaction-to-mask" module and the propagation module is replaced with XMem. The fusion module is discarded because I don't want to train it.
-For interactions, we use [f-BRS](https://github.com/saic-vul/fbrs_interactive_segmentation) and [S2M](https://github.com/hkchengrex/Scribble-to-Mask). You will need their pretrained models. Use `./scripts/download_models_demo.sh` or download them manually into `./saves`.
-
-The entry point is `interactive_demo.py`. The command line arguments should be self-explanatory.
-
-
-
-## Try it for yourself
-
-https://user-images.githubusercontent.com/7107196/177661140-f690156b-1775-4cd7-acd7-1738a5c92f30.mp4
-
-Right-click download this video (source: https://www.youtube.com/watch?v=FTcjzaqL0pE). Then run
-
-```bash
-python interactive_demo.py --video [path to the video] --num_objects 4
-```
-
-## Features
-
-* Low CPU memory cost. Unlike the implementation in MiVOS, we do not load all the images as the program starts up. We load them on-the-fly with an LRU buffer.
-* Low GPU memory cost. This is provided by XMem. See the paper.
-* Faster than MiVOS-STCN, especially for long videos. ^
-* You can continue from interrupted runs. We save the resultant masks on-the-fly in the workspace directory from which annotation can be resumed. The memory bank is not saved and cannot be resumed.
-
-## Controls
-
-* Use the slider to change the current frame. "Play Video" automatically progresses the video.
-* Select interaction type: "scribble", "click", or "free". Both scribble and "free" (free-hand drawing) modify an existing mask. Using "click" on an existing object mask (i.e., a mask from propagation or other interaction methods) will reset the mask. This is because f-BRS does not take an existing mask as input.
-* Select the target object using the number keys. "1" corresponds to the first object, etc. You need to specify the maximum number of objects when you start the program through the command line.
-* Use propagate forward/backward to let XMem do the job. Pause when correction is needed. It will only automatically stops when it hits the end of the video.
-* Make sure all objects are correctly labeled before propagating. The program doesn't care which object you have interacted with -- it treats everything as user-provided inputs. Not labelling an object implicitly means that it is part of the background.
-* The memory bank might be "polluted" by bad memory frames. Feel free to hit clear memory to erase that. Propagation runs faster with a small memory bank.
-* All output masks are automatically saved in the workspace directory, which is printed when the program starts.
-* You can load an external mask for the current frame using "Import mask".
-* For "layered insertion" (e.g., the breakdance demo), use the "layered" overlay mode. You can load a custom layer using "Import layer". The layer should be an RGBA png file. RGB image files are also accepted -- the alpha channel will be filled with ones.
-* The "save overlay during propagation" checkbox does exactly that. It does not save the overlay when the user is just scrubbing the timeline.
-* For "popup" and "layered", the visualizations during propagation (and the saved overlays) have higher quality then when the user is scrubbing the timeline. This is because we have access to the soft probability mask during propagation.
-* Both "popup" and "layered" need a binary mask. By default, the first object mask is used. You can change the target (or make the target a union of objects) using the middle mouse key.
-
-## FAQ
-
-1. Why cannot I label object 2 after pressing the number '2'?
- - Make sure you specified `--num_objects`. We ignore object IDs that exceed `num_objects`.
-2. The GUI feels slow!
- - The GUI needs to read/write images and masks on-the-go. Ideally this can be implemented with multiple threads with look-ahead but I didn't. The overheads will be smaller if you place the `workspace` on a SSD. You can also use a ram disk. `eval.py` will almost certainly be faster.
- - It takes more time to process more objects. This depends on `num_objects`, but not the actual number of objects that the user has annotated. *This does not mean that running time is directly proportional to the number of objects. There is significant shared computation.*
-3. Can I run this on a remote server?
- - X11 forwarding should be possible. I have not tried this and would love to know if it works for you.
diff --git a/docs/ECCV-logo.png b/docs/ECCV-logo.png
deleted file mode 100644
index 081a1f9..0000000
Binary files a/docs/ECCV-logo.png and /dev/null differ
diff --git a/docs/FAILURE_CASES.md b/docs/FAILURE_CASES.md
deleted file mode 100644
index 108566b..0000000
--- a/docs/FAILURE_CASES.md
+++ /dev/null
@@ -1,25 +0,0 @@
-# Failure Cases
-
-Like all methods, XMem can fail. Here, we try to show some illustrative and frankly consistent failure modes that we noticed. We slowed down all videos for visualization.
-
-## Fast motion, similar objects
-
-The first one is fast motion with similarly-looking objects that do not provide sufficient appearance clues for XMem to track. Below is an example from the YouTubeVOS validation set (0e8a6b63bb):
-
-https://user-images.githubusercontent.com/7107196/179459162-80b65a6c-439d-4239-819f-68804d9412e9.mp4
-
-And the source video:
-
-https://user-images.githubusercontent.com/7107196/181700094-356284bc-e8a4-4757-ab84-1e9009fddd4b.mp4
-
-Technically it can be solved by using more positional and motion clues. XMem is not sufficiently proficient at those.
-
-## Shot changes; saliency shift
-
-Ever wondered why I did not include the final scene of Chika Dance when the roach flies off? Because it failed there.
-
-XMem seems to be attracted to any new salient object in the scene when the (true) target object is missing. By new I mean an object that did not appear (or had a different appearance) earlier in the video -- as XMem could not have a memory representation for that object. This happens a lot if the camera shot changes.
-
-https://user-images.githubusercontent.com/7107196/179459190-d736937a-6925-4472-b46e-dcf94e1cafc0.mp4
-
-Note that the first shot change is not as problematic.
diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md
deleted file mode 100644
index aaa4295..0000000
--- a/docs/GETTING_STARTED.md
+++ /dev/null
@@ -1,64 +0,0 @@
-# Getting Started
-
-Our code is tested on Ubuntu. I have briefly tested the GUI on Windows (with a PyQt5 fix in the heading of interactive_demo.py).
-
-## Requirements
-
-* Python 3.8+
-* PyTorch 1.11+ (See [PyTorch](https://pytorch.org/) for installation instructions)
-* `torchvision` corresponding to the PyTorch version
-* OpenCV (try `pip install opencv-python`)
-* Others: `pip install -r requirements.txt`
-
-## Dataset
-
-I recommend either softlinking (`ln -s`) existing data or use the provided `scripts/download_datasets.py` to structure the datasets as our format.
-
-`python -m scripts.download_dataset`
-
-The structure is the same as the one in STCN -- you can place XMem in the same folder as STCN and it will work.
-The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case.
-It does not download BL30K because it is huge and we don't want to crash your harddisks.
-
-```bash
-├── XMem
-├── BL30K
-├── DAVIS
-│ ├── 2016
-│ │ ├── Annotations
-│ │ └── ...
-│ └── 2017
-│ ├── test-dev
-│ │ ├── Annotations
-│ │ └── ...
-│ └── trainval
-│ ├── Annotations
-│ └── ...
-├── static
-│ ├── BIG_small
-│ └── ...
-├── long_video_set
-│ ├── long_video
-│ ├── long_video_x3
-│ ├── long_video_davis
-│ └── ...
-├── YouTube
-│ ├── all_frames
-│ │ └── valid_all_frames
-│ ├── train
-│ ├── train_480p
-│ └── valid
-└── YouTube2018
- ├── all_frames
- │ └── valid_all_frames
- └── valid
-```
-
-## Long-Time Video
-
-It comes from [AFB-URR](https://github.com/xmlyqing00/AFB-URR). Please following their license when using this data. We release our extended version (X3) and corresponding `_davis` versions such that the DAVIS evaluation can be used directly. They can be downloaded [[here]](TODO). The script above would also attempt to download it.
-
-### BL30K
-
-You can either use the automatic script `download_bl30k.py` or download it manually from [MiVOS](https://github.com/hkchengrex/MiVOS/#bl30k). Note that each segment is about 115GB in size -- 700GB in total. You are going to need ~1TB of free disk space to run the script (including extraction buffer).
-The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case.
diff --git a/docs/INFERENCE.md b/docs/INFERENCE.md
deleted file mode 100644
index 86c34fb..0000000
--- a/docs/INFERENCE.md
+++ /dev/null
@@ -1,110 +0,0 @@
-# Inference
-
-What is palette? Why is the output a "colored image"? How do I make those input masks that look like color images? See [PALETTE.md](./PALETTE.md).
-
-1. Set up the datasets following [GETTING_STARTED.md](./GETTING_STARTED.md).
-2. Download the pretrained models either using `./scripts/download_models.sh`, or manually and put them in `./saves` (create the folder if it doesn't exist). You can download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing).
-
-All command-line inference are accessed with `eval.py`. See [RESULTS.md](./RESULTS.md) for an explanation of FPS and the differences between different models.
-
-## Usage
-
-```
-python eval.py --model [path to model file] --output [where to save the output] --dataset [which dataset to evaluate on] --split [val for validation or test for test-dev]
-```
-
-See the code for a complete list of available command-line arguments.
-
-Examples:
-(``--model`` defaults to `./saves/XMem.pth`)
-
-DAVIS 2017 validation:
-
-```
-python eval.py --output ../output/d17 --dataset D17
-```
-
-DAVIS 2016 validation:
-
-```
-python eval.py --output ../output/d16 --dataset D16
-```
-
-DAVIS 2017 test-dev:
-
-```
-python eval.py --output ../output/d17-td --dataset D17 --split test
-```
-
-YouTubeVOS 2018 validation:
-
-```
-python eval.py --output ../output/y18 --dataset Y18
-```
-
-Long-Time Video (3X) (note that `mem_every`, aka `r`, is set differently):
-
-```
-python eval.py --output ../output/lv3 --dataset LV3 --mem_every 10
-```
-
-## Getting quantitative results
-
-We do not provide any tools for getting quantitative results here. We used the followings to get the results reported in the paper:
-
-- DAVIS 2017 validation: [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation)
-- DAVIS 2016 validation: [davis2016-evaluation](https://github.com/hkchengrex/davis2016-evaluation) (Unofficial)
-- DAVIS 2017 test-dev: [CodaLab](https://competitions.codalab.org/competitions/20516#participate)
-- YouTubeVOS 2018 validation: [CodaLab](https://competitions.codalab.org/competitions/19544#results)
-- YouTubeVOS 2019 validation: [CodaLab](https://competitions.codalab.org/competitions/20127#participate-submit_results)
-- Long-Time Video: [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation)
-
-(For the Long-Time Video dataset, point `--davis_path` to either `long_video_davis` or `long_video_davis_x3`)
-
-## On custom data
-
-Structure your custom data like this:
-
-```bash
-├── custom_data_root
-│ ├── JPEGImages
-│ │ ├── video1
-│ │ │ ├── 00001.jpg
-│ │ │ ├── 00002.jpg
-│ │ │ ├── ...
-│ │ └── ...
-│ ├── Annotations
-│ │ ├── video1
-│ │ │ ├── 00001.png
-│ │ │ ├── ...
-│ │ └── ...
-```
-
-We use `sort` to determine frame order. The annotations do not have have to be complete (e.g., first-frame only is fine). We use PIL to read the annotations and `np.unique` to determine objects. PNG palette will be used automatically if exists.
-
-Then, point `--generic_path` to `custom_data_root` and specify `--dataset` as `G` (for generic).
-
-## Multi-scale evaluation
-
-Multi-scale evaluation is done in two steps. We first compute and save the object probabilities maps for different settings independently on hard-disks as `hkl` (hickle) files. Then, these maps are merged together with `merge_multi_score.py`.
-
-Example for DAVIS 2017 validation MS:
-
-Step 1 (can be done in parallel with multiple GPUs):
-
-```bash
-python eval.py --output ../output/d17_ms/720p --mem_every 3 --dataset D17 --save_scores --size 720
-python eval.py --output ../output/d17_ms/720p_flip --mem_every 3 --dataset D17 --save_scores --size 720 --flip
-```
-
-Step 2:
-
-```bash
-python merge_multi_scale.py --dataset D --list ../output/d17_ms/720p ../output/d17_ms/720p_flip --output ../output/d17_ms_merged
-```
-
-Instead of `--list`, you can also use `--pattern` to specify a glob pattern. It also depends on your shell (e.g., `zsh` or `bash`).
-
-## Advanced usage
-
-To develop your own evaluation interface, see `./inference/` -- most importantly, `inference_core.py`.
diff --git a/docs/PALETTE.md b/docs/PALETTE.md
deleted file mode 100644
index acd5333..0000000
--- a/docs/PALETTE.md
+++ /dev/null
@@ -1,13 +0,0 @@
-# Palette
-
-> Some image formats, such as GIF or PNG, can use a palette, which is a table of (usually) 256 colors to allow for better compression. Basically, instead of representing each pixel with its full color triplet, which takes 24bits (plus eventual 8 more for transparency), they use a 8 bit index that represent the position inside the palette, and thus the color.
--- https://docs.geoserver.org/2.22.x/en/user/tutorials/palettedimage/palettedimage.html
-
-So those mask files that look like color images are single-channel, `uint8` arrays under the hood. When `PIL` reads them, it (correctly) gives you a two-dimensional array (`opencv` does not work AFAIK). If what you get is instead of three-dimensional, `H*W*3` array, then your mask is not actually a paletted mask, but just a colored image. Reading and saving a paletted mask through `opencv` or MS Paint would destroy the palette.
-
-Our code, when asked to generate multi-object segmentation (e.g., DAVIS 2017/YouTubeVOS), always reads and writes single-channel mask. If there is a palette in the input, we will use it in the output. The code does not care whether a palette is actually used -- we can read grayscale images just fine.
-
-Importantly, we use `np.unique` to determine the number of objects in the mask. This would fail if:
-
-1. Colored images, instead of paletted masks are used.
-2. The masks have "smooth" edges, produced by feathering/downsizing/compression. For example, when you draw the mask in a painting software, make sure you set the brush hardness to maximum.
diff --git a/docs/RESULTS.md b/docs/RESULTS.md
deleted file mode 100644
index abef0c5..0000000
--- a/docs/RESULTS.md
+++ /dev/null
@@ -1,104 +0,0 @@
-# Results
-
-## Preamble
-
-Our code, by default, uses automatic mixed precision (AMP). Its effect on the output is negligible.
-All speeds reported in the paper are recorded with AMP turned off (`--benchmark`).
-Due to refactoring, there might be slight differences between the outputs produced by this code base with the precomputed results/results reported in the paper. This difference rarely leads to a change of the least significant figure (i.e., 0.1).
-
-**For most complete results, please see the paper (and the appendix)!**
-
-All available precomputed results can be found [[here]](https://drive.google.com/drive/folders/1UxHPXJbQLHjF5zYVn3XZCXfi_NYL81Bf?usp=sharing).
-
-## Pretrained models
-
-We provide four pretrained models for download:
-
-1. XMem.pth (Default)
-2. XMem-s012.pth (Trained with BL30K)
-3. XMem-s2.pth (No pretraining on static images)
-4. XMem-no-sensory (No sensory memory)
-
-The model without pretraining is for reference. The model without sensory memory might be more suitable for tasks without spatial continuity, like mask tracking in a multi-camera 3D reconstruction setting, though I would encourage you to try the base model as well.
-
-Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing).
-
-## Long-Time Video
-
-[[Precomputed Results]](https://drive.google.com/drive/folders/1NADcetigH6d83mUvyb2rH4VVjwFA76Lh?usp=sharing)
-
-### Long-Time Video (1X)
-
-| Model | J&F | J | F |
-| --- | :--:|:--:|:---:|
-| XMem | 89.8±0.2 | 88.0±0.2 | 91.6±0.2 |
-
-### Long-Time Video (3X)
-
-| Model | J&F | J | F |
-| --- | :--:|:--:|:---:|
-| XMem | 90.0±0.4 | 88.2±0.3 | 91.8±0.4 |
-
-## DAVIS
-
-[[Precomputed Results]](https://drive.google.com/drive/folders/1XTOGevTedRSjHnFVsZyTdxJG-iHjO0Re?usp=sharing)
-
-### DAVIS 2016
-
-| Model | J&F | J | F | FPS | FPS (AMP) |
-| --- | :--:|:--:|:---:|:---:|:---:|
-| XMem | 91.5 | 90.4 | 92.7 | 29.6 | 40.3 |
-| XMem-s012 | 92.0 | 90.7 | 93.2 | 29.6 | 40.3 |
-| XMem-s2 | 90.8 | 89.6 | 91.9 | 29.6 | 40.3 |
-
-### DAVIS 2017 validation
-
-| Model | J&F | J | F | FPS | FPS (AMP) |
-| --- | :--:|:--:|:---:|:---:|:---:|
-| XMem | 86.2 | 82.9 | 89.5 | 22.6 | 33.9 |
-| XMem-s012 | 87.7 | 84.0 | 91.4 | 22.6 | 33.9 |
-| XMem-s2 | 84.5 | 81.4 | 87.6 | 22.6 | 33.9 |
-| XMem-no-sensory | 85.1 | - | - | 23.1 | - |
-
-### DAVIS 2017 test-dev
-
-| Model | J&F | J | F |
-| --- | :--:|:--:|:---:|
-| XMem | 81.0 | 77.4 | 84.5 |
-| XMem-s012 | 81.2 | 77.6 | 84.7 |
-| XMem-s2 | 79.8 | 61.4 | 68.1 |
-| XMem-s012 (600p) | 82.5 | 79.1 | 85.8 |
-
-## YouTubeVOS
-
-We use all available frames in YouTubeVOS by default.
-See [INFERENCE.md](./INFERENCE.md) if you want to evaluate with sparse frames for some reason.
-
-[[Precomputed Results]](https://drive.google.com/drive/folders/1P_BmOdcG6OP5mWGqWzCZrhQJ7AaLME4E?usp=sharing)
-
-[[Precomputed Results (sparse)]](https://drive.google.com/drive/folders/1IRV1fHepufUXM45EEbtl9D4pkoh9POSZ?usp=sharing)
-
-### YouTubeVOS 2018 validation
-
-| Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen | FPS | FPS (AMP) |
-| --- | :--:|:--:|:---:|:---:|:---:|:---:|:---:|
-| XMem | 85.7 | 84.6 | 89.3 | 80.2 | 88.7 | 22.6 | 31.7 |
-| XMem-s012 | 86.1 | 85.1 | 89.8 | 80.3 | 89.2 | 22.6 | 31.7 |
-| XMem-s2 | 84.3 | 83.9 | 88.8 | 77.7 | 86.7 | 22.6 | 31.7 |
-| XMem-no-sensory | 84.4 | - | - | - | - | 23.1 | - |
-
-### YouTubeVOS 2019 validation
-
-| Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen |
-| --- | :--:|:--:|:---:|:---:|:---:|
-| XMem | 85.5 | 84.3 | 88.6 | 80.3 | 88.6 |
-| XMem-s012 | 85.8 | 84.8 | 89.2 | 80.3 | 88.8 |
-| XMem-s2 | 84.2 | 83.8 | 88.3 | 78.1 | 86.7 |
-
-## Multi-scale evaluation
-
-Please see the appendix for quantitative results.
-
-[[DAVIS-MS Precomputed Results]](https://drive.google.com/drive/folders/1H3VHKDO09izp6KR3sE-LzWbjyM-jpftn?usp=sharing)
-
-[[YouTubeVOS-MS Precomputed Results]](https://drive.google.com/drive/folders/1ww5HVRbMKXraLd2dy1rtk6kLjEawW9Kn?usp=sharing)
diff --git a/docs/TRAINING.md b/docs/TRAINING.md
deleted file mode 100644
index 526dfb7..0000000
--- a/docs/TRAINING.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# Training
-
-First, set up the datasets following [GETTING STARTED.md](./GETTING_STARTED.md).
-
-The model is trained progressively with different stages (0: static images; 1: BL30K; 2: longer main training; 3: shorter main training). After each stage finishes, we start the next stage by loading the latest trained weight.
-For example, the base model is pretrained with static images followed by the shorter main training (s03).
-
-To train the base model on two GPUs, you can use:
-
-```bash
-python -m torch.distributed.run --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain --stage 03
-```
-(**NOTE**: Unexplained accuracy decrease might occur if you are not using two GPUs to train. See https://github.com/hkchengrex/XMem/issues/71.)
-
-`master_port` needs to point to an unused port.
-`nproc_per_node` refers to the number of GPUs to be used (specify `CUDA_VISIBLE_DEVICES` to select which GPUs to use).
-`exp_id` is an identifier you give to this training job.
-
-See other available command line arguments in `util/configuration.py`.
-**Unlike the training code of STCN, batch sizes are effective. You don't have to adjust the batch size when you use more/fewer GPUs.**
-
-We implemented automatic staging in this code base. You don't have to train different stages by yourself like in STCN (but that is still supported).
-`stage` is a string that we split to determine the training stages. Examples include `0` (static images only), `03` (base training), `012` (with BL30K), `2` (main training only).
-
-You can use `tensorboard` to visualize the training process.
-
-## Outputs
-
-The model files and checkpoints will be saved in `./saves/[name containing datetime and exp_id]`.
-
-`.pth` files with `_checkpoint` store the network weights, optimizer states, etc. and can be used to resume training (with `--load_checkpoint`).
-
-Other `.pth` files store the network weights only and can be used for inference. We note that there are variations in performance across different training runs and across the last few saved models. For the base model, we most often note that main training at 107K iterations leads to the best result (full training is 110K).
-
-We measure the median and std scores across five training runs of the base model:
-
-| Dataset | median | std |
-| --- | :--:|:--:|
-| DAVIS J&F | 86.2 | 0.23 |
-| YouTubeVOS 2018 G | 85.6 | 0.21
-
-## Pretrained models
-
-You can start training from scratch, or use any of our pretrained models for fine-tuning. For example, you can load our stage 0 model to skip main training:
-
-```bash
-python -m torch.distributed.launch --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain_stage3_only --stage 3 --load_network saves/XMem-s0.pth
-```
-
-Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing).
diff --git a/docs/icon.png b/docs/icon.png
deleted file mode 100644
index 45309b2..0000000
Binary files a/docs/icon.png and /dev/null differ
diff --git a/docs/index.html b/docs/index.html
deleted file mode 100644
index 6179ee4..0000000
--- a/docs/index.html
+++ /dev/null
@@ -1,174 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
- XMem
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- XMem: Long-Term Video Object Segmentation with an Atkinson-Shiffrin Memory Model
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Interactive GUI demo available [here] !
-
-
-
-
-
-
-
-
-
-
- We present XMem, a video object segmentation architecture for long videos with unified feature memory stores inspired by the Atkinson-Shiffrin memory model.
- Prior work on video object segmentation typically only uses one type of feature memory. For videos longer than a minute, a single feature memory model tightly links memory consumption and accuracy.
- In contrast, following the Atkinson-Shiffrin model, we develop an architecture that incorporates multiple independent yet deeply-connected feature memory stores: a rapidly updated sensory memory, a high-resolution working memory, and a compact thus sustained long-term memory.
- Crucially, we develop a memory potentiation algorithm that routinely consolidates actively used working memory elements into the long-term memory, which avoids memory explosion and minimizes performance decay for long-term prediction.
- Combined with a new memory reading mechanism, XMem greatly exceeds state-of-the-art performance on long-video datasets while being on par with state-of-the-art methods (that do not work on long videos) on short-video datasets.
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Handling long-term occlusion
-
-
-
-
- VIDEO
-
-
-
-
-
-
-
-
- Very-long video; masked layer insertion
-
-
-
-
- VIDEO
- Source: https://www.youtube.com/watch?v=q5Xr0F4a0iU
-
-
-
-
-
-
-
-
- Out-of-domain case
-
-
-
-
-
-
- Your browser does not support the video tag.
-
- Source: かぐや様は告らせたい ~天才たちの恋愛頭脳戦~ Ep.3; A1 Pictures
-
-
-
-
-
- Contact: Ho Kei (Rex) Cheng hkchengrex@gmail.com
-
-
-
-
-
-
-
-
-
diff --git a/docs/style.css b/docs/style.css
deleted file mode 100644
index 38f27ec..0000000
--- a/docs/style.css
+++ /dev/null
@@ -1,59 +0,0 @@
-body {
- font-family: 'Roboto', sans-serif;
- font-size:18px;
- margin-left: auto;
- margin-right: auto;
- font-weight: 300;
- height: 100%;
- max-width: 1000px;
-}
-
-.light {
- font-weight: 100;
-}
-
-.heavy {
- font-weight: 400;
-}
-
-.column {
- float: left;
-}
-
-.metric_table {
- border-collapse: collapse;
- margin-left: 15px;
- margin-right: auto;
-}
-
-.metric_table th{
- border-bottom: 1px solid #555;
- padding-left: 15px;
- padding-right: 15px;
-}
-
-.metric_table td{
- padding-left: 15px;
- padding-right: 15px;
-}
-
-.metric_table .left_align{
- text-align: left;
-}
-
-a:link,a:visited
-{
- color: #05538f;
- text-decoration: none;
-}
-
-a:hover {
- color: #63cbdd;
-}
-
-hr
-{
- border: 0;
- height: 1px;
- background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
-}
diff --git a/inference/interact/fbrs/LICENSE b/inference/interact/fbrs/LICENSE
deleted file mode 100644
index fa0086a..0000000
--- a/inference/interact/fbrs/LICENSE
+++ /dev/null
@@ -1,373 +0,0 @@
-Mozilla Public License Version 2.0
-==================================
-
-1. Definitions
---------------
-
-1.1. "Contributor"
- means each individual or legal entity that creates, contributes to
- the creation of, or owns Covered Software.
-
-1.2. "Contributor Version"
- means the combination of the Contributions of others (if any) used
- by a Contributor and that particular Contributor's Contribution.
-
-1.3. "Contribution"
- means Covered Software of a particular Contributor.
-
-1.4. "Covered Software"
- means Source Code Form to which the initial Contributor has attached
- the notice in Exhibit A, the Executable Form of such Source Code
- Form, and Modifications of such Source Code Form, in each case
- including portions thereof.
-
-1.5. "Incompatible With Secondary Licenses"
- means
-
- (a) that the initial Contributor has attached the notice described
- in Exhibit B to the Covered Software; or
-
- (b) that the Covered Software was made available under the terms of
- version 1.1 or earlier of the License, but not also under the
- terms of a Secondary License.
-
-1.6. "Executable Form"
- means any form of the work other than Source Code Form.
-
-1.7. "Larger Work"
- means a work that combines Covered Software with other material, in
- a separate file or files, that is not Covered Software.
-
-1.8. "License"
- means this document.
-
-1.9. "Licensable"
- means having the right to grant, to the maximum extent possible,
- whether at the time of the initial grant or subsequently, any and
- all of the rights conveyed by this License.
-
-1.10. "Modifications"
- means any of the following:
-
- (a) any file in Source Code Form that results from an addition to,
- deletion from, or modification of the contents of Covered
- Software; or
-
- (b) any new file in Source Code Form that contains any Covered
- Software.
-
-1.11. "Patent Claims" of a Contributor
- means any patent claim(s), including without limitation, method,
- process, and apparatus claims, in any patent Licensable by such
- Contributor that would be infringed, but for the grant of the
- License, by the making, using, selling, offering for sale, having
- made, import, or transfer of either its Contributions or its
- Contributor Version.
-
-1.12. "Secondary License"
- means either the GNU General Public License, Version 2.0, the GNU
- Lesser General Public License, Version 2.1, the GNU Affero General
- Public License, Version 3.0, or any later versions of those
- licenses.
-
-1.13. "Source Code Form"
- means the form of the work preferred for making modifications.
-
-1.14. "You" (or "Your")
- means an individual or a legal entity exercising rights under this
- License. For legal entities, "You" includes any entity that
- controls, is controlled by, or is under common control with You. For
- purposes of this definition, "control" means (a) the power, direct
- or indirect, to cause the direction or management of such entity,
- whether by contract or otherwise, or (b) ownership of more than
- fifty percent (50%) of the outstanding shares or beneficial
- ownership of such entity.
-
-2. License Grants and Conditions
---------------------------------
-
-2.1. Grants
-
-Each Contributor hereby grants You a world-wide, royalty-free,
-non-exclusive license:
-
-(a) under intellectual property rights (other than patent or trademark)
- Licensable by such Contributor to use, reproduce, make available,
- modify, display, perform, distribute, and otherwise exploit its
- Contributions, either on an unmodified basis, with Modifications, or
- as part of a Larger Work; and
-
-(b) under Patent Claims of such Contributor to make, use, sell, offer
- for sale, have made, import, and otherwise transfer either its
- Contributions or its Contributor Version.
-
-2.2. Effective Date
-
-The licenses granted in Section 2.1 with respect to any Contribution
-become effective for each Contribution on the date the Contributor first
-distributes such Contribution.
-
-2.3. Limitations on Grant Scope
-
-The licenses granted in this Section 2 are the only rights granted under
-this License. No additional rights or licenses will be implied from the
-distribution or licensing of Covered Software under this License.
-Notwithstanding Section 2.1(b) above, no patent license is granted by a
-Contributor:
-
-(a) for any code that a Contributor has removed from Covered Software;
- or
-
-(b) for infringements caused by: (i) Your and any other third party's
- modifications of Covered Software, or (ii) the combination of its
- Contributions with other software (except as part of its Contributor
- Version); or
-
-(c) under Patent Claims infringed by Covered Software in the absence of
- its Contributions.
-
-This License does not grant any rights in the trademarks, service marks,
-or logos of any Contributor (except as may be necessary to comply with
-the notice requirements in Section 3.4).
-
-2.4. Subsequent Licenses
-
-No Contributor makes additional grants as a result of Your choice to
-distribute the Covered Software under a subsequent version of this
-License (see Section 10.2) or under the terms of a Secondary License (if
-permitted under the terms of Section 3.3).
-
-2.5. Representation
-
-Each Contributor represents that the Contributor believes its
-Contributions are its original creation(s) or it has sufficient rights
-to grant the rights to its Contributions conveyed by this License.
-
-2.6. Fair Use
-
-This License is not intended to limit any rights You have under
-applicable copyright doctrines of fair use, fair dealing, or other
-equivalents.
-
-2.7. Conditions
-
-Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
-in Section 2.1.
-
-3. Responsibilities
--------------------
-
-3.1. Distribution of Source Form
-
-All distribution of Covered Software in Source Code Form, including any
-Modifications that You create or to which You contribute, must be under
-the terms of this License. You must inform recipients that the Source
-Code Form of the Covered Software is governed by the terms of this
-License, and how they can obtain a copy of this License. You may not
-attempt to alter or restrict the recipients' rights in the Source Code
-Form.
-
-3.2. Distribution of Executable Form
-
-If You distribute Covered Software in Executable Form then:
-
-(a) such Covered Software must also be made available in Source Code
- Form, as described in Section 3.1, and You must inform recipients of
- the Executable Form how they can obtain a copy of such Source Code
- Form by reasonable means in a timely manner, at a charge no more
- than the cost of distribution to the recipient; and
-
-(b) You may distribute such Executable Form under the terms of this
- License, or sublicense it under different terms, provided that the
- license for the Executable Form does not attempt to limit or alter
- the recipients' rights in the Source Code Form under this License.
-
-3.3. Distribution of a Larger Work
-
-You may create and distribute a Larger Work under terms of Your choice,
-provided that You also comply with the requirements of this License for
-the Covered Software. If the Larger Work is a combination of Covered
-Software with a work governed by one or more Secondary Licenses, and the
-Covered Software is not Incompatible With Secondary Licenses, this
-License permits You to additionally distribute such Covered Software
-under the terms of such Secondary License(s), so that the recipient of
-the Larger Work may, at their option, further distribute the Covered
-Software under the terms of either this License or such Secondary
-License(s).
-
-3.4. Notices
-
-You may not remove or alter the substance of any license notices
-(including copyright notices, patent notices, disclaimers of warranty,
-or limitations of liability) contained within the Source Code Form of
-the Covered Software, except that You may alter any license notices to
-the extent required to remedy known factual inaccuracies.
-
-3.5. Application of Additional Terms
-
-You may choose to offer, and to charge a fee for, warranty, support,
-indemnity or liability obligations to one or more recipients of Covered
-Software. However, You may do so only on Your own behalf, and not on
-behalf of any Contributor. You must make it absolutely clear that any
-such warranty, support, indemnity, or liability obligation is offered by
-You alone, and You hereby agree to indemnify every Contributor for any
-liability incurred by such Contributor as a result of warranty, support,
-indemnity or liability terms You offer. You may include additional
-disclaimers of warranty and limitations of liability specific to any
-jurisdiction.
-
-4. Inability to Comply Due to Statute or Regulation
----------------------------------------------------
-
-If it is impossible for You to comply with any of the terms of this
-License with respect to some or all of the Covered Software due to
-statute, judicial order, or regulation then You must: (a) comply with
-the terms of this License to the maximum extent possible; and (b)
-describe the limitations and the code they affect. Such description must
-be placed in a text file included with all distributions of the Covered
-Software under this License. Except to the extent prohibited by statute
-or regulation, such description must be sufficiently detailed for a
-recipient of ordinary skill to be able to understand it.
-
-5. Termination
---------------
-
-5.1. The rights granted under this License will terminate automatically
-if You fail to comply with any of its terms. However, if You become
-compliant, then the rights granted under this License from a particular
-Contributor are reinstated (a) provisionally, unless and until such
-Contributor explicitly and finally terminates Your grants, and (b) on an
-ongoing basis, if such Contributor fails to notify You of the
-non-compliance by some reasonable means prior to 60 days after You have
-come back into compliance. Moreover, Your grants from a particular
-Contributor are reinstated on an ongoing basis if such Contributor
-notifies You of the non-compliance by some reasonable means, this is the
-first time You have received notice of non-compliance with this License
-from such Contributor, and You become compliant prior to 30 days after
-Your receipt of the notice.
-
-5.2. If You initiate litigation against any entity by asserting a patent
-infringement claim (excluding declaratory judgment actions,
-counter-claims, and cross-claims) alleging that a Contributor Version
-directly or indirectly infringes any patent, then the rights granted to
-You by any and all Contributors for the Covered Software under Section
-2.1 of this License shall terminate.
-
-5.3. In the event of termination under Sections 5.1 or 5.2 above, all
-end user license agreements (excluding distributors and resellers) which
-have been validly granted by You or Your distributors under this License
-prior to termination shall survive termination.
-
-************************************************************************
-* *
-* 6. Disclaimer of Warranty *
-* ------------------------- *
-* *
-* Covered Software is provided under this License on an "as is" *
-* basis, without warranty of any kind, either expressed, implied, or *
-* statutory, including, without limitation, warranties that the *
-* Covered Software is free of defects, merchantable, fit for a *
-* particular purpose or non-infringing. The entire risk as to the *
-* quality and performance of the Covered Software is with You. *
-* Should any Covered Software prove defective in any respect, You *
-* (not any Contributor) assume the cost of any necessary servicing, *
-* repair, or correction. This disclaimer of warranty constitutes an *
-* essential part of this License. No use of any Covered Software is *
-* authorized under this License except under this disclaimer. *
-* *
-************************************************************************
-
-************************************************************************
-* *
-* 7. Limitation of Liability *
-* -------------------------- *
-* *
-* Under no circumstances and under no legal theory, whether tort *
-* (including negligence), contract, or otherwise, shall any *
-* Contributor, or anyone who distributes Covered Software as *
-* permitted above, be liable to You for any direct, indirect, *
-* special, incidental, or consequential damages of any character *
-* including, without limitation, damages for lost profits, loss of *
-* goodwill, work stoppage, computer failure or malfunction, or any *
-* and all other commercial damages or losses, even if such party *
-* shall have been informed of the possibility of such damages. This *
-* limitation of liability shall not apply to liability for death or *
-* personal injury resulting from such party's negligence to the *
-* extent applicable law prohibits such limitation. Some *
-* jurisdictions do not allow the exclusion or limitation of *
-* incidental or consequential damages, so this exclusion and *
-* limitation may not apply to You. *
-* *
-************************************************************************
-
-8. Litigation
--------------
-
-Any litigation relating to this License may be brought only in the
-courts of a jurisdiction where the defendant maintains its principal
-place of business and such litigation shall be governed by laws of that
-jurisdiction, without reference to its conflict-of-law provisions.
-Nothing in this Section shall prevent a party's ability to bring
-cross-claims or counter-claims.
-
-9. Miscellaneous
-----------------
-
-This License represents the complete agreement concerning the subject
-matter hereof. If any provision of this License is held to be
-unenforceable, such provision shall be reformed only to the extent
-necessary to make it enforceable. Any law or regulation which provides
-that the language of a contract shall be construed against the drafter
-shall not be used to construe this License against a Contributor.
-
-10. Versions of the License
----------------------------
-
-10.1. New Versions
-
-Mozilla Foundation is the license steward. Except as provided in Section
-10.3, no one other than the license steward has the right to modify or
-publish new versions of this License. Each version will be given a
-distinguishing version number.
-
-10.2. Effect of New Versions
-
-You may distribute the Covered Software under the terms of the version
-of the License under which You originally received the Covered Software,
-or under the terms of any subsequent version published by the license
-steward.
-
-10.3. Modified Versions
-
-If you create software not governed by this License, and you want to
-create a new license for such software, you may create and use a
-modified version of this License if you rename the license and remove
-any references to the name of the license steward (except to note that
-such modified license differs from this License).
-
-10.4. Distributing Source Code Form that is Incompatible With Secondary
-Licenses
-
-If You choose to distribute Source Code Form that is Incompatible With
-Secondary Licenses under the terms of this version of the License, the
-notice described in Exhibit B of this License must be attached.
-
-Exhibit A - Source Code Form License Notice
--------------------------------------------
-
- This Source Code Form is subject to the terms of the Mozilla Public
- License, v. 2.0. If a copy of the MPL was not distributed with this
- file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-If it is not possible or desirable to put the notice in a particular
-file, then You may include the notice in a location (such as a LICENSE
-file in a relevant directory) where a recipient would be likely to look
-for such a notice.
-
-You may add additional accurate notices of copyright ownership.
-
-Exhibit B - "Incompatible With Secondary Licenses" Notice
----------------------------------------------------------
-
- This Source Code Form is "Incompatible With Secondary Licenses", as
- defined by the Mozilla Public License, v. 2.0.
\ No newline at end of file
diff --git a/inference/interact/fbrs/controller.py b/inference/interact/fbrs/controller.py
deleted file mode 100644
index 57a0a9b..0000000
--- a/inference/interact/fbrs/controller.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import torch
-
-from ..fbrs.inference import clicker
-from ..fbrs.inference.predictors import get_predictor
-
-
-class InteractiveController:
- def __init__(self, net, device, predictor_params, prob_thresh=0.5):
- self.net = net.to(device)
- self.prob_thresh = prob_thresh
- self.clicker = clicker.Clicker()
- self.states = []
- self.probs_history = []
- self.object_count = 0
- self._result_mask = None
-
- self.image = None
- self.predictor = None
- self.device = device
- self.predictor_params = predictor_params
- self.reset_predictor()
-
- def set_image(self, image):
- self.image = image
- self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8)
- self.object_count = 0
- self.reset_last_object()
-
- def add_click(self, x, y, is_positive):
- self.states.append({
- 'clicker': self.clicker.get_state(),
- 'predictor': self.predictor.get_states()
- })
-
- click = clicker.Click(is_positive=is_positive, coords=(y, x))
- self.clicker.add_click(click)
- pred = self.predictor.get_prediction(self.clicker)
- torch.cuda.empty_cache()
-
- if self.probs_history:
- self.probs_history.append((self.probs_history[-1][0], pred))
- else:
- self.probs_history.append((torch.zeros_like(pred), pred))
-
- def undo_click(self):
- if not self.states:
- return
-
- prev_state = self.states.pop()
- self.clicker.set_state(prev_state['clicker'])
- self.predictor.set_states(prev_state['predictor'])
- self.probs_history.pop()
-
- def partially_finish_object(self):
- object_prob = self.current_object_prob
- if object_prob is None:
- return
-
- self.probs_history.append((object_prob, torch.zeros_like(object_prob)))
- self.states.append(self.states[-1])
-
- self.clicker.reset_clicks()
- self.reset_predictor()
-
- def finish_object(self):
- object_prob = self.current_object_prob
- if object_prob is None:
- return
-
- self.object_count += 1
- object_mask = object_prob > self.prob_thresh
- self._result_mask[object_mask] = self.object_count
- self.reset_last_object()
-
- def reset_last_object(self):
- self.states = []
- self.probs_history = []
- self.clicker.reset_clicks()
- self.reset_predictor()
-
- def reset_predictor(self, predictor_params=None):
- if predictor_params is not None:
- self.predictor_params = predictor_params
- self.predictor = get_predictor(self.net, device=self.device,
- **self.predictor_params)
- if self.image is not None:
- self.predictor.set_input_image(self.image)
-
- @property
- def current_object_prob(self):
- if self.probs_history:
- current_prob_total, current_prob_additive = self.probs_history[-1]
- return torch.maximum(current_prob_total, current_prob_additive)
- else:
- return None
-
- @property
- def is_incomplete_mask(self):
- return len(self.probs_history) > 0
-
- @property
- def result_mask(self):
- return self._result_mask.clone()
diff --git a/inference/interact/fbrs/inference/__init__.py b/inference/interact/fbrs/inference/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/inference/interact/fbrs/inference/clicker.py b/inference/interact/fbrs/inference/clicker.py
deleted file mode 100644
index e1ea9cf..0000000
--- a/inference/interact/fbrs/inference/clicker.py
+++ /dev/null
@@ -1,103 +0,0 @@
-from collections import namedtuple
-
-import numpy as np
-from copy import deepcopy
-from scipy.ndimage import distance_transform_edt
-
-Click = namedtuple('Click', ['is_positive', 'coords'])
-
-
-class Clicker(object):
- def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1):
- if gt_mask is not None:
- self.gt_mask = gt_mask == 1
- self.not_ignore_mask = gt_mask != ignore_label
- else:
- self.gt_mask = None
-
- self.reset_clicks()
-
- if init_clicks is not None:
- for click in init_clicks:
- self.add_click(click)
-
- def make_next_click(self, pred_mask):
- assert self.gt_mask is not None
- click = self._get_click(pred_mask)
- self.add_click(click)
-
- def get_clicks(self, clicks_limit=None):
- return self.clicks_list[:clicks_limit]
-
- def _get_click(self, pred_mask, padding=True):
- fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
- fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
-
- if padding:
- fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
- fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
-
- fn_mask_dt = distance_transform_edt(fn_mask)
- fp_mask_dt = distance_transform_edt(fp_mask)
-
- if padding:
- fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
- fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
-
- fn_mask_dt = fn_mask_dt * self.not_clicked_map
- fp_mask_dt = fp_mask_dt * self.not_clicked_map
-
- fn_max_dist = np.max(fn_mask_dt)
- fp_max_dist = np.max(fp_mask_dt)
-
- is_positive = fn_max_dist > fp_max_dist
- if is_positive:
- coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
- else:
- coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
-
- return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
-
- def add_click(self, click):
- coords = click.coords
-
- if click.is_positive:
- self.num_pos_clicks += 1
- else:
- self.num_neg_clicks += 1
-
- self.clicks_list.append(click)
- if self.gt_mask is not None:
- self.not_clicked_map[coords[0], coords[1]] = False
-
- def _remove_last_click(self):
- click = self.clicks_list.pop()
- coords = click.coords
-
- if click.is_positive:
- self.num_pos_clicks -= 1
- else:
- self.num_neg_clicks -= 1
-
- if self.gt_mask is not None:
- self.not_clicked_map[coords[0], coords[1]] = True
-
- def reset_clicks(self):
- if self.gt_mask is not None:
- self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
-
- self.num_pos_clicks = 0
- self.num_neg_clicks = 0
-
- self.clicks_list = []
-
- def get_state(self):
- return deepcopy(self.clicks_list)
-
- def set_state(self, state):
- self.reset_clicks()
- for click in state:
- self.add_click(click)
-
- def __len__(self):
- return len(self.clicks_list)
diff --git a/inference/interact/fbrs/inference/evaluation.py b/inference/interact/fbrs/inference/evaluation.py
deleted file mode 100644
index 6be3ed8..0000000
--- a/inference/interact/fbrs/inference/evaluation.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from time import time
-
-import numpy as np
-import torch
-
-from ..inference import utils
-from ..inference.clicker import Clicker
-
-try:
- get_ipython()
- from tqdm import tqdm_notebook as tqdm
-except NameError:
- from tqdm import tqdm
-
-
-def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs):
- all_ious = []
-
- start_time = time()
- for index in tqdm(range(len(dataset)), leave=False):
- sample = dataset.get_sample(index)
- item = dataset[index]
-
- if oracle_eval:
- gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32)
- gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
- predictor.opt_functor.mask_loss.set_gt_mask(gt_mask)
- _, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs)
- all_ious.append(sample_ious)
- end_time = time()
- elapsed_time = end_time - start_time
-
- return all_ious, elapsed_time
-
-
-def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr,
- pred_thr=0.49, max_clicks=20):
- clicker = Clicker(gt_mask=instances_mask)
- pred_mask = np.zeros_like(instances_mask)
- ious_list = []
-
- with torch.no_grad():
- predictor.set_input_image(image_nd)
-
- for click_number in range(max_clicks):
- clicker.make_next_click(pred_mask)
- pred_probs = predictor.get_prediction(clicker)
- pred_mask = pred_probs > pred_thr
-
- iou = utils.get_iou(instances_mask, pred_mask)
- ious_list.append(iou)
-
- if iou >= max_iou_thr:
- break
-
- return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
diff --git a/inference/interact/fbrs/inference/predictors/__init__.py b/inference/interact/fbrs/inference/predictors/__init__.py
deleted file mode 100644
index 04b8b86..0000000
--- a/inference/interact/fbrs/inference/predictors/__init__.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from .base import BasePredictor
-from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
-from .brs_functors import InputOptimizer, ScaleBiasOptimizer
-from ..transforms import ZoomIn
-from ...model.is_hrnet_model import DistMapsHRNetModel
-
-
-def get_predictor(net, brs_mode, device,
- prob_thresh=0.49,
- with_flip=True,
- zoom_in_params=dict(),
- predictor_params=None,
- brs_opt_func_params=None,
- lbfgs_params=None):
- lbfgs_params_ = {
- 'm': 20,
- 'factr': 0,
- 'pgtol': 1e-8,
- 'maxfun': 20,
- }
-
- predictor_params_ = {
- 'optimize_after_n_clicks': 1
- }
-
- if zoom_in_params is not None:
- zoom_in = ZoomIn(**zoom_in_params)
- else:
- zoom_in = None
-
- if lbfgs_params is not None:
- lbfgs_params_.update(lbfgs_params)
- lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
-
- if brs_opt_func_params is None:
- brs_opt_func_params = dict()
-
- if brs_mode == 'NoBRS':
- if predictor_params is not None:
- predictor_params_.update(predictor_params)
- predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
- elif brs_mode.startswith('f-BRS'):
- predictor_params_.update({
- 'net_clicks_limit': 8,
- })
- if predictor_params is not None:
- predictor_params_.update(predictor_params)
-
- insertion_mode = {
- 'f-BRS-A': 'after_c4',
- 'f-BRS-B': 'after_aspp',
- 'f-BRS-C': 'after_deeplab'
- }[brs_mode]
-
- opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
- with_flip=with_flip,
- optimizer_params=lbfgs_params_,
- **brs_opt_func_params)
-
- if isinstance(net, DistMapsHRNetModel):
- FeaturePredictor = HRNetFeatureBRSPredictor
- insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
- else:
- FeaturePredictor = FeatureBRSPredictor
-
- predictor = FeaturePredictor(net, device,
- opt_functor=opt_functor,
- with_flip=with_flip,
- insertion_mode=insertion_mode,
- zoom_in=zoom_in,
- **predictor_params_)
- elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
- use_dmaps = brs_mode == 'DistMap-BRS'
-
- predictor_params_.update({
- 'net_clicks_limit': 5,
- })
- if predictor_params is not None:
- predictor_params_.update(predictor_params)
-
- opt_functor = InputOptimizer(prob_thresh=prob_thresh,
- with_flip=with_flip,
- optimizer_params=lbfgs_params_,
- **brs_opt_func_params)
-
- predictor = InputBRSPredictor(net, device,
- optimize_target='dmaps' if use_dmaps else 'rgb',
- opt_functor=opt_functor,
- with_flip=with_flip,
- zoom_in=zoom_in,
- **predictor_params_)
- else:
- raise NotImplementedError
-
- return predictor
diff --git a/inference/interact/fbrs/inference/predictors/base.py b/inference/interact/fbrs/inference/predictors/base.py
deleted file mode 100644
index 3776506..0000000
--- a/inference/interact/fbrs/inference/predictors/base.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import torch
-import torch.nn.functional as F
-
-from ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
-
-
-class BasePredictor(object):
- def __init__(self, net, device,
- net_clicks_limit=None,
- with_flip=False,
- zoom_in=None,
- max_size=None,
- **kwargs):
- self.net = net
- self.with_flip = with_flip
- self.net_clicks_limit = net_clicks_limit
- self.original_image = None
- self.device = device
- self.zoom_in = zoom_in
-
- self.transforms = [zoom_in] if zoom_in is not None else []
- if max_size is not None:
- self.transforms.append(LimitLongestSide(max_size=max_size))
- self.transforms.append(SigmoidForPred())
- if with_flip:
- self.transforms.append(AddHorizontalFlip())
-
- def set_input_image(self, image_nd):
- for transform in self.transforms:
- transform.reset()
- self.original_image = image_nd.to(self.device)
- if len(self.original_image.shape) == 3:
- self.original_image = self.original_image.unsqueeze(0)
-
- def get_prediction(self, clicker):
- clicks_list = clicker.get_clicks()
-
- image_nd, clicks_lists, is_image_changed = self.apply_transforms(
- self.original_image, [clicks_list]
- )
-
- pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
- prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
- size=image_nd.size()[2:])
-
- for t in reversed(self.transforms):
- prediction = t.inv_transform(prediction)
-
- if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
- print('zooming')
- return self.get_prediction(clicker)
-
- # return prediction.cpu().numpy()[0, 0]
- return prediction
-
- def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
- points_nd = self.get_points_nd(clicks_lists)
- return self.net(image_nd, points_nd)['instances']
-
- def _get_transform_states(self):
- return [x.get_state() for x in self.transforms]
-
- def _set_transform_states(self, states):
- assert len(states) == len(self.transforms)
- for state, transform in zip(states, self.transforms):
- transform.set_state(state)
-
- def apply_transforms(self, image_nd, clicks_lists):
- is_image_changed = False
- for t in self.transforms:
- image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
- is_image_changed |= t.image_changed
-
- return image_nd, clicks_lists, is_image_changed
-
- def get_points_nd(self, clicks_lists):
- total_clicks = []
- num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
- num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
- num_max_points = max(num_pos_clicks + num_neg_clicks)
- if self.net_clicks_limit is not None:
- num_max_points = min(self.net_clicks_limit, num_max_points)
- num_max_points = max(1, num_max_points)
-
- for clicks_list in clicks_lists:
- clicks_list = clicks_list[:self.net_clicks_limit]
- pos_clicks = [click.coords for click in clicks_list if click.is_positive]
- pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)]
-
- neg_clicks = [click.coords for click in clicks_list if not click.is_positive]
- neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)]
- total_clicks.append(pos_clicks + neg_clicks)
-
- return torch.tensor(total_clicks, device=self.device)
-
- def get_states(self):
- return {'transform_states': self._get_transform_states()}
-
- def set_states(self, states):
- self._set_transform_states(states['transform_states'])
diff --git a/inference/interact/fbrs/inference/predictors/brs.py b/inference/interact/fbrs/inference/predictors/brs.py
deleted file mode 100644
index bfc7296..0000000
--- a/inference/interact/fbrs/inference/predictors/brs.py
+++ /dev/null
@@ -1,280 +0,0 @@
-import torch
-import torch.nn.functional as F
-import numpy as np
-from scipy.optimize import fmin_l_bfgs_b
-
-from .base import BasePredictor
-from ...model.is_hrnet_model import DistMapsHRNetModel
-
-
-class BRSBasePredictor(BasePredictor):
- def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
- super().__init__(model, device, **kwargs)
- self.optimize_after_n_clicks = optimize_after_n_clicks
- self.opt_functor = opt_functor
-
- self.opt_data = None
- self.input_data = None
-
- def set_input_image(self, image_nd):
- super().set_input_image(image_nd)
- self.opt_data = None
- self.input_data = None
-
- def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
- pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
- neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
-
- for list_indx, clicks_list in enumerate(clicks_lists):
- for click in clicks_list:
- y, x = click.coords
- y, x = int(round(y)), int(round(x))
- y1, x1 = y - radius, x - radius
- y2, x2 = y + radius + 1, x + radius + 1
-
- if click.is_positive:
- pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
- else:
- neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
-
- with torch.no_grad():
- pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
- neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
-
- return pos_clicks_map, neg_clicks_map
-
- def get_states(self):
- return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
-
- def set_states(self, states):
- self._set_transform_states(states['transform_states'])
- self.opt_data = states['opt_data']
-
-
-class FeatureBRSPredictor(BRSBasePredictor):
- def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
- super().__init__(model, device, opt_functor=opt_functor, **kwargs)
- self.insertion_mode = insertion_mode
- self._c1_features = None
-
- if self.insertion_mode == 'after_deeplab':
- self.num_channels = model.feature_extractor.ch
- elif self.insertion_mode == 'after_c4':
- self.num_channels = model.feature_extractor.aspp_in_channels
- elif self.insertion_mode == 'after_aspp':
- self.num_channels = model.feature_extractor.ch + 32
- else:
- raise NotImplementedError
-
- def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
- points_nd = self.get_points_nd(clicks_lists)
- pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
-
- num_clicks = len(clicks_lists[0])
- bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
-
- if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
- self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
-
- if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
- self.input_data = self._get_head_input(image_nd, points_nd)
-
- def get_prediction_logits(scale, bias):
- scale = scale.view(bs, -1, 1, 1)
- bias = bias.view(bs, -1, 1, 1)
- if self.with_flip:
- scale = scale.repeat(2, 1, 1, 1)
- bias = bias.repeat(2, 1, 1, 1)
-
- scaled_backbone_features = self.input_data * scale
- scaled_backbone_features = scaled_backbone_features + bias
- if self.insertion_mode == 'after_c4':
- x = self.net.feature_extractor.aspp(scaled_backbone_features)
- x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
- align_corners=True)
- x = torch.cat((x, self._c1_features), dim=1)
- scaled_backbone_features = self.net.feature_extractor.head(x)
- elif self.insertion_mode == 'after_aspp':
- scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
-
- pred_logits = self.net.head(scaled_backbone_features)
- pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
- align_corners=True)
- return pred_logits
-
- self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
- if num_clicks > self.optimize_after_n_clicks:
- opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
- **self.opt_functor.optimizer_params)
- self.opt_data = opt_result[0]
-
- with torch.no_grad():
- if self.opt_functor.best_prediction is not None:
- opt_pred_logits = self.opt_functor.best_prediction
- else:
- opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
- opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
- opt_pred_logits = get_prediction_logits(*opt_vars)
-
- return opt_pred_logits
-
- def _get_head_input(self, image_nd, points):
- with torch.no_grad():
- coord_features = self.net.dist_maps(image_nd, points)
- x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
- if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
- c1, _, c3, c4 = self.net.feature_extractor.backbone(x)
- c1 = self.net.feature_extractor.skip_project(c1)
-
- if self.insertion_mode == 'after_aspp':
- x = self.net.feature_extractor.aspp(c4)
- x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
- x = torch.cat((x, c1), dim=1)
- backbone_features = x
- else:
- backbone_features = c4
- self._c1_features = c1
- else:
- backbone_features = self.net.feature_extractor(x)[0]
-
- return backbone_features
-
-
-class HRNetFeatureBRSPredictor(BRSBasePredictor):
- def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
- super().__init__(model, device, opt_functor=opt_functor, **kwargs)
- self.insertion_mode = insertion_mode
- self._c1_features = None
-
- if self.insertion_mode == 'A':
- self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
- elif self.insertion_mode == 'C':
- self.num_channels = 2 * model.feature_extractor.ocr_width
- else:
- raise NotImplementedError
-
- def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
- points_nd = self.get_points_nd(clicks_lists)
- pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
- num_clicks = len(clicks_lists[0])
- bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
-
- if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
- self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
-
- if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
- self.input_data = self._get_head_input(image_nd, points_nd)
-
- def get_prediction_logits(scale, bias):
- scale = scale.view(bs, -1, 1, 1)
- bias = bias.view(bs, -1, 1, 1)
- if self.with_flip:
- scale = scale.repeat(2, 1, 1, 1)
- bias = bias.repeat(2, 1, 1, 1)
-
- scaled_backbone_features = self.input_data * scale
- scaled_backbone_features = scaled_backbone_features + bias
- if self.insertion_mode == 'A':
- out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
- feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
-
- context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
- feats = self.net.feature_extractor.ocr_distri_head(feats, context)
- pred_logits = self.net.feature_extractor.cls_head(feats)
- elif self.insertion_mode == 'C':
- pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
- else:
- raise NotImplementedError
-
- pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
- align_corners=True)
- return pred_logits
-
- self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
- if num_clicks > self.optimize_after_n_clicks:
- opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
- **self.opt_functor.optimizer_params)
- self.opt_data = opt_result[0]
-
- with torch.no_grad():
- if self.opt_functor.best_prediction is not None:
- opt_pred_logits = self.opt_functor.best_prediction
- else:
- opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
- opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
- opt_pred_logits = get_prediction_logits(*opt_vars)
-
- return opt_pred_logits
-
- def _get_head_input(self, image_nd, points):
- with torch.no_grad():
- coord_features = self.net.dist_maps(image_nd, points)
- x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
- feats = self.net.feature_extractor.compute_hrnet_feats(x)
- if self.insertion_mode == 'A':
- backbone_features = feats
- elif self.insertion_mode == 'C':
- out_aux = self.net.feature_extractor.aux_head(feats)
- feats = self.net.feature_extractor.conv3x3_ocr(feats)
-
- context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
- backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
- else:
- raise NotImplementedError
-
- return backbone_features
-
-
-class InputBRSPredictor(BRSBasePredictor):
- def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
- super().__init__(model, device, opt_functor=opt_functor, **kwargs)
- self.optimize_target = optimize_target
-
- def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
- points_nd = self.get_points_nd(clicks_lists)
- pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
- num_clicks = len(clicks_lists[0])
-
- if self.opt_data is None or is_image_changed:
- opt_channels = 2 if self.optimize_target == 'dmaps' else 3
- bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
- self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
- device=self.device, dtype=torch.float32)
-
- def get_prediction_logits(opt_bias):
- input_image = image_nd
- if self.optimize_target == 'rgb':
- input_image = input_image + opt_bias
- dmaps = self.net.dist_maps(input_image, points_nd)
- if self.optimize_target == 'dmaps':
- dmaps = dmaps + opt_bias
-
- x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
- if self.optimize_target == 'all':
- x = x + opt_bias
-
- if isinstance(self.net, DistMapsHRNetModel):
- pred_logits = self.net.feature_extractor(x)[0]
- else:
- backbone_features = self.net.feature_extractor(x)
- pred_logits = self.net.head(backbone_features[0])
- pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
-
- return pred_logits
-
- self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
- shape=self.opt_data.shape)
- if num_clicks > self.optimize_after_n_clicks:
- opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
- **self.opt_functor.optimizer_params)
-
- self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
-
- with torch.no_grad():
- if self.opt_functor.best_prediction is not None:
- opt_pred_logits = self.opt_functor.best_prediction
- else:
- opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
- opt_pred_logits = get_prediction_logits(*opt_vars)
-
- return opt_pred_logits
diff --git a/inference/interact/fbrs/inference/predictors/brs_functors.py b/inference/interact/fbrs/inference/predictors/brs_functors.py
deleted file mode 100644
index 0e6eb90..0000000
--- a/inference/interact/fbrs/inference/predictors/brs_functors.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import torch
-import numpy as np
-
-from ...model.metrics import _compute_iou
-from .brs_losses import BRSMaskLoss
-
-
-class BaseOptimizer:
- def __init__(self, optimizer_params,
- prob_thresh=0.49,
- reg_weight=1e-3,
- min_iou_diff=0.01,
- brs_loss=BRSMaskLoss(),
- with_flip=False,
- flip_average=False,
- **kwargs):
- self.brs_loss = brs_loss
- self.optimizer_params = optimizer_params
- self.prob_thresh = prob_thresh
- self.reg_weight = reg_weight
- self.min_iou_diff = min_iou_diff
- self.with_flip = with_flip
- self.flip_average = flip_average
-
- self.best_prediction = None
- self._get_prediction_logits = None
- self._opt_shape = None
- self._best_loss = None
- self._click_masks = None
- self._last_mask = None
- self.device = None
-
- def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None):
- self.best_prediction = None
- self._get_prediction_logits = get_prediction_logits
- self._click_masks = (pos_mask, neg_mask)
- self._opt_shape = shape
- self._last_mask = None
- self.device = device
-
- def __call__(self, x):
- opt_params = torch.from_numpy(x).float().to(self.device)
- opt_params.requires_grad_(True)
-
- with torch.enable_grad():
- opt_vars, reg_loss = self.unpack_opt_params(opt_params)
- result_before_sigmoid = self._get_prediction_logits(*opt_vars)
- result = torch.sigmoid(result_before_sigmoid)
-
- pos_mask, neg_mask = self._click_masks
- if self.with_flip and self.flip_average:
- result, result_flipped = torch.chunk(result, 2, dim=0)
- result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
- pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
-
- loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
- loss = loss + reg_loss
-
- f_val = loss.detach().cpu().numpy()
- if self.best_prediction is None or f_val < self._best_loss:
- self.best_prediction = result_before_sigmoid.detach()
- self._best_loss = f_val
-
- if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh:
- return [f_val, np.zeros_like(x)]
-
- current_mask = result > self.prob_thresh
- if self._last_mask is not None and self.min_iou_diff > 0:
- diff_iou = _compute_iou(current_mask, self._last_mask)
- if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff:
- return [f_val, np.zeros_like(x)]
- self._last_mask = current_mask
-
- loss.backward()
- f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float32)
-
- return [f_val, f_grad]
-
- def unpack_opt_params(self, opt_params):
- raise NotImplementedError
-
-
-class InputOptimizer(BaseOptimizer):
- def unpack_opt_params(self, opt_params):
- opt_params = opt_params.view(self._opt_shape)
- if self.with_flip:
- opt_params_flipped = torch.flip(opt_params, dims=[3])
- opt_params = torch.cat([opt_params, opt_params_flipped], dim=0)
- reg_loss = self.reg_weight * torch.sum(opt_params**2)
-
- return (opt_params,), reg_loss
-
-
-class ScaleBiasOptimizer(BaseOptimizer):
- def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs):
- super().__init__(*args, **kwargs)
- self.scale_act = scale_act
- self.reg_bias_weight = reg_bias_weight
-
- def unpack_opt_params(self, opt_params):
- scale, bias = torch.chunk(opt_params, 2, dim=0)
- reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
-
- if self.scale_act == 'tanh':
- scale = torch.tanh(scale)
- elif self.scale_act == 'sin':
- scale = torch.sin(scale)
-
- return (1 + scale, bias), reg_loss
diff --git a/inference/interact/fbrs/inference/predictors/brs_losses.py b/inference/interact/fbrs/inference/predictors/brs_losses.py
deleted file mode 100644
index 0d9998a..0000000
--- a/inference/interact/fbrs/inference/predictors/brs_losses.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import torch
-
-from ...model.losses import SigmoidBinaryCrossEntropyLoss
-
-
-class BRSMaskLoss(torch.nn.Module):
- def __init__(self, eps=1e-5):
- super().__init__()
- self._eps = eps
-
- def forward(self, result, pos_mask, neg_mask):
- pos_diff = (1 - result) * pos_mask
- pos_target = torch.sum(pos_diff ** 2)
- pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
-
- neg_diff = result * neg_mask
- neg_target = torch.sum(neg_diff ** 2)
- neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
-
- loss = pos_target + neg_target
-
- with torch.no_grad():
- f_max_pos = torch.max(torch.abs(pos_diff)).item()
- f_max_neg = torch.max(torch.abs(neg_diff)).item()
-
- return loss, f_max_pos, f_max_neg
-
-
-class OracleMaskLoss(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.gt_mask = None
- self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
- self.predictor = None
- self.history = []
-
- def set_gt_mask(self, gt_mask):
- self.gt_mask = gt_mask
- self.history = []
-
- def forward(self, result, pos_mask, neg_mask):
- gt_mask = self.gt_mask.to(result.device)
- if self.predictor.object_roi is not None:
- r1, r2, c1, c2 = self.predictor.object_roi[:4]
- gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
- gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
-
- if result.shape[0] == 2:
- gt_mask_flipped = torch.flip(gt_mask, dims=[3])
- gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0)
-
- loss = self.loss(result, gt_mask)
- self.history.append(loss.detach().cpu().numpy()[0])
-
- if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5:
- return 0, 0, 0
-
- return loss, 1.0, 1.0
diff --git a/inference/interact/fbrs/inference/transforms/__init__.py b/inference/interact/fbrs/inference/transforms/__init__.py
deleted file mode 100644
index cbd54e3..0000000
--- a/inference/interact/fbrs/inference/transforms/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .base import SigmoidForPred
-from .flip import AddHorizontalFlip
-from .zoom_in import ZoomIn
-from .limit_longest_side import LimitLongestSide
-from .crops import Crops
diff --git a/inference/interact/fbrs/inference/transforms/base.py b/inference/interact/fbrs/inference/transforms/base.py
deleted file mode 100644
index eb5a2de..0000000
--- a/inference/interact/fbrs/inference/transforms/base.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import torch
-
-
-class BaseTransform(object):
- def __init__(self):
- self.image_changed = False
-
- def transform(self, image_nd, clicks_lists):
- raise NotImplementedError
-
- def inv_transform(self, prob_map):
- raise NotImplementedError
-
- def reset(self):
- raise NotImplementedError
-
- def get_state(self):
- raise NotImplementedError
-
- def set_state(self, state):
- raise NotImplementedError
-
-
-class SigmoidForPred(BaseTransform):
- def transform(self, image_nd, clicks_lists):
- return image_nd, clicks_lists
-
- def inv_transform(self, prob_map):
- return torch.sigmoid(prob_map)
-
- def reset(self):
- pass
-
- def get_state(self):
- return None
-
- def set_state(self, state):
- pass
diff --git a/inference/interact/fbrs/inference/transforms/crops.py b/inference/interact/fbrs/inference/transforms/crops.py
deleted file mode 100644
index 0910a28..0000000
--- a/inference/interact/fbrs/inference/transforms/crops.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import math
-
-import torch
-import numpy as np
-
-from ...inference.clicker import Click
-from .base import BaseTransform
-
-
-class Crops(BaseTransform):
- def __init__(self, crop_size=(320, 480), min_overlap=0.2):
- super().__init__()
- self.crop_height, self.crop_width = crop_size
- self.min_overlap = min_overlap
-
- self.x_offsets = None
- self.y_offsets = None
- self._counts = None
-
- def transform(self, image_nd, clicks_lists):
- assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
- image_height, image_width = image_nd.shape[2:4]
- self._counts = None
-
- if image_height < self.crop_height or image_width < self.crop_width:
- return image_nd, clicks_lists
-
- self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
- self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
- self._counts = np.zeros((image_height, image_width))
-
- image_crops = []
- for dy in self.y_offsets:
- for dx in self.x_offsets:
- self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
- image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
- image_crops.append(image_crop)
- image_crops = torch.cat(image_crops, dim=0)
- self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
-
- clicks_list = clicks_lists[0]
- clicks_lists = []
- for dy in self.y_offsets:
- for dx in self.x_offsets:
- crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx))
- for x in clicks_list]
- clicks_lists.append(crop_clicks)
-
- return image_crops, clicks_lists
-
- def inv_transform(self, prob_map):
- if self._counts is None:
- return prob_map
-
- new_prob_map = torch.zeros((1, 1, *self._counts.shape),
- dtype=prob_map.dtype, device=prob_map.device)
-
- crop_indx = 0
- for dy in self.y_offsets:
- for dx in self.x_offsets:
- new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
- crop_indx += 1
- new_prob_map = torch.div(new_prob_map, self._counts)
-
- return new_prob_map
-
- def get_state(self):
- return self.x_offsets, self.y_offsets, self._counts
-
- def set_state(self, state):
- self.x_offsets, self.y_offsets, self._counts = state
-
- def reset(self):
- self.x_offsets = None
- self.y_offsets = None
- self._counts = None
-
-
-def get_offsets(length, crop_size, min_overlap_ratio=0.2):
- if length == crop_size:
- return [0]
-
- N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
- N = math.ceil(N)
-
- overlap_ratio = (N - length / crop_size) / (N - 1)
- overlap_width = int(crop_size * overlap_ratio)
-
- offsets = [0]
- for i in range(1, N):
- new_offset = offsets[-1] + crop_size - overlap_width
- if new_offset + crop_size > length:
- new_offset = length - crop_size
-
- offsets.append(new_offset)
-
- return offsets
diff --git a/inference/interact/fbrs/inference/transforms/flip.py b/inference/interact/fbrs/inference/transforms/flip.py
deleted file mode 100644
index c1543cb..0000000
--- a/inference/interact/fbrs/inference/transforms/flip.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import torch
-
-from ..clicker import Click
-from .base import BaseTransform
-
-
-class AddHorizontalFlip(BaseTransform):
- def transform(self, image_nd, clicks_lists):
- assert len(image_nd.shape) == 4
- image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
-
- image_width = image_nd.shape[3]
- clicks_lists_flipped = []
- for clicks_list in clicks_lists:
- clicks_list_flipped = [Click(is_positive=click.is_positive,
- coords=(click.coords[0], image_width - click.coords[1] - 1))
- for click in clicks_list]
- clicks_lists_flipped.append(clicks_list_flipped)
- clicks_lists = clicks_lists + clicks_lists_flipped
-
- return image_nd, clicks_lists
-
- def inv_transform(self, prob_map):
- assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
- num_maps = prob_map.shape[0] // 2
- prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
-
- return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
-
- def get_state(self):
- return None
-
- def set_state(self, state):
- pass
-
- def reset(self):
- pass
diff --git a/inference/interact/fbrs/inference/transforms/limit_longest_side.py b/inference/interact/fbrs/inference/transforms/limit_longest_side.py
deleted file mode 100644
index 50c5a53..0000000
--- a/inference/interact/fbrs/inference/transforms/limit_longest_side.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from .zoom_in import ZoomIn, get_roi_image_nd
-
-
-class LimitLongestSide(ZoomIn):
- def __init__(self, max_size=800):
- super().__init__(target_size=max_size, skip_clicks=0)
-
- def transform(self, image_nd, clicks_lists):
- assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
- image_max_size = max(image_nd.shape[2:4])
- self.image_changed = False
-
- if image_max_size <= self.target_size:
- return image_nd, clicks_lists
- self._input_image = image_nd
-
- self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
- self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
- self.image_changed = True
-
- tclicks_lists = [self._transform_clicks(clicks_lists[0])]
- return self._roi_image, tclicks_lists
diff --git a/inference/interact/fbrs/inference/transforms/zoom_in.py b/inference/interact/fbrs/inference/transforms/zoom_in.py
deleted file mode 100644
index 6c11ecc..0000000
--- a/inference/interact/fbrs/inference/transforms/zoom_in.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import torch
-
-from ..clicker import Click
-from ...utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
-from .base import BaseTransform
-
-
-class ZoomIn(BaseTransform):
- def __init__(self,
- target_size=400,
- skip_clicks=1,
- expansion_ratio=1.4,
- min_crop_size=200,
- recompute_thresh_iou=0.5,
- prob_thresh=0.50):
- super().__init__()
- self.target_size = target_size
- self.min_crop_size = min_crop_size
- self.skip_clicks = skip_clicks
- self.expansion_ratio = expansion_ratio
- self.recompute_thresh_iou = recompute_thresh_iou
- self.prob_thresh = prob_thresh
-
- self._input_image_shape = None
- self._prev_probs = None
- self._object_roi = None
- self._roi_image = None
-
- def transform(self, image_nd, clicks_lists):
- assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
- self.image_changed = False
-
- clicks_list = clicks_lists[0]
- if len(clicks_list) <= self.skip_clicks:
- return image_nd, clicks_lists
-
- self._input_image_shape = image_nd.shape
-
- current_object_roi = None
- if self._prev_probs is not None:
- current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
- if current_pred_mask.sum() > 0:
- current_object_roi = get_object_roi(current_pred_mask, clicks_list,
- self.expansion_ratio, self.min_crop_size)
-
- if current_object_roi is None:
- return image_nd, clicks_lists
-
- update_object_roi = False
- if self._object_roi is None:
- update_object_roi = True
- elif not check_object_roi(self._object_roi, clicks_list):
- update_object_roi = True
- elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
- update_object_roi = True
-
- if update_object_roi:
- self._object_roi = current_object_roi
- self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
- self.image_changed = True
-
- tclicks_lists = [self._transform_clicks(clicks_list)]
- return self._roi_image.to(image_nd.device), tclicks_lists
-
- def inv_transform(self, prob_map):
- if self._object_roi is None:
- self._prev_probs = prob_map.cpu().numpy()
- return prob_map
-
- assert prob_map.shape[0] == 1
- rmin, rmax, cmin, cmax = self._object_roi
- prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
- mode='bilinear', align_corners=True)
-
- if self._prev_probs is not None:
- new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
- new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
- else:
- new_prob_map = prob_map
-
- self._prev_probs = new_prob_map.cpu().numpy()
-
- return new_prob_map
-
- def check_possible_recalculation(self):
- if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
- return False
-
- pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
- if pred_mask.sum() > 0:
- possible_object_roi = get_object_roi(pred_mask, [],
- self.expansion_ratio, self.min_crop_size)
- image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
- if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
- return True
- return False
-
- def get_state(self):
- roi_image = self._roi_image.cpu() if self._roi_image is not None else None
- return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
-
- def set_state(self, state):
- self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
-
- def reset(self):
- self._input_image_shape = None
- self._object_roi = None
- self._prev_probs = None
- self._roi_image = None
- self.image_changed = False
-
- def _transform_clicks(self, clicks_list):
- if self._object_roi is None:
- return clicks_list
-
- rmin, rmax, cmin, cmax = self._object_roi
- crop_height, crop_width = self._roi_image.shape[2:]
-
- transformed_clicks = []
- for click in clicks_list:
- new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
- new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
- transformed_clicks.append(Click(is_positive=click.is_positive, coords=(new_r, new_c)))
- return transformed_clicks
-
-
-def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
- pred_mask = pred_mask.copy()
-
- for click in clicks_list:
- if click.is_positive:
- pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
-
- bbox = get_bbox_from_mask(pred_mask)
- bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
- h, w = pred_mask.shape[0], pred_mask.shape[1]
- bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
-
- return bbox
-
-
-def get_roi_image_nd(image_nd, object_roi, target_size):
- rmin, rmax, cmin, cmax = object_roi
-
- height = rmax - rmin + 1
- width = cmax - cmin + 1
-
- if isinstance(target_size, tuple):
- new_height, new_width = target_size
- else:
- scale = target_size / max(height, width)
- new_height = int(round(height * scale))
- new_width = int(round(width * scale))
-
- with torch.no_grad():
- roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
- roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
- mode='bilinear', align_corners=True)
-
- return roi_image_nd
-
-
-def check_object_roi(object_roi, clicks_list):
- for click in clicks_list:
- if click.is_positive:
- if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
- return False
- if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
- return False
-
- return True
diff --git a/inference/interact/fbrs/inference/utils.py b/inference/interact/fbrs/inference/utils.py
deleted file mode 100644
index d1bec96..0000000
--- a/inference/interact/fbrs/inference/utils.py
+++ /dev/null
@@ -1,177 +0,0 @@
-from datetime import timedelta
-from pathlib import Path
-
-import torch
-import numpy as np
-
-from ..model.is_deeplab_model import get_deeplab_model
-from ..model.is_hrnet_model import get_hrnet_model
-
-
-def get_time_metrics(all_ious, elapsed_time):
- n_images = len(all_ious)
- n_clicks = sum(map(len, all_ious))
-
- mean_spc = elapsed_time / n_clicks
- mean_spi = elapsed_time / n_images
-
- return mean_spc, mean_spi
-
-
-def load_is_model(checkpoint, device, backbone='auto', **kwargs):
- if isinstance(checkpoint, (str, Path)):
- state_dict = torch.load(checkpoint, map_location='cpu')
- else:
- state_dict = checkpoint
-
- if backbone == 'auto':
- for k in state_dict.keys():
- if 'feature_extractor.stage2.0.branches' in k:
- return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
- return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
- elif 'resnet' in backbone:
- return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
- elif 'hrnet' in backbone:
- return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
- else:
- raise NotImplementedError('Unknown backbone')
-
-
-def load_hrnet_is_model(state_dict, device, backbone='auto', width=48, ocr_width=256,
- small=False, cpu_dist_maps=False, norm_radius=260):
- if backbone == 'auto':
- num_fe_weights = len([x for x in state_dict.keys() if 'feature_extractor.' in x])
- small = num_fe_weights < 1800
-
- ocr_f_down = [v for k, v in state_dict.items() if 'object_context_block.f_down.1.0.bias' in k]
- assert len(ocr_f_down) == 1
- ocr_width = ocr_f_down[0].shape[0]
-
- s2_conv1_w = [v for k, v in state_dict.items() if 'stage2.0.branches.0.0.conv1.weight' in k]
- assert len(s2_conv1_w) == 1
- width = s2_conv1_w[0].shape[0]
-
- model = get_hrnet_model(width=width, ocr_width=ocr_width, small=small,
- with_aux_output=False, cpu_dist_maps=cpu_dist_maps,
- norm_radius=norm_radius)
-
- model.load_state_dict(state_dict, strict=False)
- for param in model.parameters():
- param.requires_grad = False
- model.to(device)
- model.eval()
-
- return model
-
-
-def load_deeplab_is_model(state_dict, device, backbone='auto', deeplab_ch=128, aspp_dropout=0.2,
- cpu_dist_maps=False, norm_radius=260):
- if backbone == 'auto':
- num_backbone_params = len([x for x in state_dict.keys()
- if 'feature_extractor.backbone' in x and not('num_batches_tracked' in x)])
-
- if num_backbone_params <= 181:
- backbone = 'resnet34'
- elif num_backbone_params <= 276:
- backbone = 'resnet50'
- elif num_backbone_params <= 531:
- backbone = 'resnet101'
- else:
- raise NotImplementedError('Unknown backbone')
-
- if 'aspp_dropout' in state_dict:
- aspp_dropout = float(state_dict['aspp_dropout'].cpu().numpy())
- else:
- aspp_project_weight = [v for k, v in state_dict.items() if 'aspp.project.0.weight' in k][0]
- deeplab_ch = aspp_project_weight.size(0)
- if deeplab_ch == 256:
- aspp_dropout = 0.5
-
- model = get_deeplab_model(backbone=backbone, deeplab_ch=deeplab_ch,
- aspp_dropout=aspp_dropout, cpu_dist_maps=cpu_dist_maps,
- norm_radius=norm_radius)
-
- model.load_state_dict(state_dict, strict=False)
- for param in model.parameters():
- param.requires_grad = False
- model.to(device)
- model.eval()
-
- return model
-
-
-def get_iou(gt_mask, pred_mask, ignore_label=-1):
- ignore_gt_mask_inv = gt_mask != ignore_label
- obj_gt_mask = gt_mask == 1
-
- intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
- union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
-
- return intersection / union
-
-
-def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
- def _get_noc(iou_arr, iou_thr):
- vals = iou_arr >= iou_thr
- return np.argmax(vals) + 1 if np.any(vals) else max_clicks
-
- noc_list = []
- over_max_list = []
- for iou_thr in iou_thrs:
- scores_arr = np.array([_get_noc(iou_arr, iou_thr)
- for iou_arr in all_ious], dtype=np.int32)
-
- score = scores_arr.mean()
- over_max = (scores_arr == max_clicks).sum()
-
- noc_list.append(score)
- over_max_list.append(over_max)
-
- return noc_list, over_max_list
-
-
-def find_checkpoint(weights_folder, checkpoint_name):
- weights_folder = Path(weights_folder)
- if ':' in checkpoint_name:
- model_name, checkpoint_name = checkpoint_name.split(':')
- models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
- assert len(models_candidates) == 1
- model_folder = models_candidates[0]
- else:
- model_folder = weights_folder
-
- if checkpoint_name.endswith('.pth'):
- if Path(checkpoint_name).exists():
- checkpoint_path = checkpoint_name
- else:
- checkpoint_path = weights_folder / checkpoint_name
- else:
- model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
- assert len(model_checkpoints) == 1
- checkpoint_path = model_checkpoints[0]
-
- return str(checkpoint_path)
-
-
-def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
- n_clicks=20, model_name=None):
- table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
- f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
- f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
- f'{"SPC,s":^7}|{"Time":^9}|')
- row_width = len(table_header)
-
- header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
- header += '-' * row_width + '\n'
- header += table_header + '\n' + '-' * row_width
-
- eval_time = str(timedelta(seconds=int(elapsed_time)))
- table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
- table_row += f'{noc_list[0]:^9.2f}|'
- table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
- table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
- table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
- table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
- table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
-
- return header, table_row
\ No newline at end of file
diff --git a/inference/interact/fbrs/model/__init__.py b/inference/interact/fbrs/model/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/inference/interact/fbrs/model/initializer.py b/inference/interact/fbrs/model/initializer.py
deleted file mode 100644
index 470c7df..0000000
--- a/inference/interact/fbrs/model/initializer.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import torch
-import torch.nn as nn
-import numpy as np
-
-
-class Initializer(object):
- def __init__(self, local_init=True, gamma=None):
- self.local_init = local_init
- self.gamma = gamma
-
- def __call__(self, m):
- if getattr(m, '__initialized', False):
- return
-
- if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
- nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
- nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
- if m.weight is not None:
- self._init_gamma(m.weight.data)
- if m.bias is not None:
- self._init_beta(m.bias.data)
- else:
- if getattr(m, 'weight', None) is not None:
- self._init_weight(m.weight.data)
- if getattr(m, 'bias', None) is not None:
- self._init_bias(m.bias.data)
-
- if self.local_init:
- object.__setattr__(m, '__initialized', True)
-
- def _init_weight(self, data):
- nn.init.uniform_(data, -0.07, 0.07)
-
- def _init_bias(self, data):
- nn.init.constant_(data, 0)
-
- def _init_gamma(self, data):
- if self.gamma is None:
- nn.init.constant_(data, 1.0)
- else:
- nn.init.normal_(data, 1.0, self.gamma)
-
- def _init_beta(self, data):
- nn.init.constant_(data, 0)
-
-
-class Bilinear(Initializer):
- def __init__(self, scale, groups, in_channels, **kwargs):
- super().__init__(**kwargs)
- self.scale = scale
- self.groups = groups
- self.in_channels = in_channels
-
- def _init_weight(self, data):
- """Reset the weight and bias."""
- bilinear_kernel = self.get_bilinear_kernel(self.scale)
- weight = torch.zeros_like(data)
- for i in range(self.in_channels):
- if self.groups == 1:
- j = i
- else:
- j = 0
- weight[i, j] = bilinear_kernel
- data[:] = weight
-
- @staticmethod
- def get_bilinear_kernel(scale):
- """Generate a bilinear upsampling kernel."""
- kernel_size = 2 * scale - scale % 2
- scale = (kernel_size + 1) // 2
- center = scale - 0.5 * (1 + kernel_size % 2)
-
- og = np.ogrid[:kernel_size, :kernel_size]
- kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
-
- return torch.tensor(kernel, dtype=torch.float32)
-
-
-class XavierGluon(Initializer):
- def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
- super().__init__(**kwargs)
-
- self.rnd_type = rnd_type
- self.factor_type = factor_type
- self.magnitude = float(magnitude)
-
- def _init_weight(self, arr):
- fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
-
- if self.factor_type == 'avg':
- factor = (fan_in + fan_out) / 2.0
- elif self.factor_type == 'in':
- factor = fan_in
- elif self.factor_type == 'out':
- factor = fan_out
- else:
- raise ValueError('Incorrect factor type')
- scale = np.sqrt(self.magnitude / factor)
-
- if self.rnd_type == 'uniform':
- nn.init.uniform_(arr, -scale, scale)
- elif self.rnd_type == 'gaussian':
- nn.init.normal_(arr, 0, scale)
- else:
- raise ValueError('Unknown random type')
diff --git a/inference/interact/fbrs/model/is_deeplab_model.py b/inference/interact/fbrs/model/is_deeplab_model.py
deleted file mode 100644
index c9a75cc..0000000
--- a/inference/interact/fbrs/model/is_deeplab_model.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import torch
-import torch.nn as nn
-
-from .ops import DistMaps
-from .modeling.deeplab_v3 import DeepLabV3Plus
-from .modeling.basic_blocks import SepConvHead
-
-
-def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
- norm_layer=nn.BatchNorm2d, backbone_norm_layer=None,
- use_rgb_conv=True, cpu_dist_maps=False,
- norm_radius=260):
- model = DistMapsModel(
- feature_extractor=DeepLabV3Plus(backbone=backbone,
- ch=deeplab_ch,
- project_dropout=aspp_dropout,
- norm_layer=norm_layer,
- backbone_norm_layer=backbone_norm_layer),
- head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
- num_layers=2, norm_layer=norm_layer),
- use_rgb_conv=use_rgb_conv,
- norm_layer=norm_layer,
- norm_radius=norm_radius,
- cpu_dist_maps=cpu_dist_maps
- )
-
- return model
-
-
-class DistMapsModel(nn.Module):
- def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True,
- cpu_dist_maps=False, norm_radius=260):
- super(DistMapsModel, self).__init__()
-
- if use_rgb_conv:
- self.rgb_conv = nn.Sequential(
- nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
- nn.LeakyReLU(negative_slope=0.2),
- norm_layer(8),
- nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
- )
- else:
- self.rgb_conv = None
-
- self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
- cpu_mode=cpu_dist_maps)
- self.feature_extractor = feature_extractor
- self.head = head
-
- def forward(self, image, points):
- coord_features = self.dist_maps(image, points)
-
- if self.rgb_conv is not None:
- x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
- else:
- c1, c2 = torch.chunk(coord_features, 2, dim=1)
- c3 = torch.ones_like(c1)
- coord_features = torch.cat((c1, c2, c3), dim=1)
- x = 0.8 * image * coord_features + 0.2 * image
-
- backbone_features = self.feature_extractor(x)
- instance_out = self.head(backbone_features[0])
- instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
- mode='bilinear', align_corners=True)
-
- return {'instances': instance_out}
-
- def load_weights(self, path_to_weights):
- current_state_dict = self.state_dict()
- new_state_dict = torch.load(path_to_weights, map_location='cpu')
- current_state_dict.update(new_state_dict)
- self.load_state_dict(current_state_dict)
-
- def get_trainable_params(self):
- backbone_params = nn.ParameterList()
- other_params = nn.ParameterList()
-
- for name, param in self.named_parameters():
- if param.requires_grad:
- if 'backbone' in name:
- backbone_params.append(param)
- else:
- other_params.append(param)
- return backbone_params, other_params
-
-
diff --git a/inference/interact/fbrs/model/is_hrnet_model.py b/inference/interact/fbrs/model/is_hrnet_model.py
deleted file mode 100644
index ced540a..0000000
--- a/inference/interact/fbrs/model/is_hrnet_model.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import torch
-import torch.nn as nn
-
-from .ops import DistMaps
-from .modeling.hrnet_ocr import HighResolutionNet
-
-
-def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260,
- use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False,
- norm_layer=nn.BatchNorm2d):
- model = DistMapsHRNetModel(
- feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
- num_classes=1, norm_layer=norm_layer),
- use_rgb_conv=use_rgb_conv,
- with_aux_output=with_aux_output,
- norm_layer=norm_layer,
- norm_radius=norm_radius,
- cpu_dist_maps=cpu_dist_maps
- )
-
- return model
-
-
-class DistMapsHRNetModel(nn.Module):
- def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False,
- norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False):
- super(DistMapsHRNetModel, self).__init__()
- self.with_aux_output = with_aux_output
-
- if use_rgb_conv:
- self.rgb_conv = nn.Sequential(
- nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
- nn.LeakyReLU(negative_slope=0.2),
- norm_layer(8),
- nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
- )
- else:
- self.rgb_conv = None
-
- self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps)
- self.feature_extractor = feature_extractor
-
- def forward(self, image, points):
- coord_features = self.dist_maps(image, points)
-
- if self.rgb_conv is not None:
- x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
- else:
- c1, c2 = torch.chunk(coord_features, 2, dim=1)
- c3 = torch.ones_like(c1)
- coord_features = torch.cat((c1, c2, c3), dim=1)
- x = 0.8 * image * coord_features + 0.2 * image
-
- feature_extractor_out = self.feature_extractor(x)
- instance_out = feature_extractor_out[0]
- instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
- mode='bilinear', align_corners=True)
- outputs = {'instances': instance_out}
- if self.with_aux_output:
- instance_aux_out = feature_extractor_out[1]
- instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:],
- mode='bilinear', align_corners=True)
- outputs['instances_aux'] = instance_aux_out
-
- return outputs
-
- def load_weights(self, path_to_weights):
- current_state_dict = self.state_dict()
- new_state_dict = torch.load(path_to_weights)
- current_state_dict.update(new_state_dict)
- self.load_state_dict(current_state_dict)
-
- def get_trainable_params(self):
- backbone_params = nn.ParameterList()
- other_params = nn.ParameterList()
- other_params_keys = []
- nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head']
-
- for name, param in self.named_parameters():
- if param.requires_grad:
- if any(x in name for x in nonbackbone_keywords):
- other_params.append(param)
- other_params_keys.append(name)
- else:
- backbone_params.append(param)
- print('Nonbackbone params:', sorted(other_params_keys))
- return backbone_params, other_params
diff --git a/inference/interact/fbrs/model/losses.py b/inference/interact/fbrs/model/losses.py
deleted file mode 100644
index fd89bf0..0000000
--- a/inference/interact/fbrs/model/losses.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from ..utils import misc
-
-
-class NormalizedFocalLossSigmoid(nn.Module):
- def __init__(self, axis=-1, alpha=0.25, gamma=2,
- from_logits=False, batch_axis=0,
- weight=None, size_average=True, detach_delimeter=True,
- eps=1e-12, scale=1.0,
- ignore_label=-1):
- super(NormalizedFocalLossSigmoid, self).__init__()
- self._axis = axis
- self._alpha = alpha
- self._gamma = gamma
- self._ignore_label = ignore_label
- self._weight = weight if weight is not None else 1.0
- self._batch_axis = batch_axis
-
- self._scale = scale
- self._from_logits = from_logits
- self._eps = eps
- self._size_average = size_average
- self._detach_delimeter = detach_delimeter
- self._k_sum = 0
-
- def forward(self, pred, label, sample_weight=None):
- one_hot = label > 0
- sample_weight = label != self._ignore_label
-
- if not self._from_logits:
- pred = torch.sigmoid(pred)
-
- alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
- pt = torch.where(one_hot, pred, 1 - pred)
- pt = torch.where(sample_weight, pt, torch.ones_like(pt))
-
- beta = (1 - pt) ** self._gamma
-
- sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
- beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
- mult = sw_sum / (beta_sum + self._eps)
- if self._detach_delimeter:
- mult = mult.detach()
- beta = beta * mult
-
- ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
- sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
- if np.any(ignore_area == 0):
- self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
-
- loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
- loss = self._weight * (loss * sample_weight)
-
- if self._size_average:
- bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
- else:
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
-
- return self._scale * loss
-
- def log_states(self, sw, name, global_step):
- sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
-
-
-class FocalLoss(nn.Module):
- def __init__(self, axis=-1, alpha=0.25, gamma=2,
- from_logits=False, batch_axis=0,
- weight=None, num_class=None,
- eps=1e-9, size_average=True, scale=1.0):
- super(FocalLoss, self).__init__()
- self._axis = axis
- self._alpha = alpha
- self._gamma = gamma
- self._weight = weight if weight is not None else 1.0
- self._batch_axis = batch_axis
-
- self._scale = scale
- self._num_class = num_class
- self._from_logits = from_logits
- self._eps = eps
- self._size_average = size_average
-
- def forward(self, pred, label, sample_weight=None):
- if not self._from_logits:
- pred = F.sigmoid(pred)
-
- one_hot = label > 0
- pt = torch.where(one_hot, pred, 1 - pred)
-
- t = label != -1
- alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t)
- beta = (1 - pt) ** self._gamma
-
- loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
- sample_weight = label != -1
-
- loss = self._weight * (loss * sample_weight)
-
- if self._size_average:
- tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
- else:
- loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
-
- return self._scale * loss
-
-
-class SigmoidBinaryCrossEntropyLoss(nn.Module):
- def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
- super(SigmoidBinaryCrossEntropyLoss, self).__init__()
- self._from_sigmoid = from_sigmoid
- self._ignore_label = ignore_label
- self._weight = weight if weight is not None else 1.0
- self._batch_axis = batch_axis
-
- def forward(self, pred, label):
- label = label.view(pred.size())
- sample_weight = label != self._ignore_label
- label = torch.where(sample_weight, label, torch.zeros_like(label))
-
- if not self._from_sigmoid:
- loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
- else:
- eps = 1e-12
- loss = -(torch.log(pred + eps) * label
- + torch.log(1. - pred + eps) * (1. - label))
-
- loss = self._weight * (loss * sample_weight)
- return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
diff --git a/inference/interact/fbrs/model/metrics.py b/inference/interact/fbrs/model/metrics.py
deleted file mode 100644
index 9944feb..0000000
--- a/inference/interact/fbrs/model/metrics.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import torch
-import numpy as np
-
-from ..utils import misc
-
-
-class TrainMetric(object):
- def __init__(self, pred_outputs, gt_outputs):
- self.pred_outputs = pred_outputs
- self.gt_outputs = gt_outputs
-
- def update(self, *args, **kwargs):
- raise NotImplementedError
-
- def get_epoch_value(self):
- raise NotImplementedError
-
- def reset_epoch_stats(self):
- raise NotImplementedError
-
- def log_states(self, sw, tag_prefix, global_step):
- pass
-
- @property
- def name(self):
- return type(self).__name__
-
-
-class AdaptiveIoU(TrainMetric):
- def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
- ignore_label=-1, from_logits=True,
- pred_output='instances', gt_output='instances'):
- super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
- self._ignore_label = ignore_label
- self._from_logits = from_logits
- self._iou_thresh = init_thresh
- self._thresh_step = thresh_step
- self._thresh_beta = thresh_beta
- self._iou_beta = iou_beta
- self._ema_iou = 0.0
- self._epoch_iou_sum = 0.0
- self._epoch_batch_count = 0
-
- def update(self, pred, gt):
- gt_mask = gt > 0
- if self._from_logits:
- pred = torch.sigmoid(pred)
-
- gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
- if np.all(gt_mask_area == 0):
- return
-
- ignore_mask = gt == self._ignore_label
- max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
- best_thresh = self._iou_thresh
- for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
- temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
- if temp_iou > max_iou:
- max_iou = temp_iou
- best_thresh = t
-
- self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
- self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
- self._epoch_iou_sum += max_iou
- self._epoch_batch_count += 1
-
- def get_epoch_value(self):
- if self._epoch_batch_count > 0:
- return self._epoch_iou_sum / self._epoch_batch_count
- else:
- return 0.0
-
- def reset_epoch_stats(self):
- self._epoch_iou_sum = 0.0
- self._epoch_batch_count = 0
-
- def log_states(self, sw, tag_prefix, global_step):
- sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
- sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
-
- @property
- def iou_thresh(self):
- return self._iou_thresh
-
-
-def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
- if ignore_mask is not None:
- pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
-
- reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
- union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
- intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
- nonzero = union > 0
-
- iou = intersection[nonzero] / union[nonzero]
- if not keep_ignore:
- return iou
- else:
- result = np.full_like(intersection, -1)
- result[nonzero] = iou
- return result
diff --git a/inference/interact/fbrs/model/modeling/__init__.py b/inference/interact/fbrs/model/modeling/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/inference/interact/fbrs/model/modeling/basic_blocks.py b/inference/interact/fbrs/model/modeling/basic_blocks.py
deleted file mode 100644
index 35946e8..0000000
--- a/inference/interact/fbrs/model/modeling/basic_blocks.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch.nn as nn
-
-from ...model import ops
-
-
-class ConvHead(nn.Module):
- def __init__(self, out_channels, in_channels=32, num_layers=1,
- kernel_size=3, padding=1,
- norm_layer=nn.BatchNorm2d):
- super(ConvHead, self).__init__()
- convhead = []
-
- for i in range(num_layers):
- convhead.extend([
- nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
- nn.ReLU(),
- norm_layer(in_channels) if norm_layer is not None else nn.Identity()
- ])
- convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
-
- self.convhead = nn.Sequential(*convhead)
-
- def forward(self, *inputs):
- return self.convhead(inputs[0])
-
-
-class SepConvHead(nn.Module):
- def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
- kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
- norm_layer=nn.BatchNorm2d):
- super(SepConvHead, self).__init__()
-
- sepconvhead = []
-
- for i in range(num_layers):
- sepconvhead.append(
- SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
- out_channels=mid_channels,
- dw_kernel=kernel_size, dw_padding=padding,
- norm_layer=norm_layer, activation='relu')
- )
- if dropout_ratio > 0 and dropout_indx == i:
- sepconvhead.append(nn.Dropout(dropout_ratio))
-
- sepconvhead.append(
- nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
- )
-
- self.layers = nn.Sequential(*sepconvhead)
-
- def forward(self, *inputs):
- x = inputs[0]
-
- return self.layers(x)
-
-
-class SeparableConv2d(nn.Module):
- def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
- activation=None, use_bias=False, norm_layer=None):
- super(SeparableConv2d, self).__init__()
- _activation = ops.select_activation_function(activation)
- self.body = nn.Sequential(
- nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
- padding=dw_padding, bias=use_bias, groups=in_channels),
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
- norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
- _activation()
- )
-
- def forward(self, x):
- return self.body(x)
diff --git a/inference/interact/fbrs/model/modeling/deeplab_v3.py b/inference/interact/fbrs/model/modeling/deeplab_v3.py
deleted file mode 100644
index 8e86386..0000000
--- a/inference/interact/fbrs/model/modeling/deeplab_v3.py
+++ /dev/null
@@ -1,176 +0,0 @@
-from contextlib import ExitStack
-
-import torch
-from torch import nn
-import torch.nn.functional as F
-
-from .basic_blocks import SeparableConv2d
-from .resnet import ResNetBackbone
-from ...model import ops
-
-
-class DeepLabV3Plus(nn.Module):
- def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
- backbone_norm_layer=None,
- ch=256,
- project_dropout=0.5,
- inference_mode=False,
- **kwargs):
- super(DeepLabV3Plus, self).__init__()
- if backbone_norm_layer is None:
- backbone_norm_layer = norm_layer
-
- self.backbone_name = backbone
- self.norm_layer = norm_layer
- self.backbone_norm_layer = backbone_norm_layer
- self.inference_mode = False
- self.ch = ch
- self.aspp_in_channels = 2048
- self.skip_project_in_channels = 256 # layer 1 out_channels
-
- self._kwargs = kwargs
- if backbone == 'resnet34':
- self.aspp_in_channels = 512
- self.skip_project_in_channels = 64
-
- self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
- norm_layer=self.backbone_norm_layer, **kwargs)
-
- self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
- norm_layer=self.norm_layer)
- self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
- self.aspp = _ASPP(in_channels=self.aspp_in_channels,
- atrous_rates=[12, 24, 36],
- out_channels=ch,
- project_dropout=project_dropout,
- norm_layer=self.norm_layer)
-
- if inference_mode:
- self.set_prediction_mode()
-
- def load_pretrained_weights(self):
- pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
- norm_layer=self.backbone_norm_layer, **self._kwargs)
- backbone_state_dict = self.backbone.state_dict()
- pretrained_state_dict = pretrained.state_dict()
-
- backbone_state_dict.update(pretrained_state_dict)
- self.backbone.load_state_dict(backbone_state_dict)
-
- if self.inference_mode:
- for param in self.backbone.parameters():
- param.requires_grad = False
-
- def set_prediction_mode(self):
- self.inference_mode = True
- self.eval()
-
- def forward(self, x):
- with ExitStack() as stack:
- if self.inference_mode:
- stack.enter_context(torch.no_grad())
-
- c1, _, c3, c4 = self.backbone(x)
- c1 = self.skip_project(c1)
-
- x = self.aspp(c4)
- x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
- x = torch.cat((x, c1), dim=1)
- x = self.head(x)
-
- return x,
-
-
-class _SkipProject(nn.Module):
- def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
- super(_SkipProject, self).__init__()
- _activation = ops.select_activation_function("relu")
-
- self.skip_project = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
- norm_layer(out_channels),
- _activation()
- )
-
- def forward(self, x):
- return self.skip_project(x)
-
-
-class _DeepLabHead(nn.Module):
- def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
- super(_DeepLabHead, self).__init__()
-
- self.block = nn.Sequential(
- SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
- dw_padding=1, activation='relu', norm_layer=norm_layer),
- SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
- dw_padding=1, activation='relu', norm_layer=norm_layer),
- nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
- )
-
- def forward(self, x):
- return self.block(x)
-
-
-class _ASPP(nn.Module):
- def __init__(self, in_channels, atrous_rates, out_channels=256,
- project_dropout=0.5, norm_layer=nn.BatchNorm2d):
- super(_ASPP, self).__init__()
-
- b0 = nn.Sequential(
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
- norm_layer(out_channels),
- nn.ReLU()
- )
-
- rate1, rate2, rate3 = tuple(atrous_rates)
- b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
- b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
- b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
- b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
-
- self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
-
- project = [
- nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
- kernel_size=1, bias=False),
- norm_layer(out_channels),
- nn.ReLU()
- ]
- if project_dropout > 0:
- project.append(nn.Dropout(project_dropout))
- self.project = nn.Sequential(*project)
-
- def forward(self, x):
- x = torch.cat([block(x) for block in self.concurent], dim=1)
-
- return self.project(x)
-
-
-class _AsppPooling(nn.Module):
- def __init__(self, in_channels, out_channels, norm_layer):
- super(_AsppPooling, self).__init__()
-
- self.gap = nn.Sequential(
- nn.AdaptiveAvgPool2d((1, 1)),
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
- kernel_size=1, bias=False),
- norm_layer(out_channels),
- nn.ReLU()
- )
-
- def forward(self, x):
- pool = self.gap(x)
- return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
-
-
-def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
- block = nn.Sequential(
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
- kernel_size=3, padding=atrous_rate,
- dilation=atrous_rate, bias=False),
- norm_layer(out_channels),
- nn.ReLU()
- )
-
- return block
diff --git a/inference/interact/fbrs/model/modeling/hrnet_ocr.py b/inference/interact/fbrs/model/modeling/hrnet_ocr.py
deleted file mode 100644
index e5f8eff..0000000
--- a/inference/interact/fbrs/model/modeling/hrnet_ocr.py
+++ /dev/null
@@ -1,399 +0,0 @@
-import os
-import numpy as np
-import torch
-import torch.nn as nn
-import torch._utils
-import torch.nn.functional as F
-from .ocr import SpatialOCR_Module, SpatialGather_Module
-from .resnetv1b import BasicBlockV1b, BottleneckV1b
-
-relu_inplace = True
-
-
-class HighResolutionModule(nn.Module):
- def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
- num_channels, fuse_method,multi_scale_output=True,
- norm_layer=nn.BatchNorm2d, align_corners=True):
- super(HighResolutionModule, self).__init__()
- self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
-
- self.num_inchannels = num_inchannels
- self.fuse_method = fuse_method
- self.num_branches = num_branches
- self.norm_layer = norm_layer
- self.align_corners = align_corners
-
- self.multi_scale_output = multi_scale_output
-
- self.branches = self._make_branches(
- num_branches, blocks, num_blocks, num_channels)
- self.fuse_layers = self._make_fuse_layers()
- self.relu = nn.ReLU(inplace=relu_inplace)
-
- def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
- if num_branches != len(num_blocks):
- error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
- num_branches, len(num_blocks))
- raise ValueError(error_msg)
-
- if num_branches != len(num_channels):
- error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
- num_branches, len(num_channels))
- raise ValueError(error_msg)
-
- if num_branches != len(num_inchannels):
- error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
- num_branches, len(num_inchannels))
- raise ValueError(error_msg)
-
- def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
- stride=1):
- downsample = None
- if stride != 1 or \
- self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(self.num_inchannels[branch_index],
- num_channels[branch_index] * block.expansion,
- kernel_size=1, stride=stride, bias=False),
- self.norm_layer(num_channels[branch_index] * block.expansion),
- )
-
- layers = []
- layers.append(block(self.num_inchannels[branch_index],
- num_channels[branch_index], stride,
- downsample=downsample, norm_layer=self.norm_layer))
- self.num_inchannels[branch_index] = \
- num_channels[branch_index] * block.expansion
- for i in range(1, num_blocks[branch_index]):
- layers.append(block(self.num_inchannels[branch_index],
- num_channels[branch_index],
- norm_layer=self.norm_layer))
-
- return nn.Sequential(*layers)
-
- def _make_branches(self, num_branches, block, num_blocks, num_channels):
- branches = []
-
- for i in range(num_branches):
- branches.append(
- self._make_one_branch(i, block, num_blocks, num_channels))
-
- return nn.ModuleList(branches)
-
- def _make_fuse_layers(self):
- if self.num_branches == 1:
- return None
-
- num_branches = self.num_branches
- num_inchannels = self.num_inchannels
- fuse_layers = []
- for i in range(num_branches if self.multi_scale_output else 1):
- fuse_layer = []
- for j in range(num_branches):
- if j > i:
- fuse_layer.append(nn.Sequential(
- nn.Conv2d(in_channels=num_inchannels[j],
- out_channels=num_inchannels[i],
- kernel_size=1,
- bias=False),
- self.norm_layer(num_inchannels[i])))
- elif j == i:
- fuse_layer.append(None)
- else:
- conv3x3s = []
- for k in range(i - j):
- if k == i - j - 1:
- num_outchannels_conv3x3 = num_inchannels[i]
- conv3x3s.append(nn.Sequential(
- nn.Conv2d(num_inchannels[j],
- num_outchannels_conv3x3,
- kernel_size=3, stride=2, padding=1, bias=False),
- self.norm_layer(num_outchannels_conv3x3)))
- else:
- num_outchannels_conv3x3 = num_inchannels[j]
- conv3x3s.append(nn.Sequential(
- nn.Conv2d(num_inchannels[j],
- num_outchannels_conv3x3,
- kernel_size=3, stride=2, padding=1, bias=False),
- self.norm_layer(num_outchannels_conv3x3),
- nn.ReLU(inplace=relu_inplace)))
- fuse_layer.append(nn.Sequential(*conv3x3s))
- fuse_layers.append(nn.ModuleList(fuse_layer))
-
- return nn.ModuleList(fuse_layers)
-
- def get_num_inchannels(self):
- return self.num_inchannels
-
- def forward(self, x):
- if self.num_branches == 1:
- return [self.branches[0](x[0])]
-
- for i in range(self.num_branches):
- x[i] = self.branches[i](x[i])
-
- x_fuse = []
- for i in range(len(self.fuse_layers)):
- y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
- for j in range(1, self.num_branches):
- if i == j:
- y = y + x[j]
- elif j > i:
- width_output = x[i].shape[-1]
- height_output = x[i].shape[-2]
- y = y + F.interpolate(
- self.fuse_layers[i][j](x[j]),
- size=[height_output, width_output],
- mode='bilinear', align_corners=self.align_corners)
- else:
- y = y + self.fuse_layers[i][j](x[j])
- x_fuse.append(self.relu(y))
-
- return x_fuse
-
-
-class HighResolutionNet(nn.Module):
- def __init__(self, width, num_classes, ocr_width=256, small=False,
- norm_layer=nn.BatchNorm2d, align_corners=True):
- super(HighResolutionNet, self).__init__()
- self.norm_layer = norm_layer
- self.width = width
- self.ocr_width = ocr_width
- self.align_corners = align_corners
-
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
- self.bn1 = norm_layer(64)
- self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
- self.bn2 = norm_layer(64)
- self.relu = nn.ReLU(inplace=relu_inplace)
-
- num_blocks = 2 if small else 4
-
- stage1_num_channels = 64
- self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
- stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
-
- self.stage2_num_branches = 2
- num_channels = [width, 2 * width]
- num_inchannels = [
- num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
- self.transition1 = self._make_transition_layer(
- [stage1_out_channel], num_inchannels)
- self.stage2, pre_stage_channels = self._make_stage(
- BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
- num_blocks=2 * [num_blocks], num_channels=num_channels)
-
- self.stage3_num_branches = 3
- num_channels = [width, 2 * width, 4 * width]
- num_inchannels = [
- num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
- self.transition2 = self._make_transition_layer(
- pre_stage_channels, num_inchannels)
- self.stage3, pre_stage_channels = self._make_stage(
- BasicBlockV1b, num_inchannels=num_inchannels,
- num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
- num_blocks=3 * [num_blocks], num_channels=num_channels)
-
- self.stage4_num_branches = 4
- num_channels = [width, 2 * width, 4 * width, 8 * width]
- num_inchannels = [
- num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
- self.transition3 = self._make_transition_layer(
- pre_stage_channels, num_inchannels)
- self.stage4, pre_stage_channels = self._make_stage(
- BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
- num_branches=self.stage4_num_branches,
- num_blocks=4 * [num_blocks], num_channels=num_channels)
-
- last_inp_channels = np.int32(np.sum(pre_stage_channels))
- ocr_mid_channels = 2 * ocr_width
- ocr_key_channels = ocr_width
-
- self.conv3x3_ocr = nn.Sequential(
- nn.Conv2d(last_inp_channels, ocr_mid_channels,
- kernel_size=3, stride=1, padding=1),
- norm_layer(ocr_mid_channels),
- nn.ReLU(inplace=relu_inplace),
- )
- self.ocr_gather_head = SpatialGather_Module(num_classes)
-
- self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
- key_channels=ocr_key_channels,
- out_channels=ocr_mid_channels,
- scale=1,
- dropout=0.05,
- norm_layer=norm_layer,
- align_corners=align_corners)
- self.cls_head = nn.Conv2d(
- ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
-
- self.aux_head = nn.Sequential(
- nn.Conv2d(last_inp_channels, last_inp_channels,
- kernel_size=1, stride=1, padding=0),
- norm_layer(last_inp_channels),
- nn.ReLU(inplace=relu_inplace),
- nn.Conv2d(last_inp_channels, num_classes,
- kernel_size=1, stride=1, padding=0, bias=True)
- )
-
- def _make_transition_layer(
- self, num_channels_pre_layer, num_channels_cur_layer):
- num_branches_cur = len(num_channels_cur_layer)
- num_branches_pre = len(num_channels_pre_layer)
-
- transition_layers = []
- for i in range(num_branches_cur):
- if i < num_branches_pre:
- if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
- transition_layers.append(nn.Sequential(
- nn.Conv2d(num_channels_pre_layer[i],
- num_channels_cur_layer[i],
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False),
- self.norm_layer(num_channels_cur_layer[i]),
- nn.ReLU(inplace=relu_inplace)))
- else:
- transition_layers.append(None)
- else:
- conv3x3s = []
- for j in range(i + 1 - num_branches_pre):
- inchannels = num_channels_pre_layer[-1]
- outchannels = num_channels_cur_layer[i] \
- if j == i - num_branches_pre else inchannels
- conv3x3s.append(nn.Sequential(
- nn.Conv2d(inchannels, outchannels,
- kernel_size=3, stride=2, padding=1, bias=False),
- self.norm_layer(outchannels),
- nn.ReLU(inplace=relu_inplace)))
- transition_layers.append(nn.Sequential(*conv3x3s))
-
- return nn.ModuleList(transition_layers)
-
- def _make_layer(self, block, inplanes, planes, blocks, stride=1):
- downsample = None
- if stride != 1 or inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
- self.norm_layer(planes * block.expansion),
- )
-
- layers = []
- layers.append(block(inplanes, planes, stride,
- downsample=downsample, norm_layer=self.norm_layer))
- inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
-
- return nn.Sequential(*layers)
-
- def _make_stage(self, block, num_inchannels,
- num_modules, num_branches, num_blocks, num_channels,
- fuse_method='SUM',
- multi_scale_output=True):
- modules = []
- for i in range(num_modules):
- # multi_scale_output is only used last module
- if not multi_scale_output and i == num_modules - 1:
- reset_multi_scale_output = False
- else:
- reset_multi_scale_output = True
- modules.append(
- HighResolutionModule(num_branches,
- block,
- num_blocks,
- num_inchannels,
- num_channels,
- fuse_method,
- reset_multi_scale_output,
- norm_layer=self.norm_layer,
- align_corners=self.align_corners)
- )
- num_inchannels = modules[-1].get_num_inchannels()
-
- return nn.Sequential(*modules), num_inchannels
-
- def forward(self, x):
- feats = self.compute_hrnet_feats(x)
- out_aux = self.aux_head(feats)
- feats = self.conv3x3_ocr(feats)
-
- context = self.ocr_gather_head(feats, out_aux)
- feats = self.ocr_distri_head(feats, context)
- out = self.cls_head(feats)
-
- return [out, out_aux]
-
- def compute_hrnet_feats(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
- x = self.layer1(x)
-
- x_list = []
- for i in range(self.stage2_num_branches):
- if self.transition1[i] is not None:
- x_list.append(self.transition1[i](x))
- else:
- x_list.append(x)
- y_list = self.stage2(x_list)
-
- x_list = []
- for i in range(self.stage3_num_branches):
- if self.transition2[i] is not None:
- if i < self.stage2_num_branches:
- x_list.append(self.transition2[i](y_list[i]))
- else:
- x_list.append(self.transition2[i](y_list[-1]))
- else:
- x_list.append(y_list[i])
- y_list = self.stage3(x_list)
-
- x_list = []
- for i in range(self.stage4_num_branches):
- if self.transition3[i] is not None:
- if i < self.stage3_num_branches:
- x_list.append(self.transition3[i](y_list[i]))
- else:
- x_list.append(self.transition3[i](y_list[-1]))
- else:
- x_list.append(y_list[i])
- x = self.stage4(x_list)
-
- # Upsampling
- x0_h, x0_w = x[0].size(2), x[0].size(3)
- x1 = F.interpolate(x[1], size=(x0_h, x0_w),
- mode='bilinear', align_corners=self.align_corners)
- x2 = F.interpolate(x[2], size=(x0_h, x0_w),
- mode='bilinear', align_corners=self.align_corners)
- x3 = F.interpolate(x[3], size=(x0_h, x0_w),
- mode='bilinear', align_corners=self.align_corners)
-
- return torch.cat([x[0], x1, x2, x3], 1)
-
- def load_pretrained_weights(self, pretrained_path=''):
- model_dict = self.state_dict()
-
- if not os.path.exists(pretrained_path):
- print(f'\nFile "{pretrained_path}" does not exist.')
- print('You need to specify the correct path to the pre-trained weights.\n'
- 'You can download the weights for HRNet from the repository:\n'
- 'https://github.com/HRNet/HRNet-Image-Classification')
- exit(1)
- pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
- pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
- pretrained_dict.items()}
-
- print('model_dict-pretrained_dict:', sorted(list(set(model_dict) - set(pretrained_dict))))
- print('pretrained_dict-model_dict:', sorted(list(set(pretrained_dict) - set(model_dict))))
-
- pretrained_dict = {k: v for k, v in pretrained_dict.items()
- if k in model_dict.keys()}
-
- model_dict.update(pretrained_dict)
- self.load_state_dict(model_dict)
diff --git a/inference/interact/fbrs/model/modeling/ocr.py b/inference/interact/fbrs/model/modeling/ocr.py
deleted file mode 100644
index df3b4f6..0000000
--- a/inference/interact/fbrs/model/modeling/ocr.py
+++ /dev/null
@@ -1,141 +0,0 @@
-import torch
-import torch.nn as nn
-import torch._utils
-import torch.nn.functional as F
-
-
-class SpatialGather_Module(nn.Module):
- """
- Aggregate the context features according to the initial
- predicted probability distribution.
- Employ the soft-weighted method to aggregate the context.
- """
-
- def __init__(self, cls_num=0, scale=1):
- super(SpatialGather_Module, self).__init__()
- self.cls_num = cls_num
- self.scale = scale
-
- def forward(self, feats, probs):
- batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
- probs = probs.view(batch_size, c, -1)
- feats = feats.view(batch_size, feats.size(1), -1)
- feats = feats.permute(0, 2, 1) # batch x hw x c
- probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
- ocr_context = torch.matmul(probs, feats) \
- .permute(0, 2, 1).unsqueeze(3) # batch x k x c
- return ocr_context
-
-
-class SpatialOCR_Module(nn.Module):
- """
- Implementation of the OCR module:
- We aggregate the global object representation to update the representation for each pixel.
- """
-
- def __init__(self,
- in_channels,
- key_channels,
- out_channels,
- scale=1,
- dropout=0.1,
- norm_layer=nn.BatchNorm2d,
- align_corners=True):
- super(SpatialOCR_Module, self).__init__()
- self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
- norm_layer, align_corners)
- _in_channels = 2 * in_channels
-
- self.conv_bn_dropout = nn.Sequential(
- nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
- nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
- nn.Dropout2d(dropout)
- )
-
- def forward(self, feats, proxy_feats):
- context = self.object_context_block(feats, proxy_feats)
-
- output = self.conv_bn_dropout(torch.cat([context, feats], 1))
-
- return output
-
-
-class ObjectAttentionBlock2D(nn.Module):
- '''
- The basic implementation for object context block
- Input:
- N X C X H X W
- Parameters:
- in_channels : the dimension of the input feature map
- key_channels : the dimension after the key/query transform
- scale : choose the scale to downsample the input feature maps (save memory cost)
- bn_type : specify the bn type
- Return:
- N X C X H X W
- '''
-
- def __init__(self,
- in_channels,
- key_channels,
- scale=1,
- norm_layer=nn.BatchNorm2d,
- align_corners=True):
- super(ObjectAttentionBlock2D, self).__init__()
- self.scale = scale
- self.in_channels = in_channels
- self.key_channels = key_channels
- self.align_corners = align_corners
-
- self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
- self.f_pixel = nn.Sequential(
- nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
- kernel_size=1, stride=1, padding=0, bias=False),
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
- nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
- kernel_size=1, stride=1, padding=0, bias=False),
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
- )
- self.f_object = nn.Sequential(
- nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
- kernel_size=1, stride=1, padding=0, bias=False),
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
- nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
- kernel_size=1, stride=1, padding=0, bias=False),
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
- )
- self.f_down = nn.Sequential(
- nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
- kernel_size=1, stride=1, padding=0, bias=False),
- nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
- )
- self.f_up = nn.Sequential(
- nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
- kernel_size=1, stride=1, padding=0, bias=False),
- nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
- )
-
- def forward(self, x, proxy):
- batch_size, h, w = x.size(0), x.size(2), x.size(3)
- if self.scale > 1:
- x = self.pool(x)
-
- query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
- query = query.permute(0, 2, 1)
- key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
- value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
- value = value.permute(0, 2, 1)
-
- sim_map = torch.matmul(query, key)
- sim_map = (self.key_channels ** -.5) * sim_map
- sim_map = F.softmax(sim_map, dim=-1)
-
- # add bg context ...
- context = torch.matmul(sim_map, value)
- context = context.permute(0, 2, 1).contiguous()
- context = context.view(batch_size, self.key_channels, *x.size()[2:])
- context = self.f_up(context)
- if self.scale > 1:
- context = F.interpolate(input=context, size=(h, w),
- mode='bilinear', align_corners=self.align_corners)
-
- return context
diff --git a/inference/interact/fbrs/model/modeling/resnet.py b/inference/interact/fbrs/model/modeling/resnet.py
deleted file mode 100644
index 349ea1c..0000000
--- a/inference/interact/fbrs/model/modeling/resnet.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import torch
-from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
-
-
-class ResNetBackbone(torch.nn.Module):
- def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
- super(ResNetBackbone, self).__init__()
-
- if backbone == 'resnet34':
- pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
- elif backbone == 'resnet50':
- pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
- elif backbone == 'resnet101':
- pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
- elif backbone == 'resnet152':
- pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
- else:
- raise RuntimeError(f'unknown backbone: {backbone}')
-
- self.conv1 = pretrained.conv1
- self.bn1 = pretrained.bn1
- self.relu = pretrained.relu
- self.maxpool = pretrained.maxpool
- self.layer1 = pretrained.layer1
- self.layer2 = pretrained.layer2
- self.layer3 = pretrained.layer3
- self.layer4 = pretrained.layer4
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
- c1 = self.layer1(x)
- c2 = self.layer2(c1)
- c3 = self.layer3(c2)
- c4 = self.layer4(c3)
-
- return c1, c2, c3, c4
diff --git a/inference/interact/fbrs/model/modeling/resnetv1b.py b/inference/interact/fbrs/model/modeling/resnetv1b.py
deleted file mode 100644
index 4ad24ce..0000000
--- a/inference/interact/fbrs/model/modeling/resnetv1b.py
+++ /dev/null
@@ -1,276 +0,0 @@
-import torch
-import torch.nn as nn
-GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
-
-
-class BasicBlockV1b(nn.Module):
- expansion = 1
-
- def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
- previous_dilation=1, norm_layer=nn.BatchNorm2d):
- super(BasicBlockV1b, self).__init__()
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
- padding=dilation, dilation=dilation, bias=False)
- self.bn1 = norm_layer(planes)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
- padding=previous_dilation, dilation=previous_dilation, bias=False)
- self.bn2 = norm_layer(planes)
-
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out = out + residual
- out = self.relu(out)
-
- return out
-
-
-class BottleneckV1b(nn.Module):
- expansion = 4
-
- def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
- previous_dilation=1, norm_layer=nn.BatchNorm2d):
- super(BottleneckV1b, self).__init__()
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
- self.bn1 = norm_layer(planes)
-
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
- padding=dilation, dilation=dilation, bias=False)
- self.bn2 = norm_layer(planes)
-
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
- self.bn3 = norm_layer(planes * self.expansion)
-
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out = out + residual
- out = self.relu(out)
-
- return out
-
-
-class ResNetV1b(nn.Module):
- """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
-
- Parameters
- ----------
- block : Block
- Class for the residual block. Options are BasicBlockV1, BottleneckV1.
- layers : list of int
- Numbers of layers in each block
- classes : int, default 1000
- Number of classification classes.
- dilated : bool, default False
- Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
- typically used in Semantic Segmentation.
- norm_layer : object
- Normalization layer used (default: :class:`nn.BatchNorm2d`)
- deep_stem : bool, default False
- Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
- avg_down : bool, default False
- Whether to use average pooling for projection skip connection between stages/downsample.
- final_drop : float, default 0.0
- Dropout ratio before the final classification layer.
-
- Reference:
- - He, Kaiming, et al. "Deep residual learning for image recognition."
- Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
-
- - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
- """
- def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
- avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
- self.inplanes = stem_width*2 if deep_stem else 64
- super(ResNetV1b, self).__init__()
- if not deep_stem:
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
- else:
- self.conv1 = nn.Sequential(
- nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
- norm_layer(stem_width),
- nn.ReLU(True),
- nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
- norm_layer(stem_width),
- nn.ReLU(True),
- nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
- )
- self.bn1 = norm_layer(self.inplanes)
- self.relu = nn.ReLU(True)
- self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
- norm_layer=norm_layer)
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
- norm_layer=norm_layer)
- if dilated:
- self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
- avg_down=avg_down, norm_layer=norm_layer)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
- avg_down=avg_down, norm_layer=norm_layer)
- else:
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
- avg_down=avg_down, norm_layer=norm_layer)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
- avg_down=avg_down, norm_layer=norm_layer)
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.drop = None
- if final_drop > 0.0:
- self.drop = nn.Dropout(final_drop)
- self.fc = nn.Linear(512 * block.expansion, classes)
-
- def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
- avg_down=False, norm_layer=nn.BatchNorm2d):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = []
- if avg_down:
- if dilation == 1:
- downsample.append(
- nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
- )
- else:
- downsample.append(
- nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
- )
- downsample.extend([
- nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
- kernel_size=1, stride=1, bias=False),
- norm_layer(planes * block.expansion)
- ])
- downsample = nn.Sequential(*downsample)
- else:
- downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
- norm_layer(planes * block.expansion)
- )
-
- layers = []
- if dilation in (1, 2):
- layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
- previous_dilation=dilation, norm_layer=norm_layer))
- elif dilation == 4:
- layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
- previous_dilation=dilation, norm_layer=norm_layer))
- else:
- raise RuntimeError("=> unknown dilation size: {}".format(dilation))
-
- self.inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(block(self.inplanes, planes, dilation=dilation,
- previous_dilation=dilation, norm_layer=norm_layer))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- x = x.view(x.size(0), -1)
- if self.drop is not None:
- x = self.drop(x)
- x = self.fc(x)
-
- return x
-
-
-def _safe_state_dict_filtering(orig_dict, model_dict_keys):
- filtered_orig_dict = {}
- for k, v in orig_dict.items():
- if k in model_dict_keys:
- filtered_orig_dict[k] = v
- else:
- print(f"[ERROR] Failed to load <{k}> in backbone")
- return filtered_orig_dict
-
-
-def resnet34_v1b(pretrained=False, **kwargs):
- model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
- if pretrained:
- model_dict = model.state_dict()
- filtered_orig_dict = _safe_state_dict_filtering(
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
- model_dict.keys()
- )
- model_dict.update(filtered_orig_dict)
- model.load_state_dict(model_dict)
- return model
-
-
-def resnet50_v1s(pretrained=False, **kwargs):
- model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
- if pretrained:
- model_dict = model.state_dict()
- filtered_orig_dict = _safe_state_dict_filtering(
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
- model_dict.keys()
- )
- model_dict.update(filtered_orig_dict)
- model.load_state_dict(model_dict)
- return model
-
-
-def resnet101_v1s(pretrained=False, **kwargs):
- model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
- if pretrained:
- model_dict = model.state_dict()
- filtered_orig_dict = _safe_state_dict_filtering(
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
- model_dict.keys()
- )
- model_dict.update(filtered_orig_dict)
- model.load_state_dict(model_dict)
- return model
-
-
-def resnet152_v1s(pretrained=False, **kwargs):
- model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
- if pretrained:
- model_dict = model.state_dict()
- filtered_orig_dict = _safe_state_dict_filtering(
- torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
- model_dict.keys()
- )
- model_dict.update(filtered_orig_dict)
- model.load_state_dict(model_dict)
- return model
diff --git a/inference/interact/fbrs/model/ops.py b/inference/interact/fbrs/model/ops.py
deleted file mode 100644
index f46ae39..0000000
--- a/inference/interact/fbrs/model/ops.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import torch
-from torch import nn as nn
-import numpy as np
-
-from . import initializer as initializer
-from ..utils.cython import get_dist_maps
-
-
-def select_activation_function(activation):
- if isinstance(activation, str):
- if activation.lower() == 'relu':
- return nn.ReLU
- elif activation.lower() == 'softplus':
- return nn.Softplus
- else:
- raise ValueError(f"Unknown activation type {activation}")
- elif isinstance(activation, nn.Module):
- return activation
- else:
- raise ValueError(f"Unknown activation type {activation}")
-
-
-class BilinearConvTranspose2d(nn.ConvTranspose2d):
- def __init__(self, in_channels, out_channels, scale, groups=1):
- kernel_size = 2 * scale - scale % 2
- self.scale = scale
-
- super().__init__(
- in_channels, out_channels,
- kernel_size=kernel_size,
- stride=scale,
- padding=1,
- groups=groups,
- bias=False)
-
- self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
-
-
-class DistMaps(nn.Module):
- def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
- super(DistMaps, self).__init__()
- self.spatial_scale = spatial_scale
- self.norm_radius = norm_radius
- self.cpu_mode = cpu_mode
-
- def get_coord_features(self, points, batchsize, rows, cols):
- if self.cpu_mode:
- coords = []
- for i in range(batchsize):
- norm_delimeter = self.spatial_scale * self.norm_radius
- coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
- norm_delimeter))
- coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
- else:
- num_points = points.shape[1] // 2
- points = points.view(-1, 2)
- invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
- row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
- col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
-
- coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
- coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
-
- add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
- coords.add_(-add_xy)
- coords.div_(self.norm_radius * self.spatial_scale)
- coords.mul_(coords)
-
- coords[:, 0] += coords[:, 1]
- coords = coords[:, :1]
-
- coords[invalid_points, :, :, :] = 1e6
-
- coords = coords.view(-1, num_points, 1, rows, cols)
- coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w
- coords = coords.view(-1, 2, rows, cols)
-
- coords.sqrt_().mul_(2).tanh_()
-
- return coords
-
- def forward(self, x, coords):
- return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3])
diff --git a/inference/interact/fbrs/model/syncbn/LICENSE b/inference/interact/fbrs/model/syncbn/LICENSE
deleted file mode 100644
index fec5469..0000000
--- a/inference/interact/fbrs/model/syncbn/LICENSE
+++ /dev/null
@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2018 Tamaki Kojima
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/inference/interact/fbrs/model/syncbn/README.md b/inference/interact/fbrs/model/syncbn/README.md
deleted file mode 100755
index d9a9ea2..0000000
--- a/inference/interact/fbrs/model/syncbn/README.md
+++ /dev/null
@@ -1,127 +0,0 @@
-# pytorch-syncbn
-
-Tamaki Kojima(tamakoji@gmail.com)
-
-## Announcement
-
-**Pytorch 1.0 support**
-
-## Overview
-This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
-
-The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn)
-
-## Remarks
-- Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel`
-- Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation
-- You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
-- Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d`
-
-## Requirements
-For PyTorch, please refer to https://pytorch.org/
-
-NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04
-
-It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use.
-
-```
-sudo apt-get install ninja-build
-```
-
-Also install all dependencies for python. For pip, run:
-
-
-```
-pip install -U -r requirements.txt
-```
-
-## Build
-
-There is no need to build. just run and JIT will take care.
-JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes.
-
-## Usage
-
-Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d`
-
-```
-import torch
-from modules import nn as NN
-num_gpu = torch.cuda.device_count()
-model = nn.Sequential(
- nn.Conv2d(3, 3, 1, 1, bias=False),
- NN.BatchNorm2d(3),
- nn.ReLU(inplace=True),
- nn.Conv2d(3, 3, 1, 1, bias=False),
- NN.BatchNorm2d(3),
-).cuda()
-model = nn.DataParallel(model, device_ids=range(num_gpu))
-x = torch.rand(num_gpu, 3, 2, 2).cuda()
-z = model(x)
-```
-
-## Math
-
-### Forward
-1. compute in each gpu
-2. gather all from workers to master and compute where
-
-
-
- and
-
-
-
- and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
-
-3. forward batchnorm using global stats by
-
-
-
- and then
-
-
-
- where is weight parameter and is bias parameter.
-
-4. save for backward
-
-### Backward
-
-1. Restore saved
-
-2. Compute below sums on each gpu
-
-
-
- and
-
-
-
- where
-
- then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
-
-3. compute gradients using global stats
-
-
-
- where
-
-
-
- and
-
-
-
- and finally,
-
-
-
-
-
-
-
- Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.
-
- You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/)
\ No newline at end of file
diff --git a/inference/interact/fbrs/model/syncbn/__init__.py b/inference/interact/fbrs/model/syncbn/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/inference/interact/fbrs/model/syncbn/modules/__init__.py b/inference/interact/fbrs/model/syncbn/modules/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py b/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py
deleted file mode 100644
index a8eb83a..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .syncbn import batchnorm2d_sync
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py b/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py
deleted file mode 100644
index d0c1409..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""
-/*****************************************************************************/
-
-Extension module loader
-
-code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark
-
-/*****************************************************************************/
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import glob
-import os.path
-
-import torch
-
-try:
- from torch.utils.cpp_extension import load
- from torch.utils.cpp_extension import CUDA_HOME
-except ImportError:
- raise ImportError(
- "The cpp layer extensions requires PyTorch 0.4 or higher")
-
-
-def _load_C_extensions():
- this_dir = os.path.dirname(os.path.abspath(__file__))
- this_dir = os.path.join(this_dir, "csrc")
-
- main_file = glob.glob(os.path.join(this_dir, "*.cpp"))
- sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp"))
- sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu"))
-
- sources = main_file + sources_cpu
-
- extra_cflags = []
- extra_cuda_cflags = []
- if torch.cuda.is_available() and CUDA_HOME is not None:
- sources.extend(sources_cuda)
- extra_cflags = ["-O3", "-DWITH_CUDA"]
- extra_cuda_cflags = ["--expt-extended-lambda"]
- sources = [os.path.join(this_dir, s) for s in sources]
- extra_include_paths = [this_dir]
- return load(
- name="ext_lib",
- sources=sources,
- extra_cflags=extra_cflags,
- extra_include_paths=extra_include_paths,
- extra_cuda_cflags=extra_cuda_cflags,
- )
-
-
-_backend = _load_C_extensions()
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h
deleted file mode 100644
index 52567a4..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/*****************************************************************************
-
-SyncBN
-
-*****************************************************************************/
-#pragma once
-
-#ifdef WITH_CUDA
-#include "cuda/ext_lib.h"
-#endif
-
-/// SyncBN
-
-std::vector syncbn_sum_sqsum(const at::Tensor& x) {
- if (x.is_cuda()) {
-#ifdef WITH_CUDA
- return syncbn_sum_sqsum_cuda(x);
-#else
- AT_ERROR("Not compiled with GPU support");
-#endif
- } else {
- AT_ERROR("CPU implementation not supported");
- }
-}
-
-at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight,
- const at::Tensor& bias, const at::Tensor& mean,
- const at::Tensor& var, bool affine, float eps) {
- if (x.is_cuda()) {
-#ifdef WITH_CUDA
- return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps);
-#else
- AT_ERROR("Not compiled with GPU support");
-#endif
- } else {
- AT_ERROR("CPU implementation not supported");
- }
-}
-
-std::vector syncbn_backward_xhat(const at::Tensor& dz,
- const at::Tensor& x,
- const at::Tensor& mean,
- const at::Tensor& var, float eps) {
- if (dz.is_cuda()) {
-#ifdef WITH_CUDA
- return syncbn_backward_xhat_cuda(dz, x, mean, var, eps);
-#else
- AT_ERROR("Not compiled with GPU support");
-#endif
- } else {
- AT_ERROR("CPU implementation not supported");
- }
-}
-
-std::vector syncbn_backward(
- const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
- const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
- const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
- float eps) {
- if (dz.is_cuda()) {
-#ifdef WITH_CUDA
- return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz,
- sum_dz_xhat, affine, eps);
-#else
- AT_ERROR("Not compiled with GPU support");
-#endif
- } else {
- AT_ERROR("CPU implementation not supported");
- }
-}
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu
deleted file mode 100644
index 9458eba..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu
+++ /dev/null
@@ -1,280 +0,0 @@
-/*****************************************************************************
-
-CUDA SyncBN code
-
-code referenced from : https://github.com/mapillary/inplace_abn
-
-*****************************************************************************/
-#include
-#include
-#include
-#include
-#include "cuda/common.h"
-
-// Utilities
-void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) {
- num = x.size(0);
- chn = x.size(1);
- sp = 1;
- for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i);
-}
-
-/// SyncBN
-
-template
-struct SqSumOp {
- __device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {}
- __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
- T x = tensor[(batch * chn + plane) * sp + n];
- return Pair(x, x * x); // x, x^2
- }
- const T *tensor;
- const int chn;
- const int sp;
-};
-
-template
-__global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum,
- int num, int chn, int sp) {
- int plane = blockIdx.x;
- Pair res =
- reduce, SqSumOp>(SqSumOp(x, chn, sp), plane, num, chn, sp);
- __syncthreads();
- if (threadIdx.x == 0) {
- sum[plane] = res.v1;
- sqsum[plane] = res.v2;
- }
-}
-
-std::vector syncbn_sum_sqsum_cuda(const at::Tensor &x) {
- CHECK_INPUT(x);
-
- // Extract dimensions
- int64_t num, chn, sp;
- get_dims(x, num, chn, sp);
-
- // Prepare output tensors
- auto sum = at::empty({chn}, x.options());
- auto sqsum = at::empty({chn}, x.options());
-
- // Run kernel
- dim3 blocks(chn);
- dim3 threads(getNumThreads(sp));
- AT_DISPATCH_FLOATING_TYPES(
- x.type(), "syncbn_sum_sqsum_cuda", ([&] {
- syncbn_sum_sqsum_kernel<<>>(
- x.data(), sum.data(),
- sqsum.data(), num, chn, sp);
- }));
- return {sum, sqsum};
-}
-
-template
-__global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight,
- const T *bias, const T *mean,
- const T *var, bool affine, float eps,
- int num, int chn, int sp) {
- int plane = blockIdx.x;
- T _mean = mean[plane];
- T _var = var[plane];
- T _weight = affine ? weight[plane] : T(1);
- T _bias = affine ? bias[plane] : T(0);
- float _invstd = T(0);
- if (_var || eps) {
- _invstd = rsqrt(_var + eps);
- }
- for (int batch = 0; batch < num; ++batch) {
- for (int n = threadIdx.x; n < sp; n += blockDim.x) {
- T _x = x[(batch * chn + plane) * sp + n];
- T _xhat = (_x - _mean) * _invstd;
- T _z = _xhat * _weight + _bias;
- z[(batch * chn + plane) * sp + n] = _z;
- }
- }
-}
-
-at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight,
- const at::Tensor &bias, const at::Tensor &mean,
- const at::Tensor &var, bool affine, float eps) {
- CHECK_INPUT(x);
- CHECK_INPUT(weight);
- CHECK_INPUT(bias);
- CHECK_INPUT(mean);
- CHECK_INPUT(var);
-
- // Extract dimensions
- int64_t num, chn, sp;
- get_dims(x, num, chn, sp);
-
- auto z = at::zeros_like(x);
-
- // Run kernel
- dim3 blocks(chn);
- dim3 threads(getNumThreads(sp));
- AT_DISPATCH_FLOATING_TYPES(
- x.type(), "syncbn_forward_cuda", ([&] {
- syncbn_forward_kernel<<>>(
- z.data(), x.data(),
- weight.data(), bias.data(),
- mean.data(), var.data(),
- affine, eps, num, chn, sp);
- }));
- return z;
-}
-
-template
-struct XHatOp {
- __device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s)
- : weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {}
- __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
- // xhat = (x - bias) * weight
- T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight;
- // dxhat * x_hat
- T _dz = dz[(batch * chn + plane) * sp + n];
- return Pair(_dz, _dz * _xhat);
- }
- const T weight;
- const T bias;
- const T *dz;
- const T *x;
- const int chn;
- const int sp;
-};
-
-template
-__global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x,
- const T *mean, const T *var,
- T *sum_dz, T *sum_dz_xhat,
- float eps, int num, int chn,
- int sp) {
- int plane = blockIdx.x;
- T _mean = mean[plane];
- T _var = var[plane];
- T _invstd = T(0);
- if (_var || eps) {
- _invstd = rsqrt(_var + eps);
- }
- Pair res = reduce, XHatOp>(
- XHatOp(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp);
- __syncthreads();
- if (threadIdx.x == 0) {
- // \sum(\frac{dJ}{dy_i})
- sum_dz[plane] = res.v1;
- // \sum(\frac{dJ}{dy_i}*\hat{x_i})
- sum_dz_xhat[plane] = res.v2;
- }
-}
-
-std::vector syncbn_backward_xhat_cuda(const at::Tensor &dz,
- const at::Tensor &x,
- const at::Tensor &mean,
- const at::Tensor &var,
- float eps) {
- CHECK_INPUT(dz);
- CHECK_INPUT(x);
- CHECK_INPUT(mean);
- CHECK_INPUT(var);
- // Extract dimensions
- int64_t num, chn, sp;
- get_dims(x, num, chn, sp);
- // Prepare output tensors
- auto sum_dz = at::empty({chn}, x.options());
- auto sum_dz_xhat = at::empty({chn}, x.options());
- // Run kernel
- dim3 blocks(chn);
- dim3 threads(getNumThreads(sp));
- AT_DISPATCH_FLOATING_TYPES(
- x.type(), "syncbn_backward_xhat_cuda", ([&] {
- syncbn_backward_xhat_kernel<<>>(
- dz.data(), x.data(), mean.data(),
- var.data(), sum_dz.data(),
- sum_dz_xhat.data(), eps, num, chn, sp);
- }));
- return {sum_dz, sum_dz_xhat};
-}
-
-template
-__global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight,
- const T *bias, const T *mean,
- const T *var, const T *sum_dz,
- const T *sum_dz_xhat, T *dx, T *dweight,
- T *dbias, bool affine, float eps,
- int num, int chn, int sp) {
- int plane = blockIdx.x;
- T _mean = mean[plane];
- T _var = var[plane];
- T _weight = affine ? weight[plane] : T(1);
- T _sum_dz = sum_dz[plane];
- T _sum_dz_xhat = sum_dz_xhat[plane];
- T _invstd = T(0);
- if (_var || eps) {
- _invstd = rsqrt(_var + eps);
- }
- /*
- \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} (
- N\frac{dJ}{d\hat{x_i}} -
- \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) -
- \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j})
- )
- Note : N is omitted here since it will be accumulated and
- _sum_dz and _sum_dz_xhat expected to be already normalized
- before the call.
- */
- if (dx) {
- T _mul = _weight * _invstd;
- for (int batch = 0; batch < num; ++batch) {
- for (int n = threadIdx.x; n < sp; n += blockDim.x) {
- T _dz = dz[(batch * chn + plane) * sp + n];
- T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd;
- T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul;
- dx[(batch * chn + plane) * sp + n] = _dx;
- }
- }
- }
- __syncthreads();
- if (threadIdx.x == 0) {
- if (affine) {
- T _norm = num * sp;
- dweight[plane] += _sum_dz_xhat * _norm;
- dbias[plane] += _sum_dz * _norm;
- }
- }
-}
-
-std::vector syncbn_backward_cuda(
- const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight,
- const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var,
- const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine,
- float eps) {
- CHECK_INPUT(dz);
- CHECK_INPUT(x);
- CHECK_INPUT(weight);
- CHECK_INPUT(bias);
- CHECK_INPUT(mean);
- CHECK_INPUT(var);
- CHECK_INPUT(sum_dz);
- CHECK_INPUT(sum_dz_xhat);
-
- // Extract dimensions
- int64_t num, chn, sp;
- get_dims(x, num, chn, sp);
-
- // Prepare output tensors
- auto dx = at::zeros_like(dz);
- auto dweight = at::zeros_like(weight);
- auto dbias = at::zeros_like(bias);
-
- // Run kernel
- dim3 blocks(chn);
- dim3 threads(getNumThreads(sp));
- AT_DISPATCH_FLOATING_TYPES(
- x.type(), "syncbn_backward_cuda", ([&] {
- syncbn_backward_kernel<<>>(
- dz.data(), x.data(), weight.data(),
- bias.data(), mean.data(), var.data(),
- sum_dz.data(), sum_dz_xhat.data(),
- dx.data(), dweight.data(),
- dbias.data(), affine, eps, num, chn, sp);
- }));
- return {dx, dweight, dbias};
-}
\ No newline at end of file
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h
deleted file mode 100644
index a6cb2de..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h
+++ /dev/null
@@ -1,124 +0,0 @@
-/*****************************************************************************
-
-CUDA utility funcs
-
-code referenced from : https://github.com/mapillary/inplace_abn
-
-*****************************************************************************/
-#pragma once
-
-#include
-
-// Checks
-#ifndef AT_CHECK
- #define AT_CHECK AT_ASSERT
-#endif
-#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
-#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
-#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
-
-/*
- * General settings
- */
-const int WARP_SIZE = 32;
-const int MAX_BLOCK_SIZE = 512;
-
-template
-struct Pair {
- T v1, v2;
- __device__ Pair() {}
- __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
- __device__ Pair(T v) : v1(v), v2(v) {}
- __device__ Pair(int v) : v1(v), v2(v) {}
- __device__ Pair &operator+=(const Pair &a) {
- v1 += a.v1;
- v2 += a.v2;
- return *this;
- }
-};
-
-/*
- * Utility functions
- */
-template
-__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask,
- int width = warpSize,
- unsigned int mask = 0xffffffff) {
-#if CUDART_VERSION >= 9000
- return __shfl_xor_sync(mask, value, laneMask, width);
-#else
- return __shfl_xor(value, laneMask, width);
-#endif
-}
-
-__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
-
-static int getNumThreads(int nElem) {
- int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
- for (int i = 0; i != 5; ++i) {
- if (nElem <= threadSizes[i]) {
- return threadSizes[i];
- }
- }
- return MAX_BLOCK_SIZE;
-}
-
-template
-static __device__ __forceinline__ T warpSum(T val) {
-#if __CUDA_ARCH__ >= 300
- for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
- val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
- }
-#else
- __shared__ T values[MAX_BLOCK_SIZE];
- values[threadIdx.x] = val;
- __threadfence_block();
- const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
- for (int i = 1; i < WARP_SIZE; i++) {
- val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
- }
-#endif
- return val;
-}
-
-template
-static __device__ __forceinline__ Pair warpSum(Pair value) {
- value.v1 = warpSum(value.v1);
- value.v2 = warpSum(value.v2);
- return value;
-}
-
-template
-__device__ T reduce(Op op, int plane, int N, int C, int S) {
- T sum = (T)0;
- for (int batch = 0; batch < N; ++batch) {
- for (int x = threadIdx.x; x < S; x += blockDim.x) {
- sum += op(batch, plane, x);
- }
- }
-
- // sum over NumThreads within a warp
- sum = warpSum(sum);
-
- // 'transpose', and reduce within warp again
- __shared__ T shared[32];
- __syncthreads();
- if (threadIdx.x % WARP_SIZE == 0) {
- shared[threadIdx.x / WARP_SIZE] = sum;
- }
- if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
- // zero out the other entries in shared
- shared[threadIdx.x] = (T)0;
- }
- __syncthreads();
- if (threadIdx.x / WARP_SIZE == 0) {
- sum = warpSum(shared[threadIdx.x]);
- if (threadIdx.x == 0) {
- shared[0] = sum;
- }
- }
- __syncthreads();
-
- // Everyone picks it up, should be broadcast into the whole gradInput
- return shared[0];
-}
\ No newline at end of file
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h
deleted file mode 100644
index 1d70761..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h
+++ /dev/null
@@ -1,24 +0,0 @@
-/*****************************************************************************
-
-CUDA SyncBN code
-
-*****************************************************************************/
-#pragma once
-#include
-#include
-
-/// Sync-BN
-std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x);
-at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight,
- const at::Tensor& bias, const at::Tensor& mean,
- const at::Tensor& var, bool affine, float eps);
-std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz,
- const at::Tensor& x,
- const at::Tensor& mean,
- const at::Tensor& var,
- float eps);
-std::vector syncbn_backward_cuda(
- const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
- const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
- const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
- float eps);
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp b/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp
deleted file mode 100644
index 9c2ecf1..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp
+++ /dev/null
@@ -1,10 +0,0 @@
-#include "bn.h"
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation");
- m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation");
- m.def("syncbn_backward_xhat", &syncbn_backward_xhat,
- "First part of SyncBN backward computation");
- m.def("syncbn_backward", &syncbn_backward,
- "Second part of SyncBN backward computation");
-}
\ No newline at end of file
diff --git a/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py b/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py
deleted file mode 100644
index 867a432..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py
+++ /dev/null
@@ -1,137 +0,0 @@
-"""
-/*****************************************************************************/
-
-BatchNorm2dSync with multi-gpu
-
-code referenced from : https://github.com/mapillary/inplace_abn
-
-/*****************************************************************************/
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import torch.cuda.comm as comm
-from torch.autograd import Function
-from torch.autograd.function import once_differentiable
-from ._csrc import _backend
-
-
-def _count_samples(x):
- count = 1
- for i, s in enumerate(x.size()):
- if i != 1:
- count *= s
- return count
-
-
-class BatchNorm2dSyncFunc(Function):
-
- @staticmethod
- def forward(ctx, x, weight, bias, running_mean, running_var,
- extra, compute_stats=True, momentum=0.1, eps=1e-05):
- def _parse_extra(ctx, extra):
- ctx.is_master = extra["is_master"]
- if ctx.is_master:
- ctx.master_queue = extra["master_queue"]
- ctx.worker_queues = extra["worker_queues"]
- ctx.worker_ids = extra["worker_ids"]
- else:
- ctx.master_queue = extra["master_queue"]
- ctx.worker_queue = extra["worker_queue"]
- # Save context
- if extra is not None:
- _parse_extra(ctx, extra)
- ctx.compute_stats = compute_stats
- ctx.momentum = momentum
- ctx.eps = eps
- ctx.affine = weight is not None and bias is not None
- if ctx.compute_stats:
- N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
- assert N > 1
- # 1. compute sum(x) and sum(x^2)
- xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
- if ctx.is_master:
- xsums, xsqsums = [xsum], [xsqsum]
- # master : gatther all sum(x) and sum(x^2) from slaves
- for _ in range(ctx.master_queue.maxsize):
- xsum_w, xsqsum_w = ctx.master_queue.get()
- ctx.master_queue.task_done()
- xsums.append(xsum_w)
- xsqsums.append(xsqsum_w)
- xsum = comm.reduce_add(xsums)
- xsqsum = comm.reduce_add(xsqsums)
- mean = xsum / N
- sumvar = xsqsum - xsum * mean
- var = sumvar / N
- uvar = sumvar / (N - 1)
- # master : broadcast global mean, variance to all slaves
- tensors = comm.broadcast_coalesced(
- (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
- for ts, queue in zip(tensors[1:], ctx.worker_queues):
- queue.put(ts)
- else:
- # slave : send sum(x) and sum(x^2) to master
- ctx.master_queue.put((xsum, xsqsum))
- # slave : get global mean and variance
- mean, uvar, var = ctx.worker_queue.get()
- ctx.worker_queue.task_done()
-
- # Update running stats
- running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
- running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
- ctx.N = N
- ctx.save_for_backward(x, weight, bias, mean, var)
- else:
- mean, var = running_mean, running_var
-
- # do batch norm forward
- z = _backend.syncbn_forward(x, weight, bias, mean, var,
- ctx.affine, ctx.eps)
- return z
-
- @staticmethod
- @once_differentiable
- def backward(ctx, dz):
- x, weight, bias, mean, var = ctx.saved_tensors
- dz = dz.contiguous()
-
- # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
- sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat(
- dz, x, mean, var, ctx.eps)
- if ctx.is_master:
- sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
- # master : gatther from slaves
- for _ in range(ctx.master_queue.maxsize):
- sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
- ctx.master_queue.task_done()
- sum_dzs.append(sum_dz_w)
- sum_dz_xhats.append(sum_dz_xhat_w)
- # master : compute global stats
- sum_dz = comm.reduce_add(sum_dzs)
- sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
- sum_dz /= ctx.N
- sum_dz_xhat /= ctx.N
- # master : broadcast global stats
- tensors = comm.broadcast_coalesced(
- (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
- for ts, queue in zip(tensors[1:], ctx.worker_queues):
- queue.put(ts)
- else:
- # slave : send to master
- ctx.master_queue.put((sum_dz, sum_dz_xhat))
- # slave : get global stats
- sum_dz, sum_dz_xhat = ctx.worker_queue.get()
- ctx.worker_queue.task_done()
-
- # do batch norm backward
- dx, dweight, dbias = _backend.syncbn_backward(
- dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat,
- ctx.affine, ctx.eps)
-
- return dx, dweight, dbias, \
- None, None, None, None, None, None
-
-batchnorm2d_sync = BatchNorm2dSyncFunc.apply
-
-__all__ = ["batchnorm2d_sync"]
diff --git a/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py b/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py
deleted file mode 100644
index 5c5aca9..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .syncbn import *
diff --git a/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py b/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py
deleted file mode 100644
index b118c9d..0000000
--- a/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py
+++ /dev/null
@@ -1,148 +0,0 @@
-"""
-/*****************************************************************************/
-
-BatchNorm2dSync with multi-gpu
-
-/*****************************************************************************/
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-try:
- # python 3
- from queue import Queue
-except ImportError:
- # python 2
- from Queue import Queue
-
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from torch.nn.parameter import Parameter
-from isegm.model.syncbn.modules.functional import batchnorm2d_sync
-
-
-class _BatchNorm(nn.Module):
- """
- Customized BatchNorm from nn.BatchNorm
- >> added freeze attribute to enable bn freeze.
- """
-
- def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
- track_running_stats=True):
- super(_BatchNorm, self).__init__()
- self.num_features = num_features
- self.eps = eps
- self.momentum = momentum
- self.affine = affine
- self.track_running_stats = track_running_stats
- self.freezed = False
- if self.affine:
- self.weight = Parameter(torch.Tensor(num_features))
- self.bias = Parameter(torch.Tensor(num_features))
- else:
- self.register_parameter('weight', None)
- self.register_parameter('bias', None)
- if self.track_running_stats:
- self.register_buffer('running_mean', torch.zeros(num_features))
- self.register_buffer('running_var', torch.ones(num_features))
- else:
- self.register_parameter('running_mean', None)
- self.register_parameter('running_var', None)
- self.reset_parameters()
-
- def reset_parameters(self):
- if self.track_running_stats:
- self.running_mean.zero_()
- self.running_var.fill_(1)
- if self.affine:
- self.weight.data.uniform_()
- self.bias.data.zero_()
-
- def _check_input_dim(self, input):
- return NotImplemented
-
- def forward(self, input):
- self._check_input_dim(input)
-
- compute_stats = not self.freezed and \
- self.training and self.track_running_stats
-
- ret = F.batch_norm(input, self.running_mean, self.running_var,
- self.weight, self.bias, compute_stats,
- self.momentum, self.eps)
- return ret
-
- def extra_repr(self):
- return '{num_features}, eps={eps}, momentum={momentum}, '\
- 'affine={affine}, ' \
- 'track_running_stats={track_running_stats}'.format(
- **self.__dict__)
-
-
-class BatchNorm2dNoSync(_BatchNorm):
- """
- Equivalent to nn.BatchNorm2d
- """
-
- def _check_input_dim(self, input):
- if input.dim() != 4:
- raise ValueError('expected 4D input (got {}D input)'
- .format(input.dim()))
-
-
-class BatchNorm2dSync(BatchNorm2dNoSync):
- """
- BatchNorm2d with automatic multi-GPU Sync
- """
-
- def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
- track_running_stats=True):
- super(BatchNorm2dSync, self).__init__(
- num_features, eps=eps, momentum=momentum, affine=affine,
- track_running_stats=track_running_stats)
- self.sync_enabled = True
- self.devices = list(range(torch.cuda.device_count()))
- if len(self.devices) > 1:
- # Initialize queues
- self.worker_ids = self.devices[1:]
- self.master_queue = Queue(len(self.worker_ids))
- self.worker_queues = [Queue(1) for _ in self.worker_ids]
-
- def forward(self, x):
- compute_stats = not self.freezed and \
- self.training and self.track_running_stats
- if self.sync_enabled and compute_stats and len(self.devices) > 1:
- if x.get_device() == self.devices[0]:
- # Master mode
- extra = {
- "is_master": True,
- "master_queue": self.master_queue,
- "worker_queues": self.worker_queues,
- "worker_ids": self.worker_ids
- }
- else:
- # Worker mode
- extra = {
- "is_master": False,
- "master_queue": self.master_queue,
- "worker_queue": self.worker_queues[
- self.worker_ids.index(x.get_device())]
- }
- return batchnorm2d_sync(x, self.weight, self.bias,
- self.running_mean, self.running_var,
- extra, compute_stats, self.momentum,
- self.eps)
- return super(BatchNorm2dSync, self).forward(x)
-
- def __repr__(self):
- """repr"""
- rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
- 'affine={affine}, ' \
- 'track_running_stats={track_running_stats},' \
- 'devices={devices})'
- return rep.format(name=self.__class__.__name__, **self.__dict__)
-
-#BatchNorm2d = BatchNorm2dNoSync
-BatchNorm2d = BatchNorm2dSync
diff --git a/inference/interact/fbrs/utils/__init__.py b/inference/interact/fbrs/utils/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/inference/interact/fbrs/utils/cython/__init__.py b/inference/interact/fbrs/utils/cython/__init__.py
deleted file mode 100644
index eb66bdb..0000000
--- a/inference/interact/fbrs/utils/cython/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# noinspection PyUnresolvedReferences
-from .dist_maps import get_dist_maps
\ No newline at end of file
diff --git a/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx b/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx
deleted file mode 100644
index 779a7f0..0000000
--- a/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx
+++ /dev/null
@@ -1,63 +0,0 @@
-import numpy as np
-cimport cython
-cimport numpy as np
-from libc.stdlib cimport malloc, free
-
-ctypedef struct qnode:
- int row
- int col
- int layer
- int orig_row
- int orig_col
-
-@cython.infer_types(True)
-@cython.boundscheck(False)
-@cython.wraparound(False)
-@cython.nonecheck(False)
-def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points,
- int height, int width, float norm_delimeter):
- cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \
- np.full((2, height, width), 1e6, dtype=np.float32, order="C")
-
- cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0]
- cdef int i, j, x, y, dx, dy
- cdef qnode v
- cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode))
- cdef int qhead = 0, qtail = -1
- cdef float ndist
-
- for i in range(points.shape[0]):
- x, y = round(points[i, 0]), round(points[i, 1])
- if x >= 0:
- qtail += 1
- q[qtail].row = x
- q[qtail].col = y
- q[qtail].orig_row = x
- q[qtail].orig_col = y
- if i >= points.shape[0] / 2:
- q[qtail].layer = 1
- else:
- q[qtail].layer = 0
- dist_maps[q[qtail].layer, x, y] = 0
-
- while qtail - qhead + 1 > 0:
- v = q[qhead]
- qhead += 1
-
- for k in range(4):
- x = v.row + dxy[2 * k]
- y = v.col + dxy[2 * k + 1]
-
- ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2
- if (x >= 0 and y >= 0 and x < height and y < width and
- dist_maps[v.layer, x, y] > ndist):
- qtail += 1
- q[qtail].orig_col = v.orig_col
- q[qtail].orig_row = v.orig_row
- q[qtail].layer = v.layer
- q[qtail].row = x
- q[qtail].col = y
- dist_maps[v.layer, x, y] = ndist
-
- free(q)
- return dist_maps
diff --git a/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld b/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld
deleted file mode 100644
index bd44517..0000000
--- a/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld
+++ /dev/null
@@ -1,7 +0,0 @@
-import numpy
-
-def make_ext(modname, pyxfilename):
- from distutils.extension import Extension
- return Extension(modname, [pyxfilename],
- include_dirs=[numpy.get_include()],
- extra_compile_args=['-O3'], language='c++')
diff --git a/inference/interact/fbrs/utils/cython/dist_maps.py b/inference/interact/fbrs/utils/cython/dist_maps.py
deleted file mode 100644
index 8ffa1e3..0000000
--- a/inference/interact/fbrs/utils/cython/dist_maps.py
+++ /dev/null
@@ -1,3 +0,0 @@
-import pyximport; pyximport.install(pyximport=True, language_level=3)
-# noinspection PyUnresolvedReferences
-from ._get_dist_maps import get_dist_maps
\ No newline at end of file
diff --git a/inference/interact/fbrs/utils/misc.py b/inference/interact/fbrs/utils/misc.py
deleted file mode 100644
index 65ce96d..0000000
--- a/inference/interact/fbrs/utils/misc.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from functools import partial
-
-import torch
-import numpy as np
-
-
-def get_dims_with_exclusion(dim, exclude=None):
- dims = list(range(dim))
- if exclude is not None:
- dims.remove(exclude)
-
- return dims
-
-
-def get_unique_labels(mask):
- return np.nonzero(np.bincount(mask.flatten() + 1))[0] - 1
-
-
-def get_bbox_from_mask(mask):
- rows = np.any(mask, axis=1)
- cols = np.any(mask, axis=0)
- rmin, rmax = np.where(rows)[0][[0, -1]]
- cmin, cmax = np.where(cols)[0][[0, -1]]
-
- return rmin, rmax, cmin, cmax
-
-
-def expand_bbox(bbox, expand_ratio, min_crop_size=None):
- rmin, rmax, cmin, cmax = bbox
- rcenter = 0.5 * (rmin + rmax)
- ccenter = 0.5 * (cmin + cmax)
- height = expand_ratio * (rmax - rmin + 1)
- width = expand_ratio * (cmax - cmin + 1)
- if min_crop_size is not None:
- height = max(height, min_crop_size)
- width = max(width, min_crop_size)
-
- rmin = int(round(rcenter - 0.5 * height))
- rmax = int(round(rcenter + 0.5 * height))
- cmin = int(round(ccenter - 0.5 * width))
- cmax = int(round(ccenter + 0.5 * width))
-
- return rmin, rmax, cmin, cmax
-
-
-def clamp_bbox(bbox, rmin, rmax, cmin, cmax):
- return (max(rmin, bbox[0]), min(rmax, bbox[1]),
- max(cmin, bbox[2]), min(cmax, bbox[3]))
-
-
-def get_bbox_iou(b1, b2):
- h_iou = get_segments_iou(b1[:2], b2[:2])
- w_iou = get_segments_iou(b1[2:4], b2[2:4])
- return h_iou * w_iou
-
-
-def get_segments_iou(s1, s2):
- a, b = s1
- c, d = s2
- intersection = max(0, min(b, d) - max(a, c) + 1)
- union = max(1e-6, max(b, d) - min(a, c) + 1)
- return intersection / union
diff --git a/inference/interact/fbrs/utils/vis.py b/inference/interact/fbrs/utils/vis.py
deleted file mode 100644
index 4c1a291..0000000
--- a/inference/interact/fbrs/utils/vis.py
+++ /dev/null
@@ -1,129 +0,0 @@
-from functools import lru_cache
-
-import cv2
-import numpy as np
-
-
-def visualize_instances(imask, bg_color=255,
- boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8):
- num_objects = imask.max() + 1
- palette = get_palette(num_objects)
- if bg_color is not None:
- palette[0] = bg_color
-
- result = palette[imask].astype(np.uint8)
- if boundaries_color is not None:
- boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width)
- tresult = result.astype(np.float32)
- tresult[boundaries_mask] = boundaries_color
- tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result
- result = tresult.astype(np.uint8)
-
- return result
-
-
-@lru_cache(maxsize=16)
-def get_palette(num_cls):
- palette = np.zeros(3 * num_cls, dtype=np.int32)
-
- for j in range(0, num_cls):
- lab = j
- i = 0
-
- while lab > 0:
- palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i))
- palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i))
- palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i))
- i = i + 1
- lab >>= 3
-
- return palette.reshape((-1, 3))
-
-
-def visualize_mask(mask, num_cls):
- palette = get_palette(num_cls)
- mask[mask == -1] = 0
-
- return palette[mask].astype(np.uint8)
-
-
-def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1):
- proposal_map, colors, candidates = proposals_info
-
- proposal_map = draw_probmap(proposal_map)
- for x, y in candidates:
- proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1)
-
- return proposal_map
-
-
-def draw_probmap(x):
- return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT)
-
-
-def draw_points(image, points, color, radius=3):
- image = image.copy()
- for p in points:
- image = cv2.circle(image, (int(p[1]), int(p[0])), radius, color, -1)
-
- return image
-
-
-def draw_instance_map(x, palette=None):
- num_colors = x.max() + 1
- if palette is None:
- palette = get_palette(num_colors)
-
- return palette[x].astype(np.uint8)
-
-
-def blend_mask(image, mask, alpha=0.6):
- if mask.min() == -1:
- mask = mask.copy() + 1
-
- imap = draw_instance_map(mask)
- result = (image * (1 - alpha) + alpha * imap).astype(np.uint8)
- return result
-
-
-def get_boundaries(instances_masks, boundaries_width=1):
- boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool)
-
- for obj_id in np.unique(instances_masks.flatten()):
- if obj_id == 0:
- continue
-
- obj_mask = instances_masks == obj_id
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
- inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool)
-
- obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask))
- boundaries = np.logical_or(boundaries, obj_boundary)
- return boundaries
-
-
-def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0),
- neg_color=(255, 0, 0), radius=4):
- result = img.copy()
-
- if mask is not None:
- palette = get_palette(np.max(mask) + 1)
- rgb_mask = palette[mask.astype(np.uint8)]
-
- mask_region = (mask > 0).astype(np.uint8)
- result = result * (1 - mask_region[:, :, np.newaxis]) + \
- (1 - alpha) * mask_region[:, :, np.newaxis] * result + \
- alpha * rgb_mask
- result = result.astype(np.uint8)
-
- # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8)
-
- if clicks_list is not None and len(clicks_list) > 0:
- pos_points = [click.coords for click in clicks_list if click.is_positive]
- neg_points = [click.coords for click in clicks_list if not click.is_positive]
-
- result = draw_points(result, pos_points, pos_color, radius=radius)
- result = draw_points(result, neg_points, neg_color, radius=radius)
-
- return result
-
diff --git a/inference/interact/fbrs_controller.py b/inference/interact/fbrs_controller.py
deleted file mode 100644
index 1fe9ca4..0000000
--- a/inference/interact/fbrs_controller.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import torch
-from .fbrs.controller import InteractiveController
-from .fbrs.inference import utils
-
-
-class FBRSController:
- def __init__(self, checkpoint_path, device='cuda:0', max_size=800):
- model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True, norm_radius=260)
-
- # Predictor params
- zoomin_params = {
- 'skip_clicks': 1,
- 'target_size': 480,
- 'expansion_ratio': 1.4,
- }
-
- predictor_params = {
- 'brs_mode': 'f-BRS-B',
- 'prob_thresh': 0.5,
- 'zoom_in_params': zoomin_params,
- 'predictor_params': {
- 'net_clicks_limit': 8,
- 'max_size': 800,
- },
- 'brs_opt_func_params': {'min_iou_diff': 1e-3},
- 'lbfgs_params': {'maxfun': 20}
- }
-
- self.controller = InteractiveController(model, device, predictor_params)
- self.anchored = False
- self.device = device
-
- def unanchor(self):
- self.anchored = False
-
- def interact(self, image, x, y, is_positive):
- image = image.to(self.device, non_blocking=True)
- if not self.anchored:
- self.controller.set_image(image)
- self.controller.reset_predictor()
- self.anchored = True
-
- self.controller.add_click(x, y, is_positive)
- # return self.controller.result_mask
- # return self.controller.probs_history[-1][1]
- return (self.controller.probs_history[-1][1]>0.5).float()
-
- def undo(self):
- self.controller.undo_click()
- if len(self.controller.probs_history) == 0:
- return None
- else:
- return (self.controller.probs_history[-1][1]>0.5).float()
\ No newline at end of file
diff --git a/inference/interact/gui.py b/inference/interact/gui.py
deleted file mode 100644
index 039a382..0000000
--- a/inference/interact/gui.py
+++ /dev/null
@@ -1,933 +0,0 @@
-"""
-Based on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN
-(which is based on https://github.com/seoungwugoh/ivs-demo)
-
-This version is much simplified.
-In this repo, we don't have
-- local control
-- fusion module
-- undo
-- timers
-
-but with XMem as the backbone and is more memory (for both CPU and GPU) friendly
-"""
-
-import functools
-
-import os
-import cv2
-# fix conflicts between qt5 and cv2
-os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH")
-
-import numpy as np
-import torch
-
-from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox,
- QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog,
- QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton)
-
-from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon
-from PyQt5.QtCore import Qt, QTimer
-
-from model.network import XMem
-
-from inference.inference_core import InferenceCore
-from .s2m_controller import S2MController
-from .fbrs_controller import FBRSController
-
-from .interactive_utils import *
-from .interaction import *
-from .resource_manager import ResourceManager
-from .gui_utils import *
-
-
-class App(QWidget):
- def __init__(self, net: XMem,
- resource_manager: ResourceManager,
- s2m_ctrl:S2MController,
- fbrs_ctrl:FBRSController, config):
- super().__init__()
-
- self.initialized = False
- self.num_objects = config['num_objects']
- self.s2m_controller = s2m_ctrl
- self.fbrs_controller = fbrs_ctrl
- self.config = config
- self.processor = InferenceCore(net, config)
- self.processor.set_all_labels(list(range(1, self.num_objects+1)))
- self.res_man = resource_manager
-
- self.num_frames = len(self.res_man)
- self.height, self.width = self.res_man.h, self.res_man.w
-
- # set window
- self.setWindowTitle('XMem Demo')
- self.setGeometry(100, 100, self.width, self.height+100)
- self.setWindowIcon(QIcon('docs/icon.png'))
-
- # some buttons
- self.play_button = QPushButton('Play Video')
- self.play_button.clicked.connect(self.on_play_video)
- self.commit_button = QPushButton('Commit')
- self.commit_button.clicked.connect(self.on_commit)
-
- self.forward_run_button = QPushButton('Forward Propagate')
- self.forward_run_button.clicked.connect(self.on_forward_propagation)
- self.forward_run_button.setMinimumWidth(200)
-
- self.backward_run_button = QPushButton('Backward Propagate')
- self.backward_run_button.clicked.connect(self.on_backward_propagation)
- self.backward_run_button.setMinimumWidth(200)
-
- self.reset_button = QPushButton('Reset Frame')
- self.reset_button.clicked.connect(self.on_reset_mask)
-
- # LCD
- self.lcd = QTextEdit()
- self.lcd.setReadOnly(True)
- self.lcd.setMaximumHeight(28)
- self.lcd.setMaximumWidth(120)
- self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1))
-
- # timeline slider
- self.tl_slider = QSlider(Qt.Horizontal)
- self.tl_slider.valueChanged.connect(self.tl_slide)
- self.tl_slider.setMinimum(0)
- self.tl_slider.setMaximum(self.num_frames-1)
- self.tl_slider.setValue(0)
- self.tl_slider.setTickPosition(QSlider.TicksBelow)
- self.tl_slider.setTickInterval(1)
-
- # brush size slider
- self.brush_label = QLabel()
- self.brush_label.setAlignment(Qt.AlignCenter)
- self.brush_label.setMinimumWidth(100)
-
- self.brush_slider = QSlider(Qt.Horizontal)
- self.brush_slider.valueChanged.connect(self.brush_slide)
- self.brush_slider.setMinimum(1)
- self.brush_slider.setMaximum(100)
- self.brush_slider.setValue(3)
- self.brush_slider.setTickPosition(QSlider.TicksBelow)
- self.brush_slider.setTickInterval(2)
- self.brush_slider.setMinimumWidth(300)
-
- # combobox
- self.combo = QComboBox(self)
- self.combo.addItem("davis")
- self.combo.addItem("fade")
- self.combo.addItem("light")
- self.combo.addItem("popup")
- self.combo.addItem("layered")
- self.combo.currentTextChanged.connect(self.set_viz_mode)
-
- self.save_visualization_checkbox = QCheckBox(self)
- self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle)
- self.save_visualization_checkbox.setChecked(False)
- self.save_visualization = False
-
- # Radio buttons for type of interactions
- self.curr_interaction = 'Click'
- self.interaction_group = QButtonGroup()
- self.radio_fbrs = QRadioButton('Click')
- self.radio_s2m = QRadioButton('Scribble')
- self.radio_free = QRadioButton('Free')
- self.interaction_group.addButton(self.radio_fbrs)
- self.interaction_group.addButton(self.radio_s2m)
- self.interaction_group.addButton(self.radio_free)
- self.radio_fbrs.toggled.connect(self.interaction_radio_clicked)
- self.radio_s2m.toggled.connect(self.interaction_radio_clicked)
- self.radio_free.toggled.connect(self.interaction_radio_clicked)
- self.radio_fbrs.toggle()
-
- # Main canvas -> QLabel
- self.main_canvas = QLabel()
- self.main_canvas.setSizePolicy(QSizePolicy.Expanding,
- QSizePolicy.Expanding)
- self.main_canvas.setAlignment(Qt.AlignCenter)
- self.main_canvas.setMinimumSize(100, 100)
-
- self.main_canvas.mousePressEvent = self.on_mouse_press
- self.main_canvas.mouseMoveEvent = self.on_mouse_motion
- self.main_canvas.setMouseTracking(True) # Required for all-time tracking
- self.main_canvas.mouseReleaseEvent = self.on_mouse_release
-
- # Minimap -> Also a QLbal
- self.minimap = QLabel()
- self.minimap.setSizePolicy(QSizePolicy.Expanding,
- QSizePolicy.Expanding)
- self.minimap.setAlignment(Qt.AlignTop)
- self.minimap.setMinimumSize(100, 100)
-
- # Zoom-in buttons
- self.zoom_p_button = QPushButton('Zoom +')
- self.zoom_p_button.clicked.connect(self.on_zoom_plus)
- self.zoom_m_button = QPushButton('Zoom -')
- self.zoom_m_button.clicked.connect(self.on_zoom_minus)
-
- # Parameters setting
- self.clear_mem_button = QPushButton('Clear memory')
- self.clear_mem_button.clicked.connect(self.on_clear_memory)
-
- self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size')
- self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size')
- self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)')
- self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)')
-
- self.update_memory_size()
- self.update_gpu_usage()
-
- self.work_mem_min, self.work_mem_min_layout = create_parameter_box(1, 100, 'Min. working memory frames',
- callback=self.on_work_min_change)
- self.work_mem_max, self.work_mem_max_layout = create_parameter_box(2, 100, 'Max. working memory frames',
- callback=self.on_work_max_change)
- self.long_mem_max, self.long_mem_max_layout = create_parameter_box(1000, 100000,
- 'Max. long-term memory size', step=1000, callback=self.update_config)
- self.num_prototypes_box, self.num_prototypes_box_layout = create_parameter_box(32, 1280,
- 'Number of prototypes', step=32, callback=self.update_config)
- self.mem_every_box, self.mem_every_box_layout = create_parameter_box(1, 100, 'Memory frame every (r)',
- callback=self.update_config)
-
- self.work_mem_min.setValue(self.processor.memory.min_mt_frames)
- self.work_mem_max.setValue(self.processor.memory.max_mt_frames)
- self.long_mem_max.setValue(self.processor.memory.max_long_elements)
- self.num_prototypes_box.setValue(self.processor.memory.num_prototypes)
- self.mem_every_box.setValue(self.processor.mem_every)
-
- # import mask/layer
- self.import_mask_button = QPushButton('Import mask')
- self.import_mask_button.clicked.connect(self.on_import_mask)
- self.import_layer_button = QPushButton('Import layer')
- self.import_layer_button.clicked.connect(self.on_import_layer)
-
- # Console on the GUI
- self.console = QPlainTextEdit()
- self.console.setReadOnly(True)
- self.console.setMinimumHeight(100)
- self.console.setMaximumHeight(100)
-
- # navigator
- navi = QHBoxLayout()
- navi.addWidget(self.lcd)
- navi.addWidget(self.play_button)
-
- interact_subbox = QVBoxLayout()
- interact_topbox = QHBoxLayout()
- interact_botbox = QHBoxLayout()
- interact_topbox.setAlignment(Qt.AlignCenter)
- interact_topbox.addWidget(self.radio_s2m)
- interact_topbox.addWidget(self.radio_fbrs)
- interact_topbox.addWidget(self.radio_free)
- interact_topbox.addWidget(self.brush_label)
- interact_botbox.addWidget(self.brush_slider)
- interact_subbox.addLayout(interact_topbox)
- interact_subbox.addLayout(interact_botbox)
- navi.addLayout(interact_subbox)
-
- navi.addStretch(1)
- navi.addWidget(self.reset_button)
-
- navi.addStretch(1)
- navi.addWidget(QLabel('Overlay Mode'))
- navi.addWidget(self.combo)
- navi.addWidget(QLabel('Save overlay during propagation'))
- navi.addWidget(self.save_visualization_checkbox)
- navi.addStretch(1)
- navi.addWidget(self.commit_button)
- navi.addWidget(self.forward_run_button)
- navi.addWidget(self.backward_run_button)
-
- # Drawing area, main canvas and minimap
- draw_area = QHBoxLayout()
- draw_area.addWidget(self.main_canvas, 4)
-
- # Minimap area
- minimap_area = QVBoxLayout()
- minimap_area.setAlignment(Qt.AlignTop)
- mini_label = QLabel('Minimap')
- mini_label.setAlignment(Qt.AlignTop)
- minimap_area.addWidget(mini_label)
-
- # Minimap zooming
- minimap_ctrl = QHBoxLayout()
- minimap_ctrl.setAlignment(Qt.AlignTop)
- minimap_ctrl.addWidget(self.zoom_p_button)
- minimap_ctrl.addWidget(self.zoom_m_button)
- minimap_area.addLayout(minimap_ctrl)
- minimap_area.addWidget(self.minimap)
-
- # Parameters
- minimap_area.addLayout(self.work_mem_gauge_layout)
- minimap_area.addLayout(self.long_mem_gauge_layout)
- minimap_area.addLayout(self.gpu_mem_gauge_layout)
- minimap_area.addLayout(self.torch_mem_gauge_layout)
- minimap_area.addWidget(self.clear_mem_button)
- minimap_area.addLayout(self.work_mem_min_layout)
- minimap_area.addLayout(self.work_mem_max_layout)
- minimap_area.addLayout(self.long_mem_max_layout)
- minimap_area.addLayout(self.num_prototypes_box_layout)
- minimap_area.addLayout(self.mem_every_box_layout)
-
- # import mask/layer
- import_area = QHBoxLayout()
- import_area.setAlignment(Qt.AlignTop)
- import_area.addWidget(self.import_mask_button)
- import_area.addWidget(self.import_layer_button)
- minimap_area.addLayout(import_area)
-
- # console
- minimap_area.addWidget(self.console)
-
- draw_area.addLayout(minimap_area, 1)
-
- layout = QVBoxLayout()
- layout.addLayout(draw_area)
- layout.addWidget(self.tl_slider)
- layout.addLayout(navi)
- self.setLayout(layout)
-
- # timer to play video
- self.timer = QTimer()
- self.timer.setSingleShot(False)
-
- # timer to update GPU usage
- self.gpu_timer = QTimer()
- self.gpu_timer.setSingleShot(False)
- self.gpu_timer.timeout.connect(self.on_gpu_timer)
- self.gpu_timer.setInterval(2000)
- self.gpu_timer.start()
-
- # current frame info
- self.curr_frame_dirty = False
- self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
- self.current_image_torch = None
- self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8)
- self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).cuda()
-
- # initialize visualization
- self.viz_mode = 'davis'
- self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
- self.vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
- self.brush_vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
- self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
- self.cursur = 0
- self.on_showing = None
-
- # Zoom parameters
- self.zoom_pixels = 150
-
- # initialize action
- self.interaction = None
- self.pressed = False
- self.right_click = False
- self.current_object = 1
- self.last_ex = self.last_ey = 0
-
- self.propagating = False
-
- # Objects shortcuts
- for i in range(1, self.num_objects+1):
- QShortcut(QKeySequence(str(i)), self).activated.connect(functools.partial(self.hit_number_key, i))
-
- # <- and -> shortcuts
- QShortcut(QKeySequence(Qt.Key_Left), self).activated.connect(self.on_prev_frame)
- QShortcut(QKeySequence(Qt.Key_Right), self).activated.connect(self.on_next_frame)
-
- self.interacted_prob = None
- self.overlay_layer = None
- self.overlay_layer_torch = None
-
- # the object id used for popup/layered overlay
- self.vis_target_objects = [1]
- # try to load the default overlay
- self._try_load_layer('./docs/ECCV-logo.png')
-
- self.load_current_image_mask()
- self.show_current_frame()
- self.show()
-
- self.console_push_text('Initialized.')
- self.initialized = True
-
- def resizeEvent(self, event):
- self.show_current_frame()
-
- def console_push_text(self, text):
- self.console.moveCursor(QTextCursor.End)
- self.console.insertPlainText(text+'\n')
-
- def interaction_radio_clicked(self, event):
- self.last_interaction = self.curr_interaction
- if self.radio_s2m.isChecked():
- self.curr_interaction = 'Scribble'
- self.brush_size = 3
- self.brush_slider.setDisabled(True)
- elif self.radio_fbrs.isChecked():
- self.curr_interaction = 'Click'
- self.brush_size = 3
- self.brush_slider.setDisabled(True)
- elif self.radio_free.isChecked():
- self.brush_slider.setDisabled(False)
- self.brush_slide()
- self.curr_interaction = 'Free'
- if self.curr_interaction == 'Scribble':
- self.commit_button.setEnabled(True)
- else:
- self.commit_button.setEnabled(False)
-
- def load_current_image_mask(self, no_mask=False):
- self.current_image = self.res_man.get_image(self.cursur)
- self.current_image_torch = None
-
- if not no_mask:
- loaded_mask = self.res_man.get_mask(self.cursur)
- if loaded_mask is None:
- self.current_mask.fill(0)
- else:
- self.current_mask = loaded_mask.copy()
- self.current_prob = None
-
- def load_current_torch_image_mask(self, no_mask=False):
- if self.current_image_torch is None:
- self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image)
-
- if self.current_prob is None and not no_mask:
- self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda()
-
- def compose_current_im(self):
- self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask,
- self.overlay_layer, self.vis_target_objects)
-
- def update_interact_vis(self):
- # Update the interactions without re-computing the overlay
- height, width, channel = self.viz.shape
- bytesPerLine = 3 * width
-
- vis_map = self.vis_map
- vis_alpha = self.vis_alpha
- brush_vis_map = self.brush_vis_map
- brush_vis_alpha = self.brush_vis_alpha
-
- self.viz_with_stroke = self.viz*(1-vis_alpha) + vis_map*vis_alpha
- self.viz_with_stroke = self.viz_with_stroke*(1-brush_vis_alpha) + brush_vis_map*brush_vis_alpha
- self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8)
-
- qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine, QImage.Format_RGB888)
- self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
- Qt.KeepAspectRatio, Qt.FastTransformation)))
-
- self.main_canvas_size = self.main_canvas.size()
- self.image_size = qImg.size()
-
- def update_minimap(self):
- ex, ey = self.last_ex, self.last_ey
- r = self.zoom_pixels//2
- ex = int(round(max(r, min(self.width-r, ex))))
- ey = int(round(max(r, min(self.height-r, ey))))
-
- patch = self.viz_with_stroke[ey-r:ey+r, ex-r:ex+r, :].astype(np.uint8)
-
- height, width, channel = patch.shape
- bytesPerLine = 3 * width
- qImg = QImage(patch.data, width, height, bytesPerLine, QImage.Format_RGB888)
- self.minimap.setPixmap(QPixmap(qImg.scaled(self.minimap.size(),
- Qt.KeepAspectRatio, Qt.FastTransformation)))
-
- def update_current_image_fast(self):
- # fast path, uses gpu. Changes the image in-place to avoid copying
- self.viz = get_visualization_torch(self.viz_mode, self.current_image_torch_no_norm,
- self.current_prob, self.overlay_layer_torch, self.vis_target_objects)
- if self.save_visualization:
- self.res_man.save_visualization(self.cursur, self.viz)
-
- height, width, channel = self.viz.shape
- bytesPerLine = 3 * width
-
- qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888)
- self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
- Qt.KeepAspectRatio, Qt.FastTransformation)))
-
- def show_current_frame(self, fast=False):
- # Re-compute overlay and show the image
- if fast:
- self.update_current_image_fast()
- else:
- self.compose_current_im()
- self.update_interact_vis()
- self.update_minimap()
-
- self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1))
- self.tl_slider.setValue(self.cursur)
-
- def pixel_pos_to_image_pos(self, x, y):
- # Un-scale and un-pad the label coordinates into image coordinates
- oh, ow = self.image_size.height(), self.image_size.width()
- nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width()
-
- h_ratio = nh/oh
- w_ratio = nw/ow
- dominate_ratio = min(h_ratio, w_ratio)
-
- # Solve scale
- x /= dominate_ratio
- y /= dominate_ratio
-
- # Solve padding
- fh, fw = nh/dominate_ratio, nw/dominate_ratio
- x -= (fw-ow)/2
- y -= (fh-oh)/2
-
- return x, y
-
- def is_pos_out_of_bound(self, x, y):
- x, y = self.pixel_pos_to_image_pos(x, y)
-
- out_of_bound = (
- (x < 0) or
- (y < 0) or
- (x > self.width-1) or
- (y > self.height-1)
- )
-
- return out_of_bound
-
- def get_scaled_pos(self, x, y):
- x, y = self.pixel_pos_to_image_pos(x, y)
-
- x = max(0, min(self.width-1, x))
- y = max(0, min(self.height-1, y))
-
- return x, y
-
- def clear_visualization(self):
- self.vis_map.fill(0)
- self.vis_alpha.fill(0)
-
- def reset_this_interaction(self):
- self.complete_interaction()
- self.clear_visualization()
- self.interaction = None
- if self.fbrs_controller is not None:
- self.fbrs_controller.unanchor()
-
- def set_viz_mode(self):
- self.viz_mode = self.combo.currentText()
- self.show_current_frame()
-
- def save_current_mask(self):
- # save mask to hard disk
- self.res_man.save_mask(self.cursur, self.current_mask)
-
- def tl_slide(self):
- # if we are propagating, the on_run function will take care of everything
- # don't do duplicate work here
- if not self.propagating:
- if self.curr_frame_dirty:
- self.save_current_mask()
- self.curr_frame_dirty = False
-
- self.reset_this_interaction()
- self.cursur = self.tl_slider.value()
- self.load_current_image_mask()
- self.show_current_frame()
-
- def brush_slide(self):
- self.brush_size = self.brush_slider.value()
- self.brush_label.setText('Brush size: %d' % self.brush_size)
- try:
- if type(self.interaction) == FreeInteraction:
- self.interaction.set_size(self.brush_size)
- except AttributeError:
- # Initialization, forget about it
- pass
-
- def on_forward_propagation(self):
- if self.propagating:
- # acts as a pause button
- self.propagating = False
- else:
- self.propagate_fn = self.on_next_frame
- self.backward_run_button.setEnabled(False)
- self.forward_run_button.setText('Pause Propagation')
- self.on_propagation()
-
- def on_backward_propagation(self):
- if self.propagating:
- # acts as a pause button
- self.propagating = False
- else:
- self.propagate_fn = self.on_prev_frame
- self.forward_run_button.setEnabled(False)
- self.backward_run_button.setText('Pause Propagation')
- self.on_propagation()
-
- def on_pause(self):
- self.propagating = False
- self.forward_run_button.setEnabled(True)
- self.backward_run_button.setEnabled(True)
- self.clear_mem_button.setEnabled(True)
- self.forward_run_button.setText('Forward Propagate')
- self.backward_run_button.setText('Backward Propagate')
- self.console_push_text('Propagation stopped.')
-
- def on_propagation(self):
- # start to propagate
- self.load_current_torch_image_mask()
- self.show_current_frame(fast=True)
-
- self.console_push_text('Propagation started.')
- self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:])
- self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
- # clear
- self.interacted_prob = None
- self.reset_this_interaction()
-
- self.propagating = True
- self.clear_mem_button.setEnabled(False)
- # propagate till the end
- while self.propagating:
- self.propagate_fn()
-
- self.load_current_image_mask(no_mask=True)
- self.load_current_torch_image_mask(no_mask=True)
-
- self.current_prob = self.processor.step(self.current_image_torch)
- self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
-
- self.save_current_mask()
- self.show_current_frame(fast=True)
-
- self.update_memory_size()
- QApplication.processEvents()
-
- if self.cursur == 0 or self.cursur == self.num_frames-1:
- break
-
- self.propagating = False
- self.curr_frame_dirty = False
- self.on_pause()
- self.tl_slide()
- QApplication.processEvents()
-
- def pause_propagation(self):
- self.propagating = False
-
- def on_commit(self):
- self.complete_interaction()
- self.update_interacted_mask()
-
- def on_prev_frame(self):
- # self.tl_slide will trigger on setValue
- self.cursur = max(0, self.cursur-1)
- self.tl_slider.setValue(self.cursur)
-
- def on_next_frame(self):
- # self.tl_slide will trigger on setValue
- self.cursur = min(self.cursur+1, self.num_frames-1)
- self.tl_slider.setValue(self.cursur)
-
- def on_play_video_timer(self):
- self.cursur += 1
- if self.cursur > self.num_frames-1:
- self.cursur = 0
- self.tl_slider.setValue(self.cursur)
-
- def on_play_video(self):
- if self.timer.isActive():
- self.timer.stop()
- self.play_button.setText('Play Video')
- else:
- self.timer.start(1000 / 30)
- self.play_button.setText('Stop Video')
-
- def on_reset_mask(self):
- self.current_mask.fill(0)
- if self.current_prob is not None:
- self.current_prob.fill_(0)
- self.curr_frame_dirty = True
- self.save_current_mask()
- self.reset_this_interaction()
- self.show_current_frame()
-
- def on_zoom_plus(self):
- self.zoom_pixels -= 25
- self.zoom_pixels = max(50, self.zoom_pixels)
- self.update_minimap()
-
- def on_zoom_minus(self):
- self.zoom_pixels += 25
- self.zoom_pixels = min(self.zoom_pixels, 300)
- self.update_minimap()
-
- def set_navi_enable(self, boolean):
- self.zoom_p_button.setEnabled(boolean)
- self.zoom_m_button.setEnabled(boolean)
- self.run_button.setEnabled(boolean)
- self.tl_slider.setEnabled(boolean)
- self.play_button.setEnabled(boolean)
- self.lcd.setEnabled(boolean)
-
- def hit_number_key(self, number):
- if number == self.current_object:
- return
- self.current_object = number
- if self.fbrs_controller is not None:
- self.fbrs_controller.unanchor()
- self.console_push_text(f'Current object changed to {number}.')
- self.clear_brush()
- self.vis_brush(self.last_ex, self.last_ey)
- self.update_interact_vis()
- self.show_current_frame()
-
- def clear_brush(self):
- self.brush_vis_map.fill(0)
- self.brush_vis_alpha.fill(0)
-
- def vis_brush(self, ex, ey):
- self.brush_vis_map = cv2.circle(self.brush_vis_map,
- (int(round(ex)), int(round(ey))), self.brush_size//2+1, color_map[self.current_object], thickness=-1)
- self.brush_vis_alpha = cv2.circle(self.brush_vis_alpha,
- (int(round(ex)), int(round(ey))), self.brush_size//2+1, 0.5, thickness=-1)
-
- def on_mouse_press(self, event):
- if self.is_pos_out_of_bound(event.x(), event.y()):
- return
-
- # mid-click
- if (event.button() == Qt.MidButton):
- ex, ey = self.get_scaled_pos(event.x(), event.y())
- target_object = self.current_mask[int(ey),int(ex)]
- if target_object in self.vis_target_objects:
- self.vis_target_objects.remove(target_object)
- else:
- self.vis_target_objects.append(target_object)
- self.console_push_text(f'Target objects for visualization changed to {self.vis_target_objects}')
- self.show_current_frame()
- return
-
- self.right_click = (event.button() == Qt.RightButton)
- self.pressed = True
-
- h, w = self.height, self.width
-
- self.load_current_torch_image_mask()
- image = self.current_image_torch
-
- last_interaction = self.interaction
- new_interaction = None
- if self.curr_interaction == 'Scribble':
- if last_interaction is None or type(last_interaction) != ScribbleInteraction:
- self.complete_interaction()
- new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().cuda(),
- (h, w), self.s2m_controller, self.num_objects)
- elif self.curr_interaction == 'Free':
- if last_interaction is None or type(last_interaction) != FreeInteraction:
- self.complete_interaction()
- new_interaction = FreeInteraction(image, self.current_mask, (h, w),
- self.num_objects)
- new_interaction.set_size(self.brush_size)
- elif self.curr_interaction == 'Click':
- if (last_interaction is None or type(last_interaction) != ClickInteraction
- or last_interaction.tar_obj != self.current_object):
- self.complete_interaction()
- self.fbrs_controller.unanchor()
- new_interaction = ClickInteraction(image, self.current_prob, (h, w),
- self.fbrs_controller, self.current_object)
-
- if new_interaction is not None:
- self.interaction = new_interaction
-
- # Just motion it as the first step
- self.on_mouse_motion(event)
-
- def on_mouse_motion(self, event):
- ex, ey = self.get_scaled_pos(event.x(), event.y())
- self.last_ex, self.last_ey = ex, ey
- self.clear_brush()
- # Visualize
- self.vis_brush(ex, ey)
- if self.pressed:
- if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
- obj = 0 if self.right_click else self.current_object
- self.vis_map, self.vis_alpha = self.interaction.push_point(
- ex, ey, obj, (self.vis_map, self.vis_alpha)
- )
- self.update_interact_vis()
- self.update_minimap()
-
- def update_interacted_mask(self):
- self.current_prob = self.interacted_prob
- self.current_mask = torch_prob_to_numpy_mask(self.interacted_prob)
- self.show_current_frame()
- self.save_current_mask()
- self.curr_frame_dirty = False
-
- def complete_interaction(self):
- if self.interaction is not None:
- self.clear_visualization()
- self.interaction = None
-
- def on_mouse_release(self, event):
- if not self.pressed:
- # this can happen when the initial press is out-of-bound
- return
-
- ex, ey = self.get_scaled_pos(event.x(), event.y())
-
- self.console_push_text('%s interaction at frame %d.' % (self.curr_interaction, self.cursur))
- interaction = self.interaction
-
- if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
- self.on_mouse_motion(event)
- interaction.end_path()
- if self.curr_interaction == 'Free':
- self.clear_visualization()
- elif self.curr_interaction == 'Click':
- ex, ey = self.get_scaled_pos(event.x(), event.y())
- self.vis_map, self.vis_alpha = interaction.push_point(ex, ey,
- self.right_click, (self.vis_map, self.vis_alpha))
-
- self.interacted_prob = interaction.predict()
- self.update_interacted_mask()
- self.update_gpu_usage()
-
- self.pressed = self.right_click = False
-
- def wheelEvent(self, event):
- ex, ey = self.get_scaled_pos(event.x(), event.y())
- if self.curr_interaction == 'Free':
- self.brush_slider.setValue(self.brush_slider.value() + event.angleDelta().y()//30)
- self.clear_brush()
- self.vis_brush(ex, ey)
- self.update_interact_vis()
- self.update_minimap()
-
- def update_gpu_usage(self):
- info = torch.cuda.mem_get_info()
- global_free, global_total = info
- global_free /= (2**30)
- global_total /= (2**30)
- global_used = global_total - global_free
-
- self.gpu_mem_gauge.setFormat(f'{global_used:.01f} GB / {global_total:.01f} GB')
- self.gpu_mem_gauge.setValue(round(global_used/global_total*100))
-
- used_by_torch = torch.cuda.max_memory_allocated() / (2**20)
- self.torch_mem_gauge.setFormat(f'{used_by_torch:.0f} MB / {global_total:.01f} GB')
- self.torch_mem_gauge.setValue(round(used_by_torch/global_total*100/1024))
-
- def on_gpu_timer(self):
- self.update_gpu_usage()
-
- def update_memory_size(self):
- try:
- max_work_elements = self.processor.memory.max_work_elements
- max_long_elements = self.processor.memory.max_long_elements
-
- curr_work_elements = self.processor.memory.work_mem.size
- curr_long_elements = self.processor.memory.long_mem.size
-
- self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}')
- self.work_mem_gauge.setValue(round(curr_work_elements/max_work_elements*100))
-
- self.long_mem_gauge.setFormat(f'{curr_long_elements} / {max_long_elements}')
- self.long_mem_gauge.setValue(round(curr_long_elements/max_long_elements*100))
-
- except AttributeError:
- self.work_mem_gauge.setFormat('Unknown')
- self.long_mem_gauge.setFormat('Unknown')
- self.work_mem_gauge.setValue(0)
- self.long_mem_gauge.setValue(0)
-
- def on_work_min_change(self):
- if self.initialized:
- self.work_mem_min.setValue(min(self.work_mem_min.value(), self.work_mem_max.value()-1))
- self.update_config()
-
- def on_work_max_change(self):
- if self.initialized:
- self.work_mem_max.setValue(max(self.work_mem_max.value(), self.work_mem_min.value()+1))
- self.update_config()
-
- def update_config(self):
- if self.initialized:
- self.config['min_mid_term_frames'] = self.work_mem_min.value()
- self.config['max_mid_term_frames'] = self.work_mem_max.value()
- self.config['max_long_term_elements'] = self.long_mem_max.value()
- self.config['num_prototypes'] = self.num_prototypes_box.value()
- self.config['mem_every'] = self.mem_every_box.value()
-
- self.processor.update_config(self.config)
-
- def on_clear_memory(self):
- self.processor.clear_memory()
- torch.cuda.empty_cache()
- self.update_gpu_usage()
- self.update_memory_size()
-
- def _open_file(self, prompt):
- options = QFileDialog.Options()
- file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options)
- return file_name
-
- def on_import_mask(self):
- file_name = self._open_file('Mask')
- if len(file_name) == 0:
- return
-
- mask = self.res_man.read_external_image(file_name, size=(self.height, self.width))
-
- shape_condition = (
- (len(mask.shape) == 2) and
- (mask.shape[-1] == self.width) and
- (mask.shape[-2] == self.height)
- )
-
- object_condition = (
- mask.max() <= self.num_objects
- )
-
- if not shape_condition:
- self.console_push_text(f'Expected ({self.height}, {self.width}). Got {mask.shape} instead.')
- elif not object_condition:
- self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.')
- else:
- self.console_push_text(f'Mask file {file_name} loaded.')
- self.current_image_torch = self.current_prob = None
- self.current_mask = mask
- self.show_current_frame()
- self.save_current_mask()
-
- def on_import_layer(self):
- file_name = self._open_file('Layer')
- if len(file_name) == 0:
- return
-
- self._try_load_layer(file_name)
-
- def _try_load_layer(self, file_name):
- try:
- layer = self.res_man.read_external_image(file_name, size=(self.height, self.width))
-
- if layer.shape[-1] == 3:
- layer = np.concatenate([layer, np.ones_like(layer[:,:,0:1])*255], axis=-1)
-
- condition = (
- (len(layer.shape) == 3) and
- (layer.shape[-1] == 4) and
- (layer.shape[-2] == self.width) and
- (layer.shape[-3] == self.height)
- )
-
- if not condition:
- self.console_push_text(f'Expected ({self.height}, {self.width}, 4). Got {layer.shape}.')
- else:
- self.console_push_text(f'Layer file {file_name} loaded.')
- self.overlay_layer = layer
- self.overlay_layer_torch = torch.from_numpy(layer).float().cuda()/255
- self.show_current_frame()
- except FileNotFoundError:
- self.console_push_text(f'{file_name} not found.')
-
- def on_save_visualization_toggle(self):
- self.save_visualization = self.save_visualization_checkbox.isChecked()
diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py
deleted file mode 100644
index daf852b..0000000
--- a/inference/interact/gui_utils.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from PyQt5.QtCore import Qt
-from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar)
-
-
-def create_parameter_box(min_val, max_val, text, step=1, callback=None):
- layout = QHBoxLayout()
-
- dial = QSpinBox()
- dial.setMaximumHeight(28)
- dial.setMaximumWidth(150)
- dial.setMinimum(min_val)
- dial.setMaximum(max_val)
- dial.setAlignment(Qt.AlignRight)
- dial.setSingleStep(step)
- dial.valueChanged.connect(callback)
-
- label = QLabel(text)
- label.setAlignment(Qt.AlignRight)
-
- layout.addWidget(label)
- layout.addWidget(dial)
-
- return dial, layout
-
-
-def create_gauge(text):
- layout = QHBoxLayout()
-
- gauge = QProgressBar()
- gauge.setMaximumHeight(28)
- gauge.setMaximumWidth(200)
- gauge.setAlignment(Qt.AlignCenter)
-
- label = QLabel(text)
- label.setAlignment(Qt.AlignRight)
-
- layout.addWidget(label)
- layout.addWidget(gauge)
-
- return gauge, layout
diff --git a/inference/interact/interaction.py b/inference/interact/interaction.py
deleted file mode 100644
index 19f83f9..0000000
--- a/inference/interact/interaction.py
+++ /dev/null
@@ -1,252 +0,0 @@
-"""
-Contains all the types of interaction related to the GUI
-Not related to automatic evaluation in the DAVIS dataset
-
-You can inherit the Interaction class to create new interaction types
-undo is (sometimes partially) supported
-"""
-
-
-import torch
-import torch.nn.functional as F
-import numpy as np
-import cv2
-import time
-from .interactive_utils import color_map, index_numpy_to_one_hot_torch
-
-
-def aggregate_sbg(prob, keep_bg=False, hard=False):
- device = prob.device
- k, h, w = prob.shape
- ex_prob = torch.zeros((k+1, h, w), device=device)
- ex_prob[0] = 0.5
- ex_prob[1:] = prob
- ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7)
- logits = torch.log((ex_prob /(1-ex_prob)))
-
- if hard:
- # Very low temperature o((⊙﹏⊙))o 🥶
- logits *= 1000
-
- if keep_bg:
- return F.softmax(logits, dim=0)
- else:
- return F.softmax(logits, dim=0)[1:]
-
-def aggregate_wbg(prob, keep_bg=False, hard=False):
- k, h, w = prob.shape
- new_prob = torch.cat([
- torch.prod(1-prob, dim=0, keepdim=True),
- prob
- ], 0).clamp(1e-7, 1-1e-7)
- logits = torch.log((new_prob /(1-new_prob)))
-
- if hard:
- # Very low temperature o((⊙﹏⊙))o 🥶
- logits *= 1000
-
- if keep_bg:
- return F.softmax(logits, dim=0)
- else:
- return F.softmax(logits, dim=0)[1:]
-
-class Interaction:
- def __init__(self, image, prev_mask, true_size, controller):
- self.image = image
- self.prev_mask = prev_mask
- self.controller = controller
- self.start_time = time.time()
-
- self.h, self.w = true_size
-
- self.out_prob = None
- self.out_mask = None
-
- def predict(self):
- pass
-
-
-class FreeInteraction(Interaction):
- def __init__(self, image, prev_mask, true_size, num_objects):
- """
- prev_mask should be index format numpy array
- """
- super().__init__(image, prev_mask, true_size, None)
-
- self.K = num_objects
-
- self.drawn_map = self.prev_mask.copy()
- self.curr_path = [[] for _ in range(self.K + 1)]
-
- self.size = None
-
- def set_size(self, size):
- self.size = size
-
- """
- k - object id
- vis - a tuple (visualization map, pass through alpha). None if not needed.
- """
- def push_point(self, x, y, k, vis=None):
- if vis is not None:
- vis_map, vis_alpha = vis
- selected = self.curr_path[k]
- selected.append((x, y))
- if len(selected) >= 2:
- cv2.line(self.drawn_map,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- k, thickness=self.size)
-
- # Plot visualization
- if vis is not None:
- # Visualization for drawing
- if k == 0:
- vis_map = cv2.line(vis_map,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- color_map[k], thickness=self.size)
- else:
- vis_map = cv2.line(vis_map,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- color_map[k], thickness=self.size)
- # Visualization on/off boolean filter
- vis_alpha = cv2.line(vis_alpha,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- 0.75, thickness=self.size)
-
- if vis is not None:
- return vis_map, vis_alpha
-
- def end_path(self):
- # Complete the drawing
- self.curr_path = [[] for _ in range(self.K + 1)]
-
- def predict(self):
- self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda()
- # self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
- # self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
- # self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True)
- return self.out_prob
-
-class ScribbleInteraction(Interaction):
- def __init__(self, image, prev_mask, true_size, controller, num_objects):
- """
- prev_mask should be in an indexed form
- """
- super().__init__(image, prev_mask, true_size, controller)
-
- self.K = num_objects
-
- self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8)
- self.drawn_map.fill(255)
- # background + k
- self.curr_path = [[] for _ in range(self.K + 1)]
- self.size = 3
-
- """
- k - object id
- vis - a tuple (visualization map, pass through alpha). None if not needed.
- """
- def push_point(self, x, y, k, vis=None):
- if vis is not None:
- vis_map, vis_alpha = vis
- selected = self.curr_path[k]
- selected.append((x, y))
- if len(selected) >= 2:
- self.drawn_map = cv2.line(self.drawn_map,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- k, thickness=self.size)
-
- # Plot visualization
- if vis is not None:
- # Visualization for drawing
- if k == 0:
- vis_map = cv2.line(vis_map,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- color_map[k], thickness=self.size)
- else:
- vis_map = cv2.line(vis_map,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- color_map[k], thickness=self.size)
- # Visualization on/off boolean filter
- vis_alpha = cv2.line(vis_alpha,
- (int(round(selected[-2][0])), int(round(selected[-2][1]))),
- (int(round(selected[-1][0])), int(round(selected[-1][1]))),
- 0.75, thickness=self.size)
-
- # Optional vis return
- if vis is not None:
- return vis_map, vis_alpha
-
- def end_path(self):
- # Complete the drawing
- self.curr_path = [[] for _ in range(self.K + 1)]
-
- def predict(self):
- self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map)
- self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True)
- return self.out_prob
-
-
-class ClickInteraction(Interaction):
- def __init__(self, image, prev_mask, true_size, controller, tar_obj):
- """
- prev_mask in a prob. form
- """
- super().__init__(image, prev_mask, true_size, controller)
- self.tar_obj = tar_obj
-
- # negative/positive for each object
- self.pos_clicks = []
- self.neg_clicks = []
-
- self.out_prob = self.prev_mask.clone()
-
- """
- neg - Negative interaction or not
- vis - a tuple (visualization map, pass through alpha). None if not needed.
- """
- def push_point(self, x, y, neg, vis=None):
- # Clicks
- if neg:
- self.neg_clicks.append((x, y))
- else:
- self.pos_clicks.append((x, y))
-
- # Do the prediction
- self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg)
-
- # Plot visualization
- if vis is not None:
- vis_map, vis_alpha = vis
- # Visualization for clicks
- if neg:
- vis_map = cv2.circle(vis_map,
- (int(round(x)), int(round(y))),
- 2, color_map[0], thickness=-1)
- else:
- vis_map = cv2.circle(vis_map,
- (int(round(x)), int(round(y))),
- 2, color_map[self.tar_obj], thickness=-1)
-
- vis_alpha = cv2.circle(vis_alpha,
- (int(round(x)), int(round(y))),
- 2, 1, thickness=-1)
-
- # Optional vis return
- return vis_map, vis_alpha
-
- def predict(self):
- self.out_prob = self.prev_mask.clone()
- # a small hack to allow the interacting object to overwrite existing masks
- # without remembering all the object probabilities
- self.out_prob = torch.clamp(self.out_prob, max=0.9)
- self.out_prob[self.tar_obj] = self.obj_mask
- self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True)
- return self.out_prob
diff --git a/inference/interact/interactive_utils.py b/inference/interact/interactive_utils.py
deleted file mode 100644
index 9961f63..0000000
--- a/inference/interact/interactive_utils.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Modifed from https://github.com/seoungwugoh/ivs-demo
-
-import numpy as np
-
-import torch
-import torch.nn.functional as F
-from util.palette import davis_palette
-from dataset.range_transform import im_normalization
-
-def image_to_torch(frame: np.ndarray, device='cuda'):
- # frame: H*W*3 numpy array
- frame = frame.transpose(2, 0, 1)
- frame = torch.from_numpy(frame).float().to(device)/255
- frame_norm = im_normalization(frame)
- return frame_norm, frame
-
-def torch_prob_to_numpy_mask(prob):
- mask = torch.argmax(prob, dim=0)
- mask = mask.cpu().numpy().astype(np.uint8)
- return mask
-
-def index_numpy_to_one_hot_torch(mask, num_classes):
- mask = torch.from_numpy(mask).long()
- return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float()
-
-"""
-Some constants fro visualization
-"""
-color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
-# scales for better visualization
-color_map_np = (color_map_np.astype(np.float32)*1.5).clip(0, 255).astype(np.uint8)
-color_map = color_map_np.tolist()
-if torch.cuda.is_available():
- color_map_torch = torch.from_numpy(color_map_np).cuda() / 255
-
-grayscale_weights = np.array([[0.3,0.59,0.11]]).astype(np.float32)
-if torch.cuda.is_available():
- grayscale_weights_torch = torch.from_numpy(grayscale_weights).cuda().unsqueeze(0)
-
-def get_visualization(mode, image, mask, layer, target_object):
- if mode == 'fade':
- return overlay_davis(image, mask, fade=True)
- elif mode == 'davis':
- return overlay_davis(image, mask)
- elif mode == 'light':
- return overlay_davis(image, mask, 0.9)
- elif mode == 'popup':
- return overlay_popup(image, mask, target_object)
- elif mode == 'layered':
- if layer is None:
- print('Layer file not given. Defaulting to DAVIS.')
- return overlay_davis(image, mask)
- else:
- return overlay_layer(image, mask, layer, target_object)
- else:
- raise NotImplementedError
-
-def get_visualization_torch(mode, image, prob, layer, target_object):
- if mode == 'fade':
- return overlay_davis_torch(image, prob, fade=True)
- elif mode == 'davis':
- return overlay_davis_torch(image, prob)
- elif mode == 'light':
- return overlay_davis_torch(image, prob, 0.9)
- elif mode == 'popup':
- return overlay_popup_torch(image, prob, target_object)
- elif mode == 'layered':
- if layer is None:
- print('Layer file not given. Defaulting to DAVIS.')
- return overlay_davis_torch(image, prob)
- else:
- return overlay_layer_torch(image, prob, layer, target_object)
- else:
- raise NotImplementedError
-
-def overlay_davis(image, mask, alpha=0.5, fade=False):
- """ Overlay segmentation on top of RGB image. from davis official"""
- im_overlay = image.copy()
-
- colored_mask = color_map_np[mask]
- foreground = image*alpha + (1-alpha)*colored_mask
- binary_mask = (mask > 0)
- # Compose image
- im_overlay[binary_mask] = foreground[binary_mask]
- if fade:
- im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
- return im_overlay.astype(image.dtype)
-
-def overlay_popup(image, mask, target_object):
- # Keep foreground colored. Convert background to grayscale.
- im_overlay = image.copy()
-
- binary_mask = ~(np.isin(mask, target_object))
- colored_region = (im_overlay[binary_mask]*grayscale_weights).sum(-1, keepdims=-1)
- im_overlay[binary_mask] = colored_region
- return im_overlay.astype(image.dtype)
-
-def overlay_layer(image, mask, layer, target_object):
- # insert a layer between foreground and background
- # The CPU version is less accurate because we are using the hard mask
- # The GPU version has softer edges as it uses soft probabilities
- obj_mask = (np.isin(mask, target_object)).astype(np.float32)
- layer_alpha = layer[:, :, 3].astype(np.float32) / 255
- layer_rgb = layer[:, :, :3]
- background_alpha = np.maximum(obj_mask, layer_alpha)[:,:,np.newaxis]
- obj_mask = obj_mask[:,:,np.newaxis]
- im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 255)
- return im_overlay.astype(image.dtype)
-
-def overlay_davis_torch(image, mask, alpha=0.5, fade=False):
- """ Overlay segmentation on top of RGB image. from davis official"""
- # Changes the image in-place to avoid copying
- image = image.permute(1, 2, 0)
- im_overlay = image
- mask = torch.argmax(mask, dim=0)
-
- colored_mask = color_map_torch[mask]
- foreground = image*alpha + (1-alpha)*colored_mask
- binary_mask = (mask > 0)
- # Compose image
- im_overlay[binary_mask] = foreground[binary_mask]
- if fade:
- im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
-
- im_overlay = (im_overlay*255).cpu().numpy()
- im_overlay = im_overlay.astype(np.uint8)
-
- return im_overlay
-
-def overlay_popup_torch(image, mask, target_object):
- # Keep foreground colored. Convert background to grayscale.
- image = image.permute(1, 2, 0)
-
- if len(target_object) == 0:
- obj_mask = torch.zeros_like(mask[0]).unsqueeze(2)
- else:
- # I should not need to convert this to numpy.
- # uUsing list works most of the time but consistently fails
- # if I include first object -> exclude it -> include it again.
- # I check everywhere and it makes absolutely no sense.
- # I am blaming this on PyTorch and calling it a day
- obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0).unsqueeze(2)
- gray_image = (image*grayscale_weights_torch).sum(-1, keepdim=True)
- im_overlay = obj_mask*image + (1-obj_mask)*gray_image
-
- im_overlay = (im_overlay*255).cpu().numpy()
- im_overlay = im_overlay.astype(np.uint8)
-
- return im_overlay
-
-def overlay_layer_torch(image, mask, layer, target_object):
- # insert a layer between foreground and background
- # The CPU version is less accurate because we are using the hard mask
- # The GPU version has softer edges as it uses soft probabilities
- image = image.permute(1, 2, 0)
-
- if len(target_object) == 0:
- obj_mask = torch.zeros_like(mask[0])
- else:
- # I should not need to convert this to numpy.
- # uUsing list works most of the time but consistently fails
- # if I include first object -> exclude it -> include it again.
- # I check everywhere and it makes absolutely no sense.
- # I am blaming this on PyTorch and calling it a day
- obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0)
- layer_alpha = layer[:, :, 3]
- layer_rgb = layer[:, :, :3]
- background_alpha = torch.maximum(obj_mask, layer_alpha).unsqueeze(2)
- obj_mask = obj_mask.unsqueeze(2)
- im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 1)
-
- im_overlay = (im_overlay*255).cpu().numpy()
- im_overlay = im_overlay.astype(np.uint8)
-
- return im_overlay
diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py
deleted file mode 100644
index b0f28af..0000000
--- a/inference/interact/resource_manager.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import os
-from os import path
-import shutil
-import collections
-
-import cv2
-from PIL import Image
-if not hasattr(Image, 'Resampling'): # Pillow<9.0
- Image.Resampling = Image
-import numpy as np
-
-from util.palette import davis_palette
-import progressbar
-
-
-# https://bugs.python.org/issue28178
-# ah python ah why
-class LRU:
- def __init__(self, func, maxsize=128):
- self.cache = collections.OrderedDict()
- self.func = func
- self.maxsize = maxsize
-
- def __call__(self, *args):
- cache = self.cache
- if args in cache:
- cache.move_to_end(args)
- return cache[args]
- result = self.func(*args)
- cache[args] = result
- if len(cache) > self.maxsize:
- cache.popitem(last=False)
- return result
-
- def invalidate(self, key):
- self.cache.pop(key, None)
-
-
-class ResourceManager:
- def __init__(self, config):
- # determine inputs
- images = config['images']
- video = config['video']
- self.workspace = config['workspace']
- self.size = config['size']
- self.palette = davis_palette
-
- # create temporary workspace if not specified
- if self.workspace is None:
- if images is not None:
- basename = path.basename(images)
- elif video is not None:
- basename = path.basename(video)[:-4]
- else:
- raise NotImplementedError(
- 'Either images, video, or workspace has to be specified')
-
- self.workspace = path.join('./workspace', basename)
-
- print(f'Workspace is in: {self.workspace}')
-
- # determine the location of input images
- need_decoding = False
- need_resizing = False
- if path.exists(path.join(self.workspace, 'images')):
- pass
- elif images is not None:
- need_resizing = True
- elif video is not None:
- # will decode video into frames later
- need_decoding = True
-
- # create workspace subdirectories
- self.image_dir = path.join(self.workspace, 'images')
- self.mask_dir = path.join(self.workspace, 'masks')
- os.makedirs(self.image_dir, exist_ok=True)
- os.makedirs(self.mask_dir, exist_ok=True)
-
- # convert read functions to be buffered
- self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size'])
- self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size'])
-
- # extract frames from video
- if need_decoding:
- self._extract_frames(video)
-
- # copy/resize existing images to the workspace
- if need_resizing:
- self._copy_resize_frames(images)
-
- # read all frame names
- self.names = sorted(os.listdir(self.image_dir))
- self.names = [f[:-4] for f in self.names] # remove extensions
- self.length = len(self.names)
-
- assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.'
-
- print(f'{self.length} images found.')
-
- self.height, self.width = self.get_image(0).shape[:2]
- self.visualization_init = False
-
- def _extract_frames(self, video):
- cap = cv2.VideoCapture(video)
- frame_index = 0
- print(f'Extracting frames from {video} into {self.image_dir}...')
- bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength)
- while(cap.isOpened()):
- _, frame = cap.read()
- if frame is None:
- break
- if self.size > 0:
- h, w = frame.shape[:2]
- new_w = (w*self.size//min(w, h))
- new_h = (h*self.size//min(w, h))
- if new_w != w or new_h != h:
- frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
- cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame)
- frame_index += 1
- bar.update(frame_index)
- bar.finish()
- print('Done!')
-
- def _copy_resize_frames(self, images):
- image_list = os.listdir(images)
- print(f'Copying/resizing frames into {self.image_dir}...')
- for image_name in progressbar.progressbar(image_list):
- if self.size < 0:
- # just copy
- shutil.copy2(path.join(images, image_name), self.image_dir)
- else:
- frame = cv2.imread(path.join(images, image_name))
- h, w = frame.shape[:2]
- new_w = (w*self.size//min(w, h))
- new_h = (h*self.size//min(w, h))
- if new_w != w or new_h != h:
- frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
- cv2.imwrite(path.join(self.image_dir, image_name), frame)
- print('Done!')
-
- def save_mask(self, ti, mask):
- # mask should be uint8 H*W without channels
- assert 0 <= ti < self.length
- assert isinstance(mask, np.ndarray)
-
- mask = Image.fromarray(mask)
- mask.putpalette(self.palette)
- mask.save(path.join(self.mask_dir, self.names[ti]+'.png'))
- self.invalidate(ti)
-
- def save_visualization(self, ti, image):
- # image should be uint8 3*H*W
- assert 0 <= ti < self.length
- assert isinstance(image, np.ndarray)
- if not self.visualization_init:
- self.visualization_dir = path.join(self.workspace, 'visualization')
- os.makedirs(self.visualization_dir, exist_ok=True)
- self.visualization_init = True
-
- image = Image.fromarray(image)
- image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg'))
-
- def _get_image_unbuffered(self, ti):
- # returns H*W*3 uint8 array
- assert 0 <= ti < self.length
-
- image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg'))
- image = np.array(image)
- return image
-
- def _get_mask_unbuffered(self, ti):
- # returns H*W uint8 array
- assert 0 <= ti < self.length
-
- mask_path = path.join(self.mask_dir, self.names[ti]+'.png')
- if path.exists(mask_path):
- mask = Image.open(mask_path)
- mask = np.array(mask)
- return mask
- else:
- return None
-
- def read_external_image(self, file_name, size=None):
- image = Image.open(file_name)
- is_mask = image.mode in ['L', 'P']
- if size is not None:
- # PIL uses (width, height)
- image = image.resize((size[1], size[0]),
- resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC)
- image = np.array(image)
- return image
-
- def invalidate(self, ti):
- # the image buffer is never invalidated
- self.get_mask.invalidate((ti,))
-
- def __len__(self):
- return self.length
-
- @property
- def h(self):
- return self.height
-
- @property
- def w(self):
- return self.width
diff --git a/inference/interact/s2m/__init__.py b/inference/interact/s2m/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/inference/interact/s2m/_deeplab.py b/inference/interact/s2m/_deeplab.py
deleted file mode 100644
index e663007..0000000
--- a/inference/interact/s2m/_deeplab.py
+++ /dev/null
@@ -1,180 +0,0 @@
-# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from .utils import _SimpleSegmentationModel
-
-
-__all__ = ["DeepLabV3"]
-
-
-class DeepLabV3(_SimpleSegmentationModel):
- """
- Implements DeepLabV3 model from
- `"Rethinking Atrous Convolution for Semantic Image Segmentation"
- `_.
-
- Arguments:
- backbone (nn.Module): the network used to compute the features for the model.
- The backbone should return an OrderedDict[Tensor], with the key being
- "out" for the last feature map used, and "aux" if an auxiliary classifier
- is used.
- classifier (nn.Module): module that takes the "out" element returned from
- the backbone and returns a dense prediction.
- aux_classifier (nn.Module, optional): auxiliary classifier used during training
- """
- pass
-
-class DeepLabHeadV3Plus(nn.Module):
- def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
- super(DeepLabHeadV3Plus, self).__init__()
- self.project = nn.Sequential(
- nn.Conv2d(low_level_channels, 48, 1, bias=False),
- nn.BatchNorm2d(48),
- nn.ReLU(inplace=True),
- )
-
- self.aspp = ASPP(in_channels, aspp_dilate)
-
- self.classifier = nn.Sequential(
- nn.Conv2d(304, 256, 3, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, num_classes, 1)
- )
- self._init_weight()
-
- def forward(self, feature):
- low_level_feature = self.project( feature['low_level'] )
- output_feature = self.aspp(feature['out'])
- output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
- return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )
-
- def _init_weight(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight)
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
-class DeepLabHead(nn.Module):
- def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
- super(DeepLabHead, self).__init__()
-
- self.classifier = nn.Sequential(
- ASPP(in_channels, aspp_dilate),
- nn.Conv2d(256, 256, 3, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, num_classes, 1)
- )
- self._init_weight()
-
- def forward(self, feature):
- return self.classifier( feature['out'] )
-
- def _init_weight(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight)
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
-class AtrousSeparableConvolution(nn.Module):
- """ Atrous Separable Convolution
- """
- def __init__(self, in_channels, out_channels, kernel_size,
- stride=1, padding=0, dilation=1, bias=True):
- super(AtrousSeparableConvolution, self).__init__()
- self.body = nn.Sequential(
- # Separable Conv
- nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),
- # PointWise Conv
- nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
- )
-
- self._init_weight()
-
- def forward(self, x):
- return self.body(x)
-
- def _init_weight(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight)
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
-class ASPPConv(nn.Sequential):
- def __init__(self, in_channels, out_channels, dilation):
- modules = [
- nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True)
- ]
- super(ASPPConv, self).__init__(*modules)
-
-class ASPPPooling(nn.Sequential):
- def __init__(self, in_channels, out_channels):
- super(ASPPPooling, self).__init__(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True))
-
- def forward(self, x):
- size = x.shape[-2:]
- x = super(ASPPPooling, self).forward(x)
- return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
-
-class ASPP(nn.Module):
- def __init__(self, in_channels, atrous_rates):
- super(ASPP, self).__init__()
- out_channels = 256
- modules = []
- modules.append(nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True)))
-
- rate1, rate2, rate3 = tuple(atrous_rates)
- modules.append(ASPPConv(in_channels, out_channels, rate1))
- modules.append(ASPPConv(in_channels, out_channels, rate2))
- modules.append(ASPPConv(in_channels, out_channels, rate3))
- modules.append(ASPPPooling(in_channels, out_channels))
-
- self.convs = nn.ModuleList(modules)
-
- self.project = nn.Sequential(
- nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Dropout(0.1),)
-
- def forward(self, x):
- res = []
- for conv in self.convs:
- res.append(conv(x))
- res = torch.cat(res, dim=1)
- return self.project(res)
-
-
-
-def convert_to_separable_conv(module):
- new_module = module
- if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
- new_module = AtrousSeparableConvolution(module.in_channels,
- module.out_channels,
- module.kernel_size,
- module.stride,
- module.padding,
- module.dilation,
- module.bias)
- for name, child in module.named_children():
- new_module.add_module(name, convert_to_separable_conv(child))
- return new_module
\ No newline at end of file
diff --git a/inference/interact/s2m/s2m_network.py b/inference/interact/s2m/s2m_network.py
deleted file mode 100644
index e4f9a3f..0000000
--- a/inference/interact/s2m/s2m_network.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
-
-from .utils import IntermediateLayerGetter
-from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
-from . import s2m_resnet
-
-def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
-
- if output_stride==8:
- replace_stride_with_dilation=[False, True, True]
- aspp_dilate = [12, 24, 36]
- else:
- replace_stride_with_dilation=[False, False, True]
- aspp_dilate = [6, 12, 18]
-
- backbone = s2m_resnet.__dict__[backbone_name](
- pretrained=pretrained_backbone,
- replace_stride_with_dilation=replace_stride_with_dilation)
-
- inplanes = 2048
- low_level_planes = 256
-
- if name=='deeplabv3plus':
- return_layers = {'layer4': 'out', 'layer1': 'low_level'}
- classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
- elif name=='deeplabv3':
- return_layers = {'layer4': 'out'}
- classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
- backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
-
- model = DeepLabV3(backbone, classifier)
- return model
-
-def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
-
- if backbone.startswith('resnet'):
- model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
- else:
- raise NotImplementedError
- return model
-
-
-# Deeplab v3
-def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
- """Constructs a DeepLabV3 model with a ResNet-50 backbone.
-
- Args:
- num_classes (int): number of classes.
- output_stride (int): output stride for deeplab.
- pretrained_backbone (bool): If True, use the pretrained backbone.
- """
- return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
-
-
-# Deeplab v3+
-def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
- """Constructs a DeepLabV3 model with a ResNet-50 backbone.
-
- Args:
- num_classes (int): number of classes.
- output_stride (int): output stride for deeplab.
- pretrained_backbone (bool): If True, use the pretrained backbone.
- """
- return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
-
diff --git a/inference/interact/s2m/s2m_resnet.py b/inference/interact/s2m/s2m_resnet.py
deleted file mode 100644
index 89f1ce0..0000000
--- a/inference/interact/s2m/s2m_resnet.py
+++ /dev/null
@@ -1,182 +0,0 @@
-import torch
-import torch.nn as nn
-try:
- from torchvision.models.utils import load_state_dict_from_url
-except ModuleNotFoundError:
- from torch.utils.model_zoo import load_url as load_state_dict_from_url
-
-
-__all__ = ['ResNet', 'resnet50']
-
-
-model_urls = {
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
-}
-
-
-def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=dilation, groups=groups, bias=False, dilation=dilation)
-
-
-def conv1x1(in_planes, out_planes, stride=1):
- """1x1 convolution"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
- base_width=64, dilation=1, norm_layer=None):
- super(Bottleneck, self).__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- width = int(planes * (base_width / 64.)) * groups
- # Both self.conv2 and self.downsample layers downsample the input when stride != 1
- self.conv1 = conv1x1(inplanes, width)
- self.bn1 = norm_layer(width)
- self.conv2 = conv3x3(width, width, stride, groups, dilation)
- self.bn2 = norm_layer(width)
- self.conv3 = conv1x1(width, planes * self.expansion)
- self.bn3 = norm_layer(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
-
-
-class ResNet(nn.Module):
-
- def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
- groups=1, width_per_group=64, replace_stride_with_dilation=None,
- norm_layer=None):
- super(ResNet, self).__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- self._norm_layer = norm_layer
-
- self.inplanes = 64
- self.dilation = 1
- if replace_stride_with_dilation is None:
- # each element in the tuple indicates if we should replace
- # the 2x2 stride with a dilated convolution instead
- replace_stride_with_dilation = [False, False, False]
- if len(replace_stride_with_dilation) != 3:
- raise ValueError("replace_stride_with_dilation should be None "
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
- self.groups = groups
- self.base_width = width_per_group
- self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3,
- bias=False)
- self.bn1 = norm_layer(self.inplanes)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0])
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
- dilate=replace_stride_with_dilation[0])
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
- dilate=replace_stride_with_dilation[1])
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
- dilate=replace_stride_with_dilation[2])
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(512 * block.expansion, num_classes)
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- # Zero-initialize the last BN in each residual branch,
- # so that the residual branch starts with zeros, and each residual block behaves like an identity.
- # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
- if zero_init_residual:
- for m in self.modules():
- if isinstance(m, Bottleneck):
- nn.init.constant_(m.bn3.weight, 0)
-
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
- norm_layer = self._norm_layer
- downsample = None
- previous_dilation = self.dilation
- if dilate:
- self.dilation *= stride
- stride = 1
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- conv1x1(self.inplanes, planes * block.expansion, stride),
- norm_layer(planes * block.expansion),
- )
-
- layers = []
- layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
- self.base_width, previous_dilation, norm_layer))
- self.inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(block(self.inplanes, planes, groups=self.groups,
- base_width=self.base_width, dilation=self.dilation,
- norm_layer=norm_layer))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
-
- return x
-
-
-def _resnet(arch, block, layers, pretrained, progress, **kwargs):
- model = ResNet(block, layers, **kwargs)
- if pretrained:
- state_dict = load_state_dict_from_url(model_urls[arch],
- progress=progress)
- model.load_state_dict(state_dict)
- return model
-
-
-def resnet50(pretrained=False, progress=True, **kwargs):
- r"""ResNet-50 model from
- `"Deep Residual Learning for Image Recognition" `_
-
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- progress (bool): If True, displays a progress bar of the download to stderr
- """
- return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
- **kwargs)
diff --git a/inference/interact/s2m/utils.py b/inference/interact/s2m/utils.py
deleted file mode 100644
index c2adecf..0000000
--- a/inference/interact/s2m/utils.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
-
-import torch
-import torch.nn as nn
-import numpy as np
-import torch.nn.functional as F
-from collections import OrderedDict
-
-class _SimpleSegmentationModel(nn.Module):
- def __init__(self, backbone, classifier):
- super(_SimpleSegmentationModel, self).__init__()
- self.backbone = backbone
- self.classifier = classifier
-
- def forward(self, x):
- input_shape = x.shape[-2:]
- features = self.backbone(x)
- x = self.classifier(features)
- x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
- return x
-
-
-class IntermediateLayerGetter(nn.ModuleDict):
- """
- Module wrapper that returns intermediate layers from a model
-
- It has a strong assumption that the modules have been registered
- into the model in the same order as they are used.
- This means that one should **not** reuse the same nn.Module
- twice in the forward if you want this to work.
-
- Additionally, it is only able to query submodules that are directly
- assigned to the model. So if `model` is passed, `model.feature1` can
- be returned, but not `model.feature1.layer2`.
-
- Arguments:
- model (nn.Module): model on which we will extract the features
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
-
- Examples::
-
- >>> m = torchvision.models.resnet18(pretrained=True)
- >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
- >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
- >>> {'layer1': 'feat1', 'layer3': 'feat2'})
- >>> out = new_m(torch.rand(1, 3, 224, 224))
- >>> print([(k, v.shape) for k, v in out.items()])
- >>> [('feat1', torch.Size([1, 64, 56, 56])),
- >>> ('feat2', torch.Size([1, 256, 14, 14]))]
- """
- def __init__(self, model, return_layers):
- if not set(return_layers).issubset([name for name, _ in model.named_children()]):
- raise ValueError("return_layers are not present in model")
-
- orig_return_layers = return_layers
- return_layers = {k: v for k, v in return_layers.items()}
- layers = OrderedDict()
- for name, module in model.named_children():
- layers[name] = module
- if name in return_layers:
- del return_layers[name]
- if not return_layers:
- break
-
- super(IntermediateLayerGetter, self).__init__(layers)
- self.return_layers = orig_return_layers
-
- def forward(self, x):
- out = OrderedDict()
- for name, module in self.named_children():
- x = module(x)
- if name in self.return_layers:
- out_name = self.return_layers[name]
- out[out_name] = x
- return out
diff --git a/inference/interact/s2m_controller.py b/inference/interact/s2m_controller.py
deleted file mode 100644
index e222259..0000000
--- a/inference/interact/s2m_controller.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import torch
-import numpy as np
-from ..interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M
-
-from util.tensor_util import pad_divide_by, unpad
-
-
-class S2MController:
- """
- A controller for Scribble-to-Mask (for user interaction, not for DAVIS)
- Takes the image, previous mask, and scribbles to produce a new mask
- ignore_class is usually 255
- 0 is NOT the ignore class -- it is the label for the background
- """
- def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'):
- self.s2m_net = s2m_net
- self.num_objects = num_objects
- self.ignore_class = ignore_class
- self.device = device
-
- def interact(self, image, prev_mask, scr_mask):
- image = image.to(self.device, non_blocking=True)
- prev_mask = prev_mask.unsqueeze(0)
-
- h, w = image.shape[-2:]
- unaggre_mask = torch.zeros((self.num_objects, h, w), dtype=torch.float32, device=image.device)
-
- for ki in range(1, self.num_objects+1):
- p_srb = (scr_mask==ki).astype(np.uint8)
- n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8)
-
- Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device)
-
- inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1)
- inputs, pads = pad_divide_by(inputs, 16)
-
- unaggre_mask[ki-1] = unpad(torch.sigmoid(self.s2m_net(inputs)), pads)
-
- return unaggre_mask
\ No newline at end of file
diff --git a/inference/interact/timer.py b/inference/interact/timer.py
deleted file mode 100644
index d134aa1..0000000
--- a/inference/interact/timer.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import time
-
-class Timer:
- def __init__(self):
- self._acc_time = 0
- self._paused = True
-
- def start(self):
- if self._paused:
- self.last_time = time.time()
- self._paused = False
- return self
-
- def pause(self):
- self.count()
- self._paused = True
- return self
-
- def count(self):
- if self._paused:
- return self._acc_time
- t = time.time()
- self._acc_time += t - self.last_time
- self.last_time = t
- return self._acc_time
-
- def format(self):
- # count = int(self.count()*100)
- # return '%02d:%02d:%02d' % (count//6000, (count//100)%60, count%100)
- return '%03.2f' % self.count()
-
- def __str__(self):
- return self.format()
\ No newline at end of file
diff --git a/interactive_demo.py b/interactive_demo.py
deleted file mode 100644
index fef2140..0000000
--- a/interactive_demo.py
+++ /dev/null
@@ -1,95 +0,0 @@
-"""
-A simple user interface for XMem
-"""
-
-import os
-# fix for Windows
-if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ:
- os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = ''
-
-import sys
-from argparse import ArgumentParser
-
-import torch
-
-from model.network import XMem
-from inference.interact.s2m_controller import S2MController
-from inference.interact.fbrs_controller import FBRSController
-from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M
-
-from PyQt5.QtWidgets import QApplication
-from inference.interact.gui import App
-from inference.interact.resource_manager import ResourceManager
-
-torch.set_grad_enabled(False)
-
-
-if __name__ == '__main__':
-
- # Arguments parsing
- parser = ArgumentParser()
- parser.add_argument('--model', default='./saves/XMem.pth')
- parser.add_argument('--s2m_model', default='saves/s2m.pth')
- parser.add_argument('--fbrs_model', default='saves/fbrs.pth')
-
- """
- Priority 1: If a "images" folder exists in the workspace, we will read from that directory
- Priority 2: If --images is specified, we will copy/resize those images to the workspace
- Priority 3: If --video is specified, we will extract the frames to the workspace (in an "images" folder) and read from there
-
- In any case, if a "masks" folder exists in the workspace, we will use that to initialize the mask
- That way, you can continue annotation from an interrupted run as long as the same workspace is used.
- """
- parser.add_argument('--images', help='Folders containing input images.', default=None)
- parser.add_argument('--video', help='Video file readable by OpenCV.', default=None)
- parser.add_argument('--workspace', help='directory for storing buffered images (if needed) and output masks', default=None)
-
- parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100)
-
- parser.add_argument('--num_objects', type=int, default=1)
-
- # Long-memory options
- # Defaults. Some can be changed in the GUI.
- parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
- parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
- parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
- type=int, default=10000)
- parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
-
- parser.add_argument('--top_k', type=int, default=30)
- parser.add_argument('--mem_every', type=int, default=10)
- parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
- parser.add_argument('--no_amp', help='Turn off AMP', action='store_true')
- parser.add_argument('--size', default=480, type=int,
- help='Resize the shorter side to this size. -1 to use original resolution. ')
- args = parser.parse_args()
-
- config = vars(args)
- config['enable_long_term'] = True
- config['enable_long_term_count_usage'] = True
-
- with torch.cuda.amp.autocast(enabled=not args.no_amp):
-
- # Load our checkpoint
- network = XMem(config, args.model).cuda().eval()
-
- # Loads the S2M model
- if args.s2m_model is not None:
- s2m_saved = torch.load(args.s2m_model)
- s2m_model = S2M().cuda().eval()
- s2m_model.load_state_dict(s2m_saved)
- else:
- s2m_model = None
-
- s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255)
- if args.fbrs_model is not None:
- fbrs_controller = FBRSController(args.fbrs_model)
- else:
- fbrs_controller = None
-
- # Manages most IO
- resource_manager = ResourceManager(config)
-
- app = QApplication(sys.argv)
- ex = App(network, resource_manager, s2m_controller, fbrs_controller, config)
- sys.exit(app.exec_())
diff --git a/merge_multi_scale.py b/merge_multi_scale.py
deleted file mode 100644
index 43ef93f..0000000
--- a/merge_multi_scale.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import os
-from os import path
-from argparse import ArgumentParser
-import glob
-from collections import defaultdict
-
-import numpy as np
-import hickle as hkl
-from PIL import Image, ImagePalette
-
-from progressbar import progressbar
-from multiprocessing import Pool
-from util import palette
-
-from util.palette import davis_palette, youtube_palette
-import shutil
-
-
-def search_options(options, name):
- for option in options:
- if path.exists(path.join(option, name)):
- return path.join(option, name)
- else:
- return None
-
-def process_vid(vid):
- vid_path = search_options(all_options, vid)
- if vid_path is not None:
- backward_mapping = hkl.load(path.join(vid_path, 'backward.hkl'))
- else:
- backward_mapping = None
-
- frames = os.listdir(path.join(all_options[0], vid))
- frames = [f for f in frames if 'backward' not in f]
-
- print(vid)
- if 'Y' in args.dataset:
- this_out_path = path.join(out_path, 'Annotations', vid)
- else:
- this_out_path = path.join(out_path, vid)
- os.makedirs(this_out_path, exist_ok=True)
-
- for f in frames:
- result_sum = None
-
- for option in all_options:
- if not path.exists(path.join(option, vid, f)):
- continue
-
- result = hkl.load(path.join(option, vid, f))
- if result_sum is None:
- result_sum = result.astype(np.float32)
- else:
- result_sum += result
-
- # argmax and to idx
- result_sum = np.argmax(result_sum, axis=0)
-
- # Remap the indices to the original domain
- if backward_mapping is not None:
- idx_mask = np.zeros_like(result_sum, dtype=np.uint8)
- for l, i in backward_mapping.items():
- idx_mask[result_sum==i] = l
- else:
- idx_mask = result_sum.astype(np.uint8)
-
- # Save the results
- img_E = Image.fromarray(idx_mask)
- img_E.putpalette(palette)
- img_E.save(path.join(this_out_path, f[:-4]+'.png'))
-
-
-if __name__ == '__main__':
- """
- Arguments loading
- """
- parser = ArgumentParser()
- parser.add_argument('--dataset', default='Y', help='D/Y, D for DAVIS; Y for YouTubeVOS')
- parser.add_argument('--list', nargs="+")
- parser.add_argument('--pattern', default=None, help='Glob patten. Can be used in place of list.')
- parser.add_argument('--output')
- parser.add_argument('--num_proc', default=4, type=int)
- args = parser.parse_args()
-
- out_path = args.output
-
- # Find the input candidates
- if args.pattern is None:
- all_options = args.list
- else:
- assert args.list is None, 'cannot specify both list and pattern'
- all_options = glob.glob(args.pattern)
-
- # Get the correct palette
- if 'D' in args.dataset:
- palette = ImagePalette.ImagePalette(mode='P', palette=davis_palette)
- elif 'Y' in args.dataset:
- palette = ImagePalette.ImagePalette(mode='P', palette=youtube_palette)
- else:
- raise NotImplementedError
-
- # Count of the number of videos in each candidate
- all_options = [path.join(o, 'Scores') for o in all_options]
- vid_count = defaultdict(int)
- for option in all_options:
- vid_in_here = sorted(os.listdir(option))
- for vid in vid_in_here:
- vid_count[vid] += 1
-
- all_vid = []
- count_to_vid = defaultdict(int)
- for k, v in vid_count.items():
- count_to_vid[v] += 1
- all_vid.append(k)
-
- for k, v in count_to_vid.items():
- print('Videos with count %d: %d' % (k, v))
-
- all_vid = sorted(all_vid)
- print('Total number of videos: ', len(all_vid))
-
- pool = Pool(processes=args.num_proc)
- for _ in progressbar(pool.imap_unordered(process_vid, all_vid), max_value=len(all_vid)):
- pass
-
- pool.close()
- pool.join()
-
- if 'D' in args.dataset:
- print('Making zip for DAVIS test-dev...')
- shutil.make_archive(args.output, 'zip', args.output)
-
- if 'Y' in args.dataset:
- print('Making zip for YouTubeVOS...')
- shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
diff --git a/model/__init__.py b/model/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc
deleted file mode 100644
index d437260..0000000
Binary files a/model/__pycache__/__init__.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/aggregate.cpython-39.pyc b/model/__pycache__/aggregate.cpython-39.pyc
deleted file mode 100644
index 07b4d70..0000000
Binary files a/model/__pycache__/aggregate.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/cbam.cpython-39.pyc b/model/__pycache__/cbam.cpython-39.pyc
deleted file mode 100644
index 829e4e6..0000000
Binary files a/model/__pycache__/cbam.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/group_modules.cpython-39.pyc b/model/__pycache__/group_modules.cpython-39.pyc
deleted file mode 100644
index 3583085..0000000
Binary files a/model/__pycache__/group_modules.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/losses.cpython-39.pyc b/model/__pycache__/losses.cpython-39.pyc
deleted file mode 100644
index 4e1240f..0000000
Binary files a/model/__pycache__/losses.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/memory_util.cpython-39.pyc b/model/__pycache__/memory_util.cpython-39.pyc
deleted file mode 100644
index 8935461..0000000
Binary files a/model/__pycache__/memory_util.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/modules.cpython-39.pyc b/model/__pycache__/modules.cpython-39.pyc
deleted file mode 100644
index 1b606dc..0000000
Binary files a/model/__pycache__/modules.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/network.cpython-39.pyc b/model/__pycache__/network.cpython-39.pyc
deleted file mode 100644
index 9d79a20..0000000
Binary files a/model/__pycache__/network.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/resnet.cpython-39.pyc b/model/__pycache__/resnet.cpython-39.pyc
deleted file mode 100644
index 86fa7de..0000000
Binary files a/model/__pycache__/resnet.cpython-39.pyc and /dev/null differ
diff --git a/model/__pycache__/trainer.cpython-39.pyc b/model/__pycache__/trainer.cpython-39.pyc
deleted file mode 100644
index 1424203..0000000
Binary files a/model/__pycache__/trainer.cpython-39.pyc and /dev/null differ
diff --git a/requirements.txt b/requirements.txt
index 115c19a..4c1f7aa 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,4 +4,5 @@ gitpython
git+https://github.com/cheind/py-thin-plate-spline
hickle
tensorboard
-numpy
\ No newline at end of file
+numpy
+git+https://github.com/facebookresearch/segment-anything.git
diff --git a/requirements_demo.txt b/requirements_demo.txt
deleted file mode 100644
index ad3d04c..0000000
--- a/requirements_demo.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-PyQt5
-Cython
-scipy
\ No newline at end of file
diff --git a/scripts/__init__.py b/scripts/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/scripts/download_bl30k.py b/scripts/download_bl30k.py
deleted file mode 100644
index 501a7e2..0000000
--- a/scripts/download_bl30k.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import os
-import gdown
-import tarfile
-
-
-LICENSE = """
-This dataset is a derivative of ShapeNet.
-Please read and respect their licenses and terms before use.
-Textures and skybox image are obtained from Google image search with the "non-commercial reuse" flag.
-Do not use this dataset for commercial purposes.
-You should cite both ShapeNet and our paper if you use this dataset.
-"""
-
-print(LICENSE)
-print('Datasets will be downloaded and extracted to ../BL30K')
-print('The script will download and extract the segment one by one')
-print('You are going to need ~1TB of free disk space')
-reply = input('[y] to confirm, others to exit: ')
-if reply != 'y':
- exit()
-
-links = [
- 'https://drive.google.com/uc?id=1z9V5zxLOJLNt1Uj7RFqaP2FZWKzyXvVc',
- 'https://drive.google.com/uc?id=11-IzgNwEAPxgagb67FSrBdzZR7OKAEdJ',
- 'https://drive.google.com/uc?id=1ZfIv6GTo-OGpXpoKen1fUvDQ0A_WoQ-Q',
- 'https://drive.google.com/uc?id=1G4eXgYS2kL7_Cc0x3N1g1x7Zl8D_aU_-',
- 'https://drive.google.com/uc?id=1Y8q0V_oBwJIY27W_6-8CD1dRqV2gNTdE',
- 'https://drive.google.com/uc?id=1nawBAazf_unMv46qGBHhWcQ4JXZ5883r',
-]
-
-names = [
- 'BL30K_a.tar',
- 'BL30K_b.tar',
- 'BL30K_c.tar',
- 'BL30K_d.tar',
- 'BL30K_e.tar',
- 'BL30K_f.tar',
-]
-
-for i, link in enumerate(links):
- print('Downloading segment %d/%d ...' % (i, len(links)))
- gdown.download(link, output='../%s' % names[i], quiet=False)
- print('Extracting...')
- with tarfile.open('../%s' % names[i], 'r') as tar_file:
- tar_file.extractall('../%s' % names[i])
- print('Cleaning up...')
- os.remove('../%s' % names[i])
-
-
-print('Done.')
\ No newline at end of file
diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py
deleted file mode 100644
index 7537aea..0000000
--- a/scripts/download_datasets.py
+++ /dev/null
@@ -1,149 +0,0 @@
-import os
-import gdown
-import zipfile
-from scripts import resize_youtube
-
-
-LICENSE = """
-These are either re-distribution of the original datasets or derivatives (through simple processing) of the original datasets.
-Please read and respect their licenses and terms before use.
-You should cite the original papers if you use any of the datasets.
-
-For BL30K, see download_bl30k.py
-
-Links:
-DUTS: http://saliencydetection.net/duts
-HRSOD: https://github.com/yi94code/HRSOD
-FSS: https://github.com/HKUSTCV/FSS-1000
-ECSSD: https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html
-BIG: https://github.com/hkchengrex/CascadePSP
-
-YouTubeVOS: https://youtube-vos.org
-DAVIS: https://davischallenge.org/
-BL30K: https://github.com/hkchengrex/MiVOS
-Long-Time Video: https://github.com/xmlyqing00/AFB-URR
-"""
-
-print(LICENSE)
-print('Datasets will be downloaded and extracted to ../YouTube, ../YouTube2018, ../static, ../DAVIS, ../long_video_set')
-reply = input('[y] to confirm, others to exit: ')
-if reply != 'y':
- exit()
-
-
-"""
-Static image data
-"""
-os.makedirs('../static', exist_ok=True)
-print('Downloading static datasets...')
-gdown.download('https://drive.google.com/uc?id=1wUJq3HcLdN-z1t4CsUhjeZ9BVDb9YKLd', output='../static/static_data.zip', quiet=False)
-print('Extracting static datasets...')
-with zipfile.ZipFile('../static/static_data.zip', 'r') as zip_file:
- zip_file.extractall('../static/')
-print('Cleaning up static datasets...')
-os.remove('../static/static_data.zip')
-
-
-"""
-DAVIS dataset
-"""
-# Google drive mirror: https://drive.google.com/drive/folders/1hEczGHw7qcMScbCJukZsoOW4Q9byx16A?usp=sharing
-os.makedirs('../DAVIS/2017', exist_ok=True)
-
-print('Downloading DAVIS 2016...')
-gdown.download('https://drive.google.com/uc?id=198aRlh5CpAoFz0hfRgYbiNenn_K8DxWD', output='../DAVIS/DAVIS-data.zip', quiet=False)
-
-print('Downloading DAVIS 2017 trainval...')
-gdown.download('https://drive.google.com/uc?id=1kiaxrX_4GuW6NmiVuKGSGVoKGWjOdp6d', output='../DAVIS/2017/DAVIS-2017-trainval-480p.zip', quiet=False)
-
-print('Downloading DAVIS 2017 testdev...')
-gdown.download('https://drive.google.com/uc?id=1fmkxU2v9cQwyb62Tj1xFDdh2p4kDsUzD', output='../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', quiet=False)
-
-print('Downloading DAVIS 2017 scribbles...')
-gdown.download('https://drive.google.com/uc?id=1JzIQSu36h7dVM8q0VoE4oZJwBXvrZlkl', output='../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', quiet=False)
-
-print('Extracting DAVIS datasets...')
-with zipfile.ZipFile('../DAVIS/DAVIS-data.zip', 'r') as zip_file:
- zip_file.extractall('../DAVIS/')
-os.rename('../DAVIS/DAVIS', '../DAVIS/2016')
-
-with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 'r') as zip_file:
- zip_file.extractall('../DAVIS/2017/')
-with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip', 'r') as zip_file:
- zip_file.extractall('../DAVIS/2017/')
-os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/trainval')
-
-with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 'r') as zip_file:
- zip_file.extractall('../DAVIS/2017/')
-os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/test-dev')
-
-print('Cleaning up DAVIS datasets...')
-os.remove('../DAVIS/2017/DAVIS-2017-trainval-480p.zip')
-os.remove('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip')
-os.remove('../DAVIS/2017/DAVIS-2017-scribbles-trainval.zip')
-os.remove('../DAVIS/DAVIS-data.zip')
-
-
-"""
-YouTubeVOS dataset
-"""
-os.makedirs('../YouTube', exist_ok=True)
-os.makedirs('../YouTube/all_frames', exist_ok=True)
-
-print('Downloading YouTubeVOS train...')
-gdown.download('https://drive.google.com/uc?id=13Eqw0gVK-AO5B-cqvJ203mZ2vzWck9s4', output='../YouTube/train.zip', quiet=False)
-print('Downloading YouTubeVOS val...')
-gdown.download('https://drive.google.com/uc?id=1o586Wjya-f2ohxYf9C1RlRH-gkrzGS8t', output='../YouTube/valid.zip', quiet=False)
-print('Downloading YouTubeVOS all frames valid...')
-gdown.download('https://drive.google.com/uc?id=1rWQzZcMskgpEQOZdJPJ7eTmLCBEIIpEN', output='../YouTube/all_frames/valid.zip', quiet=False)
-
-print('Extracting YouTube datasets...')
-with zipfile.ZipFile('../YouTube/train.zip', 'r') as zip_file:
- zip_file.extractall('../YouTube/')
-with zipfile.ZipFile('../YouTube/valid.zip', 'r') as zip_file:
- zip_file.extractall('../YouTube/')
-with zipfile.ZipFile('../YouTube/all_frames/valid.zip', 'r') as zip_file:
- zip_file.extractall('../YouTube/all_frames')
-
-print('Cleaning up YouTubeVOS datasets...')
-os.remove('../YouTube/train.zip')
-os.remove('../YouTube/valid.zip')
-os.remove('../YouTube/all_frames/valid.zip')
-
-print('Resizing YouTubeVOS to 480p...')
-resize_youtube.resize_all('../YouTube/train', '../YouTube/train_480p')
-
-# YouTubeVOS 2018
-os.makedirs('../YouTube2018', exist_ok=True)
-os.makedirs('../YouTube2018/all_frames', exist_ok=True)
-
-print('Downloading YouTubeVOS2018 val...')
-gdown.download('https://drive.google.com/uc?id=1-QrceIl5sUNTKz7Iq0UsWC6NLZq7girr', output='../YouTube2018/valid.zip', quiet=False)
-print('Downloading YouTubeVOS2018 all frames valid...')
-gdown.download('https://drive.google.com/uc?id=1yVoHM6zgdcL348cFpolFcEl4IC1gorbV', output='../YouTube2018/all_frames/valid.zip', quiet=False)
-
-print('Extracting YouTube2018 datasets...')
-with zipfile.ZipFile('../YouTube2018/valid.zip', 'r') as zip_file:
- zip_file.extractall('../YouTube2018/')
-with zipfile.ZipFile('../YouTube2018/all_frames/valid.zip', 'r') as zip_file:
- zip_file.extractall('../YouTube2018/all_frames')
-
-print('Cleaning up YouTubeVOS2018 datasets...')
-os.remove('../YouTube2018/valid.zip')
-os.remove('../YouTube2018/all_frames/valid.zip')
-
-
-"""
-Long-Time Video dataset
-"""
-os.makedirs('../long_video_set', exist_ok=True)
-print('Downloading long video dataset...')
-gdown.download('https://drive.google.com/uc?id=100MxAuV0_UL20ca5c-5CNpqQ5QYPDSoz', output='../long_video_set/LongTimeVideo.zip', quiet=False)
-print('Extracting long video dataset...')
-with zipfile.ZipFile('../long_video_set/LongTimeVideo.zip', 'r') as zip_file:
- zip_file.extractall('../long_video_set/')
-print('Cleaning up long video dataset...')
-os.remove('../long_video_set/LongTimeVideo.zip')
-
-
-print('Done.')
\ No newline at end of file
diff --git a/scripts/download_models.sh b/scripts/download_models.sh
deleted file mode 100755
index ba669b5..0000000
--- a/scripts/download_models.sh
+++ /dev/null
@@ -1,2 +0,0 @@
-wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth
-wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth
\ No newline at end of file
diff --git a/scripts/download_models_demo.sh b/scripts/download_models_demo.sh
deleted file mode 100755
index a63f700..0000000
--- a/scripts/download_models_demo.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth
-wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/fbrs.pth
-wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/s2m.pth
\ No newline at end of file
diff --git a/scripts/expand_long_vid.py b/scripts/expand_long_vid.py
deleted file mode 100644
index ae237bc..0000000
--- a/scripts/expand_long_vid.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import sys
-import os
-from os import path
-from shutil import copy2
-
-input_path = sys.argv[1]
-output_path = sys.argv[2]
-multiplier = int(sys.argv[3])
-image_path = path.join(input_path, 'JPEGImages')
-gt_path = path.join(input_path, 'Annotations')
-
-videos = sorted(os.listdir(image_path))
-
-for vid in videos:
- os.makedirs(path.join(output_path, 'JPEGImages', vid), exist_ok=True)
- os.makedirs(path.join(output_path, 'Annotations', vid), exist_ok=True)
- frames = sorted(os.listdir(path.join(image_path, vid)))
-
- num_frames = len(frames)
- counter = 0
- output_counter = 0
- direction = 1
- for _ in range(multiplier):
- for _ in range(num_frames):
- copy2(path.join(image_path, vid, frames[counter]),
- path.join(output_path, 'JPEGImages', vid, f'{output_counter:05d}.jpg'))
-
- mask_path = path.join(gt_path, vid, frames[counter].replace('.jpg', '.png'))
- if path.exists(mask_path):
- copy2(mask_path,
- path.join(output_path, 'Annotations', vid, f'{output_counter:05d}.png'))
-
- counter += direction
- output_counter += 1
- if counter == 0 or counter == len(frames) - 1:
- direction *= -1
diff --git a/scripts/resize_youtube.py b/scripts/resize_youtube.py
deleted file mode 100644
index 501e9f5..0000000
--- a/scripts/resize_youtube.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import sys
-import os
-from os import path
-
-from PIL import Image
-import numpy as np
-from progressbar import progressbar
-from multiprocessing import Pool
-
-new_min_size = 480
-
-def resize_vid_jpeg(inputs):
- vid_name, folder_path, out_path = inputs
-
- vid_path = path.join(folder_path, vid_name)
- vid_out_path = path.join(out_path, 'JPEGImages', vid_name)
- os.makedirs(vid_out_path, exist_ok=True)
-
- for im_name in os.listdir(vid_path):
- hr_im = Image.open(path.join(vid_path, im_name))
- w, h = hr_im.size
-
- ratio = new_min_size / min(w, h)
-
- lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.BICUBIC)
- lr_im.save(path.join(vid_out_path, im_name))
-
-def resize_vid_anno(inputs):
- vid_name, folder_path, out_path = inputs
-
- vid_path = path.join(folder_path, vid_name)
- vid_out_path = path.join(out_path, 'Annotations', vid_name)
- os.makedirs(vid_out_path, exist_ok=True)
-
- for im_name in os.listdir(vid_path):
- hr_im = Image.open(path.join(vid_path, im_name)).convert('P')
- w, h = hr_im.size
-
- ratio = new_min_size / min(w, h)
-
- lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.NEAREST)
- lr_im.save(path.join(vid_out_path, im_name))
-
-
-def resize_all(in_path, out_path):
- for folder in os.listdir(in_path):
-
- if folder not in ['JPEGImages', 'Annotations']:
- continue
- folder_path = path.join(in_path, folder)
- videos = os.listdir(folder_path)
-
- videos = [(v, folder_path, out_path) for v in videos]
-
- if folder == 'JPEGImages':
- print('Processing images')
- os.makedirs(path.join(out_path, 'JPEGImages'), exist_ok=True)
-
- pool = Pool(processes=8)
- for _ in progressbar(pool.imap_unordered(resize_vid_jpeg, videos), max_value=len(videos)):
- pass
- else:
- print('Processing annotations')
- os.makedirs(path.join(out_path, 'Annotations'), exist_ok=True)
-
- pool = Pool(processes=8)
- for _ in progressbar(pool.imap_unordered(resize_vid_anno, videos), max_value=len(videos)):
- pass
-
-
-if __name__ == '__main__':
- in_path = sys.argv[1]
- out_path = sys.argv[2]
-
- resize_all(in_path, out_path)
-
- print('Done.')
\ No newline at end of file
diff --git a/dataset/__init__.py b/tracker/dataset/__init__.py
similarity index 100%
rename from dataset/__init__.py
rename to tracker/dataset/__init__.py
diff --git a/dataset/range_transform.py b/tracker/dataset/range_transform.py
similarity index 100%
rename from dataset/range_transform.py
rename to tracker/dataset/range_transform.py
diff --git a/dataset/reseed.py b/tracker/dataset/reseed.py
similarity index 100%
rename from dataset/reseed.py
rename to tracker/dataset/reseed.py
diff --git a/dataset/static_dataset.py b/tracker/dataset/static_dataset.py
similarity index 100%
rename from dataset/static_dataset.py
rename to tracker/dataset/static_dataset.py
diff --git a/dataset/tps.py b/tracker/dataset/tps.py
similarity index 100%
rename from dataset/tps.py
rename to tracker/dataset/tps.py
diff --git a/dataset/util.py b/tracker/dataset/util.py
similarity index 100%
rename from dataset/util.py
rename to tracker/dataset/util.py
diff --git a/dataset/vos_dataset.py b/tracker/dataset/vos_dataset.py
similarity index 100%
rename from dataset/vos_dataset.py
rename to tracker/dataset/vos_dataset.py
diff --git a/eval.py b/tracker/eval.py
similarity index 97%
rename from eval.py
rename to tracker/eval.py
index 3bb1da9..d1ddb32 100644
--- a/eval.py
+++ b/tracker/eval.py
@@ -26,12 +26,12 @@ except ImportError:
Arguments loading
"""
parser = ArgumentParser()
-parser.add_argument('--model', default='./saves/XMem.pth')
+parser.add_argument('--model', default='/ssd1/gaomingqi/checkpoints/XMem-s012.pth')
# Data options
parser.add_argument('--d16_path', default='../DAVIS/2016')
parser.add_argument('--d17_path', default='../DAVIS/2017')
-parser.add_argument('--y18_path', default='../YouTube2018')
+parser.add_argument('--y18_path', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
parser.add_argument('--y19_path', default='../YouTube')
parser.add_argument('--lv_path', default='../long_video_set')
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
@@ -205,6 +205,9 @@ for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect
# Run the model on this frame
prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1))
+ # consider prob as prompt to refine segment results
+
+
# Upsample to original size if needed
if need_resize:
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
diff --git a/inference/__init__.py b/tracker/inference/__init__.py
similarity index 100%
rename from inference/__init__.py
rename to tracker/inference/__init__.py
diff --git a/inference/data/__init__.py b/tracker/inference/data/__init__.py
similarity index 100%
rename from inference/data/__init__.py
rename to tracker/inference/data/__init__.py
diff --git a/inference/data/mask_mapper.py b/tracker/inference/data/mask_mapper.py
similarity index 100%
rename from inference/data/mask_mapper.py
rename to tracker/inference/data/mask_mapper.py
diff --git a/inference/data/test_datasets.py b/tracker/inference/data/test_datasets.py
similarity index 100%
rename from inference/data/test_datasets.py
rename to tracker/inference/data/test_datasets.py
diff --git a/inference/data/video_reader.py b/tracker/inference/data/video_reader.py
similarity index 100%
rename from inference/data/video_reader.py
rename to tracker/inference/data/video_reader.py
diff --git a/inference/inference_core.py b/tracker/inference/inference_core.py
similarity index 99%
rename from inference/inference_core.py
rename to tracker/inference/inference_core.py
index f5459df..c4cbffd 100644
--- a/inference/inference_core.py
+++ b/tracker/inference/inference_core.py
@@ -62,6 +62,9 @@ class InferenceCore:
# segment the current frame is needed
if need_segment:
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
+
+
+
hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
# remove batch dim
diff --git a/inference/kv_memory_store.py b/tracker/inference/kv_memory_store.py
similarity index 100%
rename from inference/kv_memory_store.py
rename to tracker/inference/kv_memory_store.py
diff --git a/inference/memory_manager.py b/tracker/inference/memory_manager.py
similarity index 100%
rename from inference/memory_manager.py
rename to tracker/inference/memory_manager.py
diff --git a/inference/interact/__init__.py b/tracker/model/__init__.py
similarity index 100%
rename from inference/interact/__init__.py
rename to tracker/model/__init__.py
diff --git a/model/aggregate.py b/tracker/model/aggregate.py
similarity index 100%
rename from model/aggregate.py
rename to tracker/model/aggregate.py
diff --git a/model/cbam.py b/tracker/model/cbam.py
similarity index 100%
rename from model/cbam.py
rename to tracker/model/cbam.py
diff --git a/model/group_modules.py b/tracker/model/group_modules.py
similarity index 100%
rename from model/group_modules.py
rename to tracker/model/group_modules.py
diff --git a/model/losses.py b/tracker/model/losses.py
similarity index 100%
rename from model/losses.py
rename to tracker/model/losses.py
diff --git a/model/memory_util.py b/tracker/model/memory_util.py
similarity index 100%
rename from model/memory_util.py
rename to tracker/model/memory_util.py
diff --git a/model/modules.py b/tracker/model/modules.py
similarity index 100%
rename from model/modules.py
rename to tracker/model/modules.py
diff --git a/model/network.py b/tracker/model/network.py
similarity index 100%
rename from model/network.py
rename to tracker/model/network.py
diff --git a/model/resnet.py b/tracker/model/resnet.py
similarity index 100%
rename from model/resnet.py
rename to tracker/model/resnet.py
diff --git a/model/trainer.py b/tracker/model/trainer.py
similarity index 100%
rename from model/trainer.py
rename to tracker/model/trainer.py
diff --git a/inference/interact/fbrs/__init__.py b/tracker/util/__init__.py
similarity index 100%
rename from inference/interact/fbrs/__init__.py
rename to tracker/util/__init__.py
diff --git a/util/configuration.py b/tracker/util/configuration.py
similarity index 100%
rename from util/configuration.py
rename to tracker/util/configuration.py
diff --git a/util/davis_subset.txt b/tracker/util/davis_subset.txt
similarity index 100%
rename from util/davis_subset.txt
rename to tracker/util/davis_subset.txt
diff --git a/util/image_saver.py b/tracker/util/image_saver.py
similarity index 100%
rename from util/image_saver.py
rename to tracker/util/image_saver.py
diff --git a/util/load_subset.py b/tracker/util/load_subset.py
similarity index 100%
rename from util/load_subset.py
rename to tracker/util/load_subset.py
diff --git a/util/log_integrator.py b/tracker/util/log_integrator.py
similarity index 100%
rename from util/log_integrator.py
rename to tracker/util/log_integrator.py
diff --git a/util/logger.py b/tracker/util/logger.py
similarity index 100%
rename from util/logger.py
rename to tracker/util/logger.py
diff --git a/util/palette.py b/tracker/util/palette.py
similarity index 100%
rename from util/palette.py
rename to tracker/util/palette.py
diff --git a/util/tensor_util.py b/tracker/util/tensor_util.py
similarity index 100%
rename from util/tensor_util.py
rename to tracker/util/tensor_util.py
diff --git a/util/yv_subset.txt b/tracker/util/yv_subset.txt
similarity index 100%
rename from util/yv_subset.txt
rename to tracker/util/yv_subset.txt
diff --git a/train.py b/train.py
deleted file mode 100644
index e5418c7..0000000
--- a/train.py
+++ /dev/null
@@ -1,253 +0,0 @@
-import datetime
-from os import path
-import math
-import git
-
-import random
-import numpy as np
-import torch
-from torch.utils.data import DataLoader, ConcatDataset
-import torch.distributed as distributed
-
-from model.trainer import XMemTrainer
-from dataset.static_dataset import StaticTransformDataset
-from dataset.vos_dataset import VOSDataset
-
-from util.logger import TensorboardLogger
-from util.configuration import Configuration
-from util.load_subset import load_sub_davis, load_sub_yv
-
-
-"""
-Initial setup
-"""
-# Init distributed environment
-distributed.init_process_group(backend="nccl")
-print(f'CUDA Device count: {torch.cuda.device_count()}')
-
-# Parse command line arguments
-raw_config = Configuration()
-raw_config.parse()
-
-if raw_config['benchmark']:
- torch.backends.cudnn.benchmark = True
-
-# Get current git info
-# repo = git.Repo(".")
-# git_info = str(repo.active_branch)+' '+str(repo.head.commit.hexsha)
-
-local_rank = torch.distributed.get_rank()
-world_size = torch.distributed.get_world_size()
-torch.cuda.set_device(local_rank)
-
-print(f'I am rank {local_rank} in this world of size {world_size}!')
-
-network_in_memory = None
-stages = raw_config['stages']
-stages_to_perform = list(stages)
-for si, stage in enumerate(stages_to_perform):
-
- # Set seed to ensure the same initialization
- torch.manual_seed(14159265)
- np.random.seed(14159265)
- random.seed(14159265)
-
- # Pick stage specific hyperparameters out
- stage_config = raw_config.get_stage_parameters(stage)
- config = dict(**raw_config.args, **stage_config)
- if config['exp_id'] != 'NULL':
- config['exp_id'] = config['exp_id']+'_s%s'%stages[:si+1]
-
- config['single_object'] = (stage == '0')
-
- config['num_gpus'] = world_size
- if config['batch_size']//config['num_gpus']*config['num_gpus'] != config['batch_size']:
- raise ValueError('Batch size must be divisible by the number of GPUs.')
- config['batch_size'] //= config['num_gpus']
- config['num_workers'] //= config['num_gpus']
- print(f'We are assuming {config["num_gpus"]} GPUs.')
-
- print(f'We are now starting stage {stage}')
-
- """
- Model related
- """
- if local_rank == 0:
- # Logging
- if config['exp_id'].lower() != 'null':
- print('I will take the role of logging!')
- long_id = '%s_%s' % (datetime.datetime.now().strftime('%b%d_%H.%M.%S'), config['exp_id'])
- else:
- long_id = None
- logger = TensorboardLogger(config['exp_id'], long_id, long_id)
- logger.log_string('hyperpara', str(config))
-
- # Construct the rank 0 model
- model = XMemTrainer(config, logger=logger,
- save_path=path.join(config['save_path'], long_id, long_id) if long_id is not None else None,
- local_rank=local_rank, world_size=world_size).train()
- else:
- # Construct model for other ranks
- model = XMemTrainer(config, local_rank=local_rank, world_size=world_size).train()
-
- # Load pertrained model if needed
- if raw_config['load_checkpoint'] is not None:
- total_iter = model.load_checkpoint(raw_config['load_checkpoint'])
- raw_config['load_checkpoint'] = None
- print('Previously trained model loaded!')
- else:
- total_iter = 0
-
- if network_in_memory is not None:
- print('I am loading network from the previous stage')
- model.load_network_in_memory(network_in_memory)
- network_in_memory = None
- elif raw_config['load_network'] is not None:
- print('I am loading network from a disk, as listed in configuration')
- model.load_network(raw_config['load_network'])
- raw_config['load_network'] = None
-
- """
- Dataloader related
- """
- # To re-seed the randomness everytime we start a worker
- def worker_init_fn(worker_id):
- worker_seed = torch.initial_seed()%(2**31) + worker_id + local_rank*100
- np.random.seed(worker_seed)
- random.seed(worker_seed)
-
- def construct_loader(dataset):
- train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, rank=local_rank, shuffle=True)
- train_loader = DataLoader(dataset, config['batch_size'], sampler=train_sampler, num_workers=config['num_workers'],
- worker_init_fn=worker_init_fn, drop_last=True)
- return train_sampler, train_loader
-
- def renew_vos_loader(max_skip, finetune=False):
- # //5 because we only have annotation for every five frames
- yv_dataset = VOSDataset(path.join(yv_root, 'JPEGImages'),
- path.join(yv_root, 'Annotations'), max_skip//5, is_bl=False, subset=load_sub_yv(), num_frames=config['num_frames'], finetune=finetune)
- davis_dataset = VOSDataset(path.join(davis_root, 'JPEGImages', '480p'),
- path.join(davis_root, 'Annotations', '480p'), max_skip, is_bl=False, subset=load_sub_davis(), num_frames=config['num_frames'], finetune=finetune)
- train_dataset = ConcatDataset([davis_dataset]*5 + [yv_dataset])
-
- print(f'YouTube dataset size: {len(yv_dataset)}')
- print(f'DAVIS dataset size: {len(davis_dataset)}')
- print(f'Concat dataset size: {len(train_dataset)}')
- print(f'Renewed with {max_skip=}')
-
- return construct_loader(train_dataset)
-
- def renew_bl_loader(max_skip, finetune=False):
- train_dataset = VOSDataset(path.join(bl_root, 'JPEGImages'),
- path.join(bl_root, 'Annotations'), max_skip, is_bl=True, num_frames=config['num_frames'], finetune=finetune)
-
- print(f'Blender dataset size: {len(train_dataset)}')
- print(f'Renewed with {max_skip=}')
-
- return construct_loader(train_dataset)
-
- """
- Dataset related
- """
-
- """
- These define the training schedule of the distance between frames
- We will switch to max_skip_values[i] once we pass the percentage specified by increase_skip_fraction[i]
- Not effective for stage 0 training
- The initial value is not listed here but in renew_vos_loader(X)
- """
- max_skip_values = [10, 15, 5, 5]
-
- if stage == '0':
- static_root = path.expanduser(config['static_root'])
- # format: path, method (style of storing images), mutliplier
- train_dataset = StaticTransformDataset(
- [
- (path.join(static_root, 'fss'), 0, 1),
- (path.join(static_root, 'DUTS-TR'), 1, 1),
- (path.join(static_root, 'DUTS-TE'), 1, 1),
- (path.join(static_root, 'ecssd'), 1, 1),
- (path.join(static_root, 'BIG_small'), 1, 5),
- (path.join(static_root, 'HRSOD_small'), 1, 5),
- ], num_frames=config['num_frames'])
- train_sampler, train_loader = construct_loader(train_dataset)
-
- print(f'Static dataset size: {len(train_dataset)}')
- elif stage == '1':
- increase_skip_fraction = [0.1, 0.3, 0.8, 100]
- bl_root = path.join(path.expanduser(config['bl_root']))
-
- train_sampler, train_loader = renew_bl_loader(5)
- renew_loader = renew_bl_loader
- else:
- # stage 2 or 3
- increase_skip_fraction = [0.1, 0.3, 0.9, 100]
- # VOS dataset, 480p is used for both datasets
- yv_root = path.join(path.expanduser(config['yv_root']), 'train')
- davis_root = path.join(path.expanduser(config['davis_root'])) #, '2017', 'trainval')
-
- train_sampler, train_loader = renew_vos_loader(5)
- renew_loader = renew_vos_loader
-
-
- """
- Determine max epoch
- """
- total_epoch = math.ceil(config['iterations']/len(train_loader))
- current_epoch = total_iter // len(train_loader)
- print(f'We approximately use {total_epoch} epochs.')
- if stage != '0':
- change_skip_iter = [round(config['iterations']*f) for f in increase_skip_fraction]
- # Skip will only change after an epoch, not in the middle
- print(f'The skip value will change approximately at the following iterations: {change_skip_iter[:-1]}')
-
- """
- Starts training
- """
- finetuning = False
- # Need this to select random bases in different workers
- np.random.seed(np.random.randint(2**30-1) + local_rank*100)
-
- max_iter = config['iterations'] + config['finetune']
-
- try:
- while total_iter < config['iterations'] + config['finetune']:
-
- # Crucial for randomness!
- train_sampler.set_epoch(current_epoch)
- current_epoch += 1
- print(f'Current epoch: {current_epoch}')
-
- # Train loop
- model.train()
- for data in train_loader:
- # Update skip if needed
- if stage!='0' and total_iter >= change_skip_iter[0]:
- while total_iter >= change_skip_iter[0]:
- cur_skip = max_skip_values[0]
- max_skip_values = max_skip_values[1:]
- change_skip_iter = change_skip_iter[1:]
- print(f'Changing skip to {cur_skip=}')
- train_sampler, train_loader = renew_loader(cur_skip)
- break
-
- # fine-tune means fewer augmentations to train the sensory memory
- if config['finetune'] > 0 and not finetuning and total_iter >= config['iterations']:
- train_sampler, train_loader = renew_loader(cur_skip, finetune=True)
- finetuning = True
- model.save_network_interval = 1000
- break
-
- model.do_pass(data, max_iter, total_iter)
- total_iter += 1
-
- if total_iter >= config['iterations'] + config['finetune']:
- break
- finally:
- if not config['debug'] and model.logger is not None and total_iter>5000:
- model.save_network(total_iter)
- model.save_checkpoint(total_iter)
-
- network_in_memory = model.XMem.module.state_dict()
-
-distributed.destroy_process_group()
diff --git a/util/__init__.py b/util/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/util/__pycache__/__init__.cpython-39.pyc b/util/__pycache__/__init__.cpython-39.pyc
deleted file mode 100644
index 53e8c89..0000000
Binary files a/util/__pycache__/__init__.cpython-39.pyc and /dev/null differ
diff --git a/util/__pycache__/configuration.cpython-39.pyc b/util/__pycache__/configuration.cpython-39.pyc
deleted file mode 100644
index 59c5602..0000000
Binary files a/util/__pycache__/configuration.cpython-39.pyc and /dev/null differ
diff --git a/util/__pycache__/image_saver.cpython-39.pyc b/util/__pycache__/image_saver.cpython-39.pyc
deleted file mode 100644
index b1d7c1b..0000000
Binary files a/util/__pycache__/image_saver.cpython-39.pyc and /dev/null differ
diff --git a/util/__pycache__/load_subset.cpython-39.pyc b/util/__pycache__/load_subset.cpython-39.pyc
deleted file mode 100644
index 1e43c88..0000000
Binary files a/util/__pycache__/load_subset.cpython-39.pyc and /dev/null differ
diff --git a/util/__pycache__/log_integrator.cpython-39.pyc b/util/__pycache__/log_integrator.cpython-39.pyc
deleted file mode 100644
index 990cca9..0000000
Binary files a/util/__pycache__/log_integrator.cpython-39.pyc and /dev/null differ
diff --git a/util/__pycache__/logger.cpython-39.pyc b/util/__pycache__/logger.cpython-39.pyc
deleted file mode 100644
index f3d14e4..0000000
Binary files a/util/__pycache__/logger.cpython-39.pyc and /dev/null differ