GopalChettri commited on
Commit
a3b2be7
1 Parent(s): 6eed899

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -51
app.py CHANGED
@@ -1,63 +1,199 @@
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
  import torch
4
-
5
- theme = gr.themes.Monochrome(
6
- primary_hue="indigo",
7
- secondary_hue="blue",
8
- neutral_hue="slate",
9
- radius_size=gr.themes.sizes.radius_sm,
10
- font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
11
- )
12
 
13
  instruct_pipeline = pipeline(model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
14
- def generate(instruction):
15
- return instruct_pipeline(instruction)
16
 
17
 
18
- examples = [
19
- "Instead of making a peanut butter and jelly sandwich, what else could I combine peanut butter with in a sandwich? Give five ideas",
20
- "How do I make a campfire?",
21
- "Write me a tweet about the launch of Dolly 2.0, a new LLM"
22
-
23
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
25
 
26
- def process_example(args):
27
- for x in generate(args):
28
- pass
29
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- css = ".generating {visibility: hidden}"
 
 
 
 
 
 
 
32
 
33
- with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
34
- with gr.Column():
35
- gr.Markdown(
36
- """ ## Dolly 2.0
37
- Dolly 2.0 is a 12B parameter language model based on the EleutherAI pythia model family and fine-tuned exclusively on a new, high-quality human generated instruction following dataset, crowdsourced among Databricks employees. For more details, please refer to the [model card](https://huggingface.co/databricks/dolly-v2-12b)
 
38
 
39
- Type in the box below and click the button to generate answers to your most pressing questions!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- """
42
- )
43
- with gr.Row():
44
- with gr.Column(scale=3):
45
- instruction = gr.Textbox(placeholder="Enter your question here", label="Question", elem_id="q-input")
46
-
47
- with gr.Box():
48
- gr.Markdown("**Answer**")
49
- output = gr.Markdown(elem_id="q-output")
50
- submit = gr.Button("Generate", variant="primary")
51
- gr.Examples(
52
- examples=examples,
53
- inputs=[instruction],
54
- cache_examples=False,
55
- fn=process_example,
56
- outputs=[output],
57
- )
58
-
59
-
60
- submit.click(generate, inputs=[instruction], outputs=[output])
61
- instruction.submit(generate, inputs=[instruction], outputs=[output])
62
-
63
- demo.queue(concurrency_count=16).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Iterable
3
  import gradio as gr
4
+ from gradio.themes.base import Base
5
+ from gradio.themes.utils import colors, fonts, sizes
6
+ import time
7
  import torch
8
+ from transformers import pipeline
9
+ import pandas as pd
 
 
 
 
 
 
10
 
11
  instruct_pipeline = pipeline(model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
 
 
12
 
13
 
14
+ def run_pipeline(prompt):
15
+ response = instruct_pipeline(prompt)
16
+ return response
17
+
18
+ def get_user_input(input_question, history):
19
+ return "", history + [[input_question, None]]
20
+
21
+ def get_qa_user_input(input_question, history):
22
+ return "", history + [[input_question, None]]
23
+
24
+ def dolly_chat(history):
25
+ prompt = history[-1][0]
26
+ bot_message = run_pipeline(prompt)
27
+ history[-1][1] = bot_message
28
+ return history
29
+
30
+ def qa_bot(context, history):
31
+ query = history[-1][0]
32
+ prompt = f'instruction: {query} \ncontext: {context}'
33
+ bot_message = run_pipeline(prompt)
34
+ history[-1][1] = bot_message
35
+ return history
36
+
37
+ def reset_chatbot():
38
+ return gr.update(value="")
39
+
40
+ def load_customer_support_example():
41
+ df = pd.read_csv("examples.csv")
42
+ return df['doc'].iloc[0], df['question'].iloc[0]
43
 
44
+ def load_databricks_doc_example():
45
+ df = pd.read_csv("examples.csv")
46
+ return df['doc'].iloc[1], df['question'].iloc[1]
47
 
48
+ # Referred & modified from https://gradio.app/theming-guide/
49
+ class SeafoamCustom(Base):
50
+ def __init__(
51
+ self,
52
+ *,
53
+ primary_hue: colors.Color | str = colors.emerald,
54
+ secondary_hue: colors.Color | str = colors.blue,
55
+ neutral_hue: colors.Color | str = colors.blue,
56
+ spacing_size: sizes.Size | str = sizes.spacing_md,
57
+ radius_size: sizes.Size | str = sizes.radius_md,
58
+ font: fonts.Font
59
+ | str
60
+ | Iterable[fonts.Font | str] = (
61
+ fonts.GoogleFont("Quicksand"),
62
+ "ui-sans-serif",
63
+ "sans-serif",
64
+ ),
65
+ font_mono: fonts.Font
66
+ | str
67
+ | Iterable[fonts.Font | str] = (
68
+ fonts.GoogleFont("IBM Plex Mono"),
69
+ "ui-monospace",
70
+ "monospace",
71
+ ),
72
+ ):
73
+ super().__init__(
74
+ primary_hue=primary_hue,
75
+ secondary_hue=secondary_hue,
76
+ neutral_hue=neutral_hue,
77
+ spacing_size=spacing_size,
78
+ radius_size=radius_size,
79
+ font=font,
80
+ font_mono=font_mono,
81
+ )
82
+ super().set(
83
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
84
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
85
+ button_primary_text_color="white",
86
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
87
+ block_shadow="*shadow_drop_lg",
88
+ button_shadow="*shadow_drop_lg",
89
+ input_background_fill="zinc",
90
+ input_border_color="*secondary_300",
91
+ input_shadow="*shadow_drop",
92
+ input_shadow_focus="*shadow_drop_lg",
93
+ )
94
+
95
+
96
+ seafoam = SeafoamCustom()
97
+
98
+ with gr.Blocks(theme=seafoam) as demo:
99
 
100
+ with gr.Row(variant='panel'):
101
+ with gr.Column():
102
+ gr.HTML(
103
+ """<html><img src='file/dolly.jpg', alt='dolly logo', width=150, height=150 /><br></html>"""
104
+ )
105
+ with gr.Column():
106
+ gr.Markdown("# **<p align='center'>Dolly 2.0: World's First Truly Open Instruction-Tuned LLM</p>**")
107
+ gr.Markdown("Dolly 2.0, the first open source, instruction-following LLM, fine-tuned on a human-generated instruction dataset licensed for research and commercial use. It's a 12B parameter language model based on the EleutherAI pythia model family and fine-tuned exclusively on a new, high-quality human generated instruction following dataset, crowdsourced among Databricks employees.")
108
 
109
+
110
+
111
+ qa_bot_state = gr.State(value=[])
112
+
113
+ with gr.Tabs():
114
+ with gr.TabItem("Dolly Chat"):
115
 
116
+ with gr.Row():
117
+
118
+ with gr.Column():
119
+ chatbot = gr.Chatbot(label="Chat History")
120
+ input_question = gr.Text(
121
+ label="Instruction",
122
+ placeholder="Type prompt and hit enter.",
123
+ )
124
+ clear = gr.Button("Clear", variant="primary")
125
+
126
+ with gr.Row():
127
+ with gr.Accordion("Show example inputs I can load:", open=False):
128
+ gr.Examples(
129
+ [
130
+ ["Explain to me the difference between nuclear fission and fusion."],
131
+ ["Give me a list of 5 science fiction books I should read next."],
132
+ ["I'm selling my Nikon D-750, write a short blurb for my ad."],
133
+ ["Write a song about sour donuts"],
134
+ ["Write a tweet about a new book launch by J.K. Rowling."],
135
+
136
+ ],
137
+ [input_question],
138
+ [],
139
+ None,
140
+ cache_examples=False,
141
+ )
142
+
143
+ with gr.TabItem("Q&A with Context"):
144
 
145
+ with gr.Row():
146
+
147
+ with gr.Column():
148
+ input_context = gr.Text(label="Add context here", lines=10)
149
+
150
+ with gr.Column():
151
+ qa_chatbot = gr.Chatbot(label="Q&A History")
152
+ qa_input_question = gr.Text(
153
+ label="Input Question",
154
+ placeholder="Type question here and hit enter.",
155
+ )
156
+ qa_clear = gr.Button("Clear", variant="primary")
157
+
158
+ with gr.Row():
159
+ with gr.Accordion("Show example inputs I can load:", open=False):
160
+ example_1 = gr.Button("Load Customer support example")
161
+ example_2 = gr.Button("Load Databricks documentation example")
162
+
163
+
164
+ input_question.submit(
165
+ get_user_input,
166
+ [input_question, chatbot],
167
+ [input_question, chatbot],
168
+ ).then(dolly_chat, [chatbot], chatbot)
169
+
170
+
171
+ clear.click(lambda: None, None, chatbot)
172
+
173
+
174
+ qa_input_question.submit(
175
+ get_qa_user_input,
176
+ [qa_input_question, qa_chatbot],
177
+ [qa_input_question, qa_chatbot],
178
+ ).then(qa_bot, [input_context, qa_chatbot], qa_chatbot)
179
+
180
+ qa_clear.click(lambda: None, None, qa_chatbot)
181
+
182
+ # reset the chatbot Q&A history when input context changes
183
+ input_context.change(fn=reset_chatbot, inputs=[], outputs=qa_chatbot)
184
+
185
+ example_1.click(
186
+ load_customer_support_example,
187
+ [],
188
+ [input_context, qa_input_question],
189
+ )
190
+
191
+ example_2.click(
192
+ load_databricks_doc_example,
193
+ [],
194
+ [input_context, qa_input_question],
195
+ )
196
+
197
+ if __name__ == "__main__":
198
+
199
+ demo.queue(concurrency_count=1,max_size=100).launch(max_threads=5,debug=True)