wangjin2000 commited on
Commit
cafea32
1 Parent(s): a4223f3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ import torch
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+ # Inference with the pre-trained model
8
+ checkpoint1 = "microsoft/git-base"
9
+ processor = AutoProcessor.from_pretrained(checkpoint1)
10
+ #model1 = AutoModelForCausalLM.from_pretrained(checkpoint1, use_auth_token=access_token)
11
+ model1 = AutoModelForCausalLM.from_pretrained(checkpoint1)
12
+
13
+ # Inference with the fine-tuned model
14
+ checkpoint2 = "wangjin2000/git-base-finetune"
15
+ #model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token)
16
+ model2 = AutoModelForCausalLM.from_pretrained(checkpoint2)
17
+
18
+ def img2cap_com(image):
19
+ input1 = processor(images=image, return_tensors="pt").to(device)
20
+ pixel_values1 = input1.pixel_values
21
+ generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50)
22
+ generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0]
23
+
24
+ input2 = processor(images=image, return_tensors="pt").to(device)
25
+ pixel_values2 = input2.pixel_values
26
+ generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50)
27
+ generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0]
28
+
29
+ return generated_caption1,generated_caption2
30
+
31
+ inputs = [
32
+ gr.inputs.Image(type="pil", label="Original Image")
33
+ ]
34
+
35
+ outputs = [
36
+ gr.outputs.Textbox(label="Caption from pre-trained model"),
37
+ gr.outputs.Textbox(label="Caption from fine-tuned model"),
38
+ ]
39
+
40
+ title = "Image Captioning using Pre-trained and Fine-tuned Model"
41
+ description = "GIT-base is used to generate Image Caption for the uploaded image."
42
+
43
+ examples = [
44
+ ["Image1.png"],
45
+ ["Image2.png"],
46
+ ["Image3.png"],
47
+ ["Image4.png"],
48
+ ["Image5.png"],
49
+ ["Image6.png"]
50
+ ]
51
+
52
+ gr.Interface(
53
+ img2cap_com,
54
+ inputs,
55
+ outputs,
56
+ title=title,
57
+ description=description,
58
+ examples=examples,
59
+ theme="huggingface",
60
+ ).launch(debug=True, enable_queue=True)