mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
OCR pipeline shall depend on TF only when necessary (#1059)
* ocr pipeline tf as optional Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
@@ -6,7 +6,6 @@ from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
@@ -19,18 +18,7 @@ from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .ocr_utils import (SegLinkDetector, boxes_from_bitmap, cal_width,
|
||||
combine_segments_python, decode_segments_links_python,
|
||||
nms_python, polygons_from_bitmap, rboxes_to_polygons)
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
import tf_slim as slim
|
||||
else:
|
||||
from tensorflow.contrib import slim
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
from .ocr_utils import cal_width, nms_python, rboxes_to_polygons
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -40,12 +28,6 @@ OFFSET_DIM = 6
|
||||
WORD_POLYGON_DIM = 8
|
||||
OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
tf.app.flags.DEFINE_float('node_threshold', 0.4,
|
||||
'Confidence threshold for nodes')
|
||||
tf.app.flags.DEFINE_float('link_threshold', 0.6,
|
||||
'Confidence threshold for links')
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.ocr_detection, module_name=Pipelines.ocr_detection)
|
||||
@@ -99,6 +81,16 @@ class OCRDetectionPipeline(Pipeline):
|
||||
logger.info('loading model done')
|
||||
else:
|
||||
# for model seglink++
|
||||
import tensorflow as tf
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
|
||||
tf.app.flags.DEFINE_float('node_threshold', 0.4,
|
||||
'Confidence threshold for nodes')
|
||||
tf.app.flags.DEFINE_float('link_threshold', 0.6,
|
||||
'Confidence threshold for links')
|
||||
tf.reset_default_graph()
|
||||
model_path = osp.join(
|
||||
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
|
||||
@@ -125,6 +117,7 @@ class OCRDetectionPipeline(Pipeline):
|
||||
variable_averages = tf.train.ExponentialMovingAverage(
|
||||
0.997, global_step)
|
||||
|
||||
from .ocr_utils import SegLinkDetector, combine_segments_python, decode_segments_links_python
|
||||
# detector
|
||||
detector = SegLinkDetector()
|
||||
all_maps = detector.build_model(
|
||||
@@ -198,6 +191,19 @@ class OCRDetectionPipeline(Pipeline):
|
||||
result = self.preprocessor(input)
|
||||
return result
|
||||
else:
|
||||
# for model seglink++
|
||||
import tensorflow as tf
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
|
||||
tf.app.flags.DEFINE_float('node_threshold', 0.4,
|
||||
'Confidence threshold for nodes')
|
||||
tf.app.flags.DEFINE_float('link_threshold', 0.6,
|
||||
'Confidence threshold for links')
|
||||
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
|
||||
h, w, c = img.shape
|
||||
|
||||
Reference in New Issue
Block a user