200 lines
6.2 KiB
Python
200 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
视频检测系统配置文件
|
||
"""
|
||
|
||
import os
|
||
from pathlib import Path
|
||
from ultralytics import YOLO
|
||
|
||
class DetectionConfig:
|
||
"""检测系统配置类"""
|
||
|
||
# 模型配置
|
||
MODEL_PATH = "models/best.pt"
|
||
CONFIDENCE_THRESHOLD = 0.5
|
||
|
||
# 动态模型信息
|
||
_model_classes = None
|
||
_model_colors = None
|
||
|
||
# 文件夹配置
|
||
INPUT_FOLDER = "input_videos"
|
||
OUTPUT_FOLDER = "output_frames"
|
||
LOG_FOLDER = "logs"
|
||
|
||
# 视频处理配置
|
||
SUPPORTED_VIDEO_FORMATS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
|
||
PROGRESS_INTERVAL = 100 # 每处理多少帧显示一次进度
|
||
MIN_FRAME_INTERVAL = 1.0 # 检测帧的最小时间间隔(秒)
|
||
|
||
# 输出配置
|
||
SAVE_ORIGINAL_FRAMES = False # 是否保存原始帧
|
||
SAVE_ANNOTATED_FRAMES = True # 是否保存标注帧
|
||
SAVE_DETECTION_JSON = True # 是否保存检测结果JSON
|
||
|
||
# 图像质量配置
|
||
IMAGE_QUALITY = 95 # JPEG质量 (1-100)
|
||
|
||
# 默认道路损伤类别映射 (中文标签) - 作为备用
|
||
DEFAULT_CLASS_NAMES_CN = {
|
||
0: "纵向裂缝",
|
||
1: "横向裂缝",
|
||
2: "网状裂缝",
|
||
3: "坑洞",
|
||
4: "白线模糊"
|
||
}
|
||
|
||
# 通用类别关键词映射(扩展版)
|
||
COMMON_CLASS_MAPPINGS = {
|
||
# 道路损伤相关
|
||
'crack': '裂缝',
|
||
'pothole': '坑洞',
|
||
'damage': '损伤',
|
||
'longitudinal': '纵向裂缝',
|
||
'transverse': '横向裂缝',
|
||
'alligator': '网状裂缝',
|
||
'rutting': '车辙',
|
||
'bleeding': '泛油',
|
||
'weathering': '风化',
|
||
'patching': '修补',
|
||
|
||
# 交通工具
|
||
'person': '人',
|
||
'people': '人群',
|
||
'car': '汽车',
|
||
'truck': '卡车',
|
||
'bus': '公交车',
|
||
'motorcycle': '摩托车',
|
||
'bicycle': '自行车',
|
||
'bike': '自行车',
|
||
'vehicle': '车辆',
|
||
|
||
# 交通设施
|
||
'road': '道路',
|
||
'traffic': '交通',
|
||
'sign': '标志',
|
||
'light': '信号灯',
|
||
'signal': '信号',
|
||
'stop': '停车标志',
|
||
'yield': '让行标志',
|
||
'crosswalk': '人行横道',
|
||
'lane': '车道',
|
||
'marking': '标线',
|
||
'white': '白线',
|
||
'yellow': '黄线',
|
||
|
||
# 其他常见目标
|
||
'animal': '动物',
|
||
'dog': '狗',
|
||
'cat': '猫',
|
||
'bird': '鸟',
|
||
'tree': '树',
|
||
'building': '建筑',
|
||
'pole': '杆子',
|
||
'fence': '围栏'
|
||
}
|
||
|
||
# 动态颜色生成配置
|
||
COLOR_PALETTE = [
|
||
(0, 255, 0), # 绿色
|
||
(255, 0, 0), # 红色
|
||
(0, 0, 255), # 蓝色
|
||
(255, 255, 0), # 黄色
|
||
(255, 0, 255), # 紫色
|
||
(0, 255, 255), # 青色
|
||
(255, 128, 0), # 橙色
|
||
(128, 0, 255), # 紫罗兰
|
||
(255, 192, 203), # 粉色
|
||
(128, 128, 128) # 灰色
|
||
]
|
||
|
||
# 默认类别颜色配置 (BGR格式)
|
||
DEFAULT_CLASS_COLORS = {
|
||
0: (0, 255, 0), # 绿色
|
||
1: (255, 0, 0), # 红色
|
||
2: (0, 0, 255), # 蓝色
|
||
3: (255, 255, 0), # 黄色
|
||
4: (255, 0, 255) # 紫色
|
||
}
|
||
|
||
@classmethod
|
||
def get_model_path(cls):
|
||
"""获取模型文件的绝对路径"""
|
||
current_dir = Path(__file__).parent
|
||
model_path = current_dir / cls.MODEL_PATH
|
||
return str(model_path.resolve())
|
||
|
||
@classmethod
|
||
def create_folders(cls):
|
||
"""创建必要的文件夹"""
|
||
folders = [cls.INPUT_FOLDER, cls.OUTPUT_FOLDER, cls.LOG_FOLDER]
|
||
for folder in folders:
|
||
Path(folder).mkdir(parents=True, exist_ok=True)
|
||
|
||
@classmethod
|
||
def load_model_classes(cls, model_path):
|
||
"""从模型文件动态加载类别信息"""
|
||
try:
|
||
model = YOLO(model_path)
|
||
cls._model_classes = model.names
|
||
|
||
# 动态生成类别颜色
|
||
cls._model_colors = {}
|
||
for i in range(len(model.names)):
|
||
# 使用调色板循环分配颜色
|
||
color_index = i % len(cls.COLOR_PALETTE)
|
||
cls._model_colors[i] = cls.COLOR_PALETTE[color_index]
|
||
|
||
print(f"成功加载 {len(model.names)} 个类别,分配了对应颜色")
|
||
return True
|
||
except Exception as e:
|
||
print(f"加载模型类别信息失败: {e}")
|
||
return False
|
||
|
||
@classmethod
|
||
def get_class_name_cn(cls, class_id):
|
||
"""获取类别的中文名称"""
|
||
# 直接使用模型的原始类别名称,不进行任何映射
|
||
if cls._model_classes is not None:
|
||
original_name = cls._model_classes.get(class_id)
|
||
if original_name:
|
||
return original_name
|
||
return f"未知类别_{class_id}"
|
||
# 回退到默认中文类别
|
||
return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, f"未知类别_{class_id}")
|
||
|
||
@classmethod
|
||
def get_class_color(cls, class_id):
|
||
"""获取类别对应的颜色"""
|
||
# 如果class_id是字符串,尝试转换为整数或使用hash
|
||
if isinstance(class_id, str):
|
||
# 尝试从类别名称获取ID
|
||
if cls._model_classes and class_id in cls._model_classes.values():
|
||
# 找到对应的ID
|
||
for id_val, name in cls._model_classes.items():
|
||
if name == class_id:
|
||
class_id = id_val
|
||
break
|
||
else:
|
||
# 使用hash生成一个稳定的索引
|
||
class_id = hash(class_id) % len(cls.COLOR_PALETTE)
|
||
|
||
# 确保class_id是整数
|
||
if not isinstance(class_id, int):
|
||
class_id = 0
|
||
|
||
if cls._model_colors is not None:
|
||
return cls._model_colors.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
|
||
return cls.DEFAULT_CLASS_COLORS.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
|
||
|
||
@classmethod
|
||
def get_all_classes(cls):
|
||
"""获取所有类别信息"""
|
||
if cls._model_classes is not None:
|
||
return cls._model_classes
|
||
return cls.DEFAULT_CLASS_NAMES_CN
|
||
|
||
# 全局配置实例
|
||
config = DetectionConfig() |