453 lines
18 KiB
Python
453 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
视频目标检测系统
|
||
使用YOLO模型对视频进行道路损伤检测,并抽取包含目标的帧
|
||
"""
|
||
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from ultralytics import YOLO
|
||
import argparse
|
||
from datetime import datetime
|
||
import json
|
||
import logging
|
||
import torch
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
from config import config, DetectionConfig
|
||
|
||
class VideoDetectionSystem:
|
||
def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"):
|
||
"""
|
||
初始化视频检测系统
|
||
|
||
Args:
|
||
model_path (str): YOLO模型文件路径
|
||
confidence_threshold (float): 置信度阈值
|
||
input_folder (str): 输入视频文件夹
|
||
output_folder (str): 输出图片文件夹
|
||
"""
|
||
self.model_path = model_path
|
||
self.confidence_threshold = confidence_threshold
|
||
self.input_folder = Path(input_folder)
|
||
self.output_folder = Path(output_folder)
|
||
|
||
# 创建输出文件夹
|
||
self.output_folder.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 设置日志
|
||
self.setup_logging()
|
||
|
||
# 加载模型
|
||
self.load_model()
|
||
|
||
# 支持的视频格式
|
||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
|
||
|
||
# 动态检测参数
|
||
self.adaptive_confidence = True # 启用自适应置信度
|
||
self.min_confidence = 0.2 # 最低置信度
|
||
self.max_confidence = 0.8 # 最高置信度
|
||
self.scene_adaptation = True # 启用场景自适应
|
||
|
||
# 检测结果统计
|
||
self.detection_stats = {
|
||
'total_videos': 0,
|
||
'total_frames': 0,
|
||
'detected_frames': 0,
|
||
'saved_frames': 0,
|
||
'detection_results': []
|
||
}
|
||
|
||
# 统计信息
|
||
self.stats = {
|
||
'adaptive_detections': 0
|
||
}
|
||
|
||
# 检测质量阈值
|
||
self.detection_quality_threshold = 0.3
|
||
|
||
def setup_logging(self):
|
||
"""设置日志记录"""
|
||
log_folder = Path('logs')
|
||
log_folder.mkdir(exist_ok=True)
|
||
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
log_file = log_folder / f'detection_{timestamp}.log'
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler(log_file, encoding='utf-8'),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
def load_model(self):
|
||
"""加载YOLO模型"""
|
||
try:
|
||
self.model = YOLO(self.model_path)
|
||
self.logger.info(f"成功加载模型: {self.model_path}")
|
||
|
||
# 动态加载模型类别信息到配置中
|
||
DetectionConfig.load_model_classes(self.model_path)
|
||
|
||
# 打印模型信息
|
||
if hasattr(self.model, 'names'):
|
||
self.logger.info(f"模型类别数量: {len(self.model.names)}")
|
||
self.logger.info(f"检测类别: {list(self.model.names.values())}")
|
||
self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}")
|
||
|
||
# 分析模型类型并设置优化参数
|
||
self.analyze_model_type()
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"加载模型失败: {e}")
|
||
raise
|
||
|
||
def analyze_model_type(self):
|
||
"""分析模型类型并设置相应的检测参数"""
|
||
model_classes = list(self.model.names.values())
|
||
|
||
# 检测是否为道路损伤专用模型
|
||
road_damage_keywords = ['crack', 'pothole', 'damage', '裂缝', '坑洞', '损伤']
|
||
is_road_damage_model = any(keyword in ' '.join(model_classes).lower() for keyword in road_damage_keywords)
|
||
|
||
if is_road_damage_model:
|
||
self.logger.info("检测到道路损伤专用模型,使用优化参数")
|
||
self.confidence_threshold = max(0.25, self.confidence_threshold) # 道路损伤检测使用较低阈值
|
||
self.min_confidence = 0.15
|
||
else:
|
||
self.logger.info("检测到通用模型,使用标准参数")
|
||
self.confidence_threshold = max(0.4, self.confidence_threshold) # 通用模型使用较高阈值
|
||
self.min_confidence = 0.3
|
||
|
||
self.logger.info(f"调整后的置信度阈值: {self.confidence_threshold}")
|
||
|
||
def get_video_files(self):
|
||
"""获取输入文件夹中的所有视频文件"""
|
||
video_files = []
|
||
|
||
if not self.input_folder.exists():
|
||
self.logger.warning(f"输入文件夹不存在: {self.input_folder}")
|
||
return video_files
|
||
|
||
for file_path in self.input_folder.rglob('*'):
|
||
if file_path.is_file() and file_path.suffix.lower() in self.video_extensions:
|
||
video_files.append(file_path)
|
||
|
||
self.logger.info(f"找到 {len(video_files)} 个视频文件")
|
||
return video_files
|
||
|
||
def detect_frame(self, frame, adaptive_conf=None):
|
||
"""对单帧进行目标检测"""
|
||
try:
|
||
# 使用自适应置信度或默认置信度
|
||
conf_threshold = adaptive_conf if adaptive_conf is not None else self.confidence_threshold
|
||
|
||
# 多尺度检测以提高检测率
|
||
results = self.model(frame, conf=conf_threshold, verbose=False,
|
||
imgsz=640, augment=True) # 启用测试时增强
|
||
|
||
# 如果没有检测到目标且启用自适应,尝试降低置信度
|
||
if self.adaptive_confidence and (not results or not results[0].boxes or len(results[0].boxes) == 0):
|
||
if conf_threshold > self.min_confidence:
|
||
lower_conf = max(self.min_confidence, conf_threshold - 0.1)
|
||
self.logger.debug(f"降低置信度重试: {lower_conf}")
|
||
results = self.model(frame, conf=lower_conf, verbose=False,
|
||
imgsz=640, augment=True)
|
||
|
||
return results[0] if results else None
|
||
except Exception as e:
|
||
self.logger.error(f"检测失败: {e}")
|
||
return None
|
||
|
||
def adaptive_detect_frame(self, frame, base_confidence):
|
||
"""自适应检测帧"""
|
||
try:
|
||
# 尝试不同的置信度阈值
|
||
confidence_levels = [base_confidence, base_confidence * 0.8, base_confidence * 0.6]
|
||
|
||
for conf in confidence_levels:
|
||
if conf < self.min_confidence:
|
||
continue
|
||
|
||
results = self.model(frame, conf=conf, verbose=False)
|
||
|
||
if len(results) > 0 and len(results[0].boxes) > 0:
|
||
# 过滤低质量检测
|
||
boxes = results[0].boxes
|
||
valid_indices = []
|
||
|
||
for i in range(len(boxes)):
|
||
confidence = float(boxes.conf[i])
|
||
if confidence >= self.min_confidence:
|
||
valid_indices.append(i)
|
||
|
||
if valid_indices:
|
||
# 直接使用原始结果,但只返回高质量检测
|
||
return results
|
||
|
||
return []
|
||
|
||
except Exception as e:
|
||
print(f"自适应检测失败: {e}")
|
||
return []
|
||
|
||
def draw_detections(self, frame, result):
|
||
"""在帧上绘制检测结果"""
|
||
if result is None or result.boxes is None:
|
||
return frame
|
||
|
||
# 转换为PIL图像以支持中文字体
|
||
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(pil_image)
|
||
|
||
# 尝试加载中文字体
|
||
try:
|
||
# Windows系统字体路径
|
||
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
|
||
if not os.path.exists(font_path):
|
||
font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑
|
||
if not os.path.exists(font_path):
|
||
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
|
||
|
||
if os.path.exists(font_path):
|
||
font = ImageFont.truetype(font_path, 20)
|
||
else:
|
||
font = ImageFont.load_default()
|
||
except:
|
||
font = ImageFont.load_default()
|
||
|
||
for box in result.boxes:
|
||
# 获取边界框坐标
|
||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||
|
||
# 获取置信度和类别
|
||
confidence = float(box.conf[0].cpu().numpy())
|
||
class_id = int(box.cls[0].cpu().numpy())
|
||
|
||
# 只显示中文类别名称
|
||
display_name = config.get_class_name_cn(class_id)
|
||
|
||
# 根据置信度调整颜色强度
|
||
base_color = config.get_class_color(class_id)
|
||
intensity = min(1.0, confidence + 0.3) # 确保颜色不会太暗
|
||
color = tuple(int(c * intensity) for c in base_color)
|
||
|
||
# 绘制边界框(使用PIL)
|
||
bg_color = tuple(reversed(color)) # BGR转RGB
|
||
line_width = max(2, int(confidence * 4)) # 根据置信度调整线宽
|
||
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=line_width)
|
||
|
||
# 准备标签文字
|
||
label = f"{display_name}: {confidence:.3f}"
|
||
|
||
# 获取文字尺寸
|
||
bbox = draw.textbbox((0, 0), label, font=font)
|
||
text_width = bbox[2] - bbox[0]
|
||
text_height = bbox[3] - bbox[1]
|
||
|
||
# 绘制标签背景(使用PIL)
|
||
draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color)
|
||
|
||
# 绘制文字(使用PIL)
|
||
draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255))
|
||
|
||
# 转换回OpenCV格式
|
||
annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||
return annotated_frame
|
||
|
||
def process_video(self, video_path):
|
||
"""处理单个视频文件"""
|
||
self.logger.info(f"开始处理视频: {video_path.name}")
|
||
|
||
# 创建视频专用输出文件夹
|
||
video_output_folder = self.output_folder / video_path.stem
|
||
video_output_folder.mkdir(exist_ok=True)
|
||
|
||
# 文件夹将在检测到对应类别时按需创建
|
||
|
||
# 打开视频
|
||
cap = cv2.VideoCapture(str(video_path))
|
||
if not cap.isOpened():
|
||
self.logger.error(f"无法打开视频文件: {video_path}")
|
||
return
|
||
|
||
# 获取视频信息
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
duration = total_frames / fps if fps > 0 else 0
|
||
|
||
self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}秒")
|
||
|
||
frame_count = 0
|
||
detected_count = 0
|
||
saved_count = 0
|
||
|
||
video_detections = {
|
||
'video_name': video_path.name,
|
||
'total_frames': total_frames,
|
||
'fps': fps,
|
||
'duration': duration,
|
||
'detections': []
|
||
}
|
||
|
||
try:
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
frame_count += 1
|
||
|
||
# 每100帧显示一次进度
|
||
if frame_count % 100 == 0:
|
||
progress = (frame_count / total_frames) * 100
|
||
self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})")
|
||
|
||
# 进行目标检测
|
||
result = self.detect_frame(frame)
|
||
|
||
# 如果检测到目标,保存帧
|
||
if result is not None and result.boxes is not None and len(result.boxes) > 0:
|
||
detected_count += 1
|
||
timestamp = frame_count / fps if fps > 0 else frame_count
|
||
|
||
# 获取当前帧检测到的所有类别
|
||
detected_classes = set()
|
||
for box in result.boxes:
|
||
class_id = int(box.cls[0].cpu().numpy())
|
||
detected_classes.add(class_id)
|
||
|
||
# 只为检测到的类别创建文件夹并保存图片
|
||
annotated_frame = self.draw_detections(frame, result)
|
||
detection_info = {
|
||
'frame_number': frame_count,
|
||
'timestamp': timestamp,
|
||
'detections': []
|
||
}
|
||
|
||
# 按检测到的类别保存图片
|
||
for class_id in detected_classes:
|
||
class_name_cn = config.get_class_name_cn(class_id)
|
||
class_name_en = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
|
||
|
||
# 只为实际检测到的类别创建文件夹
|
||
class_folder = video_output_folder / class_name_cn
|
||
class_folder.mkdir(exist_ok=True)
|
||
|
||
# 保存检测到该类别的帧(只保存标注后的图片)
|
||
annotated_filename = f"detected_{frame_count:06d}_t{timestamp:.2f}s_{class_name_cn}.jpg"
|
||
annotated_path = class_folder / annotated_filename
|
||
cv2.imwrite(str(annotated_path), annotated_frame)
|
||
saved_count += 1
|
||
|
||
# 记录检测信息
|
||
for box in result.boxes:
|
||
confidence = float(box.conf[0].cpu().numpy())
|
||
class_id = int(box.cls[0].cpu().numpy())
|
||
class_name = self.model.names[class_id] if hasattr(self.model, 'names') else str(class_id)
|
||
bbox = box.xyxy[0].cpu().numpy().tolist()
|
||
|
||
detection_info['detections'].append({
|
||
'class_name': class_name,
|
||
'class_name_cn': config.get_class_name_cn(class_id),
|
||
'class_id': class_id,
|
||
'confidence': confidence,
|
||
'bbox': bbox
|
||
})
|
||
|
||
video_detections['detections'].append(detection_info)
|
||
|
||
finally:
|
||
cap.release()
|
||
|
||
# 保存检测结果到JSON文件
|
||
json_path = video_output_folder / f"{video_path.stem}_detections.json"
|
||
with open(json_path, 'w', encoding='utf-8') as f:
|
||
json.dump(video_detections, f, ensure_ascii=False, indent=2)
|
||
|
||
self.logger.info(f"视频处理完成: {video_path.name}")
|
||
self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}")
|
||
|
||
# 更新统计信息
|
||
self.detection_stats['total_videos'] += 1
|
||
self.detection_stats['total_frames'] += frame_count
|
||
self.detection_stats['detected_frames'] += detected_count
|
||
self.detection_stats['saved_frames'] += saved_count
|
||
self.detection_stats['detection_results'].append(video_detections)
|
||
|
||
def process_all_videos(self):
|
||
"""处理所有视频文件"""
|
||
video_files = self.get_video_files()
|
||
|
||
if not video_files:
|
||
self.logger.warning("没有找到视频文件")
|
||
return
|
||
|
||
self.logger.info(f"开始处理 {len(video_files)} 个视频文件")
|
||
|
||
for i, video_path in enumerate(video_files, 1):
|
||
self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===")
|
||
self.process_video(video_path)
|
||
|
||
# 保存总体统计信息
|
||
self.save_summary_report()
|
||
|
||
self.logger.info("\n=== 处理完成 ===")
|
||
self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']} 个")
|
||
self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']} 帧")
|
||
self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']} 帧")
|
||
self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']} 张")
|
||
|
||
def save_summary_report(self):
|
||
"""保存总结报告"""
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
report_path = self.output_folder / f"detection_summary_{timestamp}.json"
|
||
|
||
summary = {
|
||
'timestamp': timestamp,
|
||
'model_path': str(self.model_path),
|
||
'confidence_threshold': self.confidence_threshold,
|
||
'input_folder': str(self.input_folder),
|
||
'output_folder': str(self.output_folder),
|
||
'statistics': self.detection_stats
|
||
}
|
||
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||
|
||
self.logger.info(f"总结报告已保存: {report_path}")
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='视频目标检测系统')
|
||
parser.add_argument('--model', type=str,
|
||
default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt',
|
||
help='YOLO模型文件路径')
|
||
parser.add_argument('--confidence', type=float, default=0.5,
|
||
help='置信度阈值 (默认: 0.5)')
|
||
parser.add_argument('--input', type=str, default='input_videos',
|
||
help='输入视频文件夹 (默认: input_videos)')
|
||
parser.add_argument('--output', type=str, default='output_frames',
|
||
help='输出图片文件夹 (默认: output_frames)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 创建检测系统
|
||
detector = VideoDetectionSystem(
|
||
model_path=args.model,
|
||
confidence_threshold=args.confidence,
|
||
input_folder=args.input,
|
||
output_folder=args.output
|
||
)
|
||
|
||
# 处理所有视频
|
||
detector.process_all_videos()
|
||
|
||
if __name__ == '__main__':
|
||
main() |