OCR pipeline shall depend on TF only when necessary (#1059)

* ocr pipeline tf as optional

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
Yingda Chen
2024-11-01 09:10:37 +08:00
committed by GitHub
parent 3de8430c03
commit fac865fd97

View File

@@ -6,7 +6,6 @@ from typing import Any, Dict
import cv2
import numpy as np
import tensorflow as tf
import torch
from modelscope.metainfo import Pipelines
@@ -19,18 +18,7 @@ from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import device_placement
from modelscope.utils.logger import get_logger
from .ocr_utils import (SegLinkDetector, boxes_from_bitmap, cal_width,
combine_segments_python, decode_segments_links_python,
nms_python, polygons_from_bitmap, rboxes_to_polygons)
if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim
if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.compat.v1.disable_eager_execution()
from .ocr_utils import cal_width, nms_python, rboxes_to_polygons
logger = get_logger()
@@ -40,12 +28,6 @@ OFFSET_DIM = 6
WORD_POLYGON_DIM = 8
OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('node_threshold', 0.4,
'Confidence threshold for nodes')
tf.app.flags.DEFINE_float('link_threshold', 0.6,
'Confidence threshold for links')
@PIPELINES.register_module(
Tasks.ocr_detection, module_name=Pipelines.ocr_detection)
@@ -99,6 +81,16 @@ class OCRDetectionPipeline(Pipeline):
logger.info('loading model done')
else:
# for model seglink++
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.compat.v1.disable_eager_execution()
tf.app.flags.DEFINE_float('node_threshold', 0.4,
'Confidence threshold for nodes')
tf.app.flags.DEFINE_float('link_threshold', 0.6,
'Confidence threshold for links')
tf.reset_default_graph()
model_path = osp.join(
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
@@ -125,6 +117,7 @@ class OCRDetectionPipeline(Pipeline):
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)
from .ocr_utils import SegLinkDetector, combine_segments_python, decode_segments_links_python
# detector
detector = SegLinkDetector()
all_maps = detector.build_model(
@@ -198,6 +191,19 @@ class OCRDetectionPipeline(Pipeline):
result = self.preprocessor(input)
return result
else:
# for model seglink++
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.compat.v1.disable_eager_execution()
tf.app.flags.DEFINE_float('node_threshold', 0.4,
'Confidence threshold for nodes')
tf.app.flags.DEFINE_float('link_threshold', 0.6,
'Confidence threshold for links')
img = LoadImage.convert_to_ndarray(input)
h, w, c = img.shape