mozharovsky commited on
Commit
8de1c3a
1 Parent(s): 1c26bab

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +105 -0
README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python T5 base model
2
+
3
+ Pre-trained model on CodeSearchNet Python dataset using a span-masking objective. The training objective and model were introduced in this paper and first released in this repository. PyT5 model used git-t5 framework build on top of JAX/Flax to pre-train the model on a TPU v3-8 node.
4
+
5
+ # How to use
6
+
7
+ You can use this model to denoise span-masked sequences. Note, that you'll need to add some boilerplate code for adding the noise to your sequences.
8
+
9
+ Add the following code for encoding an input text:
10
+ ```python
11
+ from typing import Dict, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ from transformers import PreTrainedTokenizerBase
16
+
17
+ from git_t5.data import DataCollatorForT5MLM
18
+
19
+
20
+ def encode(
21
+ tokenizer: PreTrainedTokenizerBase,
22
+ text: str,
23
+ noise_density: float = 0.15,
24
+ mean_noise_span_length: float = 3.0,
25
+ extra_tokens_per_span_inputs: int = 1,
26
+ extra_tokens_per_span_targets: int = 1,
27
+ seed: Optional[int] = None,
28
+ ) -> Tuple[Dict[str, torch.Tensor], int]:
29
+ def compute_lengths(tokens_length: int) -> Tuple[int, int]:
30
+ num_noise_tokens = int(round(tokens_length * noise_density))
31
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
32
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
33
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
34
+ # and one EOS token.
35
+ return (
36
+ num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1,
37
+ num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1,
38
+ )
39
+
40
+ encoding = tokenizer(
41
+ text,
42
+ truncation=False,
43
+ return_attention_mask=False,
44
+ return_length=True,
45
+ )
46
+
47
+ input_length = encoding.pop("length")
48
+ input_length = input_length[0]
49
+ input_length, target_length = compute_lengths(input_length)
50
+
51
+ np.random.seed(seed)
52
+
53
+ data_collator = DataCollatorForT5MLM(
54
+ tokenizer=tokenizer,
55
+ noise_density=noise_density,
56
+ mean_noise_span_length=mean_noise_span_length,
57
+ input_length=input_length,
58
+ target_length=target_length,
59
+ eos_token_id=tokenizer.eos_token_id,
60
+ pad_token_id=tokenizer.pad_token_id,
61
+ decoder_start_token_id=tokenizer.pad_token_id,
62
+ sentinel_token_id=tokenizer.convert_tokens_to_ids("<extra_id_0>"),
63
+ )
64
+
65
+ batch = data_collator([encoding]) # type: ignore
66
+ batch = {key: torch.tensor(val) for key, val in batch.items()}
67
+
68
+ return batch, target_length
69
+ ```
70
+
71
+ Next, download the model and tokenizer:
72
+ ```python
73
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,
74
+
75
+ model = AutoModelForSeq2SeqLM.from_pretrained("formermagic/pyt5-base")
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained("formermagic/pyt5-base")
78
+ ```
79
+
80
+ Finally, encode your input and generate the output sequence:
81
+ ```python
82
+ text = """
83
+ def alias(self, annotationtype, set, fallback=False):
84
+ if inspect.isclass(annotationtype): annotationtype = annotationtype.ANNOTATIONTYPE
85
+ if annotationtype in self.set_alias and set in self.set_alias[annotationtype]:
86
+ return self.set_alias[annotationtype][set]
87
+ elif fallback:
88
+ return set
89
+ else:
90
+ raise KeyError("No alias for set " + set)
91
+ """
92
+
93
+ batch, max_length = encode(tokenizer, text, seed=22)
94
+ outputs = model.generate(batch["input_ids"], max_length=max_length, num_beams=1)
95
+ print(tokenizer.batch_decode(outputs[..., 1:]))
96
+ print(tokenizer.batch_decode(batch["labels"]))
97
+ ```
98
+
99
+ You should see the following output:
100
+ ```shell
101
+ ['<extra_id_0>, fallback=<extra_id_1> inspect<extra_id_2>.set_alias<extra_id_3> return self.set<extra_id_4>) def fallback']
102
+ ['<extra_id_0>, fallback=<extra_id_1> inspect<extra_id_2>.set_alias<extra_id_3> return self.set<extra_id_4>) </s></s>']
103
+ ```
104
+
105
+ As you can see, the predicted result is very close to the target sequence.