From 9ccc897c3caaa1f516313225f06e02ba2ab715a2 Mon Sep 17 00:00:00 2001 From: Wang_Run_Ze Date: Fri, 27 Jun 2025 17:13:03 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.py | 734 ++++++++++++++++++++++++++++++++++++++++++++++ batch_detector.py | 642 ++++++++++++++++++++++++++++++++++++++++ config.py | 200 +++++++++++++ detector.py | 453 ++++++++++++++++++++++++++++ requirements.txt | 33 +++ 5 files changed, 2062 insertions(+) create mode 100644 app.py create mode 100644 batch_detector.py create mode 100644 config.py create mode 100644 detector.py create mode 100644 requirements.txt diff --git a/app.py b/app.py new file mode 100644 index 0000000..32ec99d --- /dev/null +++ b/app.py @@ -0,0 +1,734 @@ +from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, send_from_directory +import os +import shutil +import time +import threading +import json +from datetime import datetime +from pathlib import Path +from flask import Flask, render_template, request, redirect, url_for, flash, jsonify +from werkzeug.utils import secure_filename +from config import config +from batch_detector import BatchVideoDetector + +app = Flask(__name__) +app.secret_key = 'road_damage_detection_system' +app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max upload size + +# 确保必要的文件夹存在 +for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']: + os.makedirs(folder, exist_ok=True) + +# 全局变量用于跟踪检测进度和状态 +detection_progress = 0 +detection_status = "idle" # idle, running, completed, error +detection_thread = None +current_detection_info = {} +detection_start_time = None +detection_results = None +current_report_path = None +current_video_name = "" +total_videos = 0 +current_video_index = 0 +current_frame_count = 0 +total_frame_count = 0 + +# 允许的文件扩展名 +ALLOWED_VIDEO_EXTENSIONS = {ext.lstrip('.') for ext in config.SUPPORTED_VIDEO_FORMATS} +ALLOWED_MODEL_EXTENSIONS = {'pt', 'pth', 'weights'} + +def allowed_file(filename, allowed_extensions): + return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions + +# 主页 +@app.route('/', methods=['GET', 'POST']) +def index(): + global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time + + # 处理POST请求(开始检测) + if request.method == 'POST': + if detection_status == "running": + flash('已有检测任务正在运行', 'warning') + else: + # 获取选择的模型和置信度阈值 + model_name = request.form.get('model_select') + confidence_threshold = float(request.form.get('confidence_threshold', 0.25)) + + # 检查是否有视频文件 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + if not videos: + flash('输入文件夹中没有视频文件,请先上传视频', 'danger') + elif not model_name: + flash('请选择一个检测模型', 'danger') + else: + model_path = os.path.join('models', model_name) + if not os.path.exists(model_path): + flash(f'模型文件 {model_name} 不存在', 'danger') + else: + # 更新配置 + config.MODEL_PATH = model_path + config.CONFIDENCE_THRESHOLD = confidence_threshold + + # 重置检测状态和进度 + detection_progress = 0 + detection_status = "running" + detection_results = None + current_report_path = None + detection_start_time = datetime.now() + + # 启动检测线程 + detection_thread = threading.Thread(target=run_detection) + detection_thread.daemon = True + detection_thread.start() + + flash('检测已开始,请等待结果', 'info') + + # 获取可用的模型列表 + print(f"models文件夹路径: {os.path.abspath('models')}") + print(f"models文件夹存在: {os.path.exists('models')}") + if os.path.exists('models'): + all_files = os.listdir('models') + print(f"models文件夹中的所有文件: {all_files}") + for f in all_files: + print(f"检查文件 {f}: allowed_file结果 = {allowed_file(f, ALLOWED_MODEL_EXTENSIONS)}") + models = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)] + print(f"找到的模型文件: {models}") + + # 获取输入文件夹中的视频文件 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + + # 获取检测报告列表 + reports = [] + if os.path.exists(config.OUTPUT_FOLDER): + for report_file in os.listdir(config.OUTPUT_FOLDER): + if report_file.startswith('batch_detection_report_') and report_file.endswith('.json'): + report_path = os.path.join(config.OUTPUT_FOLDER, report_file) + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + reports.append({ + 'id': report_file, + 'timestamp': report_data.get('timestamp', '未知'), + 'video_count': report_data.get('statistics', {}).get('processed_videos', 0), + 'total_frames': report_data.get('statistics', {}).get('total_frames', 0), + 'detected_frames': report_data.get('statistics', {}).get('detected_frames', 0) + }) + except Exception as e: + print(f"Error loading report {report_path}: {e}") + + # 按时间戳排序,最新的在前面 + reports.sort(key=lambda x: x['timestamp'], reverse=True) + + return render_template('index.html', + models=models, + videos=videos, + reports=reports, + detection_status=detection_status, + detection_progress=detection_progress, + detection_in_progress=(detection_status == "running"), + confidence_threshold=config.CONFIDENCE_THRESHOLD, + config=config) + +# 上传模型 +@app.route('/upload_model', methods=['POST']) +def upload_model(): + if 'model_file' not in request.files: + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + file = request.files['model_file'] + if file.filename == '': + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + if file and allowed_file(file.filename, ALLOWED_MODEL_EXTENSIONS): + filename = secure_filename(file.filename) + model_path = os.path.join('models', filename) + file.save(model_path) + + # 立即检测模型信息并缓存 + try: + print(f"正在分析上传的模型: {filename}") + from ultralytics import YOLO + model = YOLO(model_path) + + # 保存模型信息到JSON文件 + model_info = { + 'filename': filename, + 'upload_time': datetime.now().isoformat(), + 'classes': dict(model.names) if hasattr(model, 'names') else {}, + 'class_count': len(model.names) if hasattr(model, 'names') else 0, + 'model_type': 'YOLO', + 'analyzed': True + } + + # 创建模型信息文件 + info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json") + with open(info_path, 'w', encoding='utf-8') as f: + import json + json.dump(model_info, f, ensure_ascii=False, indent=2) + + # 显示模型分析结果 + if hasattr(model, 'names'): + scene_types = list(model.names.values()) + flash(f'模型 {filename} 上传成功!检测到 {len(scene_types)} 种场景类型: {", ".join(scene_types)}', 'success') + else: + flash(f'模型 {filename} 上传成功!', 'success') + + except Exception as e: + print(f"模型分析失败: {e}") + flash(f'模型 {filename} 上传成功,但分析模型信息时出错: {str(e)}', 'warning') + else: + flash('不支持的文件类型', 'danger') + + return redirect(url_for('index')) + +# 获取模型列表 +@app.route('/get_models', methods=['GET']) +def get_models(): + models = [] + model_files = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)] + + for model_file in model_files: + model_info = { + 'filename': model_file, + 'analyzed': False, + 'classes': {}, + 'class_count': 0, + 'scene_types': [] + } + + # 尝试读取模型信息文件 + info_path = os.path.join('models', f"{os.path.splitext(model_file)[0]}_info.json") + if os.path.exists(info_path): + try: + with open(info_path, 'r', encoding='utf-8') as f: + import json + cached_info = json.load(f) + model_info.update(cached_info) + model_info['scene_types'] = list(cached_info.get('classes', {}).values()) + except Exception as e: + print(f"读取模型信息文件失败: {e}") + + models.append(model_info) + + return jsonify(models) + +# 删除模型 +@app.route('/delete_model/', methods=['POST']) +def delete_model(filename): + try: + model_path = os.path.join('models', filename) + info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json") + + deleted_files = [] + if os.path.exists(model_path): + os.remove(model_path) + deleted_files.append('模型文件') + + # 同时删除模型信息文件 + if os.path.exists(info_path): + os.remove(info_path) + deleted_files.append('信息文件') + + if deleted_files: + flash(f'模型 {filename} 及其{"、".join(deleted_files)}已删除', 'success') + else: + flash(f'模型 {filename} 不存在', 'warning') + except Exception as e: + flash(f'删除模型时出错: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 上传视频 +@app.route('/upload_video', methods=['POST']) +def upload_video(): + if 'video_files[]' not in request.files: + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + files = request.files.getlist('video_files[]') + if not files or files[0].filename == '': + flash('没有选择文件', 'danger') + return redirect(url_for('index')) + + success_count = 0 + error_count = 0 + + for file in files: + if file and allowed_file(file.filename, ALLOWED_VIDEO_EXTENSIONS): + filename = secure_filename(file.filename) + file.save(os.path.join(config.INPUT_FOLDER, filename)) + success_count += 1 + else: + error_count += 1 + + if success_count > 0: + flash(f'成功上传 {success_count} 个视频文件', 'success') + if error_count > 0: + flash(f'{error_count} 个文件上传失败(不支持的文件类型)', 'warning') + + return redirect(url_for('index')) + +# 获取视频列表 +@app.route('/get_videos', methods=['GET']) +def get_videos(): + videos = [f for f in os.listdir(config.INPUT_FOLDER) if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + return jsonify(videos) + +# 删除视频 +@app.route('/delete_video/', methods=['POST']) +def delete_video(filename): + try: + video_path = os.path.join(config.INPUT_FOLDER, filename) + if os.path.exists(video_path): + os.remove(video_path) + flash(f'视频 {filename} 已删除', 'success') + else: + flash(f'视频 {filename} 不存在', 'warning') + except Exception as e: + flash(f'删除视频时出错: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 清空输入文件夹 +@app.route('/clear_input_folder', methods=['POST']) +def clear_input_folder(): + try: + for filename in os.listdir(config.INPUT_FOLDER): + file_path = os.path.join(config.INPUT_FOLDER, filename) + if os.path.isfile(file_path): + os.unlink(file_path) + flash('输入文件夹已清空', 'success') + except Exception as e: + flash(f'清空文件夹时出错: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 开始检测 +@app.route('/start_detection', methods=['POST']) +def start_detection(): + global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time + + print("收到检测请求") + print(f"当前检测状态: {detection_status}") + + if detection_status == "running": + flash('已有检测任务正在运行', 'warning') + return redirect(url_for('index')) + + # 获取选择的模型和置信度阈值 + model_name = request.form.get('model_select') + confidence_threshold = float(request.form.get('confidence_threshold', 0.25)) + + print(f"接收到的模型名称: {model_name}") + print(f"接收到的置信度阈值: {confidence_threshold}") + + # 检查是否有视频文件 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)] + if not videos: + flash('输入文件夹中没有视频文件,请先上传视频', 'danger') + return redirect(url_for('index')) + + # 检查模型文件是否存在 + if not model_name: + flash('请选择一个检测模型', 'danger') + return redirect(url_for('index')) + + model_path = os.path.join('models', model_name) + if not os.path.exists(model_path): + flash(f'模型文件 {model_name} 不存在', 'danger') + return redirect(url_for('index')) + + # 更新配置 + config.MODEL_PATH = model_path + config.CONFIDENCE_THRESHOLD = confidence_threshold + + # 重置检测状态和进度 + detection_progress = 0 + detection_status = "running" + detection_results = None + current_report_path = None + detection_start_time = datetime.now() + + # 重置进度相关变量 + global current_video_name, total_videos, current_video_index, current_frame_count, total_frame_count + current_video_name = "" + total_videos = len(videos) + current_video_index = 0 + current_frame_count = 0 + total_frame_count = 0 + + # 启动检测线程 + detection_thread = threading.Thread(target=run_detection) + detection_thread.daemon = True + detection_thread.start() + + flash('检测已开始,请等待结果', 'info') + return redirect(url_for('index')) + +# 检测线程函数 +def run_detection(): + global detection_progress, detection_status, detection_results, current_report_path, current_video_name, current_video_index, current_frame_count + + try: + print(f"开始检测,模型路径: {config.MODEL_PATH}") + print(f"置信度阈值: {config.CONFIDENCE_THRESHOLD}") + print(f"输入文件夹: {config.INPUT_FOLDER}") + + # 获取视频列表 + videos = [f for f in os.listdir(config.INPUT_FOLDER) + if any(f.endswith(ext) for ext in config.SUPPORTED_VIDEO_FORMATS)] + print(f"找到视频文件: {videos}") + + # 创建检测报告基本信息 + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + + # 创建检测器实例 + print("正在创建检测器实例...") + detector = BatchVideoDetector( + model_path=config.MODEL_PATH, + confidence=config.CONFIDENCE_THRESHOLD + ) + print("检测器实例创建成功") + + # 显示模型检测能力信息 + print("\n=== 网页端模型检测分析 ===") + if hasattr(detector.model, 'names'): + model_classes = list(detector.model.names.values()) + print(f"当前模型可检测场景: {model_classes}") + print(f"使用的置信度阈值: {config.CONFIDENCE_THRESHOLD}") + print("=" * 40) + + # 设置进度回调 + def progress_callback(status, progress, video_name=None, video_index=None, frame_count=None): + global detection_status, detection_progress, current_video_name, current_video_index, current_frame_count + detection_status = status + detection_progress = progress + + # 更新详细进度信息 + if video_name: + current_video_name = video_name + if video_index is not None: + current_video_index = video_index + if frame_count is not None: + current_frame_count = frame_count + + print(f"进度更新: {status} - {progress}%") + if video_name: + print(f"当前视频: {video_name} ({video_index}/{total_videos})") + if frame_count is not None: + print(f"已处理帧数: {frame_count}") + + detector.set_progress_callback(progress_callback) + + # 执行检测 + detection_status = "开始检测视频..." + report_path = detector.process_all_videos(config.INPUT_FOLDER, config.OUTPUT_FOLDER) + + # 读取检测结果 + if report_path and os.path.exists(report_path): + with open(report_path, 'r', encoding='utf-8') as f: + detection_results = json.load(f) + + current_report_path = report_path + detection_status = "检测完成" + detection_progress = 100 + else: + detection_status = "检测完成,但未生成报告" + detection_progress = 100 + except Exception as e: + detection_status = f"检测失败: {str(e)}" + print(f"检测异常: {str(e)}") + import traceback + traceback.print_exc() + finally: + # 如果状态仍为running,则更新为失败 + if detection_status == "running": + detection_status = "失败" + # 确保进度显示为100% + if detection_progress < 100: + detection_progress = 100 + +# 获取检测进度 +@app.route('/detection_progress') +def get_detection_progress(): + # 计算检测时长 + elapsed_time = "" + if detection_start_time and detection_status == "running": + elapsed_seconds = (datetime.now() - detection_start_time).total_seconds() + minutes, seconds = divmod(elapsed_seconds, 60) + elapsed_time = f"{int(minutes)}分{int(seconds)}秒" + + # 构建详细状态信息 + detailed_status = detection_status + if detection_status == "running" and current_video_name: + detailed_status = f"正在处理: {current_video_name} ({current_video_index}/{total_videos})" + if current_frame_count > 0: + detailed_status += f" - 已处理帧数: {current_frame_count}" + + return jsonify({ + 'in_progress': detection_status == "running", + 'progress': detection_progress, + 'status': detailed_status, + 'elapsed_time': elapsed_time, + 'current_video': current_video_name, + 'current_video_index': current_video_index, + 'total_videos': total_videos, + 'current_frame_count': current_frame_count, + 'total_frame_count': total_frame_count + }) + +# 查看检测报告 +@app.route('/report/') +def view_report(report_filename): + report_path = os.path.join(config.OUTPUT_FOLDER, report_filename) + if not os.path.exists(report_path): + flash('报告文件不存在', 'danger') + return redirect(url_for('index')) + + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + + # 计算检测率 + total_frames = report_data.get('statistics', {}).get('total_frames', 0) + detected_frames = report_data.get('statistics', {}).get('detected_frames', 0) + detection_rate = 0 + if total_frames > 0: + detection_rate = (detected_frames / total_frames) * 100 + + # 为每个视频计算检测率 + for video in report_data.get('videos', []): + video_total_frames = video.get('total_frames', 0) + video_detected_frames = video.get('detected_frames', 0) + if video_total_frames > 0: + video['detection_rate'] = (video_detected_frames / video_total_frames) * 100 + else: + video['detection_rate'] = 0 + + return render_template('report.html', + report=report_data, + report_filename=report_filename, + detection_rate=detection_rate) + + except Exception as e: + flash(f'读取报告时出错: {str(e)}', 'danger') + return redirect(url_for('index')) + +# 查看视频检测结果 +@app.route('/video_results//') +def view_video_results(report_filename, video_name): + report_path = os.path.join(config.OUTPUT_FOLDER, report_filename) + if not os.path.exists(report_path): + flash('报告文件不存在', 'danger') + return redirect(url_for('index')) + + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + + # 查找特定视频的结果 + video_data = None + for video in report_data['videos']: + if video['video_name'] == video_name: + video_data = video + break + + if not video_data: + flash('视频结果不存在', 'danger') + return redirect(url_for('view_report', report_filename=report_filename)) + + # 获取检测到的帧列表 + video_output_dir = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video_name)[0]) + detected_frames = [] + + if os.path.exists(video_output_dir): + for frame_file in os.listdir(video_output_dir): + if frame_file.startswith('detected_') and frame_file.endswith('.jpg'): + try: + frame_number = int(frame_file.split('_')[1].split('.')[0]) + frame_path = os.path.join(os.path.basename(video_output_dir), frame_file) + + # 获取对应的JSON文件以读取检测信息 + json_file = frame_file.replace('.jpg', '.json') + json_path = os.path.join(video_output_dir, json_file) + detections = [] + + if os.path.exists(json_path): + with open(json_path, 'r', encoding='utf-8') as f: + frame_data = json.load(f) + detections = frame_data.get('detections', []) + + detected_frames.append({ + 'frame_number': frame_number, + 'filename': frame_file, + 'path': frame_path, + 'timestamp': frame_number / video_data.get('fps', 30), # 计算时间戳 + 'detections': detections, + 'detection_count': len(detections) + }) + except Exception as e: + print(f"Error processing frame {frame_file}: {e}") + + # 按帧号排序 + detected_frames.sort(key=lambda x: x['frame_number']) + + # 计算检测率 + total_frames = video_data.get('total_frames', 0) + detected_count = video_data.get('detected_frames', 0) + detection_rate = 0 + if total_frames > 0: + detection_rate = (detected_count / total_frames) * 100 + + # 获取时间线数据 + timeline_data = [] + for frame in detected_frames: + timeline_data.append({ + 'timestamp': frame['timestamp'], + 'frame_number': frame['frame_number'], + 'detection_count': frame['detection_count'] + }) + + return render_template('video_results.html', + report_filename=report_filename, + report=report_data, + video=video_data, + detected_frames=detected_frames, + detection_rate=detection_rate, + timeline_data=json.dumps(timeline_data)) + + except Exception as e: + flash(f'读取视频结果时出错: {str(e)}', 'danger') + return redirect(url_for('view_report', report_filename=report_filename)) + +# 获取检测到的帧图像 +@app.route('/frame/') +def get_frame(frame_path): + directory, filename = os.path.split(frame_path) + return send_from_directory(os.path.join(config.OUTPUT_FOLDER, directory), filename) + +# 删除报告 +@app.route('/delete_report/', methods=['POST']) +def delete_report(report_filename): + report_path = os.path.join(config.OUTPUT_FOLDER, report_filename) + if not os.path.exists(report_path): + flash('报告文件不存在', 'danger') + return redirect(url_for('index')) + + # 读取报告以获取视频文件夹列表 + try: + with open(report_path, 'r', encoding='utf-8') as f: + report_data = json.load(f) + + # 删除每个视频的输出文件夹 + for video in report_data['videos']: + video_output_folder = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video['video_name'])[0]) + if os.path.exists(video_output_folder) and os.path.isdir(video_output_folder): + shutil.rmtree(video_output_folder) + + # 删除报告文件 + os.remove(report_path) + flash(f'报告 {report_filename} 已删除', 'success') + except Exception as e: + flash(f'删除报告失败: {str(e)}', 'danger') + + return redirect(url_for('index')) + +# 添加BatchVideoDetector的进度回调方法 +def add_progress_callback_to_detector(): + # 检查BatchVideoDetector类是否已有set_progress_callback方法 + if not hasattr(BatchVideoDetector, 'set_progress_callback'): + def set_progress_callback(self, callback): + self.progress_callback = callback + + def update_progress(self, status, progress, video_name=None, video_index=None, frame_count=None): + if hasattr(self, 'progress_callback') and self.progress_callback: + self.progress_callback(status, progress, video_name, video_index, frame_count) + + # 添加方法到BatchVideoDetector类 + BatchVideoDetector.set_progress_callback = set_progress_callback + BatchVideoDetector.update_progress = update_progress + + # 修改process_all_videos方法以支持进度更新 + original_process_all_videos = BatchVideoDetector.process_all_videos + + def process_all_videos_with_progress(self, input_folder="input_videos", output_folder="output_frames"): + self.update_progress("获取视频文件列表", 0) + video_files = self.get_video_files(input_folder) + total_videos = len(video_files) + + if total_videos == 0: + self.update_progress("没有找到视频文件", 100) + return None + + self.update_progress(f"开始处理 {total_videos} 个视频文件", 5) + + # 处理每个视频文件 + for i, video_path in enumerate(video_files, 1): + progress = int(5 + (i / total_videos) * 90) + video_name = video_path.name + self.update_progress(f"处理视频 {i}/{total_videos}: {video_name}", progress, video_name, i) + self.process_video(video_path, output_folder) + + self.update_progress("生成检测报告", 95) + self.stats['end_time'] = datetime.now() + report_path = self.save_final_report(output_folder) + self.print_summary() + self.update_progress("检测完成", 100) + return report_path + + # 修改save_final_report方法以返回报告路径 + original_save_final_report = BatchVideoDetector.save_final_report + + def save_final_report_with_return(self, output_folder): + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json" + + # 计算总处理时间 + total_seconds = 0 + if self.stats['start_time'] and self.stats['end_time']: + total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds() + + # 创建报告数据 + report_data = { + 'timestamp': timestamp, + 'model_path': self.model_path, + 'confidence_threshold': self.confidence, + 'frame_interval_threshold': config.MIN_FRAME_INTERVAL, + 'filter_strategy': '相同场景在指定时间间隔内只保存一次', + 'statistics': { + 'total_videos': self.stats['total_videos'], + 'processed_videos': self.stats['processed_videos'], + 'total_frames': self.stats['total_frames'], + 'detected_frames': self.stats['detected_frames'], + 'filtered_frames': self.stats['filtered_frames'], + 'saved_images': self.stats['saved_images'], + 'processing_time_seconds': total_seconds + }, + 'errors': self.stats['errors'], + 'videos': [] + } + + # 保存为JSON + with open(report_path, 'w', encoding='utf-8') as f: + json.dump(report_data, f, ensure_ascii=False, indent=2) + + print(f"\n批处理报告已保存: {report_path}") + return str(report_path) + + BatchVideoDetector.process_all_videos = process_all_videos_with_progress + BatchVideoDetector.save_final_report = save_final_report_with_return + +# 在应用启动前添加进度回调方法 +add_progress_callback_to_detector() + +if __name__ == '__main__': + # 确保必要的文件夹存在 + for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']: + os.makedirs(folder, exist_ok=True) + + app.run(debug=True, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/batch_detector.py b/batch_detector.py new file mode 100644 index 0000000..af37826 --- /dev/null +++ b/batch_detector.py @@ -0,0 +1,642 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +批量视频检测器 +简化版本,专注于高效处理大量视频文件 +""" + +import os +import cv2 +import numpy as np +from pathlib import Path +from ultralytics import YOLO +import json +import time +from datetime import datetime +import torch +from config import config, DetectionConfig +from PIL import Image, ImageDraw, ImageFont + +class BatchVideoDetector: + def __init__(self, model_path=None, confidence=0.5): + """ + 初始化批量视频检测器 + + Args: + model_path (str): 模型路径,如果为None则使用配置文件中的当前模型路径 + confidence (float): 置信度阈值 + """ + # 优先使用传入的模型路径,如果没有则使用配置中的当前模型路径 + self.model_path = model_path or config.MODEL_PATH + self.confidence = confidence + + # 创建必要文件夹 + config.create_folders() + + # 加载模型 + self.load_model() + + # 时间间隔过滤变量 + self.last_saved_timestamp = 0 + + # 初始化场景类别变量 + self.last_scene_classes = set() + + # 动态检测参数 + self.adaptive_confidence = True + self.min_confidence = 0.2 + self.scene_confidence_history = [] # 记录场景置信度历史 + self.detection_quality_threshold = 0.3 # 检测质量阈值 + + # 统计信息 + self.stats = { + 'start_time': None, + 'end_time': None, + 'total_videos': 0, + 'processed_videos': 0, + 'total_frames': 0, + 'detected_frames': 0, + 'saved_images': 0, + 'filtered_frames': 0, + 'adaptive_detections': 0, # 自适应检测次数 + 'errors': [] + } + + # 视频详细信息列表 + self.video_details = [] + + def load_model(self): + """加载YOLO模型""" + try: + print(f"正在加载模型: {self.model_path}") + self.model = YOLO(self.model_path) + print(f"模型加载成功!") + + # 检查是否有缓存的模型信息 + model_filename = os.path.basename(self.model_path) + info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json") + + cached_info_loaded = False + if os.path.exists(info_path): + try: + with open(info_path, 'r', encoding='utf-8') as f: + import json + cached_info = json.load(f) + if cached_info.get('analyzed', False): + print(f"使用缓存的模型信息,跳过重复分析") + print(f"模型类别数量: {cached_info.get('class_count', 0)}") + print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}") + cached_info_loaded = True + except Exception as e: + print(f"读取缓存模型信息失败: {e},将重新分析") + + # 如果没有缓存信息,则进行完整分析 + if not cached_info_loaded: + # 动态加载模型类别信息到配置中 + DetectionConfig.load_model_classes(self.model_path) + + if hasattr(self.model, 'names'): + print(f"\n=== 模型检测场景分析 ===") + print(f"模型类别数量: {len(self.model.names)}") + print(f"模型原始类别: {list(self.model.names.values())}") + + print(f"\n=== 中文类别映射 ===") + for class_id, class_name in self.model.names.items(): + class_name_cn = config.get_class_name_cn(class_id) + print(f"类别 {class_id}: {class_name} -> {class_name_cn}") + + # 分析模型类型并优化参数 + self.analyze_and_optimize_model() + + print(f"\n=== 检测场景总结 ===") + scene_types = list(DetectionConfig.get_all_classes().values()) + print(f"本模型可检测的场景类型: {scene_types}") + print(f"检测置信度阈值: {self.confidence}") + print("=" * 40) + else: + # 使用缓存信息更新配置 + DetectionConfig.load_model_classes(self.model_path) + + except Exception as e: + print(f"模型加载失败: {e}") + raise + + def analyze_and_optimize_model(self): + """分析模型类型并优化检测参数""" + model_classes = list(self.model.names.values()) + + # 检测模型类型 + road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤'] + general_object_keywords = ['person', 'car', 'truck', 'bus', 'bicycle'] + + is_road_damage = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords) + is_general_object = any(keyword in ' '.join(model_classes).lower() for keyword in general_object_keywords) + + if is_road_damage: + print("检测到道路损伤专用模型,优化参数设置") + self.confidence = max(0.25, self.confidence) + self.min_confidence = 0.15 + self.detection_quality_threshold = 0.2 + elif is_general_object: + print("检测到通用目标检测模型,使用标准参数") + self.confidence = max(0.4, self.confidence) + self.min_confidence = 0.3 + self.detection_quality_threshold = 0.35 + else: + print("检测到自定义模型,使用保守参数") + self.confidence = max(0.3, self.confidence) + self.min_confidence = 0.2 + self.detection_quality_threshold = 0.25 + + print(f"优化后的置信度阈值: {self.confidence}") + print(f"最低置信度阈值: {self.min_confidence}") + print(f"检测质量阈值: {self.detection_quality_threshold}") + + def get_video_files(self, input_folder="input_videos"): + """获取所有视频文件""" + input_path = Path(input_folder) + if not input_path.exists(): + print(f"输入文件夹不存在: {input_folder}") + return [] + + video_files = [] + for ext in config.SUPPORTED_VIDEO_FORMATS: + video_files.extend(input_path.rglob(f"*{ext}")) + + print(f"找到 {len(video_files)} 个视频文件") + return video_files + + def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir): + """检测单帧并保存结果""" + try: + # 进行多级检测 + result = self.adaptive_detect_frame(frame) + + # 如果检测到目标 + if result is not None and result.boxes is not None and len(result.boxes) > 0: + # 获取当前帧的检测类别 + current_scene_classes = set() + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + current_scene_classes.add(class_id) + + # 检查时间间隔 + time_diff = timestamp - self.last_saved_timestamp + + # 如果时间间隔小于最小间隔,检查是否为新场景 + if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0: + # 如果当前帧的场景类别与上一帧相同,则过滤 + if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes: + self.stats['filtered_frames'] += 1 + # 仍然返回检测信息,但不保存图像 + detections = [] + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + confidence = float(box.conf[0].cpu().numpy()) + bbox = box.xyxy[0].cpu().numpy().tolist() + class_name = self.model.names.get(class_id, f"class_{class_id}") + class_name_cn = config.get_class_name_cn(class_id) + + detections.append({ + 'class_id': class_id, + 'class_name': class_name, + 'class_name_cn': class_name_cn, + 'confidence': confidence, + 'bbox': bbox, + 'filtered': True + }) + + return { + 'frame_number': frame_number, + 'timestamp': timestamp, + 'detections': detections, + 'filtered': True + } + + # 更新最后检测到的场景类别 + self.last_scene_classes = current_scene_classes + + # 更新最后保存的时间戳 + self.last_saved_timestamp = timestamp + + # 创建文件名 + base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s" + + # 获取检测到的场景类别,用于分类保存 + detected_scenes = set() + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + class_name_cn = config.get_class_name_cn(class_id) + detected_scenes.add(class_name_cn) + + # 为每个检测到的场景创建对应的文件夹并保存图片 + for scene_name in detected_scenes: + # 创建视频名称文件夹 + video_folder = output_dir / video_name + video_folder.mkdir(parents=True, exist_ok=True) + + # 创建场景文件夹 + scene_folder = video_folder / scene_name + scene_folder.mkdir(parents=True, exist_ok=True) + + # 保存原始帧 + if config.SAVE_ORIGINAL_FRAMES: + original_path = scene_folder / f"{base_name}_original.jpg" + cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY]) + + # 绘制并保存标注帧 + if config.SAVE_ANNOTATED_FRAMES: + annotated_frame = self.draw_detections(frame.copy(), result) + annotated_path = scene_folder / f"{base_name}_detected.jpg" + cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY]) + + # 收集检测信息 + detections = [] + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + confidence = float(box.conf[0].cpu().numpy()) + bbox = box.xyxy[0].cpu().numpy().tolist() + class_name = self.model.names.get(class_id, f"class_{class_id}") + class_name_cn = config.get_class_name_cn(class_id) + + detections.append({ + 'class_id': class_id, + 'class_name': class_name, + 'class_name_cn': class_name_cn, + 'confidence': confidence, + 'bbox': bbox + }) + + return { + 'frame_number': frame_number, + 'timestamp': timestamp, + 'detections': detections + } + + return None + + except Exception as e: + print(f"处理帧 {frame_number} 时出错: {e}") + self.stats['errors'].append(f"帧 {frame_number}: {str(e)}") + return None + + def adaptive_detect_frame(self, frame): + """自适应检测单帧""" + try: + # 第一次检测:使用标准置信度 + results = self.model(frame, conf=self.confidence, verbose=False, + imgsz=640, augment=True) + result = results[0] if results else None + + # 如果没有检测到目标且启用自适应检测 + if self.adaptive_confidence and (not result or not result.boxes or len(result.boxes) == 0): + # 尝试降低置信度检测 + lower_conf = max(self.min_confidence, self.confidence - 0.15) + if lower_conf < self.confidence: + results = self.model(frame, conf=lower_conf, verbose=False, + imgsz=640, augment=True) + result = results[0] if results else None + + if result and result.boxes and len(result.boxes) > 0: + self.stats['adaptive_detections'] += 1 + print(f"自适应检测成功,置信度: {lower_conf:.2f}") + + # 过滤低质量检测 + if result and result.boxes and len(result.boxes) > 0: + valid_indices = [] + for i, box in enumerate(result.boxes): + confidence = float(box.conf[0].cpu().numpy()) + if confidence >= self.detection_quality_threshold: + valid_indices.append(i) + + # 如果有有效检测,保留原始结果结构 + if not valid_indices: + result = None + + return result + + except Exception as e: + print(f"自适应检测失败: {e}") + return None + + def draw_detections(self, frame, result): + """在帧上绘制检测结果(绘制边界框和中文标注)""" + if result is None or result.boxes is None: + return frame + + # 转换为PIL图像 + pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # 尝试加载中文字体 + try: + # Windows系统常用中文字体路径 + font_paths = [ + "C:/Windows/Fonts/simhei.ttf", # 黑体 + "C:/Windows/Fonts/simsun.ttc", # 宋体 + "C:/Windows/Fonts/msyh.ttc", # 微软雅黑 + ] + font = None + for font_path in font_paths: + try: + font = ImageFont.truetype(font_path, 20) + break + except: + continue + if font is None: + font = ImageFont.load_default() + except: + font = ImageFont.load_default() + + for box in result.boxes: + # 获取坐标和信息 + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + class_id = int(box.cls[0].cpu().numpy()) + confidence = float(box.conf[0].cpu().numpy()) + + # 动态获取类别名称(优先使用模型原始名称) + if hasattr(self.model, 'names') and class_id in self.model.names: + original_name = self.model.names[class_id] + class_name_cn = config.get_class_name_cn(class_id) + # 如果中文名称就是原始名称,说明没有映射,直接使用原始名称 + if class_name_cn == original_name or class_name_cn.startswith('未知类别'): + display_name = original_name + else: + display_name = f"{class_name_cn}({original_name})" + else: + display_name = config.get_class_name_cn(class_id) + + # 根据置信度调整颜色强度和线宽 + base_color = config.get_class_color(class_id) + intensity = min(1.0, confidence + 0.3) + color = tuple(int(c * intensity) for c in base_color) + bg_color = tuple(reversed(color)) # BGR转RGB + line_width = max(2, int(confidence * 4)) + + # 绘制边界框 + draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width) + + # 准备标注文字 + label_text = f"{display_name}: {confidence:.3f}" + + # 计算文字背景框大小 + bbox = draw.textbbox((0, 0), label_text, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + # 确保标注不超出图像边界 + label_x = max(0, x1) + label_y = max(0, y1 - text_height - 5) + if label_y < 0: + label_y = y1 + 5 + + # 绘制文字背景 + draw.rectangle( + [label_x, label_y, label_x + text_width + 4, label_y + text_height + 4], + fill=bg_color + ) + + # 绘制文字(白色) + draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font) + + # 转换回OpenCV格式 + frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + return frame + + def process_video(self, video_path, output_base_dir="output_frames"): + """处理单个视频""" + print(f"\n开始处理: {video_path.name}") + + # 创建输出目录 + output_dir = Path(output_base_dir) / video_path.stem + output_dir.mkdir(parents=True, exist_ok=True) + + # 文件夹将在检测到对应类别时按需创建 + + # 重置时间间隔过滤变量,确保每个视频独立计算时间间隔 + self.last_saved_timestamp = 0 + + # 重置场景类别变量,确保每个视频独立分类场景 + self.last_scene_classes = set() + + # 打开视频 + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + error_msg = f"无法打开视频: {video_path}" + print(error_msg) + self.stats['errors'].append(error_msg) + return + + # 获取视频信息 + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS") + + # 处理统计 + frame_count = 0 + detected_frames = 0 + saved_images = 0 + filtered_frames = 0 + all_detections = [] + + start_time = time.time() + + try: + while True: + ret, frame = cap.read() + if not ret: + break + + frame_count += 1 + timestamp = frame_count / fps if fps > 0 else frame_count + + # 显示进度 + if frame_count % config.PROGRESS_INTERVAL == 0: + progress = (frame_count / total_frames) * 100 + elapsed = time.time() - start_time + fps_current = frame_count / elapsed + print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS") + + # 更新进度回调 + if hasattr(self, 'update_progress'): + self.update_progress(f"处理视频帧", None, None, None, frame_count) + + # 检测并保存 + detection_result = self.detect_and_save_frame( + frame, frame_count, timestamp, video_path.stem, output_dir + ) + + if detection_result: + detected_frames += 1 + all_detections.append(detection_result) + + # 检查是否是被过滤的帧 + if detection_result.get('filtered', False): + filtered_frames += 1 + else: + # 计算保存的图片数量 + if config.SAVE_ORIGINAL_FRAMES: + saved_images += 1 + if config.SAVE_ANNOTATED_FRAMES: + saved_images += 1 + + finally: + cap.release() + + # 保存检测结果JSON + if config.SAVE_DETECTION_JSON and all_detections: + json_path = output_dir / f"{video_path.stem}_detections.json" + detection_summary = { + 'video_name': video_path.name, + 'total_frames': frame_count, + 'detected_frames': detected_frames, + 'fps': fps, + 'detections': all_detections, + 'processing_time': time.time() - start_time + } + + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(detection_summary, f, ensure_ascii=False, indent=2) + + # 更新统计 + self.stats['processed_videos'] += 1 + self.stats['total_frames'] += frame_count + self.stats['detected_frames'] += detected_frames + self.stats['saved_images'] += saved_images + self.stats['filtered_frames'] += filtered_frames + + processing_time = time.time() - start_time + print(f"处理完成: {video_path.name}") + print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images} 张") + print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)") + print(f"处理时间: {processing_time:.2f} 秒") + + # 记录视频详细信息 + video_info = { + 'video_name': video_path.name, + 'video_path': str(video_path), + 'total_frames': frame_count, + 'detected_frames': detected_frames, + 'saved_images': saved_images, + 'filtered_frames': filtered_frames, + 'processing_time': processing_time, + 'fps': fps, + 'duration': frame_count / fps if fps > 0 else 0, + 'output_directory': str(output_dir) + } + self.video_details.append(video_info) + + def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"): + """批量处理所有视频""" + self.stats['start_time'] = datetime.now() + + video_files = self.get_video_files(input_folder) + if not video_files: + print("没有找到视频文件") + return + + self.stats['total_videos'] = len(video_files) + + print(f"\n开始批量处理 {len(video_files)} 个视频文件") + print(f"模型: {self.model_path}") + print(f"置信度阈值: {self.confidence}") + print(f"输出目录: {output_folder}") + print("=" * 50) + + for i, video_path in enumerate(video_files, 1): + print(f"\n[{i}/{len(video_files)}] 处理视频") + self.process_video(video_path, output_folder) + + self.stats['end_time'] = datetime.now() + report_path = self.save_final_report(output_folder) + self.print_summary() + return report_path + + def save_final_report(self, output_folder): + """保存最终批处理报告""" + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json" + + # 计算总处理时间 + total_seconds = 0 + if self.stats['start_time'] and self.stats['end_time']: + total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds() + + # 创建报告数据 + report_data = { + 'timestamp': timestamp, + 'model_path': self.model_path, + 'confidence_threshold': self.confidence, + 'frame_interval_threshold': config.MIN_FRAME_INTERVAL, + 'filter_strategy': '相同场景在指定时间间隔内只保存一次', + 'statistics': { + 'total_videos': self.stats['total_videos'], + 'processed_videos': self.stats['processed_videos'], + 'total_frames': self.stats['total_frames'], + 'detected_frames': self.stats['detected_frames'], + 'filtered_frames': self.stats['filtered_frames'], + 'saved_images': self.stats['saved_images'], + 'processing_time_seconds': total_seconds + }, + 'videos': self.video_details, + 'errors': self.stats['errors'] + } + + # 保存为JSON + with open(report_path, 'w', encoding='utf-8') as f: + json.dump(report_data, f, ensure_ascii=False, indent=2) + + print(f"\n批处理报告已保存: {report_path}") + return str(report_path) + + def print_summary(self): + """打印处理摘要""" + if not self.stats['end_time'] or not self.stats['start_time']: + return + + total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds() + hours, remainder = divmod(total_time, 3600) + minutes, seconds = divmod(remainder, 60) + + print("\n" + "=" * 50) + print("批量处理摘要") + print("=" * 50) + print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}") + print(f"总处理帧数: {self.stats['total_frames']}") + print(f"检测到目标的帧数: {self.stats['detected_frames']}") + print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)") + print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次") + print(f"实际保存的图片数量: {self.stats['saved_images']}") + print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}秒") + + if self.stats['errors']: + print("\n处理过程中的错误:") + for error in self.stats['errors']: + print(f"- {error}") + + print("=" * 50) + +def main(): + """主函数""" + import argparse + + parser = argparse.ArgumentParser(description='批量视频检测系统') + parser.add_argument('--model', type=str, help='模型文件路径') + parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值') + parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹') + parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹') + + args = parser.parse_args() + + # 创建检测器 + detector = BatchVideoDetector( + model_path=args.model, + confidence=args.confidence + ) + + # 开始批量处理 + detector.process_all_videos(args.input, args.output) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..a1f3efe --- /dev/null +++ b/config.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +视频检测系统配置文件 +""" + +import os +from pathlib import Path +from ultralytics import YOLO + +class DetectionConfig: + """检测系统配置类""" + + # 模型配置 + MODEL_PATH = "models/best.pt" + CONFIDENCE_THRESHOLD = 0.5 + + # 动态模型信息 + _model_classes = None + _model_colors = None + + # 文件夹配置 + INPUT_FOLDER = "input_videos" + OUTPUT_FOLDER = "output_frames" + LOG_FOLDER = "logs" + + # 视频处理配置 + SUPPORTED_VIDEO_FORMATS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'} + PROGRESS_INTERVAL = 100 # 每处理多少帧显示一次进度 + MIN_FRAME_INTERVAL = 1.0 # 检测帧的最小时间间隔(秒) + + # 输出配置 + SAVE_ORIGINAL_FRAMES = False # 是否保存原始帧 + SAVE_ANNOTATED_FRAMES = True # 是否保存标注帧 + SAVE_DETECTION_JSON = True # 是否保存检测结果JSON + + # 图像质量配置 + IMAGE_QUALITY = 95 # JPEG质量 (1-100) + + # 默认道路损伤类别映射 (中文标签) - 作为备用 + DEFAULT_CLASS_NAMES_CN = { + 0: "纵向裂缝", + 1: "横向裂缝", + 2: "网状裂缝", + 3: "坑洞", + 4: "白线模糊" + } + + # 通用类别关键词映射(扩展版) + COMMON_CLASS_MAPPINGS = { + # 道路损伤相关 + 'crack': '裂缝', + 'pothole': '坑洞', + 'damage': '损伤', + 'longitudinal': '纵向裂缝', + 'transverse': '横向裂缝', + 'alligator': '网状裂缝', + 'rutting': '车辙', + 'bleeding': '泛油', + 'weathering': '风化', + 'patching': '修补', + + # 交通工具 + 'person': '人', + 'people': '人群', + 'car': '汽车', + 'truck': '卡车', + 'bus': '公交车', + 'motorcycle': '摩托车', + 'bicycle': '自行车', + 'bike': '自行车', + 'vehicle': '车辆', + + # 交通设施 + 'road': '道路', + 'traffic': '交通', + 'sign': '标志', + 'light': '信号灯', + 'signal': '信号', + 'stop': '停车标志', + 'yield': '让行标志', + 'crosswalk': '人行横道', + 'lane': '车道', + 'marking': '标线', + 'white': '白线', + 'yellow': '黄线', + + # 其他常见目标 + 'animal': '动物', + 'dog': '狗', + 'cat': '猫', + 'bird': '鸟', + 'tree': '树', + 'building': '建筑', + 'pole': '杆子', + 'fence': '围栏' + } + + # 动态颜色生成配置 + COLOR_PALETTE = [ + (0, 255, 0), # 绿色 + (255, 0, 0), # 红色 + (0, 0, 255), # 蓝色 + (255, 255, 0), # 黄色 + (255, 0, 255), # 紫色 + (0, 255, 255), # 青色 + (255, 128, 0), # 橙色 + (128, 0, 255), # 紫罗兰 + (255, 192, 203), # 粉色 + (128, 128, 128) # 灰色 + ] + + # 默认类别颜色配置 (BGR格式) + DEFAULT_CLASS_COLORS = { + 0: (0, 255, 0), # 绿色 + 1: (255, 0, 0), # 红色 + 2: (0, 0, 255), # 蓝色 + 3: (255, 255, 0), # 黄色 + 4: (255, 0, 255) # 紫色 + } + + @classmethod + def get_model_path(cls): + """获取模型文件的绝对路径""" + current_dir = Path(__file__).parent + model_path = current_dir / cls.MODEL_PATH + return str(model_path.resolve()) + + @classmethod + def create_folders(cls): + """创建必要的文件夹""" + folders = [cls.INPUT_FOLDER, cls.OUTPUT_FOLDER, cls.LOG_FOLDER] + for folder in folders: + Path(folder).mkdir(parents=True, exist_ok=True) + + @classmethod + def load_model_classes(cls, model_path): + """从模型文件动态加载类别信息""" + try: + model = YOLO(model_path) + cls._model_classes = model.names + + # 动态生成类别颜色 + cls._model_colors = {} + for i in range(len(model.names)): + # 使用调色板循环分配颜色 + color_index = i % len(cls.COLOR_PALETTE) + cls._model_colors[i] = cls.COLOR_PALETTE[color_index] + + print(f"成功加载 {len(model.names)} 个类别,分配了对应颜色") + return True + except Exception as e: + print(f"加载模型类别信息失败: {e}") + return False + + @classmethod + def get_class_name_cn(cls, class_id): + """获取类别的中文名称""" + # 直接使用模型的原始类别名称,不进行任何映射 + if cls._model_classes is not None: + original_name = cls._model_classes.get(class_id) + if original_name: + return original_name + return f"未知类别_{class_id}" + # 回退到默认中文类别 + return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, f"未知类别_{class_id}") + + @classmethod + def get_class_color(cls, class_id): + """获取类别对应的颜色""" + # 如果class_id是字符串,尝试转换为整数或使用hash + if isinstance(class_id, str): + # 尝试从类别名称获取ID + if cls._model_classes and class_id in cls._model_classes.values(): + # 找到对应的ID + for id_val, name in cls._model_classes.items(): + if name == class_id: + class_id = id_val + break + else: + # 使用hash生成一个稳定的索引 + class_id = hash(class_id) % len(cls.COLOR_PALETTE) + + # 确保class_id是整数 + if not isinstance(class_id, int): + class_id = 0 + + if cls._model_colors is not None: + return cls._model_colors.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)]) + return cls.DEFAULT_CLASS_COLORS.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)]) + + @classmethod + def get_all_classes(cls): + """获取所有类别信息""" + if cls._model_classes is not None: + return cls._model_classes + return cls.DEFAULT_CLASS_NAMES_CN + +# 全局配置实例 +config = DetectionConfig() \ No newline at end of file diff --git a/detector.py b/detector.py new file mode 100644 index 0000000..8da12b7 --- /dev/null +++ b/detector.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +视频目标检测系统 +使用YOLO模型对视频进行道路损伤检测,并抽取包含目标的帧 +""" + +import os +import cv2 +import numpy as np +from pathlib import Path +from ultralytics import YOLO +import argparse +from datetime import datetime +import json +import logging +import torch +from PIL import Image, ImageDraw, ImageFont +from config import config, DetectionConfig + +class VideoDetectionSystem: + def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"): + """ + 初始化视频检测系统 + + Args: + model_path (str): YOLO模型文件路径 + confidence_threshold (float): 置信度阈值 + input_folder (str): 输入视频文件夹 + output_folder (str): 输出图片文件夹 + """ + self.model_path = model_path + self.confidence_threshold = confidence_threshold + self.input_folder = Path(input_folder) + self.output_folder = Path(output_folder) + + # 创建输出文件夹 + self.output_folder.mkdir(parents=True, exist_ok=True) + + # 设置日志 + self.setup_logging() + + # 加载模型 + self.load_model() + + # 支持的视频格式 + self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'} + + # 动态检测参数 + self.adaptive_confidence = True # 启用自适应置信度 + self.min_confidence = 0.2 # 最低置信度 + self.max_confidence = 0.8 # 最高置信度 + self.scene_adaptation = True # 启用场景自适应 + + # 检测结果统计 + self.detection_stats = { + 'total_videos': 0, + 'total_frames': 0, + 'detected_frames': 0, + 'saved_frames': 0, + 'detection_results': [] + } + + # 统计信息 + self.stats = { + 'adaptive_detections': 0 + } + + # 检测质量阈值 + self.detection_quality_threshold = 0.3 + + def setup_logging(self): + """设置日志记录""" + log_folder = Path('logs') + log_folder.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = log_folder / f'detection_{timestamp}.log' + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file, encoding='utf-8'), + logging.StreamHandler() + ] + ) + self.logger = logging.getLogger(__name__) + + def load_model(self): + """加载YOLO模型""" + try: + self.model = YOLO(self.model_path) + self.logger.info(f"成功加载模型: {self.model_path}") + + # 动态加载模型类别信息到配置中 + DetectionConfig.load_model_classes(self.model_path) + + # 打印模型信息 + if hasattr(self.model, 'names'): + self.logger.info(f"模型类别数量: {len(self.model.names)}") + self.logger.info(f"检测类别: {list(self.model.names.values())}") + self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}") + + # 分析模型类型并设置优化参数 + self.analyze_model_type() + + except Exception as e: + self.logger.error(f"加载模型失败: {e}") + raise + + def analyze_model_type(self): + """分析模型类型并设置相应的检测参数""" + model_classes = list(self.model.names.values()) + + # 检测是否为道路损伤专用模型 + road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤'] + is_road_damage_model = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords) + + if is_road_damage_model: + self.logger.info("检测到道路损伤专用模型,使用优化参数") + self.confidence_threshold = max(0.25, self.confidence_threshold) # 道路损伤检测使用较低阈值 + self.min_confidence = 0.15 + else: + self.logger.info("检测到通用模型,使用标准参数") + self.confidence_threshold = max(0.4, self.confidence_threshold) # 通用模型使用较高阈值 + self.min_confidence = 0.3 + + self.logger.info(f"调整后的置信度阈值: {self.confidence_threshold}") + + def get_video_files(self): + """获取输入文件夹中的所有视频文件""" + video_files = [] + + if not self.input_folder.exists(): + self.logger.warning(f"输入文件夹不存在: {self.input_folder}") + return video_files + + for file_path in self.input_folder.rglob('*'): + if file_path.is_file() and file_path.suffix.lower() in self.video_extensions: + video_files.append(file_path) + + self.logger.info(f"找到 {len(video_files)} 个视频文件") + return video_files + + def detect_frame(self, frame, adaptive_conf=None): + """对单帧进行目标检测""" + try: + # 使用自适应置信度或默认置信度 + conf_threshold = adaptive_conf if adaptive_conf is not None else self.confidence_threshold + + # 多尺度检测以提高检测率 + results = self.model(frame, conf=conf_threshold, verbose=False, + imgsz=640, augment=True) # 启用测试时增强 + + # 如果没有检测到目标且启用自适应,尝试降低置信度 + if self.adaptive_confidence and (not results or not results[0].boxes or len(results[0].boxes) == 0): + if conf_threshold > self.min_confidence: + lower_conf = max(self.min_confidence, conf_threshold - 0.1) + self.logger.debug(f"降低置信度重试: {lower_conf}") + results = self.model(frame, conf=lower_conf, verbose=False, + imgsz=640, augment=True) + + return results[0] if results else None + except Exception as e: + self.logger.error(f"检测失败: {e}") + return None + + def adaptive_detect_frame(self, frame, base_confidence): + """自适应检测帧""" + try: + # 尝试不同的置信度阈值 + confidence_levels = [base_confidence, base_confidence * 0.8, base_confidence * 0.6] + + for conf in confidence_levels: + if conf < self.min_confidence: + continue + + results = self.model(frame, conf=conf, verbose=False) + + if len(results) > 0 and len(results[0].boxes) > 0: + # 过滤低质量检测 + boxes = results[0].boxes + valid_indices = [] + + for i in range(len(boxes)): + confidence = float(boxes.conf[i]) + if confidence >= self.min_confidence: + valid_indices.append(i) + + if valid_indices: + # 直接使用原始结果,但只返回高质量检测 + return results + + return [] + + except Exception as e: + print(f"自适应检测失败: {e}") + return [] + + def draw_detections(self, frame, result): + """在帧上绘制检测结果""" + if result is None or result.boxes is None: + return frame + + # 转换为PIL图像以支持中文字体 + pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # 尝试加载中文字体 + try: + # Windows系统字体路径 + font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体 + + if os.path.exists(font_path): + font = ImageFont.truetype(font_path, 20) + else: + font = ImageFont.load_default() + except: + font = ImageFont.load_default() + + for box in result.boxes: + # 获取边界框坐标 + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + + # 获取置信度和类别 + confidence = float(box.conf[0].cpu().numpy()) + class_id = int(box.cls[0].cpu().numpy()) + + # 只显示中文类别名称 + display_name = config.get_class_name_cn(class_id) + + # 根据置信度调整颜色强度 + base_color = config.get_class_color(class_id) + intensity = min(1.0, confidence + 0.3) # 确保颜色不会太暗 + color = tuple(int(c * intensity) for c in base_color) + + # 绘制边界框(使用PIL) + bg_color = tuple(reversed(color)) # BGR转RGB + line_width = max(2, int(confidence * 4)) # 根据置信度调整线宽 + draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width) + + # 准备标签文字 + label = f"{display_name}: {confidence:.3f}" + + # 获取文字尺寸 + bbox = draw.textbbox((0, 0), label, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + # 绘制标签背景(使用PIL) + draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color) + + # 绘制文字(使用PIL) + draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255)) + + # 转换回OpenCV格式 + annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + return annotated_frame + + def process_video(self, video_path): + """处理单个视频文件""" + self.logger.info(f"开始处理视频: {video_path.name}") + + # 创建视频专用输出文件夹 + video_output_folder = self.output_folder / video_path.stem + video_output_folder.mkdir(exist_ok=True) + + # 文件夹将在检测到对应类别时按需创建 + + # 打开视频 + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + self.logger.error(f"无法打开视频文件: {video_path}") + return + + # 获取视频信息 + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = total_frames / fps if fps > 0 else 0 + + self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}秒") + + frame_count = 0 + detected_count = 0 + saved_count = 0 + + video_detections = { + 'video_name': video_path.name, + 'total_frames': total_frames, + 'fps': fps, + 'duration': duration, + 'detections': [] + } + + try: + while True: + ret, frame = cap.read() + if not ret: + break + + frame_count += 1 + + # 每100帧显示一次进度 + if frame_count % 100 == 0: + progress = (frame_count / total_frames) * 100 + self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})") + + # 进行目标检测 + result = self.detect_frame(frame) + + # 如果检测到目标,保存帧 + if result is not None and result.boxes is not None and len(result.boxes) > 0: + detected_count += 1 + timestamp = frame_count / fps if fps > 0 else frame_count + + # 获取当前帧检测到的所有类别 + detected_classes = set() + for box in result.boxes: + class_id = int(box.cls[0].cpu().numpy()) + detected_classes.add(class_id) + + # 只为检测到的类别创建文件夹并保存图片 + annotated_frame = self.draw_detections(frame, result) + detection_info = { + 'frame_number': frame_count, + 'timestamp': timestamp, + 'detections': [] + } + + # 按检测到的类别保存图片 + for class_id in detected_classes: + class_name_cn = config.get_class_name_cn(class_id) + class_name_en = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id) + + # 只为实际检测到的类别创建文件夹 + class_folder = video_output_folder / class_name_cn + class_folder.mkdir(exist_ok=True) + + # 保存检测到该类别的帧(只保存标注后的图片) + annotated_filename = f"detected_{frame_count:06d}_t{timestamp:.2f}s_{class_name_cn}.jpg" + annotated_path = class_folder / annotated_filename + cv2.imwrite(str(annotated_path), annotated_frame) + saved_count += 1 + + # 记录检测信息 + for box in result.boxes: + confidence = float(box.conf[0].cpu().numpy()) + class_id = int(box.cls[0].cpu().numpy()) + class_name = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id) + bbox = box.xyxy[0].cpu().numpy().tolist() + + detection_info['detections'].append({ + 'class_name': class_name, + 'class_name_cn': config.get_class_name_cn(class_id), + 'class_id': class_id, + 'confidence': confidence, + 'bbox': bbox + }) + + video_detections['detections'].append(detection_info) + + finally: + cap.release() + + # 保存检测结果到JSON文件 + json_path = video_output_folder / f"{video_path.stem}_detections.json" + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(video_detections, f, ensure_ascii=False, indent=2) + + self.logger.info(f"视频处理完成: {video_path.name}") + self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}") + + # 更新统计信息 + self.detection_stats['total_videos'] += 1 + self.detection_stats['total_frames'] += frame_count + self.detection_stats['detected_frames'] += detected_count + self.detection_stats['saved_frames'] += saved_count + self.detection_stats['detection_results'].append(video_detections) + + def process_all_videos(self): + """处理所有视频文件""" + video_files = self.get_video_files() + + if not video_files: + self.logger.warning("没有找到视频文件") + return + + self.logger.info(f"开始处理 {len(video_files)} 个视频文件") + + for i, video_path in enumerate(video_files, 1): + self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===") + self.process_video(video_path) + + # 保存总体统计信息 + self.save_summary_report() + + self.logger.info("\n=== 处理完成 ===") + self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']} 个") + self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']} 帧") + self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']} 帧") + self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']} 张") + + def save_summary_report(self): + """保存总结报告""" + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + report_path = self.output_folder / f"detection_summary_{timestamp}.json" + + summary = { + 'timestamp': timestamp, + 'model_path': str(self.model_path), + 'confidence_threshold': self.confidence_threshold, + 'input_folder': str(self.input_folder), + 'output_folder': str(self.output_folder), + 'statistics': self.detection_stats + } + + with open(report_path, 'w', encoding='utf-8') as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + + self.logger.info(f"总结报告已保存: {report_path}") + +def main(): + parser = argparse.ArgumentParser(description='视频目标检测系统') + parser.add_argument('--model', type=str, + default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt', + help='YOLO模型文件路径') + parser.add_argument('--confidence', type=float, default=0.5, + help='置信度阈值 (默认: 0.5)') + parser.add_argument('--input', type=str, default='input_videos', + help='输入视频文件夹 (默认: input_videos)') + parser.add_argument('--output', type=str, default='output_frames', + help='输出图片文件夹 (默认: output_frames)') + + args = parser.parse_args() + + # 创建检测系统 + detector = VideoDetectionSystem( + model_path=args.model, + confidence_threshold=args.confidence, + input_folder=args.input, + output_folder=args.output + ) + + # 处理所有视频 + detector.process_all_videos() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..293182c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +# 视频检测系统依赖包 + +# 核心依赖 +ultralytics>=8.0.0 +opencv-python>=4.5.0 +numpy>=1.21.0 +Pillow>=8.0.0 + +# 数据处理 +pandas>=1.3.0 + +# 进度显示 +tqdm>=4.62.0 + +# 配置文件处理 +PyYAML>=6.0 + +# 图像处理增强 +scipy>=1.7.0 +matplotlib>=3.4.0 + +# GPU支持 (可选) +# torch>=1.9.0 +# torchvision>=0.10.0 + +# 系统工具 +psutil>=5.8.0 + +# Web应用框架 +flask>=2.0.0 +werkzeug>=2.0.0 +flask-wtf>=1.0.0 +python-dotenv>=0.15.0 \ No newline at end of file