276 lines
11 KiB
Python
276 lines
11 KiB
Python
# -*- encoding: utf-8 -*-
|
||
# @Author: SWHL
|
||
# @Contact: liekkaskono@163.com
|
||
import argparse
|
||
import importlib
|
||
import logging
|
||
import time
|
||
import traceback
|
||
from pathlib import Path
|
||
from typing import List, Optional, Tuple, Union, Dict, Any
|
||
import numpy as np
|
||
import cv2
|
||
|
||
from wired_table_rec.table_line_rec import TableLineRecognition
|
||
from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus
|
||
from .table_recover import TableRecover
|
||
from .utils import InputType, LoadImage
|
||
from .utils_table_recover import (
|
||
match_ocr_cell,
|
||
plot_html_table,
|
||
box_4_2_poly_to_box_4_1,
|
||
get_rotate_crop_image,
|
||
sorted_ocr_boxes,
|
||
gather_ocr_list_by_row,
|
||
)
|
||
|
||
cur_dir = Path(__file__).resolve().parent
|
||
default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx"
|
||
default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx"
|
||
|
||
def convert_float32_to_float(data):
|
||
if isinstance(data, list):
|
||
return [convert_float32_to_float(item) for item in data]
|
||
elif isinstance(data, np.ndarray):
|
||
return data.tolist() # 将 NumPy 数组转换为 Python 列表
|
||
elif isinstance(data, np.float32):
|
||
return float(data) # 将 float32 转换为 float
|
||
else:
|
||
return data # 其他类型保持不变
|
||
|
||
|
||
class WiredTableRecognition:
|
||
def __init__(self, table_model_path: Union[str, Path] = None, version="v2"):
|
||
self.load_img = LoadImage()
|
||
if version == "v2":
|
||
model_path = table_model_path if table_model_path else default_model_path_v2
|
||
self.table_line_rec = TableLineRecognitionPlus(str(model_path))
|
||
else:
|
||
model_path = table_model_path if table_model_path else default_model_path
|
||
self.table_line_rec = TableLineRecognition(str(model_path))
|
||
|
||
self.table_recover = TableRecover()
|
||
|
||
try:
|
||
self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
|
||
except ModuleNotFoundError:
|
||
self.ocr = None
|
||
|
||
def __call__(
|
||
self,
|
||
img: InputType,
|
||
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
|
||
**kwargs,
|
||
) -> Tuple[str, float, Any, Any, Any, Any, Any]:
|
||
if self.ocr is None and ocr_result is None:
|
||
raise ValueError(
|
||
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
|
||
)
|
||
|
||
s = time.perf_counter()
|
||
rec_again = True
|
||
need_ocr = True
|
||
col_threshold = 15
|
||
row_threshold = 10
|
||
if kwargs:
|
||
rec_again = kwargs.get("rec_again", True)
|
||
need_ocr = kwargs.get("need_ocr", True)
|
||
col_threshold = kwargs.get("col_threshold", 15)
|
||
row_threshold = kwargs.get("row_threshold", 10)
|
||
img = self.load_img(img)
|
||
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
|
||
if polygons is None:
|
||
logging.warning("polygons is None.")
|
||
return "", 0.0, None, None, None, None, None
|
||
|
||
try:
|
||
table_res, logi_points = self.table_recover(
|
||
rotated_polygons, row_threshold, col_threshold
|
||
)
|
||
# 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
|
||
polygons[:, 1, :], polygons[:, 3, :] = (
|
||
polygons[:, 3, :].copy(),
|
||
polygons[:, 1, :].copy(),
|
||
)
|
||
if not need_ocr:
|
||
sorted_polygons, idx_list = sorted_ocr_boxes(
|
||
[box_4_2_poly_to_box_4_1(box) for box in polygons]
|
||
)
|
||
return (
|
||
"",
|
||
time.perf_counter() - s,
|
||
sorted_polygons,
|
||
logi_points[idx_list],
|
||
[],
|
||
{},
|
||
[],
|
||
)
|
||
if ocr_result is None and need_ocr:
|
||
ocr_result, _ = self.ocr(img)
|
||
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
|
||
# 如果有识别框没有ocr结果,直接进行rec补充
|
||
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
|
||
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
|
||
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
|
||
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
|
||
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
|
||
# cell_box_map =
|
||
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
|
||
cell_box_det_map = {
|
||
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
|
||
for i, t_box_ocr in enumerate(t_rec_ocr_list)
|
||
}
|
||
# cell_box_det_map_with_cor = {
|
||
# i: [[ocr_box_and_text[0],ocr_box_and_text[1]] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
|
||
# for i, t_box_ocr in enumerate(t_rec_ocr_list)
|
||
# }
|
||
table_str = plot_html_table(logi_points, cell_box_det_map)
|
||
ocr_boxes_res = [
|
||
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
|
||
]
|
||
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
|
||
sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons]
|
||
sorted_logi_points = logi_points
|
||
table_elapse = time.perf_counter() - s
|
||
|
||
# for i, box in enumerate(not_match_orc_boxes):
|
||
# cell_box_det_map[len(cell_box_det_map) + 1] = box[1]
|
||
cell_box_det_map_with_cor = {}
|
||
# 遍历cell_box_det_map的每个key
|
||
for i, box in enumerate(cell_box_det_map):
|
||
cell_box_det_map_with_cor[i] = convert_float32_to_float(sorted_polygons[i]), cell_box_det_map[i]
|
||
|
||
# print("not_match_orc_boxes", not_match_orc_boxes)
|
||
|
||
except Exception:
|
||
logging.warning(traceback.format_exc())
|
||
return "", 0.0, None, None, None, None, None
|
||
return (
|
||
table_str,
|
||
table_elapse,
|
||
sorted_polygons,
|
||
sorted_logi_points,
|
||
sorted_ocr_boxes_res,
|
||
cell_box_det_map_with_cor,
|
||
not_match_orc_boxes
|
||
)
|
||
|
||
def transform_res(
|
||
self,
|
||
cell_box_det_map: Dict[int, List[any]],
|
||
polygons: np.ndarray,
|
||
logi_points: List[np.ndarray],
|
||
) -> List[Dict[str, any]]:
|
||
res = []
|
||
for i in range(len(polygons)):
|
||
ocr_res_list = cell_box_det_map.get(i)
|
||
if not ocr_res_list:
|
||
continue
|
||
xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
|
||
ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
|
||
xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
|
||
ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
|
||
dict_res = {
|
||
# xmin,xmax,ymin,ymax
|
||
"t_box": [xmin, ymin, xmax, ymax],
|
||
# row_start,row_end,col_start,col_end
|
||
"t_logic_box": logi_points[i].tolist(),
|
||
# [[xmin,xmax,ymin,ymax], text]
|
||
"t_ocr_res": [
|
||
[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
|
||
for ocr_det in ocr_res_list
|
||
],
|
||
}
|
||
res.append(dict_res)
|
||
return res
|
||
|
||
def sort_and_gather_ocr_res(self, res):
|
||
for i, dict_res in enumerate(res):
|
||
_, sorted_idx = sorted_ocr_boxes(
|
||
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
|
||
)
|
||
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
|
||
dict_res["t_ocr_res"] = gather_ocr_list_by_row(
|
||
dict_res["t_ocr_res"], threhold=0.3
|
||
)
|
||
return res
|
||
|
||
def re_rec(
|
||
self,
|
||
img: np.ndarray,
|
||
sorted_polygons: np.ndarray,
|
||
cell_box_map: Dict[int, List[str]],
|
||
rec_again=True,
|
||
) -> Dict[int, List[Any]]:
|
||
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
|
||
for i in range(sorted_polygons.shape[0]):
|
||
if cell_box_map.get(i):
|
||
continue
|
||
if not rec_again:
|
||
box = sorted_polygons[i]
|
||
cell_box_map[i] = [[box, "", 1]]
|
||
continue
|
||
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
|
||
pad_img = cv2.copyMakeBorder(
|
||
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
|
||
)
|
||
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
|
||
box = sorted_polygons[i]
|
||
text = [rec[0] for rec in rec_res]
|
||
scores = [rec[1] for rec in rec_res]
|
||
cell_box_map[i] = [[box, "".join(text), min(scores)]]
|
||
return cell_box_map
|
||
|
||
def re_rec_high_precise(
|
||
self,
|
||
img: np.ndarray,
|
||
sorted_polygons: np.ndarray,
|
||
cell_box_map: Dict[int, List[str]],
|
||
) -> Dict[int, List[any]]:
|
||
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
|
||
#
|
||
cell_box_map = {}
|
||
for i in range(sorted_polygons.shape[0]):
|
||
if cell_box_map.get(i):
|
||
continue
|
||
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
|
||
pad_img = cv2.copyMakeBorder(
|
||
crop_img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=(255, 255, 255)
|
||
)
|
||
rec_res, _ = self.ocr(pad_img, use_det=True, use_cls=True, use_rec=True)
|
||
if not rec_res:
|
||
det_boxes = [sorted_polygons[i]]
|
||
text = [""]
|
||
scores = [1.0]
|
||
else:
|
||
det_boxes = [rec[0] for rec in rec_res]
|
||
text = [rec[1] for rec in rec_res]
|
||
scores = [rec[2] for rec in rec_res]
|
||
cell_box_map[i] = [
|
||
[box, text, score] for box, text, score in zip(det_boxes, text, scores)
|
||
]
|
||
return cell_box_map
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("-img", "--img_path", type=str, required=True)
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
|
||
except ModuleNotFoundError as exc:
|
||
raise ModuleNotFoundError(
|
||
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
|
||
) from exc
|
||
|
||
table_rec = WiredTableRecognition()
|
||
ocr_result, _ = ocr_engine(args.img_path)
|
||
table_str, elapse = table_rec(args.img_path, ocr_result)
|
||
print(table_str)
|
||
print(f"cost: {elapse:.5f}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|