删除 batch_detector.py

This commit is contained in:
Wang_Run_Ze 2025-06-27 17:10:37 +08:00
parent 2e89a910a1
commit b4923a4035

View File

@ -1,544 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量视频检测器
简化版本专注于高效处理大量视频文件
"""
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import json
import time
from datetime import datetime
from config import config, DetectionConfig
from PIL import Image, ImageDraw, ImageFont
class BatchVideoDetector:
def __init__(self, model_path=None, confidence=0.5):
"""
初始化批量视频检测器
Args:
model_path (str): 模型路径如果为None则使用配置文件中的路径
confidence (float): 置信度阈值
"""
self.model_path = model_path or config.get_model_path()
self.confidence = confidence
# 创建必要文件夹
config.create_folders()
# 加载模型
self.load_model()
# 时间间隔过滤变量
self.last_saved_timestamp = 0
# 初始化场景类别变量
self.last_scene_classes = set()
# 统计信息
self.stats = {
'start_time': None,
'end_time': None,
'total_videos': 0,
'processed_videos': 0,
'total_frames': 0,
'detected_frames': 0,
'saved_images': 0,
'filtered_frames': 0,
'errors': []
}
# 视频详细信息列表
self.video_details = []
def load_model(self):
"""加载YOLO模型"""
try:
print(f"正在加载模型: {self.model_path}")
self.model = YOLO(self.model_path)
print(f"模型加载成功!")
# 检查是否有缓存的模型信息
model_filename = os.path.basename(self.model_path)
info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json")
cached_info_loaded = False
if os.path.exists(info_path):
try:
with open(info_path, 'r', encoding='utf-8') as f:
import json
cached_info = json.load(f)
if cached_info.get('analyzed', False):
print(f"使用缓存的模型信息,跳过重复分析")
print(f"模型类别数量: {cached_info.get('class_count', 0)}")
print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}")
cached_info_loaded = True
except Exception as e:
print(f"读取缓存模型信息失败: {e},将重新分析")
# 如果没有缓存信息,则进行完整分析
if not cached_info_loaded:
# 动态加载模型类别信息到配置中
DetectionConfig.load_model_classes(self.model_path)
if hasattr(self.model, 'names'):
print(f"\n=== 模型检测场景分析 ===")
print(f"模型类别数量: {len(self.model.names)}")
print(f"模型原始类别: {list(self.model.names.values())}")
print(f"\n=== 中文类别映射 ===")
for class_id, class_name in self.model.names.items():
class_name_cn = config.get_class_name_cn(class_id)
print(f"类别 {class_id}: {class_name} -> {class_name_cn}")
print(f"\n=== 检测场景总结 ===")
scene_types = list(DetectionConfig.get_all_classes().values())
print(f"本模型可检测的场景类型: {scene_types}")
print(f"检测置信度阈值: {self.confidence}")
print("=" * 40)
else:
# 使用缓存信息更新配置
DetectionConfig.load_model_classes(self.model_path)
except Exception as e:
print(f"模型加载失败: {e}")
raise
def get_video_files(self, input_folder="input_videos"):
"""获取所有视频文件"""
input_path = Path(input_folder)
if not input_path.exists():
print(f"输入文件夹不存在: {input_folder}")
return []
video_files = []
for ext in config.SUPPORTED_VIDEO_FORMATS:
video_files.extend(input_path.rglob(f"*{ext}"))
print(f"找到 {len(video_files)} 个视频文件")
return video_files
def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir):
"""检测单帧并保存结果"""
try:
# 进行检测
results = self.model(frame, conf=self.confidence, verbose=False)
result = results[0] if results else None
# 如果检测到目标
if result is not None and result.boxes is not None and len(result.boxes) > 0:
# 获取当前帧的检测类别
current_scene_classes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
current_scene_classes.add(class_id)
# 检查时间间隔
time_diff = timestamp - self.last_saved_timestamp
# 如果时间间隔小于最小间隔,检查是否为新场景
if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0:
# 如果当前帧的场景类别与上一帧相同,则过滤
if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes:
self.stats['filtered_frames'] += 1
# 仍然返回检测信息,但不保存图像
detections = []
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
bbox = box.xyxy[0].cpu().numpy().tolist()
class_name = self.model.names.get(class_id, f"class_{class_id}")
class_name_cn = config.get_class_name_cn(class_id)
detections.append({
'class_id': class_id,
'class_name': class_name,
'class_name_cn': class_name_cn,
'confidence': confidence,
'bbox': bbox,
'filtered': True
})
return {
'frame_number': frame_number,
'timestamp': timestamp,
'detections': detections,
'filtered': True
}
# 更新最后检测到的场景类别
self.last_scene_classes = current_scene_classes
# 更新最后保存的时间戳
self.last_saved_timestamp = timestamp
# 创建文件名
base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s"
# 获取检测到的场景类别,用于分类保存
detected_scenes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
class_name_cn = config.get_class_name_cn(class_id)
detected_scenes.add(class_name_cn)
# 为每个检测到的场景创建对应的文件夹并保存图片
for scene_name in detected_scenes:
# 创建视频名称文件夹
video_folder = output_dir / video_name
video_folder.mkdir(parents=True, exist_ok=True)
# 创建场景文件夹
scene_folder = video_folder / scene_name
scene_folder.mkdir(parents=True, exist_ok=True)
# 保存原始帧
if config.SAVE_ORIGINAL_FRAMES:
original_path = scene_folder / f"{base_name}_original.jpg"
cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
# 绘制并保存标注帧
if config.SAVE_ANNOTATED_FRAMES:
annotated_frame = self.draw_detections(frame.copy(), result)
annotated_path = scene_folder / f"{base_name}_detected.jpg"
cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
# 收集检测信息
detections = []
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
bbox = box.xyxy[0].cpu().numpy().tolist()
class_name = self.model.names.get(class_id, f"class_{class_id}")
class_name_cn = config.get_class_name_cn(class_id)
detections.append({
'class_id': class_id,
'class_name': class_name,
'class_name_cn': class_name_cn,
'confidence': confidence,
'bbox': bbox
})
return {
'frame_number': frame_number,
'timestamp': timestamp,
'detections': detections
}
return None
except Exception as e:
print(f"处理帧 {frame_number} 时出错: {e}")
return None
def draw_detections(self, frame, result):
"""在帧上绘制检测结果(绘制边界框和中文标注)"""
if result is None or result.boxes is None:
return frame
# 转换为PIL图像
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统常用中文字体路径
font_paths = [
"C:/Windows/Fonts/simhei.ttf", # 黑体
"C:/Windows/Fonts/simsun.ttc", # 宋体
"C:/Windows/Fonts/msyh.ttc", # 微软雅黑
]
font = None
for font_path in font_paths:
try:
font = ImageFont.truetype(font_path, 20)
break
except:
continue
if font is None:
font = ImageFont.load_default()
except:
font = ImageFont.load_default()
for box in result.boxes:
# 获取坐标和信息
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
class_id = int(box.cls[0].cpu().numpy())
confidence = float(box.conf[0].cpu().numpy())
# 获取颜色和类别名称
color = config.get_class_color(class_id)
bg_color = tuple(reversed(color)) # BGR转RGB
class_name_cn = config.get_class_name_cn(class_id)
# 绘制边界框
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2)
# 准备标注文字
label_text = f"{class_name_cn} {confidence:.2f}"
# 计算文字背景框大小
bbox = draw.textbbox((0, 0), label_text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 确保标注不超出图像边界
label_x = max(0, x1)
label_y = max(0, y1 - text_height - 5)
if label_y < 0:
label_y = y1 + 5
# 绘制文字背景
draw.rectangle(
[label_x, label_y, label_x + text_width + 4, label_y + text_height + 4],
fill=bg_color
)
# 绘制文字(白色)
draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font)
# 转换回OpenCV格式
frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return frame
def process_video(self, video_path, output_base_dir="output_frames"):
"""处理单个视频"""
print(f"\n开始处理: {video_path.name}")
# 创建输出目录
output_dir = Path(output_base_dir) / video_path.stem
output_dir.mkdir(parents=True, exist_ok=True)
# 重置时间间隔过滤变量,确保每个视频独立计算时间间隔
self.last_saved_timestamp = 0
# 重置场景类别变量,确保每个视频独立分类场景
self.last_scene_classes = set()
# 打开视频
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
error_msg = f"无法打开视频: {video_path}"
print(error_msg)
self.stats['errors'].append(error_msg)
return
# 获取视频信息
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS")
# 处理统计
frame_count = 0
detected_frames = 0
saved_images = 0
filtered_frames = 0
all_detections = []
start_time = time.time()
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
timestamp = frame_count / fps if fps > 0 else frame_count
# 显示进度
if frame_count % config.PROGRESS_INTERVAL == 0:
progress = (frame_count / total_frames) * 100
elapsed = time.time() - start_time
fps_current = frame_count / elapsed
print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS")
# 更新进度回调
if hasattr(self, 'update_progress'):
self.update_progress(f"处理视频帧", None, None, None, frame_count)
# 检测并保存
detection_result = self.detect_and_save_frame(
frame, frame_count, timestamp, video_path.stem, output_dir
)
if detection_result:
detected_frames += 1
all_detections.append(detection_result)
# 检查是否是被过滤的帧
if detection_result.get('filtered', False):
filtered_frames += 1
else:
# 计算保存的图片数量
if config.SAVE_ORIGINAL_FRAMES:
saved_images += 1
if config.SAVE_ANNOTATED_FRAMES:
saved_images += 1
finally:
cap.release()
# 保存检测结果JSON
if config.SAVE_DETECTION_JSON and all_detections:
json_path = output_dir / f"{video_path.stem}_detections.json"
detection_summary = {
'video_name': video_path.name,
'total_frames': frame_count,
'detected_frames': detected_frames,
'fps': fps,
'detections': all_detections,
'processing_time': time.time() - start_time
}
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(detection_summary, f, ensure_ascii=False, indent=2)
# 更新统计
self.stats['processed_videos'] += 1
self.stats['total_frames'] += frame_count
self.stats['detected_frames'] += detected_frames
self.stats['saved_images'] += saved_images
self.stats['filtered_frames'] += filtered_frames
processing_time = time.time() - start_time
print(f"处理完成: {video_path.name}")
print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images}")
print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
print(f"处理时间: {processing_time:.2f}")
# 记录视频详细信息
video_info = {
'video_name': video_path.name,
'video_path': str(video_path),
'total_frames': frame_count,
'detected_frames': detected_frames,
'saved_images': saved_images,
'filtered_frames': filtered_frames,
'processing_time': processing_time,
'fps': fps,
'duration': frame_count / fps if fps > 0 else 0,
'output_directory': str(output_dir)
}
self.video_details.append(video_info)
def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"):
"""批量处理所有视频"""
self.stats['start_time'] = datetime.now()
video_files = self.get_video_files(input_folder)
if not video_files:
print("没有找到视频文件")
return
self.stats['total_videos'] = len(video_files)
print(f"\n开始批量处理 {len(video_files)} 个视频文件")
print(f"模型: {self.model_path}")
print(f"置信度阈值: {self.confidence}")
print(f"输出目录: {output_folder}")
print("=" * 50)
for i, video_path in enumerate(video_files, 1):
print(f"\n[{i}/{len(video_files)}] 处理视频")
self.process_video(video_path, output_folder)
self.stats['end_time'] = datetime.now()
report_path = self.save_final_report(output_folder)
self.print_summary()
return report_path
def save_final_report(self, output_folder):
"""保存最终批处理报告"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
# 计算总处理时间
total_seconds = 0
if self.stats['start_time'] and self.stats['end_time']:
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
# 创建报告数据
report_data = {
'timestamp': timestamp,
'model_path': self.model_path,
'confidence_threshold': self.confidence,
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
'statistics': {
'total_videos': self.stats['total_videos'],
'processed_videos': self.stats['processed_videos'],
'total_frames': self.stats['total_frames'],
'detected_frames': self.stats['detected_frames'],
'filtered_frames': self.stats['filtered_frames'],
'saved_images': self.stats['saved_images'],
'processing_time_seconds': total_seconds
},
'videos': self.video_details,
'errors': self.stats['errors']
}
# 保存为JSON
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(report_data, f, ensure_ascii=False, indent=2)
print(f"\n批处理报告已保存: {report_path}")
return str(report_path)
def print_summary(self):
"""打印处理摘要"""
if not self.stats['end_time'] or not self.stats['start_time']:
return
total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
print("\n" + "=" * 50)
print("批量处理摘要")
print("=" * 50)
print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}")
print(f"总处理帧数: {self.stats['total_frames']}")
print(f"检测到目标的帧数: {self.stats['detected_frames']}")
print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次")
print(f"实际保存的图片数量: {self.stats['saved_images']}")
print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}")
if self.stats['errors']:
print("\n处理过程中的错误:")
for error in self.stats['errors']:
print(f"- {error}")
print("=" * 50)
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='批量视频检测系统')
parser.add_argument('--model', type=str, help='模型文件路径')
parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值')
parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹')
parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹')
args = parser.parse_args()
# 创建检测器
detector = BatchVideoDetector(
model_path=args.model,
confidence=args.confidence
)
# 开始批量处理
detector.process_all_videos(args.input, args.output)
if __name__ == '__main__':
main()