methodw commited on
Commit
a8c63a3
·
1 Parent(s): 41c5d45

test torch JIT

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -12,6 +12,12 @@ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  dino_v2_model = AutoModel.from_pretrained("./dinov2-large").to(torch_device)
13
  dino_v2_image_processor = AutoImageProcessor.from_pretrained("./dinov2-large")
14
 
 
 
 
 
 
 
15
 
16
  def process_image(image):
17
  """
@@ -46,7 +52,9 @@ def process_image(image):
46
  inputs = dino_v2_image_processor(images=image, return_tensors="pt").to(
47
  torch_device
48
  )
49
- outputs = dino_v2_model(**inputs)
 
 
50
 
51
  # Normalize the features before search, whatever that means
52
  embeddings = outputs.last_hidden_state
 
12
  dino_v2_model = AutoModel.from_pretrained("./dinov2-large").to(torch_device)
13
  dino_v2_image_processor = AutoImageProcessor.from_pretrained("./dinov2-large")
14
 
15
+ # Provide a sample input for tracing
16
+ sample_input = dino_v2_image_processor(
17
+ images=Image.new("RGB", (64, 64)), return_tensors="pt"
18
+ ).to(torch_device)
19
+ traced_dino_v2_model = torch.jit.trace(dino_v2_model, sample_input["pixel_values"])
20
+
21
 
22
  def process_image(image):
23
  """
 
52
  inputs = dino_v2_image_processor(images=image, return_tensors="pt").to(
53
  torch_device
54
  )
55
+
56
+ # Use the traced model for inference
57
+ outputs = traced_dino_v2_model(**inputs)
58
 
59
  # Normalize the features before search, whatever that means
60
  embeddings = outputs.last_hidden_state