Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Load pre-trained model and tokenizer | |
model_name = "microsoft/DialoGPT-medium" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Streamlit UI | |
st.title("E-Commerce Chatbot") | |
st.write("Ask about our products or browse categories. Type 'exit' to end the conversation.") | |
# Initialize the chat history | |
chat_history_ids = None | |
# User input | |
query = st.text_input("Your Query:") | |
if query: | |
# Tokenize the input and append to chat history | |
new_input_ids = tokenizer.encode(query + tokenizer.eos_token, return_tensors="pt") | |
# Update chat history | |
chat_history_ids = new_input_ids if chat_history_ids is None else torch.cat([chat_history_ids, new_input_ids], dim=-1) | |
# Generate a response from the model | |
response_ids = model.generate( | |
chat_history_ids, | |
max_length=1000, | |
pad_token_id=tokenizer.eos_token_id, | |
attention_mask=torch.ones(chat_history_ids.shape, dtype=torch.long) | |
) | |
# Decode the model's output | |
response = tokenizer.decode(response_ids[:, chat_history_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Display the chatbot's response | |
st.write(f"Chatbot: {response}") | |