From fac865fd97494822fa38b1cfd7542266d8da80e8 Mon Sep 17 00:00:00 2001 From: Yingda Chen Date: Fri, 1 Nov 2024 09:10:37 +0800 Subject: [PATCH] OCR pipeline shall depend on TF only when necessary (#1059) * ocr pipeline tf as optional Co-authored-by: Yingda Chen --- .../pipelines/cv/ocr_detection_pipeline.py | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index cb7522c0..c23f6e6e 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -6,7 +6,6 @@ from typing import Any, Dict import cv2 import numpy as np -import tensorflow as tf import torch from modelscope.metainfo import Pipelines @@ -19,18 +18,7 @@ from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger -from .ocr_utils import (SegLinkDetector, boxes_from_bitmap, cal_width, - combine_segments_python, decode_segments_links_python, - nms_python, polygons_from_bitmap, rboxes_to_polygons) - -if tf.__version__ >= '2.0': - import tf_slim as slim -else: - from tensorflow.contrib import slim - -if tf.__version__ >= '2.0': - tf = tf.compat.v1 -tf.compat.v1.disable_eager_execution() +from .ocr_utils import cal_width, nms_python, rboxes_to_polygons logger = get_logger() @@ -40,12 +28,6 @@ OFFSET_DIM = 6 WORD_POLYGON_DIM = 8 OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1] -FLAGS = tf.app.flags.FLAGS -tf.app.flags.DEFINE_float('node_threshold', 0.4, - 'Confidence threshold for nodes') -tf.app.flags.DEFINE_float('link_threshold', 0.6, - 'Confidence threshold for links') - @PIPELINES.register_module( Tasks.ocr_detection, module_name=Pipelines.ocr_detection) @@ -99,6 +81,16 @@ class OCRDetectionPipeline(Pipeline): logger.info('loading model done') else: # for model seglink++ + import tensorflow as tf + + if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.compat.v1.disable_eager_execution() + + tf.app.flags.DEFINE_float('node_threshold', 0.4, + 'Confidence threshold for nodes') + tf.app.flags.DEFINE_float('link_threshold', 0.6, + 'Confidence threshold for links') tf.reset_default_graph() model_path = osp.join( osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), @@ -125,6 +117,7 @@ class OCRDetectionPipeline(Pipeline): variable_averages = tf.train.ExponentialMovingAverage( 0.997, global_step) + from .ocr_utils import SegLinkDetector, combine_segments_python, decode_segments_links_python # detector detector = SegLinkDetector() all_maps = detector.build_model( @@ -198,6 +191,19 @@ class OCRDetectionPipeline(Pipeline): result = self.preprocessor(input) return result else: + # for model seglink++ + import tensorflow as tf + + if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + tf.compat.v1.disable_eager_execution() + + tf.app.flags.DEFINE_float('node_threshold', 0.4, + 'Confidence threshold for nodes') + tf.app.flags.DEFINE_float('link_threshold', 0.6, + 'Confidence threshold for links') + img = LoadImage.convert_to_ndarray(input) h, w, c = img.shape