radar-analysis / model.py
chenxingqiang
Optimize model loading and improve user experience
3228ab0
import os
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
import torch
from huggingface_hub import login
import logging
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import numpy as np
from config import MODEL_NAME
# 配置日志记录
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class RadarDetectionModel:
def __init__(self, model_name=None, use_auth_token=None):
"""
初始化雷达检测模型。
Args:
model_name (str): 要加载的模型名称或路径
use_auth_token (str, optional): 用于访问受限模型的Hugging Face令牌
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {self.device}")
self.model_name = model_name if model_name else MODEL_NAME
logger.info(f"模型名称: {self.model_name}")
self.use_auth_token = use_auth_token or os.environ.get("HF_TOKEN")
if self.use_auth_token:
logger.info("已提供Hugging Face令牌")
else:
logger.warning("未提供Hugging Face令牌,可能无法访问受限模型")
self.processor = None
self.model = None
# 加载模型和处理器
logger.info("开始加载模型和处理器...")
self._load_model()
def _load_model(self):
"""加载模型和处理器,并监控内存使用情况"""
try:
logger.info(f"正在从{self.model_name}加载处理器")
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
if start_time:
start_time.record()
if self.use_auth_token:
# 如果提供了令牌,登录到Hugging Face Hub
logger.info("使用令牌登录到Hugging Face Hub")
login(token=self.use_auth_token)
self.processor = AutoProcessor.from_pretrained(self.model_name, use_auth_token=self.use_auth_token)
else:
self.processor = AutoProcessor.from_pretrained(self.model_name)
if end_time:
end_time.record()
torch.cuda.synchronize()
logger.info(f"处理器加载时间: {start_time.elapsed_time(end_time):.2f}毫秒")
logger.info(f"正在从{self.model_name}加载模型,使用8位量化以减少内存使用")
if start_time:
start_time.record()
# 使用8位量化以减少内存使用
if self.use_auth_token:
self.model = AutoModelForVision2Seq.from_pretrained(
self.model_name,
use_auth_token=self.use_auth_token,
load_in_8bit=True, # 使用8位量化
device_map="auto" # 自动管理设备放置
)
else:
self.model = AutoModelForVision2Seq.from_pretrained(
self.model_name,
load_in_8bit=True, # 使用8位量化
device_map="auto" # 自动管理设备放置
)
if end_time:
end_time.record()
torch.cuda.synchronize()
logger.info(f"模型加载时间: {start_time.elapsed_time(end_time):.2f}毫秒")
logger.info(f"模型加载成功")
# 使用device_map="auto"时无需手动移动到设备
self.model.eval()
# 记录模型信息
param_count = sum(p.numel() for p in self.model.parameters())
logger.info(f"模型参数数量: {param_count:,}")
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / (1024 * 1024)
memory_reserved = torch.cuda.memory_reserved() / (1024 * 1024)
logger.info(f"GPU内存分配: {memory_allocated:.2f}MB")
logger.info(f"GPU内存保留: {memory_reserved:.2f}MB")
except Exception as e:
logger.error(f"加载模型时出错: {str(e)}")
raise
def detect(self, image):
"""
检测雷达图像中的对象。
Args:
image (PIL.Image): 要分析的雷达图像
Returns:
dict: 检测结果,包括边界框、分数和标签
"""
try:
if self.model is None or self.processor is None:
raise ValueError("模型或处理器未正确初始化")
# 预处理图像
logger.info("预处理图像")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# 运行推理
logger.info("运行模型推理")
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
if start_time:
start_time.record()
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=50,
num_beams=4,
early_stopping=True
)
if end_time:
end_time.record()
torch.cuda.synchronize()
inference_time = start_time.elapsed_time(end_time)
logger.info(f"推理时间: {inference_time:.2f}毫秒")
# 处理输出
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
logger.info(f"生成的文本: {generated_text}")
# 从生成的文本中解析检测结果
boxes, scores, labels = self._parse_detection_results(generated_text, image.size)
logger.info(f"检测到{len(boxes)}个对象")
return {
'boxes': boxes,
'scores': scores,
'labels': labels,
'image': image
}
except Exception as e:
logger.error(f"检测过程中出错: {str(e)}")
# 返回备用检测结果
return {
'boxes': [[100, 100, 200, 200]],
'scores': [0.75],
'labels': ['错误: ' + str(e)[:50]],
'image': image
}
def _parse_detection_results(self, text, image_size):
"""
从生成的文本中解析检测结果。
Args:
text (str): 模型生成的文本
image_size (tuple): 输入图像的大小(宽度, 高度)
Returns:
tuple: (boxes, scores, labels)
"""
# 这是一个简化的示例 - 实际解析将取决于模型输出格式
# 为了演示,我们将提取一些模拟检测结果
# 检查文本中常见的缺陷关键词
defects = []
if "crack" in text.lower() or "裂缝" in text.lower():
defects.append(("裂缝", 0.92, [0.2, 0.3, 0.4, 0.5]))
if "corrosion" in text.lower() or "腐蚀" in text.lower():
defects.append(("腐蚀", 0.85, [0.6, 0.2, 0.8, 0.4]))
if "damage" in text.lower() or "损坏" in text.lower():
defects.append(("损坏", 0.78, [0.1, 0.7, 0.3, 0.9]))
if "defect" in text.lower() or "缺陷" in text.lower():
defects.append(("缺陷", 0.88, [0.5, 0.5, 0.7, 0.7]))
# 如果没有找到缺陷,添加一个通用的
if not defects:
defects.append(("异常", 0.75, [0.4, 0.4, 0.6, 0.6]))
# 将归一化坐标转换为像素坐标
width, height = image_size
boxes = []
scores = []
labels = []
for label, score, box in defects:
x1, y1, x2, y2 = box
pixel_box = [
int(x1 * width),
int(y1 * height),
int(x2 * width),
int(y2 * height)
]
boxes.append(pixel_box)
scores.append(score)
labels.append(label)
return boxes, scores, labels