Omkar Shidore commited on
Commit
360a987
1 Parent(s): e0d8048
Files changed (4) hide show
  1. data/car.jpg +0 -0
  2. data/gsd.jpg +0 -0
  3. data/highway.jpg +0 -0
  4. main.py +39 -0
data/car.jpg ADDED
data/gsd.jpg ADDED
data/highway.jpg ADDED
main.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
+
6
+ def main():
7
+ model = VisionEncoderDecoderModel.from_pretrained("OmkarShidore/scene-caption")
8
+ feature_extractor = ViTImageProcessor.from_pretrained("OmkarShidore/scene-caption")
9
+ tokenizer = AutoTokenizer.from_pretrained("OmkarShidore/scene-caption")
10
+
11
+ max_length = 16
12
+ num_beams = 4
13
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
14
+ def predict(image):
15
+ #image = Image.open(image_path)
16
+ image = image.convert(mode="RGB")
17
+ pixel_values = feature_extractor(images=[image], return_tensors="pt").pixel_values
18
+ pixel_values = pixel_values.to(device="cpu")
19
+ output_ids = model.generate(pixel_values, **gen_kwargs)
20
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
21
+ preds = [pred.strip() for pred in preds]
22
+ return preds[0]
23
+
24
+ #built interface with gradio to test the function
25
+ imagein = gr.components.Image(label='Scene Image', type='pil')
26
+ output = gr.components.Textbox()
27
+ gui = gr.Interface(fn=predict, inputs=imagein, outputs=[output])
28
+
29
+ gr.Interface(fn=predict,
30
+ inputs=imagein,
31
+ outputs=output,
32
+ title='Image To Text- Scene Description',
33
+ description="<html> <body> <h3>Hugging Face: <a href='https://huggingface.co/OmkarShidore/scene-caption'>OmkarShidore/scene-caption</a></h3><h3>Git: <a href='https://github.com/OmkarShidore/ImageToText-SceneDescription'>OmkarShidore/ImageToText-SceneDescription</a></h3> </body></html>",
34
+ examples=["./data/car.jpg", "./data/gsd.jpg", "./data/highway.jpg"],
35
+ theme=gr.themes.Base()
36
+ ).launch(share=True);
37
+
38
+ if __name__ == '__main__':
39
+ main()