学习日志|训练一个简单的昆虫识别模型
仅记录自身学习过程,不适合作为教程浏览

数据搜集

在Kaggle(Find Open Datasets and Machine Learning Projects | Kaggle)等网站上下载数据集,或是自行用爬虫爬取

数据清洗

下载下来的数据往往存在大量重复或不符合要求的低质/主题无关图像,需要进行一定处理才可用于训练。由于此次目的仅在于跑通整个训练流程,因此直接使用整理好的数据集,仅进行查重。

deepseek写的用于查重的窗体程序

并不好用,图片一多就特别卡。本来想完全依靠deepseek生成的,但最后还是得自己下场修bug。

import sys
import os
import shutil
from collections import defaultdict
from PIL import Image
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
    QFileDialog, QListWidget, QListWidgetItem, QCheckBox, QScrollArea, QGroupBox,
    QDialog, QGridLayout, QMessageBox, QProgressBar
)
from PyQt5.QtCore import Qt, QSize, QThread, pyqtSignal
from PyQt5.QtGui import QPixmap, QIcon, QFont, QImage, QPainter
from imagededup.methods import PHash

class ImageViewerDialog(QDialog):
    """图片查看对话框,显示大图"""
    def __init__(self, image_path, parent=None):
        super().__init__(parent)
        self.setWindowTitle("图片查看器")
        self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
        self.setMinimumSize(800, 600)
        
        layout = QVBoxLayout()
        
        # 显示图片路径
        path_label = QLabel(f"图片路径: {image_path}")
        path_label.setWordWrap(True)
        path_label.setStyleSheet("font-size: 12px; color: #666;")
        layout.addWidget(path_label)
        
        # 显示图片
        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setStyleSheet("background-color: #f0f0f0;")
        
        scroll_area = QScrollArea()
        scroll_area.setWidgetResizable(True)
        scroll_area.setWidget(self.image_label)
        layout.addWidget(scroll_area)
        
        # 关闭按钮
        close_btn = QPushButton("关闭")
        close_btn.setStyleSheet(
            "QPushButton { background-color: #4CAF50; color: white; padding: 8px 16px; border-radius: 4px; }"
            "QPushButton:hover { background-color: #45a049; }"
        )
        close_btn.clicked.connect(self.close)
        layout.addWidget(close_btn, alignment=Qt.AlignCenter)
        
        self.setLayout(layout)
        self.load_image(image_path)
    
    def load_image(self, image_path):
        """加载并显示图片(修复显示问题)"""
        try:
            # 加载图像并转换为RGBA模式
            image = Image.open(image_path).convert("RGBA")
            
            # 限制最大显示尺寸
            max_size = QSize(1600, 1200)
            if image.width > max_size.width() or image.height > max_size.height():
                image.thumbnail((max_size.width(), max_size.height()), Image.LANCZOS)
            
            # 创建QImage并确保正确的格式
            qimage = QImage(
                image.tobytes(), 
                image.width, 
                image.height, 
                QImage.Format_RGBA8888
            )
            
            pixmap = QPixmap.fromImage(qimage)
            self.image_label.setPixmap(pixmap)
        except Exception as e:
            self.image_label.setText(f"无法加载图片: {str(e)}")

class DedupWorker(QThread):
    """后台线程处理图片查重"""
    progress_updated = pyqtSignal(int, int, str)  # 当前进度, 总数, 当前文件
    duplicates_found = pyqtSignal(dict, list)     # 重复组, 唯一图片
    
    def __init__(self, image_dir):
        super().__init__()
        self.image_dir = image_dir
        self.canceled = False
    
    def run(self):
        """执行查重操作(修复唯一图片计数)"""
        try:
            # 获取所有图片文件
            image_files = []
            extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp')
            for root, _, files in os.walk(self.image_dir):
                for file in files:
                    if file.lower().endswith(extensions):
                        image_files.append(os.path.join(root, file))
            
            if not image_files:
                self.progress_updated.emit(0, 0, "未找到图片文件")
                return
            
            total = len(image_files)
            self.progress_updated.emit(0, total, "开始处理...")
            
            # 使用PHash算法进行查重
            phasher = PHash()
            encodings = {}
            
            # 计算图片哈希
            for i, image_path in enumerate(image_files, 1):
                if self.canceled:
                    return
                
                self.progress_updated.emit(i, total, os.path.basename(image_path))
                try:
                    encodings[image_path] = phasher.encode_image(image_path)
                except Exception as e:
                    print(f"处理图片 {image_path} 时出错: {str(e)}")
            
            # 查找重复图片
            duplicates = phasher.find_duplicates(encoding_map=encodings)
            
            # 组织重复组
            grouped_duplicates = defaultdict(list)
            visited = set()
            dup_group =[] #统计所有重复图片
            for img, dup_list in duplicates.items():
                if img in visited:
                    continue
                
                # 创建一个新的重复组
                group = [img]
                visited.add(img)
                
                # 添加所有重复项
                for dup in dup_list:
                    if dup not in visited:
                        group.append(dup)
                        visited.add(dup)
                
                # 添加组
                if len(group) > 1: #group长度大于一,即存在重复图片
                    grouped_duplicates[img] = group
                    dup_group.extend(group) #添加重复图片
            
            # 获取唯一图片 - 修复唯一图片计数问题
            unique_images = [img for img in image_files if img not in dup_group]#原来是unique_images = [img for img in image_files if img not in visited]
            print("nique_images:",len(unique_images))
            
            self.duplicates_found.emit(grouped_duplicates, unique_images)
        
        except Exception as e:
            print(f"查重过程中发生错误: {str(e)}")
    
    def cancel(self):
        """取消操作"""
        self.canceled = True

class ImageDedupTool(QMainWindow):
    """图片查重工具主界面(修复所有报告的问题)"""
    def __init__(self):
        super().__init__()
        self.setWindowTitle("图片查重工具")
        self.setGeometry(100, 100, 1000, 700)
        self.setWindowIcon(QIcon(self.create_icon()))
        
        # 初始化变量
        self.image_dir = ""
        self.output_dir = ""
        self.duplicate_groups = {}
        self.unique_images = []
        self.selected_images = set()
        self.worker_thread = None
        
        # 创建主部件和布局
        main_widget = QWidget()
        main_layout = QVBoxLayout()
        
        # 标题
        title_label = QLabel("图片查重工具")
        title_label.setFont(QFont("Arial", 18, QFont.Bold))
        title_label.setStyleSheet("color: #2c3e50; margin: 15px 0;")
        title_label.setAlignment(Qt.AlignCenter)
        main_layout.addWidget(title_label)
        
        # 文件夹选择区域
        folder_group = QGroupBox("文件夹设置")
        folder_layout = QGridLayout()
        
        self.image_dir_label = QLabel("未选择图片文件夹")
        self.image_dir_label.setStyleSheet("font-size: 12px;")
        self.image_dir_label.setWordWrap(True)
        folder_layout.addWidget(QLabel("图片文件夹:"), 0, 0)
        folder_layout.addWidget(self.image_dir_label, 0, 1)
        
        self.output_dir_label = QLabel("未选择输出文件夹")
        self.output_dir_label.setStyleSheet("font-size: 12px;")
        self.output_dir_label.setWordWrap(True)
        folder_layout.addWidget(QLabel("输出文件夹:"), 1, 0)
        folder_layout.addWidget(self.output_dir_label, 1, 1)
        
        self.select_image_btn = QPushButton("选择图片文件夹")
        self.select_image_btn.setStyleSheet(self.get_button_style())
        self.select_image_btn.clicked.connect(self.select_image_dir)
        folder_layout.addWidget(self.select_image_btn, 0, 2)
        
        self.select_output_btn = QPushButton("选择输出文件夹")
        self.select_output_btn.setStyleSheet(self.get_button_style())
        self.select_output_btn.clicked.connect(self.select_output_dir)
        folder_layout.addWidget(self.select_output_btn, 1, 2)
        
        folder_group.setLayout(folder_layout)
        main_layout.addWidget(folder_group)
        
        # 操作按钮区域
        btn_layout = QHBoxLayout()
        
        self.start_btn = QPushButton("开始查重")
        self.start_btn.setStyleSheet(self.get_button_style("#3498db"))
        self.start_btn.clicked.connect(self.start_dedup)
        self.start_btn.setEnabled(False)
        btn_layout.addWidget(self.start_btn)
        
        self.cancel_btn = QPushButton("取消")
        self.cancel_btn.setStyleSheet(self.get_button_style("#e74c3c"))
        self.cancel_btn.clicked.connect(self.cancel_dedup)
        self.cancel_btn.setEnabled(False)
        btn_layout.addWidget(self.cancel_btn)
        
        self.export_btn = QPushButton("导出结果")
        self.export_btn.setStyleSheet(self.get_button_style("#2ecc71"))
        self.export_btn.clicked.connect(self.export_results)
        self.export_btn.setEnabled(False)
        btn_layout.addWidget(self.export_btn)
        
        main_layout.addLayout(btn_layout)
        
        # 进度条
        self.progress_bar = QProgressBar()
        self.progress_bar.setStyleSheet(
            "QProgressBar { border: 1px solid #3498db; border-radius: 5px; text-align: center; }"
            "QProgressBar::chunk { background-color: #3498db; }"
        )
        self.progress_bar.setVisible(False)
        main_layout.addWidget(self.progress_bar)
        
        # 结果区域
        result_group = QGroupBox("查重结果")
        result_layout = QVBoxLayout()
        
        # 结果统计标签
        self.result_label = QLabel("请选择图片文件夹并点击'开始查重'按钮继续")
        self.result_label.setStyleSheet("font-size: 14px; margin: 10px 0;")
        self.result_label.setAlignment(Qt.AlignCenter)
        result_layout.addWidget(self.result_label)
        
        # 结果列表区域
        self.result_scroll = QScrollArea()
        self.result_scroll.setWidgetResizable(True)
        
        self.result_container = QWidget()
        self.result_container_layout = QVBoxLayout()
        self.result_container_layout.setAlignment(Qt.AlignTop)
        
        self.result_container.setLayout(self.result_container_layout)
        self.result_scroll.setWidget(self.result_container)
        
        result_layout.addWidget(self.result_scroll)
        result_group.setLayout(result_layout)
        main_layout.addWidget(result_group, 1)
        
        main_widget.setLayout(main_layout)
        self.setCentralWidget(main_widget)
        
        # 状态栏
        self.statusBar().showMessage("就绪")
    
    def create_icon(self):
        """创建应用图标"""
        # 使用简单图标替代
        pixmap = QPixmap(64, 64)
        pixmap.fill(Qt.transparent)
        return QIcon(pixmap)
    
    def get_button_style(self, color="#7f8c8d"):
        """获取按钮样式"""
        return f"""
            QPushButton {{
                background-color: {color};
                color: white;
                border: none;
                padding: 8px 16px;
                font-size: 14px;
                border-radius: 4px;
            }}
            QPushButton:hover {{
                background-color: #{int(color[1:], 16) - 0x111111:06X};
            }}
            QPushButton:disabled {{
                background-color: #bdc3c7;
            }}
        """
    
    def select_image_dir(self):
        """选择图片文件夹"""
        dir_path = QFileDialog.getExistingDirectory(self, "选择图片文件夹")
        if dir_path:
            self.image_dir = dir_path
            self.image_dir_label.setText(dir_path)
            self.start_btn.setEnabled(True)
            self.result_label.setText("已选择图片文件夹,点击'开始查重'按钮继续")
    
    def select_output_dir(self):
        """选择输出文件夹"""
        dir_path = QFileDialog.getExistingDirectory(self, "选择输出文件夹")
        if dir_path:
            self.output_dir = dir_path
            self.output_dir_label.setText(dir_path)
    
    def start_dedup(self):
        """开始查重操作"""
        if not self.image_dir:
            QMessageBox.warning(self, "警告", "请先选择图片文件夹")
            return
        
        # 清空之前的结果
        self.duplicate_groups = {}
        self.unique_images = []
        self.selected_images = set()
        self.clear_result_container()
        
        # 显示进度条
        self.progress_bar.setVisible(True)
        self.progress_bar.setValue(0)
        self.result_label.setText("正在处理图片,请稍候...")
        self.statusBar().showMessage("正在查找重复图片...")
        
        # 禁用按钮
        self.start_btn.setEnabled(False)
        self.cancel_btn.setEnabled(True)
        self.export_btn.setEnabled(False)
        
        # 启动后台线程
        self.worker_thread = DedupWorker(self.image_dir)
        self.worker_thread.progress_updated.connect(self.update_progress)
        self.worker_thread.duplicates_found.connect(self.process_results)
        self.worker_thread.start()
    
    def cancel_dedup(self):
        """取消查重操作"""
        if self.worker_thread and self.worker_thread.isRunning():
            self.worker_thread.cancel()
            self.worker_thread.wait()
        
        self.progress_bar.setVisible(False)
        self.result_label.setText("操作已取消")
        self.statusBar().showMessage("操作已取消")
        
        # 恢复按钮状态
        self.start_btn.setEnabled(True)
        self.cancel_btn.setEnabled(False)
    
    def update_progress(self, current, total, filename):
        """更新进度条"""
        if total > 0:
            progress = int((current / total) * 100)
            self.progress_bar.setValue(progress)
            self.statusBar().showMessage(f"正在处理: {filename} ({current}/{total})")
    
    def process_results(self, duplicates, uniques):
        """处理查重结果"""
        self.duplicate_groups = duplicates
        self.unique_images = uniques
        self.selected_images = set(self.unique_images)  # 自动选择唯一图片
        
        # 对于重复组,默认选择每组的第一张图片
        for group in self.duplicate_groups.values():
            if group:  # 确保组不为空
                self.selected_images.add(group[0])
        
        # 隐藏进度条
        self.progress_bar.setVisible(False)
        
        # 更新结果统计
        duplicate_count = sum(len(group) for group in self.duplicate_groups.values())
        unique_count = len(self.unique_images)
        total_images = duplicate_count + unique_count
        
        self.result_label.setText(
            f"查重完成!共发现 {len(self.duplicate_groups)} 组重复图片 "
            f"({duplicate_count} 张), {unique_count} 张唯一图片"
        )
        self.statusBar().showMessage("查重完成")
        
        # 显示结果
        self.display_results()
        
        # 更新按钮状态
        self.start_btn.setEnabled(True)
        self.cancel_btn.setEnabled(False)
        self.export_btn.setEnabled(True)
    
    def clear_result_container(self):
        """清空结果容器"""
        while self.result_container_layout.count():
            child = self.result_container_layout.takeAt(0)
            if child.widget():
                child.widget().deleteLater()
    
    def display_results(self):
        """显示查重结果"""
        self.clear_result_container()
               
        # 显示重复图片组
        for group_idx, (_, group) in enumerate(self.duplicate_groups.items(), 1):
            group_box = QGroupBox(f"重复图片组 {group_idx} - 共 {len(group)} 张")
            group_layout = QVBoxLayout()
            
            for img_path in group:
                img_item = self.create_image_item(img_path, group_idx)
                group_layout.addWidget(img_item)
            
            group_box.setLayout(group_layout)
            self.result_container_layout.addWidget(group_box)

        # 显示唯一图片
        if self.unique_images:
            unique_group = QGroupBox(f"唯一图片 - 共 {len(self.unique_images)} 张")
            unique_layout = QVBoxLayout()
            
            for img_path in self.unique_images:
                img_item = self.create_image_item(img_path, is_unique=True)
                unique_layout.addWidget(img_item)
            
            unique_group.setLayout(unique_layout)
            self.result_container_layout.addWidget(unique_group)
        
    def create_image_item(self, image_path, group_id=None, is_unique=False):
        """创建图片项控件(修复选择问题和缩略图问题)"""
        item_widget = QWidget()
        item_layout = QHBoxLayout()
        
        # 缩略图
        try:
            # 加载图像并确保使用RGBA格式
            img = Image.open(image_path).convert("RGBA")
            img.thumbnail((120, 120), Image.LANCZOS)
            
            # 创建QImage并确保正确的格式
            qimage = QImage(
                img.tobytes(), 
                img.width, 
                img.height, 
                QImage.Format_RGBA8888
            )
            
            pixmap = QPixmap.fromImage(qimage)
            
            thumbnail_label = QLabel()
            thumbnail_label.setPixmap(pixmap)
            thumbnail_label.setStyleSheet("border: 1px solid #ddd; padding: 5px;")
            thumbnail_label.setCursor(Qt.PointingHandCursor)
            thumbnail_label.mousePressEvent = lambda e: self.show_image(image_path)
        except Exception as e:
            print(f"创建缩略图时出错: {str(e)}")
            thumbnail_label = QLabel("无法加载图片")
            thumbnail_label.setStyleSheet("color: red;")
        
        item_layout.addWidget(thumbnail_label)
        
        # 图片信息
        info_layout = QVBoxLayout()
        
        # 文件名和路径
        filename = os.path.basename(image_path)
        path_label = QLabel(f"文件名: {filename}")
        path_label.setStyleSheet("font-weight: bold;")
        
        dir_label = QLabel(f"路径: {os.path.dirname(image_path)}")
        dir_label.setStyleSheet("font-size: 11px; color: #555;")
        
        # 图片尺寸和大小
        try:
            img = Image.open(image_path)  # 重新打开以获取原始尺寸
            size = f"{img.width}×{img.height}"
            img_size = os.path.getsize(image_path)
            size_mb = img_size / (1024 * 1024)
            size_label = QLabel(f"尺寸: {size} | 大小: {size_mb:.2f} MB")
        except Exception:
            size_label = QLabel("尺寸信息不可用")
        
        size_label.setStyleSheet("font-size: 12px;")
        
        info_layout.addWidget(path_label)
        info_layout.addWidget(dir_label)
        info_layout.addWidget(size_label)
        
        # 添加分隔线
        if not is_unique:
            info_layout.addWidget(QLabel())
            info_layout.addWidget(QLabel("请选择要保留的图片:"))
        
        item_layout.addLayout(info_layout, 1)
        
        # 对于唯一图片,不显示选择框
        if is_unique:
            status_label = QLabel("唯一图片,将自动保留")
            status_label.setStyleSheet("color: #27ae60; font-weight: bold;")
            item_layout.addWidget(status_label)
        else:
            # 选择框 - 修复选择逻辑
            checkbox = QCheckBox("保留此图片")
            
            # 设置初始选中状态
            is_selected = image_path in self.selected_images
            checkbox.setChecked(is_selected)
            
            # 连接状态改变信号
            checkbox.stateChanged.connect(lambda state, path=image_path: self.toggle_selection(path, state))
            
            item_layout.addWidget(checkbox)
        
        item_widget.setLayout(item_layout)
        item_widget.setStyleSheet("""
            QWidget {
                border-bottom: 1px solid #eee;
                padding: 10px;
            }
            QWidget:hover {
                background-color: #f9f9f9;
            }
        """)
        
        return item_widget
    
    def toggle_selection(self, image_path, state):
        """切换图片选择状态(修复选择逻辑)"""
        if state == Qt.Checked:
            self.selected_images.add(image_path)
            
            # 确保同一组中只有一个被选中(单选逻辑)
            for group in self.duplicate_groups.values():
                if image_path in group:
                    # 取消选择同组的其他图片
                    for other_path in group:
                        if other_path != image_path and other_path in self.selected_images:
                            self.selected_images.remove(other_path)
                    
                    # 更新UI显示
                    self.display_results()
                    break
        else:
            # 如果取消选择,但该图片是组中唯一被选中的,则重新选中
            for group in self.duplicate_groups.values():
                if image_path in group:
                    # 检查是否还有同组被选中的图片
                    any_selected = any(p in self.selected_images for p in group)
                    if not any_selected:
                        # 如果没有选中的图片,重新选中当前图片
                        self.selected_images.add(image_path)
                        # 更新UI显示
                        self.display_results()
                        return
            # 如果不是组中的图片,直接移除
            self.selected_images.discard(image_path)
    
    def show_image(self, image_path):
        """显示大图(修复崩溃问题)"""
        try:
            viewer = ImageViewerDialog(image_path, self)
            viewer.exec_()
        except Exception as e:
            QMessageBox.critical(self, "错误", f"无法显示图片: {str(e)}")
    
    def export_results(self):
        """导出结果到输出文件夹(修复导出问题)"""
        if not self.output_dir:
            QMessageBox.warning(self, "警告", "请先选择输出文件夹")
            return
        
        if not self.selected_images:
            QMessageBox.warning(self, "警告", "没有选择要保留的图片")
            return
        
        try:
            # 创建输出目录
            os.makedirs(self.output_dir, exist_ok=True)
            
            # 复制选中的图片
            total = len(self.selected_images)
            copied = 0
            
            # 显示进度条
            self.progress_bar.setVisible(True)
            self.progress_bar.setValue(0)
            
            for img_path in self.selected_images:
                filename = os.path.basename(img_path)
                dest_path = os.path.join(self.output_dir, filename)
                
                # 处理文件名冲突
                counter = 1
                while os.path.exists(dest_path):
                    name, ext = os.path.splitext(filename)
                    dest_path = os.path.join(self.output_dir, f"{name}_{counter}{ext}")
                    counter += 1
                
                shutil.copy2(img_path, dest_path)
                copied += 1
                progress = int((copied / total) * 100)
                self.progress_bar.setValue(progress)
                self.statusBar().showMessage(f"正在导出: {filename} ({copied}/{total})")
                QApplication.processEvents()  # 更新UI
            
            # 显示完成消息
            QMessageBox.information(
                self, 
                "导出完成", 
                f"成功导出 {copied} 张图片到:\n{self.output_dir}"
            )
            self.statusBar().showMessage(f"导出完成: {copied} 张图片")
            
        except Exception as e:
            QMessageBox.critical(
                self, 
                "导出错误", 
                f"导出过程中发生错误:\n{str(e)}"
            )
            self.statusBar().showMessage("导出失败")
        finally:
            self.progress_bar.setVisible(False)

if __name__ == "__main__":
    app = QApplication(sys.argv)
    app.setStyle("Fusion")
    
    # 设置应用样式
    app.setStyleSheet("""
        QMainWindow, QDialog {
            background-color: #f5f5f5;
        }
        QGroupBox {
            font-weight: bold;
            border: 1px solid #ddd;
            border-radius: 5px;
            margin-top: 10px;
            padding-top: 15px;
        }
        QGroupBox::title {
            subcontrol-origin: margin;
            subcontrol-position: top center;
            padding: 0 5px;
        }
        QScrollArea {
            border: none;
        }
        QLabel {
            padding: 2px;
        }
    """)
    
    window = ImageDedupTool()
    window.show()
    sys.exit(app.exec_())

清洗好的数据还需要进行分组,按一定格式放入对应的文件夹。

训练/测试

训练用代码如下。

(感谢deepseek)

import torch
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from PIL import Image

import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm  # 用于显示进度条 (可选)


def main():
    # 定义数据路径
    data_dir = 'Train_test/images'  # 解压后的根目录
    train_dir = data_dir + '/train'
    val_dir = data_dir + '/val'
    test_dir = data_dir + '/test'

    # 定义数据变换 (简单增强)
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),       # 随机裁剪到 224x224
        transforms.RandomHorizontalFlip(),       # 随机水平翻转
        transforms.ToTensor(),                   # 转为 Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 归一化
    ])

    val_test_transforms = transforms.Compose([
        transforms.Resize(256),                 # 缩放到 256x256
        transforms.CenterCrop(224),             # 中心裁剪到 224x224
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 加载数据集
    train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
    val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms)
    test_dataset = datasets.ImageFolder(test_dir, transform=val_test_transforms)

    # 创建 DataLoader
    batch_size = 32  # 根据显存调整,Colab T4 用 32 通常没问题
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 查看类别信息 (非常重要!)
    class_names = train_dataset.classes
    print(f'训练集类别: {class_names}')
    print(f'类别数量: {len(class_names)}')



    # 加载预训练 ResNet50
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

    # 冻结所有特征提取层的参数 (先只训练新头部)
    for param in model.parameters():
        param.requires_grad = False

    # 替换最后的全连接层 (fc),适配我们的类别数
    num_ftrs = model.fc.in_features
    num_classes = len(class_names)
    model.fc = torch.nn.Linear(num_ftrs, num_classes)  # 可以在这里加 Dropout (nn.Sequential( nn.Dropout(0.5), nn.Linear(...) ))

    # 将模型移到 GPU (如果可用)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(f'模型已移至: {device}')



    # 定义损失函数和优化器 (只优化新添加的 fc 层)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.fc.parameters(), lr=0.001)  # 学习率 0.001 是个好起点

    # 训练轮数
    num_epochs = 10  # 小数据集通常 10-20 轮就能看到效果

    # 记录训练过程
    train_loss_history = []
    val_loss_history = []
    val_acc_history = []

    for epoch in range(num_epochs):
        # ---------- 训练阶段 ----------
        model.train()  # 设置为训练模式
        running_loss = 0.0

        # 使用 tqdm 包装 train_loader 显示进度条
        for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]'):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 清零梯度
            optimizer.zero_grad()

            # 前向传播 + 计算损失
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # 反向传播 + 优化
            loss.backward()
            optimizer.step()

            # 统计损失
            running_loss += loss.item() * inputs.size(0)

        epoch_train_loss = running_loss / len(train_dataset)

        # ---------- 验证阶段 ----------
        model.eval()  # 设置为评估模式
        val_running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():  # 禁用梯度计算,节省内存和计算
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item() * inputs.size(0)

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_val_loss = val_running_loss / len(val_dataset)
        epoch_val_acc = correct / total

        # 记录历史
        train_loss_history.append(epoch_train_loss)
        val_loss_history.append(epoch_val_loss)
        val_acc_history.append(epoch_val_acc)

        # 打印本轮结果
        print(f'Epoch {epoch+1}/{num_epochs} | '
            f'Train Loss: {epoch_train_loss:.4f} | '
            f'Val Loss: {epoch_val_loss:.4f} | '
            f'Val Acc: {epoch_val_acc:.4f}')

        # 保存验证集上最好的模型 (简单实现)
        if epoch_val_acc >= max(val_acc_history, default=0):
            best_model_state = model.state_dict()
            torch.save(best_model_state, 'insect_resnet50_weights.pth')
            print(f'** 保存最佳模型 (准确率: {epoch_val_acc:.4f}) **')

    # 训练结束,加载最佳模型状态
    model.load_state_dict(best_model_state)
    print('训练完成!加载了验证集上表现最好的模型。')


    model.eval()
    test_correct = 0
    test_total = 0
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    confusion_matrix = torch.zeros(num_classes, num_classes)  # 初始化混淆矩阵

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            # 计算每个类的正确数和总数,以及混淆矩阵
            for i in range(len(labels)):
                label = labels[i]
                pred = predicted[i]
                class_correct[label] += (pred == label).item()
                class_total[label] += 1
                confusion_matrix[label][pred] += 1  # 注意:行是真实标签,列是预测标签

    # 打印总体测试准确率
    test_acc = test_correct / test_total
    print(f'测试集准确率: {test_acc:.4f}')

    # (可选) 打印每个类别的准确率
    print('\n每个类别的准确率:')
    for i in range(num_classes):
        if class_total[i] > 0:
            acc = class_correct[i] / class_total[i]
            print(f'  {class_names[i]}: {acc:.4f} ({class_correct[i]}/{class_total[i]})')
        else:
            print(f'  {class_names[i]}: 无样本')

    # (可选) 可视化混淆矩阵 (需要 matplotlib 和 seaborn)
    """ import seaborn as sns
    plt.figure(figsize=(10, 8))
    sns.heatmap(confusion_matrix.numpy(), annot=True, fmt='g', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.title('混淆矩阵')
    plt.tight_layout()
    plt.show() """


    def predict_image(image_path, model, class_names, transform):
        """
        预测单张图像的类别
        """
        model.eval()
        image = Image.open(image_path).convert('RGB')  # 确保是 RGB
        image_tensor = transform(image).unsqueeze(0)  # 增加 batch 维度 [1, C, H, W]
        image_tensor = image_tensor.to(device)

        with torch.no_grad():
            output = model(image_tensor)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)  # 计算概率
            top_prob, top_class_idx = torch.topk(probabilities, k=3)  # 取 Top-3

        # 打印结果
        print(f'预测图像: {image_path}')
        plt.imshow(image)
        plt.axis('off')
        plt.show()
        for i in range(top_prob.size(0)):
            class_idx = top_class_idx[i].item()
            print(f'  Top-{i+1}: {class_names[class_idx]} - 概率: {top_prob[i].item():.4f}')

    # 使用测试集中的一张图片或上传新图片测试
    # 示例:预测测试集第一张图片
    sample_image_path = test_dataset.samples[0][0]  # 获取第一张测试图片的路径
    predict_image(sample_image_path, model, class_names, val_test_transforms)


if __name__ == '__main__':
    # 在 Windows 上可能需要添加这行
    torch.multiprocessing.freeze_support()
    main()

其他

日常生活中人们拍摄的照片大多都不符合模型输入的要求,比如一张图片中昆虫可能只占据了一小部分位置,识别模型的裁切方式可能会导致部分信息丢失,造成识别错误;又比如一张图片中昆虫不止一个等等。因此除了识别系统外我们还需要一个检测框选系统,这部分可用yolo或Faster R-CNN实现。

上一篇
下一篇