删除 detector.py

This commit is contained in:
Wang_Run_Ze 2025-06-27 17:11:12 +08:00
parent e1d71d67c7
commit a22521f738

View File

@ -1,356 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频目标检测系统
使用YOLO模型对视频进行道路损伤检测并抽取包含目标的帧
"""
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import argparse
from datetime import datetime
import json
import logging
from PIL import Image, ImageDraw, ImageFont
from config import config, DetectionConfig
class VideoDetectionSystem:
def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"):
"""
初始化视频检测系统
Args:
model_path (str): YOLO模型文件路径
confidence_threshold (float): 置信度阈值
input_folder (str): 输入视频文件夹
output_folder (str): 输出图片文件夹
"""
self.model_path = model_path
self.confidence_threshold = confidence_threshold
self.input_folder = Path(input_folder)
self.output_folder = Path(output_folder)
# 创建输出文件夹
self.output_folder.mkdir(parents=True, exist_ok=True)
# 设置日志
self.setup_logging()
# 加载模型
self.load_model()
# 支持的视频格式
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
# 检测结果统计
self.detection_stats = {
'total_videos': 0,
'total_frames': 0,
'detected_frames': 0,
'saved_frames': 0,
'detection_results': []
}
def setup_logging(self):
"""设置日志记录"""
log_folder = Path('logs')
log_folder.mkdir(exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_folder / f'detection_{timestamp}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler()
]
)
self.logger = logging.getLogger(__name__)
def load_model(self):
"""加载YOLO模型"""
try:
self.model = YOLO(self.model_path)
self.logger.info(f"成功加载模型: {self.model_path}")
# 动态加载模型类别信息到配置中
DetectionConfig.load_model_classes(self.model_path)
# 打印模型信息
if hasattr(self.model, 'names'):
self.logger.info(f"模型类别数量: {len(self.model.names)}")
self.logger.info(f"检测类别: {list(self.model.names.values())}")
self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}")
except Exception as e:
self.logger.error(f"加载模型失败: {e}")
raise
def get_video_files(self):
"""获取输入文件夹中的所有视频文件"""
video_files = []
if not self.input_folder.exists():
self.logger.warning(f"输入文件夹不存在: {self.input_folder}")
return video_files
for file_path in self.input_folder.rglob('*'):
if file_path.is_file() and file_path.suffix.lower() in self.video_extensions:
video_files.append(file_path)
self.logger.info(f"找到 {len(video_files)} 个视频文件")
return video_files
def detect_frame(self, frame):
"""对单帧进行目标检测"""
try:
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
return results[0] if results else None
except Exception as e:
self.logger.error(f"检测失败: {e}")
return None
def draw_detections(self, frame, result):
"""在帧上绘制检测结果"""
if result is None or result.boxes is None:
return frame
# 转换为PIL图像以支持中文字体
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统字体路径
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
if os.path.exists(font_path):
font = ImageFont.truetype(font_path, 20)
else:
font = ImageFont.load_default()
except:
font = ImageFont.load_default()
for box in result.boxes:
# 获取边界框坐标
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
# 获取置信度和类别
confidence = float(box.conf[0].cpu().numpy())
class_id = int(box.cls[0].cpu().numpy())
# 获取中文类别名称和颜色
class_name_cn = config.get_class_name_cn(class_id)
color = config.get_class_color(class_id)
# 绘制边界框使用PIL
bg_color = tuple(reversed(color)) # BGR转RGB
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2)
# 准备标签文字
label = f"{class_name_cn}: {confidence:.2f}"
# 获取文字尺寸
bbox = draw.textbbox((0, 0), label, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 绘制标签背景使用PIL
draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color)
# 绘制文字使用PIL
draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255))
# 转换回OpenCV格式
annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return annotated_frame
def process_video(self, video_path):
"""处理单个视频文件"""
self.logger.info(f"开始处理视频: {video_path.name}")
# 创建视频专用输出文件夹
video_output_folder = self.output_folder / video_path.stem
video_output_folder.mkdir(exist_ok=True)
# 打开视频
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
self.logger.error(f"无法打开视频文件: {video_path}")
return
# 获取视频信息
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}")
frame_count = 0
detected_count = 0
saved_count = 0
video_detections = {
'video_name': video_path.name,
'total_frames': total_frames,
'fps': fps,
'duration': duration,
'detections': []
}
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# 每100帧显示一次进度
if frame_count % 100 == 0:
progress = (frame_count / total_frames) * 100
self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})")
# 进行目标检测
result = self.detect_frame(frame)
# 如果检测到目标,保存帧
if result is not None and result.boxes is not None and len(result.boxes) > 0:
detected_count += 1
# 计算时间戳
timestamp = frame_count / fps if fps > 0 else frame_count
# 绘制检测结果
annotated_frame = self.draw_detections(frame, result)
# 保存原始帧和标注帧
frame_filename = f"frame_{frame_count:06d}_t{timestamp:.2f}s.jpg"
annotated_filename = f"annotated_{frame_count:06d}_t{timestamp:.2f}s.jpg"
original_path = video_output_folder / frame_filename
annotated_path = video_output_folder / annotated_filename
cv2.imwrite(str(original_path), frame)
cv2.imwrite(str(annotated_path), annotated_frame)
saved_count += 2
# 记录检测信息
detection_info = {
'frame_number': frame_count,
'timestamp': timestamp,
'detections': []
}
for box in result.boxes:
confidence = float(box.conf[0].cpu().numpy())
class_id = int(box.cls[0].cpu().numpy())
class_name = self.model.names[class_id]
bbox = box.xyxy[0].cpu().numpy().tolist()
detection_info['detections'].append({
'class_name': class_name,
'class_id': class_id,
'confidence': confidence,
'bbox': bbox
})
video_detections['detections'].append(detection_info)
finally:
cap.release()
# 保存检测结果到JSON文件
json_path = video_output_folder / f"{video_path.stem}_detections.json"
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(video_detections, f, ensure_ascii=False, indent=2)
self.logger.info(f"视频处理完成: {video_path.name}")
self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}")
# 更新统计信息
self.detection_stats['total_videos'] += 1
self.detection_stats['total_frames'] += frame_count
self.detection_stats['detected_frames'] += detected_count
self.detection_stats['saved_frames'] += saved_count
self.detection_stats['detection_results'].append(video_detections)
def process_all_videos(self):
"""处理所有视频文件"""
video_files = self.get_video_files()
if not video_files:
self.logger.warning("没有找到视频文件")
return
self.logger.info(f"开始处理 {len(video_files)} 个视频文件")
for i, video_path in enumerate(video_files, 1):
self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===")
self.process_video(video_path)
# 保存总体统计信息
self.save_summary_report()
self.logger.info("\n=== 处理完成 ===")
self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']}")
self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']}")
self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']}")
self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']}")
def save_summary_report(self):
"""保存总结报告"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = self.output_folder / f"detection_summary_{timestamp}.json"
summary = {
'timestamp': timestamp,
'model_path': str(self.model_path),
'confidence_threshold': self.confidence_threshold,
'input_folder': str(self.input_folder),
'output_folder': str(self.output_folder),
'statistics': self.detection_stats
}
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
self.logger.info(f"总结报告已保存: {report_path}")
def main():
parser = argparse.ArgumentParser(description='视频目标检测系统')
parser.add_argument('--model', type=str,
default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt',
help='YOLO模型文件路径')
parser.add_argument('--confidence', type=float, default=0.5,
help='置信度阈值 (默认: 0.5)')
parser.add_argument('--input', type=str, default='input_videos',
help='输入视频文件夹 (默认: input_videos)')
parser.add_argument('--output', type=str, default='output_frames',
help='输出图片文件夹 (默认: output_frames)')
args = parser.parse_args()
# 创建检测系统
detector = VideoDetectionSystem(
model_path=args.model,
confidence_threshold=args.confidence,
input_folder=args.input,
output_folder=args.output
)
# 处理所有视频
detector.process_all_videos()
if __name__ == '__main__':
main()