nwjh/LLMServe/ocr_utils.py
2025-03-24 09:27:03 +08:00

365 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import re
import cv2
def split_ocr_results(ocr_results, threshold):
"""
对OCR识别结果进行切分当相邻数字字符的距离大于阈值时将文字识别框切分为两个。
:param ocr_results: OCR识别结果格式为文字识别框构成的List。
:param threshold: 距离阈值,用于判断是否切分。
:return: 切分后的OCR识别结果。
"""
def is_digit_pair(char1, char2):
"""判断两个字符是否都是数字字符。"""
return char1.isdigit() and char2.isdigit()
def calculate_distance(coord1, coord2):
"""计算两个坐标之间的欧氏距离。"""
x1, y1 = coord1
x2, y2 = coord2
return ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5
def create_new_box(box_coords, char_coords, text, confidence, chars, split_index):
"""
创建一个新的文字识别框。
:param box_coords: 原始文字识别框的坐标。
:param char_coords: 单字坐标列表。
:param text: 原始文本。
:param confidence: 置信度。
:param chars: 单字字符列表。
:param split_index: 切分位置的索引。
:return: 新的文字识别框。
"""
new_box_coords = [
[box_coords[0][0], box_coords[0][1]],
[char_coords[split_index][1][0], box_coords[0][1]],
[char_coords[split_index][2][0], box_coords[2][1]],
[box_coords[3][0], box_coords[3][1]]
]
return [
new_box_coords,
text[:split_index + 1],
confidence,
char_coords[:split_index + 1],
chars[:split_index + 1]
]
def update_existing_box(box_coords, char_coords, text, confidence, chars, split_index):
"""
更新原始文字识别框。
:param box_coords: 原始文字识别框的坐标。
:param char_coords: 单字坐标列表。
:param text: 原始文本。
:param confidence: 置信度。
:param chars: 单字字符列表。
:param split_index: 切分位置的索引。
:return: 更新后的文字识别框。
"""
return [
[
[char_coords[split_index + 1][0][0], box_coords[0][1]],
[box_coords[1][0], box_coords[1][1]],
[box_coords[2][0], box_coords[2][1]],
[char_coords[split_index + 1][3][0], box_coords[3][1]]
],
text[split_index + 1:],
confidence,
char_coords[split_index + 1:],
chars[split_index + 1:]
]
# 遍历每个文字识别框
i = 0
while i < len(ocr_results):
ocr_box = ocr_results[i]
box_coords, text, confidence, char_coords_list, chars = ocr_box
# print(f"正在处理文字识别框:{text}")
# 遍历每个单字
for j in range(len(chars) - 1):
if is_digit_pair(chars[j], chars[j + 1]):
distance = calculate_distance(char_coords_list[j][1], char_coords_list[j + 1][0])
# print(f"字符 '{chars[j]}' 和 '{chars[j + 1]}' 之间的距离为:{distance:.2f}")
if distance > threshold:
# print(f"距离大于阈值 {threshold},进行切分。")
# 创建新的文字识别框
new_box = create_new_box(box_coords, char_coords_list, text, confidence, chars, j)
# 更新原始文字识别框
ocr_results[i] = update_existing_box(box_coords, char_coords_list, text, confidence, chars, j)
# 插入新的文字识别框
ocr_results.insert(i, new_box)
# print(f"切分后的文字识别框:{new_box[1]} 和 {ocr_results[i+1][1]}")
break # 切分后,重新处理新的文字识别框
i += 1
return ocr_results
# 根据输入
# 在ocr结果中搜索字符串
def search_in_ocr_results(ocr_results, search_text):
"""
在OCR识别结果中搜索指定字符串。
:param ocr_results: OCR识别结果。
:param search_text: 搜索字符串。
:return: 搜索结果。
"""
results = []
for index, ocr_result in enumerate(ocr_results):
for char_index, char in enumerate(ocr_result[4]):
if char == search_text:
results.append((index, char_index))
return results
# 判断ocr结果是否存在
def is_ocr_result_exist(image_path, save_path="json_data"):
"""
判断OCR识别结果是否存在。
:param image_path: 图片路径。
:return: 是否存在。
"""
# 获取图片名
image_name = os.path.basename(image_path)
# 构造JSON文件路径
json_path = os.path.join(save_path, f"{image_name}.json")
return os.path.exists(json_path)
# 将ocr结果根据图片路径获取图片名暂存到json文件中
def save_ocr_result(ocr_result, image_path, save_path="json_data"):
"""
将OCR识别结果保存到JSON文件中。
:param ocr_result: OCR识别结果。
:param image_path: 图片路径。
:return: JSON文件路径。
"""
# 创建保存路径
if not os.path.exists(save_path):
os.makedirs(save_path)
# 获取图片名
image_name = os.path.basename(image_path)
# 创建JSON文件路径
json_path = os.path.join(save_path, f"{image_name}.json")
# 保存OCR识别结果到JSON文件
with open(json_path, 'w', encoding='utf-8') as file:
json.dump(ocr_result, file, ensure_ascii=False, indent=4)
return json_path
# 从json文件中读取ocr结果
def load_ocr_result(image_path, save_path="json_data"):
"""
从JSON文件中加载OCR识别结果。
:param image_path: JSON文件路径。
:return: OCR识别结果。
"""
# 构造JSON文件路径
image_name = os.path.basename(image_path)
json_path = os.path.join(save_path, f"{image_name}.json")
# 读取JSON文件
with open(json_path, 'r', encoding='utf-8') as file:
ocr_result = json.load(file)
return ocr_result
# 从删除ocr结果
def remove_ocr_result(image_path, save_path="json_data"):
"""
删除OCR识别结果。
:param image_path: 图片路径。
:return: 是否删除成功。
"""
# 获取图片名
image_name = os.path.basename(image_path)
# 构造JSON文件路径
json_path = os.path.join(save_path, f"{image_name}.json")
# 删除JSON文件
if os.path.exists(json_path):
os.remove(json_path)
return True
return False
# 绘制并保存ocr结果
def draw_ocr_results(image_path, boxes_list, save_path="ocr_mark_data"):
"""
绘制OCR识别结果并保存图片。
:param image_path: 图片路径。
:param boxes: 文字识别框列表。
:param save_path: 保存路径。
:return: 保存的图片路径。
"""
if not os.path.exists(save_path):
os.makedirs(save_path)
# 读取图片
img = cv2.imread(image_path)
# 绘制方框
for index, boxes in enumerate(boxes_list):
for box in boxes:
cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), get_color(index), 2)
# cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
# 保存图片
save_image_path = os.path.join(save_path, os.path.basename(image_path))
cv2.imwrite(save_image_path, img)
return save_image_path
# 根据索引,获取一个颜色
def get_color(index):
"""
根据索引获取一个颜色。
:param index: 索引。
:return: 颜色。
"""
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(0, 255, 255),
(255, 0, 255),
(128, 0, 0),
(0, 128, 0),
(0, 0, 128),
(128, 128, 0),
(0, 128, 128),
(128, 0, 128),
(128, 128, 128),
(255, 128, 0),
(255, 0, 128),
(128, 255, 0),
(0, 255, 128),
(128, 0, 255),
(0, 128, 255),
(255, 128, 128),
(128, 255, 128),
(128, 128, 255),
(255, 255, 128),
(255, 128, 255),
(128, 255, 255),
(255, 255, 255)
]
return colors[index % len(colors)]
def re_search(a, b):
"""
使用正则表达式在字符串列表b中搜索字符串a。
:param a:
:param b:
:return:
"""
if len(a) <= 1:
return []
results = []
a = replace_punctuation_with_space(a) # 删除标点符号
# 拷贝一份
a_tmp = a
a_str = re.escape(a).replace("\\ ", "\\s*") # 将目标字符串中的空格替换为 \s*
a_str = re.compile(a_str, re.IGNORECASE) # 忽略大小写
# print(a_str.pattern)
# a_str = remove_punctuation(a_str) # 删除标点符号
for index, string in enumerate(b):
string = replace_punctuation_with_space(string)
match = a_str.search(string)
if match:
results.append((index, match.start(), match.end()))
else:
string = replace_punctuation_with_space(string)
# 如果长度小于等于1直接跳过
if len(string) <= 1:
continue
b_str = re.escape(string).replace("\\ ", "\\s*")
b_str = re.compile(b_str, re.IGNORECASE)
match = b_str.search(a)
if match:
start_index = match.start()
end_index = match.end()
# 将a_tmp中的匹配部分替换为空格
a_tmp = replace_matched_text(a_tmp, b_str, "")
# print('匹配到的字符串:', string, '匹配到的字符串在a中的位置:', start_index, end_index)
results.append((index, 0, len(string)))
# 如果a_tmp中还有字符说明a_tmp中的字符没有在b中匹配到
if len(a_tmp) >= 2:
# 重新搜索一次a_tmp
a_tmp = re.escape(a_tmp).replace("\\ ", "\\s*")
a_tmp = re.compile(a_tmp, re.IGNORECASE)
# print('a_tmp:', a_tmp, 'a:', a)
for index, string in enumerate(b):
string = replace_punctuation_with_space(string)
match = a_tmp.search(string)
if match:
# print('匹配到a_tmp:', a_tmp, 'string:', string)
results.append((index, match.start(), match.end()))
else:
print('过短a_tmp:', a_tmp)
return results
# 删除字符串中的符号
def remove_punctuation(text):
"""
删除字符串中的标点符号。
:param text: 字符串。
:return: 删除标点符号后的字符串。
"""
return re.sub(r'[^\w\s]', '', text)
# 将字符串中的标点符号替换为空格
def replace_punctuation_with_space(text):
"""
将字符串中的标点符号替换为空格。
:param text: 字符串。
:return: 替换后的字符串。
"""
return re.sub(r'[^\w\s]', ' ', text)
def replace_matched_text(text, pattern, placeholder=" "):
"""
将被正则表达式匹配到的文本替换为指定占位符(默认为空格)。
参数:
text (str): 原始字符串。
pattern (str 或 re.Pattern): 正则表达式模式。
placeholder (str): 用于替换匹配文本的占位符,默认为空格。
返回:
str: 替换后的字符串。
"""
# 如果传入的是字符串形式的正则表达式,则编译它
if isinstance(pattern, str):
pattern = re.compile(pattern)
# 使用 re.sub 进行替换
return pattern.sub(placeholder, text)
if __name__ == "__main__":
s = "202411.12"
print(replace_punctuation_with_space(s))
print(replace_punctuation_with_space("2024.11.12"))
print(replace_punctuation_with_space("1375421.63"))