SeyedAli commited on
Commit
de15cca
1 Parent(s): 7ae86f7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer,VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
4
+ import torch
5
+ from PIL import Image
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ model_size = "small"
10
+ model_name = f"persiannlp/mt5-{model_size}-parsinlu-translation_en_fa"
11
+ translation_tokenizer = MT5Tokenizer.from_pretrained(model_name)
12
+ translation_model = MT5ForConditionalGeneration.from_pretrained(model_name)
13
+
14
+ translation_model=translation_model.to(device)
15
+
16
+ def run_transaltion_model(input_string, **generator_args):
17
+ input_ids = translation_tokenizer.encode(input_string, return_tensors="pt")
18
+ res = translation_model.generate(input_ids, **generator_args)
19
+ output = translation_tokenizer.batch_decode(res, skip_special_tokens=True)
20
+ return output
21
+
22
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
23
+ feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
24
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
25
+
26
+ model=model.to(device)
27
+
28
+ max_length = 32
29
+ num_beams = 4
30
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
31
+ def predict_step(image_paths):
32
+ images = []
33
+ for image_path in image_paths:
34
+ i_image = Image.open(image_path)
35
+ if i_image.mode != "RGB":
36
+ i_image = i_image.convert(mode="RGB")
37
+
38
+ images.append(i_image)
39
+
40
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
41
+ pixel_values = pixel_values.to(device)
42
+
43
+ output_ids = model.generate(pixel_values, **gen_kwargs)
44
+
45
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
46
+ preds = [pred.strip() for pred in preds]
47
+ return run_transaltion_model(preds[0])[0]
48
+
49
+ def ImageCaptioning(image):
50
+ with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
51
+ # Copy the contents of the uploaded image file to the temporary file
52
+ Image.fromarray(image).save(temp_image_file.name)
53
+ # Load the image file using Pillow
54
+ caption=predict_step(temp_image_file.name)
55
+ return caption
56
+
57
+ iface = gr.Interface(fn=ImageCaptioning, inputs="image", outputs="text")
58
+ iface.launch(share=False)