nwjh/LLMServe/classify.py

235 lines
8.5 KiB
Python
Raw Permalink Normal View History

2025-03-24 09:27:03 +08:00
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from transformers import BertTokenizer, BertModel
import os
import cv2
import numpy as np
from paddleocr import PaddleOCR
class MedicalImageClassifier(nn.Module):
def __init__(self, num_classes):
super(MedicalImageClassifier, self).__init__()
# 使用更轻量级的ResNet模型
self.resnet = models.resnet18(pretrained=False)
self.resnet.load_state_dict(torch.load('models/resnet/resnet18-5c106cde.pth', weights_only=False))
self.resnet.fc = nn.Identity()
# 使用更小的BERT模型
self.bert = BertModel.from_pretrained('models/bert')
# 添加注意力机制
self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=8)
self.classifier = nn.Sequential(
nn.Linear(512 + 768, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, image, encoded_text):
image_features = self.resnet(image)
text_features = self.bert(**encoded_text)[0] # 使用序列输出
# 应用注意力机制
text_features, _ = self.attention(text_features, text_features, text_features)
text_features = torch.mean(text_features, dim=1)
combined_features = torch.cat((image_features, text_features), dim=1)
return self.classifier(combined_features)
def predict_single2(image_path, weights_path, transform, class_names):
"""
使用预训练权重预测单张图片的类别
参数:
- image_path: 要预测的图片路径
- weights_path: 预训练模型权重的路径
- transform: 图片预处理的transform
- class_names: 类别名称的列表
"""
# 加载预训练模型权重
model = MedicalImageClassifier(num_classes=len(class_names))
if torch.cuda.device_count() > 1:
print(f"Let's use {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model) # 将模型对象转变为多GPU并行运算的模型
model.load_state_dict(torch.load(weights_path))
model.eval()
# 图片预处理
image = Image.open(image_path).convert('RGB')
# paddleOCR提取文本信息
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
ocrreader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch')
ocr_result = ocrreader.ocr(img_cv)
text = ""
keywords = class_names
for result in ocr_result:
if result:
for item in result:
if item[1]:
text_temp, _ = item[1] # 提取文本,忽略置信度
text += text_temp + " "
else:
text += "无识别文本"
# item[1] 是一个元组,其中第一个元素是识别的文本
# text_temp, _ = item[1] # 提取文本,忽略置信度
# text += text_temp + " "
if any(keyword in text for keyword in keywords):
print(f"找到关键字停止提取。找到的关键字1'{text}'")
break
if any(keyword in text for keyword in keywords):
print(f"找到关键字停止提取。找到的关键字2'{text}'")
break
else:
text += "无识别文本"
text = ' '.join([str(item) for item in ocr_result]) if ocr_result else "无识别文本"
max_length = 512
tokenizer = BertTokenizer.from_pretrained('models/bert')
encoded_text = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length = max_length,
padding = 'max_length',
truncation = True,
return_attention_mask=True,
return_tensors='pt'
)
image = transform(image).unsqueeze(0) # 增加batch维度
# 将图片移动到设备CPU或GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image = image.to(device)
encoded_text = encoded_text.to(device)
model = model.to(device)
# 预测
with torch.no_grad():
# outputs = model(image, {'input_ids': torch.zeros((1, 1), dtype=torch.long).to(device),
# 'attention_mask': torch.zeros((1, 1), dtype=torch.long).to(device)})
outputs = model(image, encoded_text)
# 获取预测结果
_, predicted = torch.max(outputs, 1)
predicted_class = class_names[predicted.item()]
return predicted_class
def predict_folder(images_dir, weights_path, transform, class_names, result_path='result/predict_results.csv'):
"""
使用预训练权重预测文件夹下所有图片的类别
参数:
- images_dir: 包含要预测图片的文件夹路径
- weights_path: 预训练模型权重的路径
- transform: 图片预处理的transform
- class_names: 类别名称的列表
"""
# 加载预训练模型权重
model = MedicalImageClassifier(num_classes=len(class_names))
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
model.eval()
# 确保文件夹路径存在
if not os.path.exists(images_dir) or not os.path.isdir(images_dir):
raise ValueError(f"提供的路径不是一个有效的目录:{images_dir}")
# 初始化OCR
ocrreader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch', det_model_dir='../Models/ch_PP-OCRv4_det_infer',rec_model_dir='../Models/ch_PP-OCRv4_rec_infer', cls_model_dir='../Models/ch_ppocr_mobile_v2.0_cls_infer')
tokenizer = BertTokenizer.from_pretrained('models/bert')
keywords = class_names
text = ""
images_info = []
# 遍历文件夹中的每个图片文件
for img_filename in os.listdir(images_dir):
img_path = os.path.join(images_dir, img_filename)
# 检查文件是否是图片
if os.path.isfile(img_path) and img_path.lower().endswith(('.png', '.jpg', '.jpeg')):
# 图片预处理
image = Image.open(img_path).convert('RGB')
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
# 使用OCR提取文本信息
ocr_result = ocrreader.ocr(img_cv)
for result in ocr_result:
if result:
for item in result:
if item[1]:
text_temp, _ = item[1] # 提取文本,忽略置信度
text += text_temp + " "
else:
text += "无识别文本"
else:
text += "无识别文本"
text = text[:50]
max_length = 512
encoded_text = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
image = transform(image).unsqueeze(0) # 增加batch维度
# 将图片和文本编码移动到设备CPU或GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image = image.to(device)
encoded_text = {key: val.to(device) for key, val in encoded_text.items()}
model = model.to(device)
# 预测
with torch.no_grad():
outputs = model(image, encoded_text)
# 获取预测结果
_, predicted = torch.max(outputs, 1)
predicted_class = class_names[predicted.item()]
# 打印或保存预测结果
image_info = {
'图像地址': img_path,
'class': predicted_class
}
images_info.append(image_info)
print(f"图片:{img_filename} 预测类别:{predicted_class}")
return images_info
if __name__ == '__main__':
weights_path = "../Models/checkpoints/model_epoch_8.pth"
# 对图像进行变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class_names = ['发票', '其他']
print(predict_folder("images/1", weights_path, transform, class_names))