MahsaShahidi commited on
Commit
1898ee1
1 Parent(s): 8d129e6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import gradio as gr
4
+ from pathlib import Path
5
+ from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
6
+
7
+ def predict(image, max_length=30, num_beams=4):
8
+ image = image.convert('RGB')
9
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
10
+ pixel_values = pixel_values.to(device)
11
+ with torch.no_grad():
12
+ caption_ids = model.generate(pixel_values.cpu())[0]
13
+ caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
14
+ return caption_text
15
+
16
+ model_path = "MahsaShahidi/Persian-Image-Captioning"
17
+ device = "cpu"
18
+ # Load model.
19
+ model = VisionEncoderDecoderModel.from_pretrained(model_path)
20
+ model.to(device)
21
+ print("Loaded model")
22
+ feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
23
+ print("Loaded feature_extractor")
24
+ tokenizer = AutoTokenizer.from_pretrained('HooshvareLab/bert-fa-base-uncased-clf-persiannews')
25
+ print("Loaded tokenizer")
26
+ title = "Persian Image Captioning"
27
+ description = ""
28
+
29
+ input = gr.inputs.Image(label="Image to search", type = 'pil', optional=False)
30
+ output = gr.outputs.Textbox(type="auto",label="Captions")
31
+
32
+ article = "This HuggingFace Space presents a demo for Persian Image Camptioning on VIT as its Encoder and ParsBERT (v2.0) as its Decoder"
33
+
34
+ images = [f"./image-{i}.jpg" for i in range(1,4)]
35
+
36
+ interface = gr.Interface(
37
+ fn=predict,
38
+ inputs = input,
39
+ outputs=output,
40
+ examples = images,
41
+ title=title,
42
+ description=article,
43
+ )
44
+ interface.launch(share = True)