avishek-018 commited on
Commit
ec9a5c9
1 Parent(s): 849bcc6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+ import numpy as np
3
+ import gradio as gr
4
+ import transformers
5
+ import tensorflow as tf
6
+
7
+ class BertSemanticDataGenerator(tf.keras.utils.Sequence):
8
+ """Generates batches of data."""
9
+ def __init__(
10
+ self,
11
+ sentence_pairs,
12
+ labels,
13
+ batch_size=32,
14
+ shuffle=True,
15
+ include_targets=True,
16
+ ):
17
+ self.sentence_pairs = sentence_pairs
18
+ self.labels = labels
19
+ self.shuffle = shuffle
20
+ self.batch_size = batch_size
21
+ self.include_targets = include_targets
22
+ # Load our BERT Tokenizer to encode the text.
23
+ # We will use base-base-uncased pretrained model.
24
+ self.tokenizer = transformers.BertTokenizer.from_pretrained(
25
+ "bert-base-uncased", do_lower_case=True
26
+ )
27
+ self.indexes = np.arange(len(self.sentence_pairs))
28
+ self.on_epoch_end()
29
+
30
+ def __len__(self):
31
+ # Denotes the number of batches per epoch.
32
+ return len(self.sentence_pairs) // self.batch_size
33
+
34
+ def __getitem__(self, idx):
35
+ # Retrieves the batch of index.
36
+ indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]
37
+ sentence_pairs = self.sentence_pairs[indexes]
38
+
39
+ # With BERT tokenizer's batch_encode_plus batch of both the sentences are
40
+ # encoded together and separated by [SEP] token.
41
+ encoded = self.tokenizer.batch_encode_plus(
42
+ sentence_pairs.tolist(),
43
+ add_special_tokens=True,
44
+ max_length=128,
45
+ return_attention_mask=True,
46
+ return_token_type_ids=True,
47
+ pad_to_max_length=True,
48
+ return_tensors="tf",
49
+ )
50
+
51
+ # Convert batch of encoded features to numpy array.
52
+ input_ids = np.array(encoded["input_ids"], dtype="int32")
53
+ attention_masks = np.array(encoded["attention_mask"], dtype="int32")
54
+ token_type_ids = np.array(encoded["token_type_ids"], dtype="int32")
55
+
56
+ # Set to true if data generator is used for training/validation.
57
+ if self.include_targets:
58
+ labels = np.array(self.labels[indexes], dtype="int32")
59
+ return [input_ids, attention_masks, token_type_ids], labels
60
+ else:
61
+ return [input_ids, attention_masks, token_type_ids]
62
+
63
+ model = from_pretrained_keras("avishek-018/bert-semantic-similarity")
64
+ labels = ["contradiction", "entailment", "neutral"]
65
+
66
+ def predict(sentence1, sentence2):
67
+ sentence_pairs = np.array([[str(sentence1), str(sentence2)]])
68
+ test_data = BertSemanticDataGenerator(
69
+ sentence_pairs, labels=None, batch_size=1, shuffle=False, include_targets=False,
70
+ )
71
+ probs = model.predict(test_data[0])[0]
72
+
73
+ labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}
74
+ return labels_probs
75
+
76
+ #idx = np.argmax(proba)
77
+ #proba = f"{proba[idx]*100:.2f}%"
78
+ #pred = labels[idx]
79
+ #return f'The semantic similarity of two input sentences is {pred} with {proba} of probability'
80
+
81
+ inputs = [
82
+ gr.Audio(source = "upload", label='Upload audio file', type="filepath"),
83
+ ]
84
+
85
+ examples = [["Two women are observing something together.", "Two women are standing with their eyes closed."],
86
+ ["A smiling costumed woman is holding an umbrella", "A happy woman in a fairy costume holds an umbrella"],
87
+ ["A soccer game with multiple males playing", "Some men are playing a sport"],
88
+ ]
89
+
90
+ gr.Interface(
91
+ fn=predict,
92
+ title="Semantic Similarity with BERT",
93
+ description = "Natural Language Inference by fine-tuning BERT model on SNLI Corpus 📰",
94
+ inputs=["text", "text"],
95
+ examples=examples,
96
+ #outputs=gr.Textbox(label='Prediction'),
97
+ outputs=gr.outputs.Label(num_top_classes=3, label='Semantic similarity'),
98
+ cache_examples=True,
99
+ ).launch(debug=True, enable_queue=True)