betajuned's picture
Update app.py
6bfd168 verified
raw
history blame
1.42 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
title = "GPT-2 JTE Chatbot"
description = "GPT-2 Menjadi Chatbot"
examples = [["Bagaimana cara mengisi KRS?"]]
tokenizer = AutoTokenizer.from_pretrained("betajuned/GPT-2_Kombinasi4")
model = AutoModelForCausalLM.from_pretrained("betajuned/GPT-2_Kombinasi4")
def predict(input, history=[]):
# Tokenize the new input sentence
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
# Append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.tensor(history, dtype=torch.long), new_user_input_ids], dim=-1) if history else new_user_input_ids
# Generate a response
chat_history_ids = model.generate(bot_input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id)
# Convert the tokens to text
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# Update the history with the new tokens
new_history = chat_history_ids[0].tolist()
return response, new_history
gr.Interface(
fn=predict,
title=title,
description=description,
examples=examples,
inputs=[gr.inputs.Textbox(lines=2, placeholder="Enter your message here..."), gr.inputs.State()],
outputs=[gr.outputs.Textbox(), gr.outputs.State()],
theme="finlaymacklon/boxy_violet",
).launch()