Spaces:
Sleeping
Sleeping
vickeee465
commited on
Commit
·
ac4450b
1
Parent(s):
2e2f372
added figure logic
Browse files
app.py
CHANGED
|
@@ -5,6 +5,8 @@ import numpy as np
|
|
| 5 |
from transformers import AutoModelForSequenceClassification
|
| 6 |
from transformers import AutoTokenizer
|
| 7 |
import gradio as gr
|
|
|
|
|
|
|
| 8 |
|
| 9 |
PATH = '/data/' # at least 150GB storage needs to be attached
|
| 10 |
os.environ['TRANSFORMERS_CACHE'] = PATH
|
|
@@ -70,6 +72,25 @@ def get_most_probable_label(probs):
|
|
| 70 |
probability = f"{round(100 * probs.max(), 2)}%"
|
| 71 |
return label, probability
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def predict_wrapper(text, language):
|
| 74 |
model_id = build_huggingface_path(language)
|
| 75 |
tokenizer_id = "xlm-roberta-large"
|
|
@@ -78,13 +99,17 @@ def predict_wrapper(text, language):
|
|
| 78 |
sentences = split_sentences(text, spacy_model)
|
| 79 |
|
| 80 |
results = []
|
|
|
|
| 81 |
for sentence in sentences:
|
| 82 |
probs = predict(sentence, model_id, tokenizer_id)
|
| 83 |
label, probability = get_most_probable_label(probs)
|
| 84 |
results.append([sentence, label, probability])
|
|
|
|
| 85 |
|
|
|
|
| 86 |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
|
| 87 |
-
return results, output_info
|
|
|
|
| 88 |
|
| 89 |
with gr.Blocks() as demo:
|
| 90 |
with gr.Row():
|
|
@@ -108,7 +133,7 @@ with gr.Blocks() as demo:
|
|
| 108 |
predict_button.click(
|
| 109 |
fn=predict_wrapper,
|
| 110 |
inputs=[input_text, language_choice],
|
| 111 |
-
outputs=[result_table, model_info]
|
| 112 |
)
|
| 113 |
|
| 114 |
if __name__ == "__main__":
|
|
|
|
| 5 |
from transformers import AutoModelForSequenceClassification
|
| 6 |
from transformers import AutoTokenizer
|
| 7 |
import gradio as gr
|
| 8 |
+
import matplotlib.plt as pyplot
|
| 9 |
+
import seaborn as sns
|
| 10 |
|
| 11 |
PATH = '/data/' # at least 150GB storage needs to be attached
|
| 12 |
os.environ['TRANSFORMERS_CACHE'] = PATH
|
|
|
|
| 72 |
probability = f"{round(100 * probs.max(), 2)}%"
|
| 73 |
return label, probability
|
| 74 |
|
| 75 |
+
def prepare_heatmap_data(data):
|
| 76 |
+
heatmap_data = pd.DataFrame(0, index=range(len(data)), columns=emotion_mapping.values())
|
| 77 |
+
for idx, item in enumerate(data):
|
| 78 |
+
for idy, confidence in enumerate(item["emotions"]):
|
| 79 |
+
emotion = emotion_mapping[idy]
|
| 80 |
+
heatmap_data.at[idx, emotion] = confidence
|
| 81 |
+
heatmap_data.index = [item["sentence"] for item in data]
|
| 82 |
+
return heatmap_data
|
| 83 |
+
|
| 84 |
+
def plot_emotion_heatmap(data):
|
| 85 |
+
heatmap_data = prepare_heatmap_data(data)
|
| 86 |
+
fig = plt.figure(figsize=(10, len(data) * 0.5 + 2))
|
| 87 |
+
sns.heatmap(heatmap_data, annot=True, cmap="coolwarm", cbar=True, linewidths=0.5, linecolor='gray')
|
| 88 |
+
plt.title("Emotion Confidence Heatmap")
|
| 89 |
+
plt.xlabel("Emotions")
|
| 90 |
+
plt.ylabel("Sentences")
|
| 91 |
+
plt.tight_layout()
|
| 92 |
+
return fig
|
| 93 |
+
|
| 94 |
def predict_wrapper(text, language):
|
| 95 |
model_id = build_huggingface_path(language)
|
| 96 |
tokenizer_id = "xlm-roberta-large"
|
|
|
|
| 99 |
sentences = split_sentences(text, spacy_model)
|
| 100 |
|
| 101 |
results = []
|
| 102 |
+
results_heatmap = []
|
| 103 |
for sentence in sentences:
|
| 104 |
probs = predict(sentence, model_id, tokenizer_id)
|
| 105 |
label, probability = get_most_probable_label(probs)
|
| 106 |
results.append([sentence, label, probability])
|
| 107 |
+
results_heatmap.append({"sentence":sentence, "emotions":probs})
|
| 108 |
|
| 109 |
+
figure = plot_emotion_heatmap(prepare_heatmap_data(results_heatmap))
|
| 110 |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
|
| 111 |
+
return results, figure, output_info
|
| 112 |
+
|
| 113 |
|
| 114 |
with gr.Blocks() as demo:
|
| 115 |
with gr.Row():
|
|
|
|
| 133 |
predict_button.click(
|
| 134 |
fn=predict_wrapper,
|
| 135 |
inputs=[input_text, language_choice],
|
| 136 |
+
outputs=[result_table, "plot", model_info]
|
| 137 |
)
|
| 138 |
|
| 139 |
if __name__ == "__main__":
|