[to #42322933]修复nano模型初始化/增加文件copyright信息

修复nano模型初始化/增加文件copyright信息
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10247456
This commit is contained in:
leyuan.hjy
2022-09-26 15:52:03 +08:00
committed by Yingda Chen
parent 69497c25f5
commit bf0ae653e7
3 changed files with 9 additions and 2 deletions

View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import logging as logger
import os
@@ -48,6 +49,7 @@ class RealtimeDetector(TorchModel):
self.nmsthre = self.exp.nmsthre
self.test_size = self.exp.test_size
self.preproc = ValTransform(legacy=False)
self.label_mapping = self.config['labels']
def inference(self, img):
with torch.no_grad():
@@ -81,5 +83,8 @@ class RealtimeDetector(TorchModel):
bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio
scores = outputs[0][:, 5].cpu().numpy()
labels = outputs[0][:, 6].cpu().int().numpy()
pred_label_names = []
for lab in labels:
pred_label_names.append(self.label_mapping[lab])
return bboxes, scores, labels
return bboxes, scores, pred_label_names

View File

@@ -42,5 +42,6 @@ class YoloXNanoExp(YoloXExp):
act=self.act,
depthwise=True)
self.model = YOLOX(backbone, head)
self.model.apply(init_yolo)
self.model.head.initialize_biases(1e-2)
return self.model

View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Any, Dict, List, Union