import uuid from fastapi import FastAPI, File, UploadFile, HTTPException, Request import ssl from uvicorn import Config, Server from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder import requests import uvicorn import config from pydantic import BaseModel from typing import Union, List import fitz from PIL import Image import json import re import os from openai import OpenAI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from classify import predict_folder from ocr import get_ocr_list, get_ocr_image_list, get_ocr from torchvision import transforms, models app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) app.mount("/data", StaticFiles(directory="data"), name="data") # 确保上传目录存在 current_directory = os.getcwd() upload_dir = os.path.join(current_directory, "data\\") images_dir = upload_dir + "img\\" pdfs_dir = upload_dir + "content\\" os.makedirs(images_dir, exist_ok=True) os.makedirs(pdfs_dir, exist_ok=True) def run_conv(text_prompt, ocr_text): print("begin deepseek ask") messages = [ {'role': 'user', 'content': str(ocr_text)}, {'role': 'assistant', 'content': text_prompt} ] # client = OpenAI( # api_key="sk-0ffaa2ae7c5c499aa7fd03e646b6717a", # base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" # ) client = OpenAI(api_key='ollama', base_url="http://localhost:11434/v1") response = client.chat.completions.create( # model="qwen-plus", model = "deepseek-r1", messages=messages ) final_response = response.choices[0].message.content match = re.search(r"```json(.*?)```", final_response, re.DOTALL) res = match.group(1) print("llm result: ", res) return res # 发票 @app.post("/invoice/") async def invoice(files: List[UploadFile] = File(...)): image_url = [] for file in files: if file.content_type.startswith('image/'): # 保存图片文件 file_location = images_dir + file.filename else: raise HTTPException(status_code=400, detail="Unsupported file type") with open(file_location, "wb+") as file_object: file_object.write(await file.read()) image_url.append(file_location) print("image_path: ", image_url) try: text_prompt = config.invoice ocr_text = get_ocr_list(image_url) # 进行OCR,返回 String output = run_conv(text_prompt, ocr_text) # 调用大模型 new_url = get_ocr_image_list(image_url, output) return JSONResponse(content={"text": output,"image_url": new_url, "class": "发票"}) except Exception as e: return JSONResponse(status_code=400, content={"error": str(e)}) # 申请表 @app.post("/application/") async def application(files: List[UploadFile] = File(...)): image_url = [] for file in files: if file.content_type.startswith('image/'): # 保存图片文件 file_location = images_dir + file.filename else: raise HTTPException(status_code=400, detail="Unsupported file type") with open(file_location, "wb+") as file_object: file_object.write(await file.read()) image_url.append(file_location) try: text_prompt = config.application ocr_text = get_ocr_list(image_url) # 进行OCR,返回String output = run_conv(text_prompt, ocr_text) # 调用大模型 return JSONResponse(content={"text": output, "image_url":image_url, "class":"申请单"}) except Exception as e: return JSONResponse(status_code=400, content={"error": str(e)}) # 确认表 @app.post("/confirmation/") async def confirmation(files: List[UploadFile] = File(...)): image_url = [] for file in files: if file.content_type.startswith('image/'): # 保存图片文件 file_location = images_dir + file.filename else: raise HTTPException(status_code=400, detail="Unsupported file type") with open(file_location, "wb+") as file_object: file_object.write(await file.read()) image_url.append(file_location) try: text_prompt = config.confirmation ocr_text = get_ocr_list(image_url) output = run_conv(text_prompt, ocr_text) return JSONResponse(content={"text": output, "image_url":image_url, "class":"确认表"}) except Exception as e: return JSONResponse(status_code=400, content={"error": str(e)}) # 合同 @app.post("/contract/") async def contract(files: UploadFile = File(...)): if files.content_type != 'application/pdf': raise HTTPException(status_code=400, detail="文件类型不符") random_filename = str(uuid.uuid4()) file_location = pdfs_dir + f"{random_filename}.pdf" with open(file_location, "wb+") as file_object: file_object.write(await files.read()) print("合同上传位置:", file_location) try: img_url, ocr_text = pagehome_llm(file_location) text_prompt = config.contract print("contract img_url: ", img_url) output = run_conv(text_prompt, ocr_text) new_url = get_ocr_image_list(img_url, output) return JSONResponse(content={"text": output, "image_url": new_url, "class": "合同表"}) except Exception as e: return JSONResponse(status_code=400, content={"error": str(e)}) @app.post("/contractPayment/") async def contract_amount(files: UploadFile = File(...)): if files.content_type != 'application/pdf': raise HTTPException(status_code=400, detail="Unsupported file type") random_filename = str(uuid.uuid4()) file_location = pdfs_dir + f"{random_filename}.pdf" with open(file_location, "wb+") as file_object: file_object.write(await files.read()) print("合同上传位置:", file_location) try: json_data = """ { "page":"", "boolean":"" } """ result = json.loads(json_data) # 打开PDF文件 pdf_document = fitz.open(file_location) for page_num in range(len(pdf_document)): image_filename = pdf_2_images(pdf_document, images_dir, page_num) ocr_text = get_ocr(image_filename) if "合同价款暂定为人民币含税价小写" in str(ocr_text): page = re.search(r'(\d+)\.png', image_filename).group(1) result["page"] = page if "合同价款结算按第" in str(ocr_text): pattern = r"合同价款结算按第(\d+)_种方式" match = re.search(pattern, str(ocr_text)) if match.group(1) == "2": result["boolean"] = "是" break if result["page"] == "": result["page"] = "0" if result["boolean"] == "": result["boolean"] = "否" return JSONResponse(content={"text": result, "pdf_url": file_location, "class": "合同抽取"}) except Exception as e: return JSONResponse(status_code=400, content={"error": str(e)}) def pdf_2_images(pdf_document, output_folder, page_num): # 获取页面 page = pdf_document.load_page(page_num) # 将页面转换成图像 pix = page.get_pixmap() # 使用PIL保存图像 img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) match = re.search(r'\\([^\\]+)\.pdf$', pdf_document.name) contract_name = match.group(1) image_filename = "{}{}page_{}.png".format(output_folder, contract_name, page_num + 1) img.save(image_filename, "PNG") print("合同分页保存位置:", image_filename) return image_filename def pagehome_llm(pdf_path): result = [] ocr_text = [] pdf_document = fitz.open(pdf_path) for page_num in range(len(pdf_document)): image_filename = pdf_2_images(pdf_document, images_dir, page_num) ocr_text.append(get_ocr(image_filename)) res = "甲方" and "乙方" and "合同编号" and "签订地点" in str(ocr_text) if res: print("合同识别起始页:", page_num) result.append(image_filename) image_filename = pdf_2_images(pdf_document, images_dir, page_num + 1) ocr_text.append(get_ocr(image_filename)) result.append(image_filename) image_filename = pdf_2_images(pdf_document, images_dir, page_num + 2) ocr_text.append(get_ocr(image_filename)) result.append(image_filename) return result, ocr_text return result, ocr_text def find_payment_page(pdf_path, output_folder): # text_prompt = config.findPage pdf_document = fitz.open(pdf_path) for page_num in range(len(pdf_document)): image_filename = pdf_2_images(pdf_document, images_dir, page_num) ocr_text = get_ocr(image_filename) if "合同价款暂定为人民币含税价小写" in str(ocr_text): result = re.search(r'(\d+)\.png', image_filename).group(1) print("位置为第" + result + "页") return result def extract_images(image_index, pdf_path,output_folder): # images_pdf = convert_from_path(pdf_path) pdf_document = fitz.open(pdf_path) index = int(image_index) i = 0 ocr_texts = [] text_prompt = config.extract if len(pdf_document) >= index: while i < 2: image_filename = pdf_2_images(pdf_document, images_dir, index-1) ocr_text = get_ocr(image_filename) ocr_texts.append(ocr_text) i += 1 index += 1 else: print("超出PDF文件长度") output = run_conv(text_prompt, ocr_texts) output = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', output) json_data = json.loads(output) if json_data["result"]: result = json_data["result"] return JSONResponse(content=result) @app.post("/findAndExtract/") async def findAndExtract(request:Request): body = await request.json() file_path = body.get("file_path") # 检查文件夹路径是否存在 if not os.path.exists(file_path): return JSONResponse(content={"error": "File path does not exist"}, status_code=400) index = find_payment_page(file_path, images_dir) # 遍历到第几页 result = extract_images(index, file_path, images_dir) json_result = jsonable_encoder(result) return JSONResponse(content=json_result) @app.post('/classify') async def get_images(request: Request): body = await request.json() folder_path = body.get("file_path") # 检查文件夹路径是否存在 if not os.path.exists(folder_path): return JSONResponse(content={"error": "File path does not exist"}, status_code=400) class_names = ['发票', '确认表', '申请表', '验收证书', '其他'] # 训练权重路径 weights_path = "../Models/checkpoints/model.pth" # 对图像进行变换 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) result = predict_folder(folder_path, weights_path, transform, class_names) json_result = jsonable_encoder(result) return JSONResponse(content=json_result) if __name__ == "__main__": # ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) # ssl_context.load_cert_chain(certfile="/home/oem/llm/certificate.crt",keyfile="/home/oem/llm/private.key") # config = Config(app=app, host="0.0.0.0", port=8000, ssl_context=ssl_context) # uvicorn.run(app, host="0.0.0.0", port=8000,ssl=ssl_context) uvicorn.run(app, host="0.0.0.0", port=8000) # weights_path = "../Models/checkpoints/model_epoch_1.pth" # # # 对图像进行变换 # transform = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(10), # transforms.ColorJitter(brightness=0.2, contrast=0.2), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ]) # class_names = ['发票', '确定表', '申请表'] # print(predict_folder("data/img/分类", weights_path, transform, class_names)) # app.run(host='0.0.0.0', port=5001)