Locutusque
commited on
Commit
·
e63b202
1
Parent(s):
75f80ec
Update README.md
Browse files
README.md
CHANGED
@@ -92,7 +92,7 @@ model = GPT2LMHeadModel.from_pretrained('gpt2-conversational-retrain')
|
|
92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
model.to(device)
|
94 |
def generate_text(model, tokenizer, prompt, max_length=1024):
|
95 |
-
prompt = f'<|
|
96 |
input_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt").to(device)
|
97 |
attention_mask = torch.ones_like(input_ids).to(device)
|
98 |
output = model.generate(input_ids,
|
|
|
92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
model.to(device)
|
94 |
def generate_text(model, tokenizer, prompt, max_length=1024):
|
95 |
+
prompt = f'<|USER|> {prompt} <|ASSISTANT|> '
|
96 |
input_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt").to(device)
|
97 |
attention_mask = torch.ones_like(input_ids).to(device)
|
98 |
output = model.generate(input_ids,
|