Spaces:
Sleeping
Sleeping
File size: 3,509 Bytes
94de423 c4dbf79 278cccd 20acc65 718dddc eed6afe 15bac75 c14b7e6 e3d3ccf 3a2adb5 c14b7e6 15bac75 eed6afe 15bac75 eed6afe 15bac75 eed6afe 94de423 718dddc 6151565 718dddc 6151565 718dddc 6151565 718dddc 6151565 6d0e485 94de423 4710379 94de423 7871f27 06a12b8 9869827 15bac75 06a12b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
__checkpoint = "gpt2"
__tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint)
__model = GPT2LMHeadModel.from_pretrained(__checkpoint)
__model_output_path = "gpt_model"
# Create a Data collator object
__data_collator = DataCollatorForLanguageModeling(tokenizer=__tokenizer, mlm=False, return_tensors="pt")
#prepare data
def prepareData():
df=pd.read_csv("MedQuAD.csv")
df['Question']=df['Question'].replace(r'^\s*$', np.nan, regex=True)
df['Answer']=df['Answer'].replace(r'^\s*$', np.nan, regex=True)
df = df.drop_duplicates(subset=['Question', 'Answer'])
df=df.dropna()
train_ds=df.groupby('Focus').head(100)
train_ds=train_ds.groupby('Focus').head(4).reset_index(drop=True)
test_ds=train_ds.groupby('Focus').head(1).reset_index(drop=True)
train_seq=list()
for i in range(len(train_ds)):
s='<question>'+train_ds.loc[i,'Question']+'<answer>'+train_ds.loc[i,'Answer']
train_seq.append(s)
val_seq=list()
for i in range(len(test_ds)):
s='<question>'+test_ds.loc[i,'Question']+'<answer>'+test_ds.loc[i,'Answer']
val_seq.append(s)
with open("train.txt", "w") as f:
f.writelines(line+'\n' for line in train_seq)
with open("val.txt", "w") as f:
f.writelines(line+'\n' for line in val_seq)
def fine_tune_gpt():
train_dataset = TextDataset(tokenizer=__tokenizer, file_path="train.txt", block_size=128)
val_dataset = TextDataset(tokenizer=__tokenizer, file_path="val.txt", block_size=128)
training_args = TrainingArguments(
output_dir = __model_output_path,
overwrite_output_dir = True,
per_device_train_batch_size = 2, # try with 2
per_device_eval_batch_size = 2, # try with 2
num_train_epochs = 0.01,
save_steps = 1_000,
save_total_limit = 2,
logging_dir = './logs',
)
# Train the model
trainer = Trainer(
model = __model,
args = training_args,
data_collator = __data_collator,
train_dataset = train_dataset,
eval_dataset = val_dataset,
)
trainer.train()
# Save the model
trainer.save_model(model_output_path)
# Save the tokenizer
__tokenizer.save_pretrained(model_output_path)
def queryGPT(question):
return generate_response(__model, __tokenizer, question)
def generate_response(model,tokenizer, prompt, max_length=200):
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
# Create the attention mask and pad token id
attention_mask = torch.ones_like(input_ids)
pad_token_id = tokenizer.eos_token_id
output = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
attention_mask=attention_mask,
pad_token_id=pad_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True)
with gr.Blocks() as demo:
txt_input = gr.Textbox(label="Input Question", lines=2)
txt_output = gr.Textbox(value="", label="Answer")
btn = gr.Button(value="Submit")
btn.click(queryGPT, inputs=[txt_input], outputs=[txt_output])
if __name__ == "__main__":
prepareData()
fine_tune_gpt()
demo.launch() |