krum-utsav commited on
Commit
5fc3c05
1 Parent(s): 66c5c22

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +0 -67
README.md CHANGED
@@ -38,73 +38,6 @@ paraphraser.paraphrase("Hey, can yuo hepl me cancel my last order?", tone="witty
38
  # "Hey, I need your help with my last order. Can you wave your magic wand and make it disappear?"
39
  ```
40
 
41
- OR use directly with transformers
42
-
43
- ```python
44
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
45
-
46
- model_id = "llm-toys/RedPajama-INCITE-Base-3B-v1-paraphrase-tone"
47
- DEVICE = "cuda"
48
- EOC_FORMAT = "\n\n### END"
49
- PARAPHRASE_PREDICT_FORMAT = (
50
- "### Instruction:\nGenerate a paraphrase for the following Input sentence.\n\n"
51
- "### Input:\n{input_text}\n\n### Response:\n"
52
- )
53
- TONE_CHANGE_PREDICT_FORMAT = (
54
- "### Instruction:\nChange the tone of the following Input sentence to {tone}.\n\n"
55
- "### Input:\n{input_text}\n\n### Response:\n"
56
- )
57
-
58
- tokenizer = AutoTokenizer.from_pretrained(model_id)
59
- model = AutoModel.from_pretrained(mode_id).to(DEVICE)
60
-
61
- class StoppingCriteriaSub(StoppingCriteria):
62
- """Helps in stopping the generation when a certain sequence of tokens is generated."""
63
-
64
- def __init__(self, stops: list = []):
65
- super().__init__()
66
- self.stops = stops
67
-
68
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> bool:
69
- return input_ids[0][-len(self.stops) :].tolist() == self.stops
70
-
71
-
72
- stopping_criteria = StoppingCriteriaList(
73
- [StoppingCriteriaSub(stops=tokenizer(EOC_FORMAT)["input_ids"])]
74
- )
75
-
76
- def predict(input_text: str) -> str:
77
- tokenized = tokenizer(
78
- input_text,
79
- max_length=self.max_length,
80
- padding=True,
81
- truncation=True,
82
- return_tensors="pt",
83
- )
84
-
85
- with torch.no_grad():
86
- out = model.generate(
87
- input_ids=tokenized["input_ids"].to(DEVICE),
88
- attention_mask=tokenized["attention_mask"].to(DEVICE),
89
- pad_token_id=self.tokenizer.eos_token_id,
90
- max_new_tokens=max_new_tokens,
91
- num_return_sequences=num_return_sequences,
92
- do_sample=True,
93
- temperature=temperature,
94
- top_p=top_p,
95
- stopping_criteria=self.stopping_criteria,
96
- )
97
-
98
- out_texts = [tokenizer.decode(o, skip_special_tokens=True) for o in out]
99
- return out_texts
100
-
101
- print("Paraphrasing:")
102
- print(predict(PARAPHRASE_PREDICT_FORMAT.format(input_text="If you have any further questions, feel free to ask.")))
103
-
104
- print("Tone change:")
105
- print(predict(TONE_CHANGE_PREDICT_FORMAT.format(input_text="If you have any further questions, feel free to ask.", tone="professional")))
106
- ```
107
-
108
  ## Sample training data
109
 
110
  ```json
 
38
  # "Hey, I need your help with my last order. Can you wave your magic wand and make it disappear?"
39
  ```
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  ## Sample training data
42
 
43
  ```json