仅记录自身学习过程,不适合作为教程浏览
数据搜集
在Kaggle(Find Open Datasets and Machine Learning Projects | Kaggle)等网站上下载数据集,或是自行用爬虫爬取
数据清洗
下载下来的数据往往存在大量重复或不符合要求的低质/主题无关图像,需要进行一定处理才可用于训练。由于此次目的仅在于跑通整个训练流程,因此直接使用整理好的数据集,仅进行查重。
deepseek写的用于查重的窗体程序
并不好用,图片一多就特别卡。本来想完全依靠deepseek生成的,但最后还是得自己下场修bug。
import sys
import os
import shutil
from collections import defaultdict
from PIL import Image
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
QFileDialog, QListWidget, QListWidgetItem, QCheckBox, QScrollArea, QGroupBox,
QDialog, QGridLayout, QMessageBox, QProgressBar
)
from PyQt5.QtCore import Qt, QSize, QThread, pyqtSignal
from PyQt5.QtGui import QPixmap, QIcon, QFont, QImage, QPainter
from imagededup.methods import PHash
class ImageViewerDialog(QDialog):
"""图片查看对话框,显示大图"""
def __init__(self, image_path, parent=None):
super().__init__(parent)
self.setWindowTitle("图片查看器")
self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
self.setMinimumSize(800, 600)
layout = QVBoxLayout()
# 显示图片路径
path_label = QLabel(f"图片路径: {image_path}")
path_label.setWordWrap(True)
path_label.setStyleSheet("font-size: 12px; color: #666;")
layout.addWidget(path_label)
# 显示图片
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("background-color: #f0f0f0;")
scroll_area = QScrollArea()
scroll_area.setWidgetResizable(True)
scroll_area.setWidget(self.image_label)
layout.addWidget(scroll_area)
# 关闭按钮
close_btn = QPushButton("关闭")
close_btn.setStyleSheet(
"QPushButton { background-color: #4CAF50; color: white; padding: 8px 16px; border-radius: 4px; }"
"QPushButton:hover { background-color: #45a049; }"
)
close_btn.clicked.connect(self.close)
layout.addWidget(close_btn, alignment=Qt.AlignCenter)
self.setLayout(layout)
self.load_image(image_path)
def load_image(self, image_path):
"""加载并显示图片(修复显示问题)"""
try:
# 加载图像并转换为RGBA模式
image = Image.open(image_path).convert("RGBA")
# 限制最大显示尺寸
max_size = QSize(1600, 1200)
if image.width > max_size.width() or image.height > max_size.height():
image.thumbnail((max_size.width(), max_size.height()), Image.LANCZOS)
# 创建QImage并确保正确的格式
qimage = QImage(
image.tobytes(),
image.width,
image.height,
QImage.Format_RGBA8888
)
pixmap = QPixmap.fromImage(qimage)
self.image_label.setPixmap(pixmap)
except Exception as e:
self.image_label.setText(f"无法加载图片: {str(e)}")
class DedupWorker(QThread):
"""后台线程处理图片查重"""
progress_updated = pyqtSignal(int, int, str) # 当前进度, 总数, 当前文件
duplicates_found = pyqtSignal(dict, list) # 重复组, 唯一图片
def __init__(self, image_dir):
super().__init__()
self.image_dir = image_dir
self.canceled = False
def run(self):
"""执行查重操作(修复唯一图片计数)"""
try:
# 获取所有图片文件
image_files = []
extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp')
for root, _, files in os.walk(self.image_dir):
for file in files:
if file.lower().endswith(extensions):
image_files.append(os.path.join(root, file))
if not image_files:
self.progress_updated.emit(0, 0, "未找到图片文件")
return
total = len(image_files)
self.progress_updated.emit(0, total, "开始处理...")
# 使用PHash算法进行查重
phasher = PHash()
encodings = {}
# 计算图片哈希
for i, image_path in enumerate(image_files, 1):
if self.canceled:
return
self.progress_updated.emit(i, total, os.path.basename(image_path))
try:
encodings[image_path] = phasher.encode_image(image_path)
except Exception as e:
print(f"处理图片 {image_path} 时出错: {str(e)}")
# 查找重复图片
duplicates = phasher.find_duplicates(encoding_map=encodings)
# 组织重复组
grouped_duplicates = defaultdict(list)
visited = set()
dup_group =[] #统计所有重复图片
for img, dup_list in duplicates.items():
if img in visited:
continue
# 创建一个新的重复组
group = [img]
visited.add(img)
# 添加所有重复项
for dup in dup_list:
if dup not in visited:
group.append(dup)
visited.add(dup)
# 添加组
if len(group) > 1: #group长度大于一,即存在重复图片
grouped_duplicates[img] = group
dup_group.extend(group) #添加重复图片
# 获取唯一图片 - 修复唯一图片计数问题
unique_images = [img for img in image_files if img not in dup_group]#原来是unique_images = [img for img in image_files if img not in visited]
print("nique_images:",len(unique_images))
self.duplicates_found.emit(grouped_duplicates, unique_images)
except Exception as e:
print(f"查重过程中发生错误: {str(e)}")
def cancel(self):
"""取消操作"""
self.canceled = True
class ImageDedupTool(QMainWindow):
"""图片查重工具主界面(修复所有报告的问题)"""
def __init__(self):
super().__init__()
self.setWindowTitle("图片查重工具")
self.setGeometry(100, 100, 1000, 700)
self.setWindowIcon(QIcon(self.create_icon()))
# 初始化变量
self.image_dir = ""
self.output_dir = ""
self.duplicate_groups = {}
self.unique_images = []
self.selected_images = set()
self.worker_thread = None
# 创建主部件和布局
main_widget = QWidget()
main_layout = QVBoxLayout()
# 标题
title_label = QLabel("图片查重工具")
title_label.setFont(QFont("Arial", 18, QFont.Bold))
title_label.setStyleSheet("color: #2c3e50; margin: 15px 0;")
title_label.setAlignment(Qt.AlignCenter)
main_layout.addWidget(title_label)
# 文件夹选择区域
folder_group = QGroupBox("文件夹设置")
folder_layout = QGridLayout()
self.image_dir_label = QLabel("未选择图片文件夹")
self.image_dir_label.setStyleSheet("font-size: 12px;")
self.image_dir_label.setWordWrap(True)
folder_layout.addWidget(QLabel("图片文件夹:"), 0, 0)
folder_layout.addWidget(self.image_dir_label, 0, 1)
self.output_dir_label = QLabel("未选择输出文件夹")
self.output_dir_label.setStyleSheet("font-size: 12px;")
self.output_dir_label.setWordWrap(True)
folder_layout.addWidget(QLabel("输出文件夹:"), 1, 0)
folder_layout.addWidget(self.output_dir_label, 1, 1)
self.select_image_btn = QPushButton("选择图片文件夹")
self.select_image_btn.setStyleSheet(self.get_button_style())
self.select_image_btn.clicked.connect(self.select_image_dir)
folder_layout.addWidget(self.select_image_btn, 0, 2)
self.select_output_btn = QPushButton("选择输出文件夹")
self.select_output_btn.setStyleSheet(self.get_button_style())
self.select_output_btn.clicked.connect(self.select_output_dir)
folder_layout.addWidget(self.select_output_btn, 1, 2)
folder_group.setLayout(folder_layout)
main_layout.addWidget(folder_group)
# 操作按钮区域
btn_layout = QHBoxLayout()
self.start_btn = QPushButton("开始查重")
self.start_btn.setStyleSheet(self.get_button_style("#3498db"))
self.start_btn.clicked.connect(self.start_dedup)
self.start_btn.setEnabled(False)
btn_layout.addWidget(self.start_btn)
self.cancel_btn = QPushButton("取消")
self.cancel_btn.setStyleSheet(self.get_button_style("#e74c3c"))
self.cancel_btn.clicked.connect(self.cancel_dedup)
self.cancel_btn.setEnabled(False)
btn_layout.addWidget(self.cancel_btn)
self.export_btn = QPushButton("导出结果")
self.export_btn.setStyleSheet(self.get_button_style("#2ecc71"))
self.export_btn.clicked.connect(self.export_results)
self.export_btn.setEnabled(False)
btn_layout.addWidget(self.export_btn)
main_layout.addLayout(btn_layout)
# 进度条
self.progress_bar = QProgressBar()
self.progress_bar.setStyleSheet(
"QProgressBar { border: 1px solid #3498db; border-radius: 5px; text-align: center; }"
"QProgressBar::chunk { background-color: #3498db; }"
)
self.progress_bar.setVisible(False)
main_layout.addWidget(self.progress_bar)
# 结果区域
result_group = QGroupBox("查重结果")
result_layout = QVBoxLayout()
# 结果统计标签
self.result_label = QLabel("请选择图片文件夹并点击'开始查重'按钮继续")
self.result_label.setStyleSheet("font-size: 14px; margin: 10px 0;")
self.result_label.setAlignment(Qt.AlignCenter)
result_layout.addWidget(self.result_label)
# 结果列表区域
self.result_scroll = QScrollArea()
self.result_scroll.setWidgetResizable(True)
self.result_container = QWidget()
self.result_container_layout = QVBoxLayout()
self.result_container_layout.setAlignment(Qt.AlignTop)
self.result_container.setLayout(self.result_container_layout)
self.result_scroll.setWidget(self.result_container)
result_layout.addWidget(self.result_scroll)
result_group.setLayout(result_layout)
main_layout.addWidget(result_group, 1)
main_widget.setLayout(main_layout)
self.setCentralWidget(main_widget)
# 状态栏
self.statusBar().showMessage("就绪")
def create_icon(self):
"""创建应用图标"""
# 使用简单图标替代
pixmap = QPixmap(64, 64)
pixmap.fill(Qt.transparent)
return QIcon(pixmap)
def get_button_style(self, color="#7f8c8d"):
"""获取按钮样式"""
return f"""
QPushButton {{
background-color: {color};
color: white;
border: none;
padding: 8px 16px;
font-size: 14px;
border-radius: 4px;
}}
QPushButton:hover {{
background-color: #{int(color[1:], 16) - 0x111111:06X};
}}
QPushButton:disabled {{
background-color: #bdc3c7;
}}
"""
def select_image_dir(self):
"""选择图片文件夹"""
dir_path = QFileDialog.getExistingDirectory(self, "选择图片文件夹")
if dir_path:
self.image_dir = dir_path
self.image_dir_label.setText(dir_path)
self.start_btn.setEnabled(True)
self.result_label.setText("已选择图片文件夹,点击'开始查重'按钮继续")
def select_output_dir(self):
"""选择输出文件夹"""
dir_path = QFileDialog.getExistingDirectory(self, "选择输出文件夹")
if dir_path:
self.output_dir = dir_path
self.output_dir_label.setText(dir_path)
def start_dedup(self):
"""开始查重操作"""
if not self.image_dir:
QMessageBox.warning(self, "警告", "请先选择图片文件夹")
return
# 清空之前的结果
self.duplicate_groups = {}
self.unique_images = []
self.selected_images = set()
self.clear_result_container()
# 显示进度条
self.progress_bar.setVisible(True)
self.progress_bar.setValue(0)
self.result_label.setText("正在处理图片,请稍候...")
self.statusBar().showMessage("正在查找重复图片...")
# 禁用按钮
self.start_btn.setEnabled(False)
self.cancel_btn.setEnabled(True)
self.export_btn.setEnabled(False)
# 启动后台线程
self.worker_thread = DedupWorker(self.image_dir)
self.worker_thread.progress_updated.connect(self.update_progress)
self.worker_thread.duplicates_found.connect(self.process_results)
self.worker_thread.start()
def cancel_dedup(self):
"""取消查重操作"""
if self.worker_thread and self.worker_thread.isRunning():
self.worker_thread.cancel()
self.worker_thread.wait()
self.progress_bar.setVisible(False)
self.result_label.setText("操作已取消")
self.statusBar().showMessage("操作已取消")
# 恢复按钮状态
self.start_btn.setEnabled(True)
self.cancel_btn.setEnabled(False)
def update_progress(self, current, total, filename):
"""更新进度条"""
if total > 0:
progress = int((current / total) * 100)
self.progress_bar.setValue(progress)
self.statusBar().showMessage(f"正在处理: {filename} ({current}/{total})")
def process_results(self, duplicates, uniques):
"""处理查重结果"""
self.duplicate_groups = duplicates
self.unique_images = uniques
self.selected_images = set(self.unique_images) # 自动选择唯一图片
# 对于重复组,默认选择每组的第一张图片
for group in self.duplicate_groups.values():
if group: # 确保组不为空
self.selected_images.add(group[0])
# 隐藏进度条
self.progress_bar.setVisible(False)
# 更新结果统计
duplicate_count = sum(len(group) for group in self.duplicate_groups.values())
unique_count = len(self.unique_images)
total_images = duplicate_count + unique_count
self.result_label.setText(
f"查重完成!共发现 {len(self.duplicate_groups)} 组重复图片 "
f"({duplicate_count} 张), {unique_count} 张唯一图片"
)
self.statusBar().showMessage("查重完成")
# 显示结果
self.display_results()
# 更新按钮状态
self.start_btn.setEnabled(True)
self.cancel_btn.setEnabled(False)
self.export_btn.setEnabled(True)
def clear_result_container(self):
"""清空结果容器"""
while self.result_container_layout.count():
child = self.result_container_layout.takeAt(0)
if child.widget():
child.widget().deleteLater()
def display_results(self):
"""显示查重结果"""
self.clear_result_container()
# 显示重复图片组
for group_idx, (_, group) in enumerate(self.duplicate_groups.items(), 1):
group_box = QGroupBox(f"重复图片组 {group_idx} - 共 {len(group)} 张")
group_layout = QVBoxLayout()
for img_path in group:
img_item = self.create_image_item(img_path, group_idx)
group_layout.addWidget(img_item)
group_box.setLayout(group_layout)
self.result_container_layout.addWidget(group_box)
# 显示唯一图片
if self.unique_images:
unique_group = QGroupBox(f"唯一图片 - 共 {len(self.unique_images)} 张")
unique_layout = QVBoxLayout()
for img_path in self.unique_images:
img_item = self.create_image_item(img_path, is_unique=True)
unique_layout.addWidget(img_item)
unique_group.setLayout(unique_layout)
self.result_container_layout.addWidget(unique_group)
def create_image_item(self, image_path, group_id=None, is_unique=False):
"""创建图片项控件(修复选择问题和缩略图问题)"""
item_widget = QWidget()
item_layout = QHBoxLayout()
# 缩略图
try:
# 加载图像并确保使用RGBA格式
img = Image.open(image_path).convert("RGBA")
img.thumbnail((120, 120), Image.LANCZOS)
# 创建QImage并确保正确的格式
qimage = QImage(
img.tobytes(),
img.width,
img.height,
QImage.Format_RGBA8888
)
pixmap = QPixmap.fromImage(qimage)
thumbnail_label = QLabel()
thumbnail_label.setPixmap(pixmap)
thumbnail_label.setStyleSheet("border: 1px solid #ddd; padding: 5px;")
thumbnail_label.setCursor(Qt.PointingHandCursor)
thumbnail_label.mousePressEvent = lambda e: self.show_image(image_path)
except Exception as e:
print(f"创建缩略图时出错: {str(e)}")
thumbnail_label = QLabel("无法加载图片")
thumbnail_label.setStyleSheet("color: red;")
item_layout.addWidget(thumbnail_label)
# 图片信息
info_layout = QVBoxLayout()
# 文件名和路径
filename = os.path.basename(image_path)
path_label = QLabel(f"文件名: {filename}")
path_label.setStyleSheet("font-weight: bold;")
dir_label = QLabel(f"路径: {os.path.dirname(image_path)}")
dir_label.setStyleSheet("font-size: 11px; color: #555;")
# 图片尺寸和大小
try:
img = Image.open(image_path) # 重新打开以获取原始尺寸
size = f"{img.width}×{img.height}"
img_size = os.path.getsize(image_path)
size_mb = img_size / (1024 * 1024)
size_label = QLabel(f"尺寸: {size} | 大小: {size_mb:.2f} MB")
except Exception:
size_label = QLabel("尺寸信息不可用")
size_label.setStyleSheet("font-size: 12px;")
info_layout.addWidget(path_label)
info_layout.addWidget(dir_label)
info_layout.addWidget(size_label)
# 添加分隔线
if not is_unique:
info_layout.addWidget(QLabel())
info_layout.addWidget(QLabel("请选择要保留的图片:"))
item_layout.addLayout(info_layout, 1)
# 对于唯一图片,不显示选择框
if is_unique:
status_label = QLabel("唯一图片,将自动保留")
status_label.setStyleSheet("color: #27ae60; font-weight: bold;")
item_layout.addWidget(status_label)
else:
# 选择框 - 修复选择逻辑
checkbox = QCheckBox("保留此图片")
# 设置初始选中状态
is_selected = image_path in self.selected_images
checkbox.setChecked(is_selected)
# 连接状态改变信号
checkbox.stateChanged.connect(lambda state, path=image_path: self.toggle_selection(path, state))
item_layout.addWidget(checkbox)
item_widget.setLayout(item_layout)
item_widget.setStyleSheet("""
QWidget {
border-bottom: 1px solid #eee;
padding: 10px;
}
QWidget:hover {
background-color: #f9f9f9;
}
""")
return item_widget
def toggle_selection(self, image_path, state):
"""切换图片选择状态(修复选择逻辑)"""
if state == Qt.Checked:
self.selected_images.add(image_path)
# 确保同一组中只有一个被选中(单选逻辑)
for group in self.duplicate_groups.values():
if image_path in group:
# 取消选择同组的其他图片
for other_path in group:
if other_path != image_path and other_path in self.selected_images:
self.selected_images.remove(other_path)
# 更新UI显示
self.display_results()
break
else:
# 如果取消选择,但该图片是组中唯一被选中的,则重新选中
for group in self.duplicate_groups.values():
if image_path in group:
# 检查是否还有同组被选中的图片
any_selected = any(p in self.selected_images for p in group)
if not any_selected:
# 如果没有选中的图片,重新选中当前图片
self.selected_images.add(image_path)
# 更新UI显示
self.display_results()
return
# 如果不是组中的图片,直接移除
self.selected_images.discard(image_path)
def show_image(self, image_path):
"""显示大图(修复崩溃问题)"""
try:
viewer = ImageViewerDialog(image_path, self)
viewer.exec_()
except Exception as e:
QMessageBox.critical(self, "错误", f"无法显示图片: {str(e)}")
def export_results(self):
"""导出结果到输出文件夹(修复导出问题)"""
if not self.output_dir:
QMessageBox.warning(self, "警告", "请先选择输出文件夹")
return
if not self.selected_images:
QMessageBox.warning(self, "警告", "没有选择要保留的图片")
return
try:
# 创建输出目录
os.makedirs(self.output_dir, exist_ok=True)
# 复制选中的图片
total = len(self.selected_images)
copied = 0
# 显示进度条
self.progress_bar.setVisible(True)
self.progress_bar.setValue(0)
for img_path in self.selected_images:
filename = os.path.basename(img_path)
dest_path = os.path.join(self.output_dir, filename)
# 处理文件名冲突
counter = 1
while os.path.exists(dest_path):
name, ext = os.path.splitext(filename)
dest_path = os.path.join(self.output_dir, f"{name}_{counter}{ext}")
counter += 1
shutil.copy2(img_path, dest_path)
copied += 1
progress = int((copied / total) * 100)
self.progress_bar.setValue(progress)
self.statusBar().showMessage(f"正在导出: {filename} ({copied}/{total})")
QApplication.processEvents() # 更新UI
# 显示完成消息
QMessageBox.information(
self,
"导出完成",
f"成功导出 {copied} 张图片到:\n{self.output_dir}"
)
self.statusBar().showMessage(f"导出完成: {copied} 张图片")
except Exception as e:
QMessageBox.critical(
self,
"导出错误",
f"导出过程中发生错误:\n{str(e)}"
)
self.statusBar().showMessage("导出失败")
finally:
self.progress_bar.setVisible(False)
if __name__ == "__main__":
app = QApplication(sys.argv)
app.setStyle("Fusion")
# 设置应用样式
app.setStyleSheet("""
QMainWindow, QDialog {
background-color: #f5f5f5;
}
QGroupBox {
font-weight: bold;
border: 1px solid #ddd;
border-radius: 5px;
margin-top: 10px;
padding-top: 15px;
}
QGroupBox::title {
subcontrol-origin: margin;
subcontrol-position: top center;
padding: 0 5px;
}
QScrollArea {
border: none;
}
QLabel {
padding: 2px;
}
""")
window = ImageDedupTool()
window.show()
sys.exit(app.exec_())
清洗好的数据还需要进行分组,按一定格式放入对应的文件夹。
训练/测试
训练用代码如下。
(感谢deepseek)
import torch
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from PIL import Image
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm # 用于显示进度条 (可选)
def main():
# 定义数据路径
data_dir = 'Train_test/images' # 解压后的根目录
train_dir = data_dir + '/train'
val_dir = data_dir + '/val'
test_dir = data_dir + '/test'
# 定义数据变换 (简单增强)
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪到 224x224
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转为 Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 归一化
])
val_test_transforms = transforms.Compose([
transforms.Resize(256), # 缩放到 256x256
transforms.CenterCrop(224), # 中心裁剪到 224x224
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms)
test_dataset = datasets.ImageFolder(test_dir, transform=val_test_transforms)
# 创建 DataLoader
batch_size = 32 # 根据显存调整,Colab T4 用 32 通常没问题
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 查看类别信息 (非常重要!)
class_names = train_dataset.classes
print(f'训练集类别: {class_names}')
print(f'类别数量: {len(class_names)}')
# 加载预训练 ResNet50
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
# 冻结所有特征提取层的参数 (先只训练新头部)
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层 (fc),适配我们的类别数
num_ftrs = model.fc.in_features
num_classes = len(class_names)
model.fc = torch.nn.Linear(num_ftrs, num_classes) # 可以在这里加 Dropout (nn.Sequential( nn.Dropout(0.5), nn.Linear(...) ))
# 将模型移到 GPU (如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f'模型已移至: {device}')
# 定义损失函数和优化器 (只优化新添加的 fc 层)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001) # 学习率 0.001 是个好起点
# 训练轮数
num_epochs = 10 # 小数据集通常 10-20 轮就能看到效果
# 记录训练过程
train_loss_history = []
val_loss_history = []
val_acc_history = []
for epoch in range(num_epochs):
# ---------- 训练阶段 ----------
model.train() # 设置为训练模式
running_loss = 0.0
# 使用 tqdm 包装 train_loader 显示进度条
for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]'):
inputs = inputs.to(device)
labels = labels.to(device)
# 清零梯度
optimizer.zero_grad()
# 前向传播 + 计算损失
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播 + 优化
loss.backward()
optimizer.step()
# 统计损失
running_loss += loss.item() * inputs.size(0)
epoch_train_loss = running_loss / len(train_dataset)
# ---------- 验证阶段 ----------
model.eval() # 设置为评估模式
val_running_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算,节省内存和计算
for inputs, labels in val_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_val_loss = val_running_loss / len(val_dataset)
epoch_val_acc = correct / total
# 记录历史
train_loss_history.append(epoch_train_loss)
val_loss_history.append(epoch_val_loss)
val_acc_history.append(epoch_val_acc)
# 打印本轮结果
print(f'Epoch {epoch+1}/{num_epochs} | '
f'Train Loss: {epoch_train_loss:.4f} | '
f'Val Loss: {epoch_val_loss:.4f} | '
f'Val Acc: {epoch_val_acc:.4f}')
# 保存验证集上最好的模型 (简单实现)
if epoch_val_acc >= max(val_acc_history, default=0):
best_model_state = model.state_dict()
torch.save(best_model_state, 'insect_resnet50_weights.pth')
print(f'** 保存最佳模型 (准确率: {epoch_val_acc:.4f}) **')
# 训练结束,加载最佳模型状态
model.load_state_dict(best_model_state)
print('训练完成!加载了验证集上表现最好的模型。')
model.eval()
test_correct = 0
test_total = 0
class_correct = [0] * num_classes
class_total = [0] * num_classes
confusion_matrix = torch.zeros(num_classes, num_classes) # 初始化混淆矩阵
with torch.no_grad():
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
# 计算每个类的正确数和总数,以及混淆矩阵
for i in range(len(labels)):
label = labels[i]
pred = predicted[i]
class_correct[label] += (pred == label).item()
class_total[label] += 1
confusion_matrix[label][pred] += 1 # 注意:行是真实标签,列是预测标签
# 打印总体测试准确率
test_acc = test_correct / test_total
print(f'测试集准确率: {test_acc:.4f}')
# (可选) 打印每个类别的准确率
print('\n每个类别的准确率:')
for i in range(num_classes):
if class_total[i] > 0:
acc = class_correct[i] / class_total[i]
print(f' {class_names[i]}: {acc:.4f} ({class_correct[i]}/{class_total[i]})')
else:
print(f' {class_names[i]}: 无样本')
# (可选) 可视化混淆矩阵 (需要 matplotlib 和 seaborn)
""" import seaborn as sns
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix.numpy(), annot=True, fmt='g', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
plt.tight_layout()
plt.show() """
def predict_image(image_path, model, class_names, transform):
"""
预测单张图像的类别
"""
model.eval()
image = Image.open(image_path).convert('RGB') # 确保是 RGB
image_tensor = transform(image).unsqueeze(0) # 增加 batch 维度 [1, C, H, W]
image_tensor = image_tensor.to(device)
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0) # 计算概率
top_prob, top_class_idx = torch.topk(probabilities, k=3) # 取 Top-3
# 打印结果
print(f'预测图像: {image_path}')
plt.imshow(image)
plt.axis('off')
plt.show()
for i in range(top_prob.size(0)):
class_idx = top_class_idx[i].item()
print(f' Top-{i+1}: {class_names[class_idx]} - 概率: {top_prob[i].item():.4f}')
# 使用测试集中的一张图片或上传新图片测试
# 示例:预测测试集第一张图片
sample_image_path = test_dataset.samples[0][0] # 获取第一张测试图片的路径
predict_image(sample_image_path, model, class_names, val_test_transforms)
if __name__ == '__main__':
# 在 Windows 上可能需要添加这行
torch.multiprocessing.freeze_support()
main()
其他
日常生活中人们拍摄的照片大多都不符合模型输入的要求,比如一张图片中昆虫可能只占据了一小部分位置,识别模型的裁切方式可能会导致部分信息丢失,造成识别错误;又比如一张图片中昆虫不止一个等等。因此除了识别系统外我们还需要一个检测框选系统,这部分可用yolo或Faster R-CNN实现。