235 lines
8.5 KiB
Python
235 lines
8.5 KiB
Python
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
from torchvision import transforms, models
|
|||
|
|
from PIL import Image
|
|||
|
|
from transformers import BertTokenizer, BertModel
|
|||
|
|
import os
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from paddleocr import PaddleOCR
|
|||
|
|
|
|||
|
|
class MedicalImageClassifier(nn.Module):
|
|||
|
|
def __init__(self, num_classes):
|
|||
|
|
super(MedicalImageClassifier, self).__init__()
|
|||
|
|
# 使用更轻量级的ResNet模型
|
|||
|
|
self.resnet = models.resnet18(pretrained=False)
|
|||
|
|
self.resnet.load_state_dict(torch.load('models/resnet/resnet18-5c106cde.pth', weights_only=False))
|
|||
|
|
self.resnet.fc = nn.Identity()
|
|||
|
|
|
|||
|
|
# 使用更小的BERT模型
|
|||
|
|
self.bert = BertModel.from_pretrained('models/bert')
|
|||
|
|
|
|||
|
|
# 添加注意力机制
|
|||
|
|
self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=8)
|
|||
|
|
|
|||
|
|
self.classifier = nn.Sequential(
|
|||
|
|
nn.Linear(512 + 768, 256),
|
|||
|
|
nn.ReLU(),
|
|||
|
|
nn.Dropout(0.5),
|
|||
|
|
nn.Linear(256, num_classes)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def forward(self, image, encoded_text):
|
|||
|
|
image_features = self.resnet(image)
|
|||
|
|
text_features = self.bert(**encoded_text)[0] # 使用序列输出
|
|||
|
|
|
|||
|
|
# 应用注意力机制
|
|||
|
|
text_features, _ = self.attention(text_features, text_features, text_features)
|
|||
|
|
text_features = torch.mean(text_features, dim=1)
|
|||
|
|
|
|||
|
|
combined_features = torch.cat((image_features, text_features), dim=1)
|
|||
|
|
return self.classifier(combined_features)
|
|||
|
|
|
|||
|
|
def predict_single2(image_path, weights_path, transform, class_names):
|
|||
|
|
"""
|
|||
|
|
使用预训练权重预测单张图片的类别。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- image_path: 要预测的图片路径。
|
|||
|
|
- weights_path: 预训练模型权重的路径。
|
|||
|
|
- transform: 图片预处理的transform。
|
|||
|
|
- class_names: 类别名称的列表。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 加载预训练模型权重
|
|||
|
|
model = MedicalImageClassifier(num_classes=len(class_names))
|
|||
|
|
|
|||
|
|
if torch.cuda.device_count() > 1:
|
|||
|
|
print(f"Let's use {torch.cuda.device_count()} GPUs!")
|
|||
|
|
model = nn.DataParallel(model) # 将模型对象转变为多GPU并行运算的模型
|
|||
|
|
|
|||
|
|
model.load_state_dict(torch.load(weights_path))
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
# 图片预处理
|
|||
|
|
image = Image.open(image_path).convert('RGB')
|
|||
|
|
|
|||
|
|
# paddleOCR提取文本信息
|
|||
|
|
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
|||
|
|
ocrreader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch')
|
|||
|
|
ocr_result = ocrreader.ocr(img_cv)
|
|||
|
|
text = ""
|
|||
|
|
keywords = class_names
|
|||
|
|
for result in ocr_result:
|
|||
|
|
if result:
|
|||
|
|
for item in result:
|
|||
|
|
if item[1]:
|
|||
|
|
text_temp, _ = item[1] # 提取文本,忽略置信度
|
|||
|
|
text += text_temp + " "
|
|||
|
|
else:
|
|||
|
|
text += "无识别文本"
|
|||
|
|
# item[1] 是一个元组,其中第一个元素是识别的文本
|
|||
|
|
# text_temp, _ = item[1] # 提取文本,忽略置信度
|
|||
|
|
# text += text_temp + " "
|
|||
|
|
if any(keyword in text for keyword in keywords):
|
|||
|
|
print(f"找到关键字,停止提取。找到的关键字1:'{text}'")
|
|||
|
|
break
|
|||
|
|
if any(keyword in text for keyword in keywords):
|
|||
|
|
print(f"找到关键字,停止提取。找到的关键字2:'{text}'")
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
text += "无识别文本"
|
|||
|
|
text = ' '.join([str(item) for item in ocr_result]) if ocr_result else "无识别文本"
|
|||
|
|
max_length = 512
|
|||
|
|
tokenizer = BertTokenizer.from_pretrained('models/bert')
|
|||
|
|
encoded_text = tokenizer.encode_plus(
|
|||
|
|
text,
|
|||
|
|
add_special_tokens=True,
|
|||
|
|
max_length = max_length,
|
|||
|
|
padding = 'max_length',
|
|||
|
|
truncation = True,
|
|||
|
|
return_attention_mask=True,
|
|||
|
|
return_tensors='pt'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
image = transform(image).unsqueeze(0) # 增加batch维度
|
|||
|
|
|
|||
|
|
# 将图片移动到设备(CPU或GPU)
|
|||
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
|
image = image.to(device)
|
|||
|
|
encoded_text = encoded_text.to(device)
|
|||
|
|
model = model.to(device)
|
|||
|
|
|
|||
|
|
# 预测
|
|||
|
|
with torch.no_grad():
|
|||
|
|
# outputs = model(image, {'input_ids': torch.zeros((1, 1), dtype=torch.long).to(device),
|
|||
|
|
# 'attention_mask': torch.zeros((1, 1), dtype=torch.long).to(device)})
|
|||
|
|
outputs = model(image, encoded_text)
|
|||
|
|
|
|||
|
|
# 获取预测结果
|
|||
|
|
_, predicted = torch.max(outputs, 1)
|
|||
|
|
predicted_class = class_names[predicted.item()]
|
|||
|
|
|
|||
|
|
return predicted_class
|
|||
|
|
|
|||
|
|
|
|||
|
|
def predict_folder(images_dir, weights_path, transform, class_names, result_path='result/predict_results.csv'):
|
|||
|
|
"""
|
|||
|
|
使用预训练权重预测文件夹下所有图片的类别。
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- images_dir: 包含要预测图片的文件夹路径。
|
|||
|
|
- weights_path: 预训练模型权重的路径。
|
|||
|
|
- transform: 图片预处理的transform。
|
|||
|
|
- class_names: 类别名称的列表。
|
|||
|
|
"""
|
|||
|
|
# 加载预训练模型权重
|
|||
|
|
model = MedicalImageClassifier(num_classes=len(class_names))
|
|||
|
|
|
|||
|
|
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
# 确保文件夹路径存在
|
|||
|
|
if not os.path.exists(images_dir) or not os.path.isdir(images_dir):
|
|||
|
|
raise ValueError(f"提供的路径不是一个有效的目录:{images_dir}")
|
|||
|
|
|
|||
|
|
# 初始化OCR
|
|||
|
|
ocrreader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch', det_model_dir='../Models/ch_PP-OCRv4_det_infer',rec_model_dir='../Models/ch_PP-OCRv4_rec_infer', cls_model_dir='../Models/ch_ppocr_mobile_v2.0_cls_infer')
|
|||
|
|
tokenizer = BertTokenizer.from_pretrained('models/bert')
|
|||
|
|
|
|||
|
|
keywords = class_names
|
|||
|
|
text = ""
|
|||
|
|
images_info = []
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 遍历文件夹中的每个图片文件
|
|||
|
|
for img_filename in os.listdir(images_dir):
|
|||
|
|
img_path = os.path.join(images_dir, img_filename)
|
|||
|
|
|
|||
|
|
# 检查文件是否是图片
|
|||
|
|
if os.path.isfile(img_path) and img_path.lower().endswith(('.png', '.jpg', '.jpeg')):
|
|||
|
|
# 图片预处理
|
|||
|
|
image = Image.open(img_path).convert('RGB')
|
|||
|
|
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
|||
|
|
|
|||
|
|
# 使用OCR提取文本信息
|
|||
|
|
ocr_result = ocrreader.ocr(img_cv)
|
|||
|
|
for result in ocr_result:
|
|||
|
|
if result:
|
|||
|
|
for item in result:
|
|||
|
|
if item[1]:
|
|||
|
|
text_temp, _ = item[1] # 提取文本,忽略置信度
|
|||
|
|
text += text_temp + " "
|
|||
|
|
else:
|
|||
|
|
text += "无识别文本"
|
|||
|
|
else:
|
|||
|
|
text += "无识别文本"
|
|||
|
|
|
|||
|
|
text = text[:50]
|
|||
|
|
max_length = 512
|
|||
|
|
encoded_text = tokenizer.encode_plus(
|
|||
|
|
text,
|
|||
|
|
add_special_tokens=True,
|
|||
|
|
max_length=max_length,
|
|||
|
|
padding='max_length',
|
|||
|
|
truncation=True,
|
|||
|
|
return_attention_mask=True,
|
|||
|
|
return_tensors='pt'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
image = transform(image).unsqueeze(0) # 增加batch维度
|
|||
|
|
|
|||
|
|
# 将图片和文本编码移动到设备(CPU或GPU)
|
|||
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
|
image = image.to(device)
|
|||
|
|
encoded_text = {key: val.to(device) for key, val in encoded_text.items()}
|
|||
|
|
model = model.to(device)
|
|||
|
|
|
|||
|
|
# 预测
|
|||
|
|
with torch.no_grad():
|
|||
|
|
outputs = model(image, encoded_text)
|
|||
|
|
|
|||
|
|
# 获取预测结果
|
|||
|
|
_, predicted = torch.max(outputs, 1)
|
|||
|
|
predicted_class = class_names[predicted.item()]
|
|||
|
|
|
|||
|
|
# 打印或保存预测结果
|
|||
|
|
image_info = {
|
|||
|
|
'图像地址': img_path,
|
|||
|
|
'class': predicted_class
|
|||
|
|
}
|
|||
|
|
images_info.append(image_info)
|
|||
|
|
print(f"图片:{img_filename} 预测类别:{predicted_class}")
|
|||
|
|
|
|||
|
|
return images_info
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
weights_path = "../Models/checkpoints/model_epoch_8.pth"
|
|||
|
|
|
|||
|
|
# 对图像进行变换
|
|||
|
|
transform = transforms.Compose([
|
|||
|
|
transforms.Resize((224, 224)),
|
|||
|
|
transforms.RandomHorizontalFlip(),
|
|||
|
|
transforms.RandomRotation(10),
|
|||
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|||
|
|
])
|
|||
|
|
class_names = ['发票', '其他']
|
|||
|
|
print(predict_folder("images/1", weights_path, transform, class_names))
|
|||
|
|
|
|||
|
|
|
|||
|
|
|