File size: 2,037 Bytes
d1d2d2e
 
 
 
 
d6cdf3a
d1d2d2e
 
79a9c91
d1d2d2e
 
 
79a9c91
d1d2d2e
 
 
 
 
 
6f13dd0
 
d1d2d2e
 
d6cdf3a
 
d1d2d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1347ad9
 
 
 
 
 
 
d1d2d2e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
---
language: "en"
tags:
- counterfactual generation
widget:
- text: "It is great for kids. <|perturb|> [negation] It [BLANK] great for kids. [SEP]"
---

# Polyjuice

## Model description

This is a ported version of [Polyjuice](https://homes.cs.washington.edu/~wtshuang/static/papers/2021-arxiv-polyjuice.pdf), the general-purpose counterfactual generator.

#### How to use

```python
from transformers import AutoTokenizer, AutoModelWithLMHead

tokenizer = AutoTokenizer.from_pretrained("uw-hai/polyjuice")
model = AutoModelWithLMHead.from_pretrained("uw-hai/polyjuice")


prompt_text = "A dog is embraced by the woman. <|perturb|> [negation] A dog is [BLANK] the woman."
# or try: "A dog is embraced by the woman. <|perturb|> [restructure] A dog is [BLANK] the woman."
perturb_tok, end_tok = "<|perturb|>", "<|endoftext|>"
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
input_ids = encoded_prompt
stop_token= '\n'
repetition_penalty=1
output_sequences = model.generate(
    input_ids=input_ids,
    max_length=100 + len(encoded_prompt[0]),
    temperature=0.1,
    num_beams=10,
    num_return_sequences=3)

if len(output_sequences.shape) > 2:
    output_sequences.squeeze_()

for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
    generated_sequence = generated_sequence.tolist()
    # Decode text
    text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
    # Remove all text after the stop token
    text = text[: text.find(stop_token) if stop_token and text.find(stop_token)>-1 else None]
    text = text[: text.find(end_tok) if end_tok and text.find(end_tok)>-1 else None]
    print(text)
```

### BibTeX entry and citation info

```bibtex
@misc{wu2021polyjuice,
    title={Polyjuice: Automated, General-purpose Counterfactual Generation}, 
    author={Tongshuang Wu and Marco Tulio Ribeiro and Jeffrey Heer and Daniel S. Weld},
    year={2021},
    eprint={2101.00288},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
```