diff --git a/batch_detector.py b/batch_detector.py deleted file mode 100644 index 18fdfa1..0000000 --- a/batch_detector.py +++ /dev/null @@ -1,544 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -批量视频检测器 -简化版本,专注于高效处理大量视频文件 -""" - -import os -import cv2 -import numpy as np -from pathlib import Path -from ultralytics import YOLO -import json -import time -from datetime import datetime -from config import config, DetectionConfig -from PIL import Image, ImageDraw, ImageFont - -class BatchVideoDetector: - def __init__(self, model_path=None, confidence=0.5): - """ - 初始化批量视频检测器 - - Args: - model_path (str): 模型路径,如果为None则使用配置文件中的路径 - confidence (float): 置信度阈值 - """ - self.model_path = model_path or config.get_model_path() - self.confidence = confidence - - # 创建必要文件夹 - config.create_folders() - - # 加载模型 - self.load_model() - - # 时间间隔过滤变量 - self.last_saved_timestamp = 0 - - # 初始化场景类别变量 - self.last_scene_classes = set() - - # 统计信息 - self.stats = { - 'start_time': None, - 'end_time': None, - 'total_videos': 0, - 'processed_videos': 0, - 'total_frames': 0, - 'detected_frames': 0, - 'saved_images': 0, - 'filtered_frames': 0, - 'errors': [] - } - - # 视频详细信息列表 - self.video_details = [] - - def load_model(self): - """加载YOLO模型""" - try: - print(f"正在加载模型: {self.model_path}") - self.model = YOLO(self.model_path) - print(f"模型加载成功!") - - # 检查是否有缓存的模型信息 - model_filename = os.path.basename(self.model_path) - info_path = os.path.join(os.path.dirname(self.model_path), f"{os.path.splitext(model_filename)[0]}_info.json") - - cached_info_loaded = False - if os.path.exists(info_path): - try: - with open(info_path, 'r', encoding='utf-8') as f: - import json - cached_info = json.load(f) - if cached_info.get('analyzed', False): - print(f"使用缓存的模型信息,跳过重复分析") - print(f"模型类别数量: {cached_info.get('class_count', 0)}") - print(f"检测场景类型: {list(cached_info.get('classes', {}).values())}") - cached_info_loaded = True - except Exception as e: - print(f"读取缓存模型信息失败: {e},将重新分析") - - # 如果没有缓存信息,则进行完整分析 - if not cached_info_loaded: - # 动态加载模型类别信息到配置中 - DetectionConfig.load_model_classes(self.model_path) - - if hasattr(self.model, 'names'): - print(f"\n=== 模型检测场景分析 ===") - print(f"模型类别数量: {len(self.model.names)}") - print(f"模型原始类别: {list(self.model.names.values())}") - - print(f"\n=== 中文类别映射 ===") - for class_id, class_name in self.model.names.items(): - class_name_cn = config.get_class_name_cn(class_id) - print(f"类别 {class_id}: {class_name} -> {class_name_cn}") - - print(f"\n=== 检测场景总结 ===") - scene_types = list(DetectionConfig.get_all_classes().values()) - print(f"本模型可检测的场景类型: {scene_types}") - print(f"检测置信度阈值: {self.confidence}") - print("=" * 40) - else: - # 使用缓存信息更新配置 - DetectionConfig.load_model_classes(self.model_path) - - except Exception as e: - print(f"模型加载失败: {e}") - raise - - def get_video_files(self, input_folder="input_videos"): - """获取所有视频文件""" - input_path = Path(input_folder) - if not input_path.exists(): - print(f"输入文件夹不存在: {input_folder}") - return [] - - video_files = [] - for ext in config.SUPPORTED_VIDEO_FORMATS: - video_files.extend(input_path.rglob(f"*{ext}")) - - print(f"找到 {len(video_files)} 个视频文件") - return video_files - - def detect_and_save_frame(self, frame, frame_number, timestamp, video_name, output_dir): - """检测单帧并保存结果""" - try: - # 进行检测 - results = self.model(frame, conf=self.confidence, verbose=False) - result = results[0] if results else None - - # 如果检测到目标 - if result is not None and result.boxes is not None and len(result.boxes) > 0: - # 获取当前帧的检测类别 - current_scene_classes = set() - for box in result.boxes: - class_id = int(box.cls[0].cpu().numpy()) - current_scene_classes.add(class_id) - - # 检查时间间隔 - time_diff = timestamp - self.last_saved_timestamp - - # 如果时间间隔小于最小间隔,检查是否为新场景 - if time_diff < config.MIN_FRAME_INTERVAL and self.last_saved_timestamp > 0: - # 如果当前帧的场景类别与上一帧相同,则过滤 - if hasattr(self, 'last_scene_classes') and current_scene_classes == self.last_scene_classes: - self.stats['filtered_frames'] += 1 - # 仍然返回检测信息,但不保存图像 - detections = [] - for box in result.boxes: - class_id = int(box.cls[0].cpu().numpy()) - confidence = float(box.conf[0].cpu().numpy()) - bbox = box.xyxy[0].cpu().numpy().tolist() - class_name = self.model.names.get(class_id, f"class_{class_id}") - class_name_cn = config.get_class_name_cn(class_id) - - detections.append({ - 'class_id': class_id, - 'class_name': class_name, - 'class_name_cn': class_name_cn, - 'confidence': confidence, - 'bbox': bbox, - 'filtered': True - }) - - return { - 'frame_number': frame_number, - 'timestamp': timestamp, - 'detections': detections, - 'filtered': True - } - - # 更新最后检测到的场景类别 - self.last_scene_classes = current_scene_classes - - # 更新最后保存的时间戳 - self.last_saved_timestamp = timestamp - - # 创建文件名 - base_name = f"{video_name}_frame_{frame_number:06d}_t{timestamp:.2f}s" - - # 获取检测到的场景类别,用于分类保存 - detected_scenes = set() - for box in result.boxes: - class_id = int(box.cls[0].cpu().numpy()) - class_name_cn = config.get_class_name_cn(class_id) - detected_scenes.add(class_name_cn) - - # 为每个检测到的场景创建对应的文件夹并保存图片 - for scene_name in detected_scenes: - # 创建视频名称文件夹 - video_folder = output_dir / video_name - video_folder.mkdir(parents=True, exist_ok=True) - - # 创建场景文件夹 - scene_folder = video_folder / scene_name - scene_folder.mkdir(parents=True, exist_ok=True) - - # 保存原始帧 - if config.SAVE_ORIGINAL_FRAMES: - original_path = scene_folder / f"{base_name}_original.jpg" - cv2.imwrite(str(original_path), frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY]) - - # 绘制并保存标注帧 - if config.SAVE_ANNOTATED_FRAMES: - annotated_frame = self.draw_detections(frame.copy(), result) - annotated_path = scene_folder / f"{base_name}_detected.jpg" - cv2.imwrite(str(annotated_path), annotated_frame, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_QUALITY]) - - # 收集检测信息 - detections = [] - for box in result.boxes: - class_id = int(box.cls[0].cpu().numpy()) - confidence = float(box.conf[0].cpu().numpy()) - bbox = box.xyxy[0].cpu().numpy().tolist() - class_name = self.model.names.get(class_id, f"class_{class_id}") - class_name_cn = config.get_class_name_cn(class_id) - - detections.append({ - 'class_id': class_id, - 'class_name': class_name, - 'class_name_cn': class_name_cn, - 'confidence': confidence, - 'bbox': bbox - }) - - return { - 'frame_number': frame_number, - 'timestamp': timestamp, - 'detections': detections - } - - return None - - except Exception as e: - print(f"处理帧 {frame_number} 时出错: {e}") - return None - - def draw_detections(self, frame, result): - """在帧上绘制检测结果(绘制边界框和中文标注)""" - if result is None or result.boxes is None: - return frame - - # 转换为PIL图像 - pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - draw = ImageDraw.Draw(pil_image) - - # 尝试加载中文字体 - try: - # Windows系统常用中文字体路径 - font_paths = [ - "C:/Windows/Fonts/simhei.ttf", # 黑体 - "C:/Windows/Fonts/simsun.ttc", # 宋体 - "C:/Windows/Fonts/msyh.ttc", # 微软雅黑 - ] - font = None - for font_path in font_paths: - try: - font = ImageFont.truetype(font_path, 20) - break - except: - continue - if font is None: - font = ImageFont.load_default() - except: - font = ImageFont.load_default() - - for box in result.boxes: - # 获取坐标和信息 - x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) - class_id = int(box.cls[0].cpu().numpy()) - confidence = float(box.conf[0].cpu().numpy()) - - # 获取颜色和类别名称 - color = config.get_class_color(class_id) - bg_color = tuple(reversed(color)) # BGR转RGB - class_name_cn = config.get_class_name_cn(class_id) - - # 绘制边界框 - draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2) - - # 准备标注文字 - label_text = f"{class_name_cn} {confidence:.2f}" - - # 计算文字背景框大小 - bbox = draw.textbbox((0, 0), label_text, font=font) - text_width = bbox[2] - bbox[0] - text_height = bbox[3] - bbox[1] - - # 确保标注不超出图像边界 - label_x = max(0, x1) - label_y = max(0, y1 - text_height - 5) - if label_y < 0: - label_y = y1 + 5 - - # 绘制文字背景 - draw.rectangle( - [label_x, label_y, label_x + text_width + 4, label_y + text_height + 4], - fill=bg_color - ) - - # 绘制文字(白色) - draw.text((label_x + 2, label_y + 2), label_text, fill=(255, 255, 255), font=font) - - # 转换回OpenCV格式 - frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) - return frame - - def process_video(self, video_path, output_base_dir="output_frames"): - """处理单个视频""" - print(f"\n开始处理: {video_path.name}") - - # 创建输出目录 - output_dir = Path(output_base_dir) / video_path.stem - output_dir.mkdir(parents=True, exist_ok=True) - - # 重置时间间隔过滤变量,确保每个视频独立计算时间间隔 - self.last_saved_timestamp = 0 - - # 重置场景类别变量,确保每个视频独立分类场景 - self.last_scene_classes = set() - - # 打开视频 - cap = cv2.VideoCapture(str(video_path)) - if not cap.isOpened(): - error_msg = f"无法打开视频: {video_path}" - print(error_msg) - self.stats['errors'].append(error_msg) - return - - # 获取视频信息 - fps = cap.get(cv2.CAP_PROP_FPS) - total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - print(f"视频信息: {total_frames} 帧, {fps:.2f} FPS") - - # 处理统计 - frame_count = 0 - detected_frames = 0 - saved_images = 0 - filtered_frames = 0 - all_detections = [] - - start_time = time.time() - - try: - while True: - ret, frame = cap.read() - if not ret: - break - - frame_count += 1 - timestamp = frame_count / fps if fps > 0 else frame_count - - # 显示进度 - if frame_count % config.PROGRESS_INTERVAL == 0: - progress = (frame_count / total_frames) * 100 - elapsed = time.time() - start_time - fps_current = frame_count / elapsed - print(f"进度: {progress:.1f}% ({frame_count}/{total_frames}), 处理速度: {fps_current:.1f} FPS") - - # 更新进度回调 - if hasattr(self, 'update_progress'): - self.update_progress(f"处理视频帧", None, None, None, frame_count) - - # 检测并保存 - detection_result = self.detect_and_save_frame( - frame, frame_count, timestamp, video_path.stem, output_dir - ) - - if detection_result: - detected_frames += 1 - all_detections.append(detection_result) - - # 检查是否是被过滤的帧 - if detection_result.get('filtered', False): - filtered_frames += 1 - else: - # 计算保存的图片数量 - if config.SAVE_ORIGINAL_FRAMES: - saved_images += 1 - if config.SAVE_ANNOTATED_FRAMES: - saved_images += 1 - - finally: - cap.release() - - # 保存检测结果JSON - if config.SAVE_DETECTION_JSON and all_detections: - json_path = output_dir / f"{video_path.stem}_detections.json" - detection_summary = { - 'video_name': video_path.name, - 'total_frames': frame_count, - 'detected_frames': detected_frames, - 'fps': fps, - 'detections': all_detections, - 'processing_time': time.time() - start_time - } - - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(detection_summary, f, ensure_ascii=False, indent=2) - - # 更新统计 - self.stats['processed_videos'] += 1 - self.stats['total_frames'] += frame_count - self.stats['detected_frames'] += detected_frames - self.stats['saved_images'] += saved_images - self.stats['filtered_frames'] += filtered_frames - - processing_time = time.time() - start_time - print(f"处理完成: {video_path.name}") - print(f"检测到目标的帧数: {detected_frames}/{frame_count}, 保存图片: {saved_images} 张") - print(f"时间间隔过滤帧数: {filtered_frames} 帧 (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)") - print(f"处理时间: {processing_time:.2f} 秒") - - # 记录视频详细信息 - video_info = { - 'video_name': video_path.name, - 'video_path': str(video_path), - 'total_frames': frame_count, - 'detected_frames': detected_frames, - 'saved_images': saved_images, - 'filtered_frames': filtered_frames, - 'processing_time': processing_time, - 'fps': fps, - 'duration': frame_count / fps if fps > 0 else 0, - 'output_directory': str(output_dir) - } - self.video_details.append(video_info) - - def process_all_videos(self, input_folder="input_videos", output_folder="output_frames"): - """批量处理所有视频""" - self.stats['start_time'] = datetime.now() - - video_files = self.get_video_files(input_folder) - if not video_files: - print("没有找到视频文件") - return - - self.stats['total_videos'] = len(video_files) - - print(f"\n开始批量处理 {len(video_files)} 个视频文件") - print(f"模型: {self.model_path}") - print(f"置信度阈值: {self.confidence}") - print(f"输出目录: {output_folder}") - print("=" * 50) - - for i, video_path in enumerate(video_files, 1): - print(f"\n[{i}/{len(video_files)}] 处理视频") - self.process_video(video_path, output_folder) - - self.stats['end_time'] = datetime.now() - report_path = self.save_final_report(output_folder) - self.print_summary() - return report_path - - def save_final_report(self, output_folder): - """保存最终批处理报告""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - report_path = Path(output_folder) / f"batch_detection_report_{timestamp}.json" - - # 计算总处理时间 - total_seconds = 0 - if self.stats['start_time'] and self.stats['end_time']: - total_seconds = (self.stats['end_time'] - self.stats['start_time']).total_seconds() - - # 创建报告数据 - report_data = { - 'timestamp': timestamp, - 'model_path': self.model_path, - 'confidence_threshold': self.confidence, - 'frame_interval_threshold': config.MIN_FRAME_INTERVAL, - 'filter_strategy': '相同场景在指定时间间隔内只保存一次', - 'statistics': { - 'total_videos': self.stats['total_videos'], - 'processed_videos': self.stats['processed_videos'], - 'total_frames': self.stats['total_frames'], - 'detected_frames': self.stats['detected_frames'], - 'filtered_frames': self.stats['filtered_frames'], - 'saved_images': self.stats['saved_images'], - 'processing_time_seconds': total_seconds - }, - 'videos': self.video_details, - 'errors': self.stats['errors'] - } - - # 保存为JSON - with open(report_path, 'w', encoding='utf-8') as f: - json.dump(report_data, f, ensure_ascii=False, indent=2) - - print(f"\n批处理报告已保存: {report_path}") - return str(report_path) - - def print_summary(self): - """打印处理摘要""" - if not self.stats['end_time'] or not self.stats['start_time']: - return - - total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds() - hours, remainder = divmod(total_time, 3600) - minutes, seconds = divmod(remainder, 60) - - print("\n" + "=" * 50) - print("批量处理摘要") - print("=" * 50) - print(f"处理视频数量: {self.stats['processed_videos']}/{self.stats['total_videos']}") - print(f"总处理帧数: {self.stats['total_frames']}") - print(f"检测到目标的帧数: {self.stats['detected_frames']}") - print(f"时间间隔过滤帧数: {self.stats['filtered_frames']} (间隔阈值: {config.MIN_FRAME_INTERVAL}秒)") - print(f"过滤策略: 相同场景在{config.MIN_FRAME_INTERVAL}秒内只保存一次") - print(f"实际保存的图片数量: {self.stats['saved_images']}") - print(f"总处理时间: {int(hours)}小时 {int(minutes)}分钟 {seconds:.2f}秒") - - if self.stats['errors']: - print("\n处理过程中的错误:") - for error in self.stats['errors']: - print(f"- {error}") - - print("=" * 50) - -def main(): - """主函数""" - import argparse - - parser = argparse.ArgumentParser(description='批量视频检测系统') - parser.add_argument('--model', type=str, help='模型文件路径') - parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值') - parser.add_argument('--input', type=str, default='input_videos', help='输入视频文件夹') - parser.add_argument('--output', type=str, default='output_frames', help='输出图片文件夹') - - args = parser.parse_args() - - # 创建检测器 - detector = BatchVideoDetector( - model_path=args.model, - confidence=args.confidence - ) - - # 开始批量处理 - detector.process_all_videos(args.input, args.output) - -if __name__ == '__main__': - main() \ No newline at end of file