patrickvonplaten
commited on
Commit
•
685ce0f
1
Parent(s):
cf42a95
up
Browse files- create_confidence_scores.py +30 -0
- example.py +21 -0
- get_sample_code.py +29 -0
create_confidence_scores.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
3 |
+
from datasets import load_dataset
|
4 |
+
import datasets
|
5 |
+
import torch
|
6 |
+
|
7 |
+
model = Wav2Vec2ForCTC.from_pretrained("facebook/data2vec-audio-base-10m")
|
8 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-10m")
|
9 |
+
|
10 |
+
minds14 = load_dataset("PolyAI/minds14", "en-US", split="train")
|
11 |
+
minds14 = minds14.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
12 |
+
|
13 |
+
input_values = processor(minds14[0]["audio"]["array"], return_tensors="pt", sampling_rate=minds14[0]["audio"]["sampling_rate"]).input_values
|
14 |
+
|
15 |
+
with torch.no_grad():
|
16 |
+
logits = model(input_values).logits
|
17 |
+
scores = torch.nn.functional.softmax(logits, dim=-1)
|
18 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
19 |
+
pred_scores = scores.gather(1, pred_ids.unsqueeze(-1))[:, :, 0]
|
20 |
+
|
21 |
+
output = processor.batch_decode(pred_ids, output_word_offsets=True)
|
22 |
+
|
23 |
+
# add confidence
|
24 |
+
def confidence_score(word_dict):
|
25 |
+
probs = pred_scores[0, word_dict["start_offset"]: word_dict["end_offset"]]
|
26 |
+
return torch.mean(probs)
|
27 |
+
|
28 |
+
output["confidence_scores"] = {d["word"]: confidence_score(d) for d in output.word_offsets[0]}
|
29 |
+
|
30 |
+
print(output["confidence_scores"])
|
example.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from transformers import RobertaTokenizer, RobertaForTokenClassification
|
3 |
+
import torch
|
4 |
+
tokenizer = RobertaTokenizer.from_pretrained("Jean-Baptiste/roberta-large-ner-english")
|
5 |
+
model = RobertaForTokenClassification.from_pretrained("Jean-Baptiste/roberta-large-ner-english")
|
6 |
+
inputs = tokenizer("HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt")
|
7 |
+
|
8 |
+
with torch.no_grad():
|
9 |
+
logits = model(**inputs).logits
|
10 |
+
|
11 |
+
predicted_token_class_ids = logits.argmax(-1)
|
12 |
+
# Note that tokens are classified rather then input words which means that
|
13 |
+
# there might be more predicted token classes than words.
|
14 |
+
# Multiple token classes might account for the same word
|
15 |
+
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
|
16 |
+
assert predicted_tokens_classes == ['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']
|
17 |
+
|
18 |
+
labels = predicted_token_class_ids
|
19 |
+
loss = model(**inputs, labels=labels).loss
|
20 |
+
ab = round(loss.item(), 2)
|
21 |
+
import ipdb; ipdb.set_trace()
|
get_sample_code.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import sys
|
3 |
+
name = sys.argv[1]
|
4 |
+
processor_class = sys.argv[2]
|
5 |
+
model_class = sys.argv[3]
|
6 |
+
checkpoint = sys.argv[4]
|
7 |
+
mask = sys.argv[5]
|
8 |
+
|
9 |
+
with open("/home/patrick/transformers/src/transformers/file_utils.py", "r") as f:
|
10 |
+
lines = f.readlines()
|
11 |
+
|
12 |
+
format_dict = {"processor_class": processor_class, "model_class": model_class, "checkpoint": checkpoint, "mask": mask}
|
13 |
+
with open("./example.py", "w") as f:
|
14 |
+
f.write("#!/usr/bin/env python3\n")
|
15 |
+
|
16 |
+
is_in = False
|
17 |
+
is_in_code = False
|
18 |
+
for line in lines:
|
19 |
+
if line.strip() == (name + ' = r"""'):
|
20 |
+
is_in = True
|
21 |
+
if is_in and "```python" in line:
|
22 |
+
is_in_code = True
|
23 |
+
if is_in_code:
|
24 |
+
if ">>>" in line:
|
25 |
+
f.write(line.split(">>> ")[-1].format(**format_dict))
|
26 |
+
elif "..." in line:
|
27 |
+
f.write(line.split("... ")[-1].format(**format_dict))
|
28 |
+
if is_in_code and (line.strip() == '"""'):
|
29 |
+
is_in = is_in_code = False
|