删除 batch_detector.py
This commit is contained in:
parent
2e89a910a1
commit
b4923a4035
@ -1,544 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
批量视频检测器
|
||||
简化版本,专注于高效处理大量视频文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from config import config, DetectionConfig
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
class BatchVideoDetector:
|
||||
def __init__(self, model_path=None, confidence=0.5):
|
||||
"""
|
||||
初始化批量视频检测器
|
||||
|
||||
Args:
|
||||
model_path (str): 模型路径,如果为None则使用配置文件中的路径
|
||||
confidence (float): 置信度阈值
|
||||
"""
|
||||
self.model_path = model_path or config.get_model_path()
|
||||
self.confidence = confidence
|
||||
|
||||
# 创建必要文件夹
|
||||
config.create_folders()
|
||||
|
||||
# 加载模型
|
||||
self.load_model()
|
||||
|
||||
# 时间间隔过滤变量
|
||||
self.last_saved_timestamp = 0
|
||||
|
||||
# 初始化场景类别变量
|
||||
self.last_scene_classes = set()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
'start_time': None,
|
||||
'end_time': None,
|
||||
'total_videos': 0,
|
||||
'processed_videos': 0,
|
||||
'total_frames': 0,
|
||||
'detected_frames': 0,
|
||||
'saved_images': 0,
|
||||
'filtered_frames': 0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
# 视频详细信息列表
|
||||
self.video_details = []
|
||||
|
||||
def load_model(self):
|
||||
"""加载YOLO模型"""
|
||||
try:
|
||||
print(f"正在加载模型: {self.model_path}")
|
||||
self.model = YOLO(self.model_path)
|
||||
print(f"模型加载成功!")
|
||||
|
||||
# 检查是否有缓存的模型信息
|
||||
model_filename = os.path.basename(self.model_path)
|
||||
info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json")
|
||||
|
||||
cached_info_loaded = False
|
||||
if os.path.exists(info_path):
|
||||
try:
|
||||
with open(info_path, 'r', encoding='utf-8') as f:
|
||||
import json
|
||||
cached_info = json.load(f)
|
||||
if cached_info.get('analyzed', False):
|
||||
print(f"使用缓存的模型信息,跳过重复分析")
|
||||
print(f"模型类别数量: {cached_info.get('class_count', 0)}")
|
||||
print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}")
|
||||
cached_info_loaded = True
|
||||
except Exception as e:
|
||||
print(f"读取缓存模型信息失败: {e},将重新分析")
|
||||
|
||||
# 如果没有缓存信息,则进行完整分析
|
||||
if not cached_info_loaded:
|
||||
# 动态加载模型类别信息到配置中
|
||||
DetectionConfig.load_model_classes(self.model_path)
|
||||
|
||||
if hasattr(self.model, 'names'):
|
||||
print(f"\n=== 模型检测场景分析 ===")
|
||||
print(f"模型类别数量: {len(self.model.names)}")
|
||||
print(f"模型原始类别: {list(self.model.names.values())}")
|
||||
|
||||
print(f"\n=== 中文类别映射 ===")
|
||||
for class_id, class_name in self.model.names.items():
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
print(f"类别 {class_id}: {class_name} -> {class_name_cn}")
|
||||
|
||||
print(f"\n=== 检测场景总结 ===")
|
||||
scene_types = list(DetectionConfig.get_all_classes().values())
|
||||
print(f"本模型可检测的场景类型: {scene_types}")
|
||||
print(f"检测置信度阈值: {self.confidence}")
|
||||
print("=" * 40)
|
||||
else:
|
||||
# 使用缓存信息更新配置
|
||||
DetectionConfig.load_model_classes(self.model_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
def get_video_files(self, input_folder="input_videos"):
|
||||
"""获取所有视频文件"""
|
||||
input_path = Path(input_folder)
|
||||
if not input_path.exists():
|
||||
print(f"输入文件夹不存在: {input_folder}")
|
||||
return []
|
||||
|
||||
video_files = []
|
||||
for ext in config.SUPPORTED_VIDEO_FORMATS:
|
||||
video_files.extend(input_path.rglob(f"*{ext}"))
|
||||
|
||||
print(f"找到 {len(video_files)} 个视频文件")
|
||||
return video_files
|
||||
|
||||
def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir):
|
||||
"""检测单帧并保存结果"""
|
||||
try:
|
||||
# 进行检测
|
||||
results = self.model(frame, conf=self.confidence, verbose=False)
|
||||
result = results[0] if results else None
|
||||
|
||||
# 如果检测到目标
|
||||
if result is not None and result.boxes is not None and len(result.boxes) > 0:
|
||||
# 获取当前帧的检测类别
|
||||
current_scene_classes = set()
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
current_scene_classes.add(class_id)
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = timestamp - self.last_saved_timestamp
|
||||
|
||||
# 如果时间间隔小于最小间隔,检查是否为新场景
|
||||
if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0:
|
||||
# 如果当前帧的场景类别与上一帧相同,则过滤
|
||||
if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes:
|
||||
self.stats['filtered_frames'] += 1
|
||||
# 仍然返回检测信息,但不保存图像
|
||||
detections = []
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
bbox = box.xyxy[0].cpu().numpy().tolist()
|
||||
class_name = self.model.names.get(class_id, f"class_{class_id}")
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
|
||||
detections.append({
|
||||
'class_id': class_id,
|
||||
'class_name': class_name,
|
||||
'class_name_cn': class_name_cn,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox,
|
||||
'filtered': True
|
||||
})
|
||||
|
||||
return {
|
||||
'frame_number': frame_number,
|
||||
'timestamp': timestamp,
|
||||
'detections': detections,
|
||||
'filtered': True
|
||||
}
|
||||
|
||||
# 更新最后检测到的场景类别
|
||||
self.last_scene_classes = current_scene_classes
|
||||
|
||||
# 更新最后保存的时间戳
|
||||
self.last_saved_timestamp = timestamp
|
||||
|
||||
# 创建文件名
|
||||
base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s"
|
||||
|
||||
# 获取检测到的场景类别,用于分类保存
|
||||
detected_scenes = set()
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
detected_scenes.add(class_name_cn)
|
||||
|
||||
# 为每个检测到的场景创建对应的文件夹并保存图片
|
||||
for scene_name in detected_scenes:
|
||||
# 创建视频名称文件夹
|
||||
video_folder = output_dir / video_name
|
||||
video_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建场景文件夹
|
||||
scene_folder = video_folder / scene_name
|
||||
scene_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存原始帧
|
||||
if config.SAVE_ORIGINAL_FRAMES:
|
||||
original_path = scene_folder / f"{base_name}_original.jpg"
|
||||
cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
|
||||
|
||||
# 绘制并保存标注帧
|
||||
if config.SAVE_ANNOTATED_FRAMES:
|
||||
annotated_frame = self.draw_detections(frame.copy(), result)
|
||||
annotated_path = scene_folder / f"{base_name}_detected.jpg"
|
||||
cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY])
|
||||
|
||||
# 收集检测信息
|
||||
detections = []
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
bbox = box.xyxy[0].cpu().numpy().tolist()
|
||||
class_name = self.model.names.get(class_id, f"class_{class_id}")
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
|
||||
detections.append({
|
||||
'class_id': class_id,
|
||||
'class_name': class_name,
|
||||
'class_name_cn': class_name_cn,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox
|
||||
})
|
||||
|
||||
return {
|
||||
'frame_number': frame_number,
|
||||
'timestamp': timestamp,
|
||||
'detections': detections
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理帧 {frame_number} 时出错: {e}")
|
||||
return None
|
||||
|
||||
def draw_detections(self, frame, result):
|
||||
"""在帧上绘制检测结果(绘制边界框和中文标注)"""
|
||||
if result is None or result.boxes is None:
|
||||
return frame
|
||||
|
||||
# 转换为PIL图像
|
||||
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
|
||||
# 尝试加载中文字体
|
||||
try:
|
||||
# Windows系统常用中文字体路径
|
||||
font_paths = [
|
||||
"C:/Windows/Fonts/simhei.ttf", # 黑体
|
||||
"C:/Windows/Fonts/simsun.ttc", # 宋体
|
||||
"C:/Windows/Fonts/msyh.ttc", # 微软雅黑
|
||||
]
|
||||
font = None
|
||||
for font_path in font_paths:
|
||||
try:
|
||||
font = ImageFont.truetype(font_path, 20)
|
||||
break
|
||||
except:
|
||||
continue
|
||||
if font is None:
|
||||
font = ImageFont.load_default()
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for box in result.boxes:
|
||||
# 获取坐标和信息
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
|
||||
# 获取颜色和类别名称
|
||||
color = config.get_class_color(class_id)
|
||||
bg_color = tuple(reversed(color)) # BGR转RGB
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
|
||||
# 绘制边界框
|
||||
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2)
|
||||
|
||||
# 准备标注文字
|
||||
label_text = f"{class_name_cn} {confidence:.2f}"
|
||||
|
||||
# 计算文字背景框大小
|
||||
bbox = draw.textbbox((0, 0), label_text, font=font)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
|
||||
# 确保标注不超出图像边界
|
||||
label_x = max(0, x1)
|
||||
label_y = max(0, y1 - text_height - 5)
|
||||
if label_y < 0:
|
||||
label_y = y1 + 5
|
||||
|
||||
# 绘制文字背景
|
||||
draw.rectangle(
|
||||
[label_x, label_y, label_x + text_width + 4, label_y + text_height + 4],
|
||||
fill=bg_color
|
||||
)
|
||||
|
||||
# 绘制文字(白色)
|
||||
draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font)
|
||||
|
||||
# 转换回OpenCV格式
|
||||
frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
return frame
|
||||
|
||||
def process_video(self, video_path, output_base_dir="output_frames"):
|
||||
"""处理单个视频"""
|
||||
print(f"\n开始处理: {video_path.name}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(output_base_dir) / video_path.stem
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 重置时间间隔过滤变量,确保每个视频独立计算时间间隔
|
||||
self.last_saved_timestamp = 0
|
||||
|
||||
# 重置场景类别变量,确保每个视频独立分类场景
|
||||
self.last_scene_classes = set()
|
||||
|
||||
# 打开视频
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
error_msg = f"无法打开视频: {video_path}"
|
||||
print(error_msg)
|
||||
self.stats['errors'].append(error_msg)
|
||||
return
|
||||
|
||||
# 获取视频信息
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS")
|
||||
|
||||
# 处理统计
|
||||
frame_count = 0
|
||||
detected_frames = 0
|
||||
saved_images = 0
|
||||
filtered_frames = 0
|
||||
all_detections = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
timestamp = frame_count / fps if fps > 0 else frame_count
|
||||
|
||||
# 显示进度
|
||||
if frame_count % config.PROGRESS_INTERVAL == 0:
|
||||
progress = (frame_count / total_frames) * 100
|
||||
elapsed = time.time() - start_time
|
||||
fps_current = frame_count / elapsed
|
||||
print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS")
|
||||
|
||||
# 更新进度回调
|
||||
if hasattr(self, 'update_progress'):
|
||||
self.update_progress(f"处理视频帧", None, None, None, frame_count)
|
||||
|
||||
# 检测并保存
|
||||
detection_result = self.detect_and_save_frame(
|
||||
frame, frame_count, timestamp, video_path.stem, output_dir
|
||||
)
|
||||
|
||||
if detection_result:
|
||||
detected_frames += 1
|
||||
all_detections.append(detection_result)
|
||||
|
||||
# 检查是否是被过滤的帧
|
||||
if detection_result.get('filtered', False):
|
||||
filtered_frames += 1
|
||||
else:
|
||||
# 计算保存的图片数量
|
||||
if config.SAVE_ORIGINAL_FRAMES:
|
||||
saved_images += 1
|
||||
if config.SAVE_ANNOTATED_FRAMES:
|
||||
saved_images += 1
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
# 保存检测结果JSON
|
||||
if config.SAVE_DETECTION_JSON and all_detections:
|
||||
json_path = output_dir / f"{video_path.stem}_detections.json"
|
||||
detection_summary = {
|
||||
'video_name': video_path.name,
|
||||
'total_frames': frame_count,
|
||||
'detected_frames': detected_frames,
|
||||
'fps': fps,
|
||||
'detections': all_detections,
|
||||
'processing_time': time.time() - start_time
|
||||
}
|
||||
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(detection_summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 更新统计
|
||||
self.stats['processed_videos'] += 1
|
||||
self.stats['total_frames'] += frame_count
|
||||
self.stats['detected_frames'] += detected_frames
|
||||
self.stats['saved_images'] += saved_images
|
||||
self.stats['filtered_frames'] += filtered_frames
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
print(f"处理完成: {video_path.name}")
|
||||
print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images} 张")
|
||||
print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
|
||||
print(f"处理时间: {processing_time:.2f} 秒")
|
||||
|
||||
# 记录视频详细信息
|
||||
video_info = {
|
||||
'video_name': video_path.name,
|
||||
'video_path': str(video_path),
|
||||
'total_frames': frame_count,
|
||||
'detected_frames': detected_frames,
|
||||
'saved_images': saved_images,
|
||||
'filtered_frames': filtered_frames,
|
||||
'processing_time': processing_time,
|
||||
'fps': fps,
|
||||
'duration': frame_count / fps if fps > 0 else 0,
|
||||
'output_directory': str(output_dir)
|
||||
}
|
||||
self.video_details.append(video_info)
|
||||
|
||||
def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"):
|
||||
"""批量处理所有视频"""
|
||||
self.stats['start_time'] = datetime.now()
|
||||
|
||||
video_files = self.get_video_files(input_folder)
|
||||
if not video_files:
|
||||
print("没有找到视频文件")
|
||||
return
|
||||
|
||||
self.stats['total_videos'] = len(video_files)
|
||||
|
||||
print(f"\n开始批量处理 {len(video_files)} 个视频文件")
|
||||
print(f"模型: {self.model_path}")
|
||||
print(f"置信度阈值: {self.confidence}")
|
||||
print(f"输出目录: {output_folder}")
|
||||
print("=" * 50)
|
||||
|
||||
for i, video_path in enumerate(video_files, 1):
|
||||
print(f"\n[{i}/{len(video_files)}] 处理视频")
|
||||
self.process_video(video_path, output_folder)
|
||||
|
||||
self.stats['end_time'] = datetime.now()
|
||||
report_path = self.save_final_report(output_folder)
|
||||
self.print_summary()
|
||||
return report_path
|
||||
|
||||
def save_final_report(self, output_folder):
|
||||
"""保存最终批处理报告"""
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json"
|
||||
|
||||
# 计算总处理时间
|
||||
total_seconds = 0
|
||||
if self.stats['start_time'] and self.stats['end_time']:
|
||||
total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
|
||||
|
||||
# 创建报告数据
|
||||
report_data = {
|
||||
'timestamp': timestamp,
|
||||
'model_path': self.model_path,
|
||||
'confidence_threshold': self.confidence,
|
||||
'frame_interval_threshold': config.MIN_FRAME_INTERVAL,
|
||||
'filter_strategy': '相同场景在指定时间间隔内只保存一次',
|
||||
'statistics': {
|
||||
'total_videos': self.stats['total_videos'],
|
||||
'processed_videos': self.stats['processed_videos'],
|
||||
'total_frames': self.stats['total_frames'],
|
||||
'detected_frames': self.stats['detected_frames'],
|
||||
'filtered_frames': self.stats['filtered_frames'],
|
||||
'saved_images': self.stats['saved_images'],
|
||||
'processing_time_seconds': total_seconds
|
||||
},
|
||||
'videos': self.video_details,
|
||||
'errors': self.stats['errors']
|
||||
}
|
||||
|
||||
# 保存为JSON
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n批处理报告已保存: {report_path}")
|
||||
return str(report_path)
|
||||
|
||||
def print_summary(self):
|
||||
"""打印处理摘要"""
|
||||
if not self.stats['end_time'] or not self.stats['start_time']:
|
||||
return
|
||||
|
||||
total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
|
||||
hours, remainder = divmod(total_time, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("批量处理摘要")
|
||||
print("=" * 50)
|
||||
print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}")
|
||||
print(f"总处理帧数: {self.stats['total_frames']}")
|
||||
print(f"检测到目标的帧数: {self.stats['detected_frames']}")
|
||||
print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)")
|
||||
print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次")
|
||||
print(f"实际保存的图片数量: {self.stats['saved_images']}")
|
||||
print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}秒")
|
||||
|
||||
if self.stats['errors']:
|
||||
print("\n处理过程中的错误:")
|
||||
for error in self.stats['errors']:
|
||||
print(f"- {error}")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='批量视频检测系统')
|
||||
parser.add_argument('--model', type=str, help='模型文件路径')
|
||||
parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值')
|
||||
parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹')
|
||||
parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建检测器
|
||||
detector = BatchVideoDetector(
|
||||
model_path=args.model,
|
||||
confidence=args.confidence
|
||||
)
|
||||
|
||||
# 开始批量处理
|
||||
detector.process_all_videos(args.input, args.output)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user