276 lines
11 KiB
Python
276 lines
11 KiB
Python
|
|
# -*- encoding: utf-8 -*-
|
|||
|
|
# @Author: SWHL
|
|||
|
|
# @Contact: liekkaskono@163.com
|
|||
|
|
import argparse
|
|||
|
|
import importlib
|
|||
|
|
import logging
|
|||
|
|
import time
|
|||
|
|
import traceback
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import List, Optional, Tuple, Union, Dict, Any
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
|
|||
|
|
from wired_table_rec.table_line_rec import TableLineRecognition
|
|||
|
|
from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus
|
|||
|
|
from .table_recover import TableRecover
|
|||
|
|
from .utils import InputType, LoadImage
|
|||
|
|
from .utils_table_recover import (
|
|||
|
|
match_ocr_cell,
|
|||
|
|
plot_html_table,
|
|||
|
|
box_4_2_poly_to_box_4_1,
|
|||
|
|
get_rotate_crop_image,
|
|||
|
|
sorted_ocr_boxes,
|
|||
|
|
gather_ocr_list_by_row,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
cur_dir = Path(__file__).resolve().parent
|
|||
|
|
default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx"
|
|||
|
|
default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx"
|
|||
|
|
|
|||
|
|
def convert_float32_to_float(data):
|
|||
|
|
if isinstance(data, list):
|
|||
|
|
return [convert_float32_to_float(item) for item in data]
|
|||
|
|
elif isinstance(data, np.ndarray):
|
|||
|
|
return data.tolist() # 将 NumPy 数组转换为 Python 列表
|
|||
|
|
elif isinstance(data, np.float32):
|
|||
|
|
return float(data) # 将 float32 转换为 float
|
|||
|
|
else:
|
|||
|
|
return data # 其他类型保持不变
|
|||
|
|
|
|||
|
|
|
|||
|
|
class WiredTableRecognition:
|
|||
|
|
def __init__(self, table_model_path: Union[str, Path] = None, version="v2"):
|
|||
|
|
self.load_img = LoadImage()
|
|||
|
|
if version == "v2":
|
|||
|
|
model_path = table_model_path if table_model_path else default_model_path_v2
|
|||
|
|
self.table_line_rec = TableLineRecognitionPlus(str(model_path))
|
|||
|
|
else:
|
|||
|
|
model_path = table_model_path if table_model_path else default_model_path
|
|||
|
|
self.table_line_rec = TableLineRecognition(str(model_path))
|
|||
|
|
|
|||
|
|
self.table_recover = TableRecover()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
|
|||
|
|
except ModuleNotFoundError:
|
|||
|
|
self.ocr = None
|
|||
|
|
|
|||
|
|
def __call__(
|
|||
|
|
self,
|
|||
|
|
img: InputType,
|
|||
|
|
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
|
|||
|
|
**kwargs,
|
|||
|
|
) -> Tuple[str, float, Any, Any, Any, Any, Any]:
|
|||
|
|
if self.ocr is None and ocr_result is None:
|
|||
|
|
raise ValueError(
|
|||
|
|
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
s = time.perf_counter()
|
|||
|
|
rec_again = True
|
|||
|
|
need_ocr = True
|
|||
|
|
col_threshold = 15
|
|||
|
|
row_threshold = 10
|
|||
|
|
if kwargs:
|
|||
|
|
rec_again = kwargs.get("rec_again", True)
|
|||
|
|
need_ocr = kwargs.get("need_ocr", True)
|
|||
|
|
col_threshold = kwargs.get("col_threshold", 15)
|
|||
|
|
row_threshold = kwargs.get("row_threshold", 10)
|
|||
|
|
img = self.load_img(img)
|
|||
|
|
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
|
|||
|
|
if polygons is None:
|
|||
|
|
logging.warning("polygons is None.")
|
|||
|
|
return "", 0.0, None, None, None, None, None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
table_res, logi_points = self.table_recover(
|
|||
|
|
rotated_polygons, row_threshold, col_threshold
|
|||
|
|
)
|
|||
|
|
# 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
|
|||
|
|
polygons[:, 1, :], polygons[:, 3, :] = (
|
|||
|
|
polygons[:, 3, :].copy(),
|
|||
|
|
polygons[:, 1, :].copy(),
|
|||
|
|
)
|
|||
|
|
if not need_ocr:
|
|||
|
|
sorted_polygons, idx_list = sorted_ocr_boxes(
|
|||
|
|
[box_4_2_poly_to_box_4_1(box) for box in polygons]
|
|||
|
|
)
|
|||
|
|
return (
|
|||
|
|
"",
|
|||
|
|
time.perf_counter() - s,
|
|||
|
|
sorted_polygons,
|
|||
|
|
logi_points[idx_list],
|
|||
|
|
[],
|
|||
|
|
{},
|
|||
|
|
[],
|
|||
|
|
)
|
|||
|
|
if ocr_result is None and need_ocr:
|
|||
|
|
ocr_result, _ = self.ocr(img)
|
|||
|
|
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
|
|||
|
|
# 如果有识别框没有ocr结果,直接进行rec补充
|
|||
|
|
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
|
|||
|
|
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
|
|||
|
|
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
|
|||
|
|
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
|
|||
|
|
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
|
|||
|
|
# cell_box_map =
|
|||
|
|
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
|
|||
|
|
cell_box_det_map = {
|
|||
|
|
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
|
|||
|
|
for i, t_box_ocr in enumerate(t_rec_ocr_list)
|
|||
|
|
}
|
|||
|
|
# cell_box_det_map_with_cor = {
|
|||
|
|
# i: [[ocr_box_and_text[0],ocr_box_and_text[1]] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
|
|||
|
|
# for i, t_box_ocr in enumerate(t_rec_ocr_list)
|
|||
|
|
# }
|
|||
|
|
table_str = plot_html_table(logi_points, cell_box_det_map)
|
|||
|
|
ocr_boxes_res = [
|
|||
|
|
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
|
|||
|
|
]
|
|||
|
|
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
|
|||
|
|
sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons]
|
|||
|
|
sorted_logi_points = logi_points
|
|||
|
|
table_elapse = time.perf_counter() - s
|
|||
|
|
|
|||
|
|
# for i, box in enumerate(not_match_orc_boxes):
|
|||
|
|
# cell_box_det_map[len(cell_box_det_map) + 1] = box[1]
|
|||
|
|
cell_box_det_map_with_cor = {}
|
|||
|
|
# 遍历cell_box_det_map的每个key
|
|||
|
|
for i, box in enumerate(cell_box_det_map):
|
|||
|
|
cell_box_det_map_with_cor[i] = convert_float32_to_float(sorted_polygons[i]), cell_box_det_map[i]
|
|||
|
|
|
|||
|
|
# print("not_match_orc_boxes", not_match_orc_boxes)
|
|||
|
|
|
|||
|
|
except Exception:
|
|||
|
|
logging.warning(traceback.format_exc())
|
|||
|
|
return "", 0.0, None, None, None, None, None
|
|||
|
|
return (
|
|||
|
|
table_str,
|
|||
|
|
table_elapse,
|
|||
|
|
sorted_polygons,
|
|||
|
|
sorted_logi_points,
|
|||
|
|
sorted_ocr_boxes_res,
|
|||
|
|
cell_box_det_map_with_cor,
|
|||
|
|
not_match_orc_boxes
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def transform_res(
|
|||
|
|
self,
|
|||
|
|
cell_box_det_map: Dict[int, List[any]],
|
|||
|
|
polygons: np.ndarray,
|
|||
|
|
logi_points: List[np.ndarray],
|
|||
|
|
) -> List[Dict[str, any]]:
|
|||
|
|
res = []
|
|||
|
|
for i in range(len(polygons)):
|
|||
|
|
ocr_res_list = cell_box_det_map.get(i)
|
|||
|
|
if not ocr_res_list:
|
|||
|
|
continue
|
|||
|
|
xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
|
|||
|
|
ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
|
|||
|
|
xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
|
|||
|
|
ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
|
|||
|
|
dict_res = {
|
|||
|
|
# xmin,xmax,ymin,ymax
|
|||
|
|
"t_box": [xmin, ymin, xmax, ymax],
|
|||
|
|
# row_start,row_end,col_start,col_end
|
|||
|
|
"t_logic_box": logi_points[i].tolist(),
|
|||
|
|
# [[xmin,xmax,ymin,ymax], text]
|
|||
|
|
"t_ocr_res": [
|
|||
|
|
[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
|
|||
|
|
for ocr_det in ocr_res_list
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
res.append(dict_res)
|
|||
|
|
return res
|
|||
|
|
|
|||
|
|
def sort_and_gather_ocr_res(self, res):
|
|||
|
|
for i, dict_res in enumerate(res):
|
|||
|
|
_, sorted_idx = sorted_ocr_boxes(
|
|||
|
|
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
|
|||
|
|
)
|
|||
|
|
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
|
|||
|
|
dict_res["t_ocr_res"] = gather_ocr_list_by_row(
|
|||
|
|
dict_res["t_ocr_res"], threhold=0.3
|
|||
|
|
)
|
|||
|
|
return res
|
|||
|
|
|
|||
|
|
def re_rec(
|
|||
|
|
self,
|
|||
|
|
img: np.ndarray,
|
|||
|
|
sorted_polygons: np.ndarray,
|
|||
|
|
cell_box_map: Dict[int, List[str]],
|
|||
|
|
rec_again=True,
|
|||
|
|
) -> Dict[int, List[Any]]:
|
|||
|
|
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
|
|||
|
|
for i in range(sorted_polygons.shape[0]):
|
|||
|
|
if cell_box_map.get(i):
|
|||
|
|
continue
|
|||
|
|
if not rec_again:
|
|||
|
|
box = sorted_polygons[i]
|
|||
|
|
cell_box_map[i] = [[box, "", 1]]
|
|||
|
|
continue
|
|||
|
|
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
|
|||
|
|
pad_img = cv2.copyMakeBorder(
|
|||
|
|
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
|
|||
|
|
)
|
|||
|
|
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
|
|||
|
|
box = sorted_polygons[i]
|
|||
|
|
text = [rec[0] for rec in rec_res]
|
|||
|
|
scores = [rec[1] for rec in rec_res]
|
|||
|
|
cell_box_map[i] = [[box, "".join(text), min(scores)]]
|
|||
|
|
return cell_box_map
|
|||
|
|
|
|||
|
|
def re_rec_high_precise(
|
|||
|
|
self,
|
|||
|
|
img: np.ndarray,
|
|||
|
|
sorted_polygons: np.ndarray,
|
|||
|
|
cell_box_map: Dict[int, List[str]],
|
|||
|
|
) -> Dict[int, List[any]]:
|
|||
|
|
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
|
|||
|
|
#
|
|||
|
|
cell_box_map = {}
|
|||
|
|
for i in range(sorted_polygons.shape[0]):
|
|||
|
|
if cell_box_map.get(i):
|
|||
|
|
continue
|
|||
|
|
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
|
|||
|
|
pad_img = cv2.copyMakeBorder(
|
|||
|
|
crop_img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=(255, 255, 255)
|
|||
|
|
)
|
|||
|
|
rec_res, _ = self.ocr(pad_img, use_det=True, use_cls=True, use_rec=True)
|
|||
|
|
if not rec_res:
|
|||
|
|
det_boxes = [sorted_polygons[i]]
|
|||
|
|
text = [""]
|
|||
|
|
scores = [1.0]
|
|||
|
|
else:
|
|||
|
|
det_boxes = [rec[0] for rec in rec_res]
|
|||
|
|
text = [rec[1] for rec in rec_res]
|
|||
|
|
scores = [rec[2] for rec in rec_res]
|
|||
|
|
cell_box_map[i] = [
|
|||
|
|
[box, text, score] for box, text, score in zip(det_boxes, text, scores)
|
|||
|
|
]
|
|||
|
|
return cell_box_map
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser()
|
|||
|
|
parser.add_argument("-img", "--img_path", type=str, required=True)
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
|
|||
|
|
except ModuleNotFoundError as exc:
|
|||
|
|
raise ModuleNotFoundError(
|
|||
|
|
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
|
|||
|
|
) from exc
|
|||
|
|
|
|||
|
|
table_rec = WiredTableRecognition()
|
|||
|
|
ocr_result, _ = ocr_engine(args.img_path)
|
|||
|
|
table_str, elapse = table_rec(args.img_path, ocr_result)
|
|||
|
|
print(table_str)
|
|||
|
|
print(f"cost: {elapse:.5f}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|