ethanlim04
commited on
Commit
•
967ebb5
1
Parent(s):
3f787be
Upload 3 files
Browse files
main.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import matplotlib
|
3 |
+
import models
|
4 |
+
import utils
|
5 |
+
|
6 |
+
def infer(gt: str, data: str) -> matplotlib.figure:
|
7 |
+
nli_res = models.compute_metric(gt, data)
|
8 |
+
tone_res = models.compare_tone(gt, data)
|
9 |
+
res_text = ""
|
10 |
+
if (nli_res["label"] == "neutral"):
|
11 |
+
res_text += "Model's response is unrelated to the Ground Truth"
|
12 |
+
if (nli_res["label"] == "contradiction"):
|
13 |
+
res_text += "Model's response contradicts the Ground Truth"
|
14 |
+
if (nli_res["label"] == "entailment"):
|
15 |
+
res_text += "Model's response is consistant with the Ground Truth"
|
16 |
+
return res_text, utils.create_pie_chart_nli(nli_res), utils.plot_tones(tone_res)
|
17 |
+
|
18 |
+
examples = [["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "Bi-encoders are superior to cross-encoders"],
|
19 |
+
["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "The cosine similarity function can be used to compare the outputs of a bi-encoder"],
|
20 |
+
["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "Bi-encoders are outperformed by cross-encoders in the task of relationship analysis"],
|
21 |
+
["Birds can fly. There are fish in the sea.", "Fish inhabit the ocean. Birds can aviate."],
|
22 |
+
["Birds can fly. There are fish in the sea.", "Fish inhabit the ocean. Birds can not aviate."]]
|
23 |
+
app = gr.Interface(fn=infer, inputs=[gr.Textbox(label="Ground Truth"), gr.Textbox(label="Model Response")], examples=examples, outputs=[gr.Textbox(label="Result"), gr.Plot(label="Comparison with GT"), gr.Plot(label="Difference in Tone")])
|
24 |
+
app.launch()
|
models.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import CrossEncoder
|
2 |
+
from transformers import AutoModelForSequenceClassification
|
3 |
+
from transformers import AutoTokenizer, AutoConfig
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def softmax(x):
|
7 |
+
e_x = np.exp(x - np.max(x))
|
8 |
+
return e_x / e_x.sum(axis=0)
|
9 |
+
|
10 |
+
# 90.04% accuracy on MNLI mismatched set
|
11 |
+
nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base')
|
12 |
+
|
13 |
+
def compute_metric(ground_truth: str, inference: str) -> dict:
|
14 |
+
scores = nli_model.predict([ground_truth, inference], apply_softmax=True)
|
15 |
+
label = ['contradiction', 'entailment', 'neutral'][scores.argmax()]
|
16 |
+
return {
|
17 |
+
'label': label,
|
18 |
+
'contradiction': scores[0],
|
19 |
+
'entailment': scores[1],
|
20 |
+
'neutral': scores[2],
|
21 |
+
}
|
22 |
+
|
23 |
+
def _compare_tone(text: str) -> dict:
|
24 |
+
# Trained on ~124M Tweets for sentiment analysis
|
25 |
+
model_name = r"cardiffnlp/twitter-roberta-base-sentiment-latest"
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
27 |
+
config = AutoConfig.from_pretrained(model_name)
|
28 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
29 |
+
|
30 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
31 |
+
output = model(**encoded_input)
|
32 |
+
scores = output[0][0].detach().numpy()
|
33 |
+
scores = softmax(scores)
|
34 |
+
ranking = np.argsort(scores)
|
35 |
+
ranking = ranking[::-1]
|
36 |
+
result = {}
|
37 |
+
for i in range(scores.shape[0]):
|
38 |
+
l = config.id2label[ranking[i]]
|
39 |
+
s = scores[ranking[i]]
|
40 |
+
result[l] = np.round(float(s), 4)
|
41 |
+
|
42 |
+
return result
|
43 |
+
|
44 |
+
def compare_tone(ground_truth: str, inference: str) -> dict:
|
45 |
+
gt = _compare_tone(ground_truth)
|
46 |
+
model_res = _compare_tone(inference)
|
47 |
+
return {"gt": gt, "model": model_res}
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are not cats."))
|
51 |
+
print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are cats."))
|
52 |
+
print(compare_tone("This is neutural", "Wtf"))
|
utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
def create_pie_chart_nli(data: dict) -> matplotlib.figure:
|
5 |
+
labels = ["neutral", "contradiction", "entailment"]
|
6 |
+
sizes = [data[label] for label in labels]
|
7 |
+
colors = ["gray", "red", "green"]
|
8 |
+
|
9 |
+
fig, ax = plt.subplots()
|
10 |
+
|
11 |
+
ax.set_title("Comparison with GT")
|
12 |
+
ax.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%')
|
13 |
+
|
14 |
+
ax.axis('equal')
|
15 |
+
|
16 |
+
return fig
|
17 |
+
|
18 |
+
def plot_tones(data: dict) -> matplotlib.figure:
|
19 |
+
keys = data["gt"].keys()
|
20 |
+
|
21 |
+
fig, ax = plt.subplots()
|
22 |
+
ax.set_title("Tone")
|
23 |
+
ax.bar(x=keys, height=[data["gt"][key] for key in keys], color="b", label="Ground Truth", width=0.7)
|
24 |
+
ax.bar(x=keys, height=[data["model"][key] for key in keys], color="r", alpha=0.5, label="Model response", width=0.5)
|
25 |
+
|
26 |
+
fig.legend()
|
27 |
+
|
28 |
+
return fig
|