paavansundar commited on
Commit
718dddc
1 Parent(s): c4dbf79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -1
app.py CHANGED
@@ -9,8 +9,31 @@ import torch
9
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
10
  from transformers import Trainer, TrainingArguments
11
 
 
 
 
 
 
12
  def queryGPT(question):
13
- return "<question>"+question+"<Answer>Test"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  with gr.Blocks() as demo:
16
 
 
9
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
10
  from transformers import Trainer, TrainingArguments
11
 
12
+ __checkpoint = "gpt2"
13
+ __tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint)
14
+ __model = GPT2LMHeadModel.from_pretrained(__checkpoint)
15
+ # Create a Data collator object
16
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
17
  def queryGPT(question):
18
+ return generate_response(__model, __tokenizer, question)
19
+
20
+ def generate_response(__model, __tokenizer, prompt, max_length=200):
21
+
22
+ input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
23
+
24
+ # Create the attention mask and pad token id
25
+ attention_mask = torch.ones_like(input_ids)
26
+ pad_token_id = __tokenizer.eos_token_id
27
+
28
+ output = __model.generate(
29
+ input_ids,
30
+ max_length=max_length,
31
+ num_return_sequences=1,
32
+ attention_mask=attention_mask,
33
+ pad_token_id=pad_token_id
34
+ )
35
+
36
+ return __tokenizer.decode(output[0], skip_special_tokens=True)
37
 
38
  with gr.Blocks() as demo:
39