126 lines
3.8 KiB
Python
126 lines
3.8 KiB
Python
# -*- encoding: utf-8 -*-
|
|
# @Author: SWHL
|
|
# @Contact: liekkaskono@163.com
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from .utils import OrtInferSession
|
|
from .utils_table_line_rec import (
|
|
bbox_decode,
|
|
bbox_post_process,
|
|
gbox_decode,
|
|
gbox_post_process,
|
|
get_affine_transform,
|
|
group_bbox_by_gbox,
|
|
nms,
|
|
)
|
|
from .utils_table_recover import (
|
|
merge_adjacent_polys,
|
|
sorted_ocr_boxes,
|
|
box_4_2_poly_to_box_4_1,
|
|
filter_duplicated_box,
|
|
)
|
|
|
|
|
|
class TableLineRecognition:
|
|
def __init__(self, model_path: Optional[str] = None):
|
|
self.K = 1000
|
|
self.MK = 4000
|
|
self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
|
|
self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
|
|
|
|
self.inp_height = 1024
|
|
self.inp_width = 1024
|
|
|
|
self.session = OrtInferSession(model_path)
|
|
|
|
def __call__(
|
|
self, img: np.ndarray, **kwargs
|
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
|
img_info = self.preprocess(img)
|
|
pred = self.infer(img_info)
|
|
polygons = self.postprocess(pred)
|
|
if polygons.size == 0:
|
|
return None, None
|
|
|
|
polygons = polygons.reshape(polygons.shape[0], 4, 2)
|
|
del_idxs = filter_duplicated_box(
|
|
[box_4_2_poly_to_box_4_1(box) for box in polygons]
|
|
)
|
|
polygons = np.delete(polygons, list(del_idxs), axis=0)
|
|
_, idx = sorted_ocr_boxes(
|
|
[box_4_2_poly_to_box_4_1(box) for box in polygons], threhold=0.4
|
|
)
|
|
polygons = polygons[idx]
|
|
polygons = merge_adjacent_polys(polygons)
|
|
return polygons, polygons
|
|
|
|
def preprocess(self, img) -> Dict[str, Any]:
|
|
height, width = img.shape[:2]
|
|
|
|
resized_image = cv2.resize(img, (width, height))
|
|
|
|
c = np.array([width / 2.0, height / 2.0], dtype=np.float32)
|
|
s = max(height, width) * 1.0
|
|
trans_input = get_affine_transform(c, s, 0, [self.inp_width, self.inp_height])
|
|
|
|
inp_image = cv2.warpAffine(
|
|
resized_image,
|
|
trans_input,
|
|
(self.inp_width, self.inp_height),
|
|
flags=cv2.INTER_LINEAR,
|
|
)
|
|
|
|
inp_image = ((inp_image / 255.0 - self.mean) / self.std).astype(np.float32)
|
|
images = inp_image.transpose(2, 0, 1).reshape(
|
|
1, 3, self.inp_height, self.inp_width
|
|
)
|
|
meta = {
|
|
"c": c,
|
|
"s": s,
|
|
"input_height": self.inp_height,
|
|
"input_width": self.inp_width,
|
|
"out_height": self.inp_height // 4,
|
|
"out_width": self.inp_width // 4,
|
|
}
|
|
return {"img": images, "meta": meta}
|
|
|
|
def infer(self, input):
|
|
ort_outs = self.session(input["img"][None, ...])
|
|
pred = [
|
|
{
|
|
"hm": ort_outs[0],
|
|
"v2c": ort_outs[1],
|
|
"c2v": ort_outs[2],
|
|
"reg": ort_outs[3],
|
|
}
|
|
]
|
|
return {"results": pred, "meta": input["meta"]}
|
|
|
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
output = inputs["results"][0]
|
|
meta = inputs["meta"]
|
|
|
|
hm = self.sigmoid(output["hm"])
|
|
v2c = output["v2c"]
|
|
c2v = output["c2v"]
|
|
reg = output["reg"]
|
|
|
|
bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K)
|
|
gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK)
|
|
|
|
bbox = nms(bbox, 0.3)
|
|
c, s, h, w = [meta["c"]], [meta["s"]], meta["out_height"], meta["out_width"]
|
|
bbox = bbox_post_process(bbox.copy(), c, s, h, w)
|
|
gbox = gbox_post_process(gbox.copy(), c, s, h, w)
|
|
|
|
bbox = group_bbox_by_gbox(bbox[0], gbox[0])
|
|
polygons = [box[:8] for box in bbox if box[8] > 0.3]
|
|
return np.array(polygons)
|
|
|
|
@staticmethod
|
|
def sigmoid(data: np.ndarray):
|
|
return 1 / (1 + np.exp(-data))
|