mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
138 lines
5.2 KiB
Python
138 lines
5.2 KiB
Python
"""
|
|
/*****************************************************************************/
|
|
|
|
BatchNorm2dSync with multi-gpu
|
|
|
|
code referenced from : https://github.com/mapillary/inplace_abn
|
|
|
|
/*****************************************************************************/
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import torch.cuda.comm as comm
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
from ._csrc import _backend
|
|
|
|
|
|
def _count_samples(x):
|
|
count = 1
|
|
for i, s in enumerate(x.size()):
|
|
if i != 1:
|
|
count *= s
|
|
return count
|
|
|
|
|
|
class BatchNorm2dSyncFunc(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, weight, bias, running_mean, running_var,
|
|
extra, compute_stats=True, momentum=0.1, eps=1e-05):
|
|
def _parse_extra(ctx, extra):
|
|
ctx.is_master = extra["is_master"]
|
|
if ctx.is_master:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queues = extra["worker_queues"]
|
|
ctx.worker_ids = extra["worker_ids"]
|
|
else:
|
|
ctx.master_queue = extra["master_queue"]
|
|
ctx.worker_queue = extra["worker_queue"]
|
|
# Save context
|
|
if extra is not None:
|
|
_parse_extra(ctx, extra)
|
|
ctx.compute_stats = compute_stats
|
|
ctx.momentum = momentum
|
|
ctx.eps = eps
|
|
ctx.affine = weight is not None and bias is not None
|
|
if ctx.compute_stats:
|
|
N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
|
|
assert N > 1
|
|
# 1. compute sum(x) and sum(x^2)
|
|
xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
|
|
if ctx.is_master:
|
|
xsums, xsqsums = [xsum], [xsqsum]
|
|
# master : gatther all sum(x) and sum(x^2) from slaves
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
xsum_w, xsqsum_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
xsums.append(xsum_w)
|
|
xsqsums.append(xsqsum_w)
|
|
xsum = comm.reduce_add(xsums)
|
|
xsqsum = comm.reduce_add(xsqsums)
|
|
mean = xsum / N
|
|
sumvar = xsqsum - xsum * mean
|
|
var = sumvar / N
|
|
uvar = sumvar / (N - 1)
|
|
# master : broadcast global mean, variance to all slaves
|
|
tensors = comm.broadcast_coalesced(
|
|
(mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
# slave : send sum(x) and sum(x^2) to master
|
|
ctx.master_queue.put((xsum, xsqsum))
|
|
# slave : get global mean and variance
|
|
mean, uvar, var = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
# Update running stats
|
|
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
|
|
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
|
|
ctx.N = N
|
|
ctx.save_for_backward(x, weight, bias, mean, var)
|
|
else:
|
|
mean, var = running_mean, running_var
|
|
|
|
# do batch norm forward
|
|
z = _backend.syncbn_forward(x, weight, bias, mean, var,
|
|
ctx.affine, ctx.eps)
|
|
return z
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, dz):
|
|
x, weight, bias, mean, var = ctx.saved_tensors
|
|
dz = dz.contiguous()
|
|
|
|
# 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
|
|
sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat(
|
|
dz, x, mean, var, ctx.eps)
|
|
if ctx.is_master:
|
|
sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
|
|
# master : gatther from slaves
|
|
for _ in range(ctx.master_queue.maxsize):
|
|
sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
|
|
ctx.master_queue.task_done()
|
|
sum_dzs.append(sum_dz_w)
|
|
sum_dz_xhats.append(sum_dz_xhat_w)
|
|
# master : compute global stats
|
|
sum_dz = comm.reduce_add(sum_dzs)
|
|
sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
|
|
sum_dz /= ctx.N
|
|
sum_dz_xhat /= ctx.N
|
|
# master : broadcast global stats
|
|
tensors = comm.broadcast_coalesced(
|
|
(sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
|
|
for ts, queue in zip(tensors[1:], ctx.worker_queues):
|
|
queue.put(ts)
|
|
else:
|
|
# slave : send to master
|
|
ctx.master_queue.put((sum_dz, sum_dz_xhat))
|
|
# slave : get global stats
|
|
sum_dz, sum_dz_xhat = ctx.worker_queue.get()
|
|
ctx.worker_queue.task_done()
|
|
|
|
# do batch norm backward
|
|
dx, dweight, dbias = _backend.syncbn_backward(
|
|
dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat,
|
|
ctx.affine, ctx.eps)
|
|
|
|
return dx, dweight, dbias, \
|
|
None, None, None, None, None, None
|
|
|
|
batchnorm2d_sync = BatchNorm2dSyncFunc.apply
|
|
|
|
__all__ = ["batchnorm2d_sync"]
|