删除 detector.py
This commit is contained in:
parent
e1d71d67c7
commit
a22521f738
356
detector.py
356
detector.py
@ -1,356 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频目标检测系统
|
||||
使用YOLO模型对视频进行道路损伤检测,并抽取包含目标的帧
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from config import config, DetectionConfig
|
||||
|
||||
class VideoDetectionSystem:
|
||||
def __init__(self, model_path, confidence_threshold=0.5, input_folder="input_videos", output_folder="output_frames"):
|
||||
"""
|
||||
初始化视频检测系统
|
||||
|
||||
Args:
|
||||
model_path (str): YOLO模型文件路径
|
||||
confidence_threshold (float): 置信度阈值
|
||||
input_folder (str): 输入视频文件夹
|
||||
output_folder (str): 输出图片文件夹
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.input_folder = Path(input_folder)
|
||||
self.output_folder = Path(output_folder)
|
||||
|
||||
# 创建输出文件夹
|
||||
self.output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 设置日志
|
||||
self.setup_logging()
|
||||
|
||||
# 加载模型
|
||||
self.load_model()
|
||||
|
||||
# 支持的视频格式
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
|
||||
|
||||
# 检测结果统计
|
||||
self.detection_stats = {
|
||||
'total_videos': 0,
|
||||
'total_frames': 0,
|
||||
'detected_frames': 0,
|
||||
'saved_frames': 0,
|
||||
'detection_results': []
|
||||
}
|
||||
|
||||
def setup_logging(self):
|
||||
"""设置日志记录"""
|
||||
log_folder = Path('logs')
|
||||
log_folder.mkdir(exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = log_folder / f'detection_{timestamp}.log'
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file, encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def load_model(self):
|
||||
"""加载YOLO模型"""
|
||||
try:
|
||||
self.model = YOLO(self.model_path)
|
||||
self.logger.info(f"成功加载模型: {self.model_path}")
|
||||
|
||||
# 动态加载模型类别信息到配置中
|
||||
DetectionConfig.load_model_classes(self.model_path)
|
||||
|
||||
# 打印模型信息
|
||||
if hasattr(self.model, 'names'):
|
||||
self.logger.info(f"模型类别数量: {len(self.model.names)}")
|
||||
self.logger.info(f"检测类别: {list(self.model.names.values())}")
|
||||
self.logger.info(f"动态加载的类别信息: {DetectionConfig.get_all_classes()}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载模型失败: {e}")
|
||||
raise
|
||||
|
||||
def get_video_files(self):
|
||||
"""获取输入文件夹中的所有视频文件"""
|
||||
video_files = []
|
||||
|
||||
if not self.input_folder.exists():
|
||||
self.logger.warning(f"输入文件夹不存在: {self.input_folder}")
|
||||
return video_files
|
||||
|
||||
for file_path in self.input_folder.rglob('*'):
|
||||
if file_path.is_file() and file_path.suffix.lower() in self.video_extensions:
|
||||
video_files.append(file_path)
|
||||
|
||||
self.logger.info(f"找到 {len(video_files)} 个视频文件")
|
||||
return video_files
|
||||
|
||||
def detect_frame(self, frame):
|
||||
"""对单帧进行目标检测"""
|
||||
try:
|
||||
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
||||
return results[0] if results else None
|
||||
except Exception as e:
|
||||
self.logger.error(f"检测失败: {e}")
|
||||
return None
|
||||
|
||||
def draw_detections(self, frame, result):
|
||||
"""在帧上绘制检测结果"""
|
||||
if result is None or result.boxes is None:
|
||||
return frame
|
||||
|
||||
# 转换为PIL图像以支持中文字体
|
||||
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
|
||||
# 尝试加载中文字体
|
||||
try:
|
||||
# Windows系统字体路径
|
||||
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
|
||||
if not os.path.exists(font_path):
|
||||
font_path = "C:/Windows/Fonts/msyh.ttf" # 微软雅黑
|
||||
if not os.path.exists(font_path):
|
||||
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
|
||||
|
||||
if os.path.exists(font_path):
|
||||
font = ImageFont.truetype(font_path, 20)
|
||||
else:
|
||||
font = ImageFont.load_default()
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for box in result.boxes:
|
||||
# 获取边界框坐标
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
|
||||
# 获取置信度和类别
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
|
||||
# 获取中文类别名称和颜色
|
||||
class_name_cn = config.get_class_name_cn(class_id)
|
||||
color = config.get_class_color(class_id)
|
||||
|
||||
# 绘制边界框(使用PIL)
|
||||
bg_color = tuple(reversed(color)) # BGR转RGB
|
||||
draw.rectangle([x1, y1, x2, y2], outline=bg_color, width=2)
|
||||
|
||||
# 准备标签文字
|
||||
label = f"{class_name_cn}: {confidence:.2f}"
|
||||
|
||||
# 获取文字尺寸
|
||||
bbox = draw.textbbox((0, 0), label, font=font)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
|
||||
# 绘制标签背景(使用PIL)
|
||||
draw.rectangle([x1, y1 - text_height - 10, x1 + text_width + 10, y1], fill=bg_color)
|
||||
|
||||
# 绘制文字(使用PIL)
|
||||
draw.text((x1 + 5, y1 - text_height - 5), label, font=font, fill=(255, 255, 255))
|
||||
|
||||
# 转换回OpenCV格式
|
||||
annotated_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
return annotated_frame
|
||||
|
||||
def process_video(self, video_path):
|
||||
"""处理单个视频文件"""
|
||||
self.logger.info(f"开始处理视频: {video_path.name}")
|
||||
|
||||
# 创建视频专用输出文件夹
|
||||
video_output_folder = self.output_folder / video_path.stem
|
||||
video_output_folder.mkdir(exist_ok=True)
|
||||
|
||||
# 打开视频
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
self.logger.error(f"无法打开视频文件: {video_path}")
|
||||
return
|
||||
|
||||
# 获取视频信息
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
self.logger.info(f"视频信息 - 帧率: {fps:.2f}, 总帧数: {total_frames}, 时长: {duration:.2f}秒")
|
||||
|
||||
frame_count = 0
|
||||
detected_count = 0
|
||||
saved_count = 0
|
||||
|
||||
video_detections = {
|
||||
'video_name': video_path.name,
|
||||
'total_frames': total_frames,
|
||||
'fps': fps,
|
||||
'duration': duration,
|
||||
'detections': []
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# 每100帧显示一次进度
|
||||
if frame_count % 100 == 0:
|
||||
progress = (frame_count / total_frames) * 100
|
||||
self.logger.info(f"处理进度: {progress:.1f}% ({frame_count}/{total_frames})")
|
||||
|
||||
# 进行目标检测
|
||||
result = self.detect_frame(frame)
|
||||
|
||||
# 如果检测到目标,保存帧
|
||||
if result is not None and result.boxes is not None and len(result.boxes) > 0:
|
||||
detected_count += 1
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = frame_count / fps if fps > 0 else frame_count
|
||||
|
||||
# 绘制检测结果
|
||||
annotated_frame = self.draw_detections(frame, result)
|
||||
|
||||
# 保存原始帧和标注帧
|
||||
frame_filename = f"frame_{frame_count:06d}_t{timestamp:.2f}s.jpg"
|
||||
annotated_filename = f"annotated_{frame_count:06d}_t{timestamp:.2f}s.jpg"
|
||||
|
||||
original_path = video_output_folder / frame_filename
|
||||
annotated_path = video_output_folder / annotated_filename
|
||||
|
||||
cv2.imwrite(str(original_path), frame)
|
||||
cv2.imwrite(str(annotated_path), annotated_frame)
|
||||
|
||||
saved_count += 2
|
||||
|
||||
# 记录检测信息
|
||||
detection_info = {
|
||||
'frame_number': frame_count,
|
||||
'timestamp': timestamp,
|
||||
'detections': []
|
||||
}
|
||||
|
||||
for box in result.boxes:
|
||||
confidence = float(box.conf[0].cpu().numpy())
|
||||
class_id = int(box.cls[0].cpu().numpy())
|
||||
class_name = self.model.names[class_id]
|
||||
bbox = box.xyxy[0].cpu().numpy().tolist()
|
||||
|
||||
detection_info['detections'].append({
|
||||
'class_name': class_name,
|
||||
'class_id': class_id,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox
|
||||
})
|
||||
|
||||
video_detections['detections'].append(detection_info)
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
# 保存检测结果到JSON文件
|
||||
json_path = video_output_folder / f"{video_path.stem}_detections.json"
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(video_detections, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.logger.info(f"视频处理完成: {video_path.name}")
|
||||
self.logger.info(f"总帧数: {frame_count}, 检测到目标的帧数: {detected_count}, 保存图片数: {saved_count}")
|
||||
|
||||
# 更新统计信息
|
||||
self.detection_stats['total_videos'] += 1
|
||||
self.detection_stats['total_frames'] += frame_count
|
||||
self.detection_stats['detected_frames'] += detected_count
|
||||
self.detection_stats['saved_frames'] += saved_count
|
||||
self.detection_stats['detection_results'].append(video_detections)
|
||||
|
||||
def process_all_videos(self):
|
||||
"""处理所有视频文件"""
|
||||
video_files = self.get_video_files()
|
||||
|
||||
if not video_files:
|
||||
self.logger.warning("没有找到视频文件")
|
||||
return
|
||||
|
||||
self.logger.info(f"开始处理 {len(video_files)} 个视频文件")
|
||||
|
||||
for i, video_path in enumerate(video_files, 1):
|
||||
self.logger.info(f"\n=== 处理第 {i}/{len(video_files)} 个视频 ===")
|
||||
self.process_video(video_path)
|
||||
|
||||
# 保存总体统计信息
|
||||
self.save_summary_report()
|
||||
|
||||
self.logger.info("\n=== 处理完成 ===")
|
||||
self.logger.info(f"总计处理视频: {self.detection_stats['total_videos']} 个")
|
||||
self.logger.info(f"总计处理帧数: {self.detection_stats['total_frames']} 帧")
|
||||
self.logger.info(f"检测到目标的帧数: {self.detection_stats['detected_frames']} 帧")
|
||||
self.logger.info(f"保存图片数量: {self.detection_stats['saved_frames']} 张")
|
||||
|
||||
def save_summary_report(self):
|
||||
"""保存总结报告"""
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
report_path = self.output_folder / f"detection_summary_{timestamp}.json"
|
||||
|
||||
summary = {
|
||||
'timestamp': timestamp,
|
||||
'model_path': str(self.model_path),
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'input_folder': str(self.input_folder),
|
||||
'output_folder': str(self.output_folder),
|
||||
'statistics': self.detection_stats
|
||||
}
|
||||
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.logger.info(f"总结报告已保存: {report_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='视频目标检测系统')
|
||||
parser.add_argument('--model', type=str,
|
||||
default='../Japan/training_results/continue_from_best_20250610_130607/weights/best.pt',
|
||||
help='YOLO模型文件路径')
|
||||
parser.add_argument('--confidence', type=float, default=0.5,
|
||||
help='置信度阈值 (默认: 0.5)')
|
||||
parser.add_argument('--input', type=str, default='input_videos',
|
||||
help='输入视频文件夹 (默认: input_videos)')
|
||||
parser.add_argument('--output', type=str, default='output_frames',
|
||||
help='输出图片文件夹 (默认: output_frames)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建检测系统
|
||||
detector = VideoDetectionSystem(
|
||||
model_path=args.model,
|
||||
confidence_threshold=args.confidence,
|
||||
input_folder=args.input,
|
||||
output_folder=args.output
|
||||
)
|
||||
|
||||
# 处理所有视频
|
||||
detector.process_all_videos()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user