377 lines
12 KiB
Python
377 lines
12 KiB
Python
import uuid
|
||
|
||
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
|
||
import ssl
|
||
|
||
from uvicorn import Config, Server
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.encoders import jsonable_encoder
|
||
import requests
|
||
import uvicorn
|
||
import config
|
||
from pydantic import BaseModel
|
||
from typing import Union, List
|
||
import fitz
|
||
from PIL import Image
|
||
import json
|
||
import re
|
||
import os
|
||
from openai import OpenAI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.staticfiles import StaticFiles
|
||
|
||
from classify import predict_folder
|
||
from ocr import get_ocr_list, get_ocr_image_list, get_ocr
|
||
from torchvision import transforms, models
|
||
|
||
|
||
app = FastAPI()
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"]
|
||
)
|
||
|
||
app.mount("/data", StaticFiles(directory="data"), name="data")
|
||
|
||
# 确保上传目录存在
|
||
current_directory = os.getcwd()
|
||
upload_dir = os.path.join(current_directory, "data\\")
|
||
images_dir = upload_dir + "img\\"
|
||
pdfs_dir = upload_dir + "content\\"
|
||
os.makedirs(images_dir, exist_ok=True)
|
||
os.makedirs(pdfs_dir, exist_ok=True)
|
||
|
||
|
||
def run_conv(text_prompt, ocr_text):
|
||
print("begin deepseek ask")
|
||
messages = [
|
||
{'role': 'user', 'content': str(ocr_text)},
|
||
{'role': 'assistant', 'content': text_prompt}
|
||
]
|
||
|
||
# client = OpenAI(
|
||
# api_key="sk-0ffaa2ae7c5c499aa7fd03e646b6717a",
|
||
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
# )
|
||
client = OpenAI(api_key='ollama',
|
||
base_url="http://localhost:11434/v1")
|
||
response = client.chat.completions.create(
|
||
# model="qwen-plus",
|
||
model = "deepseek-r1",
|
||
messages=messages
|
||
)
|
||
final_response = response.choices[0].message.content
|
||
match = re.search(r"```json(.*?)```", final_response, re.DOTALL)
|
||
res = match.group(1)
|
||
print("llm result: ", res)
|
||
|
||
return res
|
||
|
||
|
||
# 发票
|
||
@app.post("/invoice/")
|
||
async def invoice(files: List[UploadFile] = File(...)):
|
||
image_url = []
|
||
for file in files:
|
||
if file.content_type.startswith('image/'):
|
||
# 保存图片文件
|
||
file_location = images_dir + file.filename
|
||
else:
|
||
raise HTTPException(status_code=400, detail="Unsupported file type")
|
||
|
||
with open(file_location, "wb+") as file_object:
|
||
file_object.write(await file.read())
|
||
|
||
image_url.append(file_location)
|
||
print("image_path: ", image_url)
|
||
|
||
try:
|
||
text_prompt = config.invoice
|
||
|
||
ocr_text = get_ocr_list(image_url) # 进行OCR,返回 String
|
||
output = run_conv(text_prompt, ocr_text) # 调用大模型
|
||
new_url = get_ocr_image_list(image_url, output)
|
||
|
||
return JSONResponse(content={"text": output,"image_url": new_url, "class": "发票"})
|
||
|
||
except Exception as e:
|
||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||
|
||
# 申请表
|
||
@app.post("/application/")
|
||
async def application(files: List[UploadFile] = File(...)):
|
||
image_url = []
|
||
for file in files:
|
||
if file.content_type.startswith('image/'):
|
||
# 保存图片文件
|
||
file_location = images_dir + file.filename
|
||
else:
|
||
raise HTTPException(status_code=400, detail="Unsupported file type")
|
||
|
||
with open(file_location, "wb+") as file_object:
|
||
file_object.write(await file.read())
|
||
|
||
image_url.append(file_location)
|
||
|
||
try:
|
||
text_prompt = config.application
|
||
ocr_text = get_ocr_list(image_url) # 进行OCR,返回String
|
||
output = run_conv(text_prompt, ocr_text) # 调用大模型
|
||
|
||
return JSONResponse(content={"text": output, "image_url":image_url, "class":"申请单"})
|
||
except Exception as e:
|
||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||
|
||
# 确认表
|
||
@app.post("/confirmation/")
|
||
async def confirmation(files: List[UploadFile] = File(...)):
|
||
image_url = []
|
||
|
||
for file in files:
|
||
if file.content_type.startswith('image/'):
|
||
# 保存图片文件
|
||
file_location = images_dir + file.filename
|
||
else:
|
||
raise HTTPException(status_code=400, detail="Unsupported file type")
|
||
|
||
with open(file_location, "wb+") as file_object:
|
||
file_object.write(await file.read())
|
||
|
||
image_url.append(file_location)
|
||
|
||
try:
|
||
text_prompt = config.confirmation
|
||
ocr_text = get_ocr_list(image_url)
|
||
output = run_conv(text_prompt, ocr_text)
|
||
|
||
return JSONResponse(content={"text": output, "image_url":image_url, "class":"确认表"})
|
||
except Exception as e:
|
||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||
|
||
# 合同
|
||
@app.post("/contract/")
|
||
async def contract(files: UploadFile = File(...)):
|
||
if files.content_type != 'application/pdf':
|
||
raise HTTPException(status_code=400, detail="文件类型不符")
|
||
|
||
random_filename = str(uuid.uuid4())
|
||
file_location = pdfs_dir + f"{random_filename}.pdf"
|
||
|
||
with open(file_location, "wb+") as file_object:
|
||
file_object.write(await files.read())
|
||
print("合同上传位置:", file_location)
|
||
|
||
try:
|
||
img_url, ocr_text = pagehome_llm(file_location)
|
||
text_prompt = config.contract
|
||
print("contract img_url: ", img_url)
|
||
|
||
output = run_conv(text_prompt, ocr_text)
|
||
new_url = get_ocr_image_list(img_url, output)
|
||
|
||
return JSONResponse(content={"text": output, "image_url": new_url, "class": "合同表"})
|
||
except Exception as e:
|
||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||
|
||
|
||
@app.post("/contractPayment/")
|
||
async def contract_amount(files: UploadFile = File(...)):
|
||
if files.content_type != 'application/pdf':
|
||
raise HTTPException(status_code=400, detail="Unsupported file type")
|
||
|
||
random_filename = str(uuid.uuid4())
|
||
file_location = pdfs_dir + f"{random_filename}.pdf"
|
||
|
||
with open(file_location, "wb+") as file_object:
|
||
file_object.write(await files.read())
|
||
print("合同上传位置:", file_location)
|
||
|
||
try:
|
||
json_data = """
|
||
{
|
||
"page":"",
|
||
"boolean":""
|
||
}
|
||
"""
|
||
result = json.loads(json_data)
|
||
|
||
# 打开PDF文件
|
||
pdf_document = fitz.open(file_location)
|
||
for page_num in range(len(pdf_document)):
|
||
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
|
||
|
||
ocr_text = get_ocr(image_filename)
|
||
if "合同价款暂定为人民币含税价小写" in str(ocr_text):
|
||
page = re.search(r'(\d+)\.png', image_filename).group(1)
|
||
result["page"] = page
|
||
if "合同价款结算按第" in str(ocr_text):
|
||
pattern = r"合同价款结算按第(\d+)_种方式"
|
||
match = re.search(pattern, str(ocr_text))
|
||
if match.group(1) == "2":
|
||
result["boolean"] = "是"
|
||
break
|
||
|
||
if result["page"] == "":
|
||
result["page"] = "0"
|
||
if result["boolean"] == "":
|
||
result["boolean"] = "否"
|
||
return JSONResponse(content={"text": result, "pdf_url": file_location, "class": "合同抽取"})
|
||
|
||
except Exception as e:
|
||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||
|
||
|
||
def pdf_2_images(pdf_document, output_folder, page_num):
|
||
# 获取页面
|
||
page = pdf_document.load_page(page_num)
|
||
|
||
# 将页面转换成图像
|
||
pix = page.get_pixmap()
|
||
|
||
# 使用PIL保存图像
|
||
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||
|
||
match = re.search(r'\\([^\\]+)\.pdf$', pdf_document.name)
|
||
contract_name = match.group(1)
|
||
image_filename = "{}{}page_{}.png".format(output_folder, contract_name, page_num + 1)
|
||
img.save(image_filename, "PNG")
|
||
print("合同分页保存位置:", image_filename)
|
||
return image_filename
|
||
|
||
def pagehome_llm(pdf_path):
|
||
result = []
|
||
ocr_text = []
|
||
pdf_document = fitz.open(pdf_path)
|
||
|
||
for page_num in range(len(pdf_document)):
|
||
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
|
||
ocr_text.append(get_ocr(image_filename))
|
||
res = "甲方" and "乙方" and "合同编号" and "签订地点" in str(ocr_text)
|
||
|
||
if res:
|
||
print("合同识别起始页:", page_num)
|
||
result.append(image_filename)
|
||
|
||
image_filename = pdf_2_images(pdf_document, images_dir, page_num + 1)
|
||
ocr_text.append(get_ocr(image_filename))
|
||
result.append(image_filename)
|
||
|
||
image_filename = pdf_2_images(pdf_document, images_dir, page_num + 2)
|
||
ocr_text.append(get_ocr(image_filename))
|
||
result.append(image_filename)
|
||
|
||
return result, ocr_text
|
||
|
||
return result, ocr_text
|
||
|
||
|
||
|
||
|
||
|
||
def find_payment_page(pdf_path, output_folder):
|
||
# text_prompt = config.findPage
|
||
|
||
pdf_document = fitz.open(pdf_path)
|
||
for page_num in range(len(pdf_document)):
|
||
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
|
||
ocr_text = get_ocr(image_filename)
|
||
if "合同价款暂定为人民币含税价小写" in str(ocr_text):
|
||
result = re.search(r'(\d+)\.png', image_filename).group(1)
|
||
print("位置为第" + result + "页")
|
||
return result
|
||
|
||
def extract_images(image_index, pdf_path,output_folder):
|
||
# images_pdf = convert_from_path(pdf_path)
|
||
pdf_document = fitz.open(pdf_path)
|
||
index = int(image_index)
|
||
i = 0
|
||
ocr_texts = []
|
||
text_prompt = config.extract
|
||
if len(pdf_document) >= index:
|
||
while i < 2:
|
||
image_filename = pdf_2_images(pdf_document, images_dir, index-1)
|
||
ocr_text = get_ocr(image_filename)
|
||
ocr_texts.append(ocr_text)
|
||
i += 1
|
||
index += 1
|
||
else:
|
||
print("超出PDF文件长度")
|
||
|
||
output = run_conv(text_prompt, ocr_texts)
|
||
output = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', output)
|
||
json_data = json.loads(output)
|
||
if json_data["result"]:
|
||
result = json_data["result"]
|
||
return JSONResponse(content=result)
|
||
|
||
@app.post("/findAndExtract/")
|
||
async def findAndExtract(request:Request):
|
||
|
||
body = await request.json()
|
||
file_path = body.get("file_path")
|
||
|
||
# 检查文件夹路径是否存在
|
||
if not os.path.exists(file_path):
|
||
return JSONResponse(content={"error": "File path does not exist"}, status_code=400)
|
||
|
||
index = find_payment_page(file_path, images_dir) # 遍历到第几页
|
||
result = extract_images(index, file_path, images_dir)
|
||
json_result = jsonable_encoder(result)
|
||
return JSONResponse(content=json_result)
|
||
|
||
|
||
@app.post('/classify')
|
||
async def get_images(request: Request):
|
||
body = await request.json()
|
||
folder_path = body.get("file_path")
|
||
|
||
# 检查文件夹路径是否存在
|
||
if not os.path.exists(folder_path):
|
||
return JSONResponse(content={"error": "File path does not exist"}, status_code=400)
|
||
|
||
class_names = ['发票', '确认表', '申请表', '验收证书', '其他']
|
||
# 训练权重路径
|
||
weights_path = "../Models/checkpoints/model.pth"
|
||
|
||
# 对图像进行变换
|
||
transform = transforms.Compose([
|
||
transforms.Resize((224, 224)),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.RandomRotation(10),
|
||
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
])
|
||
|
||
result = predict_folder(folder_path, weights_path, transform, class_names)
|
||
json_result = jsonable_encoder(result)
|
||
return JSONResponse(content=json_result)
|
||
|
||
if __name__ == "__main__":
|
||
# ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||
# ssl_context.load_cert_chain(certfile="/home/oem/llm/certificate.crt",keyfile="/home/oem/llm/private.key")
|
||
# config = Config(app=app, host="0.0.0.0", port=8000, ssl_context=ssl_context)
|
||
# uvicorn.run(app, host="0.0.0.0", port=8000,ssl=ssl_context)
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||
|
||
# weights_path = "../Models/checkpoints/model_epoch_1.pth"
|
||
#
|
||
# # 对图像进行变换
|
||
# transform = transforms.Compose([
|
||
# transforms.Resize((224, 224)),
|
||
# transforms.RandomHorizontalFlip(),
|
||
# transforms.RandomRotation(10),
|
||
# transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
||
# transforms.ToTensor(),
|
||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
# ])
|
||
# class_names = ['发票', '确定表', '申请表']
|
||
# print(predict_folder("data/img/分类", weights_path, transform, class_names))
|
||
# app.run(host='0.0.0.0', port=5001)
|
||
|