wrz-yolo/batch_detector.py
2025-06-27 17:13:03 +08:00

642 lines
27 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量视频检测器
简化版本,专注于高效处理大量视频文件
"""
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import json
import time
from datetime import datetime
import torch
from config import config, DetectionConfig
from PIL import Image, ImageDraw, ImageFont
class BatchVideoDetector:
def __init__(self, model_path=None, confidence=0.5):
"""
初始化批量视频检测器
Args:
model_path (str): 模型路径如果为None则使用配置文件中的当前模型路径
confidence (float): 置信度阈值
"""
# 优先使用传入的模型路径,如果没有则使用配置中的当前模型路径
self.model_path = model_path or config.MODEL_PATH
self.confidence = confidence
# 创建必要文件夹
config.create_folders()
# 加载模型
self.load_model()
# 时间间隔过滤变量
self.last_saved_timestamp = 0
# 初始化场景类别变量
self.last_scene_classes = set()
# 动态检测参数
self.adaptive_confidence = True
self.min_confidence = 0.2
self.scene_confidence_history = [] # 记录场景置信度历史
self.detection_quality_threshold = 0.3 # 检测质量阈值
# 统计信息
self.stats = {
'start_time': None,
'end_time': None,
'total_videos': 0,
'processed_videos': 0,
'total_frames': 0,
'detected_frames': 0,
'saved_images': 0,
'filtered_frames': 0,
'adaptive_detections': 0, # 自适应检测次数
'errors': []
}
# 视频详细信息列表
self.video_details = []
def load_model(self):
"""加载YOLO模型"""
try:
print(f"正在加载模型: {self.model_path}")
self.model = YOLO(self.model_path)
print(f"模型加载成功!")
# 检查是否有缓存的模型信息
model_filename = os.path.basename(self.model_path)
info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json")
cached_info_loaded = False
if os.path.exists(info_path):
try:
with open(info_path, 'r', encoding='utf-8') as f:
import json
cached_info = json.load(f)
if cached_info.get('analyzed', False):
print(f"使用缓存的模型信息,跳过重复分析")
print(f"模型类别数量: {cached_info.get('class_count', 0)}")
print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}")
cached_info_loaded = True
except Exception as e:
print(f"读取缓存模型信息失败: {e},将重新分析")
# 如果没有缓存信息,则进行完整分析
if not cached_info_loaded:
# 动态加载模型类别信息到配置中
DetectionConfig.load_model_classes(self.model_path)
if hasattr(self.model, 'names'):
print(f"\n=== 模型检测场景分析 ===")
print(f"模型类别数量: {len(self.model.names)}")
print(f"模型原始类别: {list(self.model.names.values())}")
print(f"\n=== 中文类别映射 ===")
for class_id, class_name in self.model.names.items():
class_name_cn = config.get_class_name_cn(class_id)
print(f"类别 {class_id}: {class_name} -> {class_name_cn}")
# 分析模型类型并优化参数
self.analyze_and_optimize_model()
print(f"\n=== 检测场景总结 ===")
scene_types = list(DetectionConfig.get_all_classes().values())
print(f"本模型可检测的场景类型: {scene_types}")
print(f"检测置信度阈值: {self.confidence}")
print("=" * 40)
else:
# 使用缓存信息更新配置
DetectionConfig.load_model_classes(self.model_path)
except Exception as e:
print(f"模型加载失败: {e}")
raise
def analyze_and_optimize_model(self):
"""分析模型类型并优化检测参数"""
model_classes = list(self.model.names.values())
# 检测模型类型
road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤']
general_object_keywords = ['person', 'car', 'truck', 'bus', 'bicycle']
is_road_damage = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords)
is_general_object = any(keyword in ' '.join(model_classes).lower() for keyword in general_object_keywords)
if is_road_damage:
print("检测到道路损伤专用模型,优化参数设置")
self.confidence = max(0.25, self.confidence)
self.min_confidence = 0.15
self.detection_quality_threshold = 0.2
elif is_general_object:
print("检测到通用目标检测模型,使用标准参数")
self.confidence = max(0.4, self.confidence)
self.min_confidence = 0.3
self.detection_quality_threshold = 0.35
else:
print("检测到自定义模型,使用保守参数")
self.confidence = max(0.3, self.confidence)
self.min_confidence = 0.2
self.detection_quality_threshold = 0.25
print(f"优化后的置信度阈值: {self.confidence}")
print(f"最低置信度阈值: {self.min_confidence}")
print(f"检测质量阈值: {self.detection_quality_threshold}")
def get_video_files(self, input_folder="input_videos"):
"""获取所有视频文件"""
input_path = Path(input_folder)
if not input_path.exists():
print(f"输入文件夹不存在: {input_folder}")
return []
video_files = []
for ext in config.SUPPORTED_VIDEO_FORMATS:
video_files.extend(input_path.rglob(f"*{ext}"))
print(f"找到 {len(video_files)} 个视频文件")
return video_files
def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir):
"""检测单帧并保存结果"""
try:
# 进行多级检测
result = self.adaptive_detect_frame(frame)
# 如果检测到目标
if result is not None and result.boxes is not None and len(result.boxes) > 0:
# 获取当前帧的检测类别
current_scene_classes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
current_scene_classes.add(class_id)
# 检查时间间隔
time_diff = timestamp - self.last_saved_timestamp
# 如果时间间隔小于最小间隔,检查是否为新场景
if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0:
# 如果当前帧的场景类别与上一帧相同,则过滤
if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes:
self.stats['filtered_frames'] += 1
# 仍然返回检测信息,但不保存图像
detections = []
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
bbox = box.xyxy[0].cpu().numpy().tolist()
class_name = self.model.names.get(class_id, f"class_{class_id}")
class_name_cn = config.get_class_name_cn(class_id)
detections.append({
'class_id': class_id,
'class_name': class_name,
'class_name_cn': class_name_cn,
'confidence': confidence,
'bbox': bbox,
'filtered': True
})
return {
'frame_number': frame_number,
'timestamp': timestamp,
'detections': detections,
'filtered': True
}
# 更新最后检测到的场景类别
self.last_scene_classes = current_scene_classes
# 更新最后保存的时间戳
self.last_saved_timestamp = timestamp
# 创建文件名
base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s"
# 获取检测到的场景类别,用于分类保存
detected_scenes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
class_name_cn = config.get_class_name_cn(class_id)
detected_scenes.add(class_name_cn)
# 为每个检测到的场景创建对应的文件夹并保存图片
for scene_name in detected_scenes:
# 创建视频名称文件夹
video_folder = output_dir / video_name
video_folder.mkdir(parents=True, exist_ok=True)
# 创建场景文件夹
scene_folder = video_folder / scene_name
scene_folder.mkdir(parents=True, exist_ok=True)
# 保存原始帧
if config.SAVE_ORIGINAL_FRAMES:
original_path = scene_folder / f"{base_name}_original.jpg"
cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
# 绘制并保存标注帧
if config.SAVE_ANNOTATED_FRAMES:
annotated_frame = self.draw_detections(frame.copy(), result)
annotated_path = scene_folder / f"{base_name}_detected.jpg"
cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
# 收集检测信息
detections = []
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
bbox = box.xyxy[0].cpu().numpy().tolist()
class_name = self.model.names.get(class_id, f"class_{class_id}")
class_name_cn = config.get_class_name_cn(class_id)
detections.append({
'class_id': class_id,
'class_name': class_name,
'class_name_cn': class_name_cn,
'confidence': confidence,
'bbox': bbox
})
return {
'frame_number': frame_number,
'timestamp': timestamp,
'detections': detections
}
return None
except Exception as e:
print(f"处理帧 {frame_number} 时出错: {e}")
self.stats['errors'].append(f"{frame_number}: {str(e)}")
return None
def adaptive_detect_frame(self, frame):
"""自适应检测单帧"""
try:
# 第一次检测:使用标准置信度
results = self.model(frame, conf=self.confidence, verbose=False,
imgsz=640, augment=True)
result = results[0] if results else None
# 如果没有检测到目标且启用自适应检测
if self.adaptive_confidence and (not result or not result.boxes or len(result.boxes) == 0):
# 尝试降低置信度检测
lower_conf = max(self.min_confidence, self.confidence - 0.15)
if lower_conf < self.confidence:
results = self.model(frame, conf=lower_conf, verbose=False,
imgsz=640, augment=True)
result = results[0] if results else None
if result and result.boxes and len(result.boxes) > 0:
self.stats['adaptive_detections'] += 1
print(f"自适应检测成功,置信度: {lower_conf:.2f}")
# 过滤低质量检测
if result and result.boxes and len(result.boxes) > 0:
valid_indices = []
for i, box in enumerate(result.boxes):
confidence = float(box.conf[0].cpu().numpy())
if confidence >= self.detection_quality_threshold:
valid_indices.append(i)
# 如果有有效检测,保留原始结果结构
if not valid_indices:
result = None
return result
except Exception as e:
print(f"自适应检测失败: {e}")
return None
def draw_detections(self, frame, result):
"""在帧上绘制检测结果(绘制边界框和中文标注)"""
if result is None or result.boxes is None:
return frame
# 转换为PIL图像
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统常用中文字体路径
font_paths = [
"C:/Windows/Fonts/simhei.ttf", # 黑体
"C:/Windows/Fonts/simsun.ttc", # 宋体
"C:/Windows/Fonts/msyh.ttc", # 微软雅黑
]
font = None
for font_path in font_paths:
try:
font = ImageFont.truetype(font_path, 20)
break
except:
continue
if font is None:
font = ImageFont.load_default()
except:
font = ImageFont.load_default()
for box in result.boxes:
# 获取坐标和信息
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
# 动态获取类别名称(优先使用模型原始名称)
if hasattr(self.model, 'names') and class_id in self.model.names:
original_name = self.model.names[class_id]
class_name_cn = config.get_class_name_cn(class_id)
# 如果中文名称就是原始名称,说明没有映射,直接使用原始名称
if class_name_cn == original_name or class_name_cn.startswith('未知类别'):
display_name = original_name
else:
display_name = f"{class_name_cn}({original_name})"
else:
display_name = config.get_class_name_cn(class_id)
# 根据置信度调整颜色强度和线宽
base_color = config.get_class_color(class_id)
intensity = min(1.0, confidence + 0.3)
color = tuple(int(c * intensity) for c in base_color)
bg_color = tuple(reversed(color)) # BGR转RGB
line_width = max(2, int(confidence * 4))
# 绘制边界框
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width)
# 准备标注文字
label_text = f"{display_name}: {confidence:.3f}"
# 计算文字背景框大小
bbox = draw.textbbox((0, 0), label_text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 确保标注不超出图像边界
label_x = max(0, x1)
label_y = max(0, y1 - text_height - 5)
if label_y < 0:
label_y = y1 + 5
# 绘制文字背景
draw.rectangle(
[label_x, label_y, label_x + text_width + 4, label_y + text_height + 4],
fill=bg_color
)
# 绘制文字(白色)
draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font)
# 转换回OpenCV格式
frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return frame
def process_video(self, video_path, output_base_dir="output_frames"):
"""处理单个视频"""
print(f"\n开始处理: {video_path.name}")
# 创建输出目录
output_dir = Path(output_base_dir) / video_path.stem
output_dir.mkdir(parents=True, exist_ok=True)
# 文件夹将在检测到对应类别时按需创建
# 重置时间间隔过滤变量,确保每个视频独立计算时间间隔
self.last_saved_timestamp = 0
# 重置场景类别变量,确保每个视频独立分类场景
self.last_scene_classes = set()
# 打开视频
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
error_msg = f"无法打开视频: {video_path}"
print(error_msg)
self.stats['errors'].append(error_msg)
return
# 获取视频信息
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS")
# 处理统计
frame_count = 0
detected_frames = 0
saved_images = 0
filtered_frames = 0
all_detections = []
start_time = time.time()
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
timestamp = frame_count / fps if fps > 0 else frame_count
# 显示进度
if frame_count % config.PROGRESS_INTERVAL == 0:
progress = (frame_count / total_frames) * 100
elapsed = time.time() - start_time
fps_current = frame_count / elapsed
print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS")
# 更新进度回调
if hasattr(self, 'update_progress'):
self.update_progress(f"处理视频帧", None, None, None, frame_count)
# 检测并保存
detection_result = self.detect_and_save_frame(
frame, frame_count, timestamp, video_path.stem, output_dir
)
if detection_result:
detected_frames += 1
all_detections.append(detection_result)
# 检查是否是被过滤的帧
if detection_result.get('filtered', False):
filtered_frames += 1
else:
# 计算保存的图片数量
if config.SAVE_ORIGINAL_FRAMES:
saved_images += 1
if config.SAVE_ANNOTATED_FRAMES:
saved_images += 1
finally:
cap.release()
# 保存检测结果JSON
if config.SAVE_DETECTION_JSON and all_detections:
json_path = output_dir / f"{video_path.stem}_detections.json"
detection_summary = {
'video_name': video_path.name,
'total_frames': frame_count,
'detected_frames': detected_frames,
'fps': fps,
'detections': all_detections,
'processing_time': time.time() - start_time
}
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(detection_summary, f, ensure_ascii=False, indent=2)
# 更新统计
self.stats['processed_videos'] += 1
self.stats['total_frames'] += frame_count
self.stats['detected_frames'] += detected_frames
self.stats['saved_images'] += saved_images
self.stats['filtered_frames'] += filtered_frames
processing_time = time.time() - start_time
print(f"处理完成: {video_path.name}")
print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images}")
print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
print(f"处理时间: {processing_time:.2f}")
# 记录视频详细信息
video_info = {
'video_name': video_path.name,
'video_path': str(video_path),
'total_frames': frame_count,
'detected_frames': detected_frames,
'saved_images': saved_images,
'filtered_frames': filtered_frames,
'processing_time': processing_time,
'fps': fps,
'duration': frame_count / fps if fps > 0 else 0,
'output_directory': str(output_dir)
}
self.video_details.append(video_info)
def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"):
"""批量处理所有视频"""
self.stats['start_time'] = datetime.now()
video_files = self.get_video_files(input_folder)
if not video_files:
print("没有找到视频文件")
return
self.stats['total_videos'] = len(video_files)
print(f"\n开始批量处理 {len(video_files)} 个视频文件")
print(f"模型: {self.model_path}")
print(f"置信度阈值: {self.confidence}")
print(f"输出目录: {output_folder}")
print("=" * 50)
for i, video_path in enumerate(video_files, 1):
print(f"\n[{i}/{len(video_files)}] 处理视频")
self.process_video(video_path, output_folder)
self.stats['end_time'] = datetime.now()
report_path = self.save_final_report(output_folder)
self.print_summary()
return report_path
def save_final_report(self, output_folder):
"""保存最终批处理报告"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
# 计算总处理时间
total_seconds = 0
if self.stats['start_time'] and self.stats['end_time']:
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
# 创建报告数据
report_data = {
'timestamp': timestamp,
'model_path': self.model_path,
'confidence_threshold': self.confidence,
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
'statistics': {
'total_videos': self.stats['total_videos'],
'processed_videos': self.stats['processed_videos'],
'total_frames': self.stats['total_frames'],
'detected_frames': self.stats['detected_frames'],
'filtered_frames': self.stats['filtered_frames'],
'saved_images': self.stats['saved_images'],
'processing_time_seconds': total_seconds
},
'videos': self.video_details,
'errors': self.stats['errors']
}
# 保存为JSON
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report_data, f, ensure_ascii=False, indent=2)
print(f"\n批处理报告已保存: {report_path}")
return str(report_path)
def print_summary(self):
"""打印处理摘要"""
if not self.stats['end_time'] or not self.stats['start_time']:
return
total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
print("\n" + "=" * 50)
print("批量处理摘要")
print("=" * 50)
print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}")
print(f"总处理帧数: {self.stats['total_frames']}")
print(f"检测到目标的帧数: {self.stats['detected_frames']}")
print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次")
print(f"实际保存的图片数量: {self.stats['saved_images']}")
print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}")
if self.stats['errors']:
print("\n处理过程中的错误:")
for error in self.stats['errors']:
print(f"- {error}")
print("=" * 50)
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='批量视频检测系统')
parser.add_argument('--model', type=str, help='模型文件路径')
parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值')
parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹')
parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹')
args = parser.parse_args()
# 创建检测器
detector = BatchVideoDetector(
model_path=args.model,
confidence=args.confidence
)
# 开始批量处理
detector.process_all_videos(args.input, args.output)
if __name__ == '__main__':
main()