mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42362853] formalize the output of pipeline and make pipeline reusable
* format pipeline output and check it * fix UT * add docstr to clarify the difference between model.postprocess and pipeline.postprocess Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051405
This commit is contained in:
@@ -6,7 +6,8 @@ DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE)
|
|||||||
# CUDA_VERSION = 11.3
|
# CUDA_VERSION = 11.3
|
||||||
# CUDNN_VERSION = 8
|
# CUDNN_VERSION = 8
|
||||||
BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
|
BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
|
||||||
BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
|
# BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
|
||||||
|
BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
|
||||||
|
|
||||||
|
|
||||||
MODELSCOPE_VERSION = $(shell git describe --tags --always)
|
MODELSCOPE_VERSION = $(shell git describe --tags --always)
|
||||||
|
|||||||
@@ -8,13 +8,29 @@
|
|||||||
# For reference:
|
# For reference:
|
||||||
# https://docs.docker.com/develop/develop-images/build_enhancements/
|
# https://docs.docker.com/develop/develop-images/build_enhancements/
|
||||||
|
|
||||||
#ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
|
# ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
|
||||||
#FROM ${BASE_IMAGE} as dev-base
|
# FROM ${BASE_IMAGE} as dev-base
|
||||||
|
|
||||||
FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base
|
# FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base
|
||||||
|
FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
|
||||||
|
# FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime
|
||||||
# config pip source
|
# config pip source
|
||||||
RUN mkdir /root/.pip
|
RUN mkdir /root/.pip
|
||||||
COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf
|
COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf
|
||||||
|
COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list
|
||||||
|
|
||||||
|
# Install essential Ubuntu packages
|
||||||
|
RUN apt-get update &&\
|
||||||
|
apt-get install -y software-properties-common \
|
||||||
|
build-essential \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
vim \
|
||||||
|
curl \
|
||||||
|
zip \
|
||||||
|
zlib1g-dev \
|
||||||
|
unzip \
|
||||||
|
pkg-config
|
||||||
|
|
||||||
# install modelscope and its python env
|
# install modelscope and its python env
|
||||||
WORKDIR /opt/modelscope
|
WORKDIR /opt/modelscope
|
||||||
|
|||||||
@@ -20,16 +20,24 @@ class Model(ABC):
|
|||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
|
|
||||||
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||||
return self.post_process(self.forward(input))
|
return self.postprocess(self.forward(input))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def post_process(self, input: Dict[str, Tensor],
|
def postprocess(self, input: Dict[str, Tensor],
|
||||||
**kwargs) -> Dict[str, Tensor]:
|
**kwargs) -> Dict[str, Tensor]:
|
||||||
# model specific postprocess, implementation is optional
|
""" Model specific postprocess and convert model output to
|
||||||
# will be called in Pipeline and evaluation loop(in the future)
|
standard model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: input data
|
||||||
|
|
||||||
|
Return:
|
||||||
|
dict of results: a dict containing outputs of model, each
|
||||||
|
output should have the standard output name.
|
||||||
|
"""
|
||||||
return input
|
return input
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
|
import os
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
@@ -34,6 +36,11 @@ class BertForSequenceClassification(Model):
|
|||||||
('token_type_ids', torch.LongTensor)],
|
('token_type_ids', torch.LongTensor)],
|
||||||
output_keys=['predictions', 'probabilities', 'logits'])
|
output_keys=['predictions', 'probabilities', 'logits'])
|
||||||
|
|
||||||
|
self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
|
||||||
|
with open(self.label_path) as f:
|
||||||
|
self.label_mapping = json.load(f)
|
||||||
|
self.id2label = {idx: name for name, idx in self.label_mapping.items()}
|
||||||
|
|
||||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||||
"""return the result by the model
|
"""return the result by the model
|
||||||
|
|
||||||
@@ -50,3 +57,13 @@ class BertForSequenceClassification(Model):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return self.model.predict(input)
|
return self.model.predict(input)
|
||||||
|
|
||||||
|
def postprocess(self, inputs: Dict[str, np.ndarray],
|
||||||
|
**kwargs) -> Dict[str, np.ndarray]:
|
||||||
|
# N x num_classes
|
||||||
|
probs = inputs['probabilities']
|
||||||
|
result = {
|
||||||
|
'probs': probs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from modelscope.pydatasets import PyDataset
|
|||||||
from modelscope.utils.config import Config
|
from modelscope.utils.config import Config
|
||||||
from modelscope.utils.hub import get_model_cache_dir
|
from modelscope.utils.hub import get_model_cache_dir
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
from .outputs import TASK_OUTPUTS
|
||||||
from .util import is_model_name
|
from .util import is_model_name
|
||||||
|
|
||||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||||
@@ -106,8 +107,25 @@ class Pipeline(ABC):
|
|||||||
out = self.preprocess(input)
|
out = self.preprocess(input)
|
||||||
out = self.forward(out)
|
out = self.forward(out)
|
||||||
out = self.postprocess(out, **post_kwargs)
|
out = self.postprocess(out, **post_kwargs)
|
||||||
|
self._check_output(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def _check_output(self, input):
|
||||||
|
# this attribute is dynamically attached by registry
|
||||||
|
# when cls is registered in registry using task name
|
||||||
|
task_name = self.group_key
|
||||||
|
if task_name not in TASK_OUTPUTS:
|
||||||
|
logger.warning(f'task {task_name} output keys are missing')
|
||||||
|
return
|
||||||
|
output_keys = TASK_OUTPUTS[task_name]
|
||||||
|
missing_keys = []
|
||||||
|
for k in output_keys:
|
||||||
|
if k not in input:
|
||||||
|
missing_keys.append(k)
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
raise ValueError(f'expected output keys are {output_keys}, '
|
||||||
|
f'those {missing_keys} are missing')
|
||||||
|
|
||||||
def preprocess(self, inputs: Input) -> Dict[str, Any]:
|
def preprocess(self, inputs: Input) -> Dict[str, Any]:
|
||||||
""" Provide default implementation based on preprocess_cfg and user can reimplement it
|
""" Provide default implementation based on preprocess_cfg and user can reimplement it
|
||||||
"""
|
"""
|
||||||
@@ -125,4 +143,14 @@ class Pipeline(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
""" If current pipeline support model reuse, common postprocess
|
||||||
|
code should be write here.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: input data
|
||||||
|
|
||||||
|
Return:
|
||||||
|
dict of results: a dict containing outputs of model, each
|
||||||
|
output should have the standard output name.
|
||||||
|
"""
|
||||||
raise NotImplementedError('postprocess')
|
raise NotImplementedError('postprocess')
|
||||||
|
|||||||
@@ -41,50 +41,29 @@ class SequenceClassificationPipeline(Pipeline):
|
|||||||
second_sequence=None)
|
second_sequence=None)
|
||||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||||
|
|
||||||
from easynlp.utils import io
|
assert hasattr(self.model, 'id2label'), \
|
||||||
self.label_path = os.path.join(sc_model.model_dir,
|
'id2label map should be initalizaed in init function.'
|
||||||
'label_mapping.json')
|
|
||||||
with io.open(self.label_path) as f:
|
|
||||||
self.label_mapping = json.load(f)
|
|
||||||
self.label_id_to_name = {
|
|
||||||
idx: name
|
|
||||||
for name, idx in self.label_mapping.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def postprocess(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
topk: int = 5) -> Dict[str, str]:
|
||||||
"""process the prediction results
|
"""process the prediction results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (Dict[str, Any]): _description_
|
inputs (Dict[str, Any]): input data dict
|
||||||
|
topk (int): return topk classification result.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, str]: the prediction results
|
Dict[str, str]: the prediction results
|
||||||
"""
|
"""
|
||||||
|
# NxC np.ndarray
|
||||||
|
probs = inputs['probs'][0]
|
||||||
|
num_classes = probs.shape[0]
|
||||||
|
topk = min(topk, num_classes)
|
||||||
|
top_indices = np.argpartition(probs, -topk)[-topk:]
|
||||||
|
cls_ids = top_indices[np.argsort(probs[top_indices])]
|
||||||
|
probs = probs[cls_ids].tolist()
|
||||||
|
|
||||||
probs = inputs['probabilities']
|
cls_names = [self.model.id2label[cid] for cid in cls_ids]
|
||||||
logits = inputs['logits']
|
|
||||||
predictions = np.argsort(-probs, axis=-1)
|
|
||||||
preds = predictions[0]
|
|
||||||
b = 0
|
|
||||||
new_result = list()
|
|
||||||
for pred in preds:
|
|
||||||
new_result.append({
|
|
||||||
'pred': self.label_id_to_name[pred],
|
|
||||||
'prob': float(probs[b][pred]),
|
|
||||||
'logit': float(logits[b][pred])
|
|
||||||
})
|
|
||||||
new_results = list()
|
|
||||||
new_results.append({
|
|
||||||
'id':
|
|
||||||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
|
|
||||||
'output':
|
|
||||||
new_result,
|
|
||||||
'predictions':
|
|
||||||
new_result[0]['pred'],
|
|
||||||
'probabilities':
|
|
||||||
','.join([str(t) for t in inputs['probabilities'][b]]),
|
|
||||||
'logits':
|
|
||||||
','.join([str(t) for t in inputs['logits'][b]])
|
|
||||||
})
|
|
||||||
|
|
||||||
return new_results[0]
|
return {'scores': probs, 'labels': cls_names}
|
||||||
|
|||||||
@@ -56,4 +56,4 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
'').split('[SEP]')[0].replace('[CLS]',
|
'').split('[SEP]')[0].replace('[CLS]',
|
||||||
'').replace('[SEP]',
|
'').replace('[SEP]',
|
||||||
'').replace('[UNK]', '')
|
'').replace('[UNK]', '')
|
||||||
return {'pred_string': pred_string}
|
return {'text': pred_string}
|
||||||
|
|||||||
98
modelscope/pipelines/outputs.py
Normal file
98
modelscope/pipelines/outputs.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
TASK_OUTPUTS = {
|
||||||
|
|
||||||
|
# ============ vision tasks ===================
|
||||||
|
|
||||||
|
# image classification result for single sample
|
||||||
|
# {
|
||||||
|
# "labels": ["dog", "horse", "cow", "cat"],
|
||||||
|
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||||
|
# }
|
||||||
|
Tasks.image_classification: ['scores', 'labels'],
|
||||||
|
Tasks.image_tagging: ['scores', 'labels'],
|
||||||
|
|
||||||
|
# object detection result for single sample
|
||||||
|
# {
|
||||||
|
# "boxes": [
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# ],
|
||||||
|
# "labels": ["dog", "horse", "cow", "cat"],
|
||||||
|
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||||
|
# }
|
||||||
|
Tasks.object_detection: ['scores', 'labels', 'boxes'],
|
||||||
|
|
||||||
|
# instance segmentation result for single sample
|
||||||
|
# {
|
||||||
|
# "masks": [
|
||||||
|
# np.array in bgr channel order
|
||||||
|
# ],
|
||||||
|
# "labels": ["dog", "horse", "cow", "cat"],
|
||||||
|
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||||
|
# }
|
||||||
|
Tasks.image_segmentation: ['scores', 'labels', 'boxes'],
|
||||||
|
|
||||||
|
# image generation/editing/matting result for single sample
|
||||||
|
# {
|
||||||
|
# "output_png": np.array with shape(h, w, 4)
|
||||||
|
# for matting or (h, w, 3) for general purpose
|
||||||
|
# }
|
||||||
|
Tasks.image_editing: ['output_png'],
|
||||||
|
Tasks.image_matting: ['output_png'],
|
||||||
|
Tasks.image_generation: ['output_png'],
|
||||||
|
|
||||||
|
# pose estimation result for single sample
|
||||||
|
# {
|
||||||
|
# "poses": np.array with shape [num_pose, num_keypoint, 3],
|
||||||
|
# each keypoint is a array [x, y, score]
|
||||||
|
# "boxes": np.array with shape [num_pose, 4], each box is
|
||||||
|
# [x1, y1, x2, y2]
|
||||||
|
# }
|
||||||
|
Tasks.pose_estimation: ['poses', 'boxes'],
|
||||||
|
|
||||||
|
# ============ nlp tasks ===================
|
||||||
|
|
||||||
|
# text classification result for single sample
|
||||||
|
# {
|
||||||
|
# "labels": ["happy", "sad", "calm", "angry"],
|
||||||
|
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||||
|
# }
|
||||||
|
Tasks.text_classification: ['scores', 'labels'],
|
||||||
|
|
||||||
|
# text generation result for single sample
|
||||||
|
# {
|
||||||
|
# "text": "this is text generated by a model."
|
||||||
|
# }
|
||||||
|
Tasks.text_generation: ['text'],
|
||||||
|
|
||||||
|
# ============ audio tasks ===================
|
||||||
|
|
||||||
|
# ============ multi-modal tasks ===================
|
||||||
|
|
||||||
|
# image caption result for single sample
|
||||||
|
# {
|
||||||
|
# "caption": "this is an image caption text."
|
||||||
|
# }
|
||||||
|
Tasks.image_captioning: ['caption'],
|
||||||
|
|
||||||
|
# visual grounding result for single sample
|
||||||
|
# {
|
||||||
|
# "boxes": [
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# [x1, y1, x2, y2],
|
||||||
|
# ],
|
||||||
|
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||||
|
# }
|
||||||
|
Tasks.visual_grounding: ['boxes', 'scores'],
|
||||||
|
|
||||||
|
# text_to_image result for a single sample
|
||||||
|
# {
|
||||||
|
# "image": np.ndarray with shape [height, width, 3]
|
||||||
|
# }
|
||||||
|
Tasks.text_to_image_synthesis: ['image']
|
||||||
|
}
|
||||||
@@ -51,7 +51,7 @@ class Tasks(object):
|
|||||||
text_to_speech = 'text-to-speech'
|
text_to_speech = 'text-to-speech'
|
||||||
speech_signal_process = 'speech-signal-process'
|
speech_signal_process = 'speech-signal-process'
|
||||||
|
|
||||||
# multi-media
|
# multi-modal tasks
|
||||||
image_captioning = 'image-captioning'
|
image_captioning = 'image-captioning'
|
||||||
visual_grounding = 'visual-grounding'
|
visual_grounding = 'visual-grounding'
|
||||||
text_to_image_synthesis = 'text-to-image-synthesis'
|
text_to_image_synthesis = 'text-to-image-synthesis'
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class Registry(object):
|
|||||||
f'{self._name}[{group_key}]')
|
f'{self._name}[{group_key}]')
|
||||||
|
|
||||||
self._modules[group_key][module_name] = module_cls
|
self._modules[group_key][module_name] = module_cls
|
||||||
|
module_cls.group_key = group_key
|
||||||
|
|
||||||
if module_name in self._modules[default_group]:
|
if module_name in self._modules[default_group]:
|
||||||
if id(self._modules[default_group][module_name]) == id(module_cls):
|
if id(self._modules[default_group][module_name]) == id(module_cls):
|
||||||
|
|||||||
@@ -35,9 +35,10 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
CustomPipeline1()
|
CustomPipeline1()
|
||||||
|
|
||||||
def test_custom(self):
|
def test_custom(self):
|
||||||
|
dummy_task = 'dummy-task'
|
||||||
|
|
||||||
@PIPELINES.register_module(
|
@PIPELINES.register_module(
|
||||||
group_key=Tasks.image_tagging, module_name='custom-image')
|
group_key=dummy_task, module_name='custom-image')
|
||||||
class CustomImagePipeline(Pipeline):
|
class CustomImagePipeline(Pipeline):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -67,32 +68,29 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
outputs['filename'] = inputs['url']
|
outputs['filename'] = inputs['url']
|
||||||
img = inputs['img']
|
img = inputs['img']
|
||||||
new_image = img.resize((img.width // 2, img.height // 2))
|
new_image = img.resize((img.width // 2, img.height // 2))
|
||||||
outputs['resize_image'] = np.array(new_image)
|
outputs['output_png'] = np.array(new_image)
|
||||||
outputs['dummy_result'] = 'dummy_result'
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
self.assertTrue('custom-image' in PIPELINES.modules[default_group])
|
self.assertTrue('custom-image' in PIPELINES.modules[default_group])
|
||||||
add_default_pipeline_info(Tasks.image_tagging, 'custom-image')
|
add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True)
|
||||||
pipe = pipeline(pipeline_name='custom-image')
|
pipe = pipeline(pipeline_name='custom-image')
|
||||||
pipe2 = pipeline(Tasks.image_tagging)
|
pipe2 = pipeline(dummy_task)
|
||||||
self.assertTrue(type(pipe) is type(pipe2))
|
self.assertTrue(type(pipe) is type(pipe2))
|
||||||
|
|
||||||
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \
|
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \
|
||||||
'aliyuncs.com/data/test/images/image1.jpg'
|
'aliyuncs.com/data/test/images/image1.jpg'
|
||||||
output = pipe(img_url)
|
output = pipe(img_url)
|
||||||
self.assertEqual(output['filename'], img_url)
|
self.assertEqual(output['filename'], img_url)
|
||||||
self.assertEqual(output['resize_image'].shape, (318, 512, 3))
|
self.assertEqual(output['output_png'].shape, (318, 512, 3))
|
||||||
self.assertEqual(output['dummy_result'], 'dummy_result')
|
|
||||||
|
|
||||||
outputs = pipe([img_url for i in range(4)])
|
outputs = pipe([img_url for i in range(4)])
|
||||||
self.assertEqual(len(outputs), 4)
|
self.assertEqual(len(outputs), 4)
|
||||||
for out in outputs:
|
for out in outputs:
|
||||||
self.assertEqual(out['filename'], img_url)
|
self.assertEqual(out['filename'], img_url)
|
||||||
self.assertEqual(out['resize_image'].shape, (318, 512, 3))
|
self.assertEqual(out['output_png'].shape, (318, 512, 3))
|
||||||
self.assertEqual(out['dummy_result'], 'dummy_result')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user