495 lines
18 KiB
Python
495 lines
18 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from torchvision import transforms, models
|
||
from PIL import Image
|
||
from transformers import BertTokenizer, BertModel
|
||
import os
|
||
import csv
|
||
import logging
|
||
import cv2
|
||
import numpy as np
|
||
import pandas as pd
|
||
from sklearn.model_selection import train_test_split
|
||
import requests
|
||
from io import BytesIO
|
||
import time
|
||
import torch.multiprocessing as mp
|
||
import hashlib
|
||
import json
|
||
from paddleocr import PaddleOCR
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s',
|
||
filename='logs/medical_image_classifier.log', # 日志文件保存路径
|
||
filemode='w' # 以写入模式打开(覆盖现有文件)
|
||
)
|
||
|
||
# 初始化 PaddleOCR reader
|
||
reader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch', det_model_dir='../Models/ch_PP-OCRv4_det_infer',
|
||
rec_model_dir='../Models/ch_PP-OCRv4_rec_infer',
|
||
cls_model_dir='../Models/ch_ppocr_mobile_v2.0_cls_infer')
|
||
|
||
def get_cache_path(img_url, cache_dir):
|
||
"""生成缓存文件的路径,使用图片URL的哈希值来唯一标识缓存文件"""
|
||
url_hash = hashlib.md5(img_url.encode()).hexdigest()
|
||
return os.path.join(cache_dir, f"{url_hash}.json")
|
||
|
||
def save_cache(cache_path, ocr_result):
|
||
"""保存OCR结果到缓存文件"""
|
||
with open(cache_path, 'w', encoding='utf-8') as f:
|
||
json.dump(ocr_result, f, ensure_ascii=False)
|
||
|
||
def load_cache(cache_path):
|
||
"""从缓存文件加载OCR结果"""
|
||
if os.path.exists(cache_path):
|
||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
return None
|
||
|
||
def load_excel_dataset(file_path, max_samples_per_class=500, test_size=0.2, random_state=42):
|
||
# 读取Excel文件
|
||
xls = pd.ExcelFile(file_path)
|
||
|
||
all_data = []
|
||
class_names = []
|
||
|
||
# 遍历每个工作表
|
||
for sheet_name in xls.sheet_names:
|
||
df = pd.read_excel(xls, sheet_name, header=None)
|
||
if df.empty:
|
||
logging.info(f"工作表 '{sheet_name}' 为空,已跳过")
|
||
continue
|
||
|
||
image_urls = df.iloc[:, 0].tolist() # 第一列为图片链接
|
||
|
||
# 限制每个类别的样本数量
|
||
if len(image_urls) > max_samples_per_class:
|
||
image_urls = image_urls[:max_samples_per_class]
|
||
|
||
labels = [sheet_name] * len(image_urls)
|
||
all_data.extend(list(zip(image_urls, labels)))
|
||
class_names.append(sheet_name)
|
||
|
||
if not all_data:
|
||
raise ValueError("所有工作表都为空,没有可用的数据")
|
||
|
||
# 划分训练集和测试集
|
||
train_data, test_data = train_test_split(all_data, test_size=test_size, random_state=random_state, stratify=[label for _, label in all_data])
|
||
|
||
print("图片数据集加载完成")
|
||
print(f"有效类别数量: {len(class_names)}")
|
||
return train_data, test_data, class_names
|
||
|
||
|
||
def load_folder_dataset(folder_path, max_samples_per_class=500, test_size=0.2, random_state=42):
|
||
all_data = []
|
||
class_names = []
|
||
|
||
# 遍历文件夹中的每个子文件夹
|
||
for class_name in os.listdir(folder_path):
|
||
class_path = os.path.join(folder_path, class_name)
|
||
|
||
# 确保是目录(即类别的子文件夹)
|
||
if not os.path.isdir(class_path):
|
||
logging.info(f"'{class_name}' 不是文件夹,已跳过")
|
||
continue
|
||
|
||
# 获取该类别的所有图片文件路径
|
||
image_files = [os.path.join(class_path, f) for f in os.listdir(class_path) if
|
||
os.path.isfile(os.path.join(class_path, f))]
|
||
|
||
if not image_files:
|
||
logging.info(f"文件夹 '{class_name}' 中没有图片文件,已跳过")
|
||
continue
|
||
|
||
# 限制每个类别的样本数量
|
||
if len(image_files) > max_samples_per_class:
|
||
image_files = image_files[:max_samples_per_class]
|
||
|
||
# 创建数据条目 (图片路径, 标签)
|
||
labels = [class_name] * len(image_files)
|
||
all_data.extend(list(zip(image_files, labels)))
|
||
class_names.append(class_name)
|
||
|
||
if not all_data:
|
||
raise ValueError("文件夹中没有有效的图片数据")
|
||
|
||
# 划分训练集和测试集
|
||
train_data, test_data = train_test_split(all_data, test_size=test_size, random_state=random_state,
|
||
stratify=[label for _, label in all_data])
|
||
|
||
print("图片数据集加载完成")
|
||
print(f"有效类别数量: {len(class_names)}")
|
||
return train_data, test_data, class_names
|
||
|
||
# dataset
|
||
class MedicalImageDataset(Dataset):
|
||
def __init__(self, data, transform=None, max_length=512, cache_dir="train_ocr_cache"):
|
||
self.data = data
|
||
self.transform = transform
|
||
self.max_length = max_length
|
||
self.classes = list(set([label for _, label in data]))
|
||
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
|
||
self.tokenizer = BertTokenizer.from_pretrained('../Models/bert')
|
||
self.reader = reader
|
||
self.cache_dir = cache_dir
|
||
|
||
# 初始化OCR统计
|
||
self.ocr_stats = {
|
||
'processed': 0,
|
||
'failed': 0,
|
||
'total_time': 0
|
||
}
|
||
|
||
# 创建缓存目录
|
||
os.makedirs(self.cache_dir, exist_ok=True)
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
img_url, label = self.data[idx]
|
||
label_idx = self.class_to_idx[label]
|
||
|
||
self.ocr_stats['processed'] += 1
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 下载图片
|
||
print("开始下载第"+str(idx)+"张图片")
|
||
response = requests.get(img_url, timeout=10)
|
||
response.raise_for_status()
|
||
image = Image.open(BytesIO(response.content)).convert('RGB')
|
||
|
||
|
||
# # 检查图像尺寸和通道数
|
||
# if image.shape[0] != 3 or image.shape[1] != 224 or image.shape[2] != 224:
|
||
# raise ValueError(f"Unexpected image shape: {image.shape}")
|
||
|
||
# 生成缓存路径
|
||
cache_path = get_cache_path(img_url, self.cache_dir)
|
||
|
||
# 检查缓存是否存在
|
||
cached_ocr = load_cache(cache_path)
|
||
if cached_ocr:
|
||
text = cached_ocr
|
||
logging.info(f"从缓存加载OCR结果: {cache_path}")
|
||
else:
|
||
# 进行OCR
|
||
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
||
|
||
# ocr_result = self.reader.readtext(img_cv, detail=0)
|
||
# ocr_result = self.reader.ocr(img_cv)
|
||
ocr_result = self.reader.ocr(img_cv)
|
||
text = ' '.join(ocr_result) if ocr_result else "无识别文本"
|
||
|
||
if text.strip(): # 只有当文本非空时才保存缓存
|
||
save_cache(cache_path, text)
|
||
logging.info(f"OCR完成并保存结果到缓存: {cache_path}")
|
||
else:
|
||
logging.warning(f"OCR未识别到文本: {img_url}")
|
||
|
||
encoded_text = self.tokenizer.encode_plus(
|
||
text,
|
||
add_special_tokens=True,
|
||
max_length=self.max_length,
|
||
padding='max_length',
|
||
truncation=True,
|
||
return_attention_mask=True,
|
||
return_tensors='pt'
|
||
)
|
||
|
||
if self.transform:
|
||
image = self.transform(image)
|
||
|
||
return image, encoded_text, label_idx
|
||
|
||
except Exception as e:
|
||
self.ocr_stats['failed'] += 1
|
||
logging.error(f"Error processing image {img_url}: {str(e)}")
|
||
# 创建一个占位图像和文本
|
||
placeholder_image = torch.zeros((3, 224, 224))
|
||
placeholder_text = self.tokenizer.encode_plus(
|
||
"",
|
||
add_special_tokens=True,
|
||
max_length=self.max_length,
|
||
padding='max_length',
|
||
truncation=True,
|
||
return_attention_mask=True,
|
||
return_tensors='pt'
|
||
)
|
||
return placeholder_image, placeholder_text, label_idx
|
||
finally:
|
||
self.ocr_stats['total_time'] += time.time() - start_time
|
||
|
||
def get_ocr_stats(self):
|
||
return self.ocr_stats
|
||
|
||
|
||
class MedicalImageDataset_v2(torch.utils.data.Dataset):
|
||
def __init__(self, data, transform=None, max_length=512, cache_dir="train_ocr_cache"):
|
||
self.data = data # 数据格式:[(image_path, label), ...]
|
||
self.transform = transform
|
||
self.max_length = max_length
|
||
self.classes = list(set([label for _, label in data]))
|
||
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
|
||
self.tokenizer = BertTokenizer.from_pretrained('../Models/bert')
|
||
self.reader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch') # 使用 PaddleOCR 进行 OCR
|
||
# self.reader = easyocr.Reader(['ch_sim'], gpu=True) # 或者使用 EasyOCR
|
||
# self.reader = CnOcr() # 或者使用 CnOcr
|
||
self.cache_dir = cache_dir
|
||
|
||
# 初始化OCR统计
|
||
self.ocr_stats = {
|
||
'processed': 0,
|
||
'failed': 0,
|
||
'total_time': 0
|
||
}
|
||
|
||
# 创建缓存目录
|
||
os.makedirs(self.cache_dir, exist_ok=True)
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
img_path, label = self.data[idx]
|
||
label_idx = self.class_to_idx[label]
|
||
|
||
self.ocr_stats['processed'] += 1
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 加载本地图片
|
||
print(f"开始加载第{idx}张图片: {img_path}")
|
||
image = Image.open(img_path).convert('RGB')
|
||
|
||
# 生成缓存路径
|
||
cache_path = get_cache_path(img_path, self.cache_dir)
|
||
print(f"缓存路径: {cache_path}")
|
||
|
||
# 检查缓存是否存在
|
||
cached_ocr = load_cache(cache_path)
|
||
if cached_ocr:
|
||
text = cached_ocr
|
||
logging.info(f"从缓存加载OCR结果: {cache_path}")
|
||
else:
|
||
# 进行OCR
|
||
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
||
print(f"开始OCR: {img_path}")
|
||
ocr_result = self.reader.ocr(img_cv, cls=True)
|
||
|
||
for line in ocr_result:
|
||
print('line:', line)
|
||
print('line[1]:', line[1])
|
||
print('line[1][0]:', line[1][0])
|
||
print('line[1][1]:', line[1][1][0])
|
||
|
||
ocr_result = ocr_result[0]
|
||
text = ' '.join([line[1][0] for line in ocr_result]) if ocr_result else "无识别文本"
|
||
print(f"xxxOCR结果: {text}")
|
||
if text.strip(): # 只有当文本非空时才保存缓存
|
||
save_cache(cache_path, text)
|
||
logging.info(f"OCR完成并保存结果到缓存: {cache_path}")
|
||
else:
|
||
logging.warning(f"OCR未识别到文本: {img_path}")
|
||
|
||
print(f"OCR结果: {text}")
|
||
encoded_text = self.tokenizer.encode_plus(
|
||
text,
|
||
add_special_tokens=True,
|
||
max_length=self.max_length,
|
||
padding='max_length',
|
||
truncation=True,
|
||
return_attention_mask=True,
|
||
return_tensors='pt'
|
||
)
|
||
|
||
if self.transform:
|
||
image = self.transform(image)
|
||
|
||
return image, encoded_text, label_idx
|
||
|
||
# except Exception as e:
|
||
# self.ocr_stats['failed'] += 1
|
||
# logging.error(f"Error processing image {img_path}: {str(e)}")
|
||
# # 创建一个占位图像和文本
|
||
# placeholder_image = torch.zeros((3, 224, 224))
|
||
# placeholder_text = self.tokenizer.encode_plus(
|
||
# "",
|
||
# add_special_tokens=True,
|
||
# max_length=self.max_length,
|
||
# padding='max_length',
|
||
# truncation=True,
|
||
# return_attention_mask=True,
|
||
# return_tensors='pt'
|
||
# )
|
||
# return placeholder_image, placeholder_text, label_idx
|
||
finally:
|
||
self.ocr_stats['total_time'] += time.time() - start_time
|
||
|
||
def get_ocr_stats(self):
|
||
return self.ocr_stats
|
||
|
||
|
||
# 辅助函数:生成缓存路径
|
||
# def get_cache_path(image_path, cache_dir):
|
||
# image_name = os.path.basename(image_path)
|
||
# cache_filename = f"{os.path.splitext(image_name)[0]}.txt"
|
||
# return os.path.join(cache_dir, cache_filename)
|
||
#
|
||
#
|
||
# # 辅助函数:加载缓存
|
||
# def load_cache(cache_path):
|
||
# if os.path.exists(cache_path):
|
||
# with open(cache_path, 'r', encoding='utf-8') as f:
|
||
# return f.read().strip()
|
||
# return None
|
||
#
|
||
#
|
||
# # 辅助函数:保存缓存
|
||
# def save_cache(cache_path, text):
|
||
# with open(cache_path, 'w', encoding='utf-8') as f:
|
||
# f.write(text.strip())
|
||
|
||
# module
|
||
class MedicalImageClassifier(nn.Module):
|
||
def __init__(self, num_classes):
|
||
super(MedicalImageClassifier, self).__init__()
|
||
# 使用更轻量级的ResNet模型
|
||
self.resnet = models.resnet18(pretrained=False)
|
||
self.resnet.load_state_dict(torch.load('../Models/resnet/resnet18-5c106cde.pth', weights_only=False))
|
||
self.resnet.fc = nn.Identity()
|
||
|
||
# 使用更小的BERT模型
|
||
self.bert = BertModel.from_pretrained('../Models/bert')
|
||
|
||
# 添加注意力机制
|
||
self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=8)
|
||
|
||
self.classifier = nn.Sequential(
|
||
nn.Linear(512 + 768, 256),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.5),
|
||
nn.Linear(256, num_classes)
|
||
)
|
||
|
||
def forward(self, image, encoded_text):
|
||
image_features = self.resnet(image)
|
||
text_features = self.bert(**encoded_text)[0] # 使用序列输出
|
||
|
||
# 应用注意力机制
|
||
text_features, _ = self.attention(text_features, text_features, text_features)
|
||
text_features = torch.mean(text_features, dim=1)
|
||
|
||
combined_features = torch.cat((image_features, text_features), dim=1)
|
||
return self.classifier(combined_features)
|
||
|
||
def train(model, train_loader, criterion, optimizer, device):
|
||
print("开始训练")
|
||
model.train()
|
||
running_loss = 0.0
|
||
for images, encoded_texts, labels in train_loader:
|
||
images, labels = images.to(device), labels.to(device)
|
||
encoded_texts = {k: v.squeeze(1).to(device) for k, v in encoded_texts.items()}
|
||
|
||
optimizer.zero_grad()
|
||
outputs = model(images, encoded_texts)
|
||
loss = criterion(outputs, labels)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
running_loss += loss.item()
|
||
|
||
print("训练结束")
|
||
return running_loss / len(train_loader)
|
||
|
||
def test(model, test_loader, criterion, device):
|
||
print("开始测试")
|
||
model.eval()
|
||
correct = 0
|
||
total = 0
|
||
running_loss = 0.0
|
||
with torch.no_grad():
|
||
for images, encoded_texts, labels in test_loader:
|
||
images, labels = images.to(device), labels.to(device)
|
||
encoded_texts = {k: v.squeeze(1).to(device) for k, v in encoded_texts.items()}
|
||
|
||
outputs = model(images, encoded_texts)
|
||
loss = criterion(outputs, labels)
|
||
running_loss += loss.item()
|
||
|
||
_, predicted = torch.max(outputs.data, 1)
|
||
total += labels.size(0)
|
||
correct += (predicted == labels).sum().item()
|
||
print("测试结束")
|
||
return running_loss / len(test_loader), correct / total
|
||
|
||
def save_results(epoch, train_loss, test_loss, accuracy, file_path='result/results.csv'):
|
||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||
file_exists = os.path.isfile(file_path)
|
||
with open(file_path, mode='a', newline='') as file:
|
||
writer = csv.writer(file)
|
||
if not file_exists:
|
||
writer.writerow(['Epoch', 'Train Loss', 'Test Loss', 'Accuracy'])
|
||
writer.writerow([epoch, train_loss, test_loss, accuracy])
|
||
|
||
def main():
|
||
|
||
torch.backends.cudnn.enabled = False
|
||
|
||
mp.set_start_method('spawn', force=True)
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
# train_data, test_data, class_names = load_excel_dataset('dataset_url.xlsx')
|
||
train_data, test_data, class_names = load_folder_dataset('images')
|
||
|
||
#图像输入resnet前要预处理一下
|
||
transform = transforms.Compose([
|
||
transforms.Resize((224, 224)),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.RandomRotation(10),
|
||
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
])
|
||
|
||
|
||
train_dataset = MedicalImageDataset_v2(train_data, transform=transform, cache_dir="train_ocr_cache")
|
||
test_dataset = MedicalImageDataset_v2(test_data, transform=transform, cache_dir="test_ocr_cache")
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True)
|
||
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False)
|
||
|
||
model = MedicalImageClassifier(num_classes=len(class_names)).to(device)
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||
|
||
num_epochs = 10
|
||
for epoch in range(num_epochs):
|
||
print("开始第"+str(epoch)+"轮训练")
|
||
try:
|
||
train_loss = train(model, train_loader, criterion, optimizer, device)
|
||
test_loss, accuracy = test(model, test_loader, criterion, device)
|
||
logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy * 100:.2f}%")
|
||
|
||
# 保存结果
|
||
save_results(epoch + 1, train_loss, test_loss, accuracy)
|
||
|
||
# 保存模型
|
||
torch.save(model.state_dict(), f'../Models/checkpoints/model_epoch_{epoch + 1}.pth')
|
||
|
||
# except Exception as e:
|
||
# logging.error(f"Error during training in epoch {epoch + 1}: {str(e)}")
|
||
# continue
|
||
finally:
|
||
print()
|
||
print("第"+str(epoch)+"轮训练完成")
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|
||
|