lakshayt commited on
Commit
188ba58
1 Parent(s): 369f754

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import PIL
5
+
6
+ from open_flamingo import create_model_and_transforms
7
+
8
+ model, image_processor, tokenizer = create_model_and_transforms(
9
+ clip_vision_encoder_path="ViT-L-14",
10
+ clip_vision_encoder_pretrained="openai",
11
+ lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
12
+ tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
13
+ cross_attn_every_n_layers=1,
14
+ cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache
15
+ )
16
+
17
+ # grab model checkpoint from huggingface hub
18
+ from huggingface_hub import hf_hub_download
19
+ import torch
20
+
21
+ checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
22
+ model.load_state_dict(torch.load(checkpoint_path), strict=False)
23
+
24
+ from PIL import Image
25
+ import requests
26
+ import torch
27
+
28
+ """
29
+ Step 1: Load images
30
+ """
31
+ demo_image_one = Image.open(
32
+ requests.get(
33
+ "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
34
+ ).raw
35
+ )
36
+
37
+ demo_image_two = Image.open(
38
+ requests.get(
39
+ "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
40
+ stream=True
41
+ ).raw
42
+ )
43
+
44
+ query_image = Image.open(
45
+ requests.get(
46
+ "http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
47
+ stream=True
48
+ ).raw
49
+ )
50
+
51
+
52
+ """
53
+ Step 2: Preprocessing images
54
+ Details: For OpenFlamingo, we expect the image to be a torch tensor of shape
55
+ batch_size x num_media x num_frames x channels x height x width.
56
+ In this case batch_size = 1, num_media = 3, num_frames = 1,
57
+ channels = 3, height = 224, width = 224.
58
+ Step 3: Preprocessing text
59
+ Details: In the text we expect an <image> special token to indicate where an image is.
60
+ We also expect an <|endofchunk|> special token to indicate the end of the text
61
+ portion associated with an image.
62
+ tokenizer.padding_side = "left" # For generation padding tokens should be on the left
63
+ lang_x = tokenizer(
64
+ ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
65
+ return_tensors="pt",
66
+ )
67
+ """
68
+
69
+ """
70
+ Step 4: Generate text
71
+ """
72
+
73
+
74
+ #print("Generated text: ", tokenizer.decode(generated_text[0]))
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+ def predict_caption(image, prompt):
83
+ assert isinstance(prompt, str)
84
+
85
+
86
+ vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
87
+ vision_x = torch.cat(vision_x, dim=0)
88
+ vision_x = vision_x.unsqueeze(1).unsqueeze(0)
89
+
90
+
91
+ tokenizer.padding_side = "left" # For generation padding tokens should be on the left
92
+ lang_x = tokenizer(
93
+ ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
94
+ return_tensors="pt",
95
+ )
96
+
97
+ tokenizer.padding_side = "left" # For generation padding tokens should be on the left
98
+ lang_x = tokenizer(
99
+ ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
100
+ return_tensors="pt",
101
+ )
102
+
103
+ caption = tokenizer.decode(generated_text[0])
104
+
105
+ return caption
106
+
107
+
108
+ iface = gr.Interface(fn=predict_caption,
109
+ inputs=[gr.Image(type="pil"), gr.Textbox(value=DEFAULT_PROMPT, label="Prompt")],
110
+ examples=examples,
111
+ outputs="text")
112
+
113
+ iface.launch()