377 lines
12 KiB
Python
377 lines
12 KiB
Python
|
|
import uuid
|
|||
|
|
|
|||
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
|
|||
|
|
import ssl
|
|||
|
|
|
|||
|
|
from uvicorn import Config, Server
|
|||
|
|
from fastapi.responses import JSONResponse
|
|||
|
|
from fastapi.encoders import jsonable_encoder
|
|||
|
|
import requests
|
|||
|
|
import uvicorn
|
|||
|
|
import config
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from typing import Union, List
|
|||
|
|
import fitz
|
|||
|
|
from PIL import Image
|
|||
|
|
import json
|
|||
|
|
import re
|
|||
|
|
import os
|
|||
|
|
from openai import OpenAI
|
|||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
|
from fastapi.staticfiles import StaticFiles
|
|||
|
|
|
|||
|
|
from classify import predict_folder
|
|||
|
|
from ocr import get_ocr_list, get_ocr_image_list, get_ocr
|
|||
|
|
from torchvision import transforms, models
|
|||
|
|
|
|||
|
|
|
|||
|
|
app = FastAPI()
|
|||
|
|
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=["*"],
|
|||
|
|
allow_credentials=True,
|
|||
|
|
allow_methods=["*"],
|
|||
|
|
allow_headers=["*"]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
app.mount("/data", StaticFiles(directory="data"), name="data")
|
|||
|
|
|
|||
|
|
# 确保上传目录存在
|
|||
|
|
current_directory = os.getcwd()
|
|||
|
|
upload_dir = os.path.join(current_directory, "data\\")
|
|||
|
|
images_dir = upload_dir + "img\\"
|
|||
|
|
pdfs_dir = upload_dir + "content\\"
|
|||
|
|
os.makedirs(images_dir, exist_ok=True)
|
|||
|
|
os.makedirs(pdfs_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_conv(text_prompt, ocr_text):
|
|||
|
|
print("begin deepseek ask")
|
|||
|
|
messages = [
|
|||
|
|
{'role': 'user', 'content': str(ocr_text)},
|
|||
|
|
{'role': 'assistant', 'content': text_prompt}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# client = OpenAI(
|
|||
|
|
# api_key="sk-0ffaa2ae7c5c499aa7fd03e646b6717a",
|
|||
|
|
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|||
|
|
# )
|
|||
|
|
client = OpenAI(api_key='ollama',
|
|||
|
|
base_url="http://localhost:11434/v1")
|
|||
|
|
response = client.chat.completions.create(
|
|||
|
|
# model="qwen-plus",
|
|||
|
|
model = "deepseek-r1",
|
|||
|
|
messages=messages
|
|||
|
|
)
|
|||
|
|
final_response = response.choices[0].message.content
|
|||
|
|
match = re.search(r"```json(.*?)```", final_response, re.DOTALL)
|
|||
|
|
res = match.group(1)
|
|||
|
|
print("llm result: ", res)
|
|||
|
|
|
|||
|
|
return res
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 发票
|
|||
|
|
@app.post("/invoice/")
|
|||
|
|
async def invoice(files: List[UploadFile] = File(...)):
|
|||
|
|
image_url = []
|
|||
|
|
for file in files:
|
|||
|
|
if file.content_type.startswith('image/'):
|
|||
|
|
# 保存图片文件
|
|||
|
|
file_location = images_dir + file.filename
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=400, detail="Unsupported file type")
|
|||
|
|
|
|||
|
|
with open(file_location, "wb+") as file_object:
|
|||
|
|
file_object.write(await file.read())
|
|||
|
|
|
|||
|
|
image_url.append(file_location)
|
|||
|
|
print("image_path: ", image_url)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
text_prompt = config.invoice
|
|||
|
|
|
|||
|
|
ocr_text = get_ocr_list(image_url) # 进行OCR,返回 String
|
|||
|
|
output = run_conv(text_prompt, ocr_text) # 调用大模型
|
|||
|
|
new_url = get_ocr_image_list(image_url, output)
|
|||
|
|
|
|||
|
|
return JSONResponse(content={"text": output,"image_url": new_url, "class": "发票"})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
|||
|
|
|
|||
|
|
# 申请表
|
|||
|
|
@app.post("/application/")
|
|||
|
|
async def application(files: List[UploadFile] = File(...)):
|
|||
|
|
image_url = []
|
|||
|
|
for file in files:
|
|||
|
|
if file.content_type.startswith('image/'):
|
|||
|
|
# 保存图片文件
|
|||
|
|
file_location = images_dir + file.filename
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=400, detail="Unsupported file type")
|
|||
|
|
|
|||
|
|
with open(file_location, "wb+") as file_object:
|
|||
|
|
file_object.write(await file.read())
|
|||
|
|
|
|||
|
|
image_url.append(file_location)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
text_prompt = config.application
|
|||
|
|
ocr_text = get_ocr_list(image_url) # 进行OCR,返回String
|
|||
|
|
output = run_conv(text_prompt, ocr_text) # 调用大模型
|
|||
|
|
|
|||
|
|
return JSONResponse(content={"text": output, "image_url":image_url, "class":"申请单"})
|
|||
|
|
except Exception as e:
|
|||
|
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
|||
|
|
|
|||
|
|
# 确认表
|
|||
|
|
@app.post("/confirmation/")
|
|||
|
|
async def confirmation(files: List[UploadFile] = File(...)):
|
|||
|
|
image_url = []
|
|||
|
|
|
|||
|
|
for file in files:
|
|||
|
|
if file.content_type.startswith('image/'):
|
|||
|
|
# 保存图片文件
|
|||
|
|
file_location = images_dir + file.filename
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=400, detail="Unsupported file type")
|
|||
|
|
|
|||
|
|
with open(file_location, "wb+") as file_object:
|
|||
|
|
file_object.write(await file.read())
|
|||
|
|
|
|||
|
|
image_url.append(file_location)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
text_prompt = config.confirmation
|
|||
|
|
ocr_text = get_ocr_list(image_url)
|
|||
|
|
output = run_conv(text_prompt, ocr_text)
|
|||
|
|
|
|||
|
|
return JSONResponse(content={"text": output, "image_url":image_url, "class":"确认表"})
|
|||
|
|
except Exception as e:
|
|||
|
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
|||
|
|
|
|||
|
|
# 合同
|
|||
|
|
@app.post("/contract/")
|
|||
|
|
async def contract(files: UploadFile = File(...)):
|
|||
|
|
if files.content_type != 'application/pdf':
|
|||
|
|
raise HTTPException(status_code=400, detail="文件类型不符")
|
|||
|
|
|
|||
|
|
random_filename = str(uuid.uuid4())
|
|||
|
|
file_location = pdfs_dir + f"{random_filename}.pdf"
|
|||
|
|
|
|||
|
|
with open(file_location, "wb+") as file_object:
|
|||
|
|
file_object.write(await files.read())
|
|||
|
|
print("合同上传位置:", file_location)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
img_url, ocr_text = pagehome_llm(file_location)
|
|||
|
|
text_prompt = config.contract
|
|||
|
|
print("contract img_url: ", img_url)
|
|||
|
|
|
|||
|
|
output = run_conv(text_prompt, ocr_text)
|
|||
|
|
new_url = get_ocr_image_list(img_url, output)
|
|||
|
|
|
|||
|
|
return JSONResponse(content={"text": output, "image_url": new_url, "class": "合同表"})
|
|||
|
|
except Exception as e:
|
|||
|
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/contractPayment/")
|
|||
|
|
async def contract_amount(files: UploadFile = File(...)):
|
|||
|
|
if files.content_type != 'application/pdf':
|
|||
|
|
raise HTTPException(status_code=400, detail="Unsupported file type")
|
|||
|
|
|
|||
|
|
random_filename = str(uuid.uuid4())
|
|||
|
|
file_location = pdfs_dir + f"{random_filename}.pdf"
|
|||
|
|
|
|||
|
|
with open(file_location, "wb+") as file_object:
|
|||
|
|
file_object.write(await files.read())
|
|||
|
|
print("合同上传位置:", file_location)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
json_data = """
|
|||
|
|
{
|
|||
|
|
"page":"",
|
|||
|
|
"boolean":""
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
result = json.loads(json_data)
|
|||
|
|
|
|||
|
|
# 打开PDF文件
|
|||
|
|
pdf_document = fitz.open(file_location)
|
|||
|
|
for page_num in range(len(pdf_document)):
|
|||
|
|
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
|
|||
|
|
|
|||
|
|
ocr_text = get_ocr(image_filename)
|
|||
|
|
if "合同价款暂定为人民币含税价小写" in str(ocr_text):
|
|||
|
|
page = re.search(r'(\d+)\.png', image_filename).group(1)
|
|||
|
|
result["page"] = page
|
|||
|
|
if "合同价款结算按第" in str(ocr_text):
|
|||
|
|
pattern = r"合同价款结算按第(\d+)_种方式"
|
|||
|
|
match = re.search(pattern, str(ocr_text))
|
|||
|
|
if match.group(1) == "2":
|
|||
|
|
result["boolean"] = "是"
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if result["page"] == "":
|
|||
|
|
result["page"] = "0"
|
|||
|
|
if result["boolean"] == "":
|
|||
|
|
result["boolean"] = "否"
|
|||
|
|
return JSONResponse(content={"text": result, "pdf_url": file_location, "class": "合同抽取"})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
|||
|
|
|
|||
|
|
|
|||
|
|
def pdf_2_images(pdf_document, output_folder, page_num):
|
|||
|
|
# 获取页面
|
|||
|
|
page = pdf_document.load_page(page_num)
|
|||
|
|
|
|||
|
|
# 将页面转换成图像
|
|||
|
|
pix = page.get_pixmap()
|
|||
|
|
|
|||
|
|
# 使用PIL保存图像
|
|||
|
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
|||
|
|
|
|||
|
|
match = re.search(r'\\([^\\]+)\.pdf$', pdf_document.name)
|
|||
|
|
contract_name = match.group(1)
|
|||
|
|
image_filename = "{}{}page_{}.png".format(output_folder, contract_name, page_num + 1)
|
|||
|
|
img.save(image_filename, "PNG")
|
|||
|
|
print("合同分页保存位置:", image_filename)
|
|||
|
|
return image_filename
|
|||
|
|
|
|||
|
|
def pagehome_llm(pdf_path):
|
|||
|
|
result = []
|
|||
|
|
ocr_text = []
|
|||
|
|
pdf_document = fitz.open(pdf_path)
|
|||
|
|
|
|||
|
|
for page_num in range(len(pdf_document)):
|
|||
|
|
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
|
|||
|
|
ocr_text.append(get_ocr(image_filename))
|
|||
|
|
res = "甲方" and "乙方" and "合同编号" and "签订地点" in str(ocr_text)
|
|||
|
|
|
|||
|
|
if res:
|
|||
|
|
print("合同识别起始页:", page_num)
|
|||
|
|
result.append(image_filename)
|
|||
|
|
|
|||
|
|
image_filename = pdf_2_images(pdf_document, images_dir, page_num + 1)
|
|||
|
|
ocr_text.append(get_ocr(image_filename))
|
|||
|
|
result.append(image_filename)
|
|||
|
|
|
|||
|
|
image_filename = pdf_2_images(pdf_document, images_dir, page_num + 2)
|
|||
|
|
ocr_text.append(get_ocr(image_filename))
|
|||
|
|
result.append(image_filename)
|
|||
|
|
|
|||
|
|
return result, ocr_text
|
|||
|
|
|
|||
|
|
return result, ocr_text
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def find_payment_page(pdf_path, output_folder):
|
|||
|
|
# text_prompt = config.findPage
|
|||
|
|
|
|||
|
|
pdf_document = fitz.open(pdf_path)
|
|||
|
|
for page_num in range(len(pdf_document)):
|
|||
|
|
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
|
|||
|
|
ocr_text = get_ocr(image_filename)
|
|||
|
|
if "合同价款暂定为人民币含税价小写" in str(ocr_text):
|
|||
|
|
result = re.search(r'(\d+)\.png', image_filename).group(1)
|
|||
|
|
print("位置为第" + result + "页")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def extract_images(image_index, pdf_path,output_folder):
|
|||
|
|
# images_pdf = convert_from_path(pdf_path)
|
|||
|
|
pdf_document = fitz.open(pdf_path)
|
|||
|
|
index = int(image_index)
|
|||
|
|
i = 0
|
|||
|
|
ocr_texts = []
|
|||
|
|
text_prompt = config.extract
|
|||
|
|
if len(pdf_document) >= index:
|
|||
|
|
while i < 2:
|
|||
|
|
image_filename = pdf_2_images(pdf_document, images_dir, index-1)
|
|||
|
|
ocr_text = get_ocr(image_filename)
|
|||
|
|
ocr_texts.append(ocr_text)
|
|||
|
|
i += 1
|
|||
|
|
index += 1
|
|||
|
|
else:
|
|||
|
|
print("超出PDF文件长度")
|
|||
|
|
|
|||
|
|
output = run_conv(text_prompt, ocr_texts)
|
|||
|
|
output = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', output)
|
|||
|
|
json_data = json.loads(output)
|
|||
|
|
if json_data["result"]:
|
|||
|
|
result = json_data["result"]
|
|||
|
|
return JSONResponse(content=result)
|
|||
|
|
|
|||
|
|
@app.post("/findAndExtract/")
|
|||
|
|
async def findAndExtract(request:Request):
|
|||
|
|
|
|||
|
|
body = await request.json()
|
|||
|
|
file_path = body.get("file_path")
|
|||
|
|
|
|||
|
|
# 检查文件夹路径是否存在
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
return JSONResponse(content={"error": "File path does not exist"}, status_code=400)
|
|||
|
|
|
|||
|
|
index = find_payment_page(file_path, images_dir) # 遍历到第几页
|
|||
|
|
result = extract_images(index, file_path, images_dir)
|
|||
|
|
json_result = jsonable_encoder(result)
|
|||
|
|
return JSONResponse(content=json_result)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post('/classify')
|
|||
|
|
async def get_images(request: Request):
|
|||
|
|
body = await request.json()
|
|||
|
|
folder_path = body.get("file_path")
|
|||
|
|
|
|||
|
|
# 检查文件夹路径是否存在
|
|||
|
|
if not os.path.exists(folder_path):
|
|||
|
|
return JSONResponse(content={"error": "File path does not exist"}, status_code=400)
|
|||
|
|
|
|||
|
|
class_names = ['发票', '确认表', '申请表', '验收证书', '其他']
|
|||
|
|
# 训练权重路径
|
|||
|
|
weights_path = "../Models/checkpoints/model.pth"
|
|||
|
|
|
|||
|
|
# 对图像进行变换
|
|||
|
|
transform = transforms.Compose([
|
|||
|
|
transforms.Resize((224, 224)),
|
|||
|
|
transforms.RandomHorizontalFlip(),
|
|||
|
|
transforms.RandomRotation(10),
|
|||
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
result = predict_folder(folder_path, weights_path, transform, class_names)
|
|||
|
|
json_result = jsonable_encoder(result)
|
|||
|
|
return JSONResponse(content=json_result)
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|||
|
|
# ssl_context.load_cert_chain(certfile="/home/oem/llm/certificate.crt",keyfile="/home/oem/llm/private.key")
|
|||
|
|
# config = Config(app=app, host="0.0.0.0", port=8000, ssl_context=ssl_context)
|
|||
|
|
# uvicorn.run(app, host="0.0.0.0", port=8000,ssl=ssl_context)
|
|||
|
|
|
|||
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|||
|
|
|
|||
|
|
# weights_path = "../Models/checkpoints/model_epoch_1.pth"
|
|||
|
|
#
|
|||
|
|
# # 对图像进行变换
|
|||
|
|
# transform = transforms.Compose([
|
|||
|
|
# transforms.Resize((224, 224)),
|
|||
|
|
# transforms.RandomHorizontalFlip(),
|
|||
|
|
# transforms.RandomRotation(10),
|
|||
|
|
# transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
|||
|
|
# transforms.ToTensor(),
|
|||
|
|
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|||
|
|
# ])
|
|||
|
|
# class_names = ['发票', '确定表', '申请表']
|
|||
|
|
# print(predict_folder("data/img/分类", weights_path, transform, class_names))
|
|||
|
|
# app.run(host='0.0.0.0', port=5001)
|
|||
|
|
|