356 lines
14 KiB
Python
356 lines
14 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
视频目标检测系统
|
|||
|
|
使用YOLO模型对视频进行道路损伤检测,并抽取包含目标的帧
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
from pathlib import Path
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
import argparse
|
|||
|
|
from datetime import datetime
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from PIL import Image, ImageDraw, ImageFont
|
|||
|
|
from config import config, DetectionConfig
|
|||
|
|
|
|||
|
|
class VideoDetectionSystem:
|
|||
|
|
def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"):
|
|||
|
|
"""
|
|||
|
|
初始化视频检测系统
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model_path (str): YOLO模型文件路径
|
|||
|
|
confidence_threshold (float): 置信度阈值
|
|||
|
|
input_folder (str): 输入视频文件夹
|
|||
|
|
output_folder (str): 输出图片文件夹
|
|||
|
|
"""
|
|||
|
|
self.model_path = model_path
|
|||
|
|
self.confidence_threshold = confidence_threshold
|
|||
|
|
self.input_folder = Path(input_folder)
|
|||
|
|
self.output_folder = Path(output_folder)
|
|||
|
|
|
|||
|
|
# 创建输出文件夹
|
|||
|
|
self.output_folder.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 设置日志
|
|||
|
|
self.setup_logging()
|
|||
|
|
|
|||
|
|
# 加载模型
|
|||
|
|
self.load_model()
|
|||
|
|
|
|||
|
|
# 支持的视频格式
|
|||
|
|
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
|
|||
|
|
|
|||
|
|
# 检测结果统计
|
|||
|
|
self.detection_stats = {
|
|||
|
|
'total_videos': 0,
|
|||
|
|
'total_frames': 0,
|
|||
|
|
'detected_frames': 0,
|
|||
|
|
'saved_frames': 0,
|
|||
|
|
'detection_results': []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def setup_logging(self):
|
|||
|
|
"""设置日志记录"""
|
|||
|
|
log_folder = Path('logs')
|
|||
|
|
log_folder.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|||
|
|
log_file = log_folder / f'detection_{timestamp}.log'
|
|||
|
|
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO,
|
|||
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|||
|
|
handlers=[
|
|||
|
|
logging.FileHandler(log_file, encoding='utf-8'),
|
|||
|
|
logging.StreamHandler()
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
self.logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
def load_model(self):
|
|||
|
|
"""加载YOLO模型"""
|
|||
|
|
try:
|
|||
|
|
self.model = YOLO(self.model_path)
|
|||
|
|
self.logger.info(f"成功加载模型: {self.model_path}")
|
|||
|
|
|
|||
|
|
# 动态加载模型类别信息到配置中
|
|||
|
|
DetectionConfig.load_model_classes(self.model_path)
|
|||
|
|
|
|||
|
|
# 打印模型信息
|
|||
|
|
if hasattr(self.model, 'names'):
|
|||
|
|
self.logger.info(f"模型类别数量: {len(self.model.names)}")
|
|||
|
|
self.logger.info(f"检测类别: {list(self.model.names.values())}")
|
|||
|
|
self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
self.logger.error(f"加载模型失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def get_video_files(self):
|
|||
|
|
"""获取输入文件夹中的所有视频文件"""
|
|||
|
|
video_files = []
|
|||
|
|
|
|||
|
|
if not self.input_folder.exists():
|
|||
|
|
self.logger.warning(f"输入文件夹不存在: {self.input_folder}")
|
|||
|
|
return video_files
|
|||
|
|
|
|||
|
|
for file_path in self.input_folder.rglob('*'):
|
|||
|
|
if file_path.is_file() and file_path.suffix.lower() in self.video_extensions:
|
|||
|
|
video_files.append(file_path)
|
|||
|
|
|
|||
|
|
self.logger.info(f"找到 {len(video_files)} 个视频文件")
|
|||
|
|
return video_files
|
|||
|
|
|
|||
|
|
def detect_frame(self, frame):
|
|||
|
|
"""对单帧进行目标检测"""
|
|||
|
|
try:
|
|||
|
|
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
|||
|
|
return results[0] if results else None
|
|||
|
|
except Exception as e:
|
|||
|
|
self.logger.error(f"检测失败: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def draw_detections(self, frame, result):
|
|||
|
|
"""在帧上绘制检测结果"""
|
|||
|
|
if result is None or result.boxes is None:
|
|||
|
|
return frame
|
|||
|
|
|
|||
|
|
# 转换为PIL图像以支持中文字体
|
|||
|
|
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|||
|
|
draw = ImageDraw.Draw(pil_image)
|
|||
|
|
|
|||
|
|
# 尝试加载中文字体
|
|||
|
|
try:
|
|||
|
|
# Windows系统字体路径
|
|||
|
|
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
|
|||
|
|
if not os.path.exists(font_path):
|
|||
|
|
font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑
|
|||
|
|
if not os.path.exists(font_path):
|
|||
|
|
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
|
|||
|
|
|
|||
|
|
if os.path.exists(font_path):
|
|||
|
|
font = ImageFont.truetype(font_path, 20)
|
|||
|
|
else:
|
|||
|
|
font = ImageFont.load_default()
|
|||
|
|
except:
|
|||
|
|
font = ImageFont.load_default()
|
|||
|
|
|
|||
|
|
for box in result.boxes:
|
|||
|
|
# 获取边界框坐标
|
|||
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
|||
|
|
|
|||
|
|
# 获取置信度和类别
|
|||
|
|
confidence = float(box.conf[0].cpu().numpy())
|
|||
|
|
class_id = int(box.cls[0].cpu().numpy())
|
|||
|
|
|
|||
|
|
# 获取中文类别名称和颜色
|
|||
|
|
class_name_cn = config.get_class_name_cn(class_id)
|
|||
|
|
color = config.get_class_color(class_id)
|
|||
|
|
|
|||
|
|
# 绘制边界框(使用PIL)
|
|||
|
|
bg_color = tuple(reversed(color)) # BGR转RGB
|
|||
|
|
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2)
|
|||
|
|
|
|||
|
|
# 准备标签文字
|
|||
|
|
label = f"{class_name_cn}: {confidence:.2f}"
|
|||
|
|
|
|||
|
|
# 获取文字尺寸
|
|||
|
|
bbox = draw.textbbox((0, 0), label, font=font)
|
|||
|
|
text_width = bbox[2] - bbox[0]
|
|||
|
|
text_height = bbox[3] - bbox[1]
|
|||
|
|
|
|||
|
|
# 绘制标签背景(使用PIL)
|
|||
|
|
draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color)
|
|||
|
|
|
|||
|
|
# 绘制文字(使用PIL)
|
|||
|
|
draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255))
|
|||
|
|
|
|||
|
|
# 转换回OpenCV格式
|
|||
|
|
annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
|||
|
|
return annotated_frame
|
|||
|
|
|
|||
|
|
def process_video(self, video_path):
|
|||
|
|
"""处理单个视频文件"""
|
|||
|
|
self.logger.info(f"开始处理视频: {video_path.name}")
|
|||
|
|
|
|||
|
|
# 创建视频专用输出文件夹
|
|||
|
|
video_output_folder = self.output_folder / video_path.stem
|
|||
|
|
video_output_folder.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
# 打开视频
|
|||
|
|
cap = cv2.VideoCapture(str(video_path))
|
|||
|
|
if not cap.isOpened():
|
|||
|
|
self.logger.error(f"无法打开视频文件: {video_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 获取视频信息
|
|||
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|||
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|||
|
|
duration = total_frames / fps if fps > 0 else 0
|
|||
|
|
|
|||
|
|
self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}秒")
|
|||
|
|
|
|||
|
|
frame_count = 0
|
|||
|
|
detected_count = 0
|
|||
|
|
saved_count = 0
|
|||
|
|
|
|||
|
|
video_detections = {
|
|||
|
|
'video_name': video_path.name,
|
|||
|
|
'total_frames': total_frames,
|
|||
|
|
'fps': fps,
|
|||
|
|
'duration': duration,
|
|||
|
|
'detections': []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
while True:
|
|||
|
|
ret, frame = cap.read()
|
|||
|
|
if not ret:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
frame_count += 1
|
|||
|
|
|
|||
|
|
# 每100帧显示一次进度
|
|||
|
|
if frame_count % 100 == 0:
|
|||
|
|
progress = (frame_count / total_frames) * 100
|
|||
|
|
self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})")
|
|||
|
|
|
|||
|
|
# 进行目标检测
|
|||
|
|
result = self.detect_frame(frame)
|
|||
|
|
|
|||
|
|
# 如果检测到目标,保存帧
|
|||
|
|
if result is not None and result.boxes is not None and len(result.boxes) > 0:
|
|||
|
|
detected_count += 1
|
|||
|
|
|
|||
|
|
# 计算时间戳
|
|||
|
|
timestamp = frame_count / fps if fps > 0 else frame_count
|
|||
|
|
|
|||
|
|
# 绘制检测结果
|
|||
|
|
annotated_frame = self.draw_detections(frame, result)
|
|||
|
|
|
|||
|
|
# 保存原始帧和标注帧
|
|||
|
|
frame_filename = f"frame_{frame_count:06d}_t{timestamp:.2f}s.jpg"
|
|||
|
|
annotated_filename = f"annotated_{frame_count:06d}_t{timestamp:.2f}s.jpg"
|
|||
|
|
|
|||
|
|
original_path = video_output_folder / frame_filename
|
|||
|
|
annotated_path = video_output_folder / annotated_filename
|
|||
|
|
|
|||
|
|
cv2.imwrite(str(original_path), frame)
|
|||
|
|
cv2.imwrite(str(annotated_path), annotated_frame)
|
|||
|
|
|
|||
|
|
saved_count += 2
|
|||
|
|
|
|||
|
|
# 记录检测信息
|
|||
|
|
detection_info = {
|
|||
|
|
'frame_number': frame_count,
|
|||
|
|
'timestamp': timestamp,
|
|||
|
|
'detections': []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for box in result.boxes:
|
|||
|
|
confidence = float(box.conf[0].cpu().numpy())
|
|||
|
|
class_id = int(box.cls[0].cpu().numpy())
|
|||
|
|
class_name = self.model.names[class_id]
|
|||
|
|
bbox = box.xyxy[0].cpu().numpy().tolist()
|
|||
|
|
|
|||
|
|
detection_info['detections'].append({
|
|||
|
|
'class_name': class_name,
|
|||
|
|
'class_id': class_id,
|
|||
|
|
'confidence': confidence,
|
|||
|
|
'bbox': bbox
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
video_detections['detections'].append(detection_info)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
cap.release()
|
|||
|
|
|
|||
|
|
# 保存检测结果到JSON文件
|
|||
|
|
json_path = video_output_folder / f"{video_path.stem}_detections.json"
|
|||
|
|
with open(json_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(video_detections, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
self.logger.info(f"视频处理完成: {video_path.name}")
|
|||
|
|
self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}")
|
|||
|
|
|
|||
|
|
# 更新统计信息
|
|||
|
|
self.detection_stats['total_videos'] += 1
|
|||
|
|
self.detection_stats['total_frames'] += frame_count
|
|||
|
|
self.detection_stats['detected_frames'] += detected_count
|
|||
|
|
self.detection_stats['saved_frames'] += saved_count
|
|||
|
|
self.detection_stats['detection_results'].append(video_detections)
|
|||
|
|
|
|||
|
|
def process_all_videos(self):
|
|||
|
|
"""处理所有视频文件"""
|
|||
|
|
video_files = self.get_video_files()
|
|||
|
|
|
|||
|
|
if not video_files:
|
|||
|
|
self.logger.warning("没有找到视频文件")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
self.logger.info(f"开始处理 {len(video_files)} 个视频文件")
|
|||
|
|
|
|||
|
|
for i, video_path in enumerate(video_files, 1):
|
|||
|
|
self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===")
|
|||
|
|
self.process_video(video_path)
|
|||
|
|
|
|||
|
|
# 保存总体统计信息
|
|||
|
|
self.save_summary_report()
|
|||
|
|
|
|||
|
|
self.logger.info("\n=== 处理完成 ===")
|
|||
|
|
self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']} 个")
|
|||
|
|
self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']} 帧")
|
|||
|
|
self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']} 帧")
|
|||
|
|
self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']} 张")
|
|||
|
|
|
|||
|
|
def save_summary_report(self):
|
|||
|
|
"""保存总结报告"""
|
|||
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|||
|
|
report_path = self.output_folder / f"detection_summary_{timestamp}.json"
|
|||
|
|
|
|||
|
|
summary = {
|
|||
|
|
'timestamp': timestamp,
|
|||
|
|
'model_path': str(self.model_path),
|
|||
|
|
'confidence_threshold': self.confidence_threshold,
|
|||
|
|
'input_folder': str(self.input_folder),
|
|||
|
|
'output_folder': str(self.output_folder),
|
|||
|
|
'statistics': self.detection_stats
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
self.logger.info(f"总结报告已保存: {report_path}")
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description='视频目标检测系统')
|
|||
|
|
parser.add_argument('--model', type=str,
|
|||
|
|
default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt',
|
|||
|
|
help='YOLO模型文件路径')
|
|||
|
|
parser.add_argument('--confidence', type=float, default=0.5,
|
|||
|
|
help='置信度阈值 (默认: 0.5)')
|
|||
|
|
parser.add_argument('--input', type=str, default='input_videos',
|
|||
|
|
help='输入视频文件夹 (默认: input_videos)')
|
|||
|
|
parser.add_argument('--output', type=str, default='output_frames',
|
|||
|
|
help='输出图片文件夹 (默认: output_frames)')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 创建检测系统
|
|||
|
|
detector = VideoDetectionSystem(
|
|||
|
|
model_path=args.model,
|
|||
|
|
confidence_threshold=args.confidence,
|
|||
|
|
input_folder=args.input,
|
|||
|
|
output_folder=args.output
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 处理所有视频
|
|||
|
|
detector.process_all_videos()
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|