Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- app.py +44 -49
- instructions.png +0 -0
- requirements.txt +3 -4
app.py
CHANGED
@@ -1,46 +1,35 @@
|
|
1 |
import gradio as gr
|
2 |
-
import tensorflow as tf
|
3 |
import time
|
4 |
-
import warnings
|
5 |
-
import os
|
6 |
from transformers import (
|
7 |
-
AutoTokenizer,
|
8 |
-
AutoModelForSeq2SeqLM,
|
9 |
pipeline,
|
|
|
10 |
AutoModelForCausalLM,
|
11 |
-
|
|
|
|
|
12 |
)
|
13 |
|
14 |
-
# Warning Suppression
|
15 |
-
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
16 |
-
# warnings.filterwarnings('ignore', category=DeprecationWarning)
|
17 |
-
# warnings.filterwarnings('ignore', category=FutureWarning)
|
18 |
-
# warnings.filterwarnings('ignore', category=UserWarning)
|
19 |
-
# tf.get_logger().setLevel('ERROR')
|
20 |
-
|
21 |
# emotion classification
|
22 |
-
|
23 |
-
|
24 |
emotion_pipeline = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer)
|
25 |
|
26 |
# generation models
|
27 |
|
28 |
-
# no emotion
|
29 |
-
gpt2_model_no_emo = "aegrif/CIS6930_DAAGR_GPT2_NoEmo"
|
30 |
-
gpt2_tokenizer_no_emo = "aegrif/CIS6930_DAAGR_GPT2_NoEmo"
|
31 |
-
chatbot_gpt_no_emo = pipeline(model=gpt2_model_no_emo, tokenizer=gpt2_tokenizer_no_emo, pad_token_id=50256)
|
32 |
-
|
33 |
-
|
34 |
# decoder
|
35 |
-
gpt2_model_emo = "aegrif/CIS6930_DAAGR_GPT2_Emo"
|
36 |
-
gpt2_tokenizer = "aegrif/CIS6930_DAAGR_GPT2_Emo"
|
37 |
-
chatbot_gpt_emo = pipeline(model=gpt2_model_emo, tokenizer=gpt2_tokenizer, pad_token_id=50256)
|
38 |
-
|
39 |
|
40 |
# encoder-decoder
|
41 |
-
t5_model_emo = "aegrif/CIS6930_DAAGR_T5_Emo"
|
42 |
-
t5_tokenizer = "t5-small"
|
43 |
-
chatbot_t5_emo = pipeline(model=t5_model_emo, tokenizer=t5_tokenizer)
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
emotion_dict = {'disappointed': 0, 'annoyed': 1, 'excited': 2, 'afraid': 3, 'disgusted': 4, 'grateful': 5,
|
46 |
'impressed': 6, 'prepared': 7}
|
@@ -48,7 +37,6 @@ inverted_emotion_dict = {v: k for k, v in emotion_dict.items()}
|
|
48 |
|
49 |
|
50 |
def get_context(user_input):
|
51 |
-
|
52 |
output = emotion_pipeline(user_input)[0]['label']
|
53 |
|
54 |
context = inverted_emotion_dict.get(int(output[-1]))
|
@@ -56,23 +44,23 @@ def get_context(user_input):
|
|
56 |
return context
|
57 |
|
58 |
|
59 |
-
def
|
60 |
-
|
|
|
61 |
|
62 |
-
|
|
|
|
|
63 |
user_input,
|
64 |
max_new_tokens=40,
|
65 |
num_return_sequences=1,
|
66 |
do_sample=True,
|
67 |
-
temperature=0.
|
68 |
renormalize_logits=True,
|
69 |
-
exponential_decay_length_penalty=(
|
70 |
no_repeat_ngram_size=3,
|
71 |
repetition_penalty=1.5
|
72 |
-
)
|
73 |
-
|
74 |
-
# Decode the generated response
|
75 |
-
bot_response = output[0]['generated_text'].split("<|bot|>")[1].strip()
|
76 |
|
77 |
return bot_response
|
78 |
|
@@ -89,9 +77,9 @@ def predict_gpt2(user_input, history):
|
|
89 |
max_new_tokens=40,
|
90 |
num_return_sequences=1,
|
91 |
do_sample=True,
|
92 |
-
temperature=0.
|
93 |
renormalize_logits=True,
|
94 |
-
exponential_decay_length_penalty=(
|
95 |
no_repeat_ngram_size=3,
|
96 |
repetition_penalty=1.5
|
97 |
)
|
@@ -110,12 +98,13 @@ def predict_t5(user_input, history):
|
|
110 |
# Generate a response using the T5 model
|
111 |
bot_response = chatbot_t5_emo(
|
112 |
user_input,
|
113 |
-
max_new_tokens=
|
|
|
114 |
num_return_sequences=1,
|
115 |
do_sample=True,
|
116 |
-
temperature=0.
|
117 |
renormalize_logits=True,
|
118 |
-
exponential_decay_length_penalty=(
|
119 |
no_repeat_ngram_size=3,
|
120 |
repetition_penalty=1.5
|
121 |
)[0]['generated_text']
|
@@ -127,13 +116,14 @@ def user(user_message, history):
|
|
127 |
return "", history + [[user_message, None]]
|
128 |
|
129 |
|
130 |
-
def
|
131 |
user_message = history[-1][0]
|
132 |
-
bot_message =
|
133 |
history[-1][1] = bot_message
|
134 |
time.sleep(1)
|
135 |
return history
|
136 |
|
|
|
137 |
def gpt2_bot(history):
|
138 |
user_message = history[-1][0]
|
139 |
bot_message = predict_gpt2(user_message, history)
|
@@ -151,19 +141,24 @@ def t5_bot(history):
|
|
151 |
|
152 |
|
153 |
with gr.Blocks() as demo:
|
|
|
|
|
154 |
with gr.Row():
|
155 |
with gr.Column():
|
156 |
-
chatbot1 = gr.Chatbot().style()
|
157 |
msg1 = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
|
158 |
with gr.Column():
|
159 |
-
chatbot2 = gr.Chatbot().style()
|
160 |
msg2 = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
|
161 |
with gr.Column():
|
162 |
-
chatbot3 = gr.Chatbot().style()
|
163 |
msg3 = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
|
|
|
|
|
|
|
164 |
|
165 |
msg1.submit(user, [msg1, chatbot1], [msg1, chatbot1], queue=False).then(
|
166 |
-
|
167 |
)
|
168 |
msg2.submit(user, [msg2, chatbot2], [msg2, chatbot2], queue=False).then(
|
169 |
gpt2_bot, chatbot2, chatbot2
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import time
|
|
|
|
|
3 |
from transformers import (
|
|
|
|
|
4 |
pipeline,
|
5 |
+
GenerationConfig,
|
6 |
AutoModelForCausalLM,
|
7 |
+
AutoTokenizer,
|
8 |
+
TFAutoModelForSeq2SeqLM,
|
9 |
+
TFAutoModelForSequenceClassification,
|
10 |
)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# emotion classification
|
13 |
+
emotion_model = TFAutoModelForSequenceClassification.from_pretrained("aegrif/CIS6930_DAAGR_Classification")
|
14 |
+
emotion_tokenizer = AutoTokenizer.from_pretrained("aegrif/CIS6930_DAAGR_Classification")
|
15 |
emotion_pipeline = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer)
|
16 |
|
17 |
# generation models
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# decoder
|
20 |
+
gpt2_model_emo = AutoModelForCausalLM.from_pretrained("aegrif/CIS6930_DAAGR_GPT2_Emo")
|
21 |
+
gpt2_tokenizer = AutoTokenizer.from_pretrained("aegrif/CIS6930_DAAGR_GPT2_Emo")
|
22 |
+
chatbot_gpt_emo = pipeline("text-generation", model=gpt2_model_emo, tokenizer=gpt2_tokenizer, pad_token_id=50256)
|
|
|
23 |
|
24 |
# encoder-decoder
|
25 |
+
t5_model_emo = TFAutoModelForSeq2SeqLM.from_pretrained("aegrif/CIS6930_DAAGR_T5_Emo")
|
26 |
+
t5_tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
27 |
+
chatbot_t5_emo = pipeline("text2text-generation", model=t5_model_emo, tokenizer=t5_tokenizer)
|
28 |
+
|
29 |
+
# no emotion
|
30 |
+
t5_model_no_emo = TFAutoModelForSeq2SeqLM.from_pretrained("aegrif/CIS6930_DAAGR_T5_NoEmo")
|
31 |
+
t5_model_no_emo.generation_config = GenerationConfig.from_pretrained("aegrif/CIS6930_DAAGR_T5_NoEmo")
|
32 |
+
chatbot_t5_no_emo = pipeline("text2text-generation", model=t5_model_no_emo, tokenizer=t5_tokenizer)
|
33 |
|
34 |
emotion_dict = {'disappointed': 0, 'annoyed': 1, 'excited': 2, 'afraid': 3, 'disgusted': 4, 'grateful': 5,
|
35 |
'impressed': 6, 'prepared': 7}
|
|
|
37 |
|
38 |
|
39 |
def get_context(user_input):
|
|
|
40 |
output = emotion_pipeline(user_input)[0]['label']
|
41 |
|
42 |
context = inverted_emotion_dict.get(int(output[-1]))
|
|
|
44 |
return context
|
45 |
|
46 |
|
47 |
+
def predict_t5_no_emo(user_input, history):
|
48 |
+
# Get the context from the user input
|
49 |
+
context = get_context(user_input)
|
50 |
|
51 |
+
user_input = f"question: {user_input} context: {context} </s>"
|
52 |
+
# Generate a response using the T5 model
|
53 |
+
bot_response = chatbot_t5_no_emo(
|
54 |
user_input,
|
55 |
max_new_tokens=40,
|
56 |
num_return_sequences=1,
|
57 |
do_sample=True,
|
58 |
+
temperature=0.9,
|
59 |
renormalize_logits=True,
|
60 |
+
exponential_decay_length_penalty=(20, 1),
|
61 |
no_repeat_ngram_size=3,
|
62 |
repetition_penalty=1.5
|
63 |
+
)[0]['generated_text']
|
|
|
|
|
|
|
64 |
|
65 |
return bot_response
|
66 |
|
|
|
77 |
max_new_tokens=40,
|
78 |
num_return_sequences=1,
|
79 |
do_sample=True,
|
80 |
+
temperature=0.9,
|
81 |
renormalize_logits=True,
|
82 |
+
exponential_decay_length_penalty=(20, 1.05),
|
83 |
no_repeat_ngram_size=3,
|
84 |
repetition_penalty=1.5
|
85 |
)
|
|
|
98 |
# Generate a response using the T5 model
|
99 |
bot_response = chatbot_t5_emo(
|
100 |
user_input,
|
101 |
+
max_new_tokens=60,
|
102 |
+
max_length=160,
|
103 |
num_return_sequences=1,
|
104 |
do_sample=True,
|
105 |
+
temperature=0.9,
|
106 |
renormalize_logits=True,
|
107 |
+
exponential_decay_length_penalty=(20, 1),
|
108 |
no_repeat_ngram_size=3,
|
109 |
repetition_penalty=1.5
|
110 |
)[0]['generated_text']
|
|
|
116 |
return "", history + [[user_message, None]]
|
117 |
|
118 |
|
119 |
+
def t5_bot_no_emo(history):
|
120 |
user_message = history[-1][0]
|
121 |
+
bot_message = predict_t5_no_emo(user_message, history)
|
122 |
history[-1][1] = bot_message
|
123 |
time.sleep(1)
|
124 |
return history
|
125 |
|
126 |
+
|
127 |
def gpt2_bot(history):
|
128 |
user_message = history[-1][0]
|
129 |
bot_message = predict_gpt2(user_message, history)
|
|
|
141 |
|
142 |
|
143 |
with gr.Blocks() as demo:
|
144 |
+
with gr.Row():
|
145 |
+
gr.Image("instructions.png", interactive=False, show_label=False)
|
146 |
with gr.Row():
|
147 |
with gr.Column():
|
148 |
+
chatbot1 = gr.Chatbot(label="Chatbot #1").style(height=500)
|
149 |
msg1 = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
|
150 |
with gr.Column():
|
151 |
+
chatbot2 = gr.Chatbot(label="Chatbot #2").style(height=500)
|
152 |
msg2 = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
|
153 |
with gr.Column():
|
154 |
+
chatbot3 = gr.Chatbot(label="Chatbot #3").style(height=500)
|
155 |
msg3 = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
|
156 |
+
with gr.Row():
|
157 |
+
gr.HTML('<p style="font-size:150%; font-family: "Playfair Display", "Didot", "Times New Roman">Once you have finished interacting with the agents, please follow the link below to complete a short survey about your experience.</p>'
|
158 |
+
'<p style="font-size:125%; font-family: "Playfair Display", "Didot", "Times New Roman"><a href=''https://docs.google.com/forms/d/1SICfdLcj_jbDeObZ6lxZ7b8a1L7fsZjX_ETfWc5o4VQ/edit'' target=''_blank''>https://docs.google.com/forms/d/1SICfdLcj_jbDeObZ6lxZ7b8a1L7fsZjX_ETfWc5o4VQ/edit</a></p')
|
159 |
|
160 |
msg1.submit(user, [msg1, chatbot1], [msg1, chatbot1], queue=False).then(
|
161 |
+
t5_bot_no_emo, chatbot1, chatbot1
|
162 |
)
|
163 |
msg2.submit(user, [msg2, chatbot2], [msg2, chatbot2], queue=False).then(
|
164 |
gpt2_bot, chatbot2, chatbot2
|
instructions.png
ADDED
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
gradio==3.
|
2 |
-
transformers==4.27.
|
3 |
-
|
4 |
-
tensorflow==2.11.1
|
|
|
1 |
+
gradio==3.24.1
|
2 |
+
transformers==4.27.4
|
3 |
+
protobuf==3.20
|
|