g8a9 commited on
Commit
f46ae84
1 Parent(s): 0109735

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -6,7 +6,13 @@ import torch
6
 
7
 
8
  CHECKPOINT = "g8a9/vit-geppetto-captioning"
9
- model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT)
 
 
 
 
 
 
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT)
11
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
12
 
@@ -42,6 +48,7 @@ elif gen_mode == "sampling":
42
  def generate_caption(url):
43
  image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
44
  inputs = feature_extractor(image, return_tensors="pt")
 
45
  generated_ids = model.generate(
46
  inputs["pixel_values"],
47
  max_length=20,
6
 
7
 
8
  CHECKPOINT = "g8a9/vit-geppetto-captioning"
9
+
10
+ @st.cache
11
+ def get_model():
12
+ model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT)
13
+ return model
14
+
15
+
16
  feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT)
17
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
18
 
48
  def generate_caption(url):
49
  image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
50
  inputs = feature_extractor(image, return_tensors="pt")
51
+ model = get_model()
52
  generated_ids = model.generate(
53
  inputs["pixel_values"],
54
  max_length=20,