216 lines
8.1 KiB
Python
216 lines
8.1 KiB
Python
|
|
import copy
|
||
|
|
import math
|
||
|
|
from typing import Optional, Dict, Any, Tuple
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
from skimage import measure
|
||
|
|
from wired_table_rec.utils import OrtInferSession, resize_img
|
||
|
|
from wired_table_rec.utils_table_line_rec import (
|
||
|
|
get_table_line,
|
||
|
|
final_adjust_lines,
|
||
|
|
min_area_rect_box,
|
||
|
|
draw_lines,
|
||
|
|
adjust_lines,
|
||
|
|
)
|
||
|
|
from wired_table_rec.utils_table_recover import (
|
||
|
|
sorted_ocr_boxes,
|
||
|
|
box_4_2_poly_to_box_4_1,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TableLineRecognitionPlus:
|
||
|
|
def __init__(self, model_path: Optional[str] = None):
|
||
|
|
self.K = 1000
|
||
|
|
self.MK = 4000
|
||
|
|
self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
|
||
|
|
self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
|
||
|
|
self.inp_height = 1024
|
||
|
|
self.inp_width = 1024
|
||
|
|
|
||
|
|
self.session = OrtInferSession(model_path)
|
||
|
|
|
||
|
|
def __call__(
|
||
|
|
self, img: np.ndarray, **kwargs
|
||
|
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||
|
|
img_info = self.preprocess(img)
|
||
|
|
pred = self.infer(img_info)
|
||
|
|
polygons, rotated_polygons = self.postprocess(img, pred, **kwargs)
|
||
|
|
if polygons.size == 0:
|
||
|
|
return None, None
|
||
|
|
polygons = polygons.reshape(polygons.shape[0], 4, 2)
|
||
|
|
polygons[:, 3, :], polygons[:, 1, :] = (
|
||
|
|
polygons[:, 1, :].copy(),
|
||
|
|
polygons[:, 3, :].copy(),
|
||
|
|
)
|
||
|
|
rotated_polygons = rotated_polygons.reshape(rotated_polygons.shape[0], 4, 2)
|
||
|
|
rotated_polygons[:, 3, :], rotated_polygons[:, 1, :] = (
|
||
|
|
rotated_polygons[:, 1, :].copy(),
|
||
|
|
rotated_polygons[:, 3, :].copy(),
|
||
|
|
)
|
||
|
|
_, idx = sorted_ocr_boxes(
|
||
|
|
[box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
|
||
|
|
threhold=0.4,
|
||
|
|
)
|
||
|
|
polygons = polygons[idx]
|
||
|
|
rotated_polygons = rotated_polygons[idx]
|
||
|
|
return polygons, rotated_polygons
|
||
|
|
|
||
|
|
def preprocess(self, img) -> Dict[str, Any]:
|
||
|
|
scale = (self.inp_height, self.inp_width)
|
||
|
|
img, _, _ = resize_img(img, scale, True)
|
||
|
|
img = img.copy().astype(np.float32)
|
||
|
|
assert img.dtype != np.uint8
|
||
|
|
mean = np.float64(self.mean.reshape(1, -1))
|
||
|
|
stdinv = 1 / np.float64(self.std.reshape(1, -1))
|
||
|
|
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
|
||
|
|
cv2.subtract(img, mean, img) # inplace
|
||
|
|
cv2.multiply(img, stdinv, img) # inplace
|
||
|
|
img = img.transpose(2, 0, 1)
|
||
|
|
images = img[None, :]
|
||
|
|
return {"img": images}
|
||
|
|
|
||
|
|
def infer(self, input):
|
||
|
|
result = self.session(input["img"][None, ...])[0][0]
|
||
|
|
result = result[0].astype(np.uint8)
|
||
|
|
return result
|
||
|
|
|
||
|
|
def postprocess(self, img, pred, **kwargs):
|
||
|
|
row = kwargs.get("row", 50) if kwargs else 50
|
||
|
|
col = kwargs.get("col", 30) if kwargs else 30
|
||
|
|
h_lines_threshold = kwargs.get("h_lines_threshold", 100) if kwargs else 100
|
||
|
|
v_lines_threshold = kwargs.get("v_lines_threshold", 15) if kwargs else 15
|
||
|
|
angle = kwargs.get("angle", 50) if kwargs else 50
|
||
|
|
enhance_box_line = kwargs.get("enhance_box_line", True) if kwargs else True
|
||
|
|
morph_close = (
|
||
|
|
kwargs.get("morph_close", enhance_box_line) if kwargs else enhance_box_line
|
||
|
|
) # 是否进行闭合运算以找到更多小的框
|
||
|
|
more_h_lines = (
|
||
|
|
kwargs.get("more_h_lines", enhance_box_line) if kwargs else enhance_box_line
|
||
|
|
) # 是否调整以找到更多的横线
|
||
|
|
more_v_lines = (
|
||
|
|
kwargs.get("more_v_lines", enhance_box_line) if kwargs else enhance_box_line
|
||
|
|
) # 是否调整以找到更多的横线
|
||
|
|
extend_line = (
|
||
|
|
kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
|
||
|
|
) # 是否进行线段延长使得端点连接
|
||
|
|
# 是否进行旋转修正
|
||
|
|
rotated_fix = kwargs.get("rotated_fix") if kwargs else True
|
||
|
|
ori_shape = img.shape
|
||
|
|
pred = np.uint8(pred)
|
||
|
|
hpred = copy.deepcopy(pred) # 横线
|
||
|
|
vpred = copy.deepcopy(pred) # 竖线
|
||
|
|
whereh = np.where(hpred == 1)
|
||
|
|
wherev = np.where(vpred == 2)
|
||
|
|
hpred[wherev] = 0
|
||
|
|
vpred[whereh] = 0
|
||
|
|
|
||
|
|
hpred = cv2.resize(hpred, (ori_shape[1], ori_shape[0]))
|
||
|
|
vpred = cv2.resize(vpred, (ori_shape[1], ori_shape[0]))
|
||
|
|
|
||
|
|
h, w = pred.shape
|
||
|
|
hors_k = int(math.sqrt(w) * 1.2)
|
||
|
|
vert_k = int(math.sqrt(h) * 1.2)
|
||
|
|
hkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (hors_k, 1))
|
||
|
|
vkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_k))
|
||
|
|
vpred = cv2.morphologyEx(
|
||
|
|
vpred, cv2.MORPH_CLOSE, vkernel, iterations=1
|
||
|
|
) # 先膨胀后腐蚀的过程
|
||
|
|
if morph_close:
|
||
|
|
hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1)
|
||
|
|
colboxes = get_table_line(vpred, axis=1, lineW=col) # 竖线
|
||
|
|
rowboxes = get_table_line(hpred, axis=0, lineW=row) # 横线
|
||
|
|
rboxes_row_, rboxes_col_ = [], []
|
||
|
|
if more_h_lines:
|
||
|
|
rboxes_row_ = adjust_lines(rowboxes, alph=h_lines_threshold, angle=angle)
|
||
|
|
if more_v_lines:
|
||
|
|
rboxes_col_ = adjust_lines(colboxes, alph=v_lines_threshold, angle=angle)
|
||
|
|
rowboxes += rboxes_row_
|
||
|
|
colboxes += rboxes_col_
|
||
|
|
if extend_line:
|
||
|
|
rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
|
||
|
|
line_img = np.zeros(img.shape[:2], dtype="uint8")
|
||
|
|
line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
|
||
|
|
rotated_angle = self.cal_rotate_angle(line_img)
|
||
|
|
if rotated_fix and abs(rotated_angle) > 0.3:
|
||
|
|
rotated_line_img = self.rotate_image(line_img, rotated_angle)
|
||
|
|
# print("rotated_line_img:", rotated_line_img)
|
||
|
|
rotated_polygons = self.cal_region_boxes(rotated_line_img)
|
||
|
|
# print('rotated_polygons:', rotated_polygons)
|
||
|
|
polygons = self.unrotate_polygons(
|
||
|
|
rotated_polygons, rotated_angle, line_img.shape
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
polygons = self.cal_region_boxes(line_img)
|
||
|
|
rotated_polygons = polygons.copy()
|
||
|
|
return polygons, rotated_polygons
|
||
|
|
|
||
|
|
def cal_region_boxes(self, tmp):
|
||
|
|
labels = measure.label(tmp < 255, connectivity=2) # 8连通区域标记
|
||
|
|
regions = measure.regionprops(labels)
|
||
|
|
ceilboxes = min_area_rect_box(
|
||
|
|
regions,
|
||
|
|
False,
|
||
|
|
tmp.shape[1],
|
||
|
|
tmp.shape[0],
|
||
|
|
filtersmall=True,
|
||
|
|
adjust_box=False,
|
||
|
|
) # 最后一个参数改为False
|
||
|
|
return np.array(ceilboxes)
|
||
|
|
|
||
|
|
def cal_rotate_angle(self, tmp):
|
||
|
|
# 计算最外侧的旋转框
|
||
|
|
contours, _ = cv2.findContours(tmp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
|
|
if not contours:
|
||
|
|
return 0
|
||
|
|
largest_contour = max(contours, key=cv2.contourArea)
|
||
|
|
rect = cv2.minAreaRect(largest_contour)
|
||
|
|
# 计算旋转角度
|
||
|
|
angle = rect[2]
|
||
|
|
if angle < -45:
|
||
|
|
angle += 90
|
||
|
|
elif angle > 45:
|
||
|
|
angle -= 90
|
||
|
|
return angle
|
||
|
|
|
||
|
|
def rotate_image(self, image, angle):
|
||
|
|
# 获取图像的中心点
|
||
|
|
(h, w) = image.shape[:2]
|
||
|
|
center = (w // 2, h // 2)
|
||
|
|
|
||
|
|
# 计算旋转矩阵
|
||
|
|
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||
|
|
|
||
|
|
# 进行旋转
|
||
|
|
rotated_image = cv2.warpAffine(
|
||
|
|
image, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE
|
||
|
|
)
|
||
|
|
|
||
|
|
return rotated_image
|
||
|
|
|
||
|
|
def unrotate_polygons(
|
||
|
|
self, polygons: np.ndarray, angle: float, img_shape: tuple
|
||
|
|
) -> np.ndarray:
|
||
|
|
if polygons.size == 0:
|
||
|
|
return polygons
|
||
|
|
|
||
|
|
# 将多边形旋转回原始位置
|
||
|
|
(h, w) = img_shape
|
||
|
|
center = (w // 2, h // 2)
|
||
|
|
M_inv = cv2.getRotationMatrix2D(center, -angle, 1.0)
|
||
|
|
|
||
|
|
# 将 (N, 8) 转换为 (N, 4, 2)
|
||
|
|
polygons_reshaped = polygons.reshape(-1, 4, 2)
|
||
|
|
|
||
|
|
# print('polygons_reshaped:', polygons_reshaped)
|
||
|
|
|
||
|
|
# 批量逆旋转
|
||
|
|
unrotated_polygons = cv2.transform(polygons_reshaped, M_inv)
|
||
|
|
|
||
|
|
# print('unrotated_polygons:', unrotated_polygons)
|
||
|
|
|
||
|
|
# 将 (N, 4, 2) 转换回 (N, 8)
|
||
|
|
unrotated_polygons = unrotated_polygons.reshape(-1, 8)
|
||
|
|
|
||
|
|
return unrotated_polygons
|