Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import torch | |
import json | |
title = "Smart AI ChatBot" | |
description = "A conversational model capable of intelligently answering questions (DialoGPT)" | |
examples = [["How are you?"], ["What's the weather like?"]] | |
# Load DialoGPT model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") | |
# Known question-answer pairs, you can add more as per your requirement | |
known_questions_answers = { | |
"How are you?": "I'm fine, thank you for asking.", | |
"What's the weather like?": "The weather is nice today, sunny and warm.", | |
"What's your name?": "I am Smart AI ChatBot.", | |
"Do you speak English?": "I can understand and respond to English questions.", | |
} | |
def predict(input, history=[]): | |
response = None | |
# Check if the input question is in the known question-answer pairs | |
if input in known_questions_answers: | |
response = known_questions_answers[input] | |
else: | |
# Tokenize the new user input sentence | |
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt") | |
# Append the new user input tokens to the chat history | |
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) | |
# Generate a response | |
history = model.generate( | |
bot_input_ids, max_length=400, pad_token_id=tokenizer.eos_token_id | |
).tolist() | |
# Convert tokens to text, and split the response into lines | |
response = tokenizer.decode(history[0], skip_special_tokens=True) | |
return response, history | |
def main(): | |
# You can add logic here to read known question-answer pairs, for example, from a JSON file | |
pass | |
gr.Interface( | |
fn=predict, | |
title=title, | |
description=description, | |
examples=examples, | |
inputs=["text", "state"], | |
outputs=["chatbot", "state"], | |
theme="finlaymacklon/boxy_violet", | |
).launch() | |