import os from transformers import AutoFeatureExtractor, AutoModelForObjectDetection import torch from huggingface_hub import login import logging from transformers import AutoProcessor, AutoModelForVision2Seq from PIL import Image import numpy as np from config import MODEL_NAME # 配置日志记录 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class RadarDetectionModel: def __init__(self, model_name=None, use_auth_token=None): """ 初始化雷达检测模型。 Args: model_name (str): 要加载的模型名称或路径 use_auth_token (str, optional): 用于访问受限模型的Hugging Face令牌 """ self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {self.device}") self.model_name = model_name if model_name else MODEL_NAME logger.info(f"模型名称: {self.model_name}") self.use_auth_token = use_auth_token or os.environ.get("HF_TOKEN") if self.use_auth_token: logger.info("已提供Hugging Face令牌") else: logger.warning("未提供Hugging Face令牌,可能无法访问受限模型") self.processor = None self.model = None # 加载模型和处理器 logger.info("开始加载模型和处理器...") self._load_model() def _load_model(self): """加载模型和处理器,并监控内存使用情况""" try: logger.info(f"正在从{self.model_name}加载处理器") start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None if start_time: start_time.record() if self.use_auth_token: # 如果提供了令牌,登录到Hugging Face Hub logger.info("使用令牌登录到Hugging Face Hub") login(token=self.use_auth_token) self.processor = AutoProcessor.from_pretrained(self.model_name, use_auth_token=self.use_auth_token) else: self.processor = AutoProcessor.from_pretrained(self.model_name) if end_time: end_time.record() torch.cuda.synchronize() logger.info(f"处理器加载时间: {start_time.elapsed_time(end_time):.2f}毫秒") logger.info(f"正在从{self.model_name}加载模型,使用8位量化以减少内存使用") if start_time: start_time.record() # 使用8位量化以减少内存使用 if self.use_auth_token: self.model = AutoModelForVision2Seq.from_pretrained( self.model_name, use_auth_token=self.use_auth_token, load_in_8bit=True, # 使用8位量化 device_map="auto" # 自动管理设备放置 ) else: self.model = AutoModelForVision2Seq.from_pretrained( self.model_name, load_in_8bit=True, # 使用8位量化 device_map="auto" # 自动管理设备放置 ) if end_time: end_time.record() torch.cuda.synchronize() logger.info(f"模型加载时间: {start_time.elapsed_time(end_time):.2f}毫秒") logger.info(f"模型加载成功") # 使用device_map="auto"时无需手动移动到设备 self.model.eval() # 记录模型信息 param_count = sum(p.numel() for p in self.model.parameters()) logger.info(f"模型参数数量: {param_count:,}") if torch.cuda.is_available(): memory_allocated = torch.cuda.memory_allocated() / (1024 * 1024) memory_reserved = torch.cuda.memory_reserved() / (1024 * 1024) logger.info(f"GPU内存分配: {memory_allocated:.2f}MB") logger.info(f"GPU内存保留: {memory_reserved:.2f}MB") except Exception as e: logger.error(f"加载模型时出错: {str(e)}") raise def detect(self, image): """ 检测雷达图像中的对象。 Args: image (PIL.Image): 要分析的雷达图像 Returns: dict: 检测结果,包括边界框、分数和标签 """ try: if self.model is None or self.processor is None: raise ValueError("模型或处理器未正确初始化") # 预处理图像 logger.info("预处理图像") inputs = self.processor(images=image, return_tensors="pt").to(self.device) # 运行推理 logger.info("运行模型推理") start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None if start_time: start_time.record() with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=50, num_beams=4, early_stopping=True ) if end_time: end_time.record() torch.cuda.synchronize() inference_time = start_time.elapsed_time(end_time) logger.info(f"推理时间: {inference_time:.2f}毫秒") # 处理输出 generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] logger.info(f"生成的文本: {generated_text}") # 从生成的文本中解析检测结果 boxes, scores, labels = self._parse_detection_results(generated_text, image.size) logger.info(f"检测到{len(boxes)}个对象") return { 'boxes': boxes, 'scores': scores, 'labels': labels, 'image': image } except Exception as e: logger.error(f"检测过程中出错: {str(e)}") # 返回备用检测结果 return { 'boxes': [[100, 100, 200, 200]], 'scores': [0.75], 'labels': ['错误: ' + str(e)[:50]], 'image': image } def _parse_detection_results(self, text, image_size): """ 从生成的文本中解析检测结果。 Args: text (str): 模型生成的文本 image_size (tuple): 输入图像的大小(宽度, 高度) Returns: tuple: (boxes, scores, labels) """ # 这是一个简化的示例 - 实际解析将取决于模型输出格式 # 为了演示,我们将提取一些模拟检测结果 # 检查文本中常见的缺陷关键词 defects = [] if "crack" in text.lower() or "裂缝" in text.lower(): defects.append(("裂缝", 0.92, [0.2, 0.3, 0.4, 0.5])) if "corrosion" in text.lower() or "腐蚀" in text.lower(): defects.append(("腐蚀", 0.85, [0.6, 0.2, 0.8, 0.4])) if "damage" in text.lower() or "损坏" in text.lower(): defects.append(("损坏", 0.78, [0.1, 0.7, 0.3, 0.9])) if "defect" in text.lower() or "缺陷" in text.lower(): defects.append(("缺陷", 0.88, [0.5, 0.5, 0.7, 0.7])) # 如果没有找到缺陷,添加一个通用的 if not defects: defects.append(("异常", 0.75, [0.4, 0.4, 0.6, 0.6])) # 将归一化坐标转换为像素坐标 width, height = image_size boxes = [] scores = [] labels = [] for label, score, box in defects: x1, y1, x2, y2 = box pixel_box = [ int(x1 * width), int(y1 * height), int(x2 * width), int(y2 * height) ] boxes.append(pixel_box) scores.append(score) labels.append(label) return boxes, scores, labels