nwjh/LLMServe/wired_table_rec/table_recover.py

218 lines
8.8 KiB
Python
Raw Normal View History

2025-03-24 09:27:03 +08:00
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from typing import Dict, List, Tuple
import numpy as np
class TableRecover:
def __init__(
self,
):
pass
def __call__(
self, polygons: np.ndarray, rows_thresh=10, col_thresh=15
) -> Dict[int, Dict]:
rows = self.get_rows(polygons, rows_thresh)
longest_col, each_col_widths, col_nums = self.get_benchmark_cols(
rows, polygons, col_thresh
)
each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)
table_res, logic_points_dict = self.get_merge_cells(
polygons,
rows,
row_nums,
col_nums,
longest_col,
each_col_widths,
each_row_heights,
)
logic_points = np.array(
[logic_points_dict[i] for i in range(len(polygons))]
).astype(np.int32)
return table_res, logic_points
@staticmethod
def get_rows(polygons: np.array, rows_thresh=10) -> Dict[int, List[int]]:
"""对每个框进行行分类,框定哪个是一行的"""
y_axis = polygons[:, 0, 1]
if y_axis.size == 1:
return {0: [0]}
concat_y = np.array(list(zip(y_axis, y_axis[1:])))
minus_res = concat_y[:, 1] - concat_y[:, 0]
result = {}
split_idxs = np.argwhere(abs(minus_res) > rows_thresh).squeeze()
# 如果都在一行,则将所有下标设置为同一行
if split_idxs.size == 0:
return {0: [i for i in range(len(y_axis))]}
if split_idxs.ndim == 0:
split_idxs = split_idxs[None, ...]
if max(split_idxs) != len(minus_res):
split_idxs = np.append(split_idxs, len(minus_res))
start_idx = 0
for row_num, idx in enumerate(split_idxs):
if row_num != 0:
start_idx = split_idxs[row_num - 1] + 1
result.setdefault(row_num, []).extend(range(start_idx, idx + 1))
# 计算每一行相邻cell的iou如果大于0.2则合并为同一个cell
return result
def get_benchmark_cols(
self, rows: Dict[int, List], polygons: np.ndarray, col_thresh=15
) -> Tuple[np.ndarray, List[float], int]:
longest_col = max(rows.values(), key=lambda x: len(x))
longest_col_points = polygons[longest_col]
longest_x_start = list(longest_col_points[:, 0, 0])
longest_x_end = list(longest_col_points[:, 2, 0])
min_x = longest_x_start[0]
max_x = longest_x_end[-1]
# 根据当前col的起始x坐标更新col的边界
def update_longest_col(col_x_list, cur_v, min_x_, max_x_):
for i, v in enumerate(col_x_list):
if cur_v - col_thresh <= v <= cur_v + col_thresh:
break
if cur_v > v:
continue
if cur_v < min_x_:
# col_x_list.insert(0, cur_v)
min_x_ = cur_v
break
if cur_v > max_x_:
# col_x_list.append(max_x_)
max_x_ = cur_v
break
if cur_v < v:
col_x_list.insert(i, cur_v)
break
return min_x_, max_x_
for row_value in rows.values():
cur_row_start = list(polygons[row_value][:, 0, 0])
cur_row_end = list(polygons[row_value][:, 2, 0])
for idx, (cur_v_start, cur_v_end) in enumerate(
zip(cur_row_start, cur_row_end)
):
min_x, max_x = update_longest_col(
longest_x_start, cur_v_start, min_x, max_x
)
min_x, max_x = update_longest_col(
longest_x_start, cur_v_end, min_x, max_x
)
longest_x_start = np.array(longest_x_start)
each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist()
each_col_widths.append(max_x - longest_x_start[-1])
col_nums = longest_x_start.shape[0]
return longest_x_start, each_col_widths, col_nums
def get_benchmark_rows(
self, rows: Dict[int, List], polygons: np.ndarray
) -> Tuple[np.ndarray, List[float], int]:
leftmost_cell_idxs = [v[0] for v in rows.values()]
benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]
each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()
# 求出最后一行cell中最大的高度作为最后一行的高度
bottommost_idxs = list(rows.values())[-1]
bottommost_boxes = polygons[bottommost_idxs]
# fix self.compute_L2(v[3, :], v[0, :]), v为逆时针即v[3]为右上v[0]为左上,v[1]为左下
max_height = max([self.compute_L2(v[1, :], v[0, :]) for v in bottommost_boxes])
each_row_widths.append(max_height)
row_nums = benchmark_x.shape[0]
return each_row_widths, row_nums
@staticmethod
def compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:
return np.linalg.norm(a2 - a1)
def get_merge_cells(
self,
polygons: np.ndarray,
rows: Dict,
row_nums: int,
col_nums: int,
longest_col: np.ndarray,
each_col_widths: List[float],
each_row_heights: List[float],
) -> Dict[int, Dict[int, int]]:
col_res_merge, row_res_merge = {}, {}
logic_points = {}
merge_thresh = 10
for cur_row, col_list in rows.items():
one_col_result, one_row_result = {}, {}
for one_col in col_list:
box = polygons[one_col]
box_width = self.compute_L2(box[3, :], box[0, :])
# 不一定是从0开始的应该综合已有值和x坐标位置来确定起始位置
loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0]))
col_start = max(sum(one_col_result.values()), loc_col_idx)
# 计算合并多少个列方向单元格
for i in range(col_start, col_nums):
col_cum_sum = sum(each_col_widths[col_start : i + 1])
if i == col_start and col_cum_sum > box_width:
one_col_result[one_col] = 1
break
elif abs(col_cum_sum - box_width) <= merge_thresh:
one_col_result[one_col] = i + 1 - col_start
break
# 这里必须进行修正,不然会出现超越阈值范围后列交错
elif col_cum_sum > box_width:
idx = (
i
if abs(col_cum_sum - box_width)
< abs(col_cum_sum - each_col_widths[i] - box_width)
else i - 1
)
one_col_result[one_col] = idx + 1 - col_start
break
else:
one_col_result[one_col] = col_nums - col_start
col_end = one_col_result[one_col] + col_start - 1
box_height = self.compute_L2(box[1, :], box[0, :])
row_start = cur_row
for j in range(row_start, row_nums):
row_cum_sum = sum(each_row_heights[row_start : j + 1])
# box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高
# 如果第一次row_cum_sum就比box_height大那么意味着丢失了一行
if j == row_start and row_cum_sum > box_height:
one_row_result[one_col] = 1
break
elif abs(box_height - row_cum_sum) <= merge_thresh:
one_row_result[one_col] = j + 1 - row_start
break
# 这里必须进行修正,不然会出现超越阈值范围后行交错
elif row_cum_sum > box_height:
idx = (
j
if abs(row_cum_sum - box_height)
< abs(row_cum_sum - each_row_heights[j] - box_height)
else j - 1
)
one_row_result[one_col] = idx + 1 - row_start
break
else:
one_row_result[one_col] = row_nums - row_start
row_end = one_row_result[one_col] + row_start - 1
logic_points[one_col] = np.array(
[row_start, row_end, col_start, col_end]
)
col_res_merge[cur_row] = one_col_result
row_res_merge[cur_row] = one_row_result
res = {}
for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):
res[i] = {k: [cc, r[k]] for k, cc in c.items()}
return res, logic_points