Lander San Millan commited on
Commit
5f70455
·
1 Parent(s): 96e7cdc

feat: gradio for flamingo-models[QA+COT]

Browse files
Files changed (3) hide show
  1. app.py +79 -0
  2. giraffes.jpg +0 -0
  3. requirements.txt +120 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from flamingo_mini_task.utils import load_url
7
+ from flamingo_mini_task import FlamingoModel, FlamingoProcessor
8
+ from datasets import load_dataset,concatenate_datasets
9
+ from PIL import Image
10
+
11
+
12
+ flamingo_megatiny_captioning_models = {
13
+ 'flamingo-tiny-scienceQA[COT+QA]': {
14
+ 'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-tiny_ScienceQA_COT-QA'),
15
+ },
16
+ 'flamingo-mini-bilbaocaptions-scienceQA[QA]': {
17
+ 'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-mini-Bilbao_Captions-task_BilbaoQA-ScienceQA'),
18
+ },
19
+ 'flamingo-megatiny-opt-scienceQA[QA]':{
20
+ 'model': FlamingoModel.from_pretrained('landersanmi/flamingo-megatiny-opt-QA')
21
+ },
22
+ }
23
+
24
+
25
+ def generate_text(image, question, option_a, option_b, option_c, option_d, cot_checkbox, model_name):
26
+ model = flamingo_megatiny_captioning_models[model_name]['model']
27
+ processor = FlamingoProcessor(model.config)
28
+
29
+ prompt = ""
30
+ if cot_checkbox:
31
+ prompt += "[COT]"
32
+ else:
33
+ prompt += "[QA]"
34
+
35
+ prompt += "[CONTEXT]<image>[QUESTION]{} [OPTIONS] (A) {} (B) {} (C) {} (D) {} [ANSWER]".format(question,
36
+ option_a,
37
+ option_b,
38
+ option_c,
39
+ option_d)
40
+
41
+ print(prompt)
42
+ prediction = model.generate_captions(images = image,
43
+ processor = processor,
44
+ prompt = prompt,
45
+ )
46
+
47
+ return prediction[0]
48
+
49
+
50
+
51
+
52
+ image_input = gr.Image(value="giraffes.jpg")
53
+ question_input = gr.inputs.Textbox(default="Which animal is this?")
54
+ opt_a_input = gr.inputs.Textbox(default="Dog")
55
+ opt_b_input = gr.inputs.Textbox(default="Cat")
56
+ opt_c_input = gr.inputs.Textbox(default="Giraffe")
57
+ opt_d_input = gr.inputs.Textbox(default="Horse")
58
+ cot_checkbox = gr.inputs.Checkbox(label="Generate COT")
59
+ select_model = gr.inputs.Dropdown(choices=list(flamingo_megatiny_captioning_models.keys()))
60
+
61
+ text_output = gr.outputs.Textbox()
62
+
63
+ # Create the Gradio interface
64
+ gr.Interface(
65
+ fn=generate_text,
66
+ inputs=[image_input,
67
+ question_input,
68
+ opt_a_input,
69
+ opt_b_input,
70
+ opt_c_input,
71
+ opt_d_input,
72
+ cot_checkbox,
73
+ select_model
74
+ ],
75
+ outputs=text_output,
76
+ title='Generate answers from MCQ',
77
+ description='Generate answers from Multiple Choice Questions or generate a Chain Of Though about the question and the options given',
78
+ theme='default'
79
+ ).launch()
giraffes.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aiofiles==23.1.0
3
+ aiohttp==3.8.4
4
+ aiosignal==1.3.1
5
+ altair==5.0.1
6
+ anyio==3.7.0
7
+ appdirs==1.4.4
8
+ asttokens==2.2.1
9
+ async-timeout==4.0.2
10
+ attrs==23.1.0
11
+ backcall==0.2.0
12
+ beautifulsoup4==4.12.2
13
+ brotlipy==0.7.0
14
+ charset-normalizer==2.1.1
15
+ click==7.1.2
16
+ comm==0.1.3
17
+ contourpy==1.0.7
18
+ cycler==0.11.0
19
+ debugpy==1.6.7
20
+ decorator==5.1.1
21
+ deep-translator==1.11.0
22
+ dill==0.3.6
23
+ docker-pycreds==0.4.0
24
+ einops==0.6.1
25
+ einops-exts==0.0.4
26
+ evaluate==0.4.0
27
+ exceptiongroup==1.1.1
28
+ executing==1.2.0
29
+ fastapi==0.95.2
30
+ ffmpy==0.3.0
31
+ git+https://github.com/TheMrguiller/MUCSI_Modal.git@main#subdirectory=flamingo-train_task
32
+ fonttools==4.39.3
33
+ frozenlist==1.3.3
34
+ fsspec==2023.5.0
35
+ gitdb==4.0.10
36
+ GitPython==3.1.31
37
+ gradio==3.33.1
38
+ gradio_client==0.2.5
39
+ h11==0.14.0
40
+ httpcore==0.17.2
41
+ httpx==0.24.1
42
+ huggingface-hub==0.14.1
43
+ install==1.3.5
44
+ ipykernel==6.23.1
45
+ ipython==8.13.2
46
+ jedi==0.18.2
47
+ joblib==1.2.0
48
+ jsonschema==4.17.3
49
+ jupyter_client==8.2.0
50
+ jupyter_core==5.3.0
51
+ kiwisolver==1.4.4
52
+ langdetect==1.0.9
53
+ linkify-it-py==2.0.2
54
+ markdown-it-py==2.2.0
55
+ matplotlib==3.7.1
56
+ matplotlib-inline==0.1.6
57
+ mdit-py-plugins==0.3.3
58
+ mdurl==0.1.2
59
+ mpmath==1.2.1
60
+ multidict==6.0.4
61
+ multiprocess==0.70.14
62
+ nest-asyncio==1.5.6
63
+ nltk==3.8
64
+ orjson==3.9.0
65
+ packaging==23.1
66
+ pandas==2.0.1
67
+ parso==0.8.3
68
+ pathtools==0.1.2
69
+ pexpect==4.8.0
70
+ pickleshare==0.7.5
71
+ Pillow==9.4.0
72
+ platformdirs==3.5.1
73
+ prompt-toolkit==3.0.38
74
+ protobuf==4.23.0
75
+ psutil==5.9.5
76
+ ptyprocess==0.7.0
77
+ pure-eval==0.2.2
78
+ pyarrow==12.0.0
79
+ pycocotools==2.0.6
80
+ pydantic==1.10.8
81
+ pydub==0.25.1
82
+ Pygments==2.15.1
83
+ pyparsing==3.0.9
84
+ pyrsistent==0.19.3
85
+ python-dateutil==2.8.2
86
+ python-multipart==0.0.6
87
+ pytz==2023.3
88
+ PyYAML==6.0
89
+ pyzmq==25.0.2
90
+ regex==2023.5.5
91
+ responses==0.18.0
92
+ rfc3987==1.3.8
93
+ rouge-score==0.1.2
94
+ semantic-version==2.10.0
95
+ sentry-sdk==1.22.2
96
+ setproctitle==1.3.2
97
+ six==1.16.0
98
+ smmap==5.0.0
99
+ sniffio==1.3.0
100
+ soupsieve==2.4.1
101
+ stack-data==0.6.2
102
+ starlette==0.27.0
103
+ tokenizers==0.13.3
104
+ toolz==0.12.0
105
+ torch==2.0.1
106
+ torchaudio==2.0.2
107
+ torchvision==0.15.2
108
+ tornado==6.3.2
109
+ tqdm==4.65.0
110
+ traitlets==5.9.0
111
+ transformers==4.28.1
112
+ triton==2.0.0
113
+ tzdata==2023.3
114
+ uc-micro-py==1.0.2
115
+ uvicorn==0.22.0
116
+ wandb==0.15.2
117
+ wcwidth==0.2.6
118
+ websockets==11.0.3
119
+ xxhash==3.2.0
120
+ yarl==1.9.2