ethanlim04 commited on
Commit
967ebb5
1 Parent(s): 3f787be

Upload 3 files

Browse files
Files changed (3) hide show
  1. main.py +24 -0
  2. models.py +52 -0
  3. utils.py +28 -0
main.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib
3
+ import models
4
+ import utils
5
+
6
+ def infer(gt: str, data: str) -> matplotlib.figure:
7
+ nli_res = models.compute_metric(gt, data)
8
+ tone_res = models.compare_tone(gt, data)
9
+ res_text = ""
10
+ if (nli_res["label"] == "neutral"):
11
+ res_text += "Model's response is unrelated to the Ground Truth"
12
+ if (nli_res["label"] == "contradiction"):
13
+ res_text += "Model's response contradicts the Ground Truth"
14
+ if (nli_res["label"] == "entailment"):
15
+ res_text += "Model's response is consistant with the Ground Truth"
16
+ return res_text, utils.create_pie_chart_nli(nli_res), utils.plot_tones(tone_res)
17
+
18
+ examples = [["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "Bi-encoders are superior to cross-encoders"],
19
+ ["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "The cosine similarity function can be used to compare the outputs of a bi-encoder"],
20
+ ["Cross-encoders are better than bi-encoders for analyzing the relationship betwen texts", "Bi-encoders are outperformed by cross-encoders in the task of relationship analysis"],
21
+ ["Birds can fly. There are fish in the sea.", "Fish inhabit the ocean. Birds can aviate."],
22
+ ["Birds can fly. There are fish in the sea.", "Fish inhabit the ocean. Birds can not aviate."]]
23
+ app = gr.Interface(fn=infer, inputs=[gr.Textbox(label="Ground Truth"), gr.Textbox(label="Model Response")], examples=examples, outputs=[gr.Textbox(label="Result"), gr.Plot(label="Comparison with GT"), gr.Plot(label="Difference in Tone")])
24
+ app.launch()
models.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+ from transformers import AutoModelForSequenceClassification
3
+ from transformers import AutoTokenizer, AutoConfig
4
+ import numpy as np
5
+
6
+ def softmax(x):
7
+ e_x = np.exp(x - np.max(x))
8
+ return e_x / e_x.sum(axis=0)
9
+
10
+ # 90.04% accuracy on MNLI mismatched set
11
+ nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base')
12
+
13
+ def compute_metric(ground_truth: str, inference: str) -> dict:
14
+ scores = nli_model.predict([ground_truth, inference], apply_softmax=True)
15
+ label = ['contradiction', 'entailment', 'neutral'][scores.argmax()]
16
+ return {
17
+ 'label': label,
18
+ 'contradiction': scores[0],
19
+ 'entailment': scores[1],
20
+ 'neutral': scores[2],
21
+ }
22
+
23
+ def _compare_tone(text: str) -> dict:
24
+ # Trained on ~124M Tweets for sentiment analysis
25
+ model_name = r"cardiffnlp/twitter-roberta-base-sentiment-latest"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ config = AutoConfig.from_pretrained(model_name)
28
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
29
+
30
+ encoded_input = tokenizer(text, return_tensors='pt')
31
+ output = model(**encoded_input)
32
+ scores = output[0][0].detach().numpy()
33
+ scores = softmax(scores)
34
+ ranking = np.argsort(scores)
35
+ ranking = ranking[::-1]
36
+ result = {}
37
+ for i in range(scores.shape[0]):
38
+ l = config.id2label[ranking[i]]
39
+ s = scores[ranking[i]]
40
+ result[l] = np.round(float(s), 4)
41
+
42
+ return result
43
+
44
+ def compare_tone(ground_truth: str, inference: str) -> dict:
45
+ gt = _compare_tone(ground_truth)
46
+ model_res = _compare_tone(inference)
47
+ return {"gt": gt, "model": model_res}
48
+
49
+ if __name__ == "__main__":
50
+ print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are not cats."))
51
+ print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are cats."))
52
+ print(compare_tone("This is neutural", "Wtf"))
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import matplotlib.pyplot as plt
3
+
4
+ def create_pie_chart_nli(data: dict) -> matplotlib.figure:
5
+ labels = ["neutral", "contradiction", "entailment"]
6
+ sizes = [data[label] for label in labels]
7
+ colors = ["gray", "red", "green"]
8
+
9
+ fig, ax = plt.subplots()
10
+
11
+ ax.set_title("Comparison with GT")
12
+ ax.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%')
13
+
14
+ ax.axis('equal')
15
+
16
+ return fig
17
+
18
+ def plot_tones(data: dict) -> matplotlib.figure:
19
+ keys = data["gt"].keys()
20
+
21
+ fig, ax = plt.subplots()
22
+ ax.set_title("Tone")
23
+ ax.bar(x=keys, height=[data["gt"][key] for key in keys], color="b", label="Ground Truth", width=0.7)
24
+ ax.bar(x=keys, height=[data["model"][key] for key in keys], color="r", alpha=0.5, label="Model response", width=0.5)
25
+
26
+ fig.legend()
27
+
28
+ return fig