2025-03-24 09:27:03 +08:00

398 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- encoding: utf-8 -*-
import math
import traceback
from io import BytesIO
from pathlib import Path
from typing import List, Union
import cv2
import numpy as np
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
from PIL import Image, UnidentifiedImageError
root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path]
class OrtInferSession:
def __init__(self, model_path: Union[str, Path], num_threads: int = -1):
self.verify_exist(model_path)
self.num_threads = num_threads
self._init_sess_opt()
cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = [(cpu_ep, cpu_provider_options)]
try:
self.session = InferenceSession(
str(model_path), sess_options=self.sess_opt, providers=EP_list
)
except TypeError:
# 这里兼容ort 1.5.2
self.session = InferenceSession(str(model_path), sess_options=self.sess_opt)
def _init_sess_opt(self):
self.sess_opt = SessionOptions()
self.sess_opt.log_severity_level = 4
self.sess_opt.enable_cpu_mem_arena = False
if self.num_threads != -1:
self.sess_opt.intra_op_num_threads = self.num_threads
self.sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
return self.session.run(None, input_dict)
except Exception as e:
error_info = traceback.format_exc()
raise ONNXRuntimeError(error_info) from e
def get_input_names(
self,
):
return [v.name for v in self.session.get_inputs()]
def get_output_name(self, output_idx=0):
return self.session.get_outputs()[output_idx].name
def get_metadata(self):
meta_dict = self.session.get_modelmeta().custom_metadata_map
return meta_dict
@staticmethod
def verify_exist(model_path: Union[Path, str]):
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exist!")
if not model_path.is_file():
raise FileExistsError(f"{model_path} must be a file")
class ONNXRuntimeError(Exception):
pass
class LoadImage:
def __init__(
self,
):
pass
def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
f"The img type {type(img)} does not in {InputType.__args__}"
)
img = self.load_img(img)
img = self.convert_img(img)
return img
def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, (str, Path)):
self.verify_exist(img)
try:
img = np.array(Image.open(img))
except UnidentifiedImageError as e:
raise LoadImageError(f"cannot identify image file {img}") from e
return img
if isinstance(img, bytes):
img = np.array(Image.open(BytesIO(img)))
return img
if isinstance(img, np.ndarray):
return img
raise LoadImageError(f"{type(img)} is not supported!")
def convert_img(self, img: np.ndarray):
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.ndim == 3:
channel = img.shape[2]
if channel == 1:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if channel == 2:
return self.cvt_two_to_three(img)
if channel == 4:
return self.cvt_four_to_three(img)
if channel == 3:
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))
not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img
@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
img_gray = img[..., 0]
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
img_alpha = img[..., 1]
not_a = cv2.bitwise_not(img_alpha)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
new_img = cv2.add(new_img, not_a)
return new_img
@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")
class LoadImageError(Exception):
pass
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
# Set pillow_interp_codes according to the naming scheme used.
if Image is not None:
if hasattr(Image, "Resampling"):
pillow_interp_codes = {
"nearest": Image.Resampling.NEAREST,
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
"box": Image.Resampling.BOX,
"lanczos": Image.Resampling.LANCZOS,
"hamming": Image.Resampling.HAMMING,
}
else:
pillow_interp_codes = {
"nearest": Image.NEAREST,
"bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC,
"box": Image.BOX,
"lanczos": Image.LANCZOS,
"hamming": Image.HAMMING,
}
cv2_interp_codes = {
"nearest": cv2.INTER_NEAREST,
"bilinear": cv2.INTER_LINEAR,
"bicubic": cv2.INTER_CUBIC,
"area": cv2.INTER_AREA,
"lanczos": cv2.INTER_LANCZOS4,
}
def resize_img(img, scale, keep_ratio=True):
if keep_ratio:
# 缩小使用area更保真
if min(img.shape[:2]) > min(scale):
interpolation = "area"
else:
interpolation = "bicubic" # bilinear
img_new, scale_factor = imrescale(
img, scale, return_scale=True, interpolation=interpolation
)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img_new.shape[:2]
h, w = img.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
return img_new, w_scale, h_scale
def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
"""Resize image while keeping the aspect ratio.
Args:
img (ndarray): The input image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image.
interpolation (str): Same as :func:`resize`.
backend (str | None): Same as :func:`resize`.
Returns:
ndarray: The rescaled image.
"""
h, w = img.shape[:2]
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
if return_scale:
return rescaled_img, scale_factor
else:
return rescaled_img
def imresize(
img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
):
"""Resize image to a given size.
Args:
img (ndarray): The input image.
size (tuple[int]): Target size (w, h).
return_scale (bool): Whether to return `w_scale` and `h_scale`.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend.
out (ndarray): The output destination.
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used. Default: None.
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
"""
h, w = img.shape[:2]
if backend is None:
backend = "cv2"
if backend not in ["cv2", "pillow"]:
raise ValueError(
f"backend: {backend} is not supported for resize."
f"Supported backends are 'cv2', 'pillow'"
)
if backend == "pillow":
assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
pil_image = Image.fromarray(img)
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
resized_img = np.array(pil_image)
else:
resized_img = cv2.resize(
img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
)
if not return_scale:
return resized_img
else:
w_scale = size[0] / w
h_scale = size[1] / h
return resized_img, w_scale, h_scale
def rescale_size(old_size, scale, return_scale=False):
"""Calculate the new size to be rescaled to.
Args:
old_size (tuple[int]): The old size (w, h) of image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image size.
Returns:
tuple[int]: The new rescaled image size.
"""
w, h = old_size
if isinstance(scale, (float, int)):
if scale <= 0:
raise ValueError(f"Invalid scale {scale}, must be positive.")
scale_factor = scale
elif isinstance(scale, tuple):
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
else:
raise TypeError(
f"Scale must be a number or tuple of int, but got {type(scale)}"
)
new_size = _scale_size((w, h), scale_factor)
if return_scale:
return new_size, scale_factor
else:
return new_size
def _scale_size(size, scale):
"""Rescale a size by a ratio.
Args:
size (tuple[int]): (w, h).
scale (float | tuple(float)): Scaling factor.
Returns:
tuple[int]: scaled size.
"""
if isinstance(scale, (float, int)):
scale = (scale, scale)
w, h = size
return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
class ImageOrientationCorrector:
"""
对图片小角度(-90 - + 90度进行修正)
"""
def __init__(self):
self.img_loader = LoadImage()
def __call__(self, img: InputType):
img = self.img_loader(img)
# 取灰度
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 二值化
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
# 边缘检测
edges = cv2.Canny(gray, 100, 250, apertureSize=3)
# 霍夫变换摘自https://blog.csdn.net/feilong_csdn/article/details/81586322
lines = cv2.HoughLines(edges, 1, np.pi / 180, 0)
for rho, theta in lines[0]:
a = np.cos(theta)
b = np.sin(theta)
x0 = a * rho
y0 = b * rho
x1 = int(x0 + 1000 * (-b))
y1 = int(y0 + 1000 * (a))
x2 = int(x0 - 1000 * (-b))
y2 = int(y0 - 1000 * (a))
if x1 == x2 or y1 == y2:
return img
else:
t = float(y2 - y1) / (x2 - x1)
# 得到角度后
rotate_angle = math.degrees(math.atan(t))
if rotate_angle > 45:
rotate_angle = -90 + rotate_angle
elif rotate_angle < -45:
rotate_angle = 90 + rotate_angle
# 旋转图像
(h, w) = img.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0)
return cv2.warpAffine(img, M, (w, h))