Spaces:
Sleeping
Sleeping
import os | |
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection | |
import torch | |
from huggingface_hub import login | |
import logging | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from PIL import Image | |
import numpy as np | |
from config import MODEL_NAME | |
# 配置日志记录 | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
class RadarDetectionModel: | |
def __init__(self, model_name=None, use_auth_token=None): | |
""" | |
初始化雷达检测模型。 | |
Args: | |
model_name (str): 要加载的模型名称或路径 | |
use_auth_token (str, optional): 用于访问受限模型的Hugging Face令牌 | |
""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"使用设备: {self.device}") | |
self.model_name = model_name if model_name else MODEL_NAME | |
logger.info(f"模型名称: {self.model_name}") | |
self.use_auth_token = use_auth_token or os.environ.get("HF_TOKEN") | |
if self.use_auth_token: | |
logger.info("已提供Hugging Face令牌") | |
else: | |
logger.warning("未提供Hugging Face令牌,可能无法访问受限模型") | |
self.processor = None | |
self.model = None | |
# 加载模型和处理器 | |
logger.info("开始加载模型和处理器...") | |
self._load_model() | |
def _load_model(self): | |
"""加载模型和处理器,并监控内存使用情况""" | |
try: | |
logger.info(f"正在从{self.model_name}加载处理器") | |
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
if start_time: | |
start_time.record() | |
if self.use_auth_token: | |
# 如果提供了令牌,登录到Hugging Face Hub | |
logger.info("使用令牌登录到Hugging Face Hub") | |
login(token=self.use_auth_token) | |
self.processor = AutoProcessor.from_pretrained(self.model_name, use_auth_token=self.use_auth_token) | |
else: | |
self.processor = AutoProcessor.from_pretrained(self.model_name) | |
if end_time: | |
end_time.record() | |
torch.cuda.synchronize() | |
logger.info(f"处理器加载时间: {start_time.elapsed_time(end_time):.2f}毫秒") | |
logger.info(f"正在从{self.model_name}加载模型,使用8位量化以减少内存使用") | |
if start_time: | |
start_time.record() | |
# 使用8位量化以减少内存使用 | |
if self.use_auth_token: | |
self.model = AutoModelForVision2Seq.from_pretrained( | |
self.model_name, | |
use_auth_token=self.use_auth_token, | |
load_in_8bit=True, # 使用8位量化 | |
device_map="auto" # 自动管理设备放置 | |
) | |
else: | |
self.model = AutoModelForVision2Seq.from_pretrained( | |
self.model_name, | |
load_in_8bit=True, # 使用8位量化 | |
device_map="auto" # 自动管理设备放置 | |
) | |
if end_time: | |
end_time.record() | |
torch.cuda.synchronize() | |
logger.info(f"模型加载时间: {start_time.elapsed_time(end_time):.2f}毫秒") | |
logger.info(f"模型加载成功") | |
# 使用device_map="auto"时无需手动移动到设备 | |
self.model.eval() | |
# 记录模型信息 | |
param_count = sum(p.numel() for p in self.model.parameters()) | |
logger.info(f"模型参数数量: {param_count:,}") | |
if torch.cuda.is_available(): | |
memory_allocated = torch.cuda.memory_allocated() / (1024 * 1024) | |
memory_reserved = torch.cuda.memory_reserved() / (1024 * 1024) | |
logger.info(f"GPU内存分配: {memory_allocated:.2f}MB") | |
logger.info(f"GPU内存保留: {memory_reserved:.2f}MB") | |
except Exception as e: | |
logger.error(f"加载模型时出错: {str(e)}") | |
raise | |
def detect(self, image): | |
""" | |
检测雷达图像中的对象。 | |
Args: | |
image (PIL.Image): 要分析的雷达图像 | |
Returns: | |
dict: 检测结果,包括边界框、分数和标签 | |
""" | |
try: | |
if self.model is None or self.processor is None: | |
raise ValueError("模型或处理器未正确初始化") | |
# 预处理图像 | |
logger.info("预处理图像") | |
inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
# 运行推理 | |
logger.info("运行模型推理") | |
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
if start_time: | |
start_time.record() | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=50, | |
num_beams=4, | |
early_stopping=True | |
) | |
if end_time: | |
end_time.record() | |
torch.cuda.synchronize() | |
inference_time = start_time.elapsed_time(end_time) | |
logger.info(f"推理时间: {inference_time:.2f}毫秒") | |
# 处理输出 | |
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
logger.info(f"生成的文本: {generated_text}") | |
# 从生成的文本中解析检测结果 | |
boxes, scores, labels = self._parse_detection_results(generated_text, image.size) | |
logger.info(f"检测到{len(boxes)}个对象") | |
return { | |
'boxes': boxes, | |
'scores': scores, | |
'labels': labels, | |
'image': image | |
} | |
except Exception as e: | |
logger.error(f"检测过程中出错: {str(e)}") | |
# 返回备用检测结果 | |
return { | |
'boxes': [[100, 100, 200, 200]], | |
'scores': [0.75], | |
'labels': ['错误: ' + str(e)[:50]], | |
'image': image | |
} | |
def _parse_detection_results(self, text, image_size): | |
""" | |
从生成的文本中解析检测结果。 | |
Args: | |
text (str): 模型生成的文本 | |
image_size (tuple): 输入图像的大小(宽度, 高度) | |
Returns: | |
tuple: (boxes, scores, labels) | |
""" | |
# 这是一个简化的示例 - 实际解析将取决于模型输出格式 | |
# 为了演示,我们将提取一些模拟检测结果 | |
# 检查文本中常见的缺陷关键词 | |
defects = [] | |
if "crack" in text.lower() or "裂缝" in text.lower(): | |
defects.append(("裂缝", 0.92, [0.2, 0.3, 0.4, 0.5])) | |
if "corrosion" in text.lower() or "腐蚀" in text.lower(): | |
defects.append(("腐蚀", 0.85, [0.6, 0.2, 0.8, 0.4])) | |
if "damage" in text.lower() or "损坏" in text.lower(): | |
defects.append(("损坏", 0.78, [0.1, 0.7, 0.3, 0.9])) | |
if "defect" in text.lower() or "缺陷" in text.lower(): | |
defects.append(("缺陷", 0.88, [0.5, 0.5, 0.7, 0.7])) | |
# 如果没有找到缺陷,添加一个通用的 | |
if not defects: | |
defects.append(("异常", 0.75, [0.4, 0.4, 0.6, 0.6])) | |
# 将归一化坐标转换为像素坐标 | |
width, height = image_size | |
boxes = [] | |
scores = [] | |
labels = [] | |
for label, score, box in defects: | |
x1, y1, x2, y2 = box | |
pixel_box = [ | |
int(x1 * width), | |
int(y1 * height), | |
int(x2 * width), | |
int(y2 * height) | |
] | |
boxes.append(pixel_box) | |
scores.append(score) | |
labels.append(label) | |
return boxes, scores, labels | |