nwjh/LLMServe/wired_table_rec/table_line_rec.py
2025-03-24 09:27:03 +08:00

126 lines
3.8 KiB
Python

# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from typing import Any, Dict, Optional, Tuple
import cv2
import numpy as np
from .utils import OrtInferSession
from .utils_table_line_rec import (
bbox_decode,
bbox_post_process,
gbox_decode,
gbox_post_process,
get_affine_transform,
group_bbox_by_gbox,
nms,
)
from .utils_table_recover import (
merge_adjacent_polys,
sorted_ocr_boxes,
box_4_2_poly_to_box_4_1,
filter_duplicated_box,
)
class TableLineRecognition:
def __init__(self, model_path: Optional[str] = None):
self.K = 1000
self.MK = 4000
self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
self.inp_height = 1024
self.inp_width = 1024
self.session = OrtInferSession(model_path)
def __call__(
self, img: np.ndarray, **kwargs
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
img_info = self.preprocess(img)
pred = self.infer(img_info)
polygons = self.postprocess(pred)
if polygons.size == 0:
return None, None
polygons = polygons.reshape(polygons.shape[0], 4, 2)
del_idxs = filter_duplicated_box(
[box_4_2_poly_to_box_4_1(box) for box in polygons]
)
polygons = np.delete(polygons, list(del_idxs), axis=0)
_, idx = sorted_ocr_boxes(
[box_4_2_poly_to_box_4_1(box) for box in polygons], threhold=0.4
)
polygons = polygons[idx]
polygons = merge_adjacent_polys(polygons)
return polygons, polygons
def preprocess(self, img) -> Dict[str, Any]:
height, width = img.shape[:2]
resized_image = cv2.resize(img, (width, height))
c = np.array([width / 2.0, height / 2.0], dtype=np.float32)
s = max(height, width) * 1.0
trans_input = get_affine_transform(c, s, 0, [self.inp_width, self.inp_height])
inp_image = cv2.warpAffine(
resized_image,
trans_input,
(self.inp_width, self.inp_height),
flags=cv2.INTER_LINEAR,
)
inp_image = ((inp_image / 255.0 - self.mean) / self.std).astype(np.float32)
images = inp_image.transpose(2, 0, 1).reshape(
1, 3, self.inp_height, self.inp_width
)
meta = {
"c": c,
"s": s,
"input_height": self.inp_height,
"input_width": self.inp_width,
"out_height": self.inp_height // 4,
"out_width": self.inp_width // 4,
}
return {"img": images, "meta": meta}
def infer(self, input):
ort_outs = self.session(input["img"][None, ...])
pred = [
{
"hm": ort_outs[0],
"v2c": ort_outs[1],
"c2v": ort_outs[2],
"reg": ort_outs[3],
}
]
return {"results": pred, "meta": input["meta"]}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
output = inputs["results"][0]
meta = inputs["meta"]
hm = self.sigmoid(output["hm"])
v2c = output["v2c"]
c2v = output["c2v"]
reg = output["reg"]
bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K)
gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK)
bbox = nms(bbox, 0.3)
c, s, h, w = [meta["c"]], [meta["s"]], meta["out_height"], meta["out_width"]
bbox = bbox_post_process(bbox.copy(), c, s, h, w)
gbox = gbox_post_process(gbox.copy(), c, s, h, w)
bbox = group_bbox_by_gbox(bbox[0], gbox[0])
polygons = [box[:8] for box in bbox if box[8] > 0.3]
return np.array(polygons)
@staticmethod
def sigmoid(data: np.ndarray):
return 1 / (1 + np.exp(-data))