wrz-yolo/config.py
2025-06-27 17:13:03 +08:00

200 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频检测系统配置文件
"""
import os
from pathlib import Path
from ultralytics import YOLO
class DetectionConfig:
"""检测系统配置类"""
# 模型配置
MODEL_PATH = "models/best.pt"
CONFIDENCE_THRESHOLD = 0.5
# 动态模型信息
_model_classes = None
_model_colors = None
# 文件夹配置
INPUT_FOLDER = "input_videos"
OUTPUT_FOLDER = "output_frames"
LOG_FOLDER = "logs"
# 视频处理配置
SUPPORTED_VIDEO_FORMATS = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'}
PROGRESS_INTERVAL = 100 # 每处理多少帧显示一次进度
MIN_FRAME_INTERVAL = 1.0 # 检测帧的最小时间间隔(秒)
# 输出配置
SAVE_ORIGINAL_FRAMES = False # 是否保存原始帧
SAVE_ANNOTATED_FRAMES = True # 是否保存标注帧
SAVE_DETECTION_JSON = True # 是否保存检测结果JSON
# 图像质量配置
IMAGE_QUALITY = 95 # JPEG质量 (1-100)
# 默认道路损伤类别映射 (中文标签) - 作为备用
DEFAULT_CLASS_NAMES_CN = {
0: "纵向裂缝",
1: "横向裂缝",
2: "网状裂缝",
3: "坑洞",
4: "白线模糊"
}
# 通用类别关键词映射(扩展版)
COMMON_CLASS_MAPPINGS = {
# 道路损伤相关
'crack': '裂缝',
'pothole': '坑洞',
'damage': '损伤',
'longitudinal': '纵向裂缝',
'transverse': '横向裂缝',
'alligator': '网状裂缝',
'rutting': '车辙',
'bleeding': '泛油',
'weathering': '风化',
'patching': '修补',
# 交通工具
'person': '',
'people': '人群',
'car': '汽车',
'truck': '卡车',
'bus': '公交车',
'motorcycle': '摩托车',
'bicycle': '自行车',
'bike': '自行车',
'vehicle': '车辆',
# 交通设施
'road': '道路',
'traffic': '交通',
'sign': '标志',
'light': '信号灯',
'signal': '信号',
'stop': '停车标志',
'yield': '让行标志',
'crosswalk': '人行横道',
'lane': '车道',
'marking': '标线',
'white': '白线',
'yellow': '黄线',
# 其他常见目标
'animal': '动物',
'dog': '',
'cat': '',
'bird': '',
'tree': '',
'building': '建筑',
'pole': '杆子',
'fence': '围栏'
}
# 动态颜色生成配置
COLOR_PALETTE = [
(0, 255, 0), # 绿色
(255, 0, 0), # 红色
(0, 0, 255), # 蓝色
(255, 255, 0), # 黄色
(255, 0, 255), # 紫色
(0, 255, 255), # 青色
(255, 128, 0), # 橙色
(128, 0, 255), # 紫罗兰
(255, 192, 203), # 粉色
(128, 128, 128) # 灰色
]
# 默认类别颜色配置 (BGR格式)
DEFAULT_CLASS_COLORS = {
0: (0, 255, 0), # 绿色
1: (255, 0, 0), # 红色
2: (0, 0, 255), # 蓝色
3: (255, 255, 0), # 黄色
4: (255, 0, 255) # 紫色
}
@classmethod
def get_model_path(cls):
"""获取模型文件的绝对路径"""
current_dir = Path(__file__).parent
model_path = current_dir / cls.MODEL_PATH
return str(model_path.resolve())
@classmethod
def create_folders(cls):
"""创建必要的文件夹"""
folders = [cls.INPUT_FOLDER, cls.OUTPUT_FOLDER, cls.LOG_FOLDER]
for folder in folders:
Path(folder).mkdir(parents=True, exist_ok=True)
@classmethod
def load_model_classes(cls, model_path):
"""从模型文件动态加载类别信息"""
try:
model = YOLO(model_path)
cls._model_classes = model.names
# 动态生成类别颜色
cls._model_colors = {}
for i in range(len(model.names)):
# 使用调色板循环分配颜色
color_index = i % len(cls.COLOR_PALETTE)
cls._model_colors[i] = cls.COLOR_PALETTE[color_index]
print(f"成功加载 {len(model.names)} 个类别,分配了对应颜色")
return True
except Exception as e:
print(f"加载模型类别信息失败: {e}")
return False
@classmethod
def get_class_name_cn(cls, class_id):
"""获取类别的中文名称"""
# 直接使用模型的原始类别名称,不进行任何映射
if cls._model_classes is not None:
original_name = cls._model_classes.get(class_id)
if original_name:
return original_name
return f"未知类别_{class_id}"
# 回退到默认中文类别
return cls.DEFAULT_CLASS_NAMES_CN.get(class_id, f"未知类别_{class_id}")
@classmethod
def get_class_color(cls, class_id):
"""获取类别对应的颜色"""
# 如果class_id是字符串尝试转换为整数或使用hash
if isinstance(class_id, str):
# 尝试从类别名称获取ID
if cls._model_classes and class_id in cls._model_classes.values():
# 找到对应的ID
for id_val, name in cls._model_classes.items():
if name == class_id:
class_id = id_val
break
else:
# 使用hash生成一个稳定的索引
class_id = hash(class_id) % len(cls.COLOR_PALETTE)
# 确保class_id是整数
if not isinstance(class_id, int):
class_id = 0
if cls._model_colors is not None:
return cls._model_colors.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
return cls.DEFAULT_CLASS_COLORS.get(class_id, cls.COLOR_PALETTE[class_id % len(cls.COLOR_PALETTE)])
@classmethod
def get_all_classes(cls):
"""获取所有类别信息"""
if cls._model_classes is not None:
return cls._model_classes
return cls.DEFAULT_CLASS_NAMES_CN
# 全局配置实例
config = DetectionConfig()