上传文件至 /

This commit is contained in:
Wang_Run_Ze 2025-06-27 17:13:03 +08:00
parent 14418af6bb
commit 9ccc897c3c
5 changed files with 2062 additions and 0 deletions

734
app.py Normal file
View File

@ -0,0 +1,734 @@
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, send_from_directory
import os
import shutil
import time
import threading
import json
from datetime import datetime
from pathlib import Path
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify
from werkzeug.utils import secure_filename
from config import config
from batch_detector import BatchVideoDetector
app = Flask(__name__)
app.secret_key = 'road_damage_detection_system'
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max upload size
# 确保必要的文件夹存在
for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']:
os.makedirs(folder, exist_ok=True)
# 全局变量用于跟踪检测进度和状态
detection_progress = 0
detection_status = "idle" # idle, running, completed, error
detection_thread = None
current_detection_info = {}
detection_start_time = None
detection_results = None
current_report_path = None
current_video_name = ""
total_videos = 0
current_video_index = 0
current_frame_count = 0
total_frame_count = 0
# 允许的文件扩展名
ALLOWED_VIDEO_EXTENSIONS = {ext.lstrip('.') for ext in config.SUPPORTED_VIDEO_FORMATS}
ALLOWED_MODEL_EXTENSIONS = {'pt', 'pth', 'weights'}
def allowed_file(filename, allowed_extensions):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
# 主页
@app.route('/', methods=['GET', 'POST'])
def index():
global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time
# 处理POST请求开始检测
if request.method == 'POST':
if detection_status == "running":
flash('已有检测任务正在运行', 'warning')
else:
# 获取选择的模型和置信度阈值
model_name = request.form.get('model_select')
confidence_threshold = float(request.form.get('confidence_threshold', 0.25))
# 检查是否有视频文件
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
if not videos:
flash('输入文件夹中没有视频文件,请先上传视频', 'danger')
elif not model_name:
flash('请选择一个检测模型', 'danger')
else:
model_path = os.path.join('models', model_name)
if not os.path.exists(model_path):
flash(f'模型文件 {model_name} 不存在', 'danger')
else:
# 更新配置
config.MODEL_PATH = model_path
config.CONFIDENCE_THRESHOLD = confidence_threshold
# 重置检测状态和进度
detection_progress = 0
detection_status = "running"
detection_results = None
current_report_path = None
detection_start_time = datetime.now()
# 启动检测线程
detection_thread = threading.Thread(target=run_detection)
detection_thread.daemon = True
detection_thread.start()
flash('检测已开始,请等待结果', 'info')
# 获取可用的模型列表
print(f"models文件夹路径: {os.path.abspath('models')}")
print(f"models文件夹存在: {os.path.exists('models')}")
if os.path.exists('models'):
all_files = os.listdir('models')
print(f"models文件夹中的所有文件: {all_files}")
for f in all_files:
print(f"检查文件 {f}: allowed_file结果 = {allowed_file(f, ALLOWED_MODEL_EXTENSIONS)}")
models = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
print(f"找到的模型文件: {models}")
# 获取输入文件夹中的视频文件
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
# 获取检测报告列表
reports = []
if os.path.exists(config.OUTPUT_FOLDER):
for report_file in os.listdir(config.OUTPUT_FOLDER):
if report_file.startswith('batch_detection_report_') and report_file.endswith('.json'):
report_path = os.path.join(config.OUTPUT_FOLDER, report_file)
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
reports.append({
'id': report_file,
'timestamp': report_data.get('timestamp', '未知'),
'video_count': report_data.get('statistics', {}).get('processed_videos', 0),
'total_frames': report_data.get('statistics', {}).get('total_frames', 0),
'detected_frames': report_data.get('statistics', {}).get('detected_frames', 0)
})
except Exception as e:
print(f"Error loading report {report_path}: {e}")
# 按时间戳排序,最新的在前面
reports.sort(key=lambda x: x['timestamp'], reverse=True)
return render_template('index.html',
models=models,
videos=videos,
reports=reports,
detection_status=detection_status,
detection_progress=detection_progress,
detection_in_progress=(detection_status == "running"),
confidence_threshold=config.CONFIDENCE_THRESHOLD,
config=config)
# 上传模型
@app.route('/upload_model', methods=['POST'])
def upload_model():
if 'model_file' not in request.files:
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
file = request.files['model_file']
if file.filename == '':
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
if file and allowed_file(file.filename, ALLOWED_MODEL_EXTENSIONS):
filename = secure_filename(file.filename)
model_path = os.path.join('models', filename)
file.save(model_path)
# 立即检测模型信息并缓存
try:
print(f"正在分析上传的模型: {filename}")
from ultralytics import YOLO
model = YOLO(model_path)
# 保存模型信息到JSON文件
model_info = {
'filename': filename,
'upload_time': datetime.now().isoformat(),
'classes': dict(model.names) if hasattr(model, 'names') else {},
'class_count': len(model.names) if hasattr(model, 'names') else 0,
'model_type': 'YOLO',
'analyzed': True
}
# 创建模型信息文件
info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json")
with open(info_path, 'w', encoding='utf-8') as f:
import json
json.dump(model_info, f, ensure_ascii=False, indent=2)
# 显示模型分析结果
if hasattr(model, 'names'):
scene_types = list(model.names.values())
flash(f'模型 {filename} 上传成功!检测到 {len(scene_types)} 种场景类型: {", ".join(scene_types)}', 'success')
else:
flash(f'模型 {filename} 上传成功!', 'success')
except Exception as e:
print(f"模型分析失败: {e}")
flash(f'模型 {filename} 上传成功,但分析模型信息时出错: {str(e)}', 'warning')
else:
flash('不支持的文件类型', 'danger')
return redirect(url_for('index'))
# 获取模型列表
@app.route('/get_models', methods=['GET'])
def get_models():
models = []
model_files = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
for model_file in model_files:
model_info = {
'filename': model_file,
'analyzed': False,
'classes': {},
'class_count': 0,
'scene_types': []
}
# 尝试读取模型信息文件
info_path = os.path.join('models', f"{os.path.splitext(model_file)[0]}_info.json")
if os.path.exists(info_path):
try:
with open(info_path, 'r', encoding='utf-8') as f:
import json
cached_info = json.load(f)
model_info.update(cached_info)
model_info['scene_types'] = list(cached_info.get('classes', {}).values())
except Exception as e:
print(f"读取模型信息文件失败: {e}")
models.append(model_info)
return jsonify(models)
# 删除模型
@app.route('/delete_model/<filename>', methods=['POST'])
def delete_model(filename):
try:
model_path = os.path.join('models', filename)
info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json")
deleted_files = []
if os.path.exists(model_path):
os.remove(model_path)
deleted_files.append('模型文件')
# 同时删除模型信息文件
if os.path.exists(info_path):
os.remove(info_path)
deleted_files.append('信息文件')
if deleted_files:
flash(f'模型 {filename} 及其{"".join(deleted_files)}已删除', 'success')
else:
flash(f'模型 {filename} 不存在', 'warning')
except Exception as e:
flash(f'删除模型时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 上传视频
@app.route('/upload_video', methods=['POST'])
def upload_video():
if 'video_files[]' not in request.files:
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
files = request.files.getlist('video_files[]')
if not files or files[0].filename == '':
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
success_count = 0
error_count = 0
for file in files:
if file and allowed_file(file.filename, ALLOWED_VIDEO_EXTENSIONS):
filename = secure_filename(file.filename)
file.save(os.path.join(config.INPUT_FOLDER, filename))
success_count += 1
else:
error_count += 1
if success_count > 0:
flash(f'成功上传 {success_count} 个视频文件', 'success')
if error_count > 0:
flash(f'{error_count} 个文件上传失败(不支持的文件类型)', 'warning')
return redirect(url_for('index'))
# 获取视频列表
@app.route('/get_videos', methods=['GET'])
def get_videos():
videos = [f for f in os.listdir(config.INPUT_FOLDER) if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
return jsonify(videos)
# 删除视频
@app.route('/delete_video/<filename>', methods=['POST'])
def delete_video(filename):
try:
video_path = os.path.join(config.INPUT_FOLDER, filename)
if os.path.exists(video_path):
os.remove(video_path)
flash(f'视频 {filename} 已删除', 'success')
else:
flash(f'视频 {filename} 不存在', 'warning')
except Exception as e:
flash(f'删除视频时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 清空输入文件夹
@app.route('/clear_input_folder', methods=['POST'])
def clear_input_folder():
try:
for filename in os.listdir(config.INPUT_FOLDER):
file_path = os.path.join(config.INPUT_FOLDER, filename)
if os.path.isfile(file_path):
os.unlink(file_path)
flash('输入文件夹已清空', 'success')
except Exception as e:
flash(f'清空文件夹时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 开始检测
@app.route('/start_detection', methods=['POST'])
def start_detection():
global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time
print("收到检测请求")
print(f"当前检测状态: {detection_status}")
if detection_status == "running":
flash('已有检测任务正在运行', 'warning')
return redirect(url_for('index'))
# 获取选择的模型和置信度阈值
model_name = request.form.get('model_select')
confidence_threshold = float(request.form.get('confidence_threshold', 0.25))
print(f"接收到的模型名称: {model_name}")
print(f"接收到的置信度阈值: {confidence_threshold}")
# 检查是否有视频文件
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
if not videos:
flash('输入文件夹中没有视频文件,请先上传视频', 'danger')
return redirect(url_for('index'))
# 检查模型文件是否存在
if not model_name:
flash('请选择一个检测模型', 'danger')
return redirect(url_for('index'))
model_path = os.path.join('models', model_name)
if not os.path.exists(model_path):
flash(f'模型文件 {model_name} 不存在', 'danger')
return redirect(url_for('index'))
# 更新配置
config.MODEL_PATH = model_path
config.CONFIDENCE_THRESHOLD = confidence_threshold
# 重置检测状态和进度
detection_progress = 0
detection_status = "running"
detection_results = None
current_report_path = None
detection_start_time = datetime.now()
# 重置进度相关变量
global current_video_name, total_videos, current_video_index, current_frame_count, total_frame_count
current_video_name = ""
total_videos = len(videos)
current_video_index = 0
current_frame_count = 0
total_frame_count = 0
# 启动检测线程
detection_thread = threading.Thread(target=run_detection)
detection_thread.daemon = True
detection_thread.start()
flash('检测已开始,请等待结果', 'info')
return redirect(url_for('index'))
# 检测线程函数
def run_detection():
global detection_progress, detection_status, detection_results, current_report_path, current_video_name, current_video_index, current_frame_count
try:
print(f"开始检测,模型路径: {config.MODEL_PATH}")
print(f"置信度阈值: {config.CONFIDENCE_THRESHOLD}")
print(f"输入文件夹: {config.INPUT_FOLDER}")
# 获取视频列表
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if any(f.endswith(ext) for ext in config.SUPPORTED_VIDEO_FORMATS)]
print(f"找到视频文件: {videos}")
# 创建检测报告基本信息
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# 创建检测器实例
print("正在创建检测器实例...")
detector = BatchVideoDetector(
model_path=config.MODEL_PATH,
confidence=config.CONFIDENCE_THRESHOLD
)
print("检测器实例创建成功")
# 显示模型检测能力信息
print("\n=== 网页端模型检测分析 ===")
if hasattr(detector.model, 'names'):
model_classes = list(detector.model.names.values())
print(f"当前模型可检测场景: {model_classes}")
print(f"使用的置信度阈值: {config.CONFIDENCE_THRESHOLD}")
print("=" * 40)
# 设置进度回调
def progress_callback(status, progress, video_name=None, video_index=None, frame_count=None):
global detection_status, detection_progress, current_video_name, current_video_index, current_frame_count
detection_status = status
detection_progress = progress
# 更新详细进度信息
if video_name:
current_video_name = video_name
if video_index is not None:
current_video_index = video_index
if frame_count is not None:
current_frame_count = frame_count
print(f"进度更新: {status} - {progress}%")
if video_name:
print(f"当前视频: {video_name} ({video_index}/{total_videos})")
if frame_count is not None:
print(f"已处理帧数: {frame_count}")
detector.set_progress_callback(progress_callback)
# 执行检测
detection_status = "开始检测视频..."
report_path = detector.process_all_videos(config.INPUT_FOLDER, config.OUTPUT_FOLDER)
# 读取检测结果
if report_path and os.path.exists(report_path):
with open(report_path, 'r', encoding='utf-8') as f:
detection_results = json.load(f)
current_report_path = report_path
detection_status = "检测完成"
detection_progress = 100
else:
detection_status = "检测完成,但未生成报告"
detection_progress = 100
except Exception as e:
detection_status = f"检测失败: {str(e)}"
print(f"检测异常: {str(e)}")
import traceback
traceback.print_exc()
finally:
# 如果状态仍为running则更新为失败
if detection_status == "running":
detection_status = "失败"
# 确保进度显示为100%
if detection_progress < 100:
detection_progress = 100
# 获取检测进度
@app.route('/detection_progress')
def get_detection_progress():
# 计算检测时长
elapsed_time = ""
if detection_start_time and detection_status == "running":
elapsed_seconds = (datetime.now() - detection_start_time).total_seconds()
minutes, seconds = divmod(elapsed_seconds, 60)
elapsed_time = f"{int(minutes)}{int(seconds)}"
# 构建详细状态信息
detailed_status = detection_status
if detection_status == "running" and current_video_name:
detailed_status = f"正在处理: {current_video_name} ({current_video_index}/{total_videos})"
if current_frame_count > 0:
detailed_status += f" - 已处理帧数: {current_frame_count}"
return jsonify({
'in_progress': detection_status == "running",
'progress': detection_progress,
'status': detailed_status,
'elapsed_time': elapsed_time,
'current_video': current_video_name,
'current_video_index': current_video_index,
'total_videos': total_videos,
'current_frame_count': current_frame_count,
'total_frame_count': total_frame_count
})
# 查看检测报告
@app.route('/report/<report_filename>')
def view_report(report_filename):
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
if not os.path.exists(report_path):
flash('报告文件不存在', 'danger')
return redirect(url_for('index'))
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
# 计算检测率
total_frames = report_data.get('statistics', {}).get('total_frames', 0)
detected_frames = report_data.get('statistics', {}).get('detected_frames', 0)
detection_rate = 0
if total_frames > 0:
detection_rate = (detected_frames / total_frames) * 100
# 为每个视频计算检测率
for video in report_data.get('videos', []):
video_total_frames = video.get('total_frames', 0)
video_detected_frames = video.get('detected_frames', 0)
if video_total_frames > 0:
video['detection_rate'] = (video_detected_frames / video_total_frames) * 100
else:
video['detection_rate'] = 0
return render_template('report.html',
report=report_data,
report_filename=report_filename,
detection_rate=detection_rate)
except Exception as e:
flash(f'读取报告时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 查看视频检测结果
@app.route('/video_results/<report_filename>/<video_name>')
def view_video_results(report_filename, video_name):
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
if not os.path.exists(report_path):
flash('报告文件不存在', 'danger')
return redirect(url_for('index'))
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
# 查找特定视频的结果
video_data = None
for video in report_data['videos']:
if video['video_name'] == video_name:
video_data = video
break
if not video_data:
flash('视频结果不存在', 'danger')
return redirect(url_for('view_report', report_filename=report_filename))
# 获取检测到的帧列表
video_output_dir = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video_name)[0])
detected_frames = []
if os.path.exists(video_output_dir):
for frame_file in os.listdir(video_output_dir):
if frame_file.startswith('detected_') and frame_file.endswith('.jpg'):
try:
frame_number = int(frame_file.split('_')[1].split('.')[0])
frame_path = os.path.join(os.path.basename(video_output_dir), frame_file)
# 获取对应的JSON文件以读取检测信息
json_file = frame_file.replace('.jpg', '.json')
json_path = os.path.join(video_output_dir, json_file)
detections = []
if os.path.exists(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
frame_data = json.load(f)
detections = frame_data.get('detections', [])
detected_frames.append({
'frame_number': frame_number,
'filename': frame_file,
'path': frame_path,
'timestamp': frame_number / video_data.get('fps', 30), # 计算时间戳
'detections': detections,
'detection_count': len(detections)
})
except Exception as e:
print(f"Error processing frame {frame_file}: {e}")
# 按帧号排序
detected_frames.sort(key=lambda x: x['frame_number'])
# 计算检测率
total_frames = video_data.get('total_frames', 0)
detected_count = video_data.get('detected_frames', 0)
detection_rate = 0
if total_frames > 0:
detection_rate = (detected_count / total_frames) * 100
# 获取时间线数据
timeline_data = []
for frame in detected_frames:
timeline_data.append({
'timestamp': frame['timestamp'],
'frame_number': frame['frame_number'],
'detection_count': frame['detection_count']
})
return render_template('video_results.html',
report_filename=report_filename,
report=report_data,
video=video_data,
detected_frames=detected_frames,
detection_rate=detection_rate,
timeline_data=json.dumps(timeline_data))
except Exception as e:
flash(f'读取视频结果时出错: {str(e)}', 'danger')
return redirect(url_for('view_report', report_filename=report_filename))
# 获取检测到的帧图像
@app.route('/frame/<path:frame_path>')
def get_frame(frame_path):
directory, filename = os.path.split(frame_path)
return send_from_directory(os.path.join(config.OUTPUT_FOLDER, directory), filename)
# 删除报告
@app.route('/delete_report/<report_filename>', methods=['POST'])
def delete_report(report_filename):
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
if not os.path.exists(report_path):
flash('报告文件不存在', 'danger')
return redirect(url_for('index'))
# 读取报告以获取视频文件夹列表
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
# 删除每个视频的输出文件夹
for video in report_data['videos']:
video_output_folder = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video['video_name'])[0])
if os.path.exists(video_output_folder) and os.path.isdir(video_output_folder):
shutil.rmtree(video_output_folder)
# 删除报告文件
os.remove(report_path)
flash(f'报告 {report_filename} 已删除', 'success')
except Exception as e:
flash(f'删除报告失败: {str(e)}', 'danger')
return redirect(url_for('index'))
# 添加BatchVideoDetector的进度回调方法
def add_progress_callback_to_detector():
# 检查BatchVideoDetector类是否已有set_progress_callback方法
if not hasattr(BatchVideoDetector, 'set_progress_callback'):
def set_progress_callback(self, callback):
self.progress_callback = callback
def update_progress(self, status, progress, video_name=None, video_index=None, frame_count=None):
if hasattr(self, 'progress_callback') and self.progress_callback:
self.progress_callback(status, progress, video_name, video_index, frame_count)
# 添加方法到BatchVideoDetector类
BatchVideoDetector.set_progress_callback = set_progress_callback
BatchVideoDetector.update_progress = update_progress
# 修改process_all_videos方法以支持进度更新
original_process_all_videos = BatchVideoDetector.process_all_videos
def process_all_videos_with_progress(self, input_folder="input_videos", output_folder="output_frames"):
self.update_progress("获取视频文件列表", 0)
video_files = self.get_video_files(input_folder)
total_videos = len(video_files)
if total_videos == 0:
self.update_progress("没有找到视频文件", 100)
return None
self.update_progress(f"开始处理 {total_videos} 个视频文件", 5)
# 处理每个视频文件
for i, video_path in enumerate(video_files, 1):
progress = int(5 + (i / total_videos) * 90)
video_name = video_path.name
self.update_progress(f"处理视频 {i}/{total_videos}: {video_name}", progress, video_name, i)
self.process_video(video_path, output_folder)
self.update_progress("生成检测报告", 95)
self.stats['end_time'] = datetime.now()
report_path = self.save_final_report(output_folder)
self.print_summary()
self.update_progress("检测完成", 100)
return report_path
# 修改save_final_report方法以返回报告路径
original_save_final_report = BatchVideoDetector.save_final_report
def save_final_report_with_return(self, output_folder):
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
# 计算总处理时间
total_seconds = 0
if self.stats['start_time'] and self.stats['end_time']:
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
# 创建报告数据
report_data = {
'timestamp': timestamp,
'model_path': self.model_path,
'confidence_threshold': self.confidence,
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
'statistics': {
'total_videos': self.stats['total_videos'],
'processed_videos': self.stats['processed_videos'],
'total_frames': self.stats['total_frames'],
'detected_frames': self.stats['detected_frames'],
'filtered_frames': self.stats['filtered_frames'],
'saved_images': self.stats['saved_images'],
'processing_time_seconds': total_seconds
},
'errors': self.stats['errors'],
'videos': []
}
# 保存为JSON
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report_data, f, ensure_ascii=False, indent=2)
print(f"\n批处理报告已保存: {report_path}")
return str(report_path)
BatchVideoDetector.process_all_videos = process_all_videos_with_progress
BatchVideoDetector.save_final_report = save_final_report_with_return
# 在应用启动前添加进度回调方法
add_progress_callback_to_detector()
if __name__ == '__main__':
# 确保必要的文件夹存在
for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']:
os.makedirs(folder, exist_ok=True)
app.run(debug=True, host='0.0.0.0', port=5000)

642
batch_detector.py Normal file
View File

@ -0,0 +1,642 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量视频检测器
简化版本专注于高效处理大量视频文件
"""
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import json
import time
from datetime import datetime
import torch
from config import config, DetectionConfig
from PIL import Image, ImageDraw, ImageFont
class BatchVideoDetector:
def __init__(self, model_path=None, confidence=0.5):
"""
初始化批量视频检测器
Args:
model_path (str): 模型路径如果为None则使用配置文件中的当前模型路径
confidence (float): 置信度阈值
"""
# 优先使用传入的模型路径,如果没有则使用配置中的当前模型路径
self.model_path = model_path or config.MODEL_PATH
self.confidence = confidence
# 创建必要文件夹
config.create_folders()
# 加载模型
self.load_model()
# 时间间隔过滤变量
self.last_saved_timestamp = 0
# 初始化场景类别变量
self.last_scene_classes = set()
# 动态检测参数
self.adaptive_confidence = True
self.min_confidence = 0.2
self.scene_confidence_history = [] # 记录场景置信度历史
self.detection_quality_threshold = 0.3 # 检测质量阈值
# 统计信息
self.stats = {
'start_time': None,
'end_time': None,
'total_videos': 0,
'processed_videos': 0,
'total_frames': 0,
'detected_frames': 0,
'saved_images': 0,
'filtered_frames': 0,
'adaptive_detections': 0, # 自适应检测次数
'errors': []
}
# 视频详细信息列表
self.video_details = []
def load_model(self):
"""加载YOLO模型"""
try:
print(f"正在加载模型: {self.model_path}")
self.model = YOLO(self.model_path)
print(f"模型加载成功!")
# 检查是否有缓存的模型信息
model_filename = os.path.basename(self.model_path)
info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json")
cached_info_loaded = False
if os.path.exists(info_path):
try:
with open(info_path, 'r', encoding='utf-8') as f:
import json
cached_info = json.load(f)
if cached_info.get('analyzed', False):
print(f"使用缓存的模型信息,跳过重复分析")
print(f"模型类别数量: {cached_info.get('class_count', 0)}")
print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}")
cached_info_loaded = True
except Exception as e:
print(f"读取缓存模型信息失败: {e},将重新分析")
# 如果没有缓存信息,则进行完整分析
if not cached_info_loaded:
# 动态加载模型类别信息到配置中
DetectionConfig.load_model_classes(self.model_path)
if hasattr(self.model, 'names'):
print(f"\n=== 模型检测场景分析 ===")
print(f"模型类别数量: {len(self.model.names)}")
print(f"模型原始类别: {list(self.model.names.values())}")
print(f"\n=== 中文类别映射 ===")
for class_id, class_name in self.model.names.items():
class_name_cn = config.get_class_name_cn(class_id)
print(f"类别 {class_id}: {class_name} -> {class_name_cn}")
# 分析模型类型并优化参数
self.analyze_and_optimize_model()
print(f"\n=== 检测场景总结 ===")
scene_types = list(DetectionConfig.get_all_classes().values())
print(f"本模型可检测的场景类型: {scene_types}")
print(f"检测置信度阈值: {self.confidence}")
print("=" * 40)
else:
# 使用缓存信息更新配置
DetectionConfig.load_model_classes(self.model_path)
except Exception as e:
print(f"模型加载失败: {e}")
raise
def analyze_and_optimize_model(self):
"""分析模型类型并优化检测参数"""
model_classes = list(self.model.names.values())
# 检测模型类型
road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤']
general_object_keywords = ['person', 'car', 'truck', 'bus', 'bicycle']
is_road_damage = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords)
is_general_object = any(keyword in ' '.join(model_classes).lower() for keyword in general_object_keywords)
if is_road_damage:
print("检测到道路损伤专用模型,优化参数设置")
self.confidence = max(0.25, self.confidence)
self.min_confidence = 0.15
self.detection_quality_threshold = 0.2
elif is_general_object:
print("检测到通用目标检测模型,使用标准参数")
self.confidence = max(0.4, self.confidence)
self.min_confidence = 0.3
self.detection_quality_threshold = 0.35
else:
print("检测到自定义模型,使用保守参数")
self.confidence = max(0.3, self.confidence)
self.min_confidence = 0.2
self.detection_quality_threshold = 0.25
print(f"优化后的置信度阈值: {self.confidence}")
print(f"最低置信度阈值: {self.min_confidence}")
print(f"检测质量阈值: {self.detection_quality_threshold}")
def get_video_files(self, input_folder="input_videos"):
"""获取所有视频文件"""
input_path = Path(input_folder)
if not input_path.exists():
print(f"输入文件夹不存在: {input_folder}")
return []
video_files = []
for ext in config.SUPPORTED_VIDEO_FORMATS:
video_files.extend(input_path.rglob(f"*{ext}"))
print(f"找到 {len(video_files)} 个视频文件")
return video_files
def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir):
"""检测单帧并保存结果"""
try:
# 进行多级检测
result = self.adaptive_detect_frame(frame)
# 如果检测到目标
if result is not None and result.boxes is not None and len(result.boxes) > 0:
# 获取当前帧的检测类别
current_scene_classes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
current_scene_classes.add(class_id)
# 检查时间间隔
time_diff = timestamp - self.last_saved_timestamp
# 如果时间间隔小于最小间隔,检查是否为新场景
if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0:
# 如果当前帧的场景类别与上一帧相同,则过滤
if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes:
self.stats['filtered_frames'] += 1
# 仍然返回检测信息,但不保存图像
detections = []
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
bbox = box.xyxy[0].cpu().numpy().tolist()
class_name = self.model.names.get(class_id, f"class_{class_id}")
class_name_cn = config.get_class_name_cn(class_id)
detections.append({
'class_id': class_id,
'class_name': class_name,
'class_name_cn': class_name_cn,
'confidence': confidence,
'bbox': bbox,
'filtered': True
})
return {
'frame_number': frame_number,
'timestamp': timestamp,
'detections': detections,
'filtered': True
}
# 更新最后检测到的场景类别
self.last_scene_classes = current_scene_classes
# 更新最后保存的时间戳
self.last_saved_timestamp = timestamp
# 创建文件名
base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s"
# 获取检测到的场景类别,用于分类保存
detected_scenes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
class_name_cn = config.get_class_name_cn(class_id)
detected_scenes.add(class_name_cn)
# 为每个检测到的场景创建对应的文件夹并保存图片
for scene_name in detected_scenes:
# 创建视频名称文件夹
video_folder = output_dir / video_name
video_folder.mkdir(parents=True, exist_ok=True)
# 创建场景文件夹
scene_folder = video_folder / scene_name
scene_folder.mkdir(parents=True, exist_ok=True)
# 保存原始帧
if config.SAVE_ORIGINAL_FRAMES:
original_path = scene_folder / f"{base_name}_original.jpg"
cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
# 绘制并保存标注帧
if config.SAVE_ANNOTATED_FRAMES:
annotated_frame = self.draw_detections(frame.copy(), result)
annotated_path = scene_folder / f"{base_name}_detected.jpg"
cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
# 收集检测信息
detections = []
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
bbox = box.xyxy[0].cpu().numpy().tolist()
class_name = self.model.names.get(class_id, f"class_{class_id}")
class_name_cn = config.get_class_name_cn(class_id)
detections.append({
'class_id': class_id,
'class_name': class_name,
'class_name_cn': class_name_cn,
'confidence': confidence,
'bbox': bbox
})
return {
'frame_number': frame_number,
'timestamp': timestamp,
'detections': detections
}
return None
except Exception as e:
print(f"处理帧 {frame_number} 时出错: {e}")
self.stats['errors'].append(f"{frame_number}: {str(e)}")
return None
def adaptive_detect_frame(self, frame):
"""自适应检测单帧"""
try:
# 第一次检测:使用标准置信度
results = self.model(frame, conf=self.confidence, verbose=False,
imgsz=640, augment=True)
result = results[0] if results else None
# 如果没有检测到目标且启用自适应检测
if self.adaptive_confidence and (not result or not result.boxes or len(result.boxes) == 0):
# 尝试降低置信度检测
lower_conf = max(self.min_confidence, self.confidence - 0.15)
if lower_conf < self.confidence:
results = self.model(frame, conf=lower_conf, verbose=False,
imgsz=640, augment=True)
result = results[0] if results else None
if result and result.boxes and len(result.boxes) > 0:
self.stats['adaptive_detections'] += 1
print(f"自适应检测成功,置信度: {lower_conf:.2f}")
# 过滤低质量检测
if result and result.boxes and len(result.boxes) > 0:
valid_indices = []
for i, box in enumerate(result.boxes):
confidence = float(box.conf[0].cpu().numpy())
if confidence >= self.detection_quality_threshold:
valid_indices.append(i)
# 如果有有效检测,保留原始结果结构
if not valid_indices:
result = None
return result
except Exception as e:
print(f"自适应检测失败: {e}")
return None
def draw_detections(self, frame, result):
"""在帧上绘制检测结果(绘制边界框和中文标注)"""
if result is None or result.boxes is None:
return frame
# 转换为PIL图像
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统常用中文字体路径
font_paths = [
"C:/Windows/Fonts/simhei.ttf", # 黑体
"C:/Windows/Fonts/simsun.ttc", # 宋体
"C:/Windows/Fonts/msyh.ttc", # 微软雅黑
]
font = None
for font_path in font_paths:
try:
font = ImageFont.truetype(font_path, 20)
break
except:
continue
if font is None:
font = ImageFont.load_default()
except:
font = ImageFont.load_default()
for box in result.boxes:
# 获取坐标和信息
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
# 动态获取类别名称(优先使用模型原始名称)
if hasattr(self.model, 'names') and class_id in self.model.names:
original_name = self.model.names[class_id]
class_name_cn = config.get_class_name_cn(class_id)
# 如果中文名称就是原始名称,说明没有映射,直接使用原始名称
if class_name_cn == original_name or class_name_cn.startswith('未知类别'):
display_name = original_name
else:
display_name = f"{class_name_cn}({original_name})"
else:
display_name = config.get_class_name_cn(class_id)
# 根据置信度调整颜色强度和线宽
base_color = config.get_class_color(class_id)
intensity = min(1.0, confidence + 0.3)
color = tuple(int(c * intensity) for c in base_color)
bg_color = tuple(reversed(color)) # BGR转RGB
line_width = max(2, int(confidence * 4))
# 绘制边界框
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width)
# 准备标注文字
label_text = f"{display_name}: {confidence:.3f}"
# 计算文字背景框大小
bbox = draw.textbbox((0, 0), label_text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 确保标注不超出图像边界
label_x = max(0, x1)
label_y = max(0, y1 - text_height - 5)
if label_y < 0:
label_y = y1 + 5
# 绘制文字背景
draw.rectangle(
[label_x, label_y, label_x + text_width + 4, label_y + text_height + 4],
fill=bg_color
)
# 绘制文字(白色)
draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font)
# 转换回OpenCV格式
frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return frame
def process_video(self, video_path, output_base_dir="output_frames"):
"""处理单个视频"""
print(f"\n开始处理: {video_path.name}")
# 创建输出目录
output_dir = Path(output_base_dir) / video_path.stem
output_dir.mkdir(parents=True, exist_ok=True)
# 文件夹将在检测到对应类别时按需创建
# 重置时间间隔过滤变量,确保每个视频独立计算时间间隔
self.last_saved_timestamp = 0
# 重置场景类别变量,确保每个视频独立分类场景
self.last_scene_classes = set()
# 打开视频
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
error_msg = f"无法打开视频: {video_path}"
print(error_msg)
self.stats['errors'].append(error_msg)
return
# 获取视频信息
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS")
# 处理统计
frame_count = 0
detected_frames = 0
saved_images = 0
filtered_frames = 0
all_detections = []
start_time = time.time()
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
timestamp = frame_count / fps if fps > 0 else frame_count
# 显示进度
if frame_count % config.PROGRESS_INTERVAL == 0:
progress = (frame_count / total_frames) * 100
elapsed = time.time() - start_time
fps_current = frame_count / elapsed
print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS")
# 更新进度回调
if hasattr(self, 'update_progress'):
self.update_progress(f"处理视频帧", None, None, None, frame_count)
# 检测并保存
detection_result = self.detect_and_save_frame(
frame, frame_count, timestamp, video_path.stem, output_dir
)
if detection_result:
detected_frames += 1
all_detections.append(detection_result)
# 检查是否是被过滤的帧
if detection_result.get('filtered', False):
filtered_frames += 1
else:
# 计算保存的图片数量
if config.SAVE_ORIGINAL_FRAMES:
saved_images += 1
if config.SAVE_ANNOTATED_FRAMES:
saved_images += 1
finally:
cap.release()
# 保存检测结果JSON
if config.SAVE_DETECTION_JSON and all_detections:
json_path = output_dir / f"{video_path.stem}_detections.json"
detection_summary = {
'video_name': video_path.name,
'total_frames': frame_count,
'detected_frames': detected_frames,
'fps': fps,
'detections': all_detections,
'processing_time': time.time() - start_time
}
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(detection_summary, f, ensure_ascii=False, indent=2)
# 更新统计
self.stats['processed_videos'] += 1
self.stats['total_frames'] += frame_count
self.stats['detected_frames'] += detected_frames
self.stats['saved_images'] += saved_images
self.stats['filtered_frames'] += filtered_frames
processing_time = time.time() - start_time
print(f"处理完成: {video_path.name}")
print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images}")
print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
print(f"处理时间: {processing_time:.2f}")
# 记录视频详细信息
video_info = {
'video_name': video_path.name,
'video_path': str(video_path),
'total_frames': frame_count,
'detected_frames': detected_frames,
'saved_images': saved_images,
'filtered_frames': filtered_frames,
'processing_time': processing_time,
'fps': fps,
'duration': frame_count / fps if fps > 0 else 0,
'output_directory': str(output_dir)
}
self.video_details.append(video_info)
def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"):
"""批量处理所有视频"""
self.stats['start_time'] = datetime.now()
video_files = self.get_video_files(input_folder)
if not video_files:
print("没有找到视频文件")
return
self.stats['total_videos'] = len(video_files)
print(f"\n开始批量处理 {len(video_files)} 个视频文件")
print(f"模型: {self.model_path}")
print(f"置信度阈值: {self.confidence}")
print(f"输出目录: {output_folder}")
print("=" * 50)
for i, video_path in enumerate(video_files, 1):
print(f"\n[{i}/{len(video_files)}] 处理视频")
self.process_video(video_path, output_folder)
self.stats['end_time'] = datetime.now()
report_path = self.save_final_report(output_folder)
self.print_summary()
return report_path
def save_final_report(self, output_folder):
"""保存最终批处理报告"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
# 计算总处理时间
total_seconds = 0
if self.stats['start_time'] and self.stats['end_time']:
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
# 创建报告数据
report_data = {
'timestamp': timestamp,
'model_path': self.model_path,
'confidence_threshold': self.confidence,
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
'statistics': {
'total_videos': self.stats['total_videos'],
'processed_videos': self.stats['processed_videos'],
'total_frames': self.stats['total_frames'],
'detected_frames': self.stats['detected_frames'],
'filtered_frames': self.stats['filtered_frames'],
'saved_images': self.stats['saved_images'],
'processing_time_seconds': total_seconds
},
'videos': self.video_details,
'errors': self.stats['errors']
}
# 保存为JSON
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report_data, f, ensure_ascii=False, indent=2)
print(f"\n批处理报告已保存: {report_path}")
return str(report_path)
def print_summary(self):
"""打印处理摘要"""
if not self.stats['end_time'] or not self.stats['start_time']:
return
total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
print("\n" + "=" * 50)
print("批量处理摘要")
print("=" * 50)
print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}")
print(f"总处理帧数: {self.stats['total_frames']}")
print(f"检测到目标的帧数: {self.stats['detected_frames']}")
print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次")
print(f"实际保存的图片数量: {self.stats['saved_images']}")
print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}")
if self.stats['errors']:
print("\n处理过程中的错误:")
for error in self.stats['errors']:
print(f"- {error}")
print("=" * 50)
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='批量视频检测系统')
parser.add_argument('--model', type=str, help='模型文件路径')
parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值')
parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹')
parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹')
args = parser.parse_args()
# 创建检测器
detector = BatchVideoDetector(
model_path=args.model,
confidence=args.confidence
)
# 开始批量处理
detector.process_all_videos(args.input, args.output)
if __name__ == '__main__':
main()

200
config.py Normal file
View File

@ -0,0 +1,200 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频检测系统配置文件
"""
import os
from pathlib import Path
from ultralytics import YOLO
class DetectionConfig:
"""检测系统配置类"""
# 模型配置
MODEL_PATH = "models/best.pt"
CONFIDENCE_THRESHOLD = 0.5
# 动态模型信息
_model_classes = None
_model_colors = None
# 文件夹配置
INPUT_FOLDER = "input_videos"
OUTPUT_FOLDER = "output_frames"
LOG_FOLDER = "logs"
# 视频处理配置
SUPPORTED_VIDEO_FORMATS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
PROGRESS_INTERVAL = 100 # 每处理多少帧显示一次进度
MIN_FRAME_INTERVAL = 1.0 # 检测帧的最小时间间隔(秒)
# 输出配置
SAVE_ORIGINAL_FRAMES = False # 是否保存原始帧
SAVE_ANNOTATED_FRAMES = True # 是否保存标注帧
SAVE_DETECTION_JSON = True # 是否保存检测结果JSON
# 图像质量配置
IMAGE_QUALITY = 95 # JPEG质量 (1-100)
# 默认道路损伤类别映射 (中文标签) - 作为备用
DEFAULT_CLASS_NAMES_CN = {
0: "纵向裂缝",
1: "横向裂缝",
2: "网状裂缝",
3: "坑洞",
4: "白线模糊"
}
# 通用类别关键词映射(扩展版)
COMMON_CLASS_MAPPINGS = {
# 道路损伤相关
'crack': '裂缝',
'pothole': '坑洞',
'damage': '损伤',
'longitudinal': '纵向裂缝',
'transverse': '横向裂缝',
'alligator': '网状裂缝',
'rutting': '车辙',
'bleeding': '泛油',
'weathering': '风化',
'patching': '修补',
# 交通工具
'person': '',
'people': '人群',
'car': '汽车',
'truck': '卡车',
'bus': '公交车',
'motorcycle': '摩托车',
'bicycle': '自行车',
'bike': '自行车',
'vehicle': '车辆',
# 交通设施
'road': '道路',
'traffic': '交通',
'sign': '标志',
'light': '信号灯',
'signal': '信号',
'stop': '停车标志',
'yield': '让行标志',
'crosswalk': '人行横道',
'lane': '车道',
'marking': '标线',
'white': '白线',
'yellow': '黄线',
# 其他常见目标
'animal': '动物',
'dog': '',
'cat': '',
'bird': '',
'tree': '',
'building': '建筑',
'pole': '杆子',
'fence': '围栏'
}
# 动态颜色生成配置
COLOR_PALETTE = [
(0, 255, 0), # 绿色
(255, 0, 0), # 红色
(0, 0, 255), # 蓝色
(255, 255, 0), # 黄色
(255, 0, 255), # 紫色
(0, 255, 255), # 青色
(255, 128, 0), # 橙色
(128, 0, 255), # 紫罗兰
(255, 192, 203), # 粉色
(128, 128, 128) # 灰色
]
# 默认类别颜色配置 (BGR格式)
DEFAULT_CLASS_COLORS = {
0: (0, 255, 0), # 绿色
1: (255, 0, 0), # 红色
2: (0, 0, 255), # 蓝色
3: (255, 255, 0), # 黄色
4: (255, 0, 255) # 紫色
}
@classmethod
def get_model_path(cls):
"""获取模型文件的绝对路径"""
current_dir = Path(__file__).parent
model_path = current_dir / cls.MODEL_PATH
return str(model_path.resolve())
@classmethod
def create_folders(cls):
"""创建必要的文件夹"""
folders = [cls.INPUT_FOLDER, cls.OUTPUT_FOLDER, cls.LOG_FOLDER]
for folder in folders:
Path(folder).mkdir(parents=True, exist_ok=True)
@classmethod
def load_model_classes(cls, model_path):
"""从模型文件动态加载类别信息"""
try:
model = YOLO(model_path)
cls._model_classes = model.names
# 动态生成类别颜色
cls._model_colors = {}
for i in range(len(model.names)):
# 使用调色板循环分配颜色
color_index = i % len(cls.COLOR_PALETTE)
cls._model_colors[i] = cls.COLOR_PALETTE[color_index]
print(f"成功加载 {len(model.names)} 个类别,分配了对应颜色")
return True
except Exception as e:
print(f"加载模型类别信息失败: {e}")
return False
@classmethod
def get_class_name_cn(cls, class_id):
"""获取类别的中文名称"""
# 直接使用模型的原始类别名称,不进行任何映射
if cls._model_classes is not None:
original_name = cls._model_classes.get(class_id)
if original_name:
return original_name
return f"未知类别_{class_id}"
# 回退到默认中文类别
return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, f"未知类别_{class_id}")
@classmethod
def get_class_color(cls, class_id):
"""获取类别对应的颜色"""
# 如果class_id是字符串尝试转换为整数或使用hash
if isinstance(class_id, str):
# 尝试从类别名称获取ID
if cls._model_classes and class_id in cls._model_classes.values():
# 找到对应的ID
for id_val, name in cls._model_classes.items():
if name == class_id:
class_id = id_val
break
else:
# 使用hash生成一个稳定的索引
class_id = hash(class_id) % len(cls.COLOR_PALETTE)
# 确保class_id是整数
if not isinstance(class_id, int):
class_id = 0
if cls._model_colors is not None:
return cls._model_colors.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
return cls.DEFAULT_CLASS_COLORS.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
@classmethod
def get_all_classes(cls):
"""获取所有类别信息"""
if cls._model_classes is not None:
return cls._model_classes
return cls.DEFAULT_CLASS_NAMES_CN
# 全局配置实例
config = DetectionConfig()

453
detector.py Normal file
View File

@ -0,0 +1,453 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频目标检测系统
使用YOLO模型对视频进行道路损伤检测并抽取包含目标的帧
"""
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import argparse
from datetime import datetime
import json
import logging
import torch
from PIL import Image, ImageDraw, ImageFont
from config import config, DetectionConfig
class VideoDetectionSystem:
def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"):
"""
初始化视频检测系统
Args:
model_path (str): YOLO模型文件路径
confidence_threshold (float): 置信度阈值
input_folder (str): 输入视频文件夹
output_folder (str): 输出图片文件夹
"""
self.model_path = model_path
self.confidence_threshold = confidence_threshold
self.input_folder = Path(input_folder)
self.output_folder = Path(output_folder)
# 创建输出文件夹
self.output_folder.mkdir(parents=True, exist_ok=True)
# 设置日志
self.setup_logging()
# 加载模型
self.load_model()
# 支持的视频格式
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
# 动态检测参数
self.adaptive_confidence = True # 启用自适应置信度
self.min_confidence = 0.2 # 最低置信度
self.max_confidence = 0.8 # 最高置信度
self.scene_adaptation = True # 启用场景自适应
# 检测结果统计
self.detection_stats = {
'total_videos': 0,
'total_frames': 0,
'detected_frames': 0,
'saved_frames': 0,
'detection_results': []
}
# 统计信息
self.stats = {
'adaptive_detections': 0
}
# 检测质量阈值
self.detection_quality_threshold = 0.3
def setup_logging(self):
"""设置日志记录"""
log_folder = Path('logs')
log_folder.mkdir(exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_folder / f'detection_{timestamp}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler()
]
)
self.logger = logging.getLogger(__name__)
def load_model(self):
"""加载YOLO模型"""
try:
self.model = YOLO(self.model_path)
self.logger.info(f"成功加载模型: {self.model_path}")
# 动态加载模型类别信息到配置中
DetectionConfig.load_model_classes(self.model_path)
# 打印模型信息
if hasattr(self.model, 'names'):
self.logger.info(f"模型类别数量: {len(self.model.names)}")
self.logger.info(f"检测类别: {list(self.model.names.values())}")
self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}")
# 分析模型类型并设置优化参数
self.analyze_model_type()
except Exception as e:
self.logger.error(f"加载模型失败: {e}")
raise
def analyze_model_type(self):
"""分析模型类型并设置相应的检测参数"""
model_classes = list(self.model.names.values())
# 检测是否为道路损伤专用模型
road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤']
is_road_damage_model = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords)
if is_road_damage_model:
self.logger.info("检测到道路损伤专用模型,使用优化参数")
self.confidence_threshold = max(0.25, self.confidence_threshold) # 道路损伤检测使用较低阈值
self.min_confidence = 0.15
else:
self.logger.info("检测到通用模型,使用标准参数")
self.confidence_threshold = max(0.4, self.confidence_threshold) # 通用模型使用较高阈值
self.min_confidence = 0.3
self.logger.info(f"调整后的置信度阈值: {self.confidence_threshold}")
def get_video_files(self):
"""获取输入文件夹中的所有视频文件"""
video_files = []
if not self.input_folder.exists():
self.logger.warning(f"输入文件夹不存在: {self.input_folder}")
return video_files
for file_path in self.input_folder.rglob('*'):
if file_path.is_file() and file_path.suffix.lower() in self.video_extensions:
video_files.append(file_path)
self.logger.info(f"找到 {len(video_files)} 个视频文件")
return video_files
def detect_frame(self, frame, adaptive_conf=None):
"""对单帧进行目标检测"""
try:
# 使用自适应置信度或默认置信度
conf_threshold = adaptive_conf if adaptive_conf is not None else self.confidence_threshold
# 多尺度检测以提高检测率
results = self.model(frame, conf=conf_threshold, verbose=False,
imgsz=640, augment=True) # 启用测试时增强
# 如果没有检测到目标且启用自适应,尝试降低置信度
if self.adaptive_confidence and (not results or not results[0].boxes or len(results[0].boxes) == 0):
if conf_threshold > self.min_confidence:
lower_conf = max(self.min_confidence, conf_threshold - 0.1)
self.logger.debug(f"降低置信度重试: {lower_conf}")
results = self.model(frame, conf=lower_conf, verbose=False,
imgsz=640, augment=True)
return results[0] if results else None
except Exception as e:
self.logger.error(f"检测失败: {e}")
return None
def adaptive_detect_frame(self, frame, base_confidence):
"""自适应检测帧"""
try:
# 尝试不同的置信度阈值
confidence_levels = [base_confidence, base_confidence * 0.8, base_confidence * 0.6]
for conf in confidence_levels:
if conf < self.min_confidence:
continue
results = self.model(frame, conf=conf, verbose=False)
if len(results) > 0 and len(results[0].boxes) > 0:
# 过滤低质量检测
boxes = results[0].boxes
valid_indices = []
for i in range(len(boxes)):
confidence = float(boxes.conf[i])
if confidence >= self.min_confidence:
valid_indices.append(i)
if valid_indices:
# 直接使用原始结果,但只返回高质量检测
return results
return []
except Exception as e:
print(f"自适应检测失败: {e}")
return []
def draw_detections(self, frame, result):
"""在帧上绘制检测结果"""
if result is None or result.boxes is None:
return frame
# 转换为PIL图像以支持中文字体
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统字体路径
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
if os.path.exists(font_path):
font = ImageFont.truetype(font_path, 20)
else:
font = ImageFont.load_default()
except:
font = ImageFont.load_default()
for box in result.boxes:
# 获取边界框坐标
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
# 获取置信度和类别
confidence = float(box.conf[0].cpu().numpy())
class_id = int(box.cls[0].cpu().numpy())
# 只显示中文类别名称
display_name = config.get_class_name_cn(class_id)
# 根据置信度调整颜色强度
base_color = config.get_class_color(class_id)
intensity = min(1.0, confidence + 0.3) # 确保颜色不会太暗
color = tuple(int(c * intensity) for c in base_color)
# 绘制边界框使用PIL
bg_color = tuple(reversed(color)) # BGR转RGB
line_width = max(2, int(confidence * 4)) # 根据置信度调整线宽
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width)
# 准备标签文字
label = f"{display_name}: {confidence:.3f}"
# 获取文字尺寸
bbox = draw.textbbox((0, 0), label, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 绘制标签背景使用PIL
draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color)
# 绘制文字使用PIL
draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255))
# 转换回OpenCV格式
annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return annotated_frame
def process_video(self, video_path):
"""处理单个视频文件"""
self.logger.info(f"开始处理视频: {video_path.name}")
# 创建视频专用输出文件夹
video_output_folder = self.output_folder / video_path.stem
video_output_folder.mkdir(exist_ok=True)
# 文件夹将在检测到对应类别时按需创建
# 打开视频
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
self.logger.error(f"无法打开视频文件: {video_path}")
return
# 获取视频信息
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}")
frame_count = 0
detected_count = 0
saved_count = 0
video_detections = {
'video_name': video_path.name,
'total_frames': total_frames,
'fps': fps,
'duration': duration,
'detections': []
}
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# 每100帧显示一次进度
if frame_count % 100 == 0:
progress = (frame_count / total_frames) * 100
self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})")
# 进行目标检测
result = self.detect_frame(frame)
# 如果检测到目标,保存帧
if result is not None and result.boxes is not None and len(result.boxes) > 0:
detected_count += 1
timestamp = frame_count / fps if fps > 0 else frame_count
# 获取当前帧检测到的所有类别
detected_classes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
detected_classes.add(class_id)
# 只为检测到的类别创建文件夹并保存图片
annotated_frame = self.draw_detections(frame, result)
detection_info = {
'frame_number': frame_count,
'timestamp': timestamp,
'detections': []
}
# 按检测到的类别保存图片
for class_id in detected_classes:
class_name_cn = config.get_class_name_cn(class_id)
class_name_en = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
# 只为实际检测到的类别创建文件夹
class_folder = video_output_folder / class_name_cn
class_folder.mkdir(exist_ok=True)
# 保存检测到该类别的帧(只保存标注后的图片)
annotated_filename = f"detected_{frame_count:06d}_t{timestamp:.2f}s_{class_name_cn}.jpg"
annotated_path = class_folder / annotated_filename
cv2.imwrite(str(annotated_path), annotated_frame)
saved_count += 1
# 记录检测信息
for box in result.boxes:
confidence = float(box.conf[0].cpu().numpy())
class_id = int(box.cls[0].cpu().numpy())
class_name = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
bbox = box.xyxy[0].cpu().numpy().tolist()
detection_info['detections'].append({
'class_name': class_name,
'class_name_cn': config.get_class_name_cn(class_id),
'class_id': class_id,
'confidence': confidence,
'bbox': bbox
})
video_detections['detections'].append(detection_info)
finally:
cap.release()
# 保存检测结果到JSON文件
json_path = video_output_folder / f"{video_path.stem}_detections.json"
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(video_detections, f, ensure_ascii=False, indent=2)
self.logger.info(f"视频处理完成: {video_path.name}")
self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}")
# 更新统计信息
self.detection_stats['total_videos'] += 1
self.detection_stats['total_frames'] += frame_count
self.detection_stats['detected_frames'] += detected_count
self.detection_stats['saved_frames'] += saved_count
self.detection_stats['detection_results'].append(video_detections)
def process_all_videos(self):
"""处理所有视频文件"""
video_files = self.get_video_files()
if not video_files:
self.logger.warning("没有找到视频文件")
return
self.logger.info(f"开始处理 {len(video_files)} 个视频文件")
for i, video_path in enumerate(video_files, 1):
self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===")
self.process_video(video_path)
# 保存总体统计信息
self.save_summary_report()
self.logger.info("\n=== 处理完成 ===")
self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']}")
self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']}")
self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']}")
self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']}")
def save_summary_report(self):
"""保存总结报告"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = self.output_folder / f"detection_summary_{timestamp}.json"
summary = {
'timestamp': timestamp,
'model_path': str(self.model_path),
'confidence_threshold': self.confidence_threshold,
'input_folder': str(self.input_folder),
'output_folder': str(self.output_folder),
'statistics': self.detection_stats
}
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
self.logger.info(f"总结报告已保存: {report_path}")
def main():
parser = argparse.ArgumentParser(description='视频目标检测系统')
parser.add_argument('--model', type=str,
default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt',
help='YOLO模型文件路径')
parser.add_argument('--confidence', type=float, default=0.5,
help='置信度阈值 (默认: 0.5)')
parser.add_argument('--input', type=str, default='input_videos',
help='输入视频文件夹 (默认: input_videos)')
parser.add_argument('--output', type=str, default='output_frames',
help='输出图片文件夹 (默认: output_frames)')
args = parser.parse_args()
# 创建检测系统
detector = VideoDetectionSystem(
model_path=args.model,
confidence_threshold=args.confidence,
input_folder=args.input,
output_folder=args.output
)
# 处理所有视频
detector.process_all_videos()
if __name__ == '__main__':
main()

33
requirements.txt Normal file
View File

@ -0,0 +1,33 @@
# 视频检测系统依赖包
# 核心依赖
ultralytics>=8.0.0
opencv-python>=4.5.0
numpy>=1.21.0
Pillow>=8.0.0
# 数据处理
pandas>=1.3.0
# 进度显示
tqdm>=4.62.0
# 配置文件处理
PyYAML>=6.0
# 图像处理增强
scipy>=1.7.0
matplotlib>=3.4.0
# GPU支持 (可选)
# torch>=1.9.0
# torchvision>=0.10.0
# 系统工具
psutil>=5.8.0
# Web应用框架
flask>=2.0.0
werkzeug>=2.0.0
flask-wtf>=1.0.0
python-dotenv>=0.15.0