This commit is contained in:
hzwer
2021-05-13 17:02:14 +08:00
parent c1a456dd72
commit 2607a26dd1

View File

@@ -195,13 +195,10 @@ class Model:
return pred
def inference(self, img0, img1, UHD=False):
print(img0.shape, img1.shape)
imgs = torch.cat((img0, img1), 1)
scale_list = [8, 4, 2]
flow, _ = self.flownet(imgs, scale_list)
print(flow.shape, UHD)
res = self.predict(imgs, flow, training=False, UHD=False)
print(img0.shape, img1.shape, res.shape)
return self.predict(imgs, flow, training=False, UHD=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):