上传文件至 /
This commit is contained in:
parent
14418af6bb
commit
9ccc897c3c
734
app.py
Normal file
734
app.py
Normal file
@ -0,0 +1,734 @@
|
||||
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, send_from_directory
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import threading
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify
|
||||
from werkzeug.utils import secure_filename
|
||||
from config import config
|
||||
from batch_detector import BatchVideoDetector
|
||||
|
||||
app = Flask(__name__)
|
||||
app.secret_key = 'road_damage_detection_system'
|
||||
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max upload size
|
||||
|
||||
# 确保必要的文件夹存在
|
||||
for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']:
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
# 全局变量用于跟踪检测进度和状态
|
||||
detection_progress = 0
|
||||
detection_status = "idle" # idle, running, completed, error
|
||||
detection_thread = None
|
||||
current_detection_info = {}
|
||||
detection_start_time = None
|
||||
detection_results = None
|
||||
current_report_path = None
|
||||
current_video_name = ""
|
||||
total_videos = 0
|
||||
current_video_index = 0
|
||||
current_frame_count = 0
|
||||
total_frame_count = 0
|
||||
|
||||
# 允许的文件扩展名
|
||||
ALLOWED_VIDEO_EXTENSIONS = {ext.lstrip('.') for ext in config.SUPPORTED_VIDEO_FORMATS}
|
||||
ALLOWED_MODEL_EXTENSIONS = {'pt', 'pth', 'weights'}
|
||||
|
||||
def allowed_file(filename, allowed_extensions):
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
|
||||
|
||||
# 主页
|
||||
@app.route('/', methods=['GET', 'POST'])
|
||||
def index():
|
||||
global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time
|
||||
|
||||
# 处理POST请求(开始检测)
|
||||
if request.method == 'POST':
|
||||
if detection_status == "running":
|
||||
flash('已有检测任务正在运行', 'warning')
|
||||
else:
|
||||
# 获取选择的模型和置信度阈值
|
||||
model_name = request.form.get('model_select')
|
||||
confidence_threshold = float(request.form.get('confidence_threshold', 0.25))
|
||||
|
||||
# 检查是否有视频文件
|
||||
videos = [f for f in os.listdir(config.INPUT_FOLDER)
|
||||
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
|
||||
if not videos:
|
||||
flash('输入文件夹中没有视频文件,请先上传视频', 'danger')
|
||||
elif not model_name:
|
||||
flash('请选择一个检测模型', 'danger')
|
||||
else:
|
||||
model_path = os.path.join('models', model_name)
|
||||
if not os.path.exists(model_path):
|
||||
flash(f'模型文件 {model_name} 不存在', 'danger')
|
||||
else:
|
||||
# 更新配置
|
||||
config.MODEL_PATH = model_path
|
||||
config.CONFIDENCE_THRESHOLD = confidence_threshold
|
||||
|
||||
# 重置检测状态和进度
|
||||
detection_progress = 0
|
||||
detection_status = "running"
|
||||
detection_results = None
|
||||
current_report_path = None
|
||||
detection_start_time = datetime.now()
|
||||
|
||||
# 启动检测线程
|
||||
detection_thread = threading.Thread(target=run_detection)
|
||||
detection_thread.daemon = True
|
||||
detection_thread.start()
|
||||
|
||||
flash('检测已开始,请等待结果', 'info')
|
||||
|
||||
# 获取可用的模型列表
|
||||
print(f"models文件夹路径: {os.path.abspath('models')}")
|
||||
print(f"models文件夹存在: {os.path.exists('models')}")
|
||||
if os.path.exists('models'):
|
||||
all_files = os.listdir('models')
|
||||
print(f"models文件夹中的所有文件: {all_files}")
|
||||
for f in all_files:
|
||||
print(f"检查文件 {f}: allowed_file结果 = {allowed_file(f, ALLOWED_MODEL_EXTENSIONS)}")
|
||||
models = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
|
||||
print(f"找到的模型文件: {models}")
|
||||
|
||||
# 获取输入文件夹中的视频文件
|
||||
videos = [f for f in os.listdir(config.INPUT_FOLDER)
|
||||
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
|
||||
|
||||
# 获取检测报告列表
|
||||
reports = []
|
||||
if os.path.exists(config.OUTPUT_FOLDER):
|
||||
for report_file in os.listdir(config.OUTPUT_FOLDER):
|
||||
if report_file.startswith('batch_detection_report_') and report_file.endswith('.json'):
|
||||
report_path = os.path.join(config.OUTPUT_FOLDER, report_file)
|
||||
try:
|
||||
with open(report_path, 'r', encoding='utf-8') as f:
|
||||
report_data = json.load(f)
|
||||
reports.append({
|
||||
'id': report_file,
|
||||
'timestamp': report_data.get('timestamp', '未知'),
|
||||
'video_count': report_data.get('statistics', {}).get('processed_videos', 0),
|
||||
'total_frames': report_data.get('statistics', {}).get('total_frames', 0),
|
||||
'detected_frames': report_data.get('statistics', {}).get('detected_frames', 0)
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Error loading report {report_path}: {e}")
|
||||
|
||||
# 按时间戳排序,最新的在前面
|
||||
reports.sort(key=lambda x: x['timestamp'], reverse=True)
|
||||
|
||||
return render_template('index.html',
|
||||
models=models,
|
||||
videos=videos,
|
||||
reports=reports,
|
||||
detection_status=detection_status,
|
||||
detection_progress=detection_progress,
|
||||
detection_in_progress=(detection_status == "running"),
|
||||
confidence_threshold=config.CONFIDENCE_THRESHOLD,
|
||||
config=config)
|
||||
|
||||
# 上传模型
|
||||
@app.route('/upload_model', methods=['POST'])
|
||||
def upload_model():
|
||||
if 'model_file' not in request.files:
|
||||
flash('没有选择文件', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
file = request.files['model_file']
|
||||
if file.filename == '':
|
||||
flash('没有选择文件', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
if file and allowed_file(file.filename, ALLOWED_MODEL_EXTENSIONS):
|
||||
filename = secure_filename(file.filename)
|
||||
model_path = os.path.join('models', filename)
|
||||
file.save(model_path)
|
||||
|
||||
# 立即检测模型信息并缓存
|
||||
try:
|
||||
print(f"正在分析上传的模型: {filename}")
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 保存模型信息到JSON文件
|
||||
model_info = {
|
||||
'filename': filename,
|
||||
'upload_time': datetime.now().isoformat(),
|
||||
'classes': dict(model.names) if hasattr(model, 'names') else {},
|
||||
'class_count': len(model.names) if hasattr(model, 'names') else 0,
|
||||
'model_type': 'YOLO',
|
||||
'analyzed': True
|
||||
}
|
||||
|
||||
# 创建模型信息文件
|
||||
info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json")
|
||||
with open(info_path, 'w', encoding='utf-8') as f:
|
||||
import json
|
||||
json.dump(model_info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 显示模型分析结果
|
||||
if hasattr(model, 'names'):
|
||||
scene_types = list(model.names.values())
|
||||
flash(f'模型 {filename} 上传成功!检测到 {len(scene_types)} 种场景类型: {", ".join(scene_types)}', 'success')
|
||||
else:
|
||||
flash(f'模型 {filename} 上传成功!', 'success')
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型分析失败: {e}")
|
||||
flash(f'模型 {filename} 上传成功,但分析模型信息时出错: {str(e)}', 'warning')
|
||||
else:
|
||||
flash('不支持的文件类型', 'danger')
|
||||
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 获取模型列表
|
||||
@app.route('/get_models', methods=['GET'])
|
||||
def get_models():
|
||||
models = []
|
||||
model_files = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
|
||||
|
||||
for model_file in model_files:
|
||||
model_info = {
|
||||
'filename': model_file,
|
||||
'analyzed': False,
|
||||
'classes': {},
|
||||
'class_count': 0,
|
||||
'scene_types': []
|
||||
}
|
||||
|
||||
# 尝试读取模型信息文件
|
||||
info_path = os.path.join('models', f"{os.path.splitext(model_file)[0]}_info.json")
|
||||
if os.path.exists(info_path):
|
||||
try:
|
||||
with open(info_path, 'r', encoding='utf-8') as f:
|
||||
import json
|
||||
cached_info = json.load(f)
|
||||
model_info.update(cached_info)
|
||||
model_info['scene_types'] = list(cached_info.get('classes', {}).values())
|
||||
except Exception as e:
|
||||
print(f"读取模型信息文件失败: {e}")
|
||||
|
||||
models.append(model_info)
|
||||
|
||||
return jsonify(models)
|
||||
|
||||
# 删除模型
|
||||
@app.route('/delete_model/<filename>', methods=['POST'])
|
||||
def delete_model(filename):
|
||||
try:
|
||||
model_path = os.path.join('models', filename)
|
||||
info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json")
|
||||
|
||||
deleted_files = []
|
||||
if os.path.exists(model_path):
|
||||
os.remove(model_path)
|
||||
deleted_files.append('模型文件')
|
||||
|
||||
# 同时删除模型信息文件
|
||||
if os.path.exists(info_path):
|
||||
os.remove(info_path)
|
||||
deleted_files.append('信息文件')
|
||||
|
||||
if deleted_files:
|
||||
flash(f'模型 {filename} 及其{"、".join(deleted_files)}已删除', 'success')
|
||||
else:
|
||||
flash(f'模型 {filename} 不存在', 'warning')
|
||||
except Exception as e:
|
||||
flash(f'删除模型时出错: {str(e)}', 'danger')
|
||||
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 上传视频
|
||||
@app.route('/upload_video', methods=['POST'])
|
||||
def upload_video():
|
||||
if 'video_files[]' not in request.files:
|
||||
flash('没有选择文件', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
files = request.files.getlist('video_files[]')
|
||||
if not files or files[0].filename == '':
|
||||
flash('没有选择文件', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
for file in files:
|
||||
if file and allowed_file(file.filename, ALLOWED_VIDEO_EXTENSIONS):
|
||||
filename = secure_filename(file.filename)
|
||||
file.save(os.path.join(config.INPUT_FOLDER, filename))
|
||||
success_count += 1
|
||||
else:
|
||||
error_count += 1
|
||||
|
||||
if success_count > 0:
|
||||
flash(f'成功上传 {success_count} 个视频文件', 'success')
|
||||
if error_count > 0:
|
||||
flash(f'{error_count} 个文件上传失败(不支持的文件类型)', 'warning')
|
||||
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 获取视频列表
|
||||
@app.route('/get_videos', methods=['GET'])
|
||||
def get_videos():
|
||||
videos = [f for f in os.listdir(config.INPUT_FOLDER) if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
|
||||
return jsonify(videos)
|
||||
|
||||
# 删除视频
|
||||
@app.route('/delete_video/<filename>', methods=['POST'])
|
||||
def delete_video(filename):
|
||||
try:
|
||||
video_path = os.path.join(config.INPUT_FOLDER, filename)
|
||||
if os.path.exists(video_path):
|
||||
os.remove(video_path)
|
||||
flash(f'视频 {filename} 已删除', 'success')
|
||||
else:
|
||||
flash(f'视频 {filename} 不存在', 'warning')
|
||||
except Exception as e:
|
||||
flash(f'删除视频时出错: {str(e)}', 'danger')
|
||||
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 清空输入文件夹
|
||||
@app.route('/clear_input_folder', methods=['POST'])
|
||||
def clear_input_folder():
|
||||
try:
|
||||
for filename in os.listdir(config.INPUT_FOLDER):
|
||||
file_path = os.path.join(config.INPUT_FOLDER, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.unlink(file_path)
|
||||
flash('输入文件夹已清空', 'success')
|
||||
except Exception as e:
|
||||
flash(f'清空文件夹时出错: {str(e)}', 'danger')
|
||||
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 开始检测
|
||||
@app.route('/start_detection', methods=['POST'])
|
||||
def start_detection():
|
||||
global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time
|
||||
|
||||
print("收到检测请求")
|
||||
print(f"当前检测状态: {detection_status}")
|
||||
|
||||
if detection_status == "running":
|
||||
flash('已有检测任务正在运行', 'warning')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 获取选择的模型和置信度阈值
|
||||
model_name = request.form.get('model_select')
|
||||
confidence_threshold = float(request.form.get('confidence_threshold', 0.25))
|
||||
|
||||
print(f"接收到的模型名称: {model_name}")
|
||||
print(f"接收到的置信度阈值: {confidence_threshold}")
|
||||
|
||||
# 检查是否有视频文件
|
||||
videos = [f for f in os.listdir(config.INPUT_FOLDER)
|
||||
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
|
||||
if not videos:
|
||||
flash('输入文件夹中没有视频文件,请先上传视频', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 检查模型文件是否存在
|
||||
if not model_name:
|
||||
flash('请选择一个检测模型', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
model_path = os.path.join('models', model_name)
|
||||
if not os.path.exists(model_path):
|
||||
flash(f'模型文件 {model_name} 不存在', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 更新配置
|
||||
config.MODEL_PATH = model_path
|
||||
config.CONFIDENCE_THRESHOLD = confidence_threshold
|
||||
|
||||
# 重置检测状态和进度
|
||||
detection_progress = 0
|
||||
detection_status = "running"
|
||||
detection_results = None
|
||||
current_report_path = None
|
||||
detection_start_time = datetime.now()
|
||||
|
||||
# 重置进度相关变量
|
||||
global current_video_name, total_videos, current_video_index, current_frame_count, total_frame_count
|
||||
current_video_name = ""
|
||||
total_videos = len(videos)
|
||||
current_video_index = 0
|
||||
current_frame_count = 0
|
||||
total_frame_count = 0
|
||||
|
||||
# 启动检测线程
|
||||
detection_thread = threading.Thread(target=run_detection)
|
||||
detection_thread.daemon = True
|
||||
detection_thread.start()
|
||||
|
||||
flash('检测已开始,请等待结果', 'info')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 检测线程函数
|
||||
def run_detection():
|
||||
global detection_progress, detection_status, detection_results, current_report_path, current_video_name, current_video_index, current_frame_count
|
||||
|
||||
try:
|
||||
print(f"开始检测,模型路径: {config.MODEL_PATH}")
|
||||
print(f"置信度阈值: {config.CONFIDENCE_THRESHOLD}")
|
||||
print(f"输入文件夹: {config.INPUT_FOLDER}")
|
||||
|
||||
# 获取视频列表
|
||||
videos = [f for f in os.listdir(config.INPUT_FOLDER)
|
||||
if any(f.endswith(ext) for ext in config.SUPPORTED_VIDEO_FORMATS)]
|
||||
print(f"找到视频文件: {videos}")
|
||||
|
||||
# 创建检测报告基本信息
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
# 创建检测器实例
|
||||
print("正在创建检测器实例...")
|
||||
detector = BatchVideoDetector(
|
||||
model_path=config.MODEL_PATH,
|
||||
confidence=config.CONFIDENCE_THRESHOLD
|
||||
)
|
||||
print("检测器实例创建成功")
|
||||
|
||||
# 显示模型检测能力信息
|
||||
print("\n=== 网页端模型检测分析 ===")
|
||||
if hasattr(detector.model, 'names'):
|
||||
model_classes = list(detector.model.names.values())
|
||||
print(f"当前模型可检测场景: {model_classes}")
|
||||
print(f"使用的置信度阈值: {config.CONFIDENCE_THRESHOLD}")
|
||||
print("=" * 40)
|
||||
|
||||
# 设置进度回调
|
||||
def progress_callback(status, progress, video_name=None, video_index=None, frame_count=None):
|
||||
global detection_status, detection_progress, current_video_name, current_video_index, current_frame_count
|
||||
detection_status = status
|
||||
detection_progress = progress
|
||||
|
||||
# 更新详细进度信息
|
||||
if video_name:
|
||||
current_video_name = video_name
|
||||
if video_index is not None:
|
||||
current_video_index = video_index
|
||||
if frame_count is not None:
|
||||
current_frame_count = frame_count
|
||||
|
||||
print(f"进度更新: {status} - {progress}%")
|
||||
if video_name:
|
||||
print(f"当前视频: {video_name} ({video_index}/{total_videos})")
|
||||
if frame_count is not None:
|
||||
print(f"已处理帧数: {frame_count}")
|
||||
|
||||
detector.set_progress_callback(progress_callback)
|
||||
|
||||
# 执行检测
|
||||
detection_status = "开始检测视频..."
|
||||
report_path = detector.process_all_videos(config.INPUT_FOLDER, config.OUTPUT_FOLDER)
|
||||
|
||||
# 读取检测结果
|
||||
if report_path and os.path.exists(report_path):
|
||||
with open(report_path, 'r', encoding='utf-8') as f:
|
||||
detection_results = json.load(f)
|
||||
|
||||
current_report_path = report_path
|
||||
detection_status = "检测完成"
|
||||
detection_progress = 100
|
||||
else:
|
||||
detection_status = "检测完成,但未生成报告"
|
||||
detection_progress = 100
|
||||
except Exception as e:
|
||||
detection_status = f"检测失败: {str(e)}"
|
||||
print(f"检测异常: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# 如果状态仍为running,则更新为失败
|
||||
if detection_status == "running":
|
||||
detection_status = "失败"
|
||||
# 确保进度显示为100%
|
||||
if detection_progress < 100:
|
||||
detection_progress = 100
|
||||
|
||||
# 获取检测进度
|
||||
@app.route('/detection_progress')
|
||||
def get_detection_progress():
|
||||
# 计算检测时长
|
||||
elapsed_time = ""
|
||||
if detection_start_time and detection_status == "running":
|
||||
elapsed_seconds = (datetime.now() - detection_start_time).total_seconds()
|
||||
minutes, seconds = divmod(elapsed_seconds, 60)
|
||||
elapsed_time = f"{int(minutes)}分{int(seconds)}秒"
|
||||
|
||||
# 构建详细状态信息
|
||||
detailed_status = detection_status
|
||||
if detection_status == "running" and current_video_name:
|
||||
detailed_status = f"正在处理: {current_video_name} ({current_video_index}/{total_videos})"
|
||||
if current_frame_count > 0:
|
||||
detailed_status += f" - 已处理帧数: {current_frame_count}"
|
||||
|
||||
return jsonify({
|
||||
'in_progress': detection_status == "running",
|
||||
'progress': detection_progress,
|
||||
'status': detailed_status,
|
||||
'elapsed_time': elapsed_time,
|
||||
'current_video': current_video_name,
|
||||
'current_video_index': current_video_index,
|
||||
'total_videos': total_videos,
|
||||
'current_frame_count': current_frame_count,
|
||||
'total_frame_count': total_frame_count
|
||||
})
|
||||
|
||||
# 查看检测报告
|
||||
@app.route('/report/<report_filename>')
|
||||
def view_report(report_filename):
|
||||
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
|
||||
if not os.path.exists(report_path):
|
||||
flash('报告文件不存在', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
try:
|
||||
with open(report_path, 'r', encoding='utf-8') as f:
|
||||
report_data = json.load(f)
|
||||
|
||||
# 计算检测率
|
||||
total_frames = report_data.get('statistics', {}).get('total_frames', 0)
|
||||
detected_frames = report_data.get('statistics', {}).get('detected_frames', 0)
|
||||
detection_rate = 0
|
||||
if total_frames > 0:
|
||||
detection_rate = (detected_frames / total_frames) * 100
|
||||
|
||||
# 为每个视频计算检测率
|
||||
for video in report_data.get('videos', []):
|
||||
video_total_frames = video.get('total_frames', 0)
|
||||
video_detected_frames = video.get('detected_frames', 0)
|
||||
if video_total_frames > 0:
|
||||
video['detection_rate'] = (video_detected_frames / video_total_frames) * 100
|
||||
else:
|
||||
video['detection_rate'] = 0
|
||||
|
||||
return render_template('report.html',
|
||||
report=report_data,
|
||||
report_filename=report_filename,
|
||||
detection_rate=detection_rate)
|
||||
|
||||
except Exception as e:
|
||||
flash(f'读取报告时出错: {str(e)}', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 查看视频检测结果
|
||||
@app.route('/video_results/<report_filename>/<video_name>')
|
||||
def view_video_results(report_filename, video_name):
|
||||
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
|
||||
if not os.path.exists(report_path):
|
||||
flash('报告文件不存在', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
try:
|
||||
with open(report_path, 'r', encoding='utf-8') as f:
|
||||
report_data = json.load(f)
|
||||
|
||||
# 查找特定视频的结果
|
||||
video_data = None
|
||||
for video in report_data['videos']:
|
||||
if video['video_name'] == video_name:
|
||||
video_data = video
|
||||
break
|
||||
|
||||
if not video_data:
|
||||
flash('视频结果不存在', 'danger')
|
||||
return redirect(url_for('view_report', report_filename=report_filename))
|
||||
|
||||
# 获取检测到的帧列表
|
||||
video_output_dir = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video_name)[0])
|
||||
detected_frames = []
|
||||
|
||||
if os.path.exists(video_output_dir):
|
||||
for frame_file in os.listdir(video_output_dir):
|
||||
if frame_file.startswith('detected_') and frame_file.endswith('.jpg'):
|
||||
try:
|
||||
frame_number = int(frame_file.split('_')[1].split('.')[0])
|
||||
frame_path = os.path.join(os.path.basename(video_output_dir), frame_file)
|
||||
|
||||
# 获取对应的JSON文件以读取检测信息
|
||||
json_file = frame_file.replace('.jpg', '.json')
|
||||
json_path = os.path.join(video_output_dir, json_file)
|
||||
detections = []
|
||||
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
frame_data = json.load(f)
|
||||
detections = frame_data.get('detections', [])
|
||||
|
||||
detected_frames.append({
|
||||
'frame_number': frame_number,
|
||||
'filename': frame_file,
|
||||
'path': frame_path,
|
||||
'timestamp': frame_number / video_data.get('fps', 30), # 计算时间戳
|
||||
'detections': detections,
|
||||
'detection_count': len(detections)
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Error processing frame {frame_file}: {e}")
|
||||
|
||||
# 按帧号排序
|
||||
detected_frames.sort(key=lambda x: x['frame_number'])
|
||||
|
||||
# 计算检测率
|
||||
total_frames = video_data.get('total_frames', 0)
|
||||
detected_count = video_data.get('detected_frames', 0)
|
||||
detection_rate = 0
|
||||
if total_frames > 0:
|
||||
detection_rate = (detected_count / total_frames) * 100
|
||||
|
||||
# 获取时间线数据
|
||||
timeline_data = []
|
||||
for frame in detected_frames:
|
||||
timeline_data.append({
|
||||
'timestamp': frame['timestamp'],
|
||||
'frame_number': frame['frame_number'],
|
||||
'detection_count': frame['detection_count']
|
||||
})
|
||||
|
||||
return render_template('video_results.html',
|
||||
report_filename=report_filename,
|
||||
report=report_data,
|
||||
video=video_data,
|
||||
detected_frames=detected_frames,
|
||||
detection_rate=detection_rate,
|
||||
timeline_data=json.dumps(timeline_data))
|
||||
|
||||
except Exception as e:
|
||||
flash(f'读取视频结果时出错: {str(e)}', 'danger')
|
||||
return redirect(url_for('view_report', report_filename=report_filename))
|
||||
|
||||
# 获取检测到的帧图像
|
||||
@app.route('/frame/<path:frame_path>')
|
||||
def get_frame(frame_path):
|
||||
directory, filename = os.path.split(frame_path)
|
||||
return send_from_directory(os.path.join(config.OUTPUT_FOLDER, directory), filename)
|
||||
|
||||
# 删除报告
|
||||
@app.route('/delete_report/<report_filename>', methods=['POST'])
|
||||
def delete_report(report_filename):
|
||||
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
|
||||
if not os.path.exists(report_path):
|
||||
flash('报告文件不存在', 'danger')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 读取报告以获取视频文件夹列表
|
||||
try:
|
||||
with open(report_path, 'r', encoding='utf-8') as f:
|
||||
report_data = json.load(f)
|
||||
|
||||
# 删除每个视频的输出文件夹
|
||||
for video in report_data['videos']:
|
||||
video_output_folder = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video['video_name'])[0])
|
||||
if os.path.exists(video_output_folder) and os.path.isdir(video_output_folder):
|
||||
shutil.rmtree(video_output_folder)
|
||||
|
||||
# 删除报告文件
|
||||
os.remove(report_path)
|
||||
flash(f'报告 {report_filename} 已删除', 'success')
|
||||
except Exception as e:
|
||||
flash(f'删除报告失败: {str(e)}', 'danger')
|
||||
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# 添加BatchVideoDetector的进度回调方法
|
||||
def add_progress_callback_to_detector():
|
||||
# 检查BatchVideoDetector类是否已有set_progress_callback方法
|
||||
if not hasattr(BatchVideoDetector, 'set_progress_callback'):
|
||||
def set_progress_callback(self, callback):
|
||||
self.progress_callback = callback
|
||||
|
||||
def update_progress(self, status, progress, video_name=None, video_index=None, frame_count=None):
|
||||
if hasattr(self, 'progress_callback') and self.progress_callback:
|
||||
self.progress_callback(status, progress, video_name, video_index, frame_count)
|
||||
|
||||
# 添加方法到BatchVideoDetector类
|
||||
BatchVideoDetector.set_progress_callback = set_progress_callback
|
||||
BatchVideoDetector.update_progress = update_progress
|
||||
|
||||
# 修改process_all_videos方法以支持进度更新
|
||||
original_process_all_videos = BatchVideoDetector.process_all_videos
|
||||
|
||||
def process_all_videos_with_progress(self, input_folder="input_videos", output_folder="output_frames"):
|
||||
self.update_progress("获取视频文件列表", 0)
|
||||
video_files = self.get_video_files(input_folder)
|
||||
total_videos = len(video_files)
|
||||
|
||||
if total_videos == 0:
|
||||
self.update_progress("没有找到视频文件", 100)
|
||||
return None
|
||||
|
||||
self.update_progress(f"开始处理 {total_videos} 个视频文件", 5)
|
||||
|
||||
# 处理每个视频文件
|
||||
for i, video_path in enumerate(video_files, 1):
|
||||
progress = int(5 + (i / total_videos) * 90)
|
||||
video_name = video_path.name
|
||||
self.update_progress(f"处理视频 {i}/{total_videos}: {video_name}", progress, video_name, i)
|
||||
self.process_video(video_path, output_folder)
|
||||
|
||||
self.update_progress("生成检测报告", 95)
|
||||
self.stats['end_time'] = datetime.now()
|
||||
report_path = self.save_final_report(output_folder)
|
||||
self.print_summary()
|
||||
self.update_progress("检测完成", 100)
|
||||
return report_path
|
||||
|
||||
# 修改save_final_report方法以返回报告路径
|
||||
original_save_final_report = BatchVideoDetector.save_final_report
|
||||
|
||||
def save_final_report_with_return(self, output_folder):
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
|
||||
|
||||
# 计算总处理时间
|
||||
total_seconds = 0
|
||||
if self.stats['start_time'] and self.stats['end_time']:
|
||||
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
|
||||
|
||||
# 创建报告数据
|
||||
report_data = {
|
||||
'timestamp': timestamp,
|
||||
'model_path': self.model_path,
|
||||
'confidence_threshold': self.confidence,
|
||||
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
|
||||
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
|
||||
'statistics': {
|
||||
'total_videos': self.stats['total_videos'],
|
||||
'processed_videos': self.stats['processed_videos'],
|
||||
'total_frames': self.stats['total_frames'],
|
||||
'detected_frames': self.stats['detected_frames'],
|
||||
'filtered_frames': self.stats['filtered_frames'],
|
||||
'saved_images': self.stats['saved_images'],
|
||||
'processing_time_seconds': total_seconds
|
||||
},
|
||||
'errors': self.stats['errors'],
|
||||
'videos': []
|
||||
}
|
||||
|
||||
# 保存为JSON
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n批处理报告已保存: {report_path}")
|
||||
return str(report_path)
|
||||
|
||||
BatchVideoDetector.process_all_videos = process_all_videos_with_progress
|
||||
BatchVideoDetector.save_final_report = save_final_report_with_return
|
||||
|
||||
# 在应用启动前添加进度回调方法
|
||||
add_progress_callback_to_detector()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 确保必要的文件夹存在
|
||||
for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']:
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
app.run(debug=True, host='0.0.0.0', port=5000)
|
||||
642
batch_detector.py
Normal file
642
batch_detector.py
Normal file
@ -0,0 +1,642 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
批量视频检测器
|
||||
简化版本,专注于高效处理大量视频文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
import torch
|
||||
from config import config, DetectionConfig
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
class BatchVideoDetector:
|
||||
def __init__(self, model_path=None, confidence=0.5):
|
||||
"""
|
||||
初始化批量视频检测器
|
||||
|
||||
Args:
|
||||
model_path (str): 模型路径,如果为None则使用配置文件中的当前模型路径
|
||||
confidence (float): 置信度阈值
|
||||
"""
|
||||
# 优先使用传入的模型路径,如果没有则使用配置中的当前模型路径
|
||||
self.model_path = model_path or config.MODEL_PATH
|
||||
self.confidence = confidence
|
||||
|
||||
# 创建必要文件夹
|
||||
config.create_folders()
|
||||
|
||||
# 加载模型
|
||||
self.load_model()
|
||||
|
||||
# 时间间隔过滤变量
|
||||
self.last_saved_timestamp = 0
|
||||
|
||||
# 初始化场景类别变量
|
||||
self.last_scene_classes = set()
|
||||
|
||||
# 动态检测参数
|
||||
self.adaptive_confidence = True
|
||||
self.min_confidence = 0.2
|
||||
self.scene_confidence_history = [] # 记录场景置信度历史
|
||||
self.detection_quality_threshold = 0.3 # 检测质量阈值
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
'start_time': None,
|
||||
'end_time': None,
|
||||
'total_videos': 0,
|
||||
'processed_videos': 0,
|
||||
'total_frames': 0,
|
||||
'detected_frames': 0,
|
||||
'saved_images': 0,
|
||||
'filtered_frames': 0,
|
||||
'adaptive_detections': 0, # 自适应检测次数
|
||||
'errors': []
|
||||
}
|
||||
|
||||
# 视频详细信息列表
|
||||
self.video_details = []
|
||||
|
||||
def load_model(self):
|
||||
"""加载YOLO模型"""
|
||||
try:
|
||||
print(f"正在加载模型: {self.model_path}")
|
||||
self.model = YOLO(self.model_path)
|
||||
print(f"模型加载成功!")
|
||||
|
||||
# 检查是否有缓存的模型信息
|
||||
model_filename = os.path.basename(self.model_path)
|
||||
info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json")
|
||||
|
||||
cached_info_loaded = False
|
||||
if os.path.exists(info_path):
|
||||
try:
|
||||
with open(info_path, 'r', encoding='utf-8') as f:
|
||||
import json
|
||||
cached_info = json.load(f)
|
||||
if cached_info.get('analyzed', False):
|
||||
print(f"使用缓存的模型信息,跳过重复分析")
|
||||
print(f"模型类别数量: {cached_info.get('class_count', 0)}")
|
||||
print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}")
|
||||
cached_info_loaded = True
|
||||
except Exception as e:
|
||||
print(f"读取缓存模型信息失败: {e},将重新分析")
|
||||
|
||||
# 如果没有缓存信息,则进行完整分析
|
||||
if not cached_info_loaded:
|
||||
# 动态加载模型类别信息到配置中
|
||||
DetectionConfig.load_model_classes(self.model_path)
|
||||
|
||||
if hasattr(self.model, 'names'):
|
||||
print(f"\n=== 模型检测场景分析 ===")
|
||||
print(f"模型类别数量: {len(self.model.names)}")
|
||||
print(f"模型原始类别: {list(self.model.names.values())}")
|
||||
|
||||
print(f"\n=== 中文类别映射 ===")
|
||||
for class_id, class_name in self.model.names.items():
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
print(f"类别 {class_id}: {class_name} -> {class_name_cn}")
|
||||
|
||||
# 分析模型类型并优化参数
|
||||
self.analyze_and_optimize_model()
|
||||
|
||||
print(f"\n=== 检测场景总结 ===")
|
||||
scene_types = list(DetectionConfig.get_all_classes().values())
|
||||
print(f"本模型可检测的场景类型: {scene_types}")
|
||||
print(f"检测置信度阈值: {self.confidence}")
|
||||
print("=" * 40)
|
||||
else:
|
||||
# 使用缓存信息更新配置
|
||||
DetectionConfig.load_model_classes(self.model_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
def analyze_and_optimize_model(self):
|
||||
"""分析模型类型并优化检测参数"""
|
||||
model_classes = list(self.model.names.values())
|
||||
|
||||
# 检测模型类型
|
||||
road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤']
|
||||
general_object_keywords = ['person', 'car', 'truck', 'bus', 'bicycle']
|
||||
|
||||
is_road_damage = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords)
|
||||
is_general_object = any(keyword in ' '.join(model_classes).lower() for keyword in general_object_keywords)
|
||||
|
||||
if is_road_damage:
|
||||
print("检测到道路损伤专用模型,优化参数设置")
|
||||
self.confidence = max(0.25, self.confidence)
|
||||
self.min_confidence = 0.15
|
||||
self.detection_quality_threshold = 0.2
|
||||
elif is_general_object:
|
||||
print("检测到通用目标检测模型,使用标准参数")
|
||||
self.confidence = max(0.4, self.confidence)
|
||||
self.min_confidence = 0.3
|
||||
self.detection_quality_threshold = 0.35
|
||||
else:
|
||||
print("检测到自定义模型,使用保守参数")
|
||||
self.confidence = max(0.3, self.confidence)
|
||||
self.min_confidence = 0.2
|
||||
self.detection_quality_threshold = 0.25
|
||||
|
||||
print(f"优化后的置信度阈值: {self.confidence}")
|
||||
print(f"最低置信度阈值: {self.min_confidence}")
|
||||
print(f"检测质量阈值: {self.detection_quality_threshold}")
|
||||
|
||||
def get_video_files(self, input_folder="input_videos"):
|
||||
"""获取所有视频文件"""
|
||||
input_path = Path(input_folder)
|
||||
if not input_path.exists():
|
||||
print(f"输入文件夹不存在: {input_folder}")
|
||||
return []
|
||||
|
||||
video_files = []
|
||||
for ext in config.SUPPORTED_VIDEO_FORMATS:
|
||||
video_files.extend(input_path.rglob(f"*{ext}"))
|
||||
|
||||
print(f"找到 {len(video_files)} 个视频文件")
|
||||
return video_files
|
||||
|
||||
def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir):
|
||||
"""检测单帧并保存结果"""
|
||||
try:
|
||||
# 进行多级检测
|
||||
result = self.adaptive_detect_frame(frame)
|
||||
|
||||
# 如果检测到目标
|
||||
if result is not None and result.boxes is not None and len(result.boxes) > 0:
|
||||
# 获取当前帧的检测类别
|
||||
current_scene_classes = set()
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
current_scene_classes.add(class_id)
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = timestamp - self.last_saved_timestamp
|
||||
|
||||
# 如果时间间隔小于最小间隔,检查是否为新场景
|
||||
if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0:
|
||||
# 如果当前帧的场景类别与上一帧相同,则过滤
|
||||
if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes:
|
||||
self.stats['filtered_frames'] += 1
|
||||
# 仍然返回检测信息,但不保存图像
|
||||
detections = []
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
bbox = box.xyxy[0].cpu().numpy().tolist()
|
||||
class_name = self.model.names.get(class_id, f"class_{class_id}")
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
|
||||
detections.append({
|
||||
'class_id': class_id,
|
||||
'class_name': class_name,
|
||||
'class_name_cn': class_name_cn,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox,
|
||||
'filtered': True
|
||||
})
|
||||
|
||||
return {
|
||||
'frame_number': frame_number,
|
||||
'timestamp': timestamp,
|
||||
'detections': detections,
|
||||
'filtered': True
|
||||
}
|
||||
|
||||
# 更新最后检测到的场景类别
|
||||
self.last_scene_classes = current_scene_classes
|
||||
|
||||
# 更新最后保存的时间戳
|
||||
self.last_saved_timestamp = timestamp
|
||||
|
||||
# 创建文件名
|
||||
base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s"
|
||||
|
||||
# 获取检测到的场景类别,用于分类保存
|
||||
detected_scenes = set()
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
detected_scenes.add(class_name_cn)
|
||||
|
||||
# 为每个检测到的场景创建对应的文件夹并保存图片
|
||||
for scene_name in detected_scenes:
|
||||
# 创建视频名称文件夹
|
||||
video_folder = output_dir / video_name
|
||||
video_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建场景文件夹
|
||||
scene_folder = video_folder / scene_name
|
||||
scene_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存原始帧
|
||||
if config.SAVE_ORIGINAL_FRAMES:
|
||||
original_path = scene_folder / f"{base_name}_original.jpg"
|
||||
cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
|
||||
|
||||
# 绘制并保存标注帧
|
||||
if config.SAVE_ANNOTATED_FRAMES:
|
||||
annotated_frame = self.draw_detections(frame.copy(), result)
|
||||
annotated_path = scene_folder / f"{base_name}_detected.jpg"
|
||||
cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
|
||||
|
||||
# 收集检测信息
|
||||
detections = []
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
bbox = box.xyxy[0].cpu().numpy().tolist()
|
||||
class_name = self.model.names.get(class_id, f"class_{class_id}")
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
|
||||
detections.append({
|
||||
'class_id': class_id,
|
||||
'class_name': class_name,
|
||||
'class_name_cn': class_name_cn,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox
|
||||
})
|
||||
|
||||
return {
|
||||
'frame_number': frame_number,
|
||||
'timestamp': timestamp,
|
||||
'detections': detections
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理帧 {frame_number} 时出错: {e}")
|
||||
self.stats['errors'].append(f"帧 {frame_number}: {str(e)}")
|
||||
return None
|
||||
|
||||
def adaptive_detect_frame(self, frame):
|
||||
"""自适应检测单帧"""
|
||||
try:
|
||||
# 第一次检测:使用标准置信度
|
||||
results = self.model(frame, conf=self.confidence, verbose=False,
|
||||
imgsz=640, augment=True)
|
||||
result = results[0] if results else None
|
||||
|
||||
# 如果没有检测到目标且启用自适应检测
|
||||
if self.adaptive_confidence and (not result or not result.boxes or len(result.boxes) == 0):
|
||||
# 尝试降低置信度检测
|
||||
lower_conf = max(self.min_confidence, self.confidence - 0.15)
|
||||
if lower_conf < self.confidence:
|
||||
results = self.model(frame, conf=lower_conf, verbose=False,
|
||||
imgsz=640, augment=True)
|
||||
result = results[0] if results else None
|
||||
|
||||
if result and result.boxes and len(result.boxes) > 0:
|
||||
self.stats['adaptive_detections'] += 1
|
||||
print(f"自适应检测成功,置信度: {lower_conf:.2f}")
|
||||
|
||||
# 过滤低质量检测
|
||||
if result and result.boxes and len(result.boxes) > 0:
|
||||
valid_indices = []
|
||||
for i, box in enumerate(result.boxes):
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
if confidence >= self.detection_quality_threshold:
|
||||
valid_indices.append(i)
|
||||
|
||||
# 如果有有效检测,保留原始结果结构
|
||||
if not valid_indices:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"自适应检测失败: {e}")
|
||||
return None
|
||||
|
||||
def draw_detections(self, frame, result):
|
||||
"""在帧上绘制检测结果(绘制边界框和中文标注)"""
|
||||
if result is None or result.boxes is None:
|
||||
return frame
|
||||
|
||||
# 转换为PIL图像
|
||||
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
|
||||
# 尝试加载中文字体
|
||||
try:
|
||||
# Windows系统常用中文字体路径
|
||||
font_paths = [
|
||||
"C:/Windows/Fonts/simhei.ttf", # 黑体
|
||||
"C:/Windows/Fonts/simsun.ttc", # 宋体
|
||||
"C:/Windows/Fonts/msyh.ttc", # 微软雅黑
|
||||
]
|
||||
font = None
|
||||
for font_path in font_paths:
|
||||
try:
|
||||
font = ImageFont.truetype(font_path, 20)
|
||||
break
|
||||
except:
|
||||
continue
|
||||
if font is None:
|
||||
font = ImageFont.load_default()
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for box in result.boxes:
|
||||
# 获取坐标和信息
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
|
||||
# 动态获取类别名称(优先使用模型原始名称)
|
||||
if hasattr(self.model, 'names') and class_id in self.model.names:
|
||||
original_name = self.model.names[class_id]
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
# 如果中文名称就是原始名称,说明没有映射,直接使用原始名称
|
||||
if class_name_cn == original_name or class_name_cn.startswith('未知类别'):
|
||||
display_name = original_name
|
||||
else:
|
||||
display_name = f"{class_name_cn}({original_name})"
|
||||
else:
|
||||
display_name = config.get_class_name_cn(class_id)
|
||||
|
||||
# 根据置信度调整颜色强度和线宽
|
||||
base_color = config.get_class_color(class_id)
|
||||
intensity = min(1.0, confidence + 0.3)
|
||||
color = tuple(int(c * intensity) for c in base_color)
|
||||
bg_color = tuple(reversed(color)) # BGR转RGB
|
||||
line_width = max(2, int(confidence * 4))
|
||||
|
||||
# 绘制边界框
|
||||
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width)
|
||||
|
||||
# 准备标注文字
|
||||
label_text = f"{display_name}: {confidence:.3f}"
|
||||
|
||||
# 计算文字背景框大小
|
||||
bbox = draw.textbbox((0, 0), label_text, font=font)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
|
||||
# 确保标注不超出图像边界
|
||||
label_x = max(0, x1)
|
||||
label_y = max(0, y1 - text_height - 5)
|
||||
if label_y < 0:
|
||||
label_y = y1 + 5
|
||||
|
||||
# 绘制文字背景
|
||||
draw.rectangle(
|
||||
[label_x, label_y, label_x + text_width + 4, label_y + text_height + 4],
|
||||
fill=bg_color
|
||||
)
|
||||
|
||||
# 绘制文字(白色)
|
||||
draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font)
|
||||
|
||||
# 转换回OpenCV格式
|
||||
frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
return frame
|
||||
|
||||
def process_video(self, video_path, output_base_dir="output_frames"):
|
||||
"""处理单个视频"""
|
||||
print(f"\n开始处理: {video_path.name}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(output_base_dir) / video_path.stem
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 文件夹将在检测到对应类别时按需创建
|
||||
|
||||
# 重置时间间隔过滤变量,确保每个视频独立计算时间间隔
|
||||
self.last_saved_timestamp = 0
|
||||
|
||||
# 重置场景类别变量,确保每个视频独立分类场景
|
||||
self.last_scene_classes = set()
|
||||
|
||||
# 打开视频
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
error_msg = f"无法打开视频: {video_path}"
|
||||
print(error_msg)
|
||||
self.stats['errors'].append(error_msg)
|
||||
return
|
||||
|
||||
# 获取视频信息
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS")
|
||||
|
||||
# 处理统计
|
||||
frame_count = 0
|
||||
detected_frames = 0
|
||||
saved_images = 0
|
||||
filtered_frames = 0
|
||||
all_detections = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
timestamp = frame_count / fps if fps > 0 else frame_count
|
||||
|
||||
# 显示进度
|
||||
if frame_count % config.PROGRESS_INTERVAL == 0:
|
||||
progress = (frame_count / total_frames) * 100
|
||||
elapsed = time.time() - start_time
|
||||
fps_current = frame_count / elapsed
|
||||
print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS")
|
||||
|
||||
# 更新进度回调
|
||||
if hasattr(self, 'update_progress'):
|
||||
self.update_progress(f"处理视频帧", None, None, None, frame_count)
|
||||
|
||||
# 检测并保存
|
||||
detection_result = self.detect_and_save_frame(
|
||||
frame, frame_count, timestamp, video_path.stem, output_dir
|
||||
)
|
||||
|
||||
if detection_result:
|
||||
detected_frames += 1
|
||||
all_detections.append(detection_result)
|
||||
|
||||
# 检查是否是被过滤的帧
|
||||
if detection_result.get('filtered', False):
|
||||
filtered_frames += 1
|
||||
else:
|
||||
# 计算保存的图片数量
|
||||
if config.SAVE_ORIGINAL_FRAMES:
|
||||
saved_images += 1
|
||||
if config.SAVE_ANNOTATED_FRAMES:
|
||||
saved_images += 1
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
# 保存检测结果JSON
|
||||
if config.SAVE_DETECTION_JSON and all_detections:
|
||||
json_path = output_dir / f"{video_path.stem}_detections.json"
|
||||
detection_summary = {
|
||||
'video_name': video_path.name,
|
||||
'total_frames': frame_count,
|
||||
'detected_frames': detected_frames,
|
||||
'fps': fps,
|
||||
'detections': all_detections,
|
||||
'processing_time': time.time() - start_time
|
||||
}
|
||||
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(detection_summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 更新统计
|
||||
self.stats['processed_videos'] += 1
|
||||
self.stats['total_frames'] += frame_count
|
||||
self.stats['detected_frames'] += detected_frames
|
||||
self.stats['saved_images'] += saved_images
|
||||
self.stats['filtered_frames'] += filtered_frames
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
print(f"处理完成: {video_path.name}")
|
||||
print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images} 张")
|
||||
print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
|
||||
print(f"处理时间: {processing_time:.2f} 秒")
|
||||
|
||||
# 记录视频详细信息
|
||||
video_info = {
|
||||
'video_name': video_path.name,
|
||||
'video_path': str(video_path),
|
||||
'total_frames': frame_count,
|
||||
'detected_frames': detected_frames,
|
||||
'saved_images': saved_images,
|
||||
'filtered_frames': filtered_frames,
|
||||
'processing_time': processing_time,
|
||||
'fps': fps,
|
||||
'duration': frame_count / fps if fps > 0 else 0,
|
||||
'output_directory': str(output_dir)
|
||||
}
|
||||
self.video_details.append(video_info)
|
||||
|
||||
def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"):
|
||||
"""批量处理所有视频"""
|
||||
self.stats['start_time'] = datetime.now()
|
||||
|
||||
video_files = self.get_video_files(input_folder)
|
||||
if not video_files:
|
||||
print("没有找到视频文件")
|
||||
return
|
||||
|
||||
self.stats['total_videos'] = len(video_files)
|
||||
|
||||
print(f"\n开始批量处理 {len(video_files)} 个视频文件")
|
||||
print(f"模型: {self.model_path}")
|
||||
print(f"置信度阈值: {self.confidence}")
|
||||
print(f"输出目录: {output_folder}")
|
||||
print("=" * 50)
|
||||
|
||||
for i, video_path in enumerate(video_files, 1):
|
||||
print(f"\n[{i}/{len(video_files)}] 处理视频")
|
||||
self.process_video(video_path, output_folder)
|
||||
|
||||
self.stats['end_time'] = datetime.now()
|
||||
report_path = self.save_final_report(output_folder)
|
||||
self.print_summary()
|
||||
return report_path
|
||||
|
||||
def save_final_report(self, output_folder):
|
||||
"""保存最终批处理报告"""
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
|
||||
|
||||
# 计算总处理时间
|
||||
total_seconds = 0
|
||||
if self.stats['start_time'] and self.stats['end_time']:
|
||||
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
|
||||
|
||||
# 创建报告数据
|
||||
report_data = {
|
||||
'timestamp': timestamp,
|
||||
'model_path': self.model_path,
|
||||
'confidence_threshold': self.confidence,
|
||||
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
|
||||
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
|
||||
'statistics': {
|
||||
'total_videos': self.stats['total_videos'],
|
||||
'processed_videos': self.stats['processed_videos'],
|
||||
'total_frames': self.stats['total_frames'],
|
||||
'detected_frames': self.stats['detected_frames'],
|
||||
'filtered_frames': self.stats['filtered_frames'],
|
||||
'saved_images': self.stats['saved_images'],
|
||||
'processing_time_seconds': total_seconds
|
||||
},
|
||||
'videos': self.video_details,
|
||||
'errors': self.stats['errors']
|
||||
}
|
||||
|
||||
# 保存为JSON
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n批处理报告已保存: {report_path}")
|
||||
return str(report_path)
|
||||
|
||||
def print_summary(self):
|
||||
"""打印处理摘要"""
|
||||
if not self.stats['end_time'] or not self.stats['start_time']:
|
||||
return
|
||||
|
||||
total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
|
||||
hours, remainder = divmod(total_time, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("批量处理摘要")
|
||||
print("=" * 50)
|
||||
print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}")
|
||||
print(f"总处理帧数: {self.stats['total_frames']}")
|
||||
print(f"检测到目标的帧数: {self.stats['detected_frames']}")
|
||||
print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
|
||||
print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次")
|
||||
print(f"实际保存的图片数量: {self.stats['saved_images']}")
|
||||
print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}秒")
|
||||
|
||||
if self.stats['errors']:
|
||||
print("\n处理过程中的错误:")
|
||||
for error in self.stats['errors']:
|
||||
print(f"- {error}")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='批量视频检测系统')
|
||||
parser.add_argument('--model', type=str, help='模型文件路径')
|
||||
parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值')
|
||||
parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹')
|
||||
parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建检测器
|
||||
detector = BatchVideoDetector(
|
||||
model_path=args.model,
|
||||
confidence=args.confidence
|
||||
)
|
||||
|
||||
# 开始批量处理
|
||||
detector.process_all_videos(args.input, args.output)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
200
config.py
Normal file
200
config.py
Normal file
@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频检测系统配置文件
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
|
||||
class DetectionConfig:
|
||||
"""检测系统配置类"""
|
||||
|
||||
# 模型配置
|
||||
MODEL_PATH = "models/best.pt"
|
||||
CONFIDENCE_THRESHOLD = 0.5
|
||||
|
||||
# 动态模型信息
|
||||
_model_classes = None
|
||||
_model_colors = None
|
||||
|
||||
# 文件夹配置
|
||||
INPUT_FOLDER = "input_videos"
|
||||
OUTPUT_FOLDER = "output_frames"
|
||||
LOG_FOLDER = "logs"
|
||||
|
||||
# 视频处理配置
|
||||
SUPPORTED_VIDEO_FORMATS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
|
||||
PROGRESS_INTERVAL = 100 # 每处理多少帧显示一次进度
|
||||
MIN_FRAME_INTERVAL = 1.0 # 检测帧的最小时间间隔(秒)
|
||||
|
||||
# 输出配置
|
||||
SAVE_ORIGINAL_FRAMES = False # 是否保存原始帧
|
||||
SAVE_ANNOTATED_FRAMES = True # 是否保存标注帧
|
||||
SAVE_DETECTION_JSON = True # 是否保存检测结果JSON
|
||||
|
||||
# 图像质量配置
|
||||
IMAGE_QUALITY = 95 # JPEG质量 (1-100)
|
||||
|
||||
# 默认道路损伤类别映射 (中文标签) - 作为备用
|
||||
DEFAULT_CLASS_NAMES_CN = {
|
||||
0: "纵向裂缝",
|
||||
1: "横向裂缝",
|
||||
2: "网状裂缝",
|
||||
3: "坑洞",
|
||||
4: "白线模糊"
|
||||
}
|
||||
|
||||
# 通用类别关键词映射(扩展版)
|
||||
COMMON_CLASS_MAPPINGS = {
|
||||
# 道路损伤相关
|
||||
'crack': '裂缝',
|
||||
'pothole': '坑洞',
|
||||
'damage': '损伤',
|
||||
'longitudinal': '纵向裂缝',
|
||||
'transverse': '横向裂缝',
|
||||
'alligator': '网状裂缝',
|
||||
'rutting': '车辙',
|
||||
'bleeding': '泛油',
|
||||
'weathering': '风化',
|
||||
'patching': '修补',
|
||||
|
||||
# 交通工具
|
||||
'person': '人',
|
||||
'people': '人群',
|
||||
'car': '汽车',
|
||||
'truck': '卡车',
|
||||
'bus': '公交车',
|
||||
'motorcycle': '摩托车',
|
||||
'bicycle': '自行车',
|
||||
'bike': '自行车',
|
||||
'vehicle': '车辆',
|
||||
|
||||
# 交通设施
|
||||
'road': '道路',
|
||||
'traffic': '交通',
|
||||
'sign': '标志',
|
||||
'light': '信号灯',
|
||||
'signal': '信号',
|
||||
'stop': '停车标志',
|
||||
'yield': '让行标志',
|
||||
'crosswalk': '人行横道',
|
||||
'lane': '车道',
|
||||
'marking': '标线',
|
||||
'white': '白线',
|
||||
'yellow': '黄线',
|
||||
|
||||
# 其他常见目标
|
||||
'animal': '动物',
|
||||
'dog': '狗',
|
||||
'cat': '猫',
|
||||
'bird': '鸟',
|
||||
'tree': '树',
|
||||
'building': '建筑',
|
||||
'pole': '杆子',
|
||||
'fence': '围栏'
|
||||
}
|
||||
|
||||
# 动态颜色生成配置
|
||||
COLOR_PALETTE = [
|
||||
(0, 255, 0), # 绿色
|
||||
(255, 0, 0), # 红色
|
||||
(0, 0, 255), # 蓝色
|
||||
(255, 255, 0), # 黄色
|
||||
(255, 0, 255), # 紫色
|
||||
(0, 255, 255), # 青色
|
||||
(255, 128, 0), # 橙色
|
||||
(128, 0, 255), # 紫罗兰
|
||||
(255, 192, 203), # 粉色
|
||||
(128, 128, 128) # 灰色
|
||||
]
|
||||
|
||||
# 默认类别颜色配置 (BGR格式)
|
||||
DEFAULT_CLASS_COLORS = {
|
||||
0: (0, 255, 0), # 绿色
|
||||
1: (255, 0, 0), # 红色
|
||||
2: (0, 0, 255), # 蓝色
|
||||
3: (255, 255, 0), # 黄色
|
||||
4: (255, 0, 255) # 紫色
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model_path(cls):
|
||||
"""获取模型文件的绝对路径"""
|
||||
current_dir = Path(__file__).parent
|
||||
model_path = current_dir / cls.MODEL_PATH
|
||||
return str(model_path.resolve())
|
||||
|
||||
@classmethod
|
||||
def create_folders(cls):
|
||||
"""创建必要的文件夹"""
|
||||
folders = [cls.INPUT_FOLDER, cls.OUTPUT_FOLDER, cls.LOG_FOLDER]
|
||||
for folder in folders:
|
||||
Path(folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def load_model_classes(cls, model_path):
|
||||
"""从模型文件动态加载类别信息"""
|
||||
try:
|
||||
model = YOLO(model_path)
|
||||
cls._model_classes = model.names
|
||||
|
||||
# 动态生成类别颜色
|
||||
cls._model_colors = {}
|
||||
for i in range(len(model.names)):
|
||||
# 使用调色板循环分配颜色
|
||||
color_index = i % len(cls.COLOR_PALETTE)
|
||||
cls._model_colors[i] = cls.COLOR_PALETTE[color_index]
|
||||
|
||||
print(f"成功加载 {len(model.names)} 个类别,分配了对应颜色")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"加载模型类别信息失败: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_class_name_cn(cls, class_id):
|
||||
"""获取类别的中文名称"""
|
||||
# 直接使用模型的原始类别名称,不进行任何映射
|
||||
if cls._model_classes is not None:
|
||||
original_name = cls._model_classes.get(class_id)
|
||||
if original_name:
|
||||
return original_name
|
||||
return f"未知类别_{class_id}"
|
||||
# 回退到默认中文类别
|
||||
return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, f"未知类别_{class_id}")
|
||||
|
||||
@classmethod
|
||||
def get_class_color(cls, class_id):
|
||||
"""获取类别对应的颜色"""
|
||||
# 如果class_id是字符串,尝试转换为整数或使用hash
|
||||
if isinstance(class_id, str):
|
||||
# 尝试从类别名称获取ID
|
||||
if cls._model_classes and class_id in cls._model_classes.values():
|
||||
# 找到对应的ID
|
||||
for id_val, name in cls._model_classes.items():
|
||||
if name == class_id:
|
||||
class_id = id_val
|
||||
break
|
||||
else:
|
||||
# 使用hash生成一个稳定的索引
|
||||
class_id = hash(class_id) % len(cls.COLOR_PALETTE)
|
||||
|
||||
# 确保class_id是整数
|
||||
if not isinstance(class_id, int):
|
||||
class_id = 0
|
||||
|
||||
if cls._model_colors is not None:
|
||||
return cls._model_colors.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
|
||||
return cls.DEFAULT_CLASS_COLORS.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
|
||||
|
||||
@classmethod
|
||||
def get_all_classes(cls):
|
||||
"""获取所有类别信息"""
|
||||
if cls._model_classes is not None:
|
||||
return cls._model_classes
|
||||
return cls.DEFAULT_CLASS_NAMES_CN
|
||||
|
||||
# 全局配置实例
|
||||
config = DetectionConfig()
|
||||
453
detector.py
Normal file
453
detector.py
Normal file
@ -0,0 +1,453 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频目标检测系统
|
||||
使用YOLO模型对视频进行道路损伤检测,并抽取包含目标的帧
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from config import config, DetectionConfig
|
||||
|
||||
class VideoDetectionSystem:
|
||||
def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"):
|
||||
"""
|
||||
初始化视频检测系统
|
||||
|
||||
Args:
|
||||
model_path (str): YOLO模型文件路径
|
||||
confidence_threshold (float): 置信度阈值
|
||||
input_folder (str): 输入视频文件夹
|
||||
output_folder (str): 输出图片文件夹
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.input_folder = Path(input_folder)
|
||||
self.output_folder = Path(output_folder)
|
||||
|
||||
# 创建输出文件夹
|
||||
self.output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 设置日志
|
||||
self.setup_logging()
|
||||
|
||||
# 加载模型
|
||||
self.load_model()
|
||||
|
||||
# 支持的视频格式
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
|
||||
|
||||
# 动态检测参数
|
||||
self.adaptive_confidence = True # 启用自适应置信度
|
||||
self.min_confidence = 0.2 # 最低置信度
|
||||
self.max_confidence = 0.8 # 最高置信度
|
||||
self.scene_adaptation = True # 启用场景自适应
|
||||
|
||||
# 检测结果统计
|
||||
self.detection_stats = {
|
||||
'total_videos': 0,
|
||||
'total_frames': 0,
|
||||
'detected_frames': 0,
|
||||
'saved_frames': 0,
|
||||
'detection_results': []
|
||||
}
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
'adaptive_detections': 0
|
||||
}
|
||||
|
||||
# 检测质量阈值
|
||||
self.detection_quality_threshold = 0.3
|
||||
|
||||
def setup_logging(self):
|
||||
"""设置日志记录"""
|
||||
log_folder = Path('logs')
|
||||
log_folder.mkdir(exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = log_folder / f'detection_{timestamp}.log'
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file, encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def load_model(self):
|
||||
"""加载YOLO模型"""
|
||||
try:
|
||||
self.model = YOLO(self.model_path)
|
||||
self.logger.info(f"成功加载模型: {self.model_path}")
|
||||
|
||||
# 动态加载模型类别信息到配置中
|
||||
DetectionConfig.load_model_classes(self.model_path)
|
||||
|
||||
# 打印模型信息
|
||||
if hasattr(self.model, 'names'):
|
||||
self.logger.info(f"模型类别数量: {len(self.model.names)}")
|
||||
self.logger.info(f"检测类别: {list(self.model.names.values())}")
|
||||
self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}")
|
||||
|
||||
# 分析模型类型并设置优化参数
|
||||
self.analyze_model_type()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载模型失败: {e}")
|
||||
raise
|
||||
|
||||
def analyze_model_type(self):
|
||||
"""分析模型类型并设置相应的检测参数"""
|
||||
model_classes = list(self.model.names.values())
|
||||
|
||||
# 检测是否为道路损伤专用模型
|
||||
road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤']
|
||||
is_road_damage_model = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords)
|
||||
|
||||
if is_road_damage_model:
|
||||
self.logger.info("检测到道路损伤专用模型,使用优化参数")
|
||||
self.confidence_threshold = max(0.25, self.confidence_threshold) # 道路损伤检测使用较低阈值
|
||||
self.min_confidence = 0.15
|
||||
else:
|
||||
self.logger.info("检测到通用模型,使用标准参数")
|
||||
self.confidence_threshold = max(0.4, self.confidence_threshold) # 通用模型使用较高阈值
|
||||
self.min_confidence = 0.3
|
||||
|
||||
self.logger.info(f"调整后的置信度阈值: {self.confidence_threshold}")
|
||||
|
||||
def get_video_files(self):
|
||||
"""获取输入文件夹中的所有视频文件"""
|
||||
video_files = []
|
||||
|
||||
if not self.input_folder.exists():
|
||||
self.logger.warning(f"输入文件夹不存在: {self.input_folder}")
|
||||
return video_files
|
||||
|
||||
for file_path in self.input_folder.rglob('*'):
|
||||
if file_path.is_file() and file_path.suffix.lower() in self.video_extensions:
|
||||
video_files.append(file_path)
|
||||
|
||||
self.logger.info(f"找到 {len(video_files)} 个视频文件")
|
||||
return video_files
|
||||
|
||||
def detect_frame(self, frame, adaptive_conf=None):
|
||||
"""对单帧进行目标检测"""
|
||||
try:
|
||||
# 使用自适应置信度或默认置信度
|
||||
conf_threshold = adaptive_conf if adaptive_conf is not None else self.confidence_threshold
|
||||
|
||||
# 多尺度检测以提高检测率
|
||||
results = self.model(frame, conf=conf_threshold, verbose=False,
|
||||
imgsz=640, augment=True) # 启用测试时增强
|
||||
|
||||
# 如果没有检测到目标且启用自适应,尝试降低置信度
|
||||
if self.adaptive_confidence and (not results or not results[0].boxes or len(results[0].boxes) == 0):
|
||||
if conf_threshold > self.min_confidence:
|
||||
lower_conf = max(self.min_confidence, conf_threshold - 0.1)
|
||||
self.logger.debug(f"降低置信度重试: {lower_conf}")
|
||||
results = self.model(frame, conf=lower_conf, verbose=False,
|
||||
imgsz=640, augment=True)
|
||||
|
||||
return results[0] if results else None
|
||||
except Exception as e:
|
||||
self.logger.error(f"检测失败: {e}")
|
||||
return None
|
||||
|
||||
def adaptive_detect_frame(self, frame, base_confidence):
|
||||
"""自适应检测帧"""
|
||||
try:
|
||||
# 尝试不同的置信度阈值
|
||||
confidence_levels = [base_confidence, base_confidence * 0.8, base_confidence * 0.6]
|
||||
|
||||
for conf in confidence_levels:
|
||||
if conf < self.min_confidence:
|
||||
continue
|
||||
|
||||
results = self.model(frame, conf=conf, verbose=False)
|
||||
|
||||
if len(results) > 0 and len(results[0].boxes) > 0:
|
||||
# 过滤低质量检测
|
||||
boxes = results[0].boxes
|
||||
valid_indices = []
|
||||
|
||||
for i in range(len(boxes)):
|
||||
confidence = float(boxes.conf[i])
|
||||
if confidence >= self.min_confidence:
|
||||
valid_indices.append(i)
|
||||
|
||||
if valid_indices:
|
||||
# 直接使用原始结果,但只返回高质量检测
|
||||
return results
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
print(f"自适应检测失败: {e}")
|
||||
return []
|
||||
|
||||
def draw_detections(self, frame, result):
|
||||
"""在帧上绘制检测结果"""
|
||||
if result is None or result.boxes is None:
|
||||
return frame
|
||||
|
||||
# 转换为PIL图像以支持中文字体
|
||||
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
|
||||
# 尝试加载中文字体
|
||||
try:
|
||||
# Windows系统字体路径
|
||||
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
|
||||
if not os.path.exists(font_path):
|
||||
font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑
|
||||
if not os.path.exists(font_path):
|
||||
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
|
||||
|
||||
if os.path.exists(font_path):
|
||||
font = ImageFont.truetype(font_path, 20)
|
||||
else:
|
||||
font = ImageFont.load_default()
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for box in result.boxes:
|
||||
# 获取边界框坐标
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
|
||||
# 获取置信度和类别
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
|
||||
# 只显示中文类别名称
|
||||
display_name = config.get_class_name_cn(class_id)
|
||||
|
||||
# 根据置信度调整颜色强度
|
||||
base_color = config.get_class_color(class_id)
|
||||
intensity = min(1.0, confidence + 0.3) # 确保颜色不会太暗
|
||||
color = tuple(int(c * intensity) for c in base_color)
|
||||
|
||||
# 绘制边界框(使用PIL)
|
||||
bg_color = tuple(reversed(color)) # BGR转RGB
|
||||
line_width = max(2, int(confidence * 4)) # 根据置信度调整线宽
|
||||
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width)
|
||||
|
||||
# 准备标签文字
|
||||
label = f"{display_name}: {confidence:.3f}"
|
||||
|
||||
# 获取文字尺寸
|
||||
bbox = draw.textbbox((0, 0), label, font=font)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
|
||||
# 绘制标签背景(使用PIL)
|
||||
draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color)
|
||||
|
||||
# 绘制文字(使用PIL)
|
||||
draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255))
|
||||
|
||||
# 转换回OpenCV格式
|
||||
annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
return annotated_frame
|
||||
|
||||
def process_video(self, video_path):
|
||||
"""处理单个视频文件"""
|
||||
self.logger.info(f"开始处理视频: {video_path.name}")
|
||||
|
||||
# 创建视频专用输出文件夹
|
||||
video_output_folder = self.output_folder / video_path.stem
|
||||
video_output_folder.mkdir(exist_ok=True)
|
||||
|
||||
# 文件夹将在检测到对应类别时按需创建
|
||||
|
||||
# 打开视频
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
self.logger.error(f"无法打开视频文件: {video_path}")
|
||||
return
|
||||
|
||||
# 获取视频信息
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}秒")
|
||||
|
||||
frame_count = 0
|
||||
detected_count = 0
|
||||
saved_count = 0
|
||||
|
||||
video_detections = {
|
||||
'video_name': video_path.name,
|
||||
'total_frames': total_frames,
|
||||
'fps': fps,
|
||||
'duration': duration,
|
||||
'detections': []
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 每100帧显示一次进度
|
||||
if frame_count % 100 == 0:
|
||||
progress = (frame_count / total_frames) * 100
|
||||
self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})")
|
||||
|
||||
# 进行目标检测
|
||||
result = self.detect_frame(frame)
|
||||
|
||||
# 如果检测到目标,保存帧
|
||||
if result is not None and result.boxes is not None and len(result.boxes) > 0:
|
||||
detected_count += 1
|
||||
timestamp = frame_count / fps if fps > 0 else frame_count
|
||||
|
||||
# 获取当前帧检测到的所有类别
|
||||
detected_classes = set()
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
detected_classes.add(class_id)
|
||||
|
||||
# 只为检测到的类别创建文件夹并保存图片
|
||||
annotated_frame = self.draw_detections(frame, result)
|
||||
detection_info = {
|
||||
'frame_number': frame_count,
|
||||
'timestamp': timestamp,
|
||||
'detections': []
|
||||
}
|
||||
|
||||
# 按检测到的类别保存图片
|
||||
for class_id in detected_classes:
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
class_name_en = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
|
||||
|
||||
# 只为实际检测到的类别创建文件夹
|
||||
class_folder = video_output_folder / class_name_cn
|
||||
class_folder.mkdir(exist_ok=True)
|
||||
|
||||
# 保存检测到该类别的帧(只保存标注后的图片)
|
||||
annotated_filename = f"detected_{frame_count:06d}_t{timestamp:.2f}s_{class_name_cn}.jpg"
|
||||
annotated_path = class_folder / annotated_filename
|
||||
cv2.imwrite(str(annotated_path), annotated_frame)
|
||||
saved_count += 1
|
||||
|
||||
# 记录检测信息
|
||||
for box in result.boxes:
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
class_name = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
|
||||
bbox = box.xyxy[0].cpu().numpy().tolist()
|
||||
|
||||
detection_info['detections'].append({
|
||||
'class_name': class_name,
|
||||
'class_name_cn': config.get_class_name_cn(class_id),
|
||||
'class_id': class_id,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox
|
||||
})
|
||||
|
||||
video_detections['detections'].append(detection_info)
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
# 保存检测结果到JSON文件
|
||||
json_path = video_output_folder / f"{video_path.stem}_detections.json"
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(video_detections, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.logger.info(f"视频处理完成: {video_path.name}")
|
||||
self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}")
|
||||
|
||||
# 更新统计信息
|
||||
self.detection_stats['total_videos'] += 1
|
||||
self.detection_stats['total_frames'] += frame_count
|
||||
self.detection_stats['detected_frames'] += detected_count
|
||||
self.detection_stats['saved_frames'] += saved_count
|
||||
self.detection_stats['detection_results'].append(video_detections)
|
||||
|
||||
def process_all_videos(self):
|
||||
"""处理所有视频文件"""
|
||||
video_files = self.get_video_files()
|
||||
|
||||
if not video_files:
|
||||
self.logger.warning("没有找到视频文件")
|
||||
return
|
||||
|
||||
self.logger.info(f"开始处理 {len(video_files)} 个视频文件")
|
||||
|
||||
for i, video_path in enumerate(video_files, 1):
|
||||
self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===")
|
||||
self.process_video(video_path)
|
||||
|
||||
# 保存总体统计信息
|
||||
self.save_summary_report()
|
||||
|
||||
self.logger.info("\n=== 处理完成 ===")
|
||||
self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']} 个")
|
||||
self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']} 帧")
|
||||
self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']} 帧")
|
||||
self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']} 张")
|
||||
|
||||
def save_summary_report(self):
|
||||
"""保存总结报告"""
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
report_path = self.output_folder / f"detection_summary_{timestamp}.json"
|
||||
|
||||
summary = {
|
||||
'timestamp': timestamp,
|
||||
'model_path': str(self.model_path),
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'input_folder': str(self.input_folder),
|
||||
'output_folder': str(self.output_folder),
|
||||
'statistics': self.detection_stats
|
||||
}
|
||||
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.logger.info(f"总结报告已保存: {report_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='视频目标检测系统')
|
||||
parser.add_argument('--model', type=str,
|
||||
default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt',
|
||||
help='YOLO模型文件路径')
|
||||
parser.add_argument('--confidence', type=float, default=0.5,
|
||||
help='置信度阈值 (默认: 0.5)')
|
||||
parser.add_argument('--input', type=str, default='input_videos',
|
||||
help='输入视频文件夹 (默认: input_videos)')
|
||||
parser.add_argument('--output', type=str, default='output_frames',
|
||||
help='输出图片文件夹 (默认: output_frames)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建检测系统
|
||||
detector = VideoDetectionSystem(
|
||||
model_path=args.model,
|
||||
confidence_threshold=args.confidence,
|
||||
input_folder=args.input,
|
||||
output_folder=args.output
|
||||
)
|
||||
|
||||
# 处理所有视频
|
||||
detector.process_all_videos()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
33
requirements.txt
Normal file
33
requirements.txt
Normal file
@ -0,0 +1,33 @@
|
||||
# 视频检测系统依赖包
|
||||
|
||||
# 核心依赖
|
||||
ultralytics>=8.0.0
|
||||
opencv-python>=4.5.0
|
||||
numpy>=1.21.0
|
||||
Pillow>=8.0.0
|
||||
|
||||
# 数据处理
|
||||
pandas>=1.3.0
|
||||
|
||||
# 进度显示
|
||||
tqdm>=4.62.0
|
||||
|
||||
# 配置文件处理
|
||||
PyYAML>=6.0
|
||||
|
||||
# 图像处理增强
|
||||
scipy>=1.7.0
|
||||
matplotlib>=3.4.0
|
||||
|
||||
# GPU支持 (可选)
|
||||
# torch>=1.9.0
|
||||
# torchvision>=0.10.0
|
||||
|
||||
# 系统工具
|
||||
psutil>=5.8.0
|
||||
|
||||
# Web应用框架
|
||||
flask>=2.0.0
|
||||
werkzeug>=2.0.0
|
||||
flask-wtf>=1.0.0
|
||||
python-dotenv>=0.15.0
|
||||
Loading…
Reference in New Issue
Block a user