File size: 3,651 Bytes
94de423
c4dbf79
 
 
278cccd
 
 
20acc65
718dddc
 
 
eed6afe
15bac75
 
c14b7e6
 
 
e3d3ccf
3a2adb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c14b7e6
15bac75
 
 
 
eed6afe
 
 
15bac75
eed6afe
15bac75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eed6afe
 
 
94de423
718dddc
 
6151565
718dddc
0ed65fe
 
 
718dddc
 
6151565
718dddc
6151565
718dddc
 
 
 
 
 
 
6151565
6d0e485
94de423
4710379
94de423
 
 
7871f27
06a12b8
0ed65fe
 
06a12b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

__checkpoint = "gpt2"
__tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint)
__model = GPT2LMHeadModel.from_pretrained(__checkpoint)
__model_output_path = "gpt_model"
# Create a Data collator object
__data_collator = DataCollatorForLanguageModeling(tokenizer=__tokenizer, mlm=False, return_tensors="pt")

#prepare data
def prepareData():
    df=pd.read_csv("MedQuAD.csv")
    df['Question']=df['Question'].replace(r'^\s*$', np.nan, regex=True)
    df['Answer']=df['Answer'].replace(r'^\s*$', np.nan, regex=True)
    df = df.drop_duplicates(subset=['Question', 'Answer'])
    df=df.dropna()
    train_ds=df.groupby('Focus').head(100)
    train_ds=train_ds.groupby('Focus').head(4).reset_index(drop=True)
    test_ds=train_ds.groupby('Focus').head(1).reset_index(drop=True)
    train_seq=list()
    for i in range(len(train_ds)):
      s='<question>'+train_ds.loc[i,'Question']+'<answer>'+train_ds.loc[i,'Answer']
      train_seq.append(s)
    val_seq=list()
    for i in range(len(test_ds)):
      s='<question>'+test_ds.loc[i,'Question']+'<answer>'+test_ds.loc[i,'Answer']
      val_seq.append(s)
    with open("train.txt", "w") as f:
        f.writelines(line+'\n' for line in train_seq)

    with open("val.txt", "w") as f:
        f.writelines(line+'\n' for line in val_seq)
    


def fine_tune_gpt():
    
    
    train_dataset = TextDataset(tokenizer=__tokenizer, file_path="train.txt", block_size=128)
    val_dataset = TextDataset(tokenizer=__tokenizer, file_path="val.txt", block_size=128)
    training_args = TrainingArguments(
        output_dir = __model_output_path,
        overwrite_output_dir = True,
        per_device_train_batch_size = 2, # try with 2
        per_device_eval_batch_size = 2,  #  try with 2
        num_train_epochs = 0.01,
        save_steps = 1_000,
        save_total_limit = 2,
        logging_dir = './logs',
        )

# Train the model
    trainer = Trainer(
        model = __model,
        args = training_args,
        data_collator = __data_collator,
        train_dataset = train_dataset,
        eval_dataset = val_dataset,
    )

    trainer.train()

    # Save the model
    trainer.save_model(model_output_path)

    # Save the tokenizer
    __tokenizer.save_pretrained(model_output_path)
    
def queryGPT(question):
    return generate_response(__model, __tokenizer, question)

def generate_response(model,tokenizer, prompt, max_length=200):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")      # 'pt' for returning pytorch tensor
    #my_model = GPT2LMHeadModel.from_pretrained(model_output_path)
    #my_tokenizer = GPT2Tokenizer.from_pretrained(model_output_path)
    
    # Create the attention mask and pad token id
    attention_mask = torch.ones_like(input_ids)
    pad_token_id = tokenizer.eos_token_id

    output = model.generate(
        input_ids,
        max_length=max_length,
        num_return_sequences=1,
        attention_mask=attention_mask,
        pad_token_id=pad_token_id
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)
    
with gr.Blocks() as demo:

    txt_input = gr.Textbox(label="Input Question", lines=2)
    txt_output = gr.Textbox(value="", label="Answer")
    btn = gr.Button(value="Submit")
    btn.click(queryGPT, inputs=[txt_input], outputs=[txt_output])
if __name__ == "__main__":
    #prepareData()
    #fine_tune_gpt()
    demo.launch()