wrz-yolo/config.py

200 lines
6.2 KiB
Python
Raw Normal View History

2025-06-27 17:13:03 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频检测系统配置文件
"""
import os
from pathlib import Path
from ultralytics import YOLO
class DetectionConfig:
"""检测系统配置类"""
# 模型配置
MODEL_PATH = "models/best.pt"
CONFIDENCE_THRESHOLD = 0.5
# 动态模型信息
_model_classes = None
_model_colors = None
# 文件夹配置
INPUT_FOLDER = "input_videos"
OUTPUT_FOLDER = "output_frames"
LOG_FOLDER = "logs"
# 视频处理配置
SUPPORTED_VIDEO_FORMATS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
PROGRESS_INTERVAL = 100 # 每处理多少帧显示一次进度
MIN_FRAME_INTERVAL = 1.0 # 检测帧的最小时间间隔(秒)
# 输出配置
SAVE_ORIGINAL_FRAMES = False # 是否保存原始帧
SAVE_ANNOTATED_FRAMES = True # 是否保存标注帧
SAVE_DETECTION_JSON = True # 是否保存检测结果JSON
# 图像质量配置
IMAGE_QUALITY = 95 # JPEG质量 (1-100)
# 默认道路损伤类别映射 (中文标签) - 作为备用
DEFAULT_CLASS_NAMES_CN = {
0: "纵向裂缝",
1: "横向裂缝",
2: "网状裂缝",
3: "坑洞",
4: "白线模糊"
}
# 通用类别关键词映射(扩展版)
COMMON_CLASS_MAPPINGS = {
# 道路损伤相关
'crack': '裂缝',
'pothole': '坑洞',
'damage': '损伤',
'longitudinal': '纵向裂缝',
'transverse': '横向裂缝',
'alligator': '网状裂缝',
'rutting': '车辙',
'bleeding': '泛油',
'weathering': '风化',
'patching': '修补',
# 交通工具
'person': '',
'people': '人群',
'car': '汽车',
'truck': '卡车',
'bus': '公交车',
'motorcycle': '摩托车',
'bicycle': '自行车',
'bike': '自行车',
'vehicle': '车辆',
# 交通设施
'road': '道路',
'traffic': '交通',
'sign': '标志',
'light': '信号灯',
'signal': '信号',
'stop': '停车标志',
'yield': '让行标志',
'crosswalk': '人行横道',
'lane': '车道',
'marking': '标线',
'white': '白线',
'yellow': '黄线',
# 其他常见目标
'animal': '动物',
'dog': '',
'cat': '',
'bird': '',
'tree': '',
'building': '建筑',
'pole': '杆子',
'fence': '围栏'
}
# 动态颜色生成配置
COLOR_PALETTE = [
(0, 255, 0), # 绿色
(255, 0, 0), # 红色
(0, 0, 255), # 蓝色
(255, 255, 0), # 黄色
(255, 0, 255), # 紫色
(0, 255, 255), # 青色
(255, 128, 0), # 橙色
(128, 0, 255), # 紫罗兰
(255, 192, 203), # 粉色
(128, 128, 128) # 灰色
]
# 默认类别颜色配置 (BGR格式)
DEFAULT_CLASS_COLORS = {
0: (0, 255, 0), # 绿色
1: (255, 0, 0), # 红色
2: (0, 0, 255), # 蓝色
3: (255, 255, 0), # 黄色
4: (255, 0, 255) # 紫色
}
@classmethod
def get_model_path(cls):
"""获取模型文件的绝对路径"""
current_dir = Path(__file__).parent
model_path = current_dir / cls.MODEL_PATH
return str(model_path.resolve())
@classmethod
def create_folders(cls):
"""创建必要的文件夹"""
folders = [cls.INPUT_FOLDER, cls.OUTPUT_FOLDER, cls.LOG_FOLDER]
for folder in folders:
Path(folder).mkdir(parents=True, exist_ok=True)
@classmethod
def load_model_classes(cls, model_path):
"""从模型文件动态加载类别信息"""
try:
model = YOLO(model_path)
cls._model_classes = model.names
# 动态生成类别颜色
cls._model_colors = {}
for i in range(len(model.names)):
# 使用调色板循环分配颜色
color_index = i % len(cls.COLOR_PALETTE)
cls._model_colors[i] = cls.COLOR_PALETTE[color_index]
print(f"成功加载 {len(model.names)} 个类别,分配了对应颜色")
return True
except Exception as e:
print(f"加载模型类别信息失败: {e}")
return False
@classmethod
def get_class_name_cn(cls, class_id):
"""获取类别的中文名称"""
# 直接使用模型的原始类别名称,不进行任何映射
if cls._model_classes is not None:
original_name = cls._model_classes.get(class_id)
if original_name:
return original_name
return f"未知类别_{class_id}"
# 回退到默认中文类别
return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, f"未知类别_{class_id}")
@classmethod
def get_class_color(cls, class_id):
"""获取类别对应的颜色"""
# 如果class_id是字符串尝试转换为整数或使用hash
if isinstance(class_id, str):
# 尝试从类别名称获取ID
if cls._model_classes and class_id in cls._model_classes.values():
# 找到对应的ID
for id_val, name in cls._model_classes.items():
if name == class_id:
class_id = id_val
break
else:
# 使用hash生成一个稳定的索引
class_id = hash(class_id) % len(cls.COLOR_PALETTE)
# 确保class_id是整数
if not isinstance(class_id, int):
class_id = 0
if cls._model_colors is not None:
return cls._model_colors.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
return cls.DEFAULT_CLASS_COLORS.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
@classmethod
def get_all_classes(cls):
"""获取所有类别信息"""
if cls._model_classes is not None:
return cls._model_classes
return cls.DEFAULT_CLASS_NAMES_CN
# 全局配置实例
config = DetectionConfig()