276 lines
11 KiB
Python
Raw Normal View History

2025-03-24 09:27:03 +08:00
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import argparse
import importlib
import logging
import time
import traceback
from pathlib import Path
from typing import List, Optional, Tuple, Union, Dict, Any
import numpy as np
import cv2
from wired_table_rec.table_line_rec import TableLineRecognition
from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus
from .table_recover import TableRecover
from .utils import InputType, LoadImage
from .utils_table_recover import (
match_ocr_cell,
plot_html_table,
box_4_2_poly_to_box_4_1,
get_rotate_crop_image,
sorted_ocr_boxes,
gather_ocr_list_by_row,
)
cur_dir = Path(__file__).resolve().parent
default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx"
default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx"
def convert_float32_to_float(data):
if isinstance(data, list):
return [convert_float32_to_float(item) for item in data]
elif isinstance(data, np.ndarray):
return data.tolist() # 将 NumPy 数组转换为 Python 列表
elif isinstance(data, np.float32):
return float(data) # 将 float32 转换为 float
else:
return data # 其他类型保持不变
class WiredTableRecognition:
def __init__(self, table_model_path: Union[str, Path] = None, version="v2"):
self.load_img = LoadImage()
if version == "v2":
model_path = table_model_path if table_model_path else default_model_path_v2
self.table_line_rec = TableLineRecognitionPlus(str(model_path))
else:
model_path = table_model_path if table_model_path else default_model_path
self.table_line_rec = TableLineRecognition(str(model_path))
self.table_recover = TableRecover()
try:
self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError:
self.ocr = None
def __call__(
self,
img: InputType,
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
**kwargs,
) -> Tuple[str, float, Any, Any, Any, Any, Any]:
if self.ocr is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
)
s = time.perf_counter()
rec_again = True
need_ocr = True
col_threshold = 15
row_threshold = 10
if kwargs:
rec_again = kwargs.get("rec_again", True)
need_ocr = kwargs.get("need_ocr", True)
col_threshold = kwargs.get("col_threshold", 15)
row_threshold = kwargs.get("row_threshold", 10)
img = self.load_img(img)
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
if polygons is None:
logging.warning("polygons is None.")
return "", 0.0, None, None, None, None, None
try:
table_res, logi_points = self.table_recover(
rotated_polygons, row_threshold, col_threshold
)
# 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
polygons[:, 1, :], polygons[:, 3, :] = (
polygons[:, 3, :].copy(),
polygons[:, 1, :].copy(),
)
if not need_ocr:
sorted_polygons, idx_list = sorted_ocr_boxes(
[box_4_2_poly_to_box_4_1(box) for box in polygons]
)
return (
"",
time.perf_counter() - s,
sorted_polygons,
logi_points[idx_list],
[],
{},
[],
)
if ocr_result is None and need_ocr:
ocr_result, _ = self.ocr(img)
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
# 如果有识别框没有ocr结果直接进行rec补充
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
# 转换为中间格式,修正识别框坐标,将物理识别框逻辑识别框ocr识别框整合为dict方便后续处理
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
# 将每个单元格中的ocr识别结果排序和同行合并输出的html能完整保留文字的换行格式
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
# cell_box_map =
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
cell_box_det_map = {
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
for i, t_box_ocr in enumerate(t_rec_ocr_list)
}
# cell_box_det_map_with_cor = {
# i: [[ocr_box_and_text[0],ocr_box_and_text[1]] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
# for i, t_box_ocr in enumerate(t_rec_ocr_list)
# }
table_str = plot_html_table(logi_points, cell_box_det_map)
ocr_boxes_res = [
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
]
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons]
sorted_logi_points = logi_points
table_elapse = time.perf_counter() - s
# for i, box in enumerate(not_match_orc_boxes):
# cell_box_det_map[len(cell_box_det_map) + 1] = box[1]
cell_box_det_map_with_cor = {}
# 遍历cell_box_det_map的每个key
for i, box in enumerate(cell_box_det_map):
cell_box_det_map_with_cor[i] = convert_float32_to_float(sorted_polygons[i]), cell_box_det_map[i]
# print("not_match_orc_boxes", not_match_orc_boxes)
except Exception:
logging.warning(traceback.format_exc())
return "", 0.0, None, None, None, None, None
return (
table_str,
table_elapse,
sorted_polygons,
sorted_logi_points,
sorted_ocr_boxes_res,
cell_box_det_map_with_cor,
not_match_orc_boxes
)
def transform_res(
self,
cell_box_det_map: Dict[int, List[any]],
polygons: np.ndarray,
logi_points: List[np.ndarray],
) -> List[Dict[str, any]]:
res = []
for i in range(len(polygons)):
ocr_res_list = cell_box_det_map.get(i)
if not ocr_res_list:
continue
xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
dict_res = {
# xmin,xmax,ymin,ymax
"t_box": [xmin, ymin, xmax, ymax],
# row_start,row_end,col_start,col_end
"t_logic_box": logi_points[i].tolist(),
# [[xmin,xmax,ymin,ymax], text]
"t_ocr_res": [
[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
for ocr_det in ocr_res_list
],
}
res.append(dict_res)
return res
def sort_and_gather_ocr_res(self, res):
for i, dict_res in enumerate(res):
_, sorted_idx = sorted_ocr_boxes(
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
)
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
dict_res["t_ocr_res"] = gather_ocr_list_by_row(
dict_res["t_ocr_res"], threhold=0.3
)
return res
def re_rec(
self,
img: np.ndarray,
sorted_polygons: np.ndarray,
cell_box_map: Dict[int, List[str]],
rec_again=True,
) -> Dict[int, List[Any]]:
"""找到poly对应为空的框尝试将直接将poly框直接送到识别中"""
for i in range(sorted_polygons.shape[0]):
if cell_box_map.get(i):
continue
if not rec_again:
box = sorted_polygons[i]
cell_box_map[i] = [[box, "", 1]]
continue
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
pad_img = cv2.copyMakeBorder(
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
box = sorted_polygons[i]
text = [rec[0] for rec in rec_res]
scores = [rec[1] for rec in rec_res]
cell_box_map[i] = [[box, "".join(text), min(scores)]]
return cell_box_map
def re_rec_high_precise(
self,
img: np.ndarray,
sorted_polygons: np.ndarray,
cell_box_map: Dict[int, List[str]],
) -> Dict[int, List[any]]:
"""找到poly对应为空的框尝试将直接将poly框直接送到识别中"""
#
cell_box_map = {}
for i in range(sorted_polygons.shape[0]):
if cell_box_map.get(i):
continue
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
pad_img = cv2.copyMakeBorder(
crop_img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
rec_res, _ = self.ocr(pad_img, use_det=True, use_cls=True, use_rec=True)
if not rec_res:
det_boxes = [sorted_polygons[i]]
text = [""]
scores = [1.0]
else:
det_boxes = [rec[0] for rec in rec_res]
text = [rec[1] for rec in rec_res]
scores = [rec[2] for rec in rec_res]
cell_box_map[i] = [
[box, text, score] for box, text, score in zip(det_boxes, text, scores)
]
return cell_box_map
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-img", "--img_path", type=str, required=True)
args = parser.parse_args()
try:
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
) from exc
table_rec = WiredTableRecognition()
ocr_result, _ = ocr_engine(args.img_path)
table_str, elapse = table_rec(args.img_path, ocr_result)
print(table_str)
print(f"cost: {elapse:.5f}")
if __name__ == "__main__":
main()