SkalskiP commited on
Commit
9364ec8
1 Parent(s): 7cddbb9
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -2,6 +2,7 @@ from typing import List
2
 
3
  import gradio as gr
4
  import numpy as np
 
5
  from transformers import CLIPProcessor, CLIPModel
6
 
7
  IMAGENET_CLASSES_FILE = "imagenet-classes.txt"
@@ -23,14 +24,15 @@ def load_text_lines(file_path: str) -> List[str]:
23
  return [line.rstrip() for line in lines]
24
 
25
 
26
- model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m")
 
27
  processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
28
  imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE)
29
 
30
 
31
  def classify_image(input_image) -> str:
32
  inputs = processor(
33
- text=['dog', 'person'],
34
  images=input_image,
35
  return_tensors="pt",
36
  padding=True)
 
2
 
3
  import gradio as gr
4
  import numpy as np
5
+ import torch
6
  from transformers import CLIPProcessor, CLIPModel
7
 
8
  IMAGENET_CLASSES_FILE = "imagenet-classes.txt"
 
24
  return [line.rstrip() for line in lines]
25
 
26
 
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device)
29
  processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
30
  imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE)
31
 
32
 
33
  def classify_image(input_image) -> str:
34
  inputs = processor(
35
+ text=imagenet_classes,
36
  images=input_image,
37
  return_tensors="pt",
38
  padding=True)