wrz-yolo/detector.py
2025-06-27 17:13:03 +08:00

453 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频目标检测系统
使用YOLO模型对视频进行道路损伤检测并抽取包含目标的帧
"""
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import argparse
from datetime import datetime
import json
import logging
import torch
from PIL import Image, ImageDraw, ImageFont
from config import config, DetectionConfig
class VideoDetectionSystem:
def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"):
"""
初始化视频检测系统
Args:
model_path (str): YOLO模型文件路径
confidence_threshold (float): 置信度阈值
input_folder (str): 输入视频文件夹
output_folder (str): 输出图片文件夹
"""
self.model_path = model_path
self.confidence_threshold = confidence_threshold
self.input_folder = Path(input_folder)
self.output_folder = Path(output_folder)
# 创建输出文件夹
self.output_folder.mkdir(parents=True, exist_ok=True)
# 设置日志
self.setup_logging()
# 加载模型
self.load_model()
# 支持的视频格式
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
# 动态检测参数
self.adaptive_confidence = True # 启用自适应置信度
self.min_confidence = 0.2 # 最低置信度
self.max_confidence = 0.8 # 最高置信度
self.scene_adaptation = True # 启用场景自适应
# 检测结果统计
self.detection_stats = {
'total_videos': 0,
'total_frames': 0,
'detected_frames': 0,
'saved_frames': 0,
'detection_results': []
}
# 统计信息
self.stats = {
'adaptive_detections': 0
}
# 检测质量阈值
self.detection_quality_threshold = 0.3
def setup_logging(self):
"""设置日志记录"""
log_folder = Path('logs')
log_folder.mkdir(exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_folder / f'detection_{timestamp}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler()
]
)
self.logger = logging.getLogger(__name__)
def load_model(self):
"""加载YOLO模型"""
try:
self.model = YOLO(self.model_path)
self.logger.info(f"成功加载模型: {self.model_path}")
# 动态加载模型类别信息到配置中
DetectionConfig.load_model_classes(self.model_path)
# 打印模型信息
if hasattr(self.model, 'names'):
self.logger.info(f"模型类别数量: {len(self.model.names)}")
self.logger.info(f"检测类别: {list(self.model.names.values())}")
self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}")
# 分析模型类型并设置优化参数
self.analyze_model_type()
except Exception as e:
self.logger.error(f"加载模型失败: {e}")
raise
def analyze_model_type(self):
"""分析模型类型并设置相应的检测参数"""
model_classes = list(self.model.names.values())
# 检测是否为道路损伤专用模型
road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤']
is_road_damage_model = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords)
if is_road_damage_model:
self.logger.info("检测到道路损伤专用模型,使用优化参数")
self.confidence_threshold = max(0.25, self.confidence_threshold) # 道路损伤检测使用较低阈值
self.min_confidence = 0.15
else:
self.logger.info("检测到通用模型,使用标准参数")
self.confidence_threshold = max(0.4, self.confidence_threshold) # 通用模型使用较高阈值
self.min_confidence = 0.3
self.logger.info(f"调整后的置信度阈值: {self.confidence_threshold}")
def get_video_files(self):
"""获取输入文件夹中的所有视频文件"""
video_files = []
if not self.input_folder.exists():
self.logger.warning(f"输入文件夹不存在: {self.input_folder}")
return video_files
for file_path in self.input_folder.rglob('*'):
if file_path.is_file() and file_path.suffix.lower() in self.video_extensions:
video_files.append(file_path)
self.logger.info(f"找到 {len(video_files)} 个视频文件")
return video_files
def detect_frame(self, frame, adaptive_conf=None):
"""对单帧进行目标检测"""
try:
# 使用自适应置信度或默认置信度
conf_threshold = adaptive_conf if adaptive_conf is not None else self.confidence_threshold
# 多尺度检测以提高检测率
results = self.model(frame, conf=conf_threshold, verbose=False,
imgsz=640, augment=True) # 启用测试时增强
# 如果没有检测到目标且启用自适应,尝试降低置信度
if self.adaptive_confidence and (not results or not results[0].boxes or len(results[0].boxes) == 0):
if conf_threshold > self.min_confidence:
lower_conf = max(self.min_confidence, conf_threshold - 0.1)
self.logger.debug(f"降低置信度重试: {lower_conf}")
results = self.model(frame, conf=lower_conf, verbose=False,
imgsz=640, augment=True)
return results[0] if results else None
except Exception as e:
self.logger.error(f"检测失败: {e}")
return None
def adaptive_detect_frame(self, frame, base_confidence):
"""自适应检测帧"""
try:
# 尝试不同的置信度阈值
confidence_levels = [base_confidence, base_confidence * 0.8, base_confidence * 0.6]
for conf in confidence_levels:
if conf < self.min_confidence:
continue
results = self.model(frame, conf=conf, verbose=False)
if len(results) > 0 and len(results[0].boxes) > 0:
# 过滤低质量检测
boxes = results[0].boxes
valid_indices = []
for i in range(len(boxes)):
confidence = float(boxes.conf[i])
if confidence >= self.min_confidence:
valid_indices.append(i)
if valid_indices:
# 直接使用原始结果,但只返回高质量检测
return results
return []
except Exception as e:
print(f"自适应检测失败: {e}")
return []
def draw_detections(self, frame, result):
"""在帧上绘制检测结果"""
if result is None or result.boxes is None:
return frame
# 转换为PIL图像以支持中文字体
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统字体路径
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
if os.path.exists(font_path):
font = ImageFont.truetype(font_path, 20)
else:
font = ImageFont.load_default()
except:
font = ImageFont.load_default()
for box in result.boxes:
# 获取边界框坐标
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
# 获取置信度和类别
confidence = float(box.conf[0].cpu().numpy())
class_id = int(box.cls[0].cpu().numpy())
# 只显示中文类别名称
display_name = config.get_class_name_cn(class_id)
# 根据置信度调整颜色强度
base_color = config.get_class_color(class_id)
intensity = min(1.0, confidence + 0.3) # 确保颜色不会太暗
color = tuple(int(c * intensity) for c in base_color)
# 绘制边界框使用PIL
bg_color = tuple(reversed(color)) # BGR转RGB
line_width = max(2, int(confidence * 4)) # 根据置信度调整线宽
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width)
# 准备标签文字
label = f"{display_name}: {confidence:.3f}"
# 获取文字尺寸
bbox = draw.textbbox((0, 0), label, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 绘制标签背景使用PIL
draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color)
# 绘制文字使用PIL
draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255))
# 转换回OpenCV格式
annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return annotated_frame
def process_video(self, video_path):
"""处理单个视频文件"""
self.logger.info(f"开始处理视频: {video_path.name}")
# 创建视频专用输出文件夹
video_output_folder = self.output_folder / video_path.stem
video_output_folder.mkdir(exist_ok=True)
# 文件夹将在检测到对应类别时按需创建
# 打开视频
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
self.logger.error(f"无法打开视频文件: {video_path}")
return
# 获取视频信息
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}")
frame_count = 0
detected_count = 0
saved_count = 0
video_detections = {
'video_name': video_path.name,
'total_frames': total_frames,
'fps': fps,
'duration': duration,
'detections': []
}
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# 每100帧显示一次进度
if frame_count % 100 == 0:
progress = (frame_count / total_frames) * 100
self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})")
# 进行目标检测
result = self.detect_frame(frame)
# 如果检测到目标,保存帧
if result is not None and result.boxes is not None and len(result.boxes) > 0:
detected_count += 1
timestamp = frame_count / fps if fps > 0 else frame_count
# 获取当前帧检测到的所有类别
detected_classes = set()
for box in result.boxes:
class_id = int(box.cls[0].cpu().numpy())
detected_classes.add(class_id)
# 只为检测到的类别创建文件夹并保存图片
annotated_frame = self.draw_detections(frame, result)
detection_info = {
'frame_number': frame_count,
'timestamp': timestamp,
'detections': []
}
# 按检测到的类别保存图片
for class_id in detected_classes:
class_name_cn = config.get_class_name_cn(class_id)
class_name_en = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
# 只为实际检测到的类别创建文件夹
class_folder = video_output_folder / class_name_cn
class_folder.mkdir(exist_ok=True)
# 保存检测到该类别的帧(只保存标注后的图片)
annotated_filename = f"detected_{frame_count:06d}_t{timestamp:.2f}s_{class_name_cn}.jpg"
annotated_path = class_folder / annotated_filename
cv2.imwrite(str(annotated_path), annotated_frame)
saved_count += 1
# 记录检测信息
for box in result.boxes:
confidence = float(box.conf[0].cpu().numpy())
class_id = int(box.cls[0].cpu().numpy())
class_name = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
bbox = box.xyxy[0].cpu().numpy().tolist()
detection_info['detections'].append({
'class_name': class_name,
'class_name_cn': config.get_class_name_cn(class_id),
'class_id': class_id,
'confidence': confidence,
'bbox': bbox
})
video_detections['detections'].append(detection_info)
finally:
cap.release()
# 保存检测结果到JSON文件
json_path = video_output_folder / f"{video_path.stem}_detections.json"
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(video_detections, f, ensure_ascii=False, indent=2)
self.logger.info(f"视频处理完成: {video_path.name}")
self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}")
# 更新统计信息
self.detection_stats['total_videos'] += 1
self.detection_stats['total_frames'] += frame_count
self.detection_stats['detected_frames'] += detected_count
self.detection_stats['saved_frames'] += saved_count
self.detection_stats['detection_results'].append(video_detections)
def process_all_videos(self):
"""处理所有视频文件"""
video_files = self.get_video_files()
if not video_files:
self.logger.warning("没有找到视频文件")
return
self.logger.info(f"开始处理 {len(video_files)} 个视频文件")
for i, video_path in enumerate(video_files, 1):
self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===")
self.process_video(video_path)
# 保存总体统计信息
self.save_summary_report()
self.logger.info("\n=== 处理完成 ===")
self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']}")
self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']}")
self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']}")
self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']}")
def save_summary_report(self):
"""保存总结报告"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = self.output_folder / f"detection_summary_{timestamp}.json"
summary = {
'timestamp': timestamp,
'model_path': str(self.model_path),
'confidence_threshold': self.confidence_threshold,
'input_folder': str(self.input_folder),
'output_folder': str(self.output_folder),
'statistics': self.detection_stats
}
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
self.logger.info(f"总结报告已保存: {report_path}")
def main():
parser = argparse.ArgumentParser(description='视频目标检测系统')
parser.add_argument('--model', type=str,
default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt',
help='YOLO模型文件路径')
parser.add_argument('--confidence', type=float, default=0.5,
help='置信度阈值 (默认: 0.5)')
parser.add_argument('--input', type=str, default='input_videos',
help='输入视频文件夹 (默认: input_videos)')
parser.add_argument('--output', type=str, default='output_frames',
help='输出图片文件夹 (默认: output_frames)')
args = parser.parse_args()
# 创建检测系统
detector = VideoDetectionSystem(
model_path=args.model,
confidence_threshold=args.confidence,
input_folder=args.input,
output_folder=args.output
)
# 处理所有视频
detector.process_all_videos()
if __name__ == '__main__':
main()