xingqiang commited on
Commit
9ff97fa
·
1 Parent(s): 4e8a01a

update models

Browse files
Files changed (1) hide show
  1. model.py +11 -19
model.py CHANGED
@@ -1,31 +1,23 @@
1
- from transformers import AutoConfig, AutoModelForObjectDetection
2
- from PIL import Image
3
  import torch
 
4
 
5
 
6
  class RadarDetectionModel:
7
  def __init__(self):
8
- self.config = AutoConfig.from_pretrained(
9
- "Extremely4606/paligemma_9_19")
10
  self.model = AutoModelForObjectDetection.from_pretrained(
11
- "Extremely4606/paligemma_9_19")
12
  self.model.eval()
13
 
14
- def preprocess_image(self, image):
15
- # 这里需要根据模型的具体要求来处理图像
16
- # 这只是一个示例,可能需要调整
17
- image = image.resize((224, 224))
18
- image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
19
- return image.unsqueeze(0)
20
-
21
  @torch.no_grad()
22
  def detect(self, image):
23
- inputs = self.preprocess_image(image)
24
- outputs = self.model(inputs)
25
 
26
- # 这里可能需要根据模型的输出格式进行调整
27
- boxes = outputs.pred_boxes[0]
28
- scores = outputs.scores[0]
29
- labels = outputs.labels[0]
30
 
31
- return {"boxes": boxes, "scores": scores, "labels": labels}
 
1
+ from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
 
2
  import torch
3
+ from config import MODEL_NAME
4
 
5
 
6
  class RadarDetectionModel:
7
  def __init__(self):
8
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
9
+ "google/pali-gamma-336m")
10
  self.model = AutoModelForObjectDetection.from_pretrained(
11
+ "google/pali-gamma-336m")
12
  self.model.eval()
13
 
 
 
 
 
 
 
 
14
  @torch.no_grad()
15
  def detect(self, image):
16
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
17
+ outputs = self.model(**inputs)
18
 
19
+ target_sizes = torch.tensor([image.size[::-1]])
20
+ results = self.feature_extractor.post_process_object_detection(
21
+ outputs, threshold=0.5, target_sizes=target_sizes)[0]
 
22
 
23
+ return results