wrz-yolo/app.py
2025-06-27 17:13:03 +08:00

734 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, send_from_directory
import os
import shutil
import time
import threading
import json
from datetime import datetime
from pathlib import Path
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify
from werkzeug.utils import secure_filename
from config import config
from batch_detector import BatchVideoDetector
app = Flask(__name__)
app.secret_key = 'road_damage_detection_system'
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max upload size
# 确保必要的文件夹存在
for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']:
os.makedirs(folder, exist_ok=True)
# 全局变量用于跟踪检测进度和状态
detection_progress = 0
detection_status = "idle" # idle, running, completed, error
detection_thread = None
current_detection_info = {}
detection_start_time = None
detection_results = None
current_report_path = None
current_video_name = ""
total_videos = 0
current_video_index = 0
current_frame_count = 0
total_frame_count = 0
# 允许的文件扩展名
ALLOWED_VIDEO_EXTENSIONS = {ext.lstrip('.') for ext in config.SUPPORTED_VIDEO_FORMATS}
ALLOWED_MODEL_EXTENSIONS = {'pt', 'pth', 'weights'}
def allowed_file(filename, allowed_extensions):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
# 主页
@app.route('/', methods=['GET', 'POST'])
def index():
global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time
# 处理POST请求开始检测
if request.method == 'POST':
if detection_status == "running":
flash('已有检测任务正在运行', 'warning')
else:
# 获取选择的模型和置信度阈值
model_name = request.form.get('model_select')
confidence_threshold = float(request.form.get('confidence_threshold', 0.25))
# 检查是否有视频文件
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
if not videos:
flash('输入文件夹中没有视频文件,请先上传视频', 'danger')
elif not model_name:
flash('请选择一个检测模型', 'danger')
else:
model_path = os.path.join('models', model_name)
if not os.path.exists(model_path):
flash(f'模型文件 {model_name} 不存在', 'danger')
else:
# 更新配置
config.MODEL_PATH = model_path
config.CONFIDENCE_THRESHOLD = confidence_threshold
# 重置检测状态和进度
detection_progress = 0
detection_status = "running"
detection_results = None
current_report_path = None
detection_start_time = datetime.now()
# 启动检测线程
detection_thread = threading.Thread(target=run_detection)
detection_thread.daemon = True
detection_thread.start()
flash('检测已开始,请等待结果', 'info')
# 获取可用的模型列表
print(f"models文件夹路径: {os.path.abspath('models')}")
print(f"models文件夹存在: {os.path.exists('models')}")
if os.path.exists('models'):
all_files = os.listdir('models')
print(f"models文件夹中的所有文件: {all_files}")
for f in all_files:
print(f"检查文件 {f}: allowed_file结果 = {allowed_file(f, ALLOWED_MODEL_EXTENSIONS)}")
models = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
print(f"找到的模型文件: {models}")
# 获取输入文件夹中的视频文件
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
# 获取检测报告列表
reports = []
if os.path.exists(config.OUTPUT_FOLDER):
for report_file in os.listdir(config.OUTPUT_FOLDER):
if report_file.startswith('batch_detection_report_') and report_file.endswith('.json'):
report_path = os.path.join(config.OUTPUT_FOLDER, report_file)
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
reports.append({
'id': report_file,
'timestamp': report_data.get('timestamp', '未知'),
'video_count': report_data.get('statistics', {}).get('processed_videos', 0),
'total_frames': report_data.get('statistics', {}).get('total_frames', 0),
'detected_frames': report_data.get('statistics', {}).get('detected_frames', 0)
})
except Exception as e:
print(f"Error loading report {report_path}: {e}")
# 按时间戳排序,最新的在前面
reports.sort(key=lambda x: x['timestamp'], reverse=True)
return render_template('index.html',
models=models,
videos=videos,
reports=reports,
detection_status=detection_status,
detection_progress=detection_progress,
detection_in_progress=(detection_status == "running"),
confidence_threshold=config.CONFIDENCE_THRESHOLD,
config=config)
# 上传模型
@app.route('/upload_model', methods=['POST'])
def upload_model():
if 'model_file' not in request.files:
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
file = request.files['model_file']
if file.filename == '':
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
if file and allowed_file(file.filename, ALLOWED_MODEL_EXTENSIONS):
filename = secure_filename(file.filename)
model_path = os.path.join('models', filename)
file.save(model_path)
# 立即检测模型信息并缓存
try:
print(f"正在分析上传的模型: {filename}")
from ultralytics import YOLO
model = YOLO(model_path)
# 保存模型信息到JSON文件
model_info = {
'filename': filename,
'upload_time': datetime.now().isoformat(),
'classes': dict(model.names) if hasattr(model, 'names') else {},
'class_count': len(model.names) if hasattr(model, 'names') else 0,
'model_type': 'YOLO',
'analyzed': True
}
# 创建模型信息文件
info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json")
with open(info_path, 'w', encoding='utf-8') as f:
import json
json.dump(model_info, f, ensure_ascii=False, indent=2)
# 显示模型分析结果
if hasattr(model, 'names'):
scene_types = list(model.names.values())
flash(f'模型 {filename} 上传成功!检测到 {len(scene_types)} 种场景类型: {", ".join(scene_types)}', 'success')
else:
flash(f'模型 {filename} 上传成功!', 'success')
except Exception as e:
print(f"模型分析失败: {e}")
flash(f'模型 {filename} 上传成功,但分析模型信息时出错: {str(e)}', 'warning')
else:
flash('不支持的文件类型', 'danger')
return redirect(url_for('index'))
# 获取模型列表
@app.route('/get_models', methods=['GET'])
def get_models():
models = []
model_files = [f for f in os.listdir('models') if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
for model_file in model_files:
model_info = {
'filename': model_file,
'analyzed': False,
'classes': {},
'class_count': 0,
'scene_types': []
}
# 尝试读取模型信息文件
info_path = os.path.join('models', f"{os.path.splitext(model_file)[0]}_info.json")
if os.path.exists(info_path):
try:
with open(info_path, 'r', encoding='utf-8') as f:
import json
cached_info = json.load(f)
model_info.update(cached_info)
model_info['scene_types'] = list(cached_info.get('classes', {}).values())
except Exception as e:
print(f"读取模型信息文件失败: {e}")
models.append(model_info)
return jsonify(models)
# 删除模型
@app.route('/delete_model/<filename>', methods=['POST'])
def delete_model(filename):
try:
model_path = os.path.join('models', filename)
info_path = os.path.join('models', f"{os.path.splitext(filename)[0]}_info.json")
deleted_files = []
if os.path.exists(model_path):
os.remove(model_path)
deleted_files.append('模型文件')
# 同时删除模型信息文件
if os.path.exists(info_path):
os.remove(info_path)
deleted_files.append('信息文件')
if deleted_files:
flash(f'模型 {filename} 及其{"".join(deleted_files)}已删除', 'success')
else:
flash(f'模型 {filename} 不存在', 'warning')
except Exception as e:
flash(f'删除模型时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 上传视频
@app.route('/upload_video', methods=['POST'])
def upload_video():
if 'video_files[]' not in request.files:
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
files = request.files.getlist('video_files[]')
if not files or files[0].filename == '':
flash('没有选择文件', 'danger')
return redirect(url_for('index'))
success_count = 0
error_count = 0
for file in files:
if file and allowed_file(file.filename, ALLOWED_VIDEO_EXTENSIONS):
filename = secure_filename(file.filename)
file.save(os.path.join(config.INPUT_FOLDER, filename))
success_count += 1
else:
error_count += 1
if success_count > 0:
flash(f'成功上传 {success_count} 个视频文件', 'success')
if error_count > 0:
flash(f'{error_count} 个文件上传失败(不支持的文件类型)', 'warning')
return redirect(url_for('index'))
# 获取视频列表
@app.route('/get_videos', methods=['GET'])
def get_videos():
videos = [f for f in os.listdir(config.INPUT_FOLDER) if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
return jsonify(videos)
# 删除视频
@app.route('/delete_video/<filename>', methods=['POST'])
def delete_video(filename):
try:
video_path = os.path.join(config.INPUT_FOLDER, filename)
if os.path.exists(video_path):
os.remove(video_path)
flash(f'视频 {filename} 已删除', 'success')
else:
flash(f'视频 {filename} 不存在', 'warning')
except Exception as e:
flash(f'删除视频时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 清空输入文件夹
@app.route('/clear_input_folder', methods=['POST'])
def clear_input_folder():
try:
for filename in os.listdir(config.INPUT_FOLDER):
file_path = os.path.join(config.INPUT_FOLDER, filename)
if os.path.isfile(file_path):
os.unlink(file_path)
flash('输入文件夹已清空', 'success')
except Exception as e:
flash(f'清空文件夹时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 开始检测
@app.route('/start_detection', methods=['POST'])
def start_detection():
global detection_thread, detection_status, detection_progress, detection_results, current_report_path, detection_start_time
print("收到检测请求")
print(f"当前检测状态: {detection_status}")
if detection_status == "running":
flash('已有检测任务正在运行', 'warning')
return redirect(url_for('index'))
# 获取选择的模型和置信度阈值
model_name = request.form.get('model_select')
confidence_threshold = float(request.form.get('confidence_threshold', 0.25))
print(f"接收到的模型名称: {model_name}")
print(f"接收到的置信度阈值: {confidence_threshold}")
# 检查是否有视频文件
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if allowed_file(f, ALLOWED_VIDEO_EXTENSIONS)]
if not videos:
flash('输入文件夹中没有视频文件,请先上传视频', 'danger')
return redirect(url_for('index'))
# 检查模型文件是否存在
if not model_name:
flash('请选择一个检测模型', 'danger')
return redirect(url_for('index'))
model_path = os.path.join('models', model_name)
if not os.path.exists(model_path):
flash(f'模型文件 {model_name} 不存在', 'danger')
return redirect(url_for('index'))
# 更新配置
config.MODEL_PATH = model_path
config.CONFIDENCE_THRESHOLD = confidence_threshold
# 重置检测状态和进度
detection_progress = 0
detection_status = "running"
detection_results = None
current_report_path = None
detection_start_time = datetime.now()
# 重置进度相关变量
global current_video_name, total_videos, current_video_index, current_frame_count, total_frame_count
current_video_name = ""
total_videos = len(videos)
current_video_index = 0
current_frame_count = 0
total_frame_count = 0
# 启动检测线程
detection_thread = threading.Thread(target=run_detection)
detection_thread.daemon = True
detection_thread.start()
flash('检测已开始,请等待结果', 'info')
return redirect(url_for('index'))
# 检测线程函数
def run_detection():
global detection_progress, detection_status, detection_results, current_report_path, current_video_name, current_video_index, current_frame_count
try:
print(f"开始检测,模型路径: {config.MODEL_PATH}")
print(f"置信度阈值: {config.CONFIDENCE_THRESHOLD}")
print(f"输入文件夹: {config.INPUT_FOLDER}")
# 获取视频列表
videos = [f for f in os.listdir(config.INPUT_FOLDER)
if any(f.endswith(ext) for ext in config.SUPPORTED_VIDEO_FORMATS)]
print(f"找到视频文件: {videos}")
# 创建检测报告基本信息
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# 创建检测器实例
print("正在创建检测器实例...")
detector = BatchVideoDetector(
model_path=config.MODEL_PATH,
confidence=config.CONFIDENCE_THRESHOLD
)
print("检测器实例创建成功")
# 显示模型检测能力信息
print("\n=== 网页端模型检测分析 ===")
if hasattr(detector.model, 'names'):
model_classes = list(detector.model.names.values())
print(f"当前模型可检测场景: {model_classes}")
print(f"使用的置信度阈值: {config.CONFIDENCE_THRESHOLD}")
print("=" * 40)
# 设置进度回调
def progress_callback(status, progress, video_name=None, video_index=None, frame_count=None):
global detection_status, detection_progress, current_video_name, current_video_index, current_frame_count
detection_status = status
detection_progress = progress
# 更新详细进度信息
if video_name:
current_video_name = video_name
if video_index is not None:
current_video_index = video_index
if frame_count is not None:
current_frame_count = frame_count
print(f"进度更新: {status} - {progress}%")
if video_name:
print(f"当前视频: {video_name} ({video_index}/{total_videos})")
if frame_count is not None:
print(f"已处理帧数: {frame_count}")
detector.set_progress_callback(progress_callback)
# 执行检测
detection_status = "开始检测视频..."
report_path = detector.process_all_videos(config.INPUT_FOLDER, config.OUTPUT_FOLDER)
# 读取检测结果
if report_path and os.path.exists(report_path):
with open(report_path, 'r', encoding='utf-8') as f:
detection_results = json.load(f)
current_report_path = report_path
detection_status = "检测完成"
detection_progress = 100
else:
detection_status = "检测完成,但未生成报告"
detection_progress = 100
except Exception as e:
detection_status = f"检测失败: {str(e)}"
print(f"检测异常: {str(e)}")
import traceback
traceback.print_exc()
finally:
# 如果状态仍为running则更新为失败
if detection_status == "running":
detection_status = "失败"
# 确保进度显示为100%
if detection_progress < 100:
detection_progress = 100
# 获取检测进度
@app.route('/detection_progress')
def get_detection_progress():
# 计算检测时长
elapsed_time = ""
if detection_start_time and detection_status == "running":
elapsed_seconds = (datetime.now() - detection_start_time).total_seconds()
minutes, seconds = divmod(elapsed_seconds, 60)
elapsed_time = f"{int(minutes)}{int(seconds)}"
# 构建详细状态信息
detailed_status = detection_status
if detection_status == "running" and current_video_name:
detailed_status = f"正在处理: {current_video_name} ({current_video_index}/{total_videos})"
if current_frame_count > 0:
detailed_status += f" - 已处理帧数: {current_frame_count}"
return jsonify({
'in_progress': detection_status == "running",
'progress': detection_progress,
'status': detailed_status,
'elapsed_time': elapsed_time,
'current_video': current_video_name,
'current_video_index': current_video_index,
'total_videos': total_videos,
'current_frame_count': current_frame_count,
'total_frame_count': total_frame_count
})
# 查看检测报告
@app.route('/report/<report_filename>')
def view_report(report_filename):
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
if not os.path.exists(report_path):
flash('报告文件不存在', 'danger')
return redirect(url_for('index'))
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
# 计算检测率
total_frames = report_data.get('statistics', {}).get('total_frames', 0)
detected_frames = report_data.get('statistics', {}).get('detected_frames', 0)
detection_rate = 0
if total_frames > 0:
detection_rate = (detected_frames / total_frames) * 100
# 为每个视频计算检测率
for video in report_data.get('videos', []):
video_total_frames = video.get('total_frames', 0)
video_detected_frames = video.get('detected_frames', 0)
if video_total_frames > 0:
video['detection_rate'] = (video_detected_frames / video_total_frames) * 100
else:
video['detection_rate'] = 0
return render_template('report.html',
report=report_data,
report_filename=report_filename,
detection_rate=detection_rate)
except Exception as e:
flash(f'读取报告时出错: {str(e)}', 'danger')
return redirect(url_for('index'))
# 查看视频检测结果
@app.route('/video_results/<report_filename>/<video_name>')
def view_video_results(report_filename, video_name):
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
if not os.path.exists(report_path):
flash('报告文件不存在', 'danger')
return redirect(url_for('index'))
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
# 查找特定视频的结果
video_data = None
for video in report_data['videos']:
if video['video_name'] == video_name:
video_data = video
break
if not video_data:
flash('视频结果不存在', 'danger')
return redirect(url_for('view_report', report_filename=report_filename))
# 获取检测到的帧列表
video_output_dir = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video_name)[0])
detected_frames = []
if os.path.exists(video_output_dir):
for frame_file in os.listdir(video_output_dir):
if frame_file.startswith('detected_') and frame_file.endswith('.jpg'):
try:
frame_number = int(frame_file.split('_')[1].split('.')[0])
frame_path = os.path.join(os.path.basename(video_output_dir), frame_file)
# 获取对应的JSON文件以读取检测信息
json_file = frame_file.replace('.jpg', '.json')
json_path = os.path.join(video_output_dir, json_file)
detections = []
if os.path.exists(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
frame_data = json.load(f)
detections = frame_data.get('detections', [])
detected_frames.append({
'frame_number': frame_number,
'filename': frame_file,
'path': frame_path,
'timestamp': frame_number / video_data.get('fps', 30), # 计算时间戳
'detections': detections,
'detection_count': len(detections)
})
except Exception as e:
print(f"Error processing frame {frame_file}: {e}")
# 按帧号排序
detected_frames.sort(key=lambda x: x['frame_number'])
# 计算检测率
total_frames = video_data.get('total_frames', 0)
detected_count = video_data.get('detected_frames', 0)
detection_rate = 0
if total_frames > 0:
detection_rate = (detected_count / total_frames) * 100
# 获取时间线数据
timeline_data = []
for frame in detected_frames:
timeline_data.append({
'timestamp': frame['timestamp'],
'frame_number': frame['frame_number'],
'detection_count': frame['detection_count']
})
return render_template('video_results.html',
report_filename=report_filename,
report=report_data,
video=video_data,
detected_frames=detected_frames,
detection_rate=detection_rate,
timeline_data=json.dumps(timeline_data))
except Exception as e:
flash(f'读取视频结果时出错: {str(e)}', 'danger')
return redirect(url_for('view_report', report_filename=report_filename))
# 获取检测到的帧图像
@app.route('/frame/<path:frame_path>')
def get_frame(frame_path):
directory, filename = os.path.split(frame_path)
return send_from_directory(os.path.join(config.OUTPUT_FOLDER, directory), filename)
# 删除报告
@app.route('/delete_report/<report_filename>', methods=['POST'])
def delete_report(report_filename):
report_path = os.path.join(config.OUTPUT_FOLDER, report_filename)
if not os.path.exists(report_path):
flash('报告文件不存在', 'danger')
return redirect(url_for('index'))
# 读取报告以获取视频文件夹列表
try:
with open(report_path, 'r', encoding='utf-8') as f:
report_data = json.load(f)
# 删除每个视频的输出文件夹
for video in report_data['videos']:
video_output_folder = os.path.join(config.OUTPUT_FOLDER, os.path.splitext(video['video_name'])[0])
if os.path.exists(video_output_folder) and os.path.isdir(video_output_folder):
shutil.rmtree(video_output_folder)
# 删除报告文件
os.remove(report_path)
flash(f'报告 {report_filename} 已删除', 'success')
except Exception as e:
flash(f'删除报告失败: {str(e)}', 'danger')
return redirect(url_for('index'))
# 添加BatchVideoDetector的进度回调方法
def add_progress_callback_to_detector():
# 检查BatchVideoDetector类是否已有set_progress_callback方法
if not hasattr(BatchVideoDetector, 'set_progress_callback'):
def set_progress_callback(self, callback):
self.progress_callback = callback
def update_progress(self, status, progress, video_name=None, video_index=None, frame_count=None):
if hasattr(self, 'progress_callback') and self.progress_callback:
self.progress_callback(status, progress, video_name, video_index, frame_count)
# 添加方法到BatchVideoDetector类
BatchVideoDetector.set_progress_callback = set_progress_callback
BatchVideoDetector.update_progress = update_progress
# 修改process_all_videos方法以支持进度更新
original_process_all_videos = BatchVideoDetector.process_all_videos
def process_all_videos_with_progress(self, input_folder="input_videos", output_folder="output_frames"):
self.update_progress("获取视频文件列表", 0)
video_files = self.get_video_files(input_folder)
total_videos = len(video_files)
if total_videos == 0:
self.update_progress("没有找到视频文件", 100)
return None
self.update_progress(f"开始处理 {total_videos} 个视频文件", 5)
# 处理每个视频文件
for i, video_path in enumerate(video_files, 1):
progress = int(5 + (i / total_videos) * 90)
video_name = video_path.name
self.update_progress(f"处理视频 {i}/{total_videos}: {video_name}", progress, video_name, i)
self.process_video(video_path, output_folder)
self.update_progress("生成检测报告", 95)
self.stats['end_time'] = datetime.now()
report_path = self.save_final_report(output_folder)
self.print_summary()
self.update_progress("检测完成", 100)
return report_path
# 修改save_final_report方法以返回报告路径
original_save_final_report = BatchVideoDetector.save_final_report
def save_final_report_with_return(self, output_folder):
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
# 计算总处理时间
total_seconds = 0
if self.stats['start_time'] and self.stats['end_time']:
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
# 创建报告数据
report_data = {
'timestamp': timestamp,
'model_path': self.model_path,
'confidence_threshold': self.confidence,
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
'statistics': {
'total_videos': self.stats['total_videos'],
'processed_videos': self.stats['processed_videos'],
'total_frames': self.stats['total_frames'],
'detected_frames': self.stats['detected_frames'],
'filtered_frames': self.stats['filtered_frames'],
'saved_images': self.stats['saved_images'],
'processing_time_seconds': total_seconds
},
'errors': self.stats['errors'],
'videos': []
}
# 保存为JSON
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report_data, f, ensure_ascii=False, indent=2)
print(f"\n批处理报告已保存: {report_path}")
return str(report_path)
BatchVideoDetector.process_all_videos = process_all_videos_with_progress
BatchVideoDetector.save_final_report = save_final_report_with_return
# 在应用启动前添加进度回调方法
add_progress_callback_to_detector()
if __name__ == '__main__':
# 确保必要的文件夹存在
for folder in [config.INPUT_FOLDER, config.OUTPUT_FOLDER, config.LOG_FOLDER, 'models', 'static/css']:
os.makedirs(folder, exist_ok=True)
app.run(debug=True, host='0.0.0.0', port=5000)