nwjh/LLMServe/ocr.py
2025-03-24 09:27:03 +08:00

188 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from rapidocr_onnxruntime import RapidOCR
from rapidocr_paddle import RapidOCR as RapidOCR_paddle
from ocr_utils import split_ocr_results, save_ocr_result, is_ocr_result_exist, load_ocr_result, re_search, \
draw_ocr_results, remove_ocr_result
from wired_table_rec import WiredTableRecognition
from wired_table_rec.table_line_rec import TableLineRecognition
from wired_table_rec.utils import LoadImage
import os
current_directory = os.getcwd()
upload_dir = os.path.join(current_directory, "data\\")
def get_ocr(image_path,num_threshold = 15,is_gpu=False, json_save_path = upload_dir + "json_data"):
if is_gpu:
ocr_engine = RapidOCR_paddle(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
ocr_engine = RapidOCR(det_model_path='../Models/ch_PP-OCRv4_det_infer.onnx',
rec_model_path='../Models/ch_PP-OCRv4_rec_infer.onnx', box_thresh=0.3, return_word_box=True)
ocr_result, _ = ocr_engine(image_path, box_thresh=0.3, return_word_box=True)
if ocr_result == None:
ocr_result = []
for res in ocr_result:
# 删除最后一个元素
res.pop()
# 将ocr识别结果保存为json
save_ocr_result(ocr_result, image_path, save_path=json_save_path)
# 文字识别结果切分
ocr_result = split_ocr_results(ocr_result, num_threshold)
# 识别表格
wired_engine = WiredTableRecognition()
html, elasp, table_result, logic_points, ocr_res, cell_box_det_map, not_match_orc_boxes = wired_engine(image_path, ocr_result=ocr_result)
if cell_box_det_map == None:
return [[item[1]] for item in ocr_result]
fin_list = list(cell_box_det_map.values()) + list(not_match_orc_boxes.values())
# 对每个子列表,获取第二个元素
output = [item[1] for item in fin_list]
print("ocr result:", output)
return output
# print('output:', output)
def get_ocr_list(image_path_list,num_threshold = 15, is_gpu=False, json_save_path = upload_dir + "json_data"):
"""
获取多个图片的ocr结果
:param image_path_list: 图片地址列表
:param num_threshold: 相邻数字字符之间的分割阈值
:param is_gpu: 是否使用gpu
:param json_save_path: json保存地址
:return: [ocr_result1,ocr_result2,...]
ocr_result1: [box1,box2,...]
"""
# 如果是单个图片地址
if type(image_path_list) == str:
return [get_ocr(image_path_list,num_threshold)]
output = []
for image_path in image_path_list:
output.append(get_ocr(image_path,num_threshold,is_gpu,json_save_path=json_save_path))
return output
def get_ocr_image(image_path, img_json, json_save_path, img_save_path, is_gpu = False):
"""
获取ocr标记图片地址
:param image_path: 图片地址
:param img_json: 图片json
:return: 标记图片地址
"""
# 需要标注的方框坐标
boxes = []
# 读取图片ocr结果
if is_ocr_result_exist(image_path, save_path=json_save_path):
ocr_result = load_ocr_result(image_path, save_path=json_save_path)
else:
get_ocr(image_path, is_gpu=is_gpu, json_save_path=json_save_path)
ocr_result = load_ocr_result(image_path, save_path=json_save_path)
# print("ocr_result:",ocr_result)
ocr_result_only_str = [item[1] for item in ocr_result]
# print("ocr_result_only_str:",ocr_result_only_str)
img_json = json.loads(img_json)
# 遍历json
for key, value in img_json.items():
# print("key:", key, "value:", value)
if value == "":
continue
search_result = re_search(value, ocr_result_only_str)
if search_result != []:
box = []
for item in search_result:
txt_index, start_index, end_index = item
# 获取方框坐标
# print("ocr_result_box:", box)
box.append(ocr_result[txt_index][3][start_index][0] + ocr_result[txt_index][3][end_index-1][2])
boxes.append(box)
# 绘制保存ocr结果
ocr_results_path = draw_ocr_results(image_path, boxes, save_path=img_save_path)
# 删掉ocr的json文件
remove_ocr_result(image_path, save_path=json_save_path)
return ocr_results_path
# 返回ocr标记图片地址
def get_ocr_image_list(image_path_list, img_json, json_save_path = upload_dir + "json_data", img_save_path = upload_dir + "ocr_mark_data"):
"""
获取ocr标记图片地址
:param image_path_list: 图片地址列表
:param img_json: 图片json
:param json_save_path: json保存地址
:param img_save_path: 图片保存地址
:return: 标记图片地址[ocr_img1,ocr_img2,...]
"""
# 读取图片ocr结果
ocr_result_list = []
for i in range(len(image_path_list)):
img_json = img_json
image_path = image_path_list[i]
ocr_img = get_ocr_image(image_path, img_json, json_save_path, img_save_path)
ocr_result_list.append(ocr_img)
return ocr_result_list
if __name__ == '__main__':
# print(get_ocr("images/1/ivc-591248001.jpg",is_gpu=True))
# with open("ocr_res.json", "w", encoding="utf-8") as file:
# json.dump(get_ocr_list("images/1/ivc-591248001.jpg",is_gpu=True), file, ensure_ascii=False, indent=4)
# heton_json = {
# "合同名称":"广东电网有限责任公司2024年数字服务工单项目营销系统工单开发实施框架合同之电能表参数配置等工 单委托函",
# "合同编号":"0375002024030102YG00046",
# "项目名称":"2024年数字服务工单项目",
# "项目编号":"037800HK23120061",
# "起止时间":"2024年6月至2024年11月",
# "工作内容及完成情况":"项目已于2024年11月4日完成初步验收。",
# "结论及意见":"广东电网有限责任公 司2024年数字服务工单项目营销系统工单)开发实施框架合同之(电能表参数配置等)工单委托函签订后,已完成初步验收工作。",
# "委托 函的约定情况":"依据委托函约定“本委托函暂定价款为人民币含税价¥3,922,900.00元(大写:叁佰玖拾贰万贰仟玖佰元整),根据框架合同约定,单项工单费用计算方式:工单服务费=甲方评审后的单项工单概算费用×中标费率。所以本委托函结算价=评审后的单项工单概算费用×中标费率93.6%。”本委托函评审后的单项工单概算费用=开发费3482390元+实施费413049元=3895439元故本委托函结算价=评审后的单项工单概算费用3895439元×中标费率93.6%=3646130.90元。",
# "支付方式":"2、初验款项目通过甲方组织【初步验收】合格 并取得乙方开具的相应金额的符合国家规定的正规发票后45个工作日内甲方向乙方支付至本委托函结算价[70]%的条款。委托函结算价 为¥3,646,130.90元,初验款金额=委托函结算价3646130.90元*70%-首付款1176870元=¥1,375,421.63元。该项目已达到初验款支付条件金额为¥1,375,421.63元,即人民币壹佰叁拾柒万伍仟肆佰贰拾壹元陆角叁分。",
# "是否达到支付条件":"是",
# "支付金额":"1375421.63",
# "日期":"202411.12"
# }
# get_ocr_image("images/1/ivc-591248004.jpg", heton_json)
print(get_ocr("/root/autodl-tmp/LLMServe/data/img/广东电网有限责任公司信息中心2024年客户服务平台网级95598语音平台V1page_4.png"))
# out = get_ocr("images/1/ivc-591248004.jpg")
#
# # 保存到文件
# with open("ocr_res.json", "w", encoding="utf-8") as file:
# json.dump(out, file, ensure_ascii=False, indent=4)