398 lines
13 KiB
Python
398 lines
13 KiB
Python
# -*- encoding: utf-8 -*-
|
||
import math
|
||
import traceback
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
from typing import List, Union
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
|
||
from PIL import Image, UnidentifiedImageError
|
||
|
||
root_dir = Path(__file__).resolve().parent
|
||
InputType = Union[str, np.ndarray, bytes, Path]
|
||
|
||
|
||
class OrtInferSession:
|
||
def __init__(self, model_path: Union[str, Path], num_threads: int = -1):
|
||
self.verify_exist(model_path)
|
||
|
||
self.num_threads = num_threads
|
||
self._init_sess_opt()
|
||
|
||
cpu_ep = "CPUExecutionProvider"
|
||
cpu_provider_options = {
|
||
"arena_extend_strategy": "kSameAsRequested",
|
||
}
|
||
EP_list = [(cpu_ep, cpu_provider_options)]
|
||
try:
|
||
self.session = InferenceSession(
|
||
str(model_path), sess_options=self.sess_opt, providers=EP_list
|
||
)
|
||
except TypeError:
|
||
# 这里兼容ort 1.5.2
|
||
self.session = InferenceSession(str(model_path), sess_options=self.sess_opt)
|
||
|
||
def _init_sess_opt(self):
|
||
self.sess_opt = SessionOptions()
|
||
self.sess_opt.log_severity_level = 4
|
||
self.sess_opt.enable_cpu_mem_arena = False
|
||
|
||
if self.num_threads != -1:
|
||
self.sess_opt.intra_op_num_threads = self.num_threads
|
||
|
||
self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
|
||
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
|
||
input_dict = dict(zip(self.get_input_names(), input_content))
|
||
try:
|
||
return self.session.run(None, input_dict)
|
||
except Exception as e:
|
||
error_info = traceback.format_exc()
|
||
raise ONNXRuntimeError(error_info) from e
|
||
|
||
def get_input_names(
|
||
self,
|
||
):
|
||
return [v.name for v in self.session.get_inputs()]
|
||
|
||
def get_output_name(self, output_idx=0):
|
||
return self.session.get_outputs()[output_idx].name
|
||
|
||
def get_metadata(self):
|
||
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
||
return meta_dict
|
||
|
||
@staticmethod
|
||
def verify_exist(model_path: Union[Path, str]):
|
||
if not isinstance(model_path, Path):
|
||
model_path = Path(model_path)
|
||
|
||
if not model_path.exists():
|
||
raise FileNotFoundError(f"{model_path} does not exist!")
|
||
|
||
if not model_path.is_file():
|
||
raise FileExistsError(f"{model_path} must be a file")
|
||
|
||
|
||
class ONNXRuntimeError(Exception):
|
||
pass
|
||
|
||
|
||
class LoadImage:
|
||
def __init__(
|
||
self,
|
||
):
|
||
pass
|
||
|
||
def __call__(self, img: InputType) -> np.ndarray:
|
||
if not isinstance(img, InputType.__args__):
|
||
raise LoadImageError(
|
||
f"The img type {type(img)} does not in {InputType.__args__}"
|
||
)
|
||
|
||
img = self.load_img(img)
|
||
img = self.convert_img(img)
|
||
return img
|
||
|
||
def load_img(self, img: InputType) -> np.ndarray:
|
||
if isinstance(img, (str, Path)):
|
||
self.verify_exist(img)
|
||
try:
|
||
img = np.array(Image.open(img))
|
||
except UnidentifiedImageError as e:
|
||
raise LoadImageError(f"cannot identify image file {img}") from e
|
||
return img
|
||
|
||
if isinstance(img, bytes):
|
||
img = np.array(Image.open(BytesIO(img)))
|
||
return img
|
||
|
||
if isinstance(img, np.ndarray):
|
||
return img
|
||
|
||
raise LoadImageError(f"{type(img)} is not supported!")
|
||
|
||
def convert_img(self, img: np.ndarray):
|
||
if img.ndim == 2:
|
||
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||
|
||
if img.ndim == 3:
|
||
channel = img.shape[2]
|
||
if channel == 1:
|
||
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||
|
||
if channel == 2:
|
||
return self.cvt_two_to_three(img)
|
||
|
||
if channel == 4:
|
||
return self.cvt_four_to_three(img)
|
||
|
||
if channel == 3:
|
||
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||
|
||
raise LoadImageError(
|
||
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
|
||
)
|
||
|
||
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
|
||
|
||
@staticmethod
|
||
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
|
||
"""RGBA → BGR"""
|
||
r, g, b, a = cv2.split(img)
|
||
new_img = cv2.merge((b, g, r))
|
||
|
||
not_a = cv2.bitwise_not(a)
|
||
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
||
|
||
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
|
||
new_img = cv2.add(new_img, not_a)
|
||
return new_img
|
||
|
||
@staticmethod
|
||
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
|
||
"""gray + alpha → BGR"""
|
||
img_gray = img[..., 0]
|
||
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
|
||
|
||
img_alpha = img[..., 1]
|
||
not_a = cv2.bitwise_not(img_alpha)
|
||
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
||
|
||
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
|
||
new_img = cv2.add(new_img, not_a)
|
||
return new_img
|
||
|
||
@staticmethod
|
||
def verify_exist(file_path: Union[str, Path]):
|
||
if not Path(file_path).exists():
|
||
raise LoadImageError(f"{file_path} does not exist.")
|
||
|
||
|
||
class LoadImageError(Exception):
|
||
pass
|
||
|
||
|
||
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
|
||
# Set pillow_interp_codes according to the naming scheme used.
|
||
if Image is not None:
|
||
if hasattr(Image, "Resampling"):
|
||
pillow_interp_codes = {
|
||
"nearest": Image.Resampling.NEAREST,
|
||
"bilinear": Image.Resampling.BILINEAR,
|
||
"bicubic": Image.Resampling.BICUBIC,
|
||
"box": Image.Resampling.BOX,
|
||
"lanczos": Image.Resampling.LANCZOS,
|
||
"hamming": Image.Resampling.HAMMING,
|
||
}
|
||
else:
|
||
pillow_interp_codes = {
|
||
"nearest": Image.NEAREST,
|
||
"bilinear": Image.BILINEAR,
|
||
"bicubic": Image.BICUBIC,
|
||
"box": Image.BOX,
|
||
"lanczos": Image.LANCZOS,
|
||
"hamming": Image.HAMMING,
|
||
}
|
||
|
||
cv2_interp_codes = {
|
||
"nearest": cv2.INTER_NEAREST,
|
||
"bilinear": cv2.INTER_LINEAR,
|
||
"bicubic": cv2.INTER_CUBIC,
|
||
"area": cv2.INTER_AREA,
|
||
"lanczos": cv2.INTER_LANCZOS4,
|
||
}
|
||
|
||
|
||
def resize_img(img, scale, keep_ratio=True):
|
||
if keep_ratio:
|
||
# 缩小使用area更保真
|
||
if min(img.shape[:2]) > min(scale):
|
||
interpolation = "area"
|
||
else:
|
||
interpolation = "bicubic" # bilinear
|
||
img_new, scale_factor = imrescale(
|
||
img, scale, return_scale=True, interpolation=interpolation
|
||
)
|
||
# the w_scale and h_scale has minor difference
|
||
# a real fix should be done in the mmcv.imrescale in the future
|
||
new_h, new_w = img_new.shape[:2]
|
||
h, w = img.shape[:2]
|
||
w_scale = new_w / w
|
||
h_scale = new_h / h
|
||
else:
|
||
img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
|
||
return img_new, w_scale, h_scale
|
||
|
||
|
||
def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
|
||
"""Resize image while keeping the aspect ratio.
|
||
|
||
Args:
|
||
img (ndarray): The input image.
|
||
scale (float | tuple[int]): The scaling factor or maximum size.
|
||
If it is a float number, then the image will be rescaled by this
|
||
factor, else if it is a tuple of 2 integers, then the image will
|
||
be rescaled as large as possible within the scale.
|
||
return_scale (bool): Whether to return the scaling factor besides the
|
||
rescaled image.
|
||
interpolation (str): Same as :func:`resize`.
|
||
backend (str | None): Same as :func:`resize`.
|
||
|
||
Returns:
|
||
ndarray: The rescaled image.
|
||
"""
|
||
h, w = img.shape[:2]
|
||
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
|
||
rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
|
||
if return_scale:
|
||
return rescaled_img, scale_factor
|
||
else:
|
||
return rescaled_img
|
||
|
||
|
||
def imresize(
|
||
img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
|
||
):
|
||
"""Resize image to a given size.
|
||
|
||
Args:
|
||
img (ndarray): The input image.
|
||
size (tuple[int]): Target size (w, h).
|
||
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
||
interpolation (str): Interpolation method, accepted values are
|
||
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
||
backend, "nearest", "bilinear" for 'pillow' backend.
|
||
out (ndarray): The output destination.
|
||
backend (str | None): The image resize backend type. Options are `cv2`,
|
||
`pillow`, `None`. If backend is None, the global imread_backend
|
||
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
||
|
||
Returns:
|
||
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
||
`resized_img`.
|
||
"""
|
||
h, w = img.shape[:2]
|
||
if backend is None:
|
||
backend = "cv2"
|
||
if backend not in ["cv2", "pillow"]:
|
||
raise ValueError(
|
||
f"backend: {backend} is not supported for resize."
|
||
f"Supported backends are 'cv2', 'pillow'"
|
||
)
|
||
|
||
if backend == "pillow":
|
||
assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
|
||
pil_image = Image.fromarray(img)
|
||
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
|
||
resized_img = np.array(pil_image)
|
||
else:
|
||
resized_img = cv2.resize(
|
||
img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
|
||
)
|
||
if not return_scale:
|
||
return resized_img
|
||
else:
|
||
w_scale = size[0] / w
|
||
h_scale = size[1] / h
|
||
return resized_img, w_scale, h_scale
|
||
|
||
|
||
def rescale_size(old_size, scale, return_scale=False):
|
||
"""Calculate the new size to be rescaled to.
|
||
|
||
Args:
|
||
old_size (tuple[int]): The old size (w, h) of image.
|
||
scale (float | tuple[int]): The scaling factor or maximum size.
|
||
If it is a float number, then the image will be rescaled by this
|
||
factor, else if it is a tuple of 2 integers, then the image will
|
||
be rescaled as large as possible within the scale.
|
||
return_scale (bool): Whether to return the scaling factor besides the
|
||
rescaled image size.
|
||
|
||
Returns:
|
||
tuple[int]: The new rescaled image size.
|
||
"""
|
||
w, h = old_size
|
||
if isinstance(scale, (float, int)):
|
||
if scale <= 0:
|
||
raise ValueError(f"Invalid scale {scale}, must be positive.")
|
||
scale_factor = scale
|
||
elif isinstance(scale, tuple):
|
||
max_long_edge = max(scale)
|
||
max_short_edge = min(scale)
|
||
scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
|
||
else:
|
||
raise TypeError(
|
||
f"Scale must be a number or tuple of int, but got {type(scale)}"
|
||
)
|
||
|
||
new_size = _scale_size((w, h), scale_factor)
|
||
|
||
if return_scale:
|
||
return new_size, scale_factor
|
||
else:
|
||
return new_size
|
||
|
||
|
||
def _scale_size(size, scale):
|
||
"""Rescale a size by a ratio.
|
||
|
||
Args:
|
||
size (tuple[int]): (w, h).
|
||
scale (float | tuple(float)): Scaling factor.
|
||
|
||
Returns:
|
||
tuple[int]: scaled size.
|
||
"""
|
||
if isinstance(scale, (float, int)):
|
||
scale = (scale, scale)
|
||
w, h = size
|
||
return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
|
||
|
||
|
||
class ImageOrientationCorrector:
|
||
"""
|
||
对图片小角度(-90 - + 90度进行修正)
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.img_loader = LoadImage()
|
||
|
||
def __call__(self, img: InputType):
|
||
img = self.img_loader(img)
|
||
# 取灰度
|
||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||
# 二值化
|
||
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
||
# 边缘检测
|
||
edges = cv2.Canny(gray, 100, 250, apertureSize=3)
|
||
# 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322
|
||
lines = cv2.HoughLines(edges, 1, np.pi / 180, 0)
|
||
for rho, theta in lines[0]:
|
||
a = np.cos(theta)
|
||
b = np.sin(theta)
|
||
x0 = a * rho
|
||
y0 = b * rho
|
||
x1 = int(x0 + 1000 * (-b))
|
||
y1 = int(y0 + 1000 * (a))
|
||
x2 = int(x0 - 1000 * (-b))
|
||
y2 = int(y0 - 1000 * (a))
|
||
if x1 == x2 or y1 == y2:
|
||
return img
|
||
else:
|
||
t = float(y2 - y1) / (x2 - x1)
|
||
# 得到角度后
|
||
rotate_angle = math.degrees(math.atan(t))
|
||
if rotate_angle > 45:
|
||
rotate_angle = -90 + rotate_angle
|
||
elif rotate_angle < -45:
|
||
rotate_angle = 90 + rotate_angle
|
||
# 旋转图像
|
||
(h, w) = img.shape[:2]
|
||
center = (w // 2, h // 2)
|
||
M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0)
|
||
return cv2.warpAffine(img, M, (w, h))
|