nwjh/Classify/main.py

495 lines
18 KiB
Python
Raw Normal View History

2025-03-24 09:27:03 +08:00
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from transformers import BertTokenizer, BertModel
import os
import csv
import logging
import cv2
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import requests
from io import BytesIO
import time
import torch.multiprocessing as mp
import hashlib
import json
from paddleocr import PaddleOCR
# 配置日志
logging.basicConfig(
level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s',
filename='logs/medical_image_classifier.log', # 日志文件保存路径
filemode='w' # 以写入模式打开(覆盖现有文件)
)
# 初始化 PaddleOCR reader
reader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch', det_model_dir='../Models/ch_PP-OCRv4_det_infer',
rec_model_dir='../Models/ch_PP-OCRv4_rec_infer',
cls_model_dir='../Models/ch_ppocr_mobile_v2.0_cls_infer')
def get_cache_path(img_url, cache_dir):
"""生成缓存文件的路径使用图片URL的哈希值来唯一标识缓存文件"""
url_hash = hashlib.md5(img_url.encode()).hexdigest()
return os.path.join(cache_dir, f"{url_hash}.json")
def save_cache(cache_path, ocr_result):
"""保存OCR结果到缓存文件"""
with open(cache_path, 'w', encoding='utf-8') as f:
json.dump(ocr_result, f, ensure_ascii=False)
def load_cache(cache_path):
"""从缓存文件加载OCR结果"""
if os.path.exists(cache_path):
with open(cache_path, 'r', encoding='utf-8') as f:
return json.load(f)
return None
def load_excel_dataset(file_path, max_samples_per_class=500, test_size=0.2, random_state=42):
# 读取Excel文件
xls = pd.ExcelFile(file_path)
all_data = []
class_names = []
# 遍历每个工作表
for sheet_name in xls.sheet_names:
df = pd.read_excel(xls, sheet_name, header=None)
if df.empty:
logging.info(f"工作表 '{sheet_name}' 为空,已跳过")
continue
image_urls = df.iloc[:, 0].tolist() # 第一列为图片链接
# 限制每个类别的样本数量
if len(image_urls) > max_samples_per_class:
image_urls = image_urls[:max_samples_per_class]
labels = [sheet_name] * len(image_urls)
all_data.extend(list(zip(image_urls, labels)))
class_names.append(sheet_name)
if not all_data:
raise ValueError("所有工作表都为空,没有可用的数据")
# 划分训练集和测试集
train_data, test_data = train_test_split(all_data, test_size=test_size, random_state=random_state, stratify=[label for _, label in all_data])
print("图片数据集加载完成")
print(f"有效类别数量: {len(class_names)}")
return train_data, test_data, class_names
def load_folder_dataset(folder_path, max_samples_per_class=500, test_size=0.2, random_state=42):
all_data = []
class_names = []
# 遍历文件夹中的每个子文件夹
for class_name in os.listdir(folder_path):
class_path = os.path.join(folder_path, class_name)
# 确保是目录(即类别的子文件夹)
if not os.path.isdir(class_path):
logging.info(f"'{class_name}' 不是文件夹,已跳过")
continue
# 获取该类别的所有图片文件路径
image_files = [os.path.join(class_path, f) for f in os.listdir(class_path) if
os.path.isfile(os.path.join(class_path, f))]
if not image_files:
logging.info(f"文件夹 '{class_name}' 中没有图片文件,已跳过")
continue
# 限制每个类别的样本数量
if len(image_files) > max_samples_per_class:
image_files = image_files[:max_samples_per_class]
# 创建数据条目 (图片路径, 标签)
labels = [class_name] * len(image_files)
all_data.extend(list(zip(image_files, labels)))
class_names.append(class_name)
if not all_data:
raise ValueError("文件夹中没有有效的图片数据")
# 划分训练集和测试集
train_data, test_data = train_test_split(all_data, test_size=test_size, random_state=random_state,
stratify=[label for _, label in all_data])
print("图片数据集加载完成")
print(f"有效类别数量: {len(class_names)}")
return train_data, test_data, class_names
# dataset
class MedicalImageDataset(Dataset):
def __init__(self, data, transform=None, max_length=512, cache_dir="train_ocr_cache"):
self.data = data
self.transform = transform
self.max_length = max_length
self.classes = list(set([label for _, label in data]))
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
self.tokenizer = BertTokenizer.from_pretrained('../Models/bert')
self.reader = reader
self.cache_dir = cache_dir
# 初始化OCR统计
self.ocr_stats = {
'processed': 0,
'failed': 0,
'total_time': 0
}
# 创建缓存目录
os.makedirs(self.cache_dir, exist_ok=True)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_url, label = self.data[idx]
label_idx = self.class_to_idx[label]
self.ocr_stats['processed'] += 1
start_time = time.time()
try:
# 下载图片
print("开始下载第"+str(idx)+"张图片")
response = requests.get(img_url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert('RGB')
# # 检查图像尺寸和通道数
# if image.shape[0] != 3 or image.shape[1] != 224 or image.shape[2] != 224:
# raise ValueError(f"Unexpected image shape: {image.shape}")
# 生成缓存路径
cache_path = get_cache_path(img_url, self.cache_dir)
# 检查缓存是否存在
cached_ocr = load_cache(cache_path)
if cached_ocr:
text = cached_ocr
logging.info(f"从缓存加载OCR结果: {cache_path}")
else:
# 进行OCR
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
# ocr_result = self.reader.readtext(img_cv, detail=0)
# ocr_result = self.reader.ocr(img_cv)
ocr_result = self.reader.ocr(img_cv)
text = ' '.join(ocr_result) if ocr_result else "无识别文本"
if text.strip(): # 只有当文本非空时才保存缓存
save_cache(cache_path, text)
logging.info(f"OCR完成并保存结果到缓存: {cache_path}")
else:
logging.warning(f"OCR未识别到文本: {img_url}")
encoded_text = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
if self.transform:
image = self.transform(image)
return image, encoded_text, label_idx
except Exception as e:
self.ocr_stats['failed'] += 1
logging.error(f"Error processing image {img_url}: {str(e)}")
# 创建一个占位图像和文本
placeholder_image = torch.zeros((3, 224, 224))
placeholder_text = self.tokenizer.encode_plus(
"",
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return placeholder_image, placeholder_text, label_idx
finally:
self.ocr_stats['total_time'] += time.time() - start_time
def get_ocr_stats(self):
return self.ocr_stats
class MedicalImageDataset_v2(torch.utils.data.Dataset):
def __init__(self, data, transform=None, max_length=512, cache_dir="train_ocr_cache"):
self.data = data # 数据格式:[(image_path, label), ...]
self.transform = transform
self.max_length = max_length
self.classes = list(set([label for _, label in data]))
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
self.tokenizer = BertTokenizer.from_pretrained('../Models/bert')
self.reader = PaddleOCR(use_angle_cls=True, use_gpu=True, lang='ch') # 使用 PaddleOCR 进行 OCR
# self.reader = easyocr.Reader(['ch_sim'], gpu=True) # 或者使用 EasyOCR
# self.reader = CnOcr() # 或者使用 CnOcr
self.cache_dir = cache_dir
# 初始化OCR统计
self.ocr_stats = {
'processed': 0,
'failed': 0,
'total_time': 0
}
# 创建缓存目录
os.makedirs(self.cache_dir, exist_ok=True)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path, label = self.data[idx]
label_idx = self.class_to_idx[label]
self.ocr_stats['processed'] += 1
start_time = time.time()
try:
# 加载本地图片
print(f"开始加载第{idx}张图片: {img_path}")
image = Image.open(img_path).convert('RGB')
# 生成缓存路径
cache_path = get_cache_path(img_path, self.cache_dir)
print(f"缓存路径: {cache_path}")
# 检查缓存是否存在
cached_ocr = load_cache(cache_path)
if cached_ocr:
text = cached_ocr
logging.info(f"从缓存加载OCR结果: {cache_path}")
else:
# 进行OCR
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
print(f"开始OCR: {img_path}")
ocr_result = self.reader.ocr(img_cv, cls=True)
for line in ocr_result:
print('line:', line)
print('line[1]:', line[1])
print('line[1][0]:', line[1][0])
print('line[1][1]:', line[1][1][0])
ocr_result = ocr_result[0]
text = ' '.join([line[1][0] for line in ocr_result]) if ocr_result else "无识别文本"
print(f"xxxOCR结果: {text}")
if text.strip(): # 只有当文本非空时才保存缓存
save_cache(cache_path, text)
logging.info(f"OCR完成并保存结果到缓存: {cache_path}")
else:
logging.warning(f"OCR未识别到文本: {img_path}")
print(f"OCR结果: {text}")
encoded_text = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
if self.transform:
image = self.transform(image)
return image, encoded_text, label_idx
# except Exception as e:
# self.ocr_stats['failed'] += 1
# logging.error(f"Error processing image {img_path}: {str(e)}")
# # 创建一个占位图像和文本
# placeholder_image = torch.zeros((3, 224, 224))
# placeholder_text = self.tokenizer.encode_plus(
# "",
# add_special_tokens=True,
# max_length=self.max_length,
# padding='max_length',
# truncation=True,
# return_attention_mask=True,
# return_tensors='pt'
# )
# return placeholder_image, placeholder_text, label_idx
finally:
self.ocr_stats['total_time'] += time.time() - start_time
def get_ocr_stats(self):
return self.ocr_stats
# 辅助函数:生成缓存路径
# def get_cache_path(image_path, cache_dir):
# image_name = os.path.basename(image_path)
# cache_filename = f"{os.path.splitext(image_name)[0]}.txt"
# return os.path.join(cache_dir, cache_filename)
#
#
# # 辅助函数:加载缓存
# def load_cache(cache_path):
# if os.path.exists(cache_path):
# with open(cache_path, 'r', encoding='utf-8') as f:
# return f.read().strip()
# return None
#
#
# # 辅助函数:保存缓存
# def save_cache(cache_path, text):
# with open(cache_path, 'w', encoding='utf-8') as f:
# f.write(text.strip())
# module
class MedicalImageClassifier(nn.Module):
def __init__(self, num_classes):
super(MedicalImageClassifier, self).__init__()
# 使用更轻量级的ResNet模型
self.resnet = models.resnet18(pretrained=False)
self.resnet.load_state_dict(torch.load('../Models/resnet/resnet18-5c106cde.pth', weights_only=False))
self.resnet.fc = nn.Identity()
# 使用更小的BERT模型
self.bert = BertModel.from_pretrained('../Models/bert')
# 添加注意力机制
self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=8)
self.classifier = nn.Sequential(
nn.Linear(512 + 768, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, image, encoded_text):
image_features = self.resnet(image)
text_features = self.bert(**encoded_text)[0] # 使用序列输出
# 应用注意力机制
text_features, _ = self.attention(text_features, text_features, text_features)
text_features = torch.mean(text_features, dim=1)
combined_features = torch.cat((image_features, text_features), dim=1)
return self.classifier(combined_features)
def train(model, train_loader, criterion, optimizer, device):
print("开始训练")
model.train()
running_loss = 0.0
for images, encoded_texts, labels in train_loader:
images, labels = images.to(device), labels.to(device)
encoded_texts = {k: v.squeeze(1).to(device) for k, v in encoded_texts.items()}
optimizer.zero_grad()
outputs = model(images, encoded_texts)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print("训练结束")
return running_loss / len(train_loader)
def test(model, test_loader, criterion, device):
print("开始测试")
model.eval()
correct = 0
total = 0
running_loss = 0.0
with torch.no_grad():
for images, encoded_texts, labels in test_loader:
images, labels = images.to(device), labels.to(device)
encoded_texts = {k: v.squeeze(1).to(device) for k, v in encoded_texts.items()}
outputs = model(images, encoded_texts)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("测试结束")
return running_loss / len(test_loader), correct / total
def save_results(epoch, train_loss, test_loss, accuracy, file_path='result/results.csv'):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
file_exists = os.path.isfile(file_path)
with open(file_path, mode='a', newline='') as file:
writer = csv.writer(file)
if not file_exists:
writer.writerow(['Epoch', 'Train Loss', 'Test Loss', 'Accuracy'])
writer.writerow([epoch, train_loss, test_loss, accuracy])
def main():
torch.backends.cudnn.enabled = False
mp.set_start_method('spawn', force=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# train_data, test_data, class_names = load_excel_dataset('dataset_url.xlsx')
train_data, test_data, class_names = load_folder_dataset('images')
#图像输入resnet前要预处理一下
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = MedicalImageDataset_v2(train_data, transform=transform, cache_dir="train_ocr_cache")
test_dataset = MedicalImageDataset_v2(test_data, transform=transform, cache_dir="test_ocr_cache")
train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False)
model = MedicalImageClassifier(num_classes=len(class_names)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
print("开始第"+str(epoch)+"轮训练")
try:
train_loss = train(model, train_loader, criterion, optimizer, device)
test_loss, accuracy = test(model, test_loader, criterion, device)
logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy * 100:.2f}%")
# 保存结果
save_results(epoch + 1, train_loss, test_loss, accuracy)
# 保存模型
torch.save(model.state_dict(), f'../Models/checkpoints/model_epoch_{epoch + 1}.pth')
# except Exception as e:
# logging.error(f"Error during training in epoch {epoch + 1}: {str(e)}")
# continue
finally:
print()
print(""+str(epoch)+"轮训练完成")
if __name__ == '__main__':
main()