diff --git a/detector.py b/detector.py deleted file mode 100644 index ffbb0ce..0000000 --- a/detector.py +++ /dev/null @@ -1,356 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -视频目标检测系统 -使用YOLO模型对视频进行道路损伤检测,并抽取包含目标的帧 -""" - -import os -import cv2 -import numpy as np -from pathlib import Path -from ultralytics import YOLO -import argparse -from datetime import datetime -import json -import logging -from PIL import Image, ImageDraw, ImageFont -from config import config, DetectionConfig - -class VideoDetectionSystem: - def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"): - """ - 初始化视频检测系统 - - Args: - model_path (str): YOLO模型文件路径 - confidence_threshold (float): 置信度阈值 - input_folder (str): 输入视频文件夹 - output_folder (str): 输出图片文件夹 - """ - self.model_path = model_path - self.confidence_threshold = confidence_threshold - self.input_folder = Path(input_folder) - self.output_folder = Path(output_folder) - - # 创建输出文件夹 - self.output_folder.mkdir(parents=True, exist_ok=True) - - # 设置日志 - self.setup_logging() - - # 加载模型 - self.load_model() - - # 支持的视频格式 - self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'} - - # 检测结果统计 - self.detection_stats = { - 'total_videos': 0, - 'total_frames': 0, - 'detected_frames': 0, - 'saved_frames': 0, - 'detection_results': [] - } - - def setup_logging(self): - """设置日志记录""" - log_folder = Path('logs') - log_folder.mkdir(exist_ok=True) - - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - log_file = log_folder / f'detection_{timestamp}.log' - - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_file, encoding='utf-8'), - logging.StreamHandler() - ] - ) - self.logger = logging.getLogger(__name__) - - def load_model(self): - """加载YOLO模型""" - try: - self.model = YOLO(self.model_path) - self.logger.info(f"成功加载模型: {self.model_path}") - - # 动态加载模型类别信息到配置中 - DetectionConfig.load_model_classes(self.model_path) - - # 打印模型信息 - if hasattr(self.model, 'names'): - self.logger.info(f"模型类别数量: {len(self.model.names)}") - self.logger.info(f"检测类别: {list(self.model.names.values())}") - self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}") - - except Exception as e: - self.logger.error(f"加载模型失败: {e}") - raise - - def get_video_files(self): - """获取输入文件夹中的所有视频文件""" - video_files = [] - - if not self.input_folder.exists(): - self.logger.warning(f"输入文件夹不存在: {self.input_folder}") - return video_files - - for file_path in self.input_folder.rglob('*'): - if file_path.is_file() and file_path.suffix.lower() in self.video_extensions: - video_files.append(file_path) - - self.logger.info(f"找到 {len(video_files)} 个视频文件") - return video_files - - def detect_frame(self, frame): - """对单帧进行目标检测""" - try: - results = self.model(frame, conf=self.confidence_threshold, verbose=False) - return results[0] if results else None - except Exception as e: - self.logger.error(f"检测失败: {e}") - return None - - def draw_detections(self, frame, result): - """在帧上绘制检测结果""" - if result is None or result.boxes is None: - return frame - - # 转换为PIL图像以支持中文字体 - pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - draw = ImageDraw.Draw(pil_image) - - # 尝试加载中文字体 - try: - # Windows系统字体路径 - font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体 - if not os.path.exists(font_path): - font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑 - if not os.path.exists(font_path): - font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体 - - if os.path.exists(font_path): - font = ImageFont.truetype(font_path, 20) - else: - font = ImageFont.load_default() - except: - font = ImageFont.load_default() - - for box in result.boxes: - # 获取边界框坐标 - x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) - - # 获取置信度和类别 - confidence = float(box.conf[0].cpu().numpy()) - class_id = int(box.cls[0].cpu().numpy()) - - # 获取中文类别名称和颜色 - class_name_cn = config.get_class_name_cn(class_id) - color = config.get_class_color(class_id) - - # 绘制边界框(使用PIL) - bg_color = tuple(reversed(color)) # BGR转RGB - draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2) - - # 准备标签文字 - label = f"{class_name_cn}: {confidence:.2f}" - - # 获取文字尺寸 - bbox = draw.textbbox((0, 0), label, font=font) - text_width = bbox[2] - bbox[0] - text_height = bbox[3] - bbox[1] - - # 绘制标签背景(使用PIL) - draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color) - - # 绘制文字(使用PIL) - draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255)) - - # 转换回OpenCV格式 - annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) - return annotated_frame - - def process_video(self, video_path): - """处理单个视频文件""" - self.logger.info(f"开始处理视频: {video_path.name}") - - # 创建视频专用输出文件夹 - video_output_folder = self.output_folder / video_path.stem - video_output_folder.mkdir(exist_ok=True) - - # 打开视频 - cap = cv2.VideoCapture(str(video_path)) - if not cap.isOpened(): - self.logger.error(f"无法打开视频文件: {video_path}") - return - - # 获取视频信息 - fps = cap.get(cv2.CAP_PROP_FPS) - total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - duration = total_frames / fps if fps > 0 else 0 - - self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}秒") - - frame_count = 0 - detected_count = 0 - saved_count = 0 - - video_detections = { - 'video_name': video_path.name, - 'total_frames': total_frames, - 'fps': fps, - 'duration': duration, - 'detections': [] - } - - try: - while True: - ret, frame = cap.read() - if not ret: - break - - frame_count += 1 - - # 每100帧显示一次进度 - if frame_count % 100 == 0: - progress = (frame_count / total_frames) * 100 - self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})") - - # 进行目标检测 - result = self.detect_frame(frame) - - # 如果检测到目标,保存帧 - if result is not None and result.boxes is not None and len(result.boxes) > 0: - detected_count += 1 - - # 计算时间戳 - timestamp = frame_count / fps if fps > 0 else frame_count - - # 绘制检测结果 - annotated_frame = self.draw_detections(frame, result) - - # 保存原始帧和标注帧 - frame_filename = f"frame_{frame_count:06d}_t{timestamp:.2f}s.jpg" - annotated_filename = f"annotated_{frame_count:06d}_t{timestamp:.2f}s.jpg" - - original_path = video_output_folder / frame_filename - annotated_path = video_output_folder / annotated_filename - - cv2.imwrite(str(original_path), frame) - cv2.imwrite(str(annotated_path), annotated_frame) - - saved_count += 2 - - # 记录检测信息 - detection_info = { - 'frame_number': frame_count, - 'timestamp': timestamp, - 'detections': [] - } - - for box in result.boxes: - confidence = float(box.conf[0].cpu().numpy()) - class_id = int(box.cls[0].cpu().numpy()) - class_name = self.model.names[class_id] - bbox = box.xyxy[0].cpu().numpy().tolist() - - detection_info['detections'].append({ - 'class_name': class_name, - 'class_id': class_id, - 'confidence': confidence, - 'bbox': bbox - }) - - video_detections['detections'].append(detection_info) - - finally: - cap.release() - - # 保存检测结果到JSON文件 - json_path = video_output_folder / f"{video_path.stem}_detections.json" - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(video_detections, f, ensure_ascii=False, indent=2) - - self.logger.info(f"视频处理完成: {video_path.name}") - self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}") - - # 更新统计信息 - self.detection_stats['total_videos'] += 1 - self.detection_stats['total_frames'] += frame_count - self.detection_stats['detected_frames'] += detected_count - self.detection_stats['saved_frames'] += saved_count - self.detection_stats['detection_results'].append(video_detections) - - def process_all_videos(self): - """处理所有视频文件""" - video_files = self.get_video_files() - - if not video_files: - self.logger.warning("没有找到视频文件") - return - - self.logger.info(f"开始处理 {len(video_files)} 个视频文件") - - for i, video_path in enumerate(video_files, 1): - self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===") - self.process_video(video_path) - - # 保存总体统计信息 - self.save_summary_report() - - self.logger.info("\n=== 处理完成 ===") - self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']} 个") - self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']} 帧") - self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']} 帧") - self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']} 张") - - def save_summary_report(self): - """保存总结报告""" - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - report_path = self.output_folder / f"detection_summary_{timestamp}.json" - - summary = { - 'timestamp': timestamp, - 'model_path': str(self.model_path), - 'confidence_threshold': self.confidence_threshold, - 'input_folder': str(self.input_folder), - 'output_folder': str(self.output_folder), - 'statistics': self.detection_stats - } - - with open(report_path, 'w', encoding='utf-8') as f: - json.dump(summary, f, ensure_ascii=False, indent=2) - - self.logger.info(f"总结报告已保存: {report_path}") - -def main(): - parser = argparse.ArgumentParser(description='视频目标检测系统') - parser.add_argument('--model', type=str, - default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt', - help='YOLO模型文件路径') - parser.add_argument('--confidence', type=float, default=0.5, - help='置信度阈值 (默认: 0.5)') - parser.add_argument('--input', type=str, default='input_videos', - help='输入视频文件夹 (默认: input_videos)') - parser.add_argument('--output', type=str, default='output_frames', - help='输出图片文件夹 (默认: output_frames)') - - args = parser.parse_args() - - # 创建检测系统 - detector = VideoDetectionSystem( - model_path=args.model, - confidence_threshold=args.confidence, - input_folder=args.input, - output_folder=args.output - ) - - # 处理所有视频 - detector.process_all_videos() - -if __name__ == '__main__': - main() \ No newline at end of file