revise format

This commit is contained in:
rujiao.lrj
2022-11-17 15:24:17 +08:00
parent 80fba922a8
commit 4cf627ec56
9 changed files with 1629 additions and 1382 deletions

View File

@@ -2,140 +2,141 @@
class Models(object):
""" Names for different models.
"""Names for different models.
Holds the standard model name to use for identifying different model.
This should be used to register models.
Model name should only contain model info but not task info.
"""
# tinynas models
tinynas_detection = 'tinynas-detection'
tinynas_damoyolo = 'tinynas-damoyolo'
tinynas_detection = "tinynas-detection"
tinynas_damoyolo = "tinynas-damoyolo"
# vision models
detection = 'detection'
realtime_object_detection = 'realtime-object-detection'
realtime_video_object_detection = 'realtime-video-object-detection'
scrfd = 'scrfd'
classification_model = 'ClassificationModel'
nafnet = 'nafnet'
csrnet = 'csrnet'
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
gpen = 'gpen'
product_retrieval_embedding = 'product-retrieval-embedding'
body_2d_keypoints = 'body-2d-keypoints'
body_3d_keypoints = 'body-3d-keypoints'
crowd_counting = 'HRNetCrowdCounting'
face_2d_keypoints = 'face-2d-keypoints'
panoptic_segmentation = 'swinL-panoptic-segmentation'
image_reid_person = 'passvitb'
image_inpainting = 'FFTInpainting'
video_summarization = 'pgl-video-summarization'
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
text_driven_segmentation = 'text-driven-segmentation'
resnet50_bert = 'resnet50-bert'
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation'
fer = 'fer'
retinaface = 'retinaface'
shop_segmentation = 'shop-segmentation'
mogface = 'mogface'
mtcnn = 'mtcnn'
ulfd = 'ulfd'
video_inpainting = 'video-inpainting'
human_wholebody_keypoint = 'human-wholebody-keypoint'
hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'
image_body_reshaping = 'image-body-reshaping'
detection = "detection"
realtime_object_detection = "realtime-object-detection"
realtime_video_object_detection = "realtime-video-object-detection"
scrfd = "scrfd"
classification_model = "ClassificationModel"
nafnet = "nafnet"
csrnet = "csrnet"
cascade_mask_rcnn_swin = "cascade_mask_rcnn_swin"
gpen = "gpen"
product_retrieval_embedding = "product-retrieval-embedding"
body_2d_keypoints = "body-2d-keypoints"
body_3d_keypoints = "body-3d-keypoints"
crowd_counting = "HRNetCrowdCounting"
face_2d_keypoints = "face-2d-keypoints"
panoptic_segmentation = "swinL-panoptic-segmentation"
image_reid_person = "passvitb"
image_inpainting = "FFTInpainting"
video_summarization = "pgl-video-summarization"
swinL_semantic_segmentation = "swinL-semantic-segmentation"
vitadapter_semantic_segmentation = "vitadapter-semantic-segmentation"
text_driven_segmentation = "text-driven-segmentation"
resnet50_bert = "resnet50-bert"
referring_video_object_segmentation = "swinT-referring-video-object-segmentation"
fer = "fer"
retinaface = "retinaface"
shop_segmentation = "shop-segmentation"
mogface = "mogface"
mtcnn = "mtcnn"
ulfd = "ulfd"
video_inpainting = "video-inpainting"
human_wholebody_keypoint = "human-wholebody-keypoint"
hand_static = "hand-static"
face_human_hand_detection = "face-human-hand-detection"
face_emotion = "face-emotion"
product_segmentation = "product-segmentation"
image_body_reshaping = "image-body-reshaping"
# EasyCV models
yolox = 'YOLOX'
segformer = 'Segformer'
hand_2d_keypoints = 'HRNet-Hand2D-Keypoints'
image_object_detection_auto = 'image-object-detection-auto'
yolox = "YOLOX"
segformer = "Segformer"
hand_2d_keypoints = "HRNet-Hand2D-Keypoints"
image_object_detection_auto = "image-object-detection-auto"
# nlp models
bert = 'bert'
palm = 'palm-v2'
structbert = 'structbert'
deberta_v2 = 'deberta_v2'
veco = 'veco'
translation = 'csanmt-translation'
space_dst = 'space-dst'
space_intent = 'space-intent'
space_modeling = 'space-modeling'
space_T_en = 'space-T-en'
space_T_cn = 'space-T-cn'
tcrf = 'transformer-crf'
tcrf_wseg = 'transformer-crf-for-word-segmentation'
transformer_softmax = 'transformer-softmax'
lcrf = 'lstm-crf'
lcrf_wseg = 'lstm-crf-for-word-segmentation'
gcnncrf = 'gcnn-crf'
bart = 'bart'
gpt3 = 'gpt3'
gpt_neo = 'gpt-neo'
plug = 'plug'
bert_for_ds = 'bert-for-document-segmentation'
ponet = 'ponet'
T5 = 'T5'
mglm = 'mglm'
bloom = 'bloom'
bert = "bert"
palm = "palm-v2"
structbert = "structbert"
deberta_v2 = "deberta_v2"
veco = "veco"
translation = "csanmt-translation"
space_dst = "space-dst"
space_intent = "space-intent"
space_modeling = "space-modeling"
space_T_en = "space-T-en"
space_T_cn = "space-T-cn"
tcrf = "transformer-crf"
tcrf_wseg = "transformer-crf-for-word-segmentation"
transformer_softmax = "transformer-softmax"
lcrf = "lstm-crf"
lcrf_wseg = "lstm-crf-for-word-segmentation"
gcnncrf = "gcnn-crf"
bart = "bart"
gpt3 = "gpt3"
gpt_neo = "gpt-neo"
plug = "plug"
bert_for_ds = "bert-for-document-segmentation"
ponet = "ponet"
T5 = "T5"
mglm = "mglm"
bloom = "bloom"
# audio models
sambert_hifigan = 'sambert-hifigan'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'
sambert_hifigan = "sambert-hifigan"
speech_frcrn_ans_cirm_16k = "speech_frcrn_ans_cirm_16k"
speech_dfsmn_kws_char_farfield = "speech_dfsmn_kws_char_farfield"
kws_kwsbp = "kws-kwsbp"
generic_asr = "generic-asr"
# multi-modal models
ofa = 'ofa'
clip = 'clip-multi-modal-embedding'
gemm = 'gemm-generative-multi-modal'
mplug = 'mplug'
diffusion = 'diffusion-text-to-image-synthesis'
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
team = 'team-multi-modal-similarity'
video_clip = 'video-clip-multi-modal-embedding'
ofa = "ofa"
clip = "clip-multi-modal-embedding"
gemm = "gemm-generative-multi-modal"
mplug = "mplug"
diffusion = "diffusion-text-to-image-synthesis"
multi_stage_diffusion = "multi-stage-diffusion-text-to-image-synthesis"
team = "team-multi-modal-similarity"
video_clip = "video-clip-multi-modal-embedding"
# science models
unifold = 'unifold'
unifold_symmetry = 'unifold-symmetry'
unifold = "unifold"
unifold_symmetry = "unifold-symmetry"
class TaskModels(object):
# nlp task
text_classification = 'text-classification'
token_classification = 'token-classification'
information_extraction = 'information-extraction'
fill_mask = 'fill-mask'
feature_extraction = 'feature-extraction'
text_generation = 'text-generation'
text_classification = "text-classification"
token_classification = "token-classification"
information_extraction = "information-extraction"
fill_mask = "fill-mask"
feature_extraction = "feature-extraction"
text_generation = "text-generation"
class Heads(object):
# nlp heads
# text cls
text_classification = 'text-classification'
text_classification = "text-classification"
# fill mask
fill_mask = 'fill-mask'
bert_mlm = 'bert-mlm'
roberta_mlm = 'roberta-mlm'
fill_mask = "fill-mask"
bert_mlm = "bert-mlm"
roberta_mlm = "roberta-mlm"
# token cls
token_classification = 'token-classification'
token_classification = "token-classification"
# extraction
information_extraction = 'information-extraction'
information_extraction = "information-extraction"
# text gen
text_generation = 'text-generation'
text_generation = "text-generation"
class Pipelines(object):
""" Names for different pipelines.
"""Names for different pipelines.
Holds the standard pipline name to use for identifying different pipeline.
This should be used to register pipelines.
@@ -144,148 +145,151 @@ class Pipelines(object):
should use task name for this pipeline.
For pipeline which suuport only one model, we should use ${Model}-${Task} as its name.
"""
# vision tasks
portrait_matting = 'unet-image-matting'
image_denoise = 'nafnet-image-denoise'
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
table_recognition = 'dla34-table-recognition'
action_recognition = 'TAdaConv_action-recognition'
animal_recognition = 'resnet101-animal-recognition'
general_recognition = 'resnet101-general-recognition'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
hicossl_video_embedding = 'hicossl-s3dg-video_embedding'
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
body_3d_keypoints = 'canonical_body-3d-keypoints_video'
hand_2d_keypoints = 'hrnetv2w18_hand-2d-keypoints_image'
human_detection = 'resnet18-human-detection'
object_detection = 'vit-object-detection'
easycv_detection = 'easycv-detection'
easycv_segmentation = 'easycv-segmentation'
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment'
salient_detection = 'u2net-salient-detection'
image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps'
card_detection = 'resnet-card-detection-scrfd34gkps'
ulfd_face_detection = 'manual-face-detection-ulfd'
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer'
retina_face_detection = 'resnet50-face-detection-retinaface'
mog_face_detection = 'resnet101-face-detection-cvpr22papermogface'
mtcnn_face_detection = 'manual-face-detection-mtcnn'
live_category = 'live-category'
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
image_color_enhance = 'csrnet-image-color-enhance'
virtual_try_on = 'virtual-try-on'
image_colorization = 'unet-image-colorization'
image_style_transfer = 'AAMS-style-transfer'
image_super_resolution = 'rrdb-image-super-resolution'
face_image_generation = 'gan-face-image-generation'
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'
realtime_object_detection = 'cspnet_realtime-object-detection_yolox'
realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo'
face_recognition = 'ir101-face-recognition-cfglint'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
image2image_translation = 'image-to-image-translation'
live_category = 'live-category'
video_category = 'video-category'
ocr_recognition = 'convnextTiny-ocr-recognition'
image_portrait_enhancement = 'gpen-image-portrait-enhancement'
image_to_image_generation = 'image-to-image-generation'
image_object_detection_auto = 'yolox_image-object-detection-auto'
skin_retouching = 'unet-skin-retouching'
tinynas_classification = 'tinynas-classification'
tinynas_detection = 'tinynas-detection'
crowd_counting = 'hrnet-crowd-counting'
action_detection = 'ResNetC3D-action-detection'
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
image_panoptic_segmentation = 'image-panoptic-segmentation'
video_summarization = 'googlenet_pgl_video_summarization'
image_semantic_segmentation = 'image-semantic-segmentation'
image_reid_person = 'passvitb-image-reid-person'
image_inpainting = 'fft-inpainting'
text_driven_segmentation = 'text-driven-segmentation'
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
shop_segmentation = 'shop-segmentation'
video_inpainting = 'video-inpainting'
human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image'
pst_action_recognition = 'patchshift-action-recognition'
hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'
image_body_reshaping = 'flow-based-body-reshaping'
referring_video_object_segmentation = 'referring-video-object-segmentation'
portrait_matting = "unet-image-matting"
image_denoise = "nafnet-image-denoise"
person_image_cartoon = "unet-person-image-cartoon"
ocr_detection = "resnet18-ocr-detection"
table_recognition = "dla34-table-recognition"
action_recognition = "TAdaConv_action-recognition"
animal_recognition = "resnet101-animal-recognition"
general_recognition = "resnet101-general-recognition"
cmdssl_video_embedding = "cmdssl-r2p1d_video_embedding"
hicossl_video_embedding = "hicossl-s3dg-video_embedding"
body_2d_keypoints = "hrnetv2w32_body-2d-keypoints_image"
body_3d_keypoints = "canonical_body-3d-keypoints_video"
hand_2d_keypoints = "hrnetv2w18_hand-2d-keypoints_image"
human_detection = "resnet18-human-detection"
object_detection = "vit-object-detection"
easycv_detection = "easycv-detection"
easycv_segmentation = "easycv-segmentation"
face_2d_keypoints = "mobilenet_face-2d-keypoints_alignment"
salient_detection = "u2net-salient-detection"
image_classification = "image-classification"
face_detection = "resnet-face-detection-scrfd10gkps"
card_detection = "resnet-card-detection-scrfd34gkps"
ulfd_face_detection = "manual-face-detection-ulfd"
facial_expression_recognition = "vgg19-facial-expression-recognition-fer"
retina_face_detection = "resnet50-face-detection-retinaface"
mog_face_detection = "resnet101-face-detection-cvpr22papermogface"
mtcnn_face_detection = "manual-face-detection-mtcnn"
live_category = "live-category"
general_image_classification = "vit-base_image-classification_ImageNet-labels"
daily_image_classification = "vit-base_image-classification_Dailylife-labels"
image_color_enhance = "csrnet-image-color-enhance"
virtual_try_on = "virtual-try-on"
image_colorization = "unet-image-colorization"
image_style_transfer = "AAMS-style-transfer"
image_super_resolution = "rrdb-image-super-resolution"
face_image_generation = "gan-face-image-generation"
product_retrieval_embedding = "resnet50-product-retrieval-embedding"
realtime_object_detection = "cspnet_realtime-object-detection_yolox"
realtime_video_object_detection = (
"cspnet_realtime-video-object-detection_streamyolo"
)
face_recognition = "ir101-face-recognition-cfglint"
image_instance_segmentation = "cascade-mask-rcnn-swin-image-instance-segmentation"
image2image_translation = "image-to-image-translation"
live_category = "live-category"
video_category = "video-category"
ocr_recognition = "convnextTiny-ocr-recognition"
image_portrait_enhancement = "gpen-image-portrait-enhancement"
image_to_image_generation = "image-to-image-generation"
image_object_detection_auto = "yolox_image-object-detection-auto"
skin_retouching = "unet-skin-retouching"
tinynas_classification = "tinynas-classification"
tinynas_detection = "tinynas-detection"
crowd_counting = "hrnet-crowd-counting"
action_detection = "ResNetC3D-action-detection"
video_single_object_tracking = "ostrack-vitb-video-single-object-tracking"
image_panoptic_segmentation = "image-panoptic-segmentation"
video_summarization = "googlenet_pgl_video_summarization"
image_semantic_segmentation = "image-semantic-segmentation"
image_reid_person = "passvitb-image-reid-person"
image_inpainting = "fft-inpainting"
text_driven_segmentation = "text-driven-segmentation"
movie_scene_segmentation = "resnet50-bert-movie-scene-segmentation"
shop_segmentation = "shop-segmentation"
video_inpainting = "video-inpainting"
human_wholebody_keypoint = "hrnetw48_human-wholebody-keypoint_image"
pst_action_recognition = "patchshift-action-recognition"
hand_static = "hand-static"
face_human_hand_detection = "face-human-hand-detection"
face_emotion = "face-emotion"
product_segmentation = "product-segmentation"
image_body_reshaping = "flow-based-body-reshaping"
referring_video_object_segmentation = "referring-video-object-segmentation"
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
translation_quality_estimation = 'translation-quality-estimation'
domain_classification = 'domain-classification'
sentence_similarity = 'sentence-similarity'
word_segmentation = 'word-segmentation'
multilingual_word_segmentation = 'multilingual-word-segmentation'
word_segmentation_thai = 'word-segmentation-thai'
part_of_speech = 'part-of-speech'
named_entity_recognition = 'named-entity-recognition'
named_entity_recognition_thai = 'named-entity-recognition-thai'
named_entity_recognition_viet = 'named-entity-recognition-viet'
text_generation = 'text-generation'
text2text_generation = 'text2text-generation'
sentiment_analysis = 'sentiment-analysis'
sentiment_classification = 'sentiment-classification'
text_classification = 'text-classification'
fill_mask = 'fill-mask'
fill_mask_ponet = 'fill-mask-ponet'
csanmt_translation = 'csanmt-translation'
nli = 'nli'
dialog_intent_prediction = 'dialog-intent-prediction'
dialog_modeling = 'dialog-modeling'
dialog_state_tracking = 'dialog-state-tracking'
zero_shot_classification = 'zero-shot-classification'
text_error_correction = 'text-error-correction'
plug_generation = 'plug-generation'
gpt3_generation = 'gpt3-generation'
faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'
table_question_answering_pipeline = 'table-question-answering-pipeline'
sentence_embedding = 'sentence-embedding'
text_ranking = 'text-ranking'
relation_extraction = 'relation-extraction'
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'
mglm_text_summarization = 'mglm-text-summarization'
translation_en_to_de = 'translation_en_to_de' # keep it underscore
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
token_classification = 'token-classification'
automatic_post_editing = "automatic-post-editing"
translation_quality_estimation = "translation-quality-estimation"
domain_classification = "domain-classification"
sentence_similarity = "sentence-similarity"
word_segmentation = "word-segmentation"
multilingual_word_segmentation = "multilingual-word-segmentation"
word_segmentation_thai = "word-segmentation-thai"
part_of_speech = "part-of-speech"
named_entity_recognition = "named-entity-recognition"
named_entity_recognition_thai = "named-entity-recognition-thai"
named_entity_recognition_viet = "named-entity-recognition-viet"
text_generation = "text-generation"
text2text_generation = "text2text-generation"
sentiment_analysis = "sentiment-analysis"
sentiment_classification = "sentiment-classification"
text_classification = "text-classification"
fill_mask = "fill-mask"
fill_mask_ponet = "fill-mask-ponet"
csanmt_translation = "csanmt-translation"
nli = "nli"
dialog_intent_prediction = "dialog-intent-prediction"
dialog_modeling = "dialog-modeling"
dialog_state_tracking = "dialog-state-tracking"
zero_shot_classification = "zero-shot-classification"
text_error_correction = "text-error-correction"
plug_generation = "plug-generation"
gpt3_generation = "gpt3-generation"
faq_question_answering = "faq-question-answering"
conversational_text_to_sql = "conversational-text-to-sql"
table_question_answering_pipeline = "table-question-answering-pipeline"
sentence_embedding = "sentence-embedding"
text_ranking = "text-ranking"
relation_extraction = "relation-extraction"
document_segmentation = "document-segmentation"
feature_extraction = "feature-extraction"
mglm_text_summarization = "mglm-text-summarization"
translation_en_to_de = "translation_en_to_de" # keep it underscore
translation_en_to_ro = "translation_en_to_ro" # keep it underscore
translation_en_to_fr = "translation_en_to_fr" # keep it underscore
token_classification = "token-classification"
# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
asr_inference = 'asr-inference'
sambert_hifigan_tts = "sambert-hifigan-tts"
speech_dfsmn_aec_psm_16k = "speech-dfsmn-aec-psm-16k"
speech_frcrn_ans_cirm_16k = "speech_frcrn_ans_cirm_16k"
speech_dfsmn_kws_char_farfield = "speech_dfsmn_kws_char_farfield"
kws_kwsbp = "kws-kwsbp"
asr_inference = "asr-inference"
# multi-modal tasks
image_captioning = 'image-captioning'
multi_modal_embedding = 'multi-modal-embedding'
generative_multi_modal_embedding = 'generative-multi-modal-embedding'
visual_question_answering = 'visual-question-answering'
visual_grounding = 'visual-grounding'
visual_entailment = 'visual-entailment'
multi_modal_similarity = 'multi-modal-similarity'
text_to_image_synthesis = 'text-to-image-synthesis'
video_multi_modal_embedding = 'video-multi-modal-embedding'
image_text_retrieval = 'image-text-retrieval'
ofa_ocr_recognition = 'ofa-ocr-recognition'
image_captioning = "image-captioning"
multi_modal_embedding = "multi-modal-embedding"
generative_multi_modal_embedding = "generative-multi-modal-embedding"
visual_question_answering = "visual-question-answering"
visual_grounding = "visual-grounding"
visual_entailment = "visual-entailment"
multi_modal_similarity = "multi-modal-similarity"
text_to_image_synthesis = "text-to-image-synthesis"
video_multi_modal_embedding = "video-multi-modal-embedding"
image_text_retrieval = "image-text-retrieval"
ofa_ocr_recognition = "ofa-ocr-recognition"
# science tasks
protein_structure = 'unifold-protein-structure'
protein_structure = "unifold-protein-structure"
class Trainers(object):
""" Names for different trainer.
"""Names for different trainer.
Holds the standard trainer name to use for identifying different trainer.
This should be used to register trainers.
@@ -294,41 +298,41 @@ class Trainers(object):
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
"""
default = 'trainer'
easycv = 'easycv'
default = "trainer"
easycv = "easycv"
# multi-modal trainers
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa = 'ofa'
mplug = 'mplug'
clip_multi_modal_embedding = "clip-multi-modal-embedding"
ofa = "ofa"
mplug = "mplug"
# cv trainers
image_instance_segmentation = 'image-instance-segmentation'
image_portrait_enhancement = 'image-portrait-enhancement'
video_summarization = 'video-summarization'
movie_scene_segmentation = 'movie-scene-segmentation'
face_detection_scrfd = 'face-detection-scrfd'
card_detection_scrfd = 'card-detection-scrfd'
image_inpainting = 'image-inpainting'
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_classification_team = 'image-classification-team'
image_instance_segmentation = "image-instance-segmentation"
image_portrait_enhancement = "image-portrait-enhancement"
video_summarization = "video-summarization"
movie_scene_segmentation = "movie-scene-segmentation"
face_detection_scrfd = "face-detection-scrfd"
card_detection_scrfd = "card-detection-scrfd"
image_inpainting = "image-inpainting"
referring_video_object_segmentation = "referring-video-object-segmentation"
image_classification_team = "image-classification-team"
# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'
dialog_modeling_trainer = 'dialog-modeling-trainer'
dialog_intent_trainer = 'dialog-intent-trainer'
nlp_base_trainer = 'nlp-base-trainer'
nlp_veco_trainer = 'nlp-veco-trainer'
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
text_generation_trainer = 'text-generation-trainer'
bert_sentiment_analysis = "bert-sentiment-analysis"
dialog_modeling_trainer = "dialog-modeling-trainer"
dialog_intent_trainer = "dialog-intent-trainer"
nlp_base_trainer = "nlp-base-trainer"
nlp_veco_trainer = "nlp-veco-trainer"
nlp_text_ranking_trainer = "nlp-text-ranking-trainer"
text_generation_trainer = "text-generation-trainer"
# audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_frcrn_ans_cirm_16k = "speech_frcrn_ans_cirm_16k"
speech_dfsmn_kws_char_farfield = "speech_dfsmn_kws_char_farfield"
class Preprocessors(object):
""" Names for different preprocessor.
"""Names for different preprocessor.
Holds the standard preprocessor name to use for identifying different preprocessor.
This should be used to register preprocessors.
@@ -339,168 +343,171 @@ class Preprocessors(object):
"""
# cv preprocessor
load_image = 'load-image'
image_denoie_preprocessor = 'image-denoise-preprocessor'
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
video_summarization_preprocessor = 'video-summarization-preprocessor'
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
load_image = "load-image"
image_denoie_preprocessor = "image-denoise-preprocessor"
image_color_enhance_preprocessor = "image-color-enhance-preprocessor"
image_instance_segmentation_preprocessor = (
"image-instance-segmentation-preprocessor"
)
image_portrait_enhancement_preprocessor = "image-portrait-enhancement-preprocessor"
video_summarization_preprocessor = "video-summarization-preprocessor"
movie_scene_segmentation_preprocessor = "movie-scene-segmentation-preprocessor"
# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
cross_encoder_tokenizer = 'cross-encoder-tokenizer'
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
text_gen_tokenizer = 'text-gen-tokenizer'
text2text_gen_preprocessor = 'text2text-gen-preprocessor'
text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer'
text2text_translate_preprocessor = 'text2text-translate-preprocessor'
token_cls_tokenizer = 'token-cls-tokenizer'
ner_tokenizer = 'ner-tokenizer'
thai_ner_tokenizer = 'thai-ner-tokenizer'
viet_ner_tokenizer = 'viet-ner-tokenizer'
nli_tokenizer = 'nli-tokenizer'
sen_cls_tokenizer = 'sen-cls-tokenizer'
dialog_intent_preprocessor = 'dialog-intent-preprocessor'
dialog_modeling_preprocessor = 'dialog-modeling-preprocessor'
dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'
text_error_correction = 'text-error-correction'
sentence_embedding = 'sentence-embedding'
text_ranking = 'text-ranking'
sequence_labeling_tokenizer = 'sequence-labeling-tokenizer'
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
thai_wseg_tokenizer = 'thai-wseg-tokenizer'
fill_mask = 'fill-mask'
fill_mask_ponet = 'fill-mask-ponet'
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor'
conversational_text_to_sql = 'conversational-text-to-sql'
table_question_answering_preprocessor = 'table-question-answering-preprocessor'
re_tokenizer = 're-tokenizer'
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'
mglm_summarization = 'mglm-summarization'
sentence_piece = 'sentence-piece'
sen_sim_tokenizer = "sen-sim-tokenizer"
cross_encoder_tokenizer = "cross-encoder-tokenizer"
bert_seq_cls_tokenizer = "bert-seq-cls-tokenizer"
text_gen_tokenizer = "text-gen-tokenizer"
text2text_gen_preprocessor = "text2text-gen-preprocessor"
text_gen_jieba_tokenizer = "text-gen-jieba-tokenizer"
text2text_translate_preprocessor = "text2text-translate-preprocessor"
token_cls_tokenizer = "token-cls-tokenizer"
ner_tokenizer = "ner-tokenizer"
thai_ner_tokenizer = "thai-ner-tokenizer"
viet_ner_tokenizer = "viet-ner-tokenizer"
nli_tokenizer = "nli-tokenizer"
sen_cls_tokenizer = "sen-cls-tokenizer"
dialog_intent_preprocessor = "dialog-intent-preprocessor"
dialog_modeling_preprocessor = "dialog-modeling-preprocessor"
dialog_state_tracking_preprocessor = "dialog-state-tracking-preprocessor"
sbert_token_cls_tokenizer = "sbert-token-cls-tokenizer"
zero_shot_cls_tokenizer = "zero-shot-cls-tokenizer"
text_error_correction = "text-error-correction"
sentence_embedding = "sentence-embedding"
text_ranking = "text-ranking"
sequence_labeling_tokenizer = "sequence-labeling-tokenizer"
word_segment_text_to_label_preprocessor = "word-segment-text-to-label-preprocessor"
thai_wseg_tokenizer = "thai-wseg-tokenizer"
fill_mask = "fill-mask"
fill_mask_ponet = "fill-mask-ponet"
faq_question_answering_preprocessor = "faq-question-answering-preprocessor"
conversational_text_to_sql = "conversational-text-to-sql"
table_question_answering_preprocessor = "table-question-answering-preprocessor"
re_tokenizer = "re-tokenizer"
document_segmentation = "document-segmentation"
feature_extraction = "feature-extraction"
mglm_summarization = "mglm-summarization"
sentence_piece = "sentence-piece"
# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
wav_to_lists = 'wav-to-lists'
wav_to_scp = 'wav-to-scp'
linear_aec_fbank = "linear-aec-fbank"
text_to_tacotron_symbols = "text-to-tacotron-symbols"
wav_to_lists = "wav-to-lists"
wav_to_scp = "wav-to-scp"
# multi-modal preprocessor
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor'
clip_preprocessor = 'clip-preprocessor'
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor'
ofa_tasks_preprocessor = "ofa-tasks-preprocessor"
clip_preprocessor = "clip-preprocessor"
mplug_tasks_preprocessor = "mplug-tasks-preprocessor"
# science preprocessor
unifold_preprocessor = 'unifold-preprocessor'
unifold_preprocessor = "unifold-preprocessor"
class Metrics(object):
""" Names for different metrics.
"""
"""Names for different metrics."""
# accuracy
accuracy = 'accuracy'
multi_average_precision = 'mAP'
audio_noise_metric = 'audio-noise-metric'
accuracy = "accuracy"
multi_average_precision = "mAP"
audio_noise_metric = "audio-noise-metric"
# text gen
BLEU = 'bleu'
BLEU = "bleu"
# metrics for image denoise task
image_denoise_metric = 'image-denoise-metric'
image_denoise_metric = "image-denoise-metric"
# metric for image instance segmentation task
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric'
image_ins_seg_coco_metric = "image-ins-seg-coco-metric"
# metrics for sequence classification task
seq_cls_metric = 'seq-cls-metric'
seq_cls_metric = "seq-cls-metric"
# metrics for token-classification task
token_cls_metric = 'token-cls-metric'
token_cls_metric = "token-cls-metric"
# metrics for text-generation task
text_gen_metric = 'text-gen-metric'
text_gen_metric = "text-gen-metric"
# metrics for image-color-enhance task
image_color_enhance_metric = 'image-color-enhance-metric'
image_color_enhance_metric = "image-color-enhance-metric"
# metrics for image-portrait-enhancement task
image_portrait_enhancement_metric = 'image-portrait-enhancement-metric'
video_summarization_metric = 'video-summarization-metric'
image_portrait_enhancement_metric = "image-portrait-enhancement-metric"
video_summarization_metric = "video-summarization-metric"
# metric for movie-scene-segmentation task
movie_scene_segmentation_metric = 'movie-scene-segmentation-metric'
movie_scene_segmentation_metric = "movie-scene-segmentation-metric"
# metric for inpainting task
image_inpainting_metric = 'image-inpainting-metric'
image_inpainting_metric = "image-inpainting-metric"
# metric for ocr
NED = 'ned'
NED = "ned"
# metric for cross-modal retrieval
inbatch_recall = 'inbatch_recall'
inbatch_recall = "inbatch_recall"
# metric for referring-video-object-segmentation task
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric'
referring_video_object_segmentation_metric = (
"referring-video-object-segmentation-metric"
)
class Optimizers(object):
""" Names for different OPTIMIZER.
"""Names for different OPTIMIZER.
Holds the standard optimizer name to use for identifying different optimizer.
This should be used to register optimizer.
Holds the standard optimizer name to use for identifying different optimizer.
This should be used to register optimizer.
"""
default = 'optimizer'
default = "optimizer"
SGD = 'SGD'
SGD = "SGD"
class Hooks(object):
""" Names for different hooks.
"""Names for different hooks.
All kinds of hooks are defined here
All kinds of hooks are defined here
"""
# lr
LrSchedulerHook = 'LrSchedulerHook'
PlateauLrSchedulerHook = 'PlateauLrSchedulerHook'
NoneLrSchedulerHook = 'NoneLrSchedulerHook'
LrSchedulerHook = "LrSchedulerHook"
PlateauLrSchedulerHook = "PlateauLrSchedulerHook"
NoneLrSchedulerHook = "NoneLrSchedulerHook"
# optimizer
OptimizerHook = 'OptimizerHook'
TorchAMPOptimizerHook = 'TorchAMPOptimizerHook'
ApexAMPOptimizerHook = 'ApexAMPOptimizerHook'
NoneOptimizerHook = 'NoneOptimizerHook'
OptimizerHook = "OptimizerHook"
TorchAMPOptimizerHook = "TorchAMPOptimizerHook"
ApexAMPOptimizerHook = "ApexAMPOptimizerHook"
NoneOptimizerHook = "NoneOptimizerHook"
# checkpoint
CheckpointHook = 'CheckpointHook'
BestCkptSaverHook = 'BestCkptSaverHook'
CheckpointHook = "CheckpointHook"
BestCkptSaverHook = "BestCkptSaverHook"
# logger
TextLoggerHook = 'TextLoggerHook'
TensorboardHook = 'TensorboardHook'
TextLoggerHook = "TextLoggerHook"
TensorboardHook = "TensorboardHook"
IterTimerHook = 'IterTimerHook'
EvaluationHook = 'EvaluationHook'
IterTimerHook = "IterTimerHook"
EvaluationHook = "EvaluationHook"
# Compression
SparsityHook = 'SparsityHook'
SparsityHook = "SparsityHook"
# CLIP logit_scale clamp
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'
ClipClampLogitScaleHook = "ClipClampLogitScaleHook"
class LR_Schedulers(object):
"""learning rate scheduler is defined here
"""learning rate scheduler is defined here"""
"""
LinearWarmup = 'LinearWarmup'
ConstantWarmup = 'ConstantWarmup'
ExponentialWarmup = 'ExponentialWarmup'
LinearWarmup = "LinearWarmup"
ConstantWarmup = "ConstantWarmup"
ExponentialWarmup = "ExponentialWarmup"
class Datasets(object):
""" Names for different datasets.
"""
ClsDataset = 'ClsDataset'
Face2dKeypointsDataset = 'FaceKeypointDataset'
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset'
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset'
SegDataset = 'SegDataset'
DetDataset = 'DetDataset'
DetImagesMixDataset = 'DetImagesMixDataset'
PairedDataset = 'PairedDataset'
"""Names for different datasets."""
ClsDataset = "ClsDataset"
Face2dKeypointsDataset = "FaceKeypointDataset"
HandCocoWholeBodyDataset = "HandCocoWholeBodyDataset"
HumanWholeBodyKeypointDataset = "WholeBodyCocoTopDownDataset"
SegDataset = "SegDataset"
DetDataset = "DetDataset"
DetImagesMixDataset = "DetImagesMixDataset"
PairedDataset = "PairedDataset"

View File

@@ -6,53 +6,51 @@ from modelscope.utils.constant import Tasks
class OutputKeys(object):
LOSS = 'loss'
LOGITS = 'logits'
SCORES = 'scores'
SCORE = 'score'
LABEL = 'label'
LABELS = 'labels'
INPUT_IDS = 'input_ids'
LABEL_POS = 'label_pos'
POSES = 'poses'
CAPTION = 'caption'
BOXES = 'boxes'
KEYPOINTS = 'keypoints'
MASKS = 'masks'
TEXT = 'text'
POLYGONS = 'polygons'
OUTPUT = 'output'
OUTPUT_IMG = 'output_img'
OUTPUT_VIDEO = 'output_video'
OUTPUT_PCM = 'output_pcm'
IMG_EMBEDDING = 'img_embedding'
SPO_LIST = 'spo_list'
TEXT_EMBEDDING = 'text_embedding'
TRANSLATION = 'translation'
RESPONSE = 'response'
PREDICTION = 'prediction'
PREDICTIONS = 'predictions'
PROBABILITIES = 'probabilities'
DIALOG_STATES = 'dialog_states'
VIDEO_EMBEDDING = 'video_embedding'
UUID = 'uuid'
WORD = 'word'
KWS_LIST = 'kws_list'
SQL_STRING = 'sql_string'
SQL_QUERY = 'sql_query'
HISTORY = 'history'
QUERT_RESULT = 'query_result'
TIMESTAMPS = 'timestamps'
SHOT_NUM = 'shot_num'
SCENE_NUM = 'scene_num'
SCENE_META_LIST = 'scene_meta_list'
SHOT_META_LIST = 'shot_meta_list'
LOSS = "loss"
LOGITS = "logits"
SCORES = "scores"
SCORE = "score"
LABEL = "label"
LABELS = "labels"
INPUT_IDS = "input_ids"
LABEL_POS = "label_pos"
POSES = "poses"
CAPTION = "caption"
BOXES = "boxes"
KEYPOINTS = "keypoints"
MASKS = "masks"
TEXT = "text"
POLYGONS = "polygons"
OUTPUT = "output"
OUTPUT_IMG = "output_img"
OUTPUT_VIDEO = "output_video"
OUTPUT_PCM = "output_pcm"
IMG_EMBEDDING = "img_embedding"
SPO_LIST = "spo_list"
TEXT_EMBEDDING = "text_embedding"
TRANSLATION = "translation"
RESPONSE = "response"
PREDICTION = "prediction"
PREDICTIONS = "predictions"
PROBABILITIES = "probabilities"
DIALOG_STATES = "dialog_states"
VIDEO_EMBEDDING = "video_embedding"
UUID = "uuid"
WORD = "word"
KWS_LIST = "kws_list"
SQL_STRING = "sql_string"
SQL_QUERY = "sql_query"
HISTORY = "history"
QUERT_RESULT = "query_result"
TIMESTAMPS = "timestamps"
SHOT_NUM = "shot_num"
SCENE_NUM = "scene_num"
SCENE_META_LIST = "scene_meta_list"
SHOT_META_LIST = "shot_meta_list"
TASK_OUTPUTS = {
# ============ vision tasks ===================
# ocr detection result for single sample
# {
# "polygons": np.array with shape [num_text, 8], each polygon is
@@ -60,13 +58,11 @@ TASK_OUTPUTS = {
# }
Tasks.ocr_detection: [OutputKeys.POLYGONS],
Tasks.table_recognition: [OutputKeys.POLYGONS],
# ocr recognition result for single sample
# {
# "text": "电子元器件提供BOM配单"
# }
Tasks.ocr_recognition: [OutputKeys.TEXT],
# face 2d keypoint result for single sample
# {
# "keypoints": [
@@ -85,9 +81,7 @@ TASK_OUTPUTS = {
# [x1, y1, x2, y2],
# ]
# }
Tasks.face_2d_keypoints:
[OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES],
Tasks.face_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES],
# face detection result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
@@ -104,9 +98,7 @@ TASK_OUTPUTS = {
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
# ],
# }
Tasks.face_detection:
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS],
Tasks.face_detection: [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS],
# card detection result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
@@ -123,23 +115,18 @@ TASK_OUTPUTS = {
# [x1, y1, x2, y2, x3, y3, x4, y4],
# ],
# }
Tasks.card_detection:
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS],
Tasks.card_detection: [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS],
# facial expression recognition result for single sample
# {
# "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02],
# "labels": ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
# }
Tasks.facial_expression_recognition:
[OutputKeys.SCORES, OutputKeys.LABELS],
Tasks.facial_expression_recognition: [OutputKeys.SCORES, OutputKeys.LABELS],
# face recognition result for single sample
# {
# "img_embedding": np.array with shape [1, D],
# }
Tasks.face_recognition: [OutputKeys.IMG_EMBEDDING],
# human detection result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
@@ -151,22 +138,18 @@ TASK_OUTPUTS = {
# ],
# }
#
Tasks.human_detection:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],
Tasks.human_detection: [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],
# face generation result for single sample
# {
# "output_img": np.array with shape(h, w, 3)
# }
Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG],
# image classification result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["dog", "horse", "cow", "cat"],
# }
Tasks.image_classification: [OutputKeys.SCORES, OutputKeys.LABELS],
# object detection result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
@@ -177,12 +160,13 @@ TASK_OUTPUTS = {
# [x1, y1, x2, y2],
# ],
# }
Tasks.image_object_detection:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],
Tasks.image_object_detection: [
OutputKeys.SCORES,
OutputKeys.LABELS,
OutputKeys.BOXES,
],
# video object detection result for single sample
# {
# "scores": [[0.8, 0.25, 0.05, 0.05], [0.9, 0.1, 0.05, 0.05]]
# "labels": [["person", "traffic light", "car", "bus"],
# ["person", "traffic light", "car", "bus"]]
@@ -201,11 +185,12 @@ TASK_OUTPUTS = {
# [x1, y1, x2, y2],
# ]
# ],
# }
Tasks.video_object_detection:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],
Tasks.video_object_detection: [
OutputKeys.SCORES,
OutputKeys.LABELS,
OutputKeys.BOXES,
],
# instance segmentation result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05],
@@ -214,15 +199,12 @@ TASK_OUTPUTS = {
# np.array # 2D array containing only 0, 1
# ]
# }
Tasks.image_segmentation:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS],
Tasks.image_segmentation: [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS],
# semantic segmentation result for single sample
# {
# "masks": [np.array # 2D array with shape [height, width]]
# }
Tasks.semantic_segmentation: [OutputKeys.MASKS],
# image matting result for single sample
# {
# "output_img": np.array with shape(h, w, 4)
@@ -230,7 +212,6 @@ TASK_OUTPUTS = {
# , shape(h, w) for crowd counting
# }
Tasks.portrait_matting: [OutputKeys.OUTPUT_IMG],
# image editing task result for a single image
# {"output_img": np.array with shape (h, w, 3)}
Tasks.skin_retouching: [OutputKeys.OUTPUT_IMG],
@@ -241,7 +222,6 @@ TASK_OUTPUTS = {
Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG],
Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG],
Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG],
# image generation task result for a single image
# {"output_img": np.array with shape (h, w, 3)}
Tasks.image_to_image_generation: [OutputKeys.OUTPUT_IMG],
@@ -249,20 +229,17 @@ TASK_OUTPUTS = {
Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG],
Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG],
Tasks.image_body_reshaping: [OutputKeys.OUTPUT_IMG],
# live category recognition result for single video
# {
# "scores": [0.885272, 0.014790631, 0.014558001]
# "labels": ['女装/女士精品>>棉衣/棉服', '女装/女士精品>>牛仔裤', '女装/女士精品>>裤子>>休闲裤'],
# }
Tasks.live_category: [OutputKeys.SCORES, OutputKeys.LABELS],
# action recognition result for single video
# {
# "output_label": "abseiling"
# }
Tasks.action_recognition: [OutputKeys.LABELS],
# human body keypoints detection result for single sample
# {
# "keypoints": [
@@ -281,9 +258,11 @@ TASK_OUTPUTS = {
# [x1, y1, x2, y2],
# ]
# }
Tasks.body_2d_keypoints:
[OutputKeys.KEYPOINTS, OutputKeys.SCORES, OutputKeys.BOXES],
Tasks.body_2d_keypoints: [
OutputKeys.KEYPOINTS,
OutputKeys.SCORES,
OutputKeys.BOXES,
],
# 3D human body keypoints detection result for single sample
# {
# "keypoints": [ # 3d pose coordinate in camera coordinate
@@ -299,9 +278,11 @@ TASK_OUTPUTS = {
# "output_video": "path_to_rendered_video" , this is optional
# and is only avaialbe when the "render" option is enabled.
# }
Tasks.body_3d_keypoints:
[OutputKeys.KEYPOINTS, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO],
Tasks.body_3d_keypoints: [
OutputKeys.KEYPOINTS,
OutputKeys.TIMESTAMPS,
OutputKeys.OUTPUT_VIDEO,
],
# 2D hand keypoints result for single sample
# {
# "keypoints": [
@@ -316,7 +297,6 @@ TASK_OUTPUTS = {
# ]
# }
Tasks.hand_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.BOXES],
# video single object tracking result for single video
# {
# "boxes": [
@@ -326,35 +306,29 @@ TASK_OUTPUTS = {
# ],
# "timestamps": ["hh:mm:ss", "hh:mm:ss", "hh:mm:ss"]
# }
Tasks.video_single_object_tracking:
[OutputKeys.BOXES, OutputKeys.TIMESTAMPS],
Tasks.video_single_object_tracking: [OutputKeys.BOXES, OutputKeys.TIMESTAMPS],
# live category recognition result for single video
# {
# "scores": [0.885272, 0.014790631, 0.014558001],
# 'labels': ['修身型棉衣', '高腰牛仔裤', '休闲连体裤']
# }
Tasks.live_category: [OutputKeys.SCORES, OutputKeys.LABELS],
# video category recognition result for single video
# {
# "scores": [0.7716429233551025],
# "labels": ['生活>>好物推荐']
# }
Tasks.video_category: [OutputKeys.SCORES, OutputKeys.LABELS],
# image embedding result for a single image
# {
# "image_bedding": np.array with shape [D]
# }
Tasks.product_retrieval_embedding: [OutputKeys.IMG_EMBEDDING],
# video embedding result for single video
# {
# "video_embedding": np.array with shape [D],
# }
Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING],
# virtual_try_on result for a single sample
# {
# "output_img": np.ndarray with shape [height, width, 3]
@@ -397,10 +371,11 @@ TASK_OUTPUTS = {
#
# }
Tasks.movie_scene_segmentation: [
OutputKeys.SHOT_NUM, OutputKeys.SHOT_META_LIST, OutputKeys.SCENE_NUM,
OutputKeys.SCENE_META_LIST
OutputKeys.SHOT_NUM,
OutputKeys.SHOT_META_LIST,
OutputKeys.SCENE_NUM,
OutputKeys.SCENE_META_LIST,
],
# human whole body keypoints detection result for single sample
# {
# "keypoints": [
@@ -415,7 +390,6 @@ TASK_OUTPUTS = {
# ]
# }
Tasks.human_wholebody_keypoint: [OutputKeys.KEYPOINTS, OutputKeys.BOXES],
# video summarization result for a single video
# {
# "output":
@@ -431,50 +405,42 @@ TASK_OUTPUTS = {
# ]
# }
Tasks.video_summarization: [OutputKeys.OUTPUT],
# referring video object segmentation result for a single video
# {
# "masks": [np.array # 2D array with shape [height, width]]
# }
Tasks.referring_video_object_segmentation: [OutputKeys.MASKS],
# ============ nlp tasks ===================
# text classification result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["happy", "sad", "calm", "angry"],
# }
Tasks.text_classification: [OutputKeys.SCORES, OutputKeys.LABELS],
# sentence similarity result for single sample
# {
# "scores": 0.9
# "labels": "1",
# }
Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS],
# nli result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS],
# sentiment classification result for single sample
# {
# 'scores': [0.07183828949928284, 0.9281617403030396],
# 'labels': ['1', '0']
# }
Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS],
# zero-shot classification result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["happy", "sad", "calm", "angry"],
# }
Tasks.zero_shot_classification: [OutputKeys.SCORES, OutputKeys.LABELS],
# relation extraction result for a single sample
# {
# "uuid": "人生信息-1",
@@ -482,19 +448,16 @@ TASK_OUTPUTS = {
# "spo_list": [{"subject": "石顺义", "predicate": "国籍", "object": "中国"}]
# }
Tasks.relation_extraction: [OutputKeys.SPO_LIST],
# translation result for a source sentence
# {
# "translation": “北京是中国的首都”
# }
Tasks.translation: [OutputKeys.TRANSLATION],
# word segmentation result for single sample
# {
# "output": "今天 天气 不错 适合 出去 游玩"
# }
Tasks.word_segmentation: [OutputKeys.OUTPUT],
# TODO @wenmeng.zwm support list of result check
# named entity recognition result for single sample
# {
@@ -505,7 +468,6 @@ TASK_OUTPUTS = {
# }
Tasks.named_entity_recognition: [OutputKeys.OUTPUT],
Tasks.part_of_speech: [OutputKeys.OUTPUT],
# text_error_correction result for a single sample
# {
# "output": "我想吃苹果"
@@ -513,31 +475,26 @@ TASK_OUTPUTS = {
Tasks.text_error_correction: [OutputKeys.OUTPUT],
Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES],
Tasks.text_ranking: [OutputKeys.SCORES],
# text generation result for single sample
# {
# "text": "this is the text generated by a model."
# }
Tasks.text_generation: [OutputKeys.TEXT],
# summarization result for single sample
# {
# "text": "this is the text generated by a model."
# }
Tasks.text_summarization: [OutputKeys.TEXT],
# text generation result for single sample
# {
# "text": "北京"
# }
Tasks.text2text_generation: [OutputKeys.TEXT],
# fill mask result for single sample
# {
# "text": "this is the text which masks filled by model."
# }
Tasks.fill_mask: [OutputKeys.TEXT],
# feature extraction result for single sample
# {
# "text_embedding": [[
@@ -553,7 +510,6 @@ TASK_OUTPUTS = {
# ]
# }
Tasks.feature_extraction: [OutputKeys.TEXT_EMBEDDING],
# (Deprecated) dialog intent prediction result for single sample
# {'output': {'prediction': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04,
@@ -575,10 +531,8 @@ TASK_OUTPUTS = {
# 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04,
# 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05,
# 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'}}
# (Deprecated) dialog modeling prediction result for single sample
# {'output' : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']}
# (Deprecated) dialog state tracking result for single sample
# {
# "output":{
@@ -617,19 +571,16 @@ TASK_OUTPUTS = {
# }
# }
Tasks.task_oriented_conversation: [OutputKeys.OUTPUT],
# table-question-answering result for single sample
# {
# "sql": "SELECT shop.Name FROM shop."
# "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]}
# }
Tasks.table_question_answering: [OutputKeys.OUTPUT],
# ============ audio tasks ===================
# asr result for single sample
# { "text": "每一天都要快乐喔"}
Tasks.auto_speech_recognition: [OutputKeys.TEXT],
# audio processed for single file in PCM format
# {
# "output_pcm": pcm encoded audio bytes
@@ -637,13 +588,11 @@ TASK_OUTPUTS = {
Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM],
Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM],
Tasks.acoustic_noise_suppression: [OutputKeys.OUTPUT_PCM],
# text_to_speech result for a single sample
# {
# "output_pcm": {"input_label" : np.ndarray with shape [D]}
# }
Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM],
# {
# "kws_list": [
# {
@@ -656,16 +605,13 @@ TASK_OUTPUTS = {
# ]
# }
Tasks.keyword_spotting: [OutputKeys.KWS_LIST],
# ============ multi-modal tasks ===================
# image caption result for single sample
# {
# "caption": "this is an image caption text."
# }
Tasks.image_captioning: [OutputKeys.CAPTION],
Tasks.ocr_recognition: [OutputKeys.TEXT],
# visual grounding result for single sample
# {
# "boxes": [
@@ -676,27 +622,22 @@ TASK_OUTPUTS = {
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.visual_grounding: [OutputKeys.BOXES, OutputKeys.SCORES],
# text_to_image result for a single sample
# {
# "output_img": np.ndarray with shape [height, width, 3]
# }
Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG],
# text_to_speech result for a single sample
# {
# "output_pcm": {"input_label" : np.ndarray with shape [D]}
# }
Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM],
# multi-modal embedding result for single sample
# {
# "img_embedding": np.array with shape [1, D],
# "text_embedding": np.array with shape [1, D]
# }
Tasks.multi_modal_embedding:
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING],
Tasks.multi_modal_embedding: [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING],
# generative multi-modal embedding result for single sample
# {
# "img_embedding": np.array with shape [1, D],
@@ -704,9 +645,10 @@ TASK_OUTPUTS = {
# "caption": "this is an image caption text."
# }
Tasks.generative_multi_modal_embedding: [
OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION
OutputKeys.IMG_EMBEDDING,
OutputKeys.TEXT_EMBEDDING,
OutputKeys.CAPTION,
],
# multi-modal similarity result for single sample
# {
# "img_embedding": np.array with shape [1, D],
@@ -714,25 +656,23 @@ TASK_OUTPUTS = {
# "similarity": float
# }
Tasks.multi_modal_similarity: [
OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES
OutputKeys.IMG_EMBEDDING,
OutputKeys.TEXT_EMBEDDING,
OutputKeys.SCORES,
],
# VQA result for a sample
# {"text": "this is a text answser. "}
Tasks.visual_question_answering: [OutputKeys.TEXT],
# auto_speech_recognition result for a single sample
# {
# "text": "每天都要快乐喔"
# }
Tasks.auto_speech_recognition: [OutputKeys.TEXT],
# {
# "scores": [0.9, 0.1, 0.1],
# "labels": ["entailment", "contradiction", "neutral"]
# }
Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS],
# {
# 'labels': ['吸烟', '打电话', '吸烟'],
# 'scores': [0.7527753114700317, 0.753358006477356, 0.6880350708961487],
@@ -745,7 +685,6 @@ TASK_OUTPUTS = {
OutputKeys.SCORES,
OutputKeys.BOXES,
],
# {
# 'output': [
# [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509},
@@ -756,36 +695,32 @@ TASK_OUTPUTS = {
# {'label': '13421097', 'score': 2.75914817393641e-06}]]
# }
Tasks.faq_question_answering: [OutputKeys.OUTPUT],
# image person reid result for single sample
# {
# "img_embedding": np.array with shape [1, D],
# }
Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING],
# {
# 'output': ['Done' / 'Decode_Error']
# }
Tasks.video_inpainting: [OutputKeys.OUTPUT],
# {
# 'output': ['bixin']
# }
Tasks.hand_static: [OutputKeys.OUTPUT],
# { 'labels': [2, 1, 0],
# 'boxes':[[[78, 282, 240, 504], [127, 87, 332, 370], [0, 0, 367, 639]]
# 'scores':[0.8202137351036072, 0.8987470269203186, 0.9679114818572998]
# }
Tasks.face_human_hand_detection: [
OutputKeys.LABELS, OutputKeys.BOXES, OutputKeys.SCORES
OutputKeys.LABELS,
OutputKeys.BOXES,
OutputKeys.SCORES,
],
# {
# {'output': 'Happiness', 'boxes': (203, 104, 663, 564)}
# }
Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES],
# {
# "masks": [
# np.array # 2D array containing only 0, 255
@@ -796,7 +731,6 @@ TASK_OUTPUTS = {
class ModelOutputBase(list):
def __post_init__(self):
self.reconstruct()
self.post_init = True
@@ -813,7 +747,7 @@ class ModelOutputBase(list):
return getattr(self, item)
elif isinstance(item, (int, slice)):
return super().__getitem__(item)
raise IndexError(f'No Index {item} found in the dataclass.')
raise IndexError(f"No Index {item} found in the dataclass.")
def __setitem__(self, key, value):
if isinstance(key, str):
@@ -832,15 +766,13 @@ class ModelOutputBase(list):
super().__setattr__(key_name, value)
def __setattr__(self, key, value):
if getattr(self, 'post_init', False):
if getattr(self, "post_init", False):
return self.__setitem__(key, value)
else:
return super().__setattr__(key, value)
def keys(self):
return [
f.name for f in fields(self) if getattr(self, f.name) is not None
]
return [f.name for f in fields(self) if getattr(self, f.name) is not None]
def items(self):
return self.to_dict().items()

View File

@@ -13,206 +13,297 @@ from modelscope.utils.registry import Registry, build_from_cfg
from .base import Pipeline
from .util import is_official_hub_path
PIPELINES = Registry('pipelines')
PIPELINES = Registry("pipelines")
DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.sentence_embedding:
(Pipelines.sentence_embedding,
'damo/nlp_corom_sentence-embedding_english-base'),
Tasks.text_ranking: (Pipelines.text_ranking,
'damo/nlp_corom_passage-ranking_english-base'),
Tasks.word_segmentation:
(Pipelines.word_segmentation,
'damo/nlp_structbert_word-segmentation_chinese-base'),
Tasks.part_of_speech: (Pipelines.part_of_speech,
'damo/nlp_structbert_part-of-speech_chinese-base'),
Tasks.token_classification:
(Pipelines.part_of_speech,
'damo/nlp_structbert_part-of-speech_chinese-base'),
Tasks.named_entity_recognition:
(Pipelines.named_entity_recognition,
'damo/nlp_raner_named-entity-recognition_chinese-base-news'),
Tasks.relation_extraction:
(Pipelines.relation_extraction,
'damo/nlp_bert_relation-extraction_chinese-base'),
Tasks.information_extraction:
(Pipelines.relation_extraction,
'damo/nlp_bert_relation-extraction_chinese-base'),
Tasks.sentence_similarity:
(Pipelines.sentence_similarity,
'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.translation: (Pipelines.csanmt_translation,
'damo/nlp_csanmt_translation_zh2en'),
Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'),
Tasks.sentiment_classification:
(Pipelines.sentiment_classification,
'damo/nlp_structbert_sentiment-classification_chinese-base'
), # TODO: revise back after passing the pr
Tasks.portrait_matting: (Pipelines.portrait_matting,
'damo/cv_unet_image-matting'),
Tasks.human_detection: (Pipelines.human_detection,
'damo/cv_resnet18_human-detection'),
Tasks.image_object_detection: (Pipelines.object_detection,
'damo/cv_vit_object-detection_coco'),
Tasks.image_denoising: (Pipelines.image_denoise,
'damo/cv_nafnet_image-denoise_sidd'),
Tasks.text_classification:
(Pipelines.sentiment_classification,
'damo/nlp_structbert_sentiment-classification_chinese-base'),
Tasks.text_generation: (Pipelines.text_generation,
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.zero_shot_classification:
(Pipelines.zero_shot_classification,
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.task_oriented_conversation: (Pipelines.dialog_modeling,
'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'),
Tasks.table_question_answering:
(Pipelines.table_question_answering_pipeline,
'damo/nlp-convai-text2sql-pretrain-cn'),
Tasks.text_error_correction:
(Pipelines.text_error_correction,
'damo/nlp_bart_text-error-correction_chinese'),
Tasks.image_captioning: (Pipelines.image_captioning,
'damo/ofa_image-caption_coco_large_en'),
Tasks.image_portrait_stylization:
(Pipelines.person_image_cartoon,
'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.table_recognition: (Pipelines.table_recognition,
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.feature_extraction: (Pipelines.feature_extraction,
'damo/pert_feature-extraction_base-test'),
Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'),
Tasks.action_detection: (Pipelines.action_detection,
'damo/cv_ResNetC3D_action-detection_detection2d'),
Tasks.live_category: (Pipelines.live_category,
'damo/cv_resnet50_live-category'),
Tasks.video_category: (Pipelines.video_category,
'damo/cv_resnet50_video-category'),
Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-base-patch16_zh'),
Tasks.generative_multi_modal_embedding:
(Pipelines.generative_multi_modal_embedding,
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'
),
Tasks.multi_modal_similarity:
(Pipelines.multi_modal_similarity,
'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'),
Tasks.visual_question_answering:
(Pipelines.visual_question_answering,
'damo/mplug_visual-question-answering_coco_large_en'),
Tasks.video_embedding: (Pipelines.cmdssl_video_embedding,
'damo/cv_r2p1d_video_embedding'),
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
'damo/cv_canonical_body-3d-keypoints_video'),
Tasks.hand_2d_keypoints:
(Pipelines.hand_2d_keypoints,
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'),
Tasks.face_detection: (Pipelines.face_detection,
'damo/cv_resnet_facedetection_scrfd10gkps'),
Tasks.card_detection: (Pipelines.card_detection,
'damo/cv_resnet_carddetection_scrfd34gkps'),
Tasks.face_detection:
(Pipelines.face_detection,
'damo/cv_resnet101_face-detection_cvpr22papermogface'),
Tasks.face_recognition: (Pipelines.face_recognition,
'damo/cv_ir101_facerecognition_cfglint'),
Tasks.facial_expression_recognition:
(Pipelines.facial_expression_recognition,
'damo/cv_vgg19_facial-expression-recognition_fer'),
Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints,
'damo/cv_mobilenet_face-2d-keypoints_alignment'),
Tasks.video_multi_modal_embedding:
(Pipelines.video_multi_modal_embedding,
'damo/multi_modal_clip_vtretrival_msrvtt_53'),
Tasks.image_color_enhancement:
(Pipelines.image_color_enhance,
'damo/cv_csrnet_image-color-enhance-models'),
Tasks.virtual_try_on: (Pipelines.virtual_try_on,
'damo/cv_daflow_virtual-try-on_base'),
Tasks.image_colorization: (Pipelines.image_colorization,
'damo/cv_unet_image-colorization'),
Tasks.image_segmentation:
(Pipelines.image_instance_segmentation,
'damo/cv_swin-b_image-instance-segmentation_coco'),
Tasks.image_style_transfer: (Pipelines.image_style_transfer,
'damo/cv_aams_style-transfer_damo'),
Tasks.face_image_generation: (Pipelines.face_image_generation,
'damo/cv_gan_face-image-generation'),
Tasks.image_super_resolution: (Pipelines.image_super_resolution,
'damo/cv_rrdb_image-super-resolution'),
Tasks.image_portrait_enhancement:
(Pipelines.image_portrait_enhancement,
'damo/cv_gpen_image-portrait-enhancement'),
Tasks.product_retrieval_embedding:
(Pipelines.product_retrieval_embedding,
'damo/cv_resnet50_product-bag-embedding-models'),
Tasks.image_to_image_generation:
(Pipelines.image_to_image_generation,
'damo/cv_latent_diffusion_image2image_generate'),
Tasks.image_classification:
(Pipelines.daily_image_classification,
'damo/cv_vit-base_image-classification_Dailylife-labels'),
Tasks.image_object_detection:
(Pipelines.image_object_detection_auto,
'damo/cv_yolox_image-object-detection-auto'),
Tasks.ocr_recognition:
(Pipelines.ocr_recognition,
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
Tasks.skin_retouching: (Pipelines.skin_retouching,
'damo/cv_unet_skin-retouching'),
Tasks.faq_question_answering:
(Pipelines.faq_question_answering,
'damo/nlp_structbert_faq-question-answering_chinese-base'),
Tasks.crowd_counting: (Pipelines.crowd_counting,
'damo/cv_hrnet_crowd-counting_dcanet'),
Tasks.video_single_object_tracking:
(Pipelines.video_single_object_tracking,
'damo/cv_vitb_video-single-object-tracking_ostrack'),
Tasks.image_reid_person: (Pipelines.image_reid_person,
'damo/cv_passvitb_image-reid-person_market'),
Tasks.text_driven_segmentation:
(Pipelines.text_driven_segmentation,
'damo/cv_vitl16_segmentation_text-driven-seg'),
Tasks.movie_scene_segmentation:
(Pipelines.movie_scene_segmentation,
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
Tasks.shop_segmentation: (Pipelines.shop_segmentation,
'damo/cv_vitb16_segmentation_shop-seg'),
Tasks.image_inpainting: (Pipelines.image_inpainting,
'damo/cv_fft_inpainting_lama'),
Tasks.video_inpainting: (Pipelines.video_inpainting,
'damo/cv_video-inpainting'),
Tasks.human_wholebody_keypoint:
(Pipelines.human_wholebody_keypoint,
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
Tasks.hand_static: (Pipelines.hand_static,
'damo/cv_mobileface_hand-static'),
Tasks.face_human_hand_detection:
(Pipelines.face_human_hand_detection,
'damo/cv_nanodet_face-human-hand-detection'),
Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'),
Tasks.product_segmentation: (Pipelines.product_segmentation,
'damo/cv_F3Net_product-segmentation'),
Tasks.referring_video_object_segmentation:
(Pipelines.referring_video_object_segmentation,
'damo/cv_swin-t_referring_video-object-segmentation'),
Tasks.sentence_embedding: (
Pipelines.sentence_embedding,
"damo/nlp_corom_sentence-embedding_english-base",
),
Tasks.text_ranking: (
Pipelines.text_ranking,
"damo/nlp_corom_passage-ranking_english-base",
),
Tasks.word_segmentation: (
Pipelines.word_segmentation,
"damo/nlp_structbert_word-segmentation_chinese-base",
),
Tasks.part_of_speech: (
Pipelines.part_of_speech,
"damo/nlp_structbert_part-of-speech_chinese-base",
),
Tasks.token_classification: (
Pipelines.part_of_speech,
"damo/nlp_structbert_part-of-speech_chinese-base",
),
Tasks.named_entity_recognition: (
Pipelines.named_entity_recognition,
"damo/nlp_raner_named-entity-recognition_chinese-base-news",
),
Tasks.relation_extraction: (
Pipelines.relation_extraction,
"damo/nlp_bert_relation-extraction_chinese-base",
),
Tasks.information_extraction: (
Pipelines.relation_extraction,
"damo/nlp_bert_relation-extraction_chinese-base",
),
Tasks.sentence_similarity: (
Pipelines.sentence_similarity,
"damo/nlp_structbert_sentence-similarity_chinese-base",
),
Tasks.translation: (
Pipelines.csanmt_translation,
"damo/nlp_csanmt_translation_zh2en",
),
Tasks.nli: (Pipelines.nli, "damo/nlp_structbert_nli_chinese-base"),
Tasks.sentiment_classification: (
Pipelines.sentiment_classification,
"damo/nlp_structbert_sentiment-classification_chinese-base",
), # TODO: revise back after passing the pr
Tasks.portrait_matting: (Pipelines.portrait_matting, "damo/cv_unet_image-matting"),
Tasks.human_detection: (
Pipelines.human_detection,
"damo/cv_resnet18_human-detection",
),
Tasks.image_object_detection: (
Pipelines.object_detection,
"damo/cv_vit_object-detection_coco",
),
Tasks.image_denoising: (
Pipelines.image_denoise,
"damo/cv_nafnet_image-denoise_sidd",
),
Tasks.text_classification: (
Pipelines.sentiment_classification,
"damo/nlp_structbert_sentiment-classification_chinese-base",
),
Tasks.text_generation: (
Pipelines.text_generation,
"damo/nlp_palm2.0_text-generation_chinese-base",
),
Tasks.zero_shot_classification: (
Pipelines.zero_shot_classification,
"damo/nlp_structbert_zero-shot-classification_chinese-base",
),
Tasks.task_oriented_conversation: (
Pipelines.dialog_modeling,
"damo/nlp_space_dialog-modeling",
),
Tasks.dialog_state_tracking: (
Pipelines.dialog_state_tracking,
"damo/nlp_space_dialog-state-tracking",
),
Tasks.table_question_answering: (
Pipelines.table_question_answering_pipeline,
"damo/nlp-convai-text2sql-pretrain-cn",
),
Tasks.text_error_correction: (
Pipelines.text_error_correction,
"damo/nlp_bart_text-error-correction_chinese",
),
Tasks.image_captioning: (
Pipelines.image_captioning,
"damo/ofa_image-caption_coco_large_en",
),
Tasks.image_portrait_stylization: (
Pipelines.person_image_cartoon,
"damo/cv_unet_person-image-cartoon_compound-models",
),
Tasks.ocr_detection: (
Pipelines.ocr_detection,
"damo/cv_resnet18_ocr-detection-line-level_damo",
),
Tasks.table_recognition: (
Pipelines.table_recognition,
"damo/cv_dla34_table-structure-recognition_cycle-centernet",
),
Tasks.fill_mask: (Pipelines.fill_mask, "damo/nlp_veco_fill-mask-large"),
Tasks.feature_extraction: (
Pipelines.feature_extraction,
"damo/pert_feature-extraction_base-test",
),
Tasks.action_recognition: (
Pipelines.action_recognition,
"damo/cv_TAdaConv_action-recognition",
),
Tasks.action_detection: (
Pipelines.action_detection,
"damo/cv_ResNetC3D_action-detection_detection2d",
),
Tasks.live_category: (Pipelines.live_category, "damo/cv_resnet50_live-category"),
Tasks.video_category: (Pipelines.video_category, "damo/cv_resnet50_video-category"),
Tasks.multi_modal_embedding: (
Pipelines.multi_modal_embedding,
"damo/multi-modal_clip-vit-base-patch16_zh",
),
Tasks.generative_multi_modal_embedding: (
Pipelines.generative_multi_modal_embedding,
"damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding",
),
Tasks.multi_modal_similarity: (
Pipelines.multi_modal_similarity,
"damo/multi-modal_team-vit-large-patch14_multi-modal-similarity",
),
Tasks.visual_question_answering: (
Pipelines.visual_question_answering,
"damo/mplug_visual-question-answering_coco_large_en",
),
Tasks.video_embedding: (
Pipelines.cmdssl_video_embedding,
"damo/cv_r2p1d_video_embedding",
),
Tasks.text_to_image_synthesis: (
Pipelines.text_to_image_synthesis,
"damo/cv_diffusion_text-to-image-synthesis_tiny",
),
Tasks.body_2d_keypoints: (
Pipelines.body_2d_keypoints,
"damo/cv_hrnetv2w32_body-2d-keypoints_image",
),
Tasks.body_3d_keypoints: (
Pipelines.body_3d_keypoints,
"damo/cv_canonical_body-3d-keypoints_video",
),
Tasks.hand_2d_keypoints: (
Pipelines.hand_2d_keypoints,
"damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody",
),
Tasks.face_detection: (
Pipelines.face_detection,
"damo/cv_resnet_facedetection_scrfd10gkps",
),
Tasks.card_detection: (
Pipelines.card_detection,
"damo/cv_resnet_carddetection_scrfd34gkps",
),
Tasks.face_detection: (
Pipelines.face_detection,
"damo/cv_resnet101_face-detection_cvpr22papermogface",
),
Tasks.face_recognition: (
Pipelines.face_recognition,
"damo/cv_ir101_facerecognition_cfglint",
),
Tasks.facial_expression_recognition: (
Pipelines.facial_expression_recognition,
"damo/cv_vgg19_facial-expression-recognition_fer",
),
Tasks.face_2d_keypoints: (
Pipelines.face_2d_keypoints,
"damo/cv_mobilenet_face-2d-keypoints_alignment",
),
Tasks.video_multi_modal_embedding: (
Pipelines.video_multi_modal_embedding,
"damo/multi_modal_clip_vtretrival_msrvtt_53",
),
Tasks.image_color_enhancement: (
Pipelines.image_color_enhance,
"damo/cv_csrnet_image-color-enhance-models",
),
Tasks.virtual_try_on: (
Pipelines.virtual_try_on,
"damo/cv_daflow_virtual-try-on_base",
),
Tasks.image_colorization: (
Pipelines.image_colorization,
"damo/cv_unet_image-colorization",
),
Tasks.image_segmentation: (
Pipelines.image_instance_segmentation,
"damo/cv_swin-b_image-instance-segmentation_coco",
),
Tasks.image_style_transfer: (
Pipelines.image_style_transfer,
"damo/cv_aams_style-transfer_damo",
),
Tasks.face_image_generation: (
Pipelines.face_image_generation,
"damo/cv_gan_face-image-generation",
),
Tasks.image_super_resolution: (
Pipelines.image_super_resolution,
"damo/cv_rrdb_image-super-resolution",
),
Tasks.image_portrait_enhancement: (
Pipelines.image_portrait_enhancement,
"damo/cv_gpen_image-portrait-enhancement",
),
Tasks.product_retrieval_embedding: (
Pipelines.product_retrieval_embedding,
"damo/cv_resnet50_product-bag-embedding-models",
),
Tasks.image_to_image_generation: (
Pipelines.image_to_image_generation,
"damo/cv_latent_diffusion_image2image_generate",
),
Tasks.image_classification: (
Pipelines.daily_image_classification,
"damo/cv_vit-base_image-classification_Dailylife-labels",
),
Tasks.image_object_detection: (
Pipelines.image_object_detection_auto,
"damo/cv_yolox_image-object-detection-auto",
),
Tasks.ocr_recognition: (
Pipelines.ocr_recognition,
"damo/cv_convnextTiny_ocr-recognition-general_damo",
),
Tasks.skin_retouching: (Pipelines.skin_retouching, "damo/cv_unet_skin-retouching"),
Tasks.faq_question_answering: (
Pipelines.faq_question_answering,
"damo/nlp_structbert_faq-question-answering_chinese-base",
),
Tasks.crowd_counting: (
Pipelines.crowd_counting,
"damo/cv_hrnet_crowd-counting_dcanet",
),
Tasks.video_single_object_tracking: (
Pipelines.video_single_object_tracking,
"damo/cv_vitb_video-single-object-tracking_ostrack",
),
Tasks.image_reid_person: (
Pipelines.image_reid_person,
"damo/cv_passvitb_image-reid-person_market",
),
Tasks.text_driven_segmentation: (
Pipelines.text_driven_segmentation,
"damo/cv_vitl16_segmentation_text-driven-seg",
),
Tasks.movie_scene_segmentation: (
Pipelines.movie_scene_segmentation,
"damo/cv_resnet50-bert_video-scene-segmentation_movienet",
),
Tasks.shop_segmentation: (
Pipelines.shop_segmentation,
"damo/cv_vitb16_segmentation_shop-seg",
),
Tasks.image_inpainting: (Pipelines.image_inpainting, "damo/cv_fft_inpainting_lama"),
Tasks.video_inpainting: (Pipelines.video_inpainting, "damo/cv_video-inpainting"),
Tasks.human_wholebody_keypoint: (
Pipelines.human_wholebody_keypoint,
"damo/cv_hrnetw48_human-wholebody-keypoint_image",
),
Tasks.hand_static: (Pipelines.hand_static, "damo/cv_mobileface_hand-static"),
Tasks.face_human_hand_detection: (
Pipelines.face_human_hand_detection,
"damo/cv_nanodet_face-human-hand-detection",
),
Tasks.face_emotion: (Pipelines.face_emotion, "damo/cv_face-emotion"),
Tasks.product_segmentation: (
Pipelines.product_segmentation,
"damo/cv_F3Net_product-segmentation",
),
Tasks.referring_video_object_segmentation: (
Pipelines.referring_video_object_segmentation,
"damo/cv_swin-t_referring_video-object-segmentation",
),
}
def normalize_model_input(model, model_revision):
""" normalize the input model, to ensure that a model str is a valid local path: in other words,
"""normalize the input model, to ensure that a model str is a valid local path: in other words,
for model represented by a model id, the model shall be downloaded locally
"""
if isinstance(model, str) and is_official_hub_path(model, model_revision):
@@ -222,18 +313,15 @@ def normalize_model_input(model, model_revision):
model = snapshot_download(model, revision=model_revision)
elif isinstance(model, list) and isinstance(model[0], str):
for idx in range(len(model)):
if is_official_hub_path(
model[idx],
model_revision) and not os.path.exists(model[idx]):
model[idx] = snapshot_download(
model[idx], revision=model_revision)
if is_official_hub_path(model[idx], model_revision) and not os.path.exists(
model[idx]
):
model[idx] = snapshot_download(model[idx], revision=model_revision)
return model
def build_pipeline(cfg: ConfigDict,
task_name: str = None,
default_args: dict = None):
""" build pipeline given model config dict.
def build_pipeline(cfg: ConfigDict, task_name: str = None, default_args: dict = None):
"""build pipeline given model config dict.
Args:
cfg (:obj:`ConfigDict`): config dict for model object.
@@ -242,19 +330,22 @@ def build_pipeline(cfg: ConfigDict,
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(
cfg, PIPELINES, group_key=task_name, default_args=default_args)
cfg, PIPELINES, group_key=task_name, default_args=default_args
)
def pipeline(task: str = None,
model: Union[str, List[str], Model, List[Model]] = None,
preprocessor=None,
config_file: str = None,
pipeline_name: str = None,
framework: str = None,
device: str = 'gpu',
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
**kwargs) -> Pipeline:
""" Factory method to build an obj:`Pipeline`.
def pipeline(
task: str = None,
model: Union[str, List[str], Model, List[Model]] = None,
preprocessor=None,
config_file: str = None,
pipeline_name: str = None,
framework: str = None,
device: str = "gpu",
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
**kwargs,
) -> Pipeline:
"""Factory method to build an obj:`Pipeline`.
Args:
@@ -284,19 +375,21 @@ def pipeline(task: str = None,
>>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2'])
"""
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
raise ValueError("task or pipeline_name is required")
model = normalize_model_input(model, model_revision)
if pipeline_name is None:
# get default pipeline for this task
if isinstance(model, str) \
or (isinstance(model, list) and isinstance(model[0], str)):
if isinstance(model, str) or (
isinstance(model, list) and isinstance(model[0], str)
):
if is_official_hub_path(model, revision=model_revision):
# read config file from hub and parse
cfg = read_config(
model, revision=model_revision) if isinstance(
model, str) else read_config(
model[0], revision=model_revision)
cfg = (
read_config(model, revision=model_revision)
if isinstance(model, str)
else read_config(model[0], revision=model_revision)
)
check_config(cfg)
pipeline_name = cfg.pipeline.type
else:
@@ -305,7 +398,7 @@ def pipeline(task: str = None,
elif model is not None:
# get pipeline info from Model object
first_model = model[0] if isinstance(model, list) else model
if not hasattr(first_model, 'pipeline'):
if not hasattr(first_model, "pipeline"):
# model is instantiated by user, we should parse config again
cfg = read_config(first_model.model_dir)
check_config(cfg)
@@ -327,11 +420,10 @@ def pipeline(task: str = None,
return build_pipeline(cfg, task_name=task)
def add_default_pipeline_info(task: str,
model_name: str,
modelhub_name: str = None,
overwrite: bool = False):
""" Add default model for a task.
def add_default_pipeline_info(
task: str, model_name: str, modelhub_name: str = None, overwrite: bool = False
):
"""Add default model for a task.
Args:
task (str): task name.
@@ -340,14 +432,15 @@ def add_default_pipeline_info(task: str,
overwrite (bool): overwrite default info.
"""
if not overwrite:
assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
f'task {task} already has default model.'
assert (
task not in DEFAULT_MODEL_FOR_PIPELINE
), f"task {task} already has default model."
DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)
def get_default_pipeline_info(task):
""" Get default info for certain task.
"""Get default info for certain task.
Args:
task (str): task name.
@@ -367,7 +460,7 @@ def get_default_pipeline_info(task):
def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]):
""" Get pipeline name by task name and model name
"""Get pipeline name by task name and model name
Args:
task (str): task name.
@@ -376,7 +469,8 @@ def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]):
if isinstance(model, str):
model_key = model
else:
model_key = '_'.join(model)
assert model_key in PIPELINES.modules[task], \
f'pipeline for task {task} model {model_key} not found.'
model_key = "_".join(model)
assert (
model_key in PIPELINES.modules[task]
), f"pipeline for task {task} model {model_key} not found."
return model_key

View File

@@ -47,92 +47,91 @@ if TYPE_CHECKING:
from .video_category_pipeline import VideoCategoryPipeline
from .virtual_try_on_pipeline import VirtualTryonPipeline
from .shop_segmentation_pipleline import ShopSegmentationPipeline
from .easycv_pipelines import (EasyCVDetectionPipeline,
EasyCVSegmentationPipeline,
Face2DKeypointsPipeline,
HumanWholebodyKeypointsPipeline)
from .easycv_pipelines import (
EasyCVDetectionPipeline,
EasyCVSegmentationPipeline,
Face2DKeypointsPipeline,
HumanWholebodyKeypointsPipeline,
)
from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline
from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline
from .mog_face_detection_pipeline import MogFaceDetectionPipeline
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline
from .facial_expression_recognition_pipeline import (
FacialExpressionRecognitionPipeline,
)
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin
from .hand_static_pipeline import HandStaticPipeline
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
from .referring_video_object_segmentation_pipeline import (
ReferringVideoObjectSegmentationPipeline,
)
else:
_import_structure = {
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
'action_detection_pipeline': ['ActionDetectionPipeline'],
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'],
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'],
'hand_2d_keypoints_pipeline': ['Hand2DKeypointsPipeline'],
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'hicossl_video_embedding_pipeline': ['HICOSSLVideoEmbeddingPipeline'],
'crowd_counting_pipeline': ['CrowdCountingPipeline'],
'image_detection_pipeline': ['ImageDetectionPipeline'],
'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'],
'face_detection_pipeline': ['FaceDetectionPipeline'],
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
'face_recognition_pipeline': ['FaceRecognitionPipeline'],
'general_recognition_pipeline': ['GeneralRecognitionPipeline'],
'image_classification_pipeline':
['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'],
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
'image_denoise_pipeline': ['ImageDenoisePipeline'],
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
'image_colorization_pipeline': ['ImageColorizationPipeline'],
'image_instance_segmentation_pipeline':
['ImageInstanceSegmentationPipeline'],
'image_matting_pipeline': ['ImageMattingPipeline'],
'image_panoptic_segmentation_pipeline':
['ImagePanopticSegmentationPipeline'],
'image_portrait_enhancement_pipeline':
['ImagePortraitEnhancementPipeline'],
'image_reid_person_pipeline': ['ImageReidPersonPipeline'],
'image_semantic_segmentation_pipeline':
['ImageSemanticSegmentationPipeline'],
'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'],
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
'image_to_image_translation_pipeline':
['Image2ImageTranslationPipeline'],
'product_retrieval_embedding_pipeline':
['ProductRetrievalEmbeddingPipeline'],
'realtime_object_detection_pipeline':
['RealtimeObjectDetectionPipeline'],
'live_category_pipeline': ['LiveCategoryPipeline'],
'image_to_image_generation_pipeline':
['Image2ImageGenerationPipeline'],
'image_inpainting_pipeline': ['ImageInpaintingPipeline'],
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'],
'table_recognition_pipeline': ['TableRecognitionPipeline'],
'skin_retouching_pipeline': ['SkinRetouchingPipeline'],
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'],
'video_category_pipeline': ['VideoCategoryPipeline'],
'virtual_try_on_pipeline': ['VirtualTryonPipeline'],
'shop_segmentation_pipleline': ['ShopSegmentationPipeline'],
'easycv_pipeline': [
'EasyCVDetectionPipeline',
'EasyCVSegmentationPipeline',
'Face2DKeypointsPipeline',
'HumanWholebodyKeypointsPipeline',
"action_recognition_pipeline": ["ActionRecognitionPipeline"],
"action_detection_pipeline": ["ActionDetectionPipeline"],
"animal_recognition_pipeline": ["AnimalRecognitionPipeline"],
"body_2d_keypoints_pipeline": ["Body2DKeypointsPipeline"],
"body_3d_keypoints_pipeline": ["Body3DKeypointsPipeline"],
"hand_2d_keypoints_pipeline": ["Hand2DKeypointsPipeline"],
"cmdssl_video_embedding_pipeline": ["CMDSSLVideoEmbeddingPipeline"],
"hicossl_video_embedding_pipeline": ["HICOSSLVideoEmbeddingPipeline"],
"crowd_counting_pipeline": ["CrowdCountingPipeline"],
"image_detection_pipeline": ["ImageDetectionPipeline"],
"image_salient_detection_pipeline": ["ImageSalientDetectionPipeline"],
"face_detection_pipeline": ["FaceDetectionPipeline"],
"face_image_generation_pipeline": ["FaceImageGenerationPipeline"],
"face_recognition_pipeline": ["FaceRecognitionPipeline"],
"general_recognition_pipeline": ["GeneralRecognitionPipeline"],
"image_classification_pipeline": [
"GeneralImageClassificationPipeline",
"ImageClassificationPipeline",
],
'text_driven_segmentation_pipeline':
['TextDrivenSegmentationPipeline'],
'movie_scene_segmentation_pipeline':
['MovieSceneSegmentationPipeline'],
'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'],
'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'],
'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'],
'facial_expression_recognition_pipelin':
['FacialExpressionRecognitionPipeline'],
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'],
'hand_static_pipeline': ['HandStaticPipeline'],
'referring_video_object_segmentation_pipeline': [
'ReferringVideoObjectSegmentationPipeline'
"image_cartoon_pipeline": ["ImageCartoonPipeline"],
"image_denoise_pipeline": ["ImageDenoisePipeline"],
"image_color_enhance_pipeline": ["ImageColorEnhancePipeline"],
"image_colorization_pipeline": ["ImageColorizationPipeline"],
"image_instance_segmentation_pipeline": ["ImageInstanceSegmentationPipeline"],
"image_matting_pipeline": ["ImageMattingPipeline"],
"image_panoptic_segmentation_pipeline": ["ImagePanopticSegmentationPipeline"],
"image_portrait_enhancement_pipeline": ["ImagePortraitEnhancementPipeline"],
"image_reid_person_pipeline": ["ImageReidPersonPipeline"],
"image_semantic_segmentation_pipeline": ["ImageSemanticSegmentationPipeline"],
"image_style_transfer_pipeline": ["ImageStyleTransferPipeline"],
"image_super_resolution_pipeline": ["ImageSuperResolutionPipeline"],
"image_to_image_translation_pipeline": ["Image2ImageTranslationPipeline"],
"product_retrieval_embedding_pipeline": ["ProductRetrievalEmbeddingPipeline"],
"realtime_object_detection_pipeline": ["RealtimeObjectDetectionPipeline"],
"live_category_pipeline": ["LiveCategoryPipeline"],
"image_to_image_generation_pipeline": ["Image2ImageGenerationPipeline"],
"image_inpainting_pipeline": ["ImageInpaintingPipeline"],
"ocr_detection_pipeline": ["OCRDetectionPipeline"],
"ocr_recognition_pipeline": ["OCRRecognitionPipeline"],
"table_recognition_pipeline": ["TableRecognitionPipeline"],
"skin_retouching_pipeline": ["SkinRetouchingPipeline"],
"tinynas_classification_pipeline": ["TinynasClassificationPipeline"],
"video_category_pipeline": ["VideoCategoryPipeline"],
"virtual_try_on_pipeline": ["VirtualTryonPipeline"],
"shop_segmentation_pipleline": ["ShopSegmentationPipeline"],
"easycv_pipeline": [
"EasyCVDetectionPipeline",
"EasyCVSegmentationPipeline",
"Face2DKeypointsPipeline",
"HumanWholebodyKeypointsPipeline",
],
"text_driven_segmentation_pipeline": ["TextDrivenSegmentationPipeline"],
"movie_scene_segmentation_pipeline": ["MovieSceneSegmentationPipeline"],
"mog_face_detection_pipeline": ["MogFaceDetectionPipeline"],
"ulfd_face_detection_pipeline": ["UlfdFaceDetectionPipeline"],
"retina_face_detection_pipeline": ["RetinaFaceDetectionPipeline"],
"facial_expression_recognition_pipelin": [
"FacialExpressionRecognitionPipeline"
],
"mtcnn_face_detection_pipeline": ["MtcnnFaceDetectionPipeline"],
"hand_static_pipeline": ["HandStaticPipeline"],
"referring_video_object_segmentation_pipeline": [
"ReferringVideoObjectSegmentationPipeline"
],
}
@@ -140,7 +139,7 @@ else:
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
globals()["__file__"],
_import_structure,
module_spec=__spec__,
extra_objects={},

View File

@@ -15,20 +15,34 @@ import numpy as np
BatchNorm = nn.BatchNorm2d
def get_model_url(data='imagenet', name='dla34', hash='ba72cf86'):
return join('http://dl.yf.io/dla/models', data, '{}-{}.pth'.format(name, hash))
def get_model_url(data="imagenet", name="dla34", hash="ba72cf86"):
return join("http://dl.yf.io/dla/models", data, "{}-{}.pth".format(name, hash))
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
stride=stride, padding=dilation,
bias=False, dilation=dilation)
self.conv1 = nn.Conv2d(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation,
)
self.bn1 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=dilation,
bias=False, dilation=dilation)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=1,
padding=dilation,
bias=False,
dilation=dilation,
)
self.bn2 = BatchNorm(planes)
self.stride = stride
@@ -56,15 +70,19 @@ class Bottleneck(nn.Module):
super(Bottleneck, self).__init__()
expansion = Bottleneck.expansion
bottle_planes = planes // expansion
self.conv1 = nn.Conv2d(inplanes, bottle_planes,
kernel_size=1, bias=False)
self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(bottle_planes)
self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3,
stride=stride, padding=dilation,
bias=False, dilation=dilation)
self.conv2 = nn.Conv2d(
bottle_planes,
bottle_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation,
)
self.bn2 = BatchNorm(bottle_planes)
self.conv3 = nn.Conv2d(bottle_planes, planes,
kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
@@ -100,15 +118,20 @@ class BottleneckX(nn.Module):
# dim = int(math.floor(planes * (BottleneckV5.expansion / 64.0)))
# bottle_planes = dim * cardinality
bottle_planes = planes * cardinality // 32
self.conv1 = nn.Conv2d(inplanes, bottle_planes,
kernel_size=1, bias=False)
self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(bottle_planes)
self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3,
stride=stride, padding=dilation, bias=False,
dilation=dilation, groups=cardinality)
self.conv2 = nn.Conv2d(
bottle_planes,
bottle_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation,
groups=cardinality,
)
self.bn2 = BatchNorm(bottle_planes)
self.conv3 = nn.Conv2d(bottle_planes, planes,
kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
@@ -138,8 +161,13 @@ class Root(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual):
super(Root, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1,
stride=1, bias=False, padding=(kernel_size - 1) // 2)
in_channels,
out_channels,
1,
stride=1,
bias=False,
padding=(kernel_size - 1) // 2,
)
self.bn = BatchNorm(out_channels)
self.relu = nn.ReLU(inplace=True)
self.residual = residual
@@ -156,31 +184,51 @@ class Root(nn.Module):
class Tree(nn.Module):
def __init__(self, levels, block, in_channels, out_channels, stride=1,
level_root=False, root_dim=0, root_kernel_size=1,
dilation=1, root_residual=False):
def __init__(
self,
levels,
block,
in_channels,
out_channels,
stride=1,
level_root=False,
root_dim=0,
root_kernel_size=1,
dilation=1,
root_residual=False,
):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
self.tree1 = block(in_channels, out_channels, stride,
dilation=dilation)
self.tree2 = block(out_channels, out_channels, 1,
dilation=dilation)
self.tree1 = block(in_channels, out_channels, stride, dilation=dilation)
self.tree2 = block(out_channels, out_channels, 1, dilation=dilation)
else:
self.tree1 = Tree(levels - 1, block, in_channels, out_channels,
stride, root_dim=0,
root_kernel_size=root_kernel_size,
dilation=dilation, root_residual=root_residual)
self.tree2 = Tree(levels - 1, block, out_channels, out_channels,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation, root_residual=root_residual)
self.tree1 = Tree(
levels - 1,
block,
in_channels,
out_channels,
stride,
root_dim=0,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual,
)
self.tree2 = Tree(
levels - 1,
block,
out_channels,
out_channels,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual,
)
if levels == 1:
self.root = Root(root_dim, out_channels, root_kernel_size,
root_residual)
self.root = Root(root_dim, out_channels, root_kernel_size, root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
@@ -190,9 +238,10 @@ class Tree(nn.Module):
self.downsample = nn.MaxPool2d(stride, stride=stride)
if in_channels != out_channels:
self.project = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=1, bias=False),
BatchNorm(out_channels)
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, bias=False
),
BatchNorm(out_channels),
)
def forward(self, x, residual=None, children=None):
@@ -212,40 +261,76 @@ class Tree(nn.Module):
class DLA(nn.Module):
def __init__(self, levels, channels, num_classes=1000,
block=BasicBlock, residual_root=False, return_levels=False,
pool_size=7, linear_root=False):
def __init__(
self,
levels,
channels,
num_classes=1000,
block=BasicBlock,
residual_root=False,
return_levels=False,
pool_size=7,
linear_root=False,
):
super(DLA, self).__init__()
self.channels = channels
self.return_levels = return_levels
self.num_classes = num_classes
self.base_layer = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
padding=3, bias=False),
nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
BatchNorm(channels[0]),
nn.ReLU(inplace=True))
self.level0 = self._make_conv_level(
channels[0], channels[0], levels[0])
nn.ReLU(inplace=True),
)
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
self.level1 = self._make_conv_level(
channels[0], channels[1], levels[1], stride=2)
self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,
level_root=False,
root_residual=residual_root)
self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,
level_root=True, root_residual=residual_root)
self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,
level_root=True, root_residual=residual_root)
self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,
level_root=True, root_residual=residual_root)
channels[0], channels[1], levels[1], stride=2
)
self.level2 = Tree(
levels[2],
block,
channels[1],
channels[2],
2,
level_root=False,
root_residual=residual_root,
)
self.level3 = Tree(
levels[3],
block,
channels[2],
channels[3],
2,
level_root=True,
root_residual=residual_root,
)
self.level4 = Tree(
levels[4],
block,
channels[3],
channels[4],
2,
level_root=True,
root_residual=residual_root,
)
self.level5 = Tree(
levels[5],
block,
channels[4],
channels[5],
2,
level_root=True,
root_residual=residual_root,
)
self.avgpool = nn.AvgPool2d(pool_size)
self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1,
stride=1, padding=0, bias=True)
self.fc = nn.Conv2d(
channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, BatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
@@ -255,8 +340,7 @@ class DLA(nn.Module):
if stride != 1 or inplanes != planes:
downsample = nn.Sequential(
nn.MaxPool2d(stride, stride=stride),
nn.Conv2d(inplanes, planes,
kernel_size=1, stride=1, bias=False),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
BatchNorm(planes),
)
@@ -270,12 +354,21 @@ class DLA(nn.Module):
def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
modules = []
for i in range(convs):
modules.extend([
nn.Conv2d(inplanes, planes, kernel_size=3,
stride=stride if i == 0 else 1,
padding=dilation, bias=False, dilation=dilation),
BatchNorm(planes),
nn.ReLU(inplace=True)])
modules.extend(
[
nn.Conv2d(
inplanes,
planes,
kernel_size=3,
stride=stride if i == 0 else 1,
padding=dilation,
bias=False,
dilation=dilation,
),
BatchNorm(planes),
nn.ReLU(inplace=True),
]
)
inplanes = planes
return nn.Sequential(*modules)
@@ -283,7 +376,7 @@ class DLA(nn.Module):
y = []
x = self.base_layer(x)
for i in range(6):
x = getattr(self, 'level{}'.format(i))(x)
x = getattr(self, "level{}".format(i))(x)
y.append(x)
if self.return_levels:
return y
@@ -294,113 +387,138 @@ class DLA(nn.Module):
return x
def load_pretrained_model(self, data='imagenet', name='dla34', hash='ba72cf86'):
def load_pretrained_model(self, data="imagenet", name="dla34", hash="ba72cf86"):
fc = self.fc
if name.endswith('.pth'):
if name.endswith(".pth"):
model_weights = torch.load(data + name)
else:
model_url = get_model_url(data, name, hash)
model_weights = model_zoo.load_url(model_url)
num_classes = len(model_weights[list(model_weights.keys())[-1]])
self.fc = nn.Conv2d(
self.channels[-1], num_classes,
kernel_size=1, stride=1, padding=0, bias=True)
self.channels[-1],
num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.load_state_dict(model_weights)
self.fc = fc
def dla34(pretrained, **kwargs): # DLA-34
model = DLA([1, 1, 1, 2, 2, 1],
[16, 32, 64, 128, 256, 512],
block=BasicBlock, **kwargs)
model = DLA(
[1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs
)
if pretrained:
model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86')
model.load_pretrained_model(data="imagenet", name="dla34", hash="ba72cf86")
return model
def dla46_c(pretrained=None, **kwargs): # DLA-46-C
Bottleneck.expansion = 2
model = DLA([1, 1, 1, 2, 2, 1],
[16, 32, 64, 64, 128, 256],
block=Bottleneck, **kwargs)
model = DLA(
[1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=Bottleneck, **kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla46_c')
model.load_pretrained_model(pretrained, "dla46_c")
return model
def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 2, 2, 1],
[16, 32, 64, 64, 128, 256],
block=BottleneckX, **kwargs)
model = DLA(
[1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla46x_c')
model.load_pretrained_model(pretrained, "dla46x_c")
return model
def dla60x_c(pretrained, **kwargs): # DLA-X-60-C
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 2, 3, 1],
[16, 32, 64, 64, 128, 256],
block=BottleneckX, **kwargs)
model = DLA(
[1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256], block=BottleneckX, **kwargs
)
if pretrained:
model.load_pretrained_model(data='imagenet', name='dla60x_c', hash='b870c45c')
model.load_pretrained_model(data="imagenet", name="dla60x_c", hash="b870c45c")
return model
def dla60(pretrained=None, **kwargs): # DLA-60
Bottleneck.expansion = 2
model = DLA([1, 1, 1, 2, 3, 1],
[16, 32, 128, 256, 512, 1024],
block=Bottleneck, **kwargs)
model = DLA(
[1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=Bottleneck, **kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla60')
model.load_pretrained_model(pretrained, "dla60")
return model
def dla60x(pretrained=None, **kwargs): # DLA-X-60
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 2, 3, 1],
[16, 32, 128, 256, 512, 1024],
block=BottleneckX, **kwargs)
model = DLA(
[1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], block=BottleneckX, **kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla60x')
model.load_pretrained_model(pretrained, "dla60x")
return model
def dla102(pretrained=None, **kwargs): # DLA-102
Bottleneck.expansion = 2
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=Bottleneck, residual_root=True, **kwargs)
model = DLA(
[1, 1, 1, 3, 4, 1],
[16, 32, 128, 256, 512, 1024],
block=Bottleneck,
residual_root=True,
**kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla102')
model.load_pretrained_model(pretrained, "dla102")
return model
def dla102x(pretrained=None, **kwargs): # DLA-X-102
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=BottleneckX, residual_root=True, **kwargs)
model = DLA(
[1, 1, 1, 3, 4, 1],
[16, 32, 128, 256, 512, 1024],
block=BottleneckX,
residual_root=True,
**kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla102x')
model.load_pretrained_model(pretrained, "dla102x")
return model
def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64
BottleneckX.cardinality = 64
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=BottleneckX, residual_root=True, **kwargs)
model = DLA(
[1, 1, 1, 3, 4, 1],
[16, 32, 128, 256, 512, 1024],
block=BottleneckX,
residual_root=True,
**kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla102x2')
model.load_pretrained_model(pretrained, "dla102x2")
return model
def dla169(pretrained=None, **kwargs): # DLA-169
Bottleneck.expansion = 2
model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024],
block=Bottleneck, residual_root=True, **kwargs)
model = DLA(
[1, 1, 2, 3, 5, 1],
[16, 32, 128, 256, 512, 1024],
block=Bottleneck,
residual_root=True,
**kwargs
)
if pretrained is not None:
model.load_pretrained_model(pretrained, 'dla169')
model.load_pretrained_model(pretrained, "dla169")
return model
@@ -421,11 +539,10 @@ class Identity(nn.Module):
def fill_up_weights(up):
w = up.weight.data
f = math.ceil(w.size(2) / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
c = (2 * f - 1 - f % 2) / (2.0 * f)
for i in range(w.size(2)):
for j in range(w.size(3)):
w[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, w.size(0)):
w[c, 0, :, :] = w[0, 0, :, :]
@@ -440,50 +557,64 @@ class IDAUp(nn.Module):
proj = Identity()
else:
proj = nn.Sequential(
nn.Conv2d(c, out_dim,
kernel_size=1, stride=1, bias=False),
nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False),
BatchNorm(out_dim),
nn.ReLU(inplace=True))
nn.ReLU(inplace=True),
)
f = int(up_factors[i])
if f == 1:
up = Identity()
else:
up = nn.ConvTranspose2d(
out_dim, out_dim, f * 2, stride=f, padding=f // 2,
output_padding=0, groups=out_dim, bias=False)
out_dim,
out_dim,
f * 2,
stride=f,
padding=f // 2,
output_padding=0,
groups=out_dim,
bias=False,
)
fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)
setattr(self, "proj_" + str(i), proj)
setattr(self, "up_" + str(i), up)
for i in range(1, len(channels)):
node = nn.Sequential(
nn.Conv2d(out_dim * 2, out_dim,
kernel_size=node_kernel, stride=1,
padding=node_kernel // 2, bias=False),
nn.Conv2d(
out_dim * 2,
out_dim,
kernel_size=node_kernel,
stride=1,
padding=node_kernel // 2,
bias=False,
),
BatchNorm(out_dim),
nn.ReLU(inplace=True))
setattr(self, 'node_' + str(i), node)
nn.ReLU(inplace=True),
)
setattr(self, "node_" + str(i), node)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, BatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, layers):
assert len(self.channels) == len(layers), \
'{} vs {} layers'.format(len(self.channels), len(layers))
assert len(self.channels) == len(layers), "{} vs {} layers".format(
len(self.channels), len(layers)
)
layers = list(layers)
for i, l in enumerate(layers):
upsample = getattr(self, 'up_' + str(i))
project = getattr(self, 'proj_' + str(i))
upsample = getattr(self, "up_" + str(i))
project = getattr(self, "proj_" + str(i))
layers[i] = upsample(project(l))
x = layers[0]
y = []
for i in range(1, len(layers)):
node = getattr(self, 'node_' + str(i))
node = getattr(self, "node_" + str(i))
x = node(torch.cat([x, layers[i]], 1))
y.append(x)
return x, y
@@ -499,21 +630,24 @@ class DLAUp(nn.Module):
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
setattr(self, 'ida_{}'.format(i),
IDAUp(3, channels[j], in_channels[j:],
scales[j:] // scales[j]))
scales[j + 1:] = scales[j]
in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]
setattr(
self,
"ida_{}".format(i),
IDAUp(3, channels[j], in_channels[j:], scales[j:] // scales[j]),
)
scales[j + 1 :] = scales[j]
in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]]
def forward(self, layers):
layers = list(layers)
assert len(layers) > 1
for i in range(len(layers) - 1):
ida = getattr(self, 'ida_{}'.format(i))
x, y = ida(layers[-i - 2:])
layers[-i - 1:] = y
ida = getattr(self, "ida_{}".format(i))
x, y = ida(layers[-i - 2 :])
layers[-i - 1 :] = y
return x
def fill_fc_weights(layers):
for m in layers.modules():
if isinstance(m, nn.Conv2d):
@@ -523,38 +657,55 @@ def fill_fc_weights(layers):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class DLASeg(nn.Module):
def __init__(self, base_name='dla34',
pretrained=False, down_ratio=4, head_conv=256):
def __init__(
self, base_name="dla34", pretrained=False, down_ratio=4, head_conv=256
):
super(DLASeg, self).__init__()
assert down_ratio in [2, 4, 8, 16]
self.heads = {'hm': 2,'v2c':8, 'c2v': 8, 'reg': 2}
self.heads = {"hm": 2, "v2c": 8, "c2v": 8, "reg": 2}
self.first_level = int(np.log2(down_ratio))
self.base = globals()[base_name](
pretrained=pretrained, return_levels=True)
self.base = globals()[base_name](pretrained=pretrained, return_levels=True)
channels = self.base.channels
scales = [2 ** i for i in range(len(channels[self.first_level:]))]
self.dla_up = DLAUp(channels[self.first_level:], scales=scales)
scales = [2**i for i in range(len(channels[self.first_level :]))]
self.dla_up = DLAUp(channels[self.first_level :], scales=scales)
for head in self.heads:
classes = self.heads[head]
if head_conv > 0:
fc = nn.Sequential(
nn.Conv2d(channels[self.first_level], head_conv,
kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, classes,
kernel_size=1, stride=1,
padding=0, bias=True))
if 'hm' in head:
nn.Conv2d(
channels[self.first_level],
head_conv,
kernel_size=3,
padding=1,
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv2d(
head_conv,
classes,
kernel_size=1,
stride=1,
padding=0,
bias=True,
),
)
if "hm" in head:
fc[-1].bias.data.fill_(-2.19)
else:
fill_fc_weights(fc)
else:
fc = nn.Conv2d(channels[self.first_level], classes,
kernel_size=1, stride=1,
padding=0, bias=True)
if 'hm' in head:
fc = nn.Conv2d(
channels[self.first_level],
classes,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
if "hm" in head:
fc.bias.data.fill_(-2.19)
else:
fill_fc_weights(fc)
@@ -562,7 +713,7 @@ class DLASeg(nn.Module):
def forward(self, x):
x = self.base(x)
x = self.dla_up(x[self.first_level:])
x = self.dla_up(x[self.first_level :])
ret = {}
for head in self.heads:
ret[head] = self.__getattr__(head)(x)
@@ -570,5 +721,5 @@ class DLASeg(nn.Module):
def TableRecModel():
model = DLASeg()
return model
model = DLASeg()
return model

View File

@@ -1,11 +1,12 @@
import numpy as np
import cv2
import cv2
import copy
import math
import random
import torch
import torch.nn as nn
def transform_preds(coords, center, scale, output_size, rot=0):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, rot, output_size, inv=1)
@@ -13,12 +14,10 @@ def transform_preds(coords, center, scale, output_size, rot=0):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
def get_affine_transform(
center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale], dtype=np.float32)
@@ -27,7 +26,7 @@ def get_affine_transform(center,
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
@@ -38,8 +37,8 @@ def get_affine_transform(center,
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
@@ -48,11 +47,13 @@ def get_affine_transform(center,
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
@@ -62,17 +63,20 @@ def get_dir(src_point, rot_rad):
return src_result
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def _sigmoid(x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
return y
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
return y
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
@@ -80,19 +84,21 @@ def _gather_feat(feat, ind, mask=None):
feat = feat.view(-1, dim)
return feat
def _tranpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = _gather_feat(feat, ind)
return feat
def _nms(heat, kernel=3):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep,keep
return heat * keep, keep
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
@@ -100,13 +106,12 @@ def _topk(scores, K=40):
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_clses = (topk_ind / K).int()
topk_inds = _gather_feat(
topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
@@ -118,37 +123,43 @@ def bbox_decode(heat, wh, reg=None, K=100):
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat,keep = _nms(heat)
heat, keep = _nms(heat)
scores, inds, clses, ys, xs = _topk(heat, K=K)
if reg is not None:
reg = _tranpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
reg = _tranpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
wh = _tranpose_and_gather_feat(wh, inds)
wh = wh.view(batch, K, 8)
clses = clses.view(batch, K, 1).float()
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1],
ys - wh[..., 1:2],
xs - wh[..., 2:3],
ys - wh[..., 3:4],
xs - wh[..., 4:5],
ys - wh[..., 5:6],
xs - wh[..., 6:7],
ys - wh[..., 7:8]], dim=2)
bboxes = torch.cat(
[
xs - wh[..., 0:1],
ys - wh[..., 1:2],
xs - wh[..., 2:3],
ys - wh[..., 3:4],
xs - wh[..., 4:5],
ys - wh[..., 5:6],
xs - wh[..., 6:7],
ys - wh[..., 7:8],
],
dim=2,
)
detections = torch.cat([bboxes, scores, clses], dim=2)
return detections,keep
return detections, keep
def gbox_decode(mk, st_reg, reg=None, K=400):
batch, cat, height, width = mk.size()
mk,keep = _nms(mk)
mk, keep = _nms(mk)
scores, inds, clses, ys, xs = _topk(mk, K=K)
if reg is not None:
reg = _tranpose_and_gather_feat(reg, inds)
@@ -159,17 +170,23 @@ def gbox_decode(mk, st_reg, reg=None, K=400):
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
scores = scores.view(batch, K, 1)
clses = clses.view(batch, K, 1).float()
clses = clses.view(batch, K, 1).float()
st_Reg = _tranpose_and_gather_feat(st_reg, inds)
bboxes = torch.cat([xs - st_Reg[..., 0:1],
ys - st_Reg[..., 1:2],
xs - st_Reg[..., 2:3],
ys - st_Reg[..., 3:4],
xs - st_Reg[..., 4:5],
ys - st_Reg[..., 5:6],
xs - st_Reg[..., 6:7],
ys - st_Reg[..., 7:8]], dim=2)
return torch.cat([xs,ys,bboxes,scores,clses], dim=2), keep
bboxes = torch.cat(
[
xs - st_Reg[..., 0:1],
ys - st_Reg[..., 1:2],
xs - st_Reg[..., 2:3],
ys - st_Reg[..., 3:4],
xs - st_Reg[..., 4:5],
ys - st_Reg[..., 5:6],
xs - st_Reg[..., 6:7],
ys - st_Reg[..., 7:8],
],
dim=2,
)
return torch.cat([xs, ys, bboxes, scores, clses], dim=2), keep
def bbox_post_process(bbox, c, s, h, w):
# dets: batch x max_dets x dim
@@ -179,102 +196,116 @@ def bbox_post_process(bbox, c, s, h, w):
bbox[i, :, 2:4] = transform_preds(bbox[i, :, 2:4], c[i], s[i], (w, h))
bbox[i, :, 4:6] = transform_preds(bbox[i, :, 4:6], c[i], s[i], (w, h))
bbox[i, :, 6:8] = transform_preds(bbox[i, :, 6:8], c[i], s[i], (w, h))
return bbox
return bbox
def gbox_post_process(gbox, c, s, h, w):
for i in range(gbox.shape[0]):
gbox[i, :, 0:2] = transform_preds(gbox[i, :, 0:2], c[i], s[i], (w, h))
gbox[i, :, 2:4] = transform_preds(gbox[i, :, 2:4], c[i], s[i], (w, h))
gbox[i, :, 4:6] = transform_preds(gbox[i, :, 4:6], c[i], s[i], (w, h))
gbox[i, :, 6:8] = transform_preds(gbox[i, :, 6:8], c[i], s[i], (w, h))
gbox[i, :, 8:10] = transform_preds(gbox[i, :, 8:10], c[i], s[i], (w, h))
gbox[i, :, 0:2] = transform_preds(gbox[i, :, 0:2], c[i], s[i], (w, h))
gbox[i, :, 2:4] = transform_preds(gbox[i, :, 2:4], c[i], s[i], (w, h))
gbox[i, :, 4:6] = transform_preds(gbox[i, :, 4:6], c[i], s[i], (w, h))
gbox[i, :, 6:8] = transform_preds(gbox[i, :, 6:8], c[i], s[i], (w, h))
gbox[i, :, 8:10] = transform_preds(gbox[i, :, 8:10], c[i], s[i], (w, h))
return gbox
def nms(dets,thresh):
if len(dets)<2:
def nms(dets, thresh):
if len(dets) < 2:
return dets
scores = dets[:,8]
scores = dets[:, 8]
index_keep = []
keep = []
for i in range(len(dets)):
box = dets[i]
if box[-1]<thresh:
if box[-1] < thresh:
break
max_score_index = -1
ctx = (dets[i][0] + dets[i][2] + dets[i][4] + dets[i][6])/4
cty = (dets[i][1] + dets[i][3] + dets[i][5] + dets[i][7])/4
ctx = (dets[i][0] + dets[i][2] + dets[i][4] + dets[i][6]) / 4
cty = (dets[i][1] + dets[i][3] + dets[i][5] + dets[i][7]) / 4
for j in range(len(dets)):
if i==j or dets[j][-1]<thresh:
if i == j or dets[j][-1] < thresh:
break
x1,y1 = dets[j][0],dets[j][1]
x2,y2 = dets[j][2],dets[j][3]
x3,y3 = dets[j][4],dets[j][5]
x4,y4 = dets[j][6],dets[j][7]
a = (x2 - x1)*(cty - y1) - (y2 - y1)*(ctx - x1)
b = (x3 - x2)*(cty - y2) - (y3 - y2)*(ctx - x2)
c = (x4 - x3)*(cty - y3) - (y4 - y3)*(ctx - x3)
d = (x1 - x4)*(cty - y4) - (y1 - y4)*(ctx - x4)
if ((a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 and c < 0 and d < 0)):
x1, y1 = dets[j][0], dets[j][1]
x2, y2 = dets[j][2], dets[j][3]
x3, y3 = dets[j][4], dets[j][5]
x4, y4 = dets[j][6], dets[j][7]
a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1)
b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2)
c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3)
d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4)
if (a > 0 and b > 0 and c > 0 and d > 0) or (
a < 0 and b < 0 and c < 0 and d < 0
):
if dets[i][8] > dets[j][8] and max_score_index < 0:
max_score_index = i
elif dets[i][8] < dets[j][8]:
max_score_index = i
elif dets[i][8] < dets[j][8]:
max_score_index = -2
break
if max_score_index > -1:
if max_score_index > -1:
index_keep.append(max_score_index)
elif max_score_index==-1:
elif max_score_index == -1:
index_keep.append(i)
for i in range(0,len(index_keep)):
for i in range(0, len(index_keep)):
keep.append(dets[index_keep[i]])
return np.array(keep)
def group_bbox_by_gbox(bboxes,gboxes,score_thred=0.3, v2c_dist_thred=2, c2v_dist_thred=0.5):
def point_in_box(box,point):
x1,y1,x2,y2 = box[0],box[1],box[2],box[3]
x3,y3,x4,y4 = box[4],box[5],box[6],box[7]
ctx,cty = point[0],point[1]
a = (x2 - x1)*(cty - y1) - (y2 - y1)*(ctx - x1)
b = (x3 - x2)*(cty - y2) - (y3 - y2)*(ctx - x2)
c = (x4 - x3)*(cty - y3) - (y4 - y3)*(ctx - x3)
d = (x1 - x4)*(cty - y4) - (y1 - y4)*(ctx - x4)
if ((a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 and c < 0 and d < 0)):
def group_bbox_by_gbox(
bboxes, gboxes, score_thred=0.3, v2c_dist_thred=2, c2v_dist_thred=0.5
):
def point_in_box(box, point):
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
x3, y3, x4, y4 = box[4], box[5], box[6], box[7]
ctx, cty = point[0], point[1]
a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1)
b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2)
c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3)
d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4)
if (a > 0 and b > 0 and c > 0 and d > 0) or (
a < 0 and b < 0 and c < 0 and d < 0
):
return True
else :
else:
return False
def get_distance(pt1,pt2):
return math.sqrt((pt1[0]-pt2[0])*(pt1[0]-pt2[0]) + (pt1[1]-pt2[1])*(pt1[1]-pt2[1]))
def get_distance(pt1, pt2):
return math.sqrt(
(pt1[0] - pt2[0]) * (pt1[0] - pt2[0])
+ (pt1[1] - pt2[1]) * (pt1[1] - pt2[1])
)
dets = copy.deepcopy(bboxes)
sign = np.zeros((len(dets),4))
sign = np.zeros((len(dets), 4))
for idx,gbox in enumerate(gboxes): #vertex x,y, gbox, score
for idx, gbox in enumerate(gboxes): # vertex x,y, gbox, score
if gbox[10] < score_thred:
break
vertex = [gbox[0],gbox[1]]
for i in range(0,4):
center = [gbox[2*i+2],gbox[2*i+3]]
if get_distance(vertex,center) < v2c_dist_thred:
vertex = [gbox[0], gbox[1]]
for i in range(0, 4):
center = [gbox[2 * i + 2], gbox[2 * i + 3]]
if get_distance(vertex, center) < v2c_dist_thred:
continue
for k,bbox in enumerate(dets):
for k, bbox in enumerate(dets):
if bbox[8] < score_thred:
break
if sum(sign[k])==4:
if sum(sign[k]) == 4:
continue
w = (abs(bbox[6] - bbox[0]) + abs(bbox[4] - bbox[2])) / 2
h = (abs(bbox[3] - bbox[1]) + abs(bbox[5] - bbox[7])) / 2
m = max(w,h)
if point_in_box(bbox,center):
min_dist,min_id = 1e4,-1
for j in range(0,4):
dist = get_distance(vertex,[bbox[2*j],bbox[2*j+1]])
w = (abs(bbox[6] - bbox[0]) + abs(bbox[4] - bbox[2])) / 2
h = (abs(bbox[3] - bbox[1]) + abs(bbox[5] - bbox[7])) / 2
m = max(w, h)
if point_in_box(bbox, center):
min_dist, min_id = 1e4, -1
for j in range(0, 4):
dist = get_distance(vertex, [bbox[2 * j], bbox[2 * j + 1]])
if dist < min_dist:
min_dist = dist
min_id = j
if min_id>-1 and min_dist<c2v_dist_thred*m and sign[k][min_id]==0:
bboxes[k][2*min_id] = vertex[0]
bboxes[k][2*min_id+1] = vertex[1]
if (
min_id > -1
and min_dist < c2v_dist_thred * m
and sign[k][min_id] == 0
):
bboxes[k][2 * min_id] = vertex[0]
bboxes[k][2 * min_id + 1] = vertex[1]
sign[k][min_id] = 1
return bboxes

View File

@@ -16,15 +16,25 @@ from modelscope.pipelines.cv.ocr_utils.model_dla34 import TableRecModel
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from modelscope.pipelines.cv.ocr_utils.table_process import get_affine_transform,bbox_decode,gbox_decode,nms
from modelscope.pipelines.cv.ocr_utils.table_process import bbox_post_process,gbox_post_process,group_bbox_by_gbox
from modelscope.pipelines.cv.ocr_utils.table_process import (
get_affine_transform,
bbox_decode,
gbox_decode,
nms,
)
from modelscope.pipelines.cv.ocr_utils.table_process import (
bbox_post_process,
gbox_post_process,
group_bbox_by_gbox,
)
logger = get_logger()
@PIPELINES.register_module(
Tasks.table_recognition, module_name=Pipelines.table_recognition)
class TableRecognitionPipeline(Pipeline):
@PIPELINES.register_module(
Tasks.table_recognition, module_name=Pipelines.table_recognition
)
class TableRecognitionPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
Args:
@@ -32,17 +42,16 @@ class TableRecognitionPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
logger.info(f"loading model from {model_path}")
self.K = 1000
self.MK = 4000
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.infer_model = TableRecModel().to(self.device)
self.infer_model.eval()
checkpoint = torch.load(model_path, map_location=self.device)
if 'state_dict' in checkpoint:
self.infer_model.load_state_dict(checkpoint['state_dict'])
if "state_dict" in checkpoint:
self.infer_model.load_state_dict(checkpoint["state_dict"])
else:
self.infer_model.load_state_dict(checkpoint)
@@ -55,60 +64,77 @@ class TableRecognitionPipeline(Pipeline):
if len(input.shape) == 3:
img = input
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
raise TypeError(
f"input should be either str, PIL.Image,"
f" np.array, but got {type(input)}"
)
mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
height, width = img.shape[0:2]
inp_height, inp_width = 1024, 1024
c = np.array([width / 2., height / 2.], dtype=np.float32)
s = max(height, width) * 1.0
c = np.array([width / 2.0, height / 2.0], dtype=np.float32)
s = max(height, width) * 1.0
trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
resized_image = cv2.resize(img, (width, height))
inp_image = cv2.warpAffine(
resized_image, trans_input, (inp_width, inp_height),
flags=cv2.INTER_LINEAR)
inp_image = ((inp_image / 255. - mean) / std).astype(np.float32)
resized_image, trans_input, (inp_width, inp_height), flags=cv2.INTER_LINEAR
)
inp_image = ((inp_image / 255.0 - mean) / std).astype(np.float32)
images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, inp_width)
images = torch.from_numpy(images).to(self.device)
meta = {'c': c, 's': s,
'input_height':inp_height,
'input_width':inp_width,
'out_height': inp_height // 4,
'out_width': inp_width // 4}
meta = {
"c": c,
"s": s,
"input_height": inp_height,
"input_width": inp_width,
"out_height": inp_height // 4,
"out_width": inp_width // 4,
}
result = {'img': images, 'meta': meta}
result = {"img": images, "meta": meta}
return result
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pred = self.infer_model(input['img'])
return {'results': pred, 'meta': input['meta']}
pred = self.infer_model(input["img"])
return {"results": pred, "meta": input["meta"]}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
output = inputs['results'][0]
meta = inputs['meta']
hm = output['hm'].sigmoid_()
v2c = output['v2c']
c2v = output['c2v']
reg = output['reg']
bbox, _ = bbox_decode(hm[:,0:1,:,:], c2v, reg=reg, K=self.K)
gbox, _ = gbox_decode(hm[:,1:2,:,:], v2c, reg=reg, K=self.MK)
output = inputs["results"][0]
meta = inputs["meta"]
hm = output["hm"].sigmoid_()
v2c = output["v2c"]
c2v = output["c2v"]
reg = output["reg"]
bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K)
gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK)
bbox = bbox.detach().cpu().numpy()
gbox = gbox.detach().cpu().numpy()
bbox = nms(bbox,0.3)
bbox = bbox_post_process(bbox.copy(),[meta['c'].cpu().numpy()],[meta['s']],meta['out_height'],meta['out_width'])
gbox = gbox_post_process(gbox.copy(),[meta['c'].cpu().numpy()],[meta['s']],meta['out_height'],meta['out_width'])
bbox = group_bbox_by_gbox(bbox[0],gbox[0])
bbox = nms(bbox, 0.3)
bbox = bbox_post_process(
bbox.copy(),
[meta["c"].cpu().numpy()],
[meta["s"]],
meta["out_height"],
meta["out_width"],
)
gbox = gbox_post_process(
gbox.copy(),
[meta["c"].cpu().numpy()],
[meta["s"]],
meta["out_height"],
meta["out_width"],
)
bbox = group_bbox_by_gbox(bbox[0], gbox[0])
res = []
for box in bbox:
if box[8] > 0.3:
res.append(box[0:8])
result = {OutputKeys.POLYGONS: np.array(res)}
return result
return result

View File

@@ -3,182 +3,183 @@ import enum
class Fields(object):
""" Names for different application fields
"""
cv = 'cv'
nlp = 'nlp'
audio = 'audio'
multi_modal = 'multi-modal'
science = 'science'
"""Names for different application fields"""
cv = "cv"
nlp = "nlp"
audio = "audio"
multi_modal = "multi-modal"
science = "science"
class CVTasks(object):
# ocr
ocr_detection = 'ocr-detection'
ocr_recognition = 'ocr-recognition'
table_recognition = 'table-recognition'
ocr_detection = "ocr-detection"
ocr_recognition = "ocr-recognition"
table_recognition = "table-recognition"
# human face body related
animal_recognition = 'animal-recognition'
face_detection = 'face-detection'
card_detection = 'card-detection'
face_recognition = 'face-recognition'
facial_expression_recognition = 'facial-expression-recognition'
face_2d_keypoints = 'face-2d-keypoints'
human_detection = 'human-detection'
human_object_interaction = 'human-object-interaction'
face_image_generation = 'face-image-generation'
body_2d_keypoints = 'body-2d-keypoints'
body_3d_keypoints = 'body-3d-keypoints'
hand_2d_keypoints = 'hand-2d-keypoints'
general_recognition = 'general-recognition'
human_wholebody_keypoint = 'human-wholebody-keypoint'
animal_recognition = "animal-recognition"
face_detection = "face-detection"
card_detection = "card-detection"
face_recognition = "face-recognition"
facial_expression_recognition = "facial-expression-recognition"
face_2d_keypoints = "face-2d-keypoints"
human_detection = "human-detection"
human_object_interaction = "human-object-interaction"
face_image_generation = "face-image-generation"
body_2d_keypoints = "body-2d-keypoints"
body_3d_keypoints = "body-3d-keypoints"
hand_2d_keypoints = "hand-2d-keypoints"
general_recognition = "general-recognition"
human_wholebody_keypoint = "human-wholebody-keypoint"
image_classification = 'image-classification'
image_multilabel_classification = 'image-multilabel-classification'
image_classification_imagenet = 'image-classification-imagenet'
image_classification_dailylife = 'image-classification-dailylife'
image_classification = "image-classification"
image_multilabel_classification = "image-multilabel-classification"
image_classification_imagenet = "image-classification-imagenet"
image_classification_dailylife = "image-classification-dailylife"
image_object_detection = 'image-object-detection'
video_object_detection = 'video-object-detection'
image_object_detection = "image-object-detection"
video_object_detection = "video-object-detection"
image_segmentation = 'image-segmentation'
semantic_segmentation = 'semantic-segmentation'
portrait_matting = 'portrait-matting'
text_driven_segmentation = 'text-driven-segmentation'
shop_segmentation = 'shop-segmentation'
hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'
image_segmentation = "image-segmentation"
semantic_segmentation = "semantic-segmentation"
portrait_matting = "portrait-matting"
text_driven_segmentation = "text-driven-segmentation"
shop_segmentation = "shop-segmentation"
hand_static = "hand-static"
face_human_hand_detection = "face-human-hand-detection"
face_emotion = "face-emotion"
product_segmentation = "product-segmentation"
crowd_counting = 'crowd-counting'
crowd_counting = "crowd-counting"
# image editing
skin_retouching = 'skin-retouching'
image_super_resolution = 'image-super-resolution'
image_colorization = 'image-colorization'
image_color_enhancement = 'image-color-enhancement'
image_denoising = 'image-denoising'
image_portrait_enhancement = 'image-portrait-enhancement'
image_inpainting = 'image-inpainting'
skin_retouching = "skin-retouching"
image_super_resolution = "image-super-resolution"
image_colorization = "image-colorization"
image_color_enhancement = "image-color-enhancement"
image_denoising = "image-denoising"
image_portrait_enhancement = "image-portrait-enhancement"
image_inpainting = "image-inpainting"
# image generation
image_to_image_translation = 'image-to-image-translation'
image_to_image_generation = 'image-to-image-generation'
image_style_transfer = 'image-style-transfer'
image_portrait_stylization = 'image-portrait-stylization'
image_body_reshaping = 'image-body-reshaping'
image_embedding = 'image-embedding'
image_to_image_translation = "image-to-image-translation"
image_to_image_generation = "image-to-image-generation"
image_style_transfer = "image-style-transfer"
image_portrait_stylization = "image-portrait-stylization"
image_body_reshaping = "image-body-reshaping"
image_embedding = "image-embedding"
product_retrieval_embedding = 'product-retrieval-embedding'
product_retrieval_embedding = "product-retrieval-embedding"
# video recognition
live_category = 'live-category'
action_recognition = 'action-recognition'
action_detection = 'action-detection'
video_category = 'video-category'
video_embedding = 'video-embedding'
virtual_try_on = 'virtual-try-on'
movie_scene_segmentation = 'movie-scene-segmentation'
live_category = "live-category"
action_recognition = "action-recognition"
action_detection = "action-detection"
video_category = "video-category"
video_embedding = "video-embedding"
virtual_try_on = "virtual-try-on"
movie_scene_segmentation = "movie-scene-segmentation"
# video segmentation
referring_video_object_segmentation = 'referring-video-object-segmentation'
referring_video_object_segmentation = "referring-video-object-segmentation"
# video editing
video_inpainting = 'video-inpainting'
video_inpainting = "video-inpainting"
# reid and tracking
video_single_object_tracking = 'video-single-object-tracking'
video_summarization = 'video-summarization'
image_reid_person = 'image-reid-person'
video_single_object_tracking = "video-single-object-tracking"
video_summarization = "video-summarization"
image_reid_person = "image-reid-person"
class NLPTasks(object):
# nlp tasks
word_segmentation = 'word-segmentation'
part_of_speech = 'part-of-speech'
named_entity_recognition = 'named-entity-recognition'
nli = 'nli'
sentiment_classification = 'sentiment-classification'
sentiment_analysis = 'sentiment-analysis'
sentence_similarity = 'sentence-similarity'
text_classification = 'text-classification'
sentence_embedding = 'sentence-embedding'
text_ranking = 'text-ranking'
relation_extraction = 'relation-extraction'
zero_shot = 'zero-shot'
translation = 'translation'
token_classification = 'token-classification'
conversational = 'conversational'
text_generation = 'text-generation'
text2text_generation = 'text2text-generation'
task_oriented_conversation = 'task-oriented-conversation'
dialog_intent_prediction = 'dialog-intent-prediction'
dialog_state_tracking = 'dialog-state-tracking'
table_question_answering = 'table-question-answering'
fill_mask = 'fill-mask'
text_summarization = 'text-summarization'
question_answering = 'question-answering'
zero_shot_classification = 'zero-shot-classification'
backbone = 'backbone'
text_error_correction = 'text-error-correction'
faq_question_answering = 'faq-question-answering'
information_extraction = 'information-extraction'
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'
word_segmentation = "word-segmentation"
part_of_speech = "part-of-speech"
named_entity_recognition = "named-entity-recognition"
nli = "nli"
sentiment_classification = "sentiment-classification"
sentiment_analysis = "sentiment-analysis"
sentence_similarity = "sentence-similarity"
text_classification = "text-classification"
sentence_embedding = "sentence-embedding"
text_ranking = "text-ranking"
relation_extraction = "relation-extraction"
zero_shot = "zero-shot"
translation = "translation"
token_classification = "token-classification"
conversational = "conversational"
text_generation = "text-generation"
text2text_generation = "text2text-generation"
task_oriented_conversation = "task-oriented-conversation"
dialog_intent_prediction = "dialog-intent-prediction"
dialog_state_tracking = "dialog-state-tracking"
table_question_answering = "table-question-answering"
fill_mask = "fill-mask"
text_summarization = "text-summarization"
question_answering = "question-answering"
zero_shot_classification = "zero-shot-classification"
backbone = "backbone"
text_error_correction = "text-error-correction"
faq_question_answering = "faq-question-answering"
information_extraction = "information-extraction"
document_segmentation = "document-segmentation"
feature_extraction = "feature-extraction"
class AudioTasks(object):
# audio tasks
auto_speech_recognition = 'auto-speech-recognition'
text_to_speech = 'text-to-speech'
speech_signal_process = 'speech-signal-process'
acoustic_echo_cancellation = 'acoustic-echo-cancellation'
acoustic_noise_suppression = 'acoustic-noise-suppression'
keyword_spotting = 'keyword-spotting'
auto_speech_recognition = "auto-speech-recognition"
text_to_speech = "text-to-speech"
speech_signal_process = "speech-signal-process"
acoustic_echo_cancellation = "acoustic-echo-cancellation"
acoustic_noise_suppression = "acoustic-noise-suppression"
keyword_spotting = "keyword-spotting"
class MultiModalTasks(object):
# multi-modal tasks
image_captioning = 'image-captioning'
visual_grounding = 'visual-grounding'
text_to_image_synthesis = 'text-to-image-synthesis'
multi_modal_embedding = 'multi-modal-embedding'
generative_multi_modal_embedding = 'generative-multi-modal-embedding'
multi_modal_similarity = 'multi-modal-similarity'
visual_question_answering = 'visual-question-answering'
visual_entailment = 'visual-entailment'
video_multi_modal_embedding = 'video-multi-modal-embedding'
image_text_retrieval = 'image-text-retrieval'
image_captioning = "image-captioning"
visual_grounding = "visual-grounding"
text_to_image_synthesis = "text-to-image-synthesis"
multi_modal_embedding = "multi-modal-embedding"
generative_multi_modal_embedding = "generative-multi-modal-embedding"
multi_modal_similarity = "multi-modal-similarity"
visual_question_answering = "visual-question-answering"
visual_entailment = "visual-entailment"
video_multi_modal_embedding = "video-multi-modal-embedding"
image_text_retrieval = "image-text-retrieval"
class ScienceTasks(object):
protein_structure = 'protein-structure'
protein_structure = "protein-structure"
class TasksIODescriptions(object):
image_to_image = 'image_to_image',
images_to_image = 'images_to_image',
image_to_text = 'image_to_text',
seed_to_image = 'seed_to_image',
text_to_speech = 'text_to_speech',
text_to_text = 'text_to_text',
speech_to_text = 'speech_to_text',
speech_to_speech = 'speech_to_speech'
speeches_to_speech = 'speeches_to_speech',
visual_grounding = 'visual_grounding',
visual_question_answering = 'visual_question_answering',
visual_entailment = 'visual_entailment',
generative_multi_modal_embedding = 'generative_multi_modal_embedding'
image_to_image = ("image_to_image",)
images_to_image = ("images_to_image",)
image_to_text = ("image_to_text",)
seed_to_image = ("seed_to_image",)
text_to_speech = ("text_to_speech",)
text_to_text = ("text_to_text",)
speech_to_text = ("speech_to_text",)
speech_to_speech = "speech_to_speech"
speeches_to_speech = ("speeches_to_speech",)
visual_grounding = ("visual_grounding",)
visual_question_answering = ("visual_question_answering",)
visual_entailment = ("visual_entailment",)
generative_multi_modal_embedding = "generative_multi_modal_embedding"
class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks):
""" Names for tasks supported by modelscope.
"""Names for tasks supported by modelscope.
Holds the standard task name to use for identifying different tasks.
This should be used to register models, pipelines, trainers.
"""
reverse_field_index = {}
@staticmethod
@@ -187,78 +188,83 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks):
# Lazy init, not thread safe
field_dict = {
Fields.cv: [
getattr(Tasks, attr) for attr in dir(CVTasks)
if not attr.startswith('__')
getattr(Tasks, attr)
for attr in dir(CVTasks)
if not attr.startswith("__")
],
Fields.nlp: [
getattr(Tasks, attr) for attr in dir(NLPTasks)
if not attr.startswith('__')
getattr(Tasks, attr)
for attr in dir(NLPTasks)
if not attr.startswith("__")
],
Fields.audio: [
getattr(Tasks, attr) for attr in dir(AudioTasks)
if not attr.startswith('__')
getattr(Tasks, attr)
for attr in dir(AudioTasks)
if not attr.startswith("__")
],
Fields.multi_modal: [
getattr(Tasks, attr) for attr in dir(MultiModalTasks)
if not attr.startswith('__')
getattr(Tasks, attr)
for attr in dir(MultiModalTasks)
if not attr.startswith("__")
],
Fields.science: [
getattr(Tasks, attr) for attr in dir(ScienceTasks)
if not attr.startswith('__')
getattr(Tasks, attr)
for attr in dir(ScienceTasks)
if not attr.startswith("__")
],
}
for field, tasks in field_dict.items():
for task in tasks:
if task in Tasks.reverse_field_index:
raise ValueError(f'Duplicate task: {task}')
raise ValueError(f"Duplicate task: {task}")
Tasks.reverse_field_index[task] = field
return Tasks.reverse_field_index.get(task_name)
class InputFields(object):
""" Names for input data fields in the input data for pipelines
"""
img = 'img'
text = 'text'
audio = 'audio'
"""Names for input data fields in the input data for pipelines"""
img = "img"
text = "text"
audio = "audio"
class Hubs(enum.Enum):
""" Source from which an entity (such as a Dataset or Model) is stored
"""
modelscope = 'modelscope'
huggingface = 'huggingface'
"""Source from which an entity (such as a Dataset or Model) is stored"""
modelscope = "modelscope"
huggingface = "huggingface"
class DownloadMode(enum.Enum):
""" How to treat existing datasets
"""
REUSE_DATASET_IF_EXISTS = 'reuse_dataset_if_exists'
FORCE_REDOWNLOAD = 'force_redownload'
"""How to treat existing datasets"""
REUSE_DATASET_IF_EXISTS = "reuse_dataset_if_exists"
FORCE_REDOWNLOAD = "force_redownload"
class DownloadChannel(enum.Enum):
""" Channels of datasets downloading for uv/pv counting.
"""
LOCAL = 'local'
DSW = 'dsw'
EAIS = 'eais'
"""Channels of datasets downloading for uv/pv counting."""
LOCAL = "local"
DSW = "dsw"
EAIS = "eais"
class UploadMode(enum.Enum):
""" How to upload object to remote.
"""
"""How to upload object to remote."""
# Upload all objects from local, existing remote objects may be overwritten. (Default)
OVERWRITE = 'overwrite'
OVERWRITE = "overwrite"
# Upload local objects in append mode, skipping all existing remote objects.
APPEND = 'append'
APPEND = "append"
class DatasetFormations(enum.Enum):
""" How a dataset is organized and interpreted
"""
"""How a dataset is organized and interpreted"""
# formation that is compatible with official huggingface dataset, which
# organizes whole dataset into one single (zip) file.
hf_compatible = 1
@@ -268,114 +274,116 @@ class DatasetFormations(enum.Enum):
DatasetMetaFormats = {
DatasetFormations.native: ['.json'],
DatasetFormations.hf_compatible: ['.py'],
DatasetFormations.native: [".json"],
DatasetFormations.hf_compatible: [".py"],
}
class ModelFile(object):
CONFIGURATION = 'configuration.json'
README = 'README.md'
TF_SAVED_MODEL_FILE = 'saved_model.pb'
TF_GRAPH_FILE = 'tf_graph.pb'
TF_CHECKPOINT_FOLDER = 'tf_ckpts'
TF_CKPT_PREFIX = 'ckpt-'
TORCH_MODEL_FILE = 'pytorch_model.pt'
TORCH_MODEL_BIN_FILE = 'pytorch_model.bin'
VOCAB_FILE = 'vocab.txt'
ONNX_MODEL_FILE = 'model.onnx'
LABEL_MAPPING = 'label_mapping.json'
TRAIN_OUTPUT_DIR = 'output'
TS_MODEL_FILE = 'model.ts'
CONFIGURATION = "configuration.json"
README = "README.md"
TF_SAVED_MODEL_FILE = "saved_model.pb"
TF_GRAPH_FILE = "tf_graph.pb"
TF_CHECKPOINT_FOLDER = "tf_ckpts"
TF_CKPT_PREFIX = "ckpt-"
TORCH_MODEL_FILE = "pytorch_model.pt"
TORCH_MODEL_BIN_FILE = "pytorch_model.bin"
VOCAB_FILE = "vocab.txt"
ONNX_MODEL_FILE = "model.onnx"
LABEL_MAPPING = "label_mapping.json"
TRAIN_OUTPUT_DIR = "output"
TS_MODEL_FILE = "model.ts"
class ConfigFields(object):
""" First level keyword in configuration file
"""
framework = 'framework'
task = 'task'
pipeline = 'pipeline'
model = 'model'
dataset = 'dataset'
preprocessor = 'preprocessor'
train = 'train'
evaluation = 'evaluation'
postprocessor = 'postprocessor'
"""First level keyword in configuration file"""
framework = "framework"
task = "task"
pipeline = "pipeline"
model = "model"
dataset = "dataset"
preprocessor = "preprocessor"
train = "train"
evaluation = "evaluation"
postprocessor = "postprocessor"
class ConfigKeys(object):
"""Fixed keywords in configuration file"""
train = 'train'
val = 'val'
test = 'test'
train = "train"
val = "val"
test = "test"
class Requirements(object):
"""Requirement names for each module
"""
protobuf = 'protobuf'
sentencepiece = 'sentencepiece'
sklearn = 'sklearn'
scipy = 'scipy'
timm = 'timm'
tokenizers = 'tokenizers'
tf = 'tf'
torch = 'torch'
"""Requirement names for each module"""
protobuf = "protobuf"
sentencepiece = "sentencepiece"
sklearn = "sklearn"
scipy = "scipy"
timm = "timm"
tokenizers = "tokenizers"
tf = "tf"
torch = "torch"
class Frameworks(object):
tf = 'tensorflow'
torch = 'pytorch'
kaldi = 'kaldi'
tf = "tensorflow"
torch = "pytorch"
kaldi = "kaldi"
DEFAULT_MODEL_REVISION = None
MASTER_MODEL_BRANCH = 'master'
DEFAULT_REPOSITORY_REVISION = 'master'
DEFAULT_DATASET_REVISION = 'master'
DEFAULT_DATASET_NAMESPACE = 'modelscope'
MASTER_MODEL_BRANCH = "master"
DEFAULT_REPOSITORY_REVISION = "master"
DEFAULT_DATASET_REVISION = "master"
DEFAULT_DATASET_NAMESPACE = "modelscope"
class ModeKeys:
TRAIN = 'train'
EVAL = 'eval'
INFERENCE = 'inference'
TRAIN = "train"
EVAL = "eval"
INFERENCE = "inference"
class LogKeys:
ITER = 'iter'
ITER_TIME = 'iter_time'
EPOCH = 'epoch'
LR = 'lr' # learning rate
MODE = 'mode'
DATA_LOAD_TIME = 'data_load_time'
ETA = 'eta' # estimated time of arrival
MEMORY = 'memory'
LOSS = 'loss'
ITER = "iter"
ITER_TIME = "iter_time"
EPOCH = "epoch"
LR = "lr" # learning rate
MODE = "mode"
DATA_LOAD_TIME = "data_load_time"
ETA = "eta" # estimated time of arrival
MEMORY = "memory"
LOSS = "loss"
class TrainerStages:
before_run = 'before_run'
before_train_epoch = 'before_train_epoch'
before_train_iter = 'before_train_iter'
after_train_iter = 'after_train_iter'
after_train_epoch = 'after_train_epoch'
before_val_epoch = 'before_val_epoch'
before_val_iter = 'before_val_iter'
after_val_iter = 'after_val_iter'
after_val_epoch = 'after_val_epoch'
after_run = 'after_run'
before_run = "before_run"
before_train_epoch = "before_train_epoch"
before_train_iter = "before_train_iter"
after_train_iter = "after_train_iter"
after_train_epoch = "after_train_epoch"
before_val_epoch = "before_val_epoch"
before_val_iter = "before_val_iter"
after_val_iter = "after_val_iter"
after_val_epoch = "after_val_epoch"
after_run = "after_run"
class ColorCodes:
MAGENTA = '\033[95m'
YELLOW = '\033[93m'
GREEN = '\033[92m'
RED = '\033[91m'
END = '\033[0m'
MAGENTA = "\033[95m"
YELLOW = "\033[93m"
GREEN = "\033[92m"
RED = "\033[91m"
END = "\033[0m"
class Devices:
"""device used for training and inference"""
cpu = 'cpu'
gpu = 'gpu'
cpu = "cpu"
gpu = "gpu"

View File

@@ -9,31 +9,30 @@ from modelscope.utils.test_utils import test_level
class TableRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.model_id = 'damo/cv_dla34_table-structure-recognition_cycle-centernet'
self.test_image = 'data/test/images/table_recognition.jpg'
self.model_id = "damo/cv_dla34_table-structure-recognition_cycle-centernet"
self.test_image = "data/test/images/table_recognition.jpg"
self.task = Tasks.table_recognition
def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
print('table recognition results: ')
print("table recognition results: ")
print(result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, "skip test in current test level")
def test_run_with_model_from_modelhub(self):
table_recognition = pipeline(Tasks.table_recognition, model=self.model_id)
self.pipeline_inference(table_recognition, self.test_image)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, "skip test in current test level")
def test_run_modelhub_default_model(self):
table_recognition = pipeline(Tasks.table_recognition)
self.pipeline_inference(table_recognition, self.test_image)
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
@unittest.skip("demo compatibility test is only enabled on a needed-basis")
def test_demo_compatibility(self):
self.compatibility_check()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()