nwjh/LLMServe/server.py
2025-03-24 09:27:03 +08:00

377 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
import ssl
from uvicorn import Config, Server
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
import requests
import uvicorn
import config
from pydantic import BaseModel
from typing import Union, List
import fitz
from PIL import Image
import json
import re
import os
from openai import OpenAI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from classify import predict_folder
from ocr import get_ocr_list, get_ocr_image_list, get_ocr
from torchvision import transforms, models
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
app.mount("/data", StaticFiles(directory="data"), name="data")
# 确保上传目录存在
current_directory = os.getcwd()
upload_dir = os.path.join(current_directory, "data\\")
images_dir = upload_dir + "img\\"
pdfs_dir = upload_dir + "content\\"
os.makedirs(images_dir, exist_ok=True)
os.makedirs(pdfs_dir, exist_ok=True)
def run_conv(text_prompt, ocr_text):
print("begin deepseek ask")
messages = [
{'role': 'user', 'content': str(ocr_text)},
{'role': 'assistant', 'content': text_prompt}
]
# client = OpenAI(
# api_key="sk-0ffaa2ae7c5c499aa7fd03e646b6717a",
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
# )
client = OpenAI(api_key='ollama',
base_url="http://localhost:11434/v1")
response = client.chat.completions.create(
# model="qwen-plus",
model = "deepseek-r1",
messages=messages
)
final_response = response.choices[0].message.content
match = re.search(r"```json(.*?)```", final_response, re.DOTALL)
res = match.group(1)
print("llm result: ", res)
return res
# 发票
@app.post("/invoice/")
async def invoice(files: List[UploadFile] = File(...)):
image_url = []
for file in files:
if file.content_type.startswith('image/'):
# 保存图片文件
file_location = images_dir + file.filename
else:
raise HTTPException(status_code=400, detail="Unsupported file type")
with open(file_location, "wb+") as file_object:
file_object.write(await file.read())
image_url.append(file_location)
print("image_path: ", image_url)
try:
text_prompt = config.invoice
ocr_text = get_ocr_list(image_url) # 进行OCR返回 String
output = run_conv(text_prompt, ocr_text) # 调用大模型
new_url = get_ocr_image_list(image_url, output)
return JSONResponse(content={"text": output,"image_url": new_url, "class": "发票"})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
# 申请表
@app.post("/application/")
async def application(files: List[UploadFile] = File(...)):
image_url = []
for file in files:
if file.content_type.startswith('image/'):
# 保存图片文件
file_location = images_dir + file.filename
else:
raise HTTPException(status_code=400, detail="Unsupported file type")
with open(file_location, "wb+") as file_object:
file_object.write(await file.read())
image_url.append(file_location)
try:
text_prompt = config.application
ocr_text = get_ocr_list(image_url) # 进行OCR返回String
output = run_conv(text_prompt, ocr_text) # 调用大模型
return JSONResponse(content={"text": output, "image_url":image_url, "class":"申请单"})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
# 确认表
@app.post("/confirmation/")
async def confirmation(files: List[UploadFile] = File(...)):
image_url = []
for file in files:
if file.content_type.startswith('image/'):
# 保存图片文件
file_location = images_dir + file.filename
else:
raise HTTPException(status_code=400, detail="Unsupported file type")
with open(file_location, "wb+") as file_object:
file_object.write(await file.read())
image_url.append(file_location)
try:
text_prompt = config.confirmation
ocr_text = get_ocr_list(image_url)
output = run_conv(text_prompt, ocr_text)
return JSONResponse(content={"text": output, "image_url":image_url, "class":"确认表"})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
# 合同
@app.post("/contract/")
async def contract(files: UploadFile = File(...)):
if files.content_type != 'application/pdf':
raise HTTPException(status_code=400, detail="文件类型不符")
random_filename = str(uuid.uuid4())
file_location = pdfs_dir + f"{random_filename}.pdf"
with open(file_location, "wb+") as file_object:
file_object.write(await files.read())
print("合同上传位置:", file_location)
try:
img_url, ocr_text = pagehome_llm(file_location)
text_prompt = config.contract
print("contract img_url: ", img_url)
output = run_conv(text_prompt, ocr_text)
new_url = get_ocr_image_list(img_url, output)
return JSONResponse(content={"text": output, "image_url": new_url, "class": "合同表"})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.post("/contractPayment/")
async def contract_amount(files: UploadFile = File(...)):
if files.content_type != 'application/pdf':
raise HTTPException(status_code=400, detail="Unsupported file type")
random_filename = str(uuid.uuid4())
file_location = pdfs_dir + f"{random_filename}.pdf"
with open(file_location, "wb+") as file_object:
file_object.write(await files.read())
print("合同上传位置:", file_location)
try:
json_data = """
{
"page":"",
"boolean":""
}
"""
result = json.loads(json_data)
# 打开PDF文件
pdf_document = fitz.open(file_location)
for page_num in range(len(pdf_document)):
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
ocr_text = get_ocr(image_filename)
if "合同价款暂定为人民币含税价小写" in str(ocr_text):
page = re.search(r'(\d+)\.png', image_filename).group(1)
result["page"] = page
if "合同价款结算按第" in str(ocr_text):
pattern = r"合同价款结算按第(\d+)_种方式"
match = re.search(pattern, str(ocr_text))
if match.group(1) == "2":
result["boolean"] = ""
break
if result["page"] == "":
result["page"] = "0"
if result["boolean"] == "":
result["boolean"] = ""
return JSONResponse(content={"text": result, "pdf_url": file_location, "class": "合同抽取"})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
def pdf_2_images(pdf_document, output_folder, page_num):
# 获取页面
page = pdf_document.load_page(page_num)
# 将页面转换成图像
pix = page.get_pixmap()
# 使用PIL保存图像
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
match = re.search(r'\\([^\\]+)\.pdf$', pdf_document.name)
contract_name = match.group(1)
image_filename = "{}{}page_{}.png".format(output_folder, contract_name, page_num + 1)
img.save(image_filename, "PNG")
print("合同分页保存位置:", image_filename)
return image_filename
def pagehome_llm(pdf_path):
result = []
ocr_text = []
pdf_document = fitz.open(pdf_path)
for page_num in range(len(pdf_document)):
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
ocr_text.append(get_ocr(image_filename))
res = "甲方" and "乙方" and "合同编号" and "签订地点" in str(ocr_text)
if res:
print("合同识别起始页:", page_num)
result.append(image_filename)
image_filename = pdf_2_images(pdf_document, images_dir, page_num + 1)
ocr_text.append(get_ocr(image_filename))
result.append(image_filename)
image_filename = pdf_2_images(pdf_document, images_dir, page_num + 2)
ocr_text.append(get_ocr(image_filename))
result.append(image_filename)
return result, ocr_text
return result, ocr_text
def find_payment_page(pdf_path, output_folder):
# text_prompt = config.findPage
pdf_document = fitz.open(pdf_path)
for page_num in range(len(pdf_document)):
image_filename = pdf_2_images(pdf_document, images_dir, page_num)
ocr_text = get_ocr(image_filename)
if "合同价款暂定为人民币含税价小写" in str(ocr_text):
result = re.search(r'(\d+)\.png', image_filename).group(1)
print("位置为第" + result + "")
return result
def extract_images(image_index, pdf_path,output_folder):
# images_pdf = convert_from_path(pdf_path)
pdf_document = fitz.open(pdf_path)
index = int(image_index)
i = 0
ocr_texts = []
text_prompt = config.extract
if len(pdf_document) >= index:
while i < 2:
image_filename = pdf_2_images(pdf_document, images_dir, index-1)
ocr_text = get_ocr(image_filename)
ocr_texts.append(ocr_text)
i += 1
index += 1
else:
print("超出PDF文件长度")
output = run_conv(text_prompt, ocr_texts)
output = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', output)
json_data = json.loads(output)
if json_data["result"]:
result = json_data["result"]
return JSONResponse(content=result)
@app.post("/findAndExtract/")
async def findAndExtract(request:Request):
body = await request.json()
file_path = body.get("file_path")
# 检查文件夹路径是否存在
if not os.path.exists(file_path):
return JSONResponse(content={"error": "File path does not exist"}, status_code=400)
index = find_payment_page(file_path, images_dir) # 遍历到第几页
result = extract_images(index, file_path, images_dir)
json_result = jsonable_encoder(result)
return JSONResponse(content=json_result)
@app.post('/classify')
async def get_images(request: Request):
body = await request.json()
folder_path = body.get("file_path")
# 检查文件夹路径是否存在
if not os.path.exists(folder_path):
return JSONResponse(content={"error": "File path does not exist"}, status_code=400)
class_names = ['发票', '确认表', '申请表', '验收证书', '其他']
# 训练权重路径
weights_path = "../Models/checkpoints/model.pth"
# 对图像进行变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
result = predict_folder(folder_path, weights_path, transform, class_names)
json_result = jsonable_encoder(result)
return JSONResponse(content=json_result)
if __name__ == "__main__":
# ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
# ssl_context.load_cert_chain(certfile="/home/oem/llm/certificate.crt",keyfile="/home/oem/llm/private.key")
# config = Config(app=app, host="0.0.0.0", port=8000, ssl_context=ssl_context)
# uvicorn.run(app, host="0.0.0.0", port=8000,ssl=ssl_context)
uvicorn.run(app, host="0.0.0.0", port=8000)
# weights_path = "../Models/checkpoints/model_epoch_1.pth"
#
# # 对图像进行变换
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.RandomHorizontalFlip(),
# transforms.RandomRotation(10),
# transforms.ColorJitter(brightness=0.2, contrast=0.2),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
# class_names = ['发票', '确定表', '申请表']
# print(predict_folder("data/img/分类", weights_path, transform, class_names))
# app.run(host='0.0.0.0', port=5001)