File size: 1,508 Bytes
24ebe7f
f5f2fbb
24ebe7f
f5f2fbb
ec8623a
24ebe7f
 
 
 
 
 
f5f2fbb
24ebe7f
 
 
 
 
 
 
 
 
 
a5cbfe0
24ebe7f
 
eff7b9a
24ebe7f
 
 
 
 
a899f18
 
24ebe7f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel


def generate_response(model, tokenizer, prompt, max_length=250):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    # Create the attention mask and pad token id
    attention_mask = torch.ones_like(input_ids)
    pad_token_id = tokenizer.eos_token_id

    output = model.generate(
        input_ids,
        max_length=max_length,
        num_return_sequences=1,
        attention_mask=attention_mask,
        pad_token_id=pad_token_id
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)

model_path = "Tinyants21/Canine_model"
# Load the fine-tuned model and tokenizer
my_chat_model = GPT2LMHeadModel.from_pretrained(model_path)
my_chat_tokenizer = GPT2Tokenizer.from_pretrained(model_path)

def generate_response_gradio(prompt):
    response = generate_response(my_chat_model, my_chat_tokenizer, prompt, max_length=250)  
    return response

title = "Canine Distemper FAQ"
description = "Chatbot that uses a GPT-2 - 775 Million parameter model to answer common questions about canine distemper."
examples = [
    ["What is canine distemper?"],
    ["How is canine distemper transmitted?"],
    ["Is there a vaccine for canine distemper?"],
]

inputs = gr.inputs.Textbox(label="Question")
outputs = gr.outputs.Textbox(label="Answer")

gr.Interface(fn=generate_response_gradio, inputs=inputs, outputs=outputs, title=title, description=description, examples=examples).launch()