paavansundar commited on
Commit
15bac75
1 Parent(s): 9869827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -10,6 +10,8 @@ __checkpoint = "gpt2"
10
  __tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint)
11
  __model = GPT2LMHeadModel.from_pretrained(__checkpoint)
12
  __model_output_path = "paavansundar/Medical_QNA_GPT2"
 
 
13
 
14
  #prepare data
15
  def prepareData():
@@ -35,14 +37,44 @@ def prepareData():
35
  with open("val.txt", "w") as f:
36
  f.writelines(line+'\n' for line in val_seq)
37
 
38
- # Create a Data collator object
39
- data_collator = DataCollatorForLanguageModeling(tokenizer=__tokenizer, mlm=False, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def queryGPT(question):
41
  return generate_response(__model, __tokenizer, question)
42
 
43
  def generate_response(model,tokenizer, prompt, max_length=200):
44
- train_dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
45
- val_dataset = TextDataset(tokenizer=tokenizer, file_path="val.txt", block_size=128)
46
  input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
47
 
48
  # Create the attention mask and pad token id
@@ -67,4 +99,5 @@ with gr.Blocks() as demo:
67
  btn.click(queryGPT, inputs=[txt_input], outputs=[txt_output])
68
  if __name__ == "__main__":
69
  prepareData()
 
70
  demo.launch()
 
10
  __tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint)
11
  __model = GPT2LMHeadModel.from_pretrained(__checkpoint)
12
  __model_output_path = "paavansundar/Medical_QNA_GPT2"
13
+ # Create a Data collator object
14
+ __data_collator = DataCollatorForLanguageModeling(tokenizer=__tokenizer, mlm=False, return_tensors="pt")
15
 
16
  #prepare data
17
  def prepareData():
 
37
  with open("val.txt", "w") as f:
38
  f.writelines(line+'\n' for line in val_seq)
39
 
40
+
41
+
42
+ def fine_tune_gpt():
43
+
44
+ model_output_path = "gpt_model"
45
+ train_dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
46
+ val_dataset = TextDataset(tokenizer=tokenizer, file_path="val.txt", block_size=128)
47
+ training_args = TrainingArguments(
48
+ output_dir = model_output_path,
49
+ overwrite_output_dir = True,
50
+ per_device_train_batch_size = 2, # try with 2
51
+ per_device_eval_batch_size = 2, # try with 2
52
+ num_train_epochs = 0.01,
53
+ save_steps = 1_000,
54
+ save_total_limit = 2,
55
+ logging_dir = './logs',
56
+ )
57
+
58
+ # Train the model
59
+ trainer = Trainer(
60
+ model = __model,
61
+ args = training_args,
62
+ data_collator = __data_collator,
63
+ train_dataset = train_dataset,
64
+ eval_dataset = val_dataset,
65
+ )
66
+
67
+ trainer.train()
68
+
69
+ # Save the model
70
+ trainer.save_model(model_output_path)
71
+
72
+ # Save the tokenizer
73
+ tokenizer.save_pretrained(model_output_path)
74
  def queryGPT(question):
75
  return generate_response(__model, __tokenizer, question)
76
 
77
  def generate_response(model,tokenizer, prompt, max_length=200):
 
 
78
  input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
79
 
80
  # Create the attention mask and pad token id
 
99
  btn.click(queryGPT, inputs=[txt_input], outputs=[txt_output])
100
  if __name__ == "__main__":
101
  prepareData()
102
+ fine_tune_gpt()
103
  demo.launch()