398 lines
13 KiB
Python
398 lines
13 KiB
Python
|
|
# -*- encoding: utf-8 -*-
|
|||
|
|
import math
|
|||
|
|
import traceback
|
|||
|
|
from io import BytesIO
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import List, Union
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
|
|||
|
|
from PIL import Image, UnidentifiedImageError
|
|||
|
|
|
|||
|
|
root_dir = Path(__file__).resolve().parent
|
|||
|
|
InputType = Union[str, np.ndarray, bytes, Path]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OrtInferSession:
|
|||
|
|
def __init__(self, model_path: Union[str, Path], num_threads: int = -1):
|
|||
|
|
self.verify_exist(model_path)
|
|||
|
|
|
|||
|
|
self.num_threads = num_threads
|
|||
|
|
self._init_sess_opt()
|
|||
|
|
|
|||
|
|
cpu_ep = "CPUExecutionProvider"
|
|||
|
|
cpu_provider_options = {
|
|||
|
|
"arena_extend_strategy": "kSameAsRequested",
|
|||
|
|
}
|
|||
|
|
EP_list = [(cpu_ep, cpu_provider_options)]
|
|||
|
|
try:
|
|||
|
|
self.session = InferenceSession(
|
|||
|
|
str(model_path), sess_options=self.sess_opt, providers=EP_list
|
|||
|
|
)
|
|||
|
|
except TypeError:
|
|||
|
|
# 这里兼容ort 1.5.2
|
|||
|
|
self.session = InferenceSession(str(model_path), sess_options=self.sess_opt)
|
|||
|
|
|
|||
|
|
def _init_sess_opt(self):
|
|||
|
|
self.sess_opt = SessionOptions()
|
|||
|
|
self.sess_opt.log_severity_level = 4
|
|||
|
|
self.sess_opt.enable_cpu_mem_arena = False
|
|||
|
|
|
|||
|
|
if self.num_threads != -1:
|
|||
|
|
self.sess_opt.intra_op_num_threads = self.num_threads
|
|||
|
|
|
|||
|
|
self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|||
|
|
|
|||
|
|
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
|
|||
|
|
input_dict = dict(zip(self.get_input_names(), input_content))
|
|||
|
|
try:
|
|||
|
|
return self.session.run(None, input_dict)
|
|||
|
|
except Exception as e:
|
|||
|
|
error_info = traceback.format_exc()
|
|||
|
|
raise ONNXRuntimeError(error_info) from e
|
|||
|
|
|
|||
|
|
def get_input_names(
|
|||
|
|
self,
|
|||
|
|
):
|
|||
|
|
return [v.name for v in self.session.get_inputs()]
|
|||
|
|
|
|||
|
|
def get_output_name(self, output_idx=0):
|
|||
|
|
return self.session.get_outputs()[output_idx].name
|
|||
|
|
|
|||
|
|
def get_metadata(self):
|
|||
|
|
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|||
|
|
return meta_dict
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def verify_exist(model_path: Union[Path, str]):
|
|||
|
|
if not isinstance(model_path, Path):
|
|||
|
|
model_path = Path(model_path)
|
|||
|
|
|
|||
|
|
if not model_path.exists():
|
|||
|
|
raise FileNotFoundError(f"{model_path} does not exist!")
|
|||
|
|
|
|||
|
|
if not model_path.is_file():
|
|||
|
|
raise FileExistsError(f"{model_path} must be a file")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ONNXRuntimeError(Exception):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LoadImage:
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def __call__(self, img: InputType) -> np.ndarray:
|
|||
|
|
if not isinstance(img, InputType.__args__):
|
|||
|
|
raise LoadImageError(
|
|||
|
|
f"The img type {type(img)} does not in {InputType.__args__}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
img = self.load_img(img)
|
|||
|
|
img = self.convert_img(img)
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
def load_img(self, img: InputType) -> np.ndarray:
|
|||
|
|
if isinstance(img, (str, Path)):
|
|||
|
|
self.verify_exist(img)
|
|||
|
|
try:
|
|||
|
|
img = np.array(Image.open(img))
|
|||
|
|
except UnidentifiedImageError as e:
|
|||
|
|
raise LoadImageError(f"cannot identify image file {img}") from e
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
if isinstance(img, bytes):
|
|||
|
|
img = np.array(Image.open(BytesIO(img)))
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
if isinstance(img, np.ndarray):
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
raise LoadImageError(f"{type(img)} is not supported!")
|
|||
|
|
|
|||
|
|
def convert_img(self, img: np.ndarray):
|
|||
|
|
if img.ndim == 2:
|
|||
|
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|||
|
|
|
|||
|
|
if img.ndim == 3:
|
|||
|
|
channel = img.shape[2]
|
|||
|
|
if channel == 1:
|
|||
|
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|||
|
|
|
|||
|
|
if channel == 2:
|
|||
|
|
return self.cvt_two_to_three(img)
|
|||
|
|
|
|||
|
|
if channel == 4:
|
|||
|
|
return self.cvt_four_to_three(img)
|
|||
|
|
|
|||
|
|
if channel == 3:
|
|||
|
|
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|||
|
|
|
|||
|
|
raise LoadImageError(
|
|||
|
|
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
|
|||
|
|
"""RGBA → BGR"""
|
|||
|
|
r, g, b, a = cv2.split(img)
|
|||
|
|
new_img = cv2.merge((b, g, r))
|
|||
|
|
|
|||
|
|
not_a = cv2.bitwise_not(a)
|
|||
|
|
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
|||
|
|
|
|||
|
|
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
|
|||
|
|
new_img = cv2.add(new_img, not_a)
|
|||
|
|
return new_img
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
|
|||
|
|
"""gray + alpha → BGR"""
|
|||
|
|
img_gray = img[..., 0]
|
|||
|
|
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
|
|||
|
|
|
|||
|
|
img_alpha = img[..., 1]
|
|||
|
|
not_a = cv2.bitwise_not(img_alpha)
|
|||
|
|
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
|||
|
|
|
|||
|
|
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
|
|||
|
|
new_img = cv2.add(new_img, not_a)
|
|||
|
|
return new_img
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def verify_exist(file_path: Union[str, Path]):
|
|||
|
|
if not Path(file_path).exists():
|
|||
|
|
raise LoadImageError(f"{file_path} does not exist.")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LoadImageError(Exception):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
|
|||
|
|
# Set pillow_interp_codes according to the naming scheme used.
|
|||
|
|
if Image is not None:
|
|||
|
|
if hasattr(Image, "Resampling"):
|
|||
|
|
pillow_interp_codes = {
|
|||
|
|
"nearest": Image.Resampling.NEAREST,
|
|||
|
|
"bilinear": Image.Resampling.BILINEAR,
|
|||
|
|
"bicubic": Image.Resampling.BICUBIC,
|
|||
|
|
"box": Image.Resampling.BOX,
|
|||
|
|
"lanczos": Image.Resampling.LANCZOS,
|
|||
|
|
"hamming": Image.Resampling.HAMMING,
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
pillow_interp_codes = {
|
|||
|
|
"nearest": Image.NEAREST,
|
|||
|
|
"bilinear": Image.BILINEAR,
|
|||
|
|
"bicubic": Image.BICUBIC,
|
|||
|
|
"box": Image.BOX,
|
|||
|
|
"lanczos": Image.LANCZOS,
|
|||
|
|
"hamming": Image.HAMMING,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
cv2_interp_codes = {
|
|||
|
|
"nearest": cv2.INTER_NEAREST,
|
|||
|
|
"bilinear": cv2.INTER_LINEAR,
|
|||
|
|
"bicubic": cv2.INTER_CUBIC,
|
|||
|
|
"area": cv2.INTER_AREA,
|
|||
|
|
"lanczos": cv2.INTER_LANCZOS4,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resize_img(img, scale, keep_ratio=True):
|
|||
|
|
if keep_ratio:
|
|||
|
|
# 缩小使用area更保真
|
|||
|
|
if min(img.shape[:2]) > min(scale):
|
|||
|
|
interpolation = "area"
|
|||
|
|
else:
|
|||
|
|
interpolation = "bicubic" # bilinear
|
|||
|
|
img_new, scale_factor = imrescale(
|
|||
|
|
img, scale, return_scale=True, interpolation=interpolation
|
|||
|
|
)
|
|||
|
|
# the w_scale and h_scale has minor difference
|
|||
|
|
# a real fix should be done in the mmcv.imrescale in the future
|
|||
|
|
new_h, new_w = img_new.shape[:2]
|
|||
|
|
h, w = img.shape[:2]
|
|||
|
|
w_scale = new_w / w
|
|||
|
|
h_scale = new_h / h
|
|||
|
|
else:
|
|||
|
|
img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
|
|||
|
|
return img_new, w_scale, h_scale
|
|||
|
|
|
|||
|
|
|
|||
|
|
def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
|
|||
|
|
"""Resize image while keeping the aspect ratio.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img (ndarray): The input image.
|
|||
|
|
scale (float | tuple[int]): The scaling factor or maximum size.
|
|||
|
|
If it is a float number, then the image will be rescaled by this
|
|||
|
|
factor, else if it is a tuple of 2 integers, then the image will
|
|||
|
|
be rescaled as large as possible within the scale.
|
|||
|
|
return_scale (bool): Whether to return the scaling factor besides the
|
|||
|
|
rescaled image.
|
|||
|
|
interpolation (str): Same as :func:`resize`.
|
|||
|
|
backend (str | None): Same as :func:`resize`.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ndarray: The rescaled image.
|
|||
|
|
"""
|
|||
|
|
h, w = img.shape[:2]
|
|||
|
|
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
|
|||
|
|
rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
|
|||
|
|
if return_scale:
|
|||
|
|
return rescaled_img, scale_factor
|
|||
|
|
else:
|
|||
|
|
return rescaled_img
|
|||
|
|
|
|||
|
|
|
|||
|
|
def imresize(
|
|||
|
|
img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
|
|||
|
|
):
|
|||
|
|
"""Resize image to a given size.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img (ndarray): The input image.
|
|||
|
|
size (tuple[int]): Target size (w, h).
|
|||
|
|
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
|||
|
|
interpolation (str): Interpolation method, accepted values are
|
|||
|
|
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
|||
|
|
backend, "nearest", "bilinear" for 'pillow' backend.
|
|||
|
|
out (ndarray): The output destination.
|
|||
|
|
backend (str | None): The image resize backend type. Options are `cv2`,
|
|||
|
|
`pillow`, `None`. If backend is None, the global imread_backend
|
|||
|
|
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
|||
|
|
`resized_img`.
|
|||
|
|
"""
|
|||
|
|
h, w = img.shape[:2]
|
|||
|
|
if backend is None:
|
|||
|
|
backend = "cv2"
|
|||
|
|
if backend not in ["cv2", "pillow"]:
|
|||
|
|
raise ValueError(
|
|||
|
|
f"backend: {backend} is not supported for resize."
|
|||
|
|
f"Supported backends are 'cv2', 'pillow'"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if backend == "pillow":
|
|||
|
|
assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
|
|||
|
|
pil_image = Image.fromarray(img)
|
|||
|
|
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
|
|||
|
|
resized_img = np.array(pil_image)
|
|||
|
|
else:
|
|||
|
|
resized_img = cv2.resize(
|
|||
|
|
img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
|
|||
|
|
)
|
|||
|
|
if not return_scale:
|
|||
|
|
return resized_img
|
|||
|
|
else:
|
|||
|
|
w_scale = size[0] / w
|
|||
|
|
h_scale = size[1] / h
|
|||
|
|
return resized_img, w_scale, h_scale
|
|||
|
|
|
|||
|
|
|
|||
|
|
def rescale_size(old_size, scale, return_scale=False):
|
|||
|
|
"""Calculate the new size to be rescaled to.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
old_size (tuple[int]): The old size (w, h) of image.
|
|||
|
|
scale (float | tuple[int]): The scaling factor or maximum size.
|
|||
|
|
If it is a float number, then the image will be rescaled by this
|
|||
|
|
factor, else if it is a tuple of 2 integers, then the image will
|
|||
|
|
be rescaled as large as possible within the scale.
|
|||
|
|
return_scale (bool): Whether to return the scaling factor besides the
|
|||
|
|
rescaled image size.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple[int]: The new rescaled image size.
|
|||
|
|
"""
|
|||
|
|
w, h = old_size
|
|||
|
|
if isinstance(scale, (float, int)):
|
|||
|
|
if scale <= 0:
|
|||
|
|
raise ValueError(f"Invalid scale {scale}, must be positive.")
|
|||
|
|
scale_factor = scale
|
|||
|
|
elif isinstance(scale, tuple):
|
|||
|
|
max_long_edge = max(scale)
|
|||
|
|
max_short_edge = min(scale)
|
|||
|
|
scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
|
|||
|
|
else:
|
|||
|
|
raise TypeError(
|
|||
|
|
f"Scale must be a number or tuple of int, but got {type(scale)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
new_size = _scale_size((w, h), scale_factor)
|
|||
|
|
|
|||
|
|
if return_scale:
|
|||
|
|
return new_size, scale_factor
|
|||
|
|
else:
|
|||
|
|
return new_size
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _scale_size(size, scale):
|
|||
|
|
"""Rescale a size by a ratio.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
size (tuple[int]): (w, h).
|
|||
|
|
scale (float | tuple(float)): Scaling factor.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple[int]: scaled size.
|
|||
|
|
"""
|
|||
|
|
if isinstance(scale, (float, int)):
|
|||
|
|
scale = (scale, scale)
|
|||
|
|
w, h = size
|
|||
|
|
return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ImageOrientationCorrector:
|
|||
|
|
"""
|
|||
|
|
对图片小角度(-90 - + 90度进行修正)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.img_loader = LoadImage()
|
|||
|
|
|
|||
|
|
def __call__(self, img: InputType):
|
|||
|
|
img = self.img_loader(img)
|
|||
|
|
# 取灰度
|
|||
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|||
|
|
# 二值化
|
|||
|
|
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
|||
|
|
# 边缘检测
|
|||
|
|
edges = cv2.Canny(gray, 100, 250, apertureSize=3)
|
|||
|
|
# 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322
|
|||
|
|
lines = cv2.HoughLines(edges, 1, np.pi / 180, 0)
|
|||
|
|
for rho, theta in lines[0]:
|
|||
|
|
a = np.cos(theta)
|
|||
|
|
b = np.sin(theta)
|
|||
|
|
x0 = a * rho
|
|||
|
|
y0 = b * rho
|
|||
|
|
x1 = int(x0 + 1000 * (-b))
|
|||
|
|
y1 = int(y0 + 1000 * (a))
|
|||
|
|
x2 = int(x0 - 1000 * (-b))
|
|||
|
|
y2 = int(y0 - 1000 * (a))
|
|||
|
|
if x1 == x2 or y1 == y2:
|
|||
|
|
return img
|
|||
|
|
else:
|
|||
|
|
t = float(y2 - y1) / (x2 - x1)
|
|||
|
|
# 得到角度后
|
|||
|
|
rotate_angle = math.degrees(math.atan(t))
|
|||
|
|
if rotate_angle > 45:
|
|||
|
|
rotate_angle = -90 + rotate_angle
|
|||
|
|
elif rotate_angle < -45:
|
|||
|
|
rotate_angle = 90 + rotate_angle
|
|||
|
|
# 旋转图像
|
|||
|
|
(h, w) = img.shape[:2]
|
|||
|
|
center = (w // 2, h // 2)
|
|||
|
|
M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0)
|
|||
|
|
return cv2.warpAffine(img, M, (w, h))
|