mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
Merge remote-tracking branch 'origin/master' into ofa/finetune
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.8.3
|
||||
rev: 4.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: thirdparty/|examples/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: /home/admin/pre-commit/flake8
|
||||
rev: 3.8.3
|
||||
rev: 4.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: thirdparty/|examples/
|
||||
|
||||
@@ -390,11 +390,13 @@ class HubApi:
|
||||
return resp['Data']
|
||||
|
||||
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
|
||||
is_recursive, is_filter_dir, revision,
|
||||
cookies):
|
||||
is_recursive, is_filter_dir, revision):
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
|
||||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
|
||||
cookies = requests.utils.dict_from_cookiejar(cookies)
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies:
|
||||
cookies = requests.utils.dict_from_cookiejar(cookies)
|
||||
|
||||
resp = requests.get(url=url, cookies=cookies)
|
||||
resp = resp.json()
|
||||
|
||||
@@ -9,7 +9,9 @@ class Models(object):
|
||||
|
||||
Model name should only contain model info but not task info.
|
||||
"""
|
||||
# tinynas models
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
# vision models
|
||||
detection = 'detection'
|
||||
|
||||
@@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tinynas_detector import Tinynas_detector
|
||||
from .tinynas_damoyolo import DamoYolo
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'tinynas_detector': ['TinynasDetector'],
|
||||
'tinynas_damoyolo': ['DamoYolo'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.utils.file_utils import read_file
|
||||
from ..core.base_ops import Focus, SPPBottleneck, get_activation
|
||||
from ..core.repvgg_block import RepVggBlock
|
||||
|
||||
@@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module):
|
||||
kernel_size,
|
||||
stride,
|
||||
force_resproj=False,
|
||||
act='silu'):
|
||||
act='silu',
|
||||
reparam=False):
|
||||
super(ResConvK1KX, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
|
||||
self.conv2 = RepVggBlock(
|
||||
btn_c, out_c, kernel_size, stride, act='identity')
|
||||
if not reparam:
|
||||
self.conv2 = ConvKXBN(btn_c, out_c, 3, stride)
|
||||
else:
|
||||
self.conv2 = RepVggBlock(
|
||||
btn_c, out_c, kernel_size, stride, act='identity')
|
||||
|
||||
if act is None:
|
||||
self.activation_function = torch.relu
|
||||
@@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module):
|
||||
stride,
|
||||
num_blocks,
|
||||
with_spp=False,
|
||||
act='silu'):
|
||||
act='silu',
|
||||
reparam=False):
|
||||
super(SuperResConvK1KX, self).__init__()
|
||||
if act is None:
|
||||
self.act = torch.relu
|
||||
@@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module):
|
||||
this_kernel_size,
|
||||
this_stride,
|
||||
force_resproj,
|
||||
act=act)
|
||||
act=act,
|
||||
reparam=reparam)
|
||||
self.block_list.append(the_block)
|
||||
if block_id == 0 and with_spp:
|
||||
self.block_list.append(
|
||||
@@ -248,7 +255,8 @@ class TinyNAS(nn.Module):
|
||||
with_spp=False,
|
||||
use_focus=False,
|
||||
need_conv1=True,
|
||||
act='silu'):
|
||||
act='silu',
|
||||
reparam=False):
|
||||
super(TinyNAS, self).__init__()
|
||||
assert len(out_indices) == len(out_channels)
|
||||
self.out_indices = out_indices
|
||||
@@ -281,7 +289,8 @@ class TinyNAS(nn.Module):
|
||||
block_info['s'],
|
||||
block_info['L'],
|
||||
spp,
|
||||
act=act)
|
||||
act=act,
|
||||
reparam=reparam)
|
||||
self.block_list.append(the_block)
|
||||
elif the_block_class == 'SuperResConvKXKX':
|
||||
spp = with_spp if idx == len(structure_info) - 1 else False
|
||||
@@ -325,8 +334,8 @@ class TinyNAS(nn.Module):
|
||||
def load_tinynas_net(backbone_cfg):
|
||||
# load masternet model to path
|
||||
import ast
|
||||
|
||||
struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str])
|
||||
net_structure_str = read_file(backbone_cfg.structure_file)
|
||||
struct_str = ''.join([x.strip() for x in net_structure_str])
|
||||
struct_info = ast.literal_eval(struct_str)
|
||||
for layer in struct_info:
|
||||
if 'nbitsA' in layer:
|
||||
@@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg):
|
||||
use_focus=backbone_cfg.use_focus,
|
||||
act=backbone_cfg.act,
|
||||
need_conv1=backbone_cfg.need_conv1,
|
||||
)
|
||||
reparam=backbone_cfg.reparam)
|
||||
|
||||
return model
|
||||
|
||||
@@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel):
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
config_path = osp.join(model_dir, 'airdet_s.py')
|
||||
config_path = osp.join(model_dir, self.config_name)
|
||||
config = parse_config(config_path)
|
||||
self.cfg = config
|
||||
model_path = osp.join(model_dir, config.model.name)
|
||||
@@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel):
|
||||
self.conf_thre = config.model.head.nms_conf_thre
|
||||
self.nms_thre = config.model.head.nms_iou_thre
|
||||
|
||||
if self.cfg.model.backbone.name == 'TinyNAS':
|
||||
self.cfg.model.backbone.structure_file = osp.join(
|
||||
model_dir, self.cfg.model.backbone.structure_file)
|
||||
self.backbone = build_backbone(self.cfg.model.backbone)
|
||||
self.neck = build_neck(self.cfg.model.neck)
|
||||
self.head = build_head(self.cfg.model.head)
|
||||
|
||||
@@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module):
|
||||
simOTA_iou_weight=3.0,
|
||||
octbase=8,
|
||||
simlqe=False,
|
||||
use_lqe=True,
|
||||
**kwargs):
|
||||
self.simlqe = simlqe
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.strides = strides
|
||||
self.use_lqe = use_lqe
|
||||
self.feat_channels = feat_channels if isinstance(feat_channels, list) \
|
||||
else [feat_channels] * len(self.strides)
|
||||
|
||||
@@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module):
|
||||
groups=self.conv_groups,
|
||||
norm=self.norm,
|
||||
act=self.act))
|
||||
if not self.simlqe:
|
||||
conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)]
|
||||
if self.use_lqe:
|
||||
if not self.simlqe:
|
||||
conf_vector = [
|
||||
nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)
|
||||
]
|
||||
else:
|
||||
conf_vector = [
|
||||
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
|
||||
]
|
||||
conf_vector += [self.relu]
|
||||
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
|
||||
reg_conf = nn.Sequential(*conf_vector)
|
||||
else:
|
||||
conf_vector = [
|
||||
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
|
||||
]
|
||||
conf_vector += [self.relu]
|
||||
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
|
||||
reg_conf = nn.Sequential(*conf_vector)
|
||||
reg_conf = None
|
||||
|
||||
return cls_convs, reg_convs, reg_conf
|
||||
|
||||
@@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module):
|
||||
N, C, H, W = bbox_pred.size()
|
||||
prob = F.softmax(
|
||||
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
|
||||
if not self.simlqe:
|
||||
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
|
||||
if self.use_lqe:
|
||||
if not self.simlqe:
|
||||
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
|
||||
|
||||
if self.add_mean:
|
||||
stat = torch.cat(
|
||||
[prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2)
|
||||
if self.add_mean:
|
||||
stat = torch.cat(
|
||||
[prob_topk,
|
||||
prob_topk.mean(dim=2, keepdim=True)],
|
||||
dim=2)
|
||||
else:
|
||||
stat = prob_topk
|
||||
|
||||
quality_score = reg_conf(
|
||||
stat.reshape(N, 4 * self.total_dim, H, W))
|
||||
else:
|
||||
stat = prob_topk
|
||||
quality_score = reg_conf(
|
||||
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
|
||||
|
||||
quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W))
|
||||
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
|
||||
else:
|
||||
quality_score = reg_conf(
|
||||
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
|
||||
|
||||
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
|
||||
cls_score = gfl_cls(cls_feat).sigmoid()
|
||||
|
||||
flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2)
|
||||
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2)
|
||||
|
||||
@@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module):
|
||||
self,
|
||||
depth=1.0,
|
||||
width=1.0,
|
||||
in_features=[2, 3, 4],
|
||||
in_channels=[256, 512, 1024],
|
||||
out_channels=[256, 512, 1024],
|
||||
depthwise=False,
|
||||
@@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module):
|
||||
block_name='BasicBlock',
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.in_channels = in_channels
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
|
||||
@@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module):
|
||||
"""
|
||||
|
||||
# backbone
|
||||
features = [out_features[f] for f in self.in_features]
|
||||
[x2, x1, x0] = features
|
||||
[x2, x1, x0] = out_features
|
||||
|
||||
# node x3
|
||||
x13 = self.bu_conv13(x1)
|
||||
|
||||
15
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
Normal file
15
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .detector import SingleStageDetector
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
|
||||
class DamoYolo(SingleStageDetector):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
self.config_name = 'damoyolo_s.py'
|
||||
super(DamoYolo, self).__init__(model_dir, *args, **kwargs)
|
||||
@@ -12,5 +12,5 @@ from .detector import SingleStageDetector
|
||||
class TinynasDetector(SingleStageDetector):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
|
||||
self.config_name = 'airdet_s.py'
|
||||
super(TinynasDetector, self).__init__(model_dir, *args, **kwargs)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Mapping, Optional, Sequence, Union
|
||||
from datasets.builder import DatasetBuilder
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, DownloadParams
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder
|
||||
|
||||
@@ -95,15 +95,13 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool,
|
||||
res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...]
|
||||
"""
|
||||
res = []
|
||||
cookies = hub_api.check_cookies_upload_data(use_cookies=True)
|
||||
objects = hub_api.list_oss_dataset_objects(
|
||||
dataset_name=dataset_name,
|
||||
namespace=namespace,
|
||||
max_limit=max_limit,
|
||||
is_recursive=is_recursive,
|
||||
is_filter_dir=True,
|
||||
revision=version,
|
||||
cookies=cookies)
|
||||
revision=version)
|
||||
|
||||
for item in objects:
|
||||
object_key = item.get('Key')
|
||||
@@ -174,7 +172,7 @@ def get_dataset_files(subset_split_into: dict,
|
||||
modelscope_api = HubApi()
|
||||
objects = list_dataset_objects(
|
||||
hub_api=modelscope_api,
|
||||
max_limit=DownloadParams.MAX_LIST_OBJECTS_NUM.value,
|
||||
max_limit=-1,
|
||||
is_recursive=True,
|
||||
dataset_name=dataset_name,
|
||||
namespace=namespace,
|
||||
|
||||
@@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import \
|
||||
show_image_object_detection_auto_result
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline):
|
||||
|
||||
bboxes, scores, labels = self.model.postprocess(inputs['data'])
|
||||
if bboxes is None:
|
||||
return None
|
||||
outputs = {
|
||||
OutputKeys.SCORES: scores,
|
||||
OutputKeys.LABELS: labels,
|
||||
OutputKeys.BOXES: bboxes
|
||||
}
|
||||
outputs = {
|
||||
OutputKeys.SCORES: [],
|
||||
OutputKeys.LABELS: [],
|
||||
OutputKeys.BOXES: []
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
OutputKeys.SCORES: scores,
|
||||
OutputKeys.LABELS: labels,
|
||||
OutputKeys.BOXES: bboxes
|
||||
}
|
||||
return outputs
|
||||
|
||||
def show_result(self, img_path, result, save_path=None):
|
||||
show_image_object_detection_auto_result(img_path, result, save_path)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from os.path import exists
|
||||
from tempfile import TemporaryDirectory
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -9,6 +14,7 @@ import torchvision.transforms._transforms_video as transforms
|
||||
from decord import VideoReader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from modelscope.hub.file_download import http_get_file
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.utils.constant import Fields, ModeKeys
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
@@ -30,7 +36,22 @@ def ReadVideoData(cfg,
|
||||
Returns:
|
||||
data (Tensor): the normalized video clips for model inputs
|
||||
"""
|
||||
data = _decode_video(cfg, video_path, num_temporal_views_override)
|
||||
url_parsed = urlparse(video_path)
|
||||
if url_parsed.scheme in ('file', '') and exists(
|
||||
url_parsed.path): # Possibly a local file
|
||||
data = _decode_video(cfg, video_path, num_temporal_views_override)
|
||||
else:
|
||||
with TemporaryDirectory() as temporary_cache_dir:
|
||||
random_str = uuid.uuid4().hex
|
||||
http_get_file(
|
||||
url=video_path,
|
||||
local_dir=temporary_cache_dir,
|
||||
file_name=random_str,
|
||||
cookies=None)
|
||||
temp_file_path = os.path.join(temporary_cache_dir, random_str)
|
||||
data = _decode_video(cfg, temp_file_path,
|
||||
num_temporal_views_override)
|
||||
|
||||
if num_spatial_crops_override is not None:
|
||||
num_spatial_crops = num_spatial_crops_override
|
||||
transform = kinetics400_tranform(cfg, num_spatial_crops_override)
|
||||
|
||||
@@ -231,13 +231,6 @@ class DownloadMode(enum.Enum):
|
||||
FORCE_REDOWNLOAD = 'force_redownload'
|
||||
|
||||
|
||||
class DownloadParams(enum.Enum):
|
||||
"""
|
||||
Parameters for downloading dataset.
|
||||
"""
|
||||
MAX_LIST_OBJECTS_NUM = 50000
|
||||
|
||||
|
||||
class DatasetFormations(enum.Enum):
|
||||
""" How a dataset is organized and interpreted
|
||||
"""
|
||||
|
||||
@@ -61,8 +61,8 @@ def device_placement(framework, device_name='gpu:0'):
|
||||
if framework == Frameworks.tf:
|
||||
import tensorflow as tf
|
||||
if device_type == Devices.gpu and not tf.test.is_gpu_available():
|
||||
logger.warning(
|
||||
'tensorflow cuda is not available, using cpu instead.')
|
||||
logger.debug(
|
||||
'tensorflow: cuda is not available, using cpu instead.')
|
||||
device_type = Devices.cpu
|
||||
if device_type == Devices.cpu:
|
||||
with tf.device('/CPU:0'):
|
||||
@@ -78,7 +78,8 @@ def device_placement(framework, device_name='gpu:0'):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(f'cuda:{device_id}')
|
||||
else:
|
||||
logger.warning('cuda is not available, using cpu instead.')
|
||||
logger.debug(
|
||||
'pytorch: cuda is not available, using cpu instead.')
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
@@ -96,9 +97,7 @@ def create_device(device_name):
|
||||
if device_type == Devices.gpu:
|
||||
use_cuda = True
|
||||
if not torch.cuda.is_available():
|
||||
logger.warning(
|
||||
'cuda is not available, create gpu device failed, using cpu instead.'
|
||||
)
|
||||
logger.info('cuda is not available, using cpu instead.')
|
||||
use_cuda = False
|
||||
|
||||
if use_cuda:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -35,3 +36,10 @@ def get_default_cache_dir():
|
||||
"""
|
||||
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
|
||||
return default_cache_dir
|
||||
|
||||
|
||||
def read_file(path):
|
||||
|
||||
with open(path, 'r') as f:
|
||||
text = f.read()
|
||||
return text
|
||||
|
||||
@@ -176,7 +176,7 @@ def build_from_cfg(cfg,
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
# dynamic load installation reqruiements for this module
|
||||
# dynamic load installation requirements for this module
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
sig = (registry.name.upper(), group_key, cfg['type'])
|
||||
LazyImportModule.import_module(sig)
|
||||
@@ -193,8 +193,11 @@ def build_from_cfg(cfg,
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type, group_key=group_key)
|
||||
if obj_cls is None:
|
||||
raise KeyError(f'{obj_type} is not in the {registry.name}'
|
||||
f' registry group {group_key}')
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name}'
|
||||
f' registry group {group_key}. Please make'
|
||||
f' sure the correct version of 1qqQModelScope library is used.'
|
||||
)
|
||||
obj_cls.group_key = group_key
|
||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||
obj_cls = obj_type
|
||||
|
||||
@@ -4,22 +4,45 @@ import unittest
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_object_detection
|
||||
self.model_id = 'damo/cv_tinynas_object-detection_damoyolo'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
def test_run_airdet(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print(result)
|
||||
|
||||
@unittest.skip('will be enabled after damoyolo officially released')
|
||||
def test_run_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo')
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print(result)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.test_demo()
|
||||
self.compatibility_check()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_object_detection_auto_pipeline(self):
|
||||
test_image = 'data/test/images/image_detection.jpg'
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
||||
result = tinynas_object_detection(test_image)
|
||||
tinynas_object_detection.show_result(test_image, result,
|
||||
'demo_ret.jpg')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user