hysts HF staff commited on
Commit
bc904d0
1 Parent(s): d4d8571
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +144 -3
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,5 +1,4 @@
1
  ---
2
- license: mit
3
  title: InstructBLIP
4
  emoji: ⚡
5
  colorFrom: red
@@ -9,6 +8,7 @@ sdk_version: 3.50.2
9
  python_version: 3.10.13
10
  app_file: app.py
11
  pinned: false
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
 
2
  title: InstructBLIP
3
  emoji: ⚡
4
  colorFrom: red
 
8
  python_version: 3.10.13
9
  app_file: app.py
10
  pinned: false
11
+ license: mit
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,9 +1,150 @@
1
  #!/usr/bin/env python
2
 
 
 
 
 
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- with gr.Blocks() as demo:
6
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  if __name__ == "__main__":
9
- demo.queue().launch()
 
1
  #!/usr/bin/env python
2
 
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
  import gradio as gr
8
+ import PIL.Image
9
+ import spaces
10
+ import torch
11
+ from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor
12
+
13
+ DESCRIPTION = "# InstructBLIP"
14
+
15
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ model_id = "Salesforce/instructblip-vicuna-7b"
20
+ processor = InstructBlipProcessor.from_pretrained(model_id)
21
+ model = InstructBlipForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
22
+
23
+
24
+ @spaces.GPU
25
+ def run(
26
+ image: PIL.Image.Image,
27
+ prompt: str,
28
+ text_decoding_method: str = "Nucleus sampling",
29
+ num_beams: int = 5,
30
+ max_length: int = 256,
31
+ min_length: int = 1,
32
+ top_p: float = 0.9,
33
+ repetition_penalty: float = 1.5,
34
+ length_penalty: float = 1.0,
35
+ temperature: float = 1.0,
36
+ ) -> str:
37
+ h, w = image.size
38
+ scale = MAX_IMAGE_SIZE / max(h, w)
39
+ if scale < 1:
40
+ new_w = int(w * scale)
41
+ new_h = int(h * scale)
42
+ image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)
43
+
44
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
45
+ generated_ids = model.generate(
46
+ **inputs,
47
+ do_sample=text_decoding_method == "Nucleus sampling",
48
+ num_beams=num_beams,
49
+ max_length=max_length,
50
+ min_length=min_length,
51
+ top_p=top_p,
52
+ repetition_penalty=repetition_penalty,
53
+ length_penalty=length_penalty,
54
+ temperature=temperature,
55
+ )
56
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
57
+ return generated_caption
58
+
59
+
60
+ with gr.Blocks(css="style.css") as demo:
61
+ gr.Markdown(DESCRIPTION)
62
+
63
+ with gr.Row():
64
+ with gr.Column():
65
+ input_image = gr.Image(type="pil")
66
+ prompt = gr.Textbox(label="Prompt")
67
+ run_button = gr.Button()
68
+ with gr.Accordion(label="Advanced options", open=False):
69
+ text_decoding_method = gr.Radio(
70
+ label="Text Decoding Method",
71
+ choices=["Beam search", "Nucleus sampling"],
72
+ value="Nucleus sampling",
73
+ )
74
+ num_beams = gr.Slider(
75
+ label="Number of Beams",
76
+ minimum=1,
77
+ maximum=10,
78
+ step=1,
79
+ value=5,
80
+ )
81
+ max_length = gr.Slider(
82
+ label="Max Length",
83
+ minimum=1,
84
+ maximum=512,
85
+ step=1,
86
+ value=256,
87
+ )
88
+ min_length = gr.Slider(
89
+ label="Minimum Length",
90
+ minimum=1,
91
+ maximum=64,
92
+ step=1,
93
+ value=1,
94
+ )
95
+ top_p = gr.Slider(
96
+ label="Top P",
97
+ minimum=0.1,
98
+ maximum=1.0,
99
+ step=0.1,
100
+ value=0.9,
101
+ )
102
+ repetition_penalty = gr.Slider(
103
+ label="Repetition Penalty",
104
+ info="Larger value prevents repetition.",
105
+ minimum=1.0,
106
+ maximum=5.0,
107
+ step=0.5,
108
+ value=1.5,
109
+ )
110
+ length_penalty = gr.Slider(
111
+ label="Length Penalty",
112
+ info="Set to larger for longer sequence, used with beam search.",
113
+ minimum=-1.0,
114
+ maximum=2.0,
115
+ step=0.2,
116
+ value=1.0,
117
+ )
118
+ temperature = gr.Slider(
119
+ label="Temperature",
120
+ info="Used with nucleus sampling.",
121
+ minimum=0.5,
122
+ maximum=1.0,
123
+ step=0.1,
124
+ value=1.0,
125
+ )
126
+
127
+ with gr.Column():
128
+ output = gr.Textbox(label="Result")
129
 
130
+ gr.on(
131
+ triggers=[prompt.submit, run_button.click],
132
+ fn=run,
133
+ inputs=[
134
+ input_image,
135
+ prompt,
136
+ text_decoding_method,
137
+ num_beams,
138
+ max_length,
139
+ min_length,
140
+ top_p,
141
+ repetition_penalty,
142
+ length_penalty,
143
+ temperature,
144
+ ],
145
+ outputs=output,
146
+ api_name="run",
147
+ )
148
 
149
  if __name__ == "__main__":
150
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ gradio==3.50.2
3
+ Pillow==10.1.0
4
+ spaces==0.16.3
5
+ torch==2.0.0
6
+ torchvision==0.15.1
7
+ transformers==4.34.1