toxicity_debias_pipeline / my_toxicity_debiaser.py
shainaraza's picture
Update my_toxicity_debiaser.py
32bed4f
import torch
class MyToxicityDebiaserPipeline(object):
def __init__(self, model, tokenizer, gpt_model, gpt_tokenizer, device=None, **kwargs):
self.model = model.to(device)
self.tokenizer = tokenizer
self.gpt_model = gpt_model.to(device)
self.gpt_tokenizer = gpt_tokenizer
self.device = device if device is not None else torch.device("cpu")
def _forward(self, inputs):
text = inputs["text"]
encoded = self.tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(self.device)
logits = self.model(encoded.input_ids, encoded.attention_mask).logits
probs = torch.softmax(logits, dim=-1)
label = torch.argmax(probs, dim=-1).item()
return {"label": label, "probabilities": probs.tolist(), "text_input_ids": encoded.input_ids}
def _sanitize_parameters(self, **kwargs):
return kwargs, {}, {}
def preprocess(self, inputs):
return {"text": inputs}
def postprocess(self, outputs):
label = outputs["label"]
if label == 0:
prompt = "This comment is non-toxic."
elif label == 1:
prompt = "This comment is toxic but has been debiased as follows:"
text = self.tokenizer.decode(outputs["text_input_ids"][0])
debias_prompt = f"Remove the offensive words and biased tone and write the same sentence nicely: {text}"
encoded_debias_prompt = self.gpt_tokenizer.encode_plus(debias_prompt, return_tensors="pt").to(self.device)
generated = self.gpt_model.generate(
input_ids=encoded_debias_prompt["input_ids"],
attention_mask=encoded_debias_prompt["attention_mask"],
do_sample=True,
max_length=100,
top_p=0.95,
temperature=0.7,
pad_token_id=self.gpt_tokenizer.pad_token_id,
eos_token_id=self.gpt_tokenizer.eos_token_id,
early_stopping=True,
)
generated_text = self.gpt_tokenizer.decode(generated[0], skip_special_tokens=True)
prompt += f"\nOriginal text: {text}\nDebiased text: {generated_text}"
return prompt
def __call__(self, inputs, *args, **kwargs):
_args, _kwargs, forward_kwargs = self._sanitize_parameters(*args, **kwargs)
inputs = self.preprocess(inputs)
outputs = self._forward(inputs, **forward_kwargs)
return self.postprocess(outputs)