commit 0f10187716c2a6cd6cd4a8228da121adc09db40e Author: Wang_Run_Ze Date: Fri Jun 27 10:47:10 2025 +0800 上传文件至 / diff --git a/README.md b/README.md new file mode 100644 index 0000000..481eb29 --- /dev/null +++ b/README.md @@ -0,0 +1,287 @@ + + +## 📋 系统简介 + +本系统是一个基于YOLO深度学习模型的道路损伤检测系统,支持对视频文件进行智能分析,自动识别和标注道路损伤区域。系统提供了命令行和Web界面两种使用方式,适合不同用户需求。 + +### 主要功能 +- 📹 支持多种视频格式批量处理 +- 🌐 友好的Web界面操作 +- 📊 详细的检测报告和统计分析 +- 🔧 灵活的模型管理和参数配置 +- 💾 完整的检测结果保存和导出 + +## 🚀 快速开始 + +### 第一步:环境准备 + +1. **安装Python环境** + - 下载并安装 Python 3.8 或更高版本 + - 下载地址:https://www.python.org/downloads/ + - 安装时勾选 "Add Python to PATH" + +2. **安装系统依赖** + ```bash + # 方法一:使用安装脚本(推荐) + 双击运行 install.bat + + # 方法二:手动安装 + pip install -r requirements.txt + ``` + +### 第二步:准备文件 + +1. **准备视频文件** + - 将需要检测的视频文件放入 `input_videos` 文件夹 + - 支持格式:`.mp4`, `.avi`, `.mov`, `.mkv`, `.wmv`, `.flv`, `.webm` + - 可以创建子文件夹组织视频文件 + +2. **准备检测模型** + - 系统已预置 `best.pt` 和 `yolov8x.pt` 模型 + - 可通过Web界面上传自定义模型(.pt格式) + +### 第三步:开始检测 + +#### 方式一:Web界面(推荐) +```bash +# 启动Web应用 +双击运行 start_web_app.bat +# 或者 +python app.py +``` +然后在浏览器中访问:http://localhost:5000 + +#### 方式二:命令行批量检测 +```bash +# 使用启动脚本 +双击运行 start_detection.bat +# 或者 +python run_detection.py +``` + +## 🌐 Web界面使用指南 + +### 界面布局 +- **左侧面板**:检测结果列表,显示历史检测记录 +- **右侧面板**:文件处理和配置功能 + +### 功能详解 + +#### 1. 模型管理 +- **上传新模型**:点击上传区域选择.pt格式的YOLO模型文件 +- **查看模型列表**:显示所有可用模型及其信息 +- **删除模型**:移除不需要的模型文件 + +#### 2. 视频文件管理 +- **添加视频**:直接将视频文件放入 `input_videos` 文件夹 +- **查看文件列表**:显示当前待检测的视频文件 +- **删除文件**:移除单个视频文件 +- **清空文件夹**:一键清空所有视频文件 + +#### 3. 检测配置 +- **选择模型**:从下拉列表中选择检测模型 +- **置信度阈值**:调整检测敏感度(0.1-0.9) + - 低值:检测更多目标,可能包含误检 + - 高值:只检测高置信度目标,可能遗漏部分目标 + +#### 4. 开始检测 +- 点击"开始检测"按钮启动检测过程 +- 实时显示检测进度和状态 +- 显示当前处理的视频和帧数信息 + +#### 5. 结果查看 +- **检测报告**:查看详细的检测统计和分析 +- **视频结果**:浏览每个视频的检测结果 +- **导出功能**:下载检测报告和结果数据 + +## 💻 命令行使用指南 + +### 基本命令 + +```bash +# 批量检测所有视频 +python batch_detector.py + +# 交互式检测 +python run_detection.py + +# 单个视频检测 +python detector.py --input video.mp4 --model best.pt +``` + +### 参数说明 + +- `--input`: 输入视频文件路径 +- `--model`: 检测模型路径 +- `--conf`: 置信度阈值(默认0.5) +- `--output`: 输出目录路径 + +## 📁 文件结构说明 + +``` +检测系统/ +├── app.py # Web应用主程序 +├── batch_detector.py # 批量检测器 +├── config.py # 配置文件 +├── detector.py # 核心检测器 +├── run_detection.py # 交互式检测脚本 +├── requirements.txt # Python依赖包列表 +├── install.bat # 环境安装脚本 +├── start_web_app.bat # Web应用启动脚本 +├── start_detection.bat # 批量检测启动脚本 +├── input_videos/ # 输入视频文件夹 +├── models/ # 检测模型文件夹 +│ ├── best.pt # 预训练模型 +│ ├── best_info.json # 模型信息 +│ ├── yolov8x.pt # YOLO模型 +│ └── yolov8x_info.json # 模型信息 +├── output_frames/ # 检测结果输出文件夹 +├── logs/ # 系统日志文件夹 +├── static/ # Web界面静态资源 +│ └── css/ +│ └── style.css # 界面样式文件 +└── templates/ # Web界面模板 + ├── index.html # 主页面 + ├── layout.html # 布局模板 + ├── report.html # 报告页面 + └── video_results.html # 视频结果页面 +``` + +## 🔧 配置说明 + +### 系统配置(config.py) + +- `INPUT_FOLDER`: 输入视频文件夹路径 +- `OUTPUT_FOLDER`: 输出结果文件夹路径 +- `MODELS_FOLDER`: 模型文件夹路径 +- `CONFIDENCE_THRESHOLD`: 默认置信度阈值 +- `SUPPORTED_FORMATS`: 支持的视频格式 + +### 检测参数调优 + +1. **置信度阈值** + - 道路损伤检测推荐:0.3-0.6 + - 精确检测推荐:0.6-0.8 + - 宽松检测推荐:0.2-0.4 + +2. **模型选择** + - `best.pt`: 专门训练的道路损伤检测模型 + - `yolov8x.pt`: 通用目标检测模型 + +## 📊 检测结果说明 + +### 输出文件结构 + +``` +output_frames/ +└── [检测时间戳]/ + ├── report.html # 检测报告 + ├── summary.json # 检测摘要 + └── [视频名称]/ + ├── original/ # 原始帧 + ├── annotated/ # 标注帧 + └── detections.json # 检测数据 +``` + +### 检测数据格式 + +```json +{ + "frame_number": 123, + "timestamp": "00:02:05", + "detections": [ + { + "class": "pothole", + "confidence": 0.85, + "bbox": [x, y, width, height] + } + ] +} +``` + +## 🚨 故障排除 + +### 常见问题 + +1. **Web界面无法访问** + - 检查防火墙设置 + - 确认端口5000未被占用 + - 重启Web应用 + +2. **检测速度慢** + - 检查是否有GPU可用 + - 降低视频分辨率 + - 调整检测参数 + +3. **内存不足** + - 减少批量处理的视频数量 + - 降低视频分辨率 + - 关闭其他应用程序 + +4. **模型加载失败** + - 检查模型文件完整性 + - 确认模型格式正确(.pt) + - 重新下载模型文件 + +### 日志查看 + +```bash +# 查看系统日志 +type logs\app.log + +# 查看检测日志 +type logs\detection.log +``` + +## 🔄 系统更新 + +### 更新依赖包 + +```bash +pip install -r requirements.txt --upgrade +``` + +### 更新模型 + +1. 下载新的模型文件 +2. 放入 `models` 文件夹 +3. 通过Web界面上传或直接复制 + +## 📞 技术支持 + +### 系统要求 + +- **操作系统**: Windows 10/11, macOS 10.14+, Ubuntu 18.04+ +- **Python版本**: 3.8 或更高 +- **内存**: 建议8GB以上 +- **存储**: 建议10GB可用空间 +- **GPU**: 可选,NVIDIA GPU可加速检测 + +### 性能优化建议 + +1. **硬件优化** + - 使用SSD存储提高I/O性能 + - 配置NVIDIA GPU加速 + - 增加系统内存 + +2. **软件优化** + - 定期清理输出文件夹 + - 优化视频文件大小 + - 调整检测参数 + +## 📝 更新日志 + +### v2.0.0 +- 新增Web界面 +- 优化检测算法 +- 改进用户体验 +- 增加批量处理功能 + +### v1.0.0 +- 初始版本发布 +- 基础检测功能 +- 命令行界面 + +--- + +**注意**: 本系统仅供学习和研究使用,检测结果仅供参考,实际应用请结合专业判断。 \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..32ec99d --- /dev/null +++ b/app.py @@ -0,0 +1,734 @@ +from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, send_from_directory +import os +import shutil +import time +import threading +import json +from datetime import datetime +from pathlib import Path +from flask import Flask, render_template, request, redirect, url_for, flash, jsonify +from werkzeug.utils import secure_filename +from config import config +from batch_detector import BatchVideoDetector + +app = Flask(__name__) +app.secret_key = 'road_damage_detection_system' +app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max upload size + +# 确保必要的文件夹存在 +for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']: + os.makedirs(folder, exist_ok=True) + +# 全局变量用于跟踪检测进度和状态 +detection_progress = 0 +detection_status = "idle" # idle, running, completed, error +detection_thread = None +current_detection_info = {} +detection_start_time = None +detection_results = None +current_report_path = None +current_video_name = "" +total_videos = 0 +current_video_index = 0 +current_frame_count = 0 +total_frame_count = 0 + +# 允许的文件扩展名 +ALLOWED_VIDEO_EXTENSIONS = {ext.lstrip('.') for ext in config.SUPPORTED_VIDEO_FORMATS} +ALLOWED_MODEL_EXTENSIONS = {'pt', 'pth', 'weights'} + +def allowed_file(filename, allowed_extensions): + return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions + +# 主页 +@app.route('/', methods=['GET', 'POST']) +def index(): + global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time + + # 处理POST请求(开始检测) + if request.method == 'POST': + if detection_status == "running": + flash('已有检测任务正在运行', 'warning') + else: + # 获取选择的模型和置信度阈值 + model_name = request.form.get('model_select') + confidence_threshold = float(request.form.get('confidence_threshold', 0.25)) + + # 检查是否有视频文件 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + if not videos: + flash('输入文件夹中没有视频文件,请先上传视频', 'danger') + elif not model_name: + flash('请选择一个检测模型', 'danger') + else: + model_path = os.path.join('models', model_name) + if not os.path.exists(model_path): + flash(f'模型文件 {model_name} 不存在', 'danger') + else: + # 更新配置 + config.MODEL_PATH = model_path + config.CONFIDENCE_THRESHOLD = confidence_threshold + + # 重置检测状态和进度 + detection_progress = 0 + detection_status = "running" + detection_results = None + current_report_path = None + detection_start_time = datetime.now() + + # 启动检测线程 + detection_thread = threading.Thread(target=run_detection) + detection_thread.daemon = True + detection_thread.start() + + flash('检测已开始,请等待结果', 'info') + + # 获取可用的模型列表 + print(f"models文件夹路径: {os.path.abspath('models')}") + print(f"models文件夹存在: {os.path.exists('models')}") + if os.path.exists('models'): + all_files = os.listdir('models') + print(f"models文件夹中的所有文件: {all_files}") + for f in all_files: + print(f"检查文件 {f}: allowed_file结果 = {allowed_file(f, ALLOWED_MODEL_EXTENSIONS)}") + models = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)] + print(f"找到的模型文件: {models}") + + # 获取输入文件夹中的视频文件 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + + # 获取检测报告列表 + reports = [] + if os.path.exists(config.OUTPUT_FOLDER): + for report_file in os.listdir(config.OUTPUT_FOLDER): + if report_file.startswith('batch_detection_report_') and report_file.endswith('.json'): + report_path = os.path.join(config.OUTPUT_FOLDER, report_file) + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + reports.append({ + 'id': report_file, + 'timestamp': report_data.get('timestamp', '未知'), + 'video_count': report_data.get('statistics', {}).get('processed_videos', 0), + 'total_frames': report_data.get('statistics', {}).get('total_frames', 0), + 'detected_frames': report_data.get('statistics', {}).get('detected_frames', 0) + }) + except Exception as e: + print(f"Error loading report {report_path}: {e}") + + # 按时间戳排序,最新的在前面 + reports.sort(key=lambda x: x['timestamp'], reverse=True) + + return render_template('index.html', + models=models, + videos=videos, + reports=reports, + detection_status=detection_status, + detection_progress=detection_progress, + detection_in_progress=(detection_status == "running"), + confidence_threshold=config.CONFIDENCE_THRESHOLD, + config=config) + +# 上传模型 +@app.route('/upload_model', methods=['POST']) +def upload_model(): + if 'model_file' not in request.files: + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + file = request.files['model_file'] + if file.filename == '': + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + if file and allowed_file(file.filename, ALLOWED_MODEL_EXTENSIONS): + filename = secure_filename(file.filename) + model_path = os.path.join('models', filename) + file.save(model_path) + + # 立即检测模型信息并缓存 + try: + print(f"正在分析上传的模型: {filename}") + from ultralytics import YOLO + model = YOLO(model_path) + + # 保存模型信息到JSON文件 + model_info = { + 'filename': filename, + 'upload_time': datetime.now().isoformat(), + 'classes': dict(model.names) if hasattr(model, 'names') else {}, + 'class_count': len(model.names) if hasattr(model, 'names') else 0, + 'model_type': 'YOLO', + 'analyzed': True + } + + # 创建模型信息文件 + info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json") + with open(info_path, 'w', encoding='utf-8') as f: + import json + json.dump(model_info, f, ensure_ascii=False, indent=2) + + # 显示模型分析结果 + if hasattr(model, 'names'): + scene_types = list(model.names.values()) + flash(f'模型 {filename} 上传成功!检测到 {len(scene_types)} 种场景类型: {", ".join(scene_types)}', 'success') + else: + flash(f'模型 {filename} 上传成功!', 'success') + + except Exception as e: + print(f"模型分析失败: {e}") + flash(f'模型 {filename} 上传成功,但分析模型信息时出错: {str(e)}', 'warning') + else: + flash('不支持的文件类型', 'danger') + + return redirect(url_for('index')) + +# 获取模型列表 +@app.route('/get_models', methods=['GET']) +def get_models(): + models = [] + model_files = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)] + + for model_file in model_files: + model_info = { + 'filename': model_file, + 'analyzed': False, + 'classes': {}, + 'class_count': 0, + 'scene_types': [] + } + + # 尝试读取模型信息文件 + info_path = os.path.join('models', f"{os.path.splitext(model_file)[0]}_info.json") + if os.path.exists(info_path): + try: + with open(info_path, 'r', encoding='utf-8') as f: + import json + cached_info = json.load(f) + model_info.update(cached_info) + model_info['scene_types'] = list(cached_info.get('classes', {}).values()) + except Exception as e: + print(f"读取模型信息文件失败: {e}") + + models.append(model_info) + + return jsonify(models) + +# 删除模型 +@app.route('/delete_model/', methods=['POST']) +def delete_model(filename): + try: + model_path = os.path.join('models', filename) + info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json") + + deleted_files = [] + if os.path.exists(model_path): + os.remove(model_path) + deleted_files.append('模型文件') + + # 同时删除模型信息文件 + if os.path.exists(info_path): + os.remove(info_path) + deleted_files.append('信息文件') + + if deleted_files: + flash(f'模型 {filename} 及其{"、".join(deleted_files)}已删除', 'success') + else: + flash(f'模型 {filename} 不存在', 'warning') + except Exception as e: + flash(f'删除模型时出错: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 上传视频 +@app.route('/upload_video', methods=['POST']) +def upload_video(): + if 'video_files[]' not in request.files: + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + files = request.files.getlist('video_files[]') + if not files or files[0].filename == '': + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + success_count = 0 + error_count = 0 + + for file in files: + if file and allowed_file(file.filename, ALLOWED_VIDEO_EXTENSIONS): + filename = secure_filename(file.filename) + file.save(os.path.join(config.INPUT_FOLDER, filename)) + success_count += 1 + else: + error_count += 1 + + if success_count > 0: + flash(f'成功上传 {success_count} 个视频文件', 'success') + if error_count > 0: + flash(f'{error_count} 个文件上传失败(不支持的文件类型)', 'warning') + + return redirect(url_for('index')) + +# 获取视频列表 +@app.route('/get_videos', methods=['GET']) +def get_videos(): + videos = [f for f in os.listdir(config.INPUT_FOLDER) if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + return jsonify(videos) + +# 删除视频 +@app.route('/delete_video/', methods=['POST']) +def delete_video(filename): + try: + video_path = os.path.join(config.INPUT_FOLDER, filename) + if os.path.exists(video_path): + os.remove(video_path) + flash(f'视频 {filename} 已删除', 'success') + else: + flash(f'视频 {filename} 不存在', 'warning') + except Exception as e: + flash(f'删除视频时出错: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 清空输入文件夹 +@app.route('/clear_input_folder', methods=['POST']) +def clear_input_folder(): + try: + for filename in os.listdir(config.INPUT_FOLDER): + file_path = os.path.join(config.INPUT_FOLDER, filename) + if os.path.isfile(file_path): + os.unlink(file_path) + flash('输入文件夹已清空', 'success') + except Exception as e: + flash(f'清空文件夹时出错: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 开始检测 +@app.route('/start_detection', methods=['POST']) +def start_detection(): + global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time + + print("收到检测请求") + print(f"当前检测状态: {detection_status}") + + if detection_status == "running": + flash('已有检测任务正在运行', 'warning') + return redirect(url_for('index')) + + # 获取选择的模型和置信度阈值 + model_name = request.form.get('model_select') + confidence_threshold = float(request.form.get('confidence_threshold', 0.25)) + + print(f"接收到的模型名称: {model_name}") + print(f"接收到的置信度阈值: {confidence_threshold}") + + # 检查是否有视频文件 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + if not videos: + flash('输入文件夹中没有视频文件,请先上传视频', 'danger') + return redirect(url_for('index')) + + # 检查模型文件是否存在 + if not model_name: + flash('请选择一个检测模型', 'danger') + return redirect(url_for('index')) + + model_path = os.path.join('models', model_name) + if not os.path.exists(model_path): + flash(f'模型文件 {model_name} 不存在', 'danger') + return redirect(url_for('index')) + + # 更新配置 + config.MODEL_PATH = model_path + config.CONFIDENCE_THRESHOLD = confidence_threshold + + # 重置检测状态和进度 + detection_progress = 0 + detection_status = "running" + detection_results = None + current_report_path = None + detection_start_time = datetime.now() + + # 重置进度相关变量 + global current_video_name, total_videos, current_video_index, current_frame_count, total_frame_count + current_video_name = "" + total_videos = len(videos) + current_video_index = 0 + current_frame_count = 0 + total_frame_count = 0 + + # 启动检测线程 + detection_thread = threading.Thread(target=run_detection) + detection_thread.daemon = True + detection_thread.start() + + flash('检测已开始,请等待结果', 'info') + return redirect(url_for('index')) + +# 检测线程函数 +def run_detection(): + global detection_progress, detection_status, detection_results, current_report_path, current_video_name, current_video_index, current_frame_count + + try: + print(f"开始检测,模型路径: {config.MODEL_PATH}") + print(f"置信度阈值: {config.CONFIDENCE_THRESHOLD}") + print(f"输入文件夹: {config.INPUT_FOLDER}") + + # 获取视频列表 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if any(f.endswith(ext) for ext in config.SUPPORTED_VIDEO_FORMATS)] + print(f"找到视频文件: {videos}") + + # 创建检测报告基本信息 + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + + # 创建检测器实例 + print("正在创建检测器实例...") + detector = BatchVideoDetector( + model_path=config.MODEL_PATH, + confidence=config.CONFIDENCE_THRESHOLD + ) + print("检测器实例创建成功") + + # 显示模型检测能力信息 + print("\n=== 网页端模型检测分析 ===") + if hasattr(detector.model, 'names'): + model_classes = list(detector.model.names.values()) + print(f"当前模型可检测场景: {model_classes}") + print(f"使用的置信度阈值: {config.CONFIDENCE_THRESHOLD}") + print("=" * 40) + + # 设置进度回调 + def progress_callback(status, progress, video_name=None, video_index=None, frame_count=None): + global detection_status, detection_progress, current_video_name, current_video_index, current_frame_count + detection_status = status + detection_progress = progress + + # 更新详细进度信息 + if video_name: + current_video_name = video_name + if video_index is not None: + current_video_index = video_index + if frame_count is not None: + current_frame_count = frame_count + + print(f"进度更新: {status} - {progress}%") + if video_name: + print(f"当前视频: {video_name} ({video_index}/{total_videos})") + if frame_count is not None: + print(f"已处理帧数: {frame_count}") + + detector.set_progress_callback(progress_callback) + + # 执行检测 + detection_status = "开始检测视频..." + report_path = detector.process_all_videos(config.INPUT_FOLDER, config.OUTPUT_FOLDER) + + # 读取检测结果 + if report_path and os.path.exists(report_path): + with open(report_path, 'r', encoding='utf-8') as f: + detection_results = json.load(f) + + current_report_path = report_path + detection_status = "检测完成" + detection_progress = 100 + else: + detection_status = "检测完成,但未生成报告" + detection_progress = 100 + except Exception as e: + detection_status = f"检测失败: {str(e)}" + print(f"检测异常: {str(e)}") + import traceback + traceback.print_exc() + finally: + # 如果状态仍为running,则更新为失败 + if detection_status == "running": + detection_status = "失败" + # 确保进度显示为100% + if detection_progress < 100: + detection_progress = 100 + +# 获取检测进度 +@app.route('/detection_progress') +def get_detection_progress(): + # 计算检测时长 + elapsed_time = "" + if detection_start_time and detection_status == "running": + elapsed_seconds = (datetime.now() - detection_start_time).total_seconds() + minutes, seconds = divmod(elapsed_seconds, 60) + elapsed_time = f"{int(minutes)}分{int(seconds)}秒" + + # 构建详细状态信息 + detailed_status = detection_status + if detection_status == "running" and current_video_name: + detailed_status = f"正在处理: {current_video_name} ({current_video_index}/{total_videos})" + if current_frame_count > 0: + detailed_status += f" - 已处理帧数: {current_frame_count}" + + return jsonify({ + 'in_progress': detection_status == "running", + 'progress': detection_progress, + 'status': detailed_status, + 'elapsed_time': elapsed_time, + 'current_video': current_video_name, + 'current_video_index': current_video_index, + 'total_videos': total_videos, + 'current_frame_count': current_frame_count, + 'total_frame_count': total_frame_count + }) + +# 查看检测报告 +@app.route('/report/') +def view_report(report_filename): + report_path = os.path.join(config.OUTPUT_FOLDER, report_filename) + if not os.path.exists(report_path): + flash('报告文件不存在', 'danger') + return redirect(url_for('index')) + + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + + # 计算检测率 + total_frames = report_data.get('statistics', {}).get('total_frames', 0) + detected_frames = report_data.get('statistics', {}).get('detected_frames', 0) + detection_rate = 0 + if total_frames > 0: + detection_rate = (detected_frames / total_frames) * 100 + + # 为每个视频计算检测率 + for video in report_data.get('videos', []): + video_total_frames = video.get('total_frames', 0) + video_detected_frames = video.get('detected_frames', 0) + if video_total_frames > 0: + video['detection_rate'] = (video_detected_frames / video_total_frames) * 100 + else: + video['detection_rate'] = 0 + + return render_template('report.html', + report=report_data, + report_filename=report_filename, + detection_rate=detection_rate) + + except Exception as e: + flash(f'读取报告时出错: {str(e)}', 'danger') + return redirect(url_for('index')) + +# 查看视频检测结果 +@app.route('/video_results//') +def view_video_results(report_filename, video_name): + report_path = os.path.join(config.OUTPUT_FOLDER, report_filename) + if not os.path.exists(report_path): + flash('报告文件不存在', 'danger') + return redirect(url_for('index')) + + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + + # 查找特定视频的结果 + video_data = None + for video in report_data['videos']: + if video['video_name'] == video_name: + video_data = video + break + + if not video_data: + flash('视频结果不存在', 'danger') + return redirect(url_for('view_report', report_filename=report_filename)) + + # 获取检测到的帧列表 + video_output_dir = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video_name)[0]) + detected_frames = [] + + if os.path.exists(video_output_dir): + for frame_file in os.listdir(video_output_dir): + if frame_file.startswith('detected_') and frame_file.endswith('.jpg'): + try: + frame_number = int(frame_file.split('_')[1].split('.')[0]) + frame_path = os.path.join(os.path.basename(video_output_dir), frame_file) + + # 获取对应的JSON文件以读取检测信息 + json_file = frame_file.replace('.jpg', '.json') + json_path = os.path.join(video_output_dir, json_file) + detections = [] + + if os.path.exists(json_path): + with open(json_path, 'r', encoding='utf-8') as f: + frame_data = json.load(f) + detections = frame_data.get('detections', []) + + detected_frames.append({ + 'frame_number': frame_number, + 'filename': frame_file, + 'path': frame_path, + 'timestamp': frame_number / video_data.get('fps', 30), # 计算时间戳 + 'detections': detections, + 'detection_count': len(detections) + }) + except Exception as e: + print(f"Error processing frame {frame_file}: {e}") + + # 按帧号排序 + detected_frames.sort(key=lambda x: x['frame_number']) + + # 计算检测率 + total_frames = video_data.get('total_frames', 0) + detected_count = video_data.get('detected_frames', 0) + detection_rate = 0 + if total_frames > 0: + detection_rate = (detected_count / total_frames) * 100 + + # 获取时间线数据 + timeline_data = [] + for frame in detected_frames: + timeline_data.append({ + 'timestamp': frame['timestamp'], + 'frame_number': frame['frame_number'], + 'detection_count': frame['detection_count'] + }) + + return render_template('video_results.html', + report_filename=report_filename, + report=report_data, + video=video_data, + detected_frames=detected_frames, + detection_rate=detection_rate, + timeline_data=json.dumps(timeline_data)) + + except Exception as e: + flash(f'读取视频结果时出错: {str(e)}', 'danger') + return redirect(url_for('view_report', report_filename=report_filename)) + +# 获取检测到的帧图像 +@app.route('/frame/') +def get_frame(frame_path): + directory, filename = os.path.split(frame_path) + return send_from_directory(os.path.join(config.OUTPUT_FOLDER, directory), filename) + +# 删除报告 +@app.route('/delete_report/', methods=['POST']) +def delete_report(report_filename): + report_path = os.path.join(config.OUTPUT_FOLDER, report_filename) + if not os.path.exists(report_path): + flash('报告文件不存在', 'danger') + return redirect(url_for('index')) + + # 读取报告以获取视频文件夹列表 + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + + # 删除每个视频的输出文件夹 + for video in report_data['videos']: + video_output_folder = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video['video_name'])[0]) + if os.path.exists(video_output_folder) and os.path.isdir(video_output_folder): + shutil.rmtree(video_output_folder) + + # 删除报告文件 + os.remove(report_path) + flash(f'报告 {report_filename} 已删除', 'success') + except Exception as e: + flash(f'删除报告失败: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 添加BatchVideoDetector的进度回调方法 +def add_progress_callback_to_detector(): + # 检查BatchVideoDetector类是否已有set_progress_callback方法 + if not hasattr(BatchVideoDetector, 'set_progress_callback'): + def set_progress_callback(self, callback): + self.progress_callback = callback + + def update_progress(self, status, progress, video_name=None, video_index=None, frame_count=None): + if hasattr(self, 'progress_callback') and self.progress_callback: + self.progress_callback(status, progress, video_name, video_index, frame_count) + + # 添加方法到BatchVideoDetector类 + BatchVideoDetector.set_progress_callback = set_progress_callback + BatchVideoDetector.update_progress = update_progress + + # 修改process_all_videos方法以支持进度更新 + original_process_all_videos = BatchVideoDetector.process_all_videos + + def process_all_videos_with_progress(self, input_folder="input_videos", output_folder="output_frames"): + self.update_progress("获取视频文件列表", 0) + video_files = self.get_video_files(input_folder) + total_videos = len(video_files) + + if total_videos == 0: + self.update_progress("没有找到视频文件", 100) + return None + + self.update_progress(f"开始处理 {total_videos} 个视频文件", 5) + + # 处理每个视频文件 + for i, video_path in enumerate(video_files, 1): + progress = int(5 + (i / total_videos) * 90) + video_name = video_path.name + self.update_progress(f"处理视频 {i}/{total_videos}: {video_name}", progress, video_name, i) + self.process_video(video_path, output_folder) + + self.update_progress("生成检测报告", 95) + self.stats['end_time'] = datetime.now() + report_path = self.save_final_report(output_folder) + self.print_summary() + self.update_progress("检测完成", 100) + return report_path + + # 修改save_final_report方法以返回报告路径 + original_save_final_report = BatchVideoDetector.save_final_report + + def save_final_report_with_return(self, output_folder): + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json" + + # 计算总处理时间 + total_seconds = 0 + if self.stats['start_time'] and self.stats['end_time']: + total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds() + + # 创建报告数据 + report_data = { + 'timestamp': timestamp, + 'model_path': self.model_path, + 'confidence_threshold': self.confidence, + 'frame_interval_threshold': config.MIN_FRAME_INTERVAL, + 'filter_strategy': '相同场景在指定时间间隔内只保存一次', + 'statistics': { + 'total_videos': self.stats['total_videos'], + 'processed_videos': self.stats['processed_videos'], + 'total_frames': self.stats['total_frames'], + 'detected_frames': self.stats['detected_frames'], + 'filtered_frames': self.stats['filtered_frames'], + 'saved_images': self.stats['saved_images'], + 'processing_time_seconds': total_seconds + }, + 'errors': self.stats['errors'], + 'videos': [] + } + + # 保存为JSON + with open(report_path, 'w', encoding='utf-8') as f: + json.dump(report_data, f, ensure_ascii=False, indent=2) + + print(f"\n批处理报告已保存: {report_path}") + return str(report_path) + + BatchVideoDetector.process_all_videos = process_all_videos_with_progress + BatchVideoDetector.save_final_report = save_final_report_with_return + +# 在应用启动前添加进度回调方法 +add_progress_callback_to_detector() + +if __name__ == '__main__': + # 确保必要的文件夹存在 + for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']: + os.makedirs(folder, exist_ok=True) + + app.run(debug=True, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/batch_detector.py b/batch_detector.py new file mode 100644 index 0000000..18fdfa1 --- /dev/null +++ b/batch_detector.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +批量视频检测器 +简化版本,专注于高效处理大量视频文件 +""" + +import os +import cv2 +import numpy as np +from pathlib import Path +from ultralytics import YOLO +import json +import time +from datetime import datetime +from config import config, DetectionConfig +from PIL import Image, ImageDraw, ImageFont + +class BatchVideoDetector: + def __init__(self, model_path=None, confidence=0.5): + """ + 初始化批量视频检测器 + + Args: + model_path (str): 模型路径,如果为None则使用配置文件中的路径 + confidence (float): 置信度阈值 + """ + self.model_path = model_path or config.get_model_path() + self.confidence = confidence + + # 创建必要文件夹 + config.create_folders() + + # 加载模型 + self.load_model() + + # 时间间隔过滤变量 + self.last_saved_timestamp = 0 + + # 初始化场景类别变量 + self.last_scene_classes = set() + + # 统计信息 + self.stats = { + 'start_time': None, + 'end_time': None, + 'total_videos': 0, + 'processed_videos': 0, + 'total_frames': 0, + 'detected_frames': 0, + 'saved_images': 0, + 'filtered_frames': 0, + 'errors': [] + } + + # 视频详细信息列表 + self.video_details = [] + + def load_model(self): + """加载YOLO模型""" + try: + print(f"正在加载模型: {self.model_path}") + self.model = YOLO(self.model_path) + print(f"模型加载成功!") + + # 检查是否有缓存的模型信息 + model_filename = os.path.basename(self.model_path) + info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json") + + cached_info_loaded = False + if os.path.exists(info_path): + try: + with open(info_path, 'r', encoding='utf-8') as f: + import json + cached_info = json.load(f) + if cached_info.get('analyzed', False): + print(f"使用缓存的模型信息,跳过重复分析") + print(f"模型类别数量: {cached_info.get('class_count', 0)}") + print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}") + cached_info_loaded = True + except Exception as e: + print(f"读取缓存模型信息失败: {e},将重新分析") + + # 如果没有缓存信息,则进行完整分析 + if not cached_info_loaded: + # 动态加载模型类别信息到配置中 + DetectionConfig.load_model_classes(self.model_path) + + if hasattr(self.model, 'names'): + print(f"\n=== 模型检测场景分析 ===") + print(f"模型类别数量: {len(self.model.names)}") + print(f"模型原始类别: {list(self.model.names.values())}") + + print(f"\n=== 中文类别映射 ===") + for class_id, class_name in self.model.names.items(): + class_name_cn = config.get_class_name_cn(class_id) + print(f"类别 {class_id}: {class_name} -> {class_name_cn}") + + print(f"\n=== 检测场景总结 ===") + scene_types = list(DetectionConfig.get_all_classes().values()) + print(f"本模型可检测的场景类型: {scene_types}") + print(f"检测置信度阈值: {self.confidence}") + print("=" * 40) + else: + # 使用缓存信息更新配置 + DetectionConfig.load_model_classes(self.model_path) + + except Exception as e: + print(f"模型加载失败: {e}") + raise + + def get_video_files(self, input_folder="input_videos"): + """获取所有视频文件""" + input_path = Path(input_folder) + if not input_path.exists(): + print(f"输入文件夹不存在: {input_folder}") + return [] + + video_files = [] + for ext in config.SUPPORTED_VIDEO_FORMATS: + video_files.extend(input_path.rglob(f"*{ext}")) + + print(f"找到 {len(video_files)} 个视频文件") + return video_files + + def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir): + """检测单帧并保存结果""" + try: + # 进行检测 + results = self.model(frame, conf=self.confidence, verbose=False) + result = results[0] if results else None + + # 如果检测到目标 + if result is not None and result.boxes is not None and len(result.boxes) > 0: + # 获取当前帧的检测类别 + current_scene_classes = set() + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + current_scene_classes.add(class_id) + + # 检查时间间隔 + time_diff = timestamp - self.last_saved_timestamp + + # 如果时间间隔小于最小间隔,检查是否为新场景 + if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0: + # 如果当前帧的场景类别与上一帧相同,则过滤 + if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes: + self.stats['filtered_frames'] += 1 + # 仍然返回检测信息,但不保存图像 + detections = [] + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + confidence = float(box.conf[0].cpu().numpy()) + bbox = box.xyxy[0].cpu().numpy().tolist() + class_name = self.model.names.get(class_id, f"class_{class_id}") + class_name_cn = config.get_class_name_cn(class_id) + + detections.append({ + 'class_id': class_id, + 'class_name': class_name, + 'class_name_cn': class_name_cn, + 'confidence': confidence, + 'bbox': bbox, + 'filtered': True + }) + + return { + 'frame_number': frame_number, + 'timestamp': timestamp, + 'detections': detections, + 'filtered': True + } + + # 更新最后检测到的场景类别 + self.last_scene_classes = current_scene_classes + + # 更新最后保存的时间戳 + self.last_saved_timestamp = timestamp + + # 创建文件名 + base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s" + + # 获取检测到的场景类别,用于分类保存 + detected_scenes = set() + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + class_name_cn = config.get_class_name_cn(class_id) + detected_scenes.add(class_name_cn) + + # 为每个检测到的场景创建对应的文件夹并保存图片 + for scene_name in detected_scenes: + # 创建视频名称文件夹 + video_folder = output_dir / video_name + video_folder.mkdir(parents=True, exist_ok=True) + + # 创建场景文件夹 + scene_folder = video_folder / scene_name + scene_folder.mkdir(parents=True, exist_ok=True) + + # 保存原始帧 + if config.SAVE_ORIGINAL_FRAMES: + original_path = scene_folder / f"{base_name}_original.jpg" + cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY]) + + # 绘制并保存标注帧 + if config.SAVE_ANNOTATED_FRAMES: + annotated_frame = self.draw_detections(frame.copy(), result) + annotated_path = scene_folder / f"{base_name}_detected.jpg" + cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY]) + + # 收集检测信息 + detections = [] + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + confidence = float(box.conf[0].cpu().numpy()) + bbox = box.xyxy[0].cpu().numpy().tolist() + class_name = self.model.names.get(class_id, f"class_{class_id}") + class_name_cn = config.get_class_name_cn(class_id) + + detections.append({ + 'class_id': class_id, + 'class_name': class_name, + 'class_name_cn': class_name_cn, + 'confidence': confidence, + 'bbox': bbox + }) + + return { + 'frame_number': frame_number, + 'timestamp': timestamp, + 'detections': detections + } + + return None + + except Exception as e: + print(f"处理帧 {frame_number} 时出错: {e}") + return None + + def draw_detections(self, frame, result): + """在帧上绘制检测结果(绘制边界框和中文标注)""" + if result is None or result.boxes is None: + return frame + + # 转换为PIL图像 + pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # 尝试加载中文字体 + try: + # Windows系统常用中文字体路径 + font_paths = [ + "C:/Windows/Fonts/simhei.ttf", # 黑体 + "C:/Windows/Fonts/simsun.ttc", # 宋体 + "C:/Windows/Fonts/msyh.ttc", # 微软雅黑 + ] + font = None + for font_path in font_paths: + try: + font = ImageFont.truetype(font_path, 20) + break + except: + continue + if font is None: + font = ImageFont.load_default() + except: + font = ImageFont.load_default() + + for box in result.boxes: + # 获取坐标和信息 + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + class_id = int(box.cls[0].cpu().numpy()) + confidence = float(box.conf[0].cpu().numpy()) + + # 获取颜色和类别名称 + color = config.get_class_color(class_id) + bg_color = tuple(reversed(color)) # BGR转RGB + class_name_cn = config.get_class_name_cn(class_id) + + # 绘制边界框 + draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2) + + # 准备标注文字 + label_text = f"{class_name_cn} {confidence:.2f}" + + # 计算文字背景框大小 + bbox = draw.textbbox((0, 0), label_text, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + # 确保标注不超出图像边界 + label_x = max(0, x1) + label_y = max(0, y1 - text_height - 5) + if label_y < 0: + label_y = y1 + 5 + + # 绘制文字背景 + draw.rectangle( + [label_x, label_y, label_x + text_width + 4, label_y + text_height + 4], + fill=bg_color + ) + + # 绘制文字(白色) + draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font) + + # 转换回OpenCV格式 + frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + return frame + + def process_video(self, video_path, output_base_dir="output_frames"): + """处理单个视频""" + print(f"\n开始处理: {video_path.name}") + + # 创建输出目录 + output_dir = Path(output_base_dir) / video_path.stem + output_dir.mkdir(parents=True, exist_ok=True) + + # 重置时间间隔过滤变量,确保每个视频独立计算时间间隔 + self.last_saved_timestamp = 0 + + # 重置场景类别变量,确保每个视频独立分类场景 + self.last_scene_classes = set() + + # 打开视频 + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + error_msg = f"无法打开视频: {video_path}" + print(error_msg) + self.stats['errors'].append(error_msg) + return + + # 获取视频信息 + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS") + + # 处理统计 + frame_count = 0 + detected_frames = 0 + saved_images = 0 + filtered_frames = 0 + all_detections = [] + + start_time = time.time() + + try: + while True: + ret, frame = cap.read() + if not ret: + break + + frame_count += 1 + timestamp = frame_count / fps if fps > 0 else frame_count + + # 显示进度 + if frame_count % config.PROGRESS_INTERVAL == 0: + progress = (frame_count / total_frames) * 100 + elapsed = time.time() - start_time + fps_current = frame_count / elapsed + print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS") + + # 更新进度回调 + if hasattr(self, 'update_progress'): + self.update_progress(f"处理视频帧", None, None, None, frame_count) + + # 检测并保存 + detection_result = self.detect_and_save_frame( + frame, frame_count, timestamp, video_path.stem, output_dir + ) + + if detection_result: + detected_frames += 1 + all_detections.append(detection_result) + + # 检查是否是被过滤的帧 + if detection_result.get('filtered', False): + filtered_frames += 1 + else: + # 计算保存的图片数量 + if config.SAVE_ORIGINAL_FRAMES: + saved_images += 1 + if config.SAVE_ANNOTATED_FRAMES: + saved_images += 1 + + finally: + cap.release() + + # 保存检测结果JSON + if config.SAVE_DETECTION_JSON and all_detections: + json_path = output_dir / f"{video_path.stem}_detections.json" + detection_summary = { + 'video_name': video_path.name, + 'total_frames': frame_count, + 'detected_frames': detected_frames, + 'fps': fps, + 'detections': all_detections, + 'processing_time': time.time() - start_time + } + + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(detection_summary, f, ensure_ascii=False, indent=2) + + # 更新统计 + self.stats['processed_videos'] += 1 + self.stats['total_frames'] += frame_count + self.stats['detected_frames'] += detected_frames + self.stats['saved_images'] += saved_images + self.stats['filtered_frames'] += filtered_frames + + processing_time = time.time() - start_time + print(f"处理完成: {video_path.name}") + print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images} 张") + print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)") + print(f"处理时间: {processing_time:.2f} 秒") + + # 记录视频详细信息 + video_info = { + 'video_name': video_path.name, + 'video_path': str(video_path), + 'total_frames': frame_count, + 'detected_frames': detected_frames, + 'saved_images': saved_images, + 'filtered_frames': filtered_frames, + 'processing_time': processing_time, + 'fps': fps, + 'duration': frame_count / fps if fps > 0 else 0, + 'output_directory': str(output_dir) + } + self.video_details.append(video_info) + + def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"): + """批量处理所有视频""" + self.stats['start_time'] = datetime.now() + + video_files = self.get_video_files(input_folder) + if not video_files: + print("没有找到视频文件") + return + + self.stats['total_videos'] = len(video_files) + + print(f"\n开始批量处理 {len(video_files)} 个视频文件") + print(f"模型: {self.model_path}") + print(f"置信度阈值: {self.confidence}") + print(f"输出目录: {output_folder}") + print("=" * 50) + + for i, video_path in enumerate(video_files, 1): + print(f"\n[{i}/{len(video_files)}] 处理视频") + self.process_video(video_path, output_folder) + + self.stats['end_time'] = datetime.now() + report_path = self.save_final_report(output_folder) + self.print_summary() + return report_path + + def save_final_report(self, output_folder): + """保存最终批处理报告""" + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json" + + # 计算总处理时间 + total_seconds = 0 + if self.stats['start_time'] and self.stats['end_time']: + total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds() + + # 创建报告数据 + report_data = { + 'timestamp': timestamp, + 'model_path': self.model_path, + 'confidence_threshold': self.confidence, + 'frame_interval_threshold': config.MIN_FRAME_INTERVAL, + 'filter_strategy': '相同场景在指定时间间隔内只保存一次', + 'statistics': { + 'total_videos': self.stats['total_videos'], + 'processed_videos': self.stats['processed_videos'], + 'total_frames': self.stats['total_frames'], + 'detected_frames': self.stats['detected_frames'], + 'filtered_frames': self.stats['filtered_frames'], + 'saved_images': self.stats['saved_images'], + 'processing_time_seconds': total_seconds + }, + 'videos': self.video_details, + 'errors': self.stats['errors'] + } + + # 保存为JSON + with open(report_path, 'w', encoding='utf-8') as f: + json.dump(report_data, f, ensure_ascii=False, indent=2) + + print(f"\n批处理报告已保存: {report_path}") + return str(report_path) + + def print_summary(self): + """打印处理摘要""" + if not self.stats['end_time'] or not self.stats['start_time']: + return + + total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds() + hours, remainder = divmod(total_time, 3600) + minutes, seconds = divmod(remainder, 60) + + print("\n" + "=" * 50) + print("批量处理摘要") + print("=" * 50) + print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}") + print(f"总处理帧数: {self.stats['total_frames']}") + print(f"检测到目标的帧数: {self.stats['detected_frames']}") + print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)") + print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次") + print(f"实际保存的图片数量: {self.stats['saved_images']}") + print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}秒") + + if self.stats['errors']: + print("\n处理过程中的错误:") + for error in self.stats['errors']: + print(f"- {error}") + + print("=" * 50) + +def main(): + """主函数""" + import argparse + + parser = argparse.ArgumentParser(description='批量视频检测系统') + parser.add_argument('--model', type=str, help='模型文件路径') + parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值') + parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹') + parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹') + + args = parser.parse_args() + + # 创建检测器 + detector = BatchVideoDetector( + model_path=args.model, + confidence=args.confidence + ) + + # 开始批量处理 + detector.process_all_videos(args.input, args.output) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..943a680 --- /dev/null +++ b/config.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +视频检测系统配置文件 +""" + +import os +from pathlib import Path +from ultralytics import YOLO + +class DetectionConfig: + """检测系统配置类""" + + # 模型配置 + MODEL_PATH = "models/best.pt" + CONFIDENCE_THRESHOLD = 0.5 + + # 动态模型信息 + _model_classes = None + _model_colors = None + + # 文件夹配置 + INPUT_FOLDER = "input_videos" + OUTPUT_FOLDER = "output_frames" + LOG_FOLDER = "logs" + + # 视频处理配置 + SUPPORTED_VIDEO_FORMATS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'} + PROGRESS_INTERVAL = 100 # 每处理多少帧显示一次进度 + MIN_FRAME_INTERVAL = 1.0 # 检测帧的最小时间间隔(秒) + + # 输出配置 + SAVE_ORIGINAL_FRAMES = False # 是否保存原始帧 + SAVE_ANNOTATED_FRAMES = True # 是否保存标注帧 + SAVE_DETECTION_JSON = True # 是否保存检测结果JSON + + # 图像质量配置 + IMAGE_QUALITY = 95 # JPEG质量 (1-100) + + # 默认道路损伤类别映射 (中文标签) - 作为备用 + DEFAULT_CLASS_NAMES_CN = { + 0: "纵向裂缝", + 1: "横向裂缝", + 2: "网状裂缝", + 3: "坑洞", + 4: "白线模糊" + } + + # 默认类别颜色配置 (BGR格式) - 所有类别使用相同的绿色 + DEFAULT_CLASS_COLORS = { + 0: (0, 255, 0), # 绿色 + 1: (0, 255, 0), # 绿色 + 2: (0, 255, 0), # 绿色 + 3: (0, 255, 0), # 绿色 + 4: (0, 255, 0) # 绿色 + } + + @classmethod + def get_model_path(cls): + """获取模型文件的绝对路径""" + current_dir = Path(__file__).parent + model_path = current_dir / cls.MODEL_PATH + return str(model_path.resolve()) + + @classmethod + def create_folders(cls): + """创建必要的文件夹""" + folders = [cls.INPUT_FOLDER, cls.OUTPUT_FOLDER, cls.LOG_FOLDER] + for folder in folders: + Path(folder).mkdir(parents=True, exist_ok=True) + + @classmethod + def load_model_classes(cls, model_path): + """从模型文件动态加载类别信息""" + try: + model = YOLO(model_path) + cls._model_classes = model.names + # 为所有类别生成绿色 + cls._model_colors = {i: (0, 255, 0) for i in range(len(model.names))} + return True + except Exception as e: + print(f"加载模型类别信息失败: {e}") + return False + + @classmethod + def get_class_name_cn(cls, class_id): + """获取类别的中文名称""" + # 优先使用动态加载的模型类别 + if cls._model_classes is not None: + original_name = cls._model_classes.get(class_id) + if original_name: + # 如果有对应的中文映射则使用,否则使用原始英文名称 + return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, original_name) + return f"未知类别_{class_id}" + # 回退到默认中文类别 + return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, f"未知类别_{class_id}") + + @classmethod + def get_class_color(cls, class_id): + """获取类别对应的颜色 - 始终返回绿色""" + if cls._model_colors is not None: + return cls._model_colors.get(class_id, (0, 255, 0)) + return cls.DEFAULT_CLASS_COLORS.get(class_id, (0, 255, 0)) + + @classmethod + def get_all_classes(cls): + """获取所有类别信息""" + if cls._model_classes is not None: + return cls._model_classes + return cls.DEFAULT_CLASS_NAMES_CN + +# 全局配置实例 +config = DetectionConfig() \ No newline at end of file diff --git a/detector.py b/detector.py new file mode 100644 index 0000000..ffbb0ce --- /dev/null +++ b/detector.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +视频目标检测系统 +使用YOLO模型对视频进行道路损伤检测,并抽取包含目标的帧 +""" + +import os +import cv2 +import numpy as np +from pathlib import Path +from ultralytics import YOLO +import argparse +from datetime import datetime +import json +import logging +from PIL import Image, ImageDraw, ImageFont +from config import config, DetectionConfig + +class VideoDetectionSystem: + def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"): + """ + 初始化视频检测系统 + + Args: + model_path (str): YOLO模型文件路径 + confidence_threshold (float): 置信度阈值 + input_folder (str): 输入视频文件夹 + output_folder (str): 输出图片文件夹 + """ + self.model_path = model_path + self.confidence_threshold = confidence_threshold + self.input_folder = Path(input_folder) + self.output_folder = Path(output_folder) + + # 创建输出文件夹 + self.output_folder.mkdir(parents=True, exist_ok=True) + + # 设置日志 + self.setup_logging() + + # 加载模型 + self.load_model() + + # 支持的视频格式 + self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'} + + # 检测结果统计 + self.detection_stats = { + 'total_videos': 0, + 'total_frames': 0, + 'detected_frames': 0, + 'saved_frames': 0, + 'detection_results': [] + } + + def setup_logging(self): + """设置日志记录""" + log_folder = Path('logs') + log_folder.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = log_folder / f'detection_{timestamp}.log' + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file, encoding='utf-8'), + logging.StreamHandler() + ] + ) + self.logger = logging.getLogger(__name__) + + def load_model(self): + """加载YOLO模型""" + try: + self.model = YOLO(self.model_path) + self.logger.info(f"成功加载模型: {self.model_path}") + + # 动态加载模型类别信息到配置中 + DetectionConfig.load_model_classes(self.model_path) + + # 打印模型信息 + if hasattr(self.model, 'names'): + self.logger.info(f"模型类别数量: {len(self.model.names)}") + self.logger.info(f"检测类别: {list(self.model.names.values())}") + self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}") + + except Exception as e: + self.logger.error(f"加载模型失败: {e}") + raise + + def get_video_files(self): + """获取输入文件夹中的所有视频文件""" + video_files = [] + + if not self.input_folder.exists(): + self.logger.warning(f"输入文件夹不存在: {self.input_folder}") + return video_files + + for file_path in self.input_folder.rglob('*'): + if file_path.is_file() and file_path.suffix.lower() in self.video_extensions: + video_files.append(file_path) + + self.logger.info(f"找到 {len(video_files)} 个视频文件") + return video_files + + def detect_frame(self, frame): + """对单帧进行目标检测""" + try: + results = self.model(frame, conf=self.confidence_threshold, verbose=False) + return results[0] if results else None + except Exception as e: + self.logger.error(f"检测失败: {e}") + return None + + def draw_detections(self, frame, result): + """在帧上绘制检测结果""" + if result is None or result.boxes is None: + return frame + + # 转换为PIL图像以支持中文字体 + pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # 尝试加载中文字体 + try: + # Windows系统字体路径 + font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体 + + if os.path.exists(font_path): + font = ImageFont.truetype(font_path, 20) + else: + font = ImageFont.load_default() + except: + font = ImageFont.load_default() + + for box in result.boxes: + # 获取边界框坐标 + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + + # 获取置信度和类别 + confidence = float(box.conf[0].cpu().numpy()) + class_id = int(box.cls[0].cpu().numpy()) + + # 获取中文类别名称和颜色 + class_name_cn = config.get_class_name_cn(class_id) + color = config.get_class_color(class_id) + + # 绘制边界框(使用PIL) + bg_color = tuple(reversed(color)) # BGR转RGB + draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2) + + # 准备标签文字 + label = f"{class_name_cn}: {confidence:.2f}" + + # 获取文字尺寸 + bbox = draw.textbbox((0, 0), label, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + # 绘制标签背景(使用PIL) + draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color) + + # 绘制文字(使用PIL) + draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255)) + + # 转换回OpenCV格式 + annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + return annotated_frame + + def process_video(self, video_path): + """处理单个视频文件""" + self.logger.info(f"开始处理视频: {video_path.name}") + + # 创建视频专用输出文件夹 + video_output_folder = self.output_folder / video_path.stem + video_output_folder.mkdir(exist_ok=True) + + # 打开视频 + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + self.logger.error(f"无法打开视频文件: {video_path}") + return + + # 获取视频信息 + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = total_frames / fps if fps > 0 else 0 + + self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}秒") + + frame_count = 0 + detected_count = 0 + saved_count = 0 + + video_detections = { + 'video_name': video_path.name, + 'total_frames': total_frames, + 'fps': fps, + 'duration': duration, + 'detections': [] + } + + try: + while True: + ret, frame = cap.read() + if not ret: + break + + frame_count += 1 + + # 每100帧显示一次进度 + if frame_count % 100 == 0: + progress = (frame_count / total_frames) * 100 + self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})") + + # 进行目标检测 + result = self.detect_frame(frame) + + # 如果检测到目标,保存帧 + if result is not None and result.boxes is not None and len(result.boxes) > 0: + detected_count += 1 + + # 计算时间戳 + timestamp = frame_count / fps if fps > 0 else frame_count + + # 绘制检测结果 + annotated_frame = self.draw_detections(frame, result) + + # 保存原始帧和标注帧 + frame_filename = f"frame_{frame_count:06d}_t{timestamp:.2f}s.jpg" + annotated_filename = f"annotated_{frame_count:06d}_t{timestamp:.2f}s.jpg" + + original_path = video_output_folder / frame_filename + annotated_path = video_output_folder / annotated_filename + + cv2.imwrite(str(original_path), frame) + cv2.imwrite(str(annotated_path), annotated_frame) + + saved_count += 2 + + # 记录检测信息 + detection_info = { + 'frame_number': frame_count, + 'timestamp': timestamp, + 'detections': [] + } + + for box in result.boxes: + confidence = float(box.conf[0].cpu().numpy()) + class_id = int(box.cls[0].cpu().numpy()) + class_name = self.model.names[class_id] + bbox = box.xyxy[0].cpu().numpy().tolist() + + detection_info['detections'].append({ + 'class_name': class_name, + 'class_id': class_id, + 'confidence': confidence, + 'bbox': bbox + }) + + video_detections['detections'].append(detection_info) + + finally: + cap.release() + + # 保存检测结果到JSON文件 + json_path = video_output_folder / f"{video_path.stem}_detections.json" + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(video_detections, f, ensure_ascii=False, indent=2) + + self.logger.info(f"视频处理完成: {video_path.name}") + self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}") + + # 更新统计信息 + self.detection_stats['total_videos'] += 1 + self.detection_stats['total_frames'] += frame_count + self.detection_stats['detected_frames'] += detected_count + self.detection_stats['saved_frames'] += saved_count + self.detection_stats['detection_results'].append(video_detections) + + def process_all_videos(self): + """处理所有视频文件""" + video_files = self.get_video_files() + + if not video_files: + self.logger.warning("没有找到视频文件") + return + + self.logger.info(f"开始处理 {len(video_files)} 个视频文件") + + for i, video_path in enumerate(video_files, 1): + self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===") + self.process_video(video_path) + + # 保存总体统计信息 + self.save_summary_report() + + self.logger.info("\n=== 处理完成 ===") + self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']} 个") + self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']} 帧") + self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']} 帧") + self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']} 张") + + def save_summary_report(self): + """保存总结报告""" + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + report_path = self.output_folder / f"detection_summary_{timestamp}.json" + + summary = { + 'timestamp': timestamp, + 'model_path': str(self.model_path), + 'confidence_threshold': self.confidence_threshold, + 'input_folder': str(self.input_folder), + 'output_folder': str(self.output_folder), + 'statistics': self.detection_stats + } + + with open(report_path, 'w', encoding='utf-8') as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + + self.logger.info(f"总结报告已保存: {report_path}") + +def main(): + parser = argparse.ArgumentParser(description='视频目标检测系统') + parser.add_argument('--model', type=str, + default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt', + help='YOLO模型文件路径') + parser.add_argument('--confidence', type=float, default=0.5, + help='置信度阈值 (默认: 0.5)') + parser.add_argument('--input', type=str, default='input_videos', + help='输入视频文件夹 (默认: input_videos)') + parser.add_argument('--output', type=str, default='output_frames', + help='输出图片文件夹 (默认: output_frames)') + + args = parser.parse_args() + + # 创建检测系统 + detector = VideoDetectionSystem( + model_path=args.model, + confidence_threshold=args.confidence, + input_folder=args.input, + output_folder=args.output + ) + + # 处理所有视频 + detector.process_all_videos() + +if __name__ == '__main__': + main() \ No newline at end of file