tushar-r-pawar's picture
Update app.py
617e23e verified
import torch
import streamlit as st
import os
from dotenv import load_dotenv
from airllm import AutoModel,AirLLMInternLM
from transformers import AutoTokenizer, GenerationConfig
# Load environment variables
load_dotenv()
# Retrieve the API token from the environment variables
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Initialize model and tokenizer
MAX_LENGTH = 128
model_name = "internlm/internlm2_5-7b"
try:
model = AirLLMInternLM.from_pretrained(model_name)
except:
try:
model=AirLLMInternLM(model_name)
except:
model=AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Streamlit app configuration
st.set_page_config(
page_title="Conversational Chatbot with internlm2_5-7b-chat",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded",
)
# App title
st.title("Conversational Chatbot with internlm2_5-7b-chat and AirLLM")
# Sidebar configuration
st.sidebar.header("Chatbot Configuration")
theme = st.sidebar.selectbox("Choose a theme", ["Default", "Dark", "Light"])
# Set theme based on user selection
if theme == "Dark":
st.markdown(
"""
<style>
.reportview-container {
background: #2E2E2E;
color: #FFFFFF;
}
.sidebar .sidebar-content {
background: #333333;
}
</style>
""",
unsafe_allow_html=True
)
elif theme == "Light":
st.markdown(
"""
<style>
.reportview-container {
background: #FFFFFF;
color: #000000;
}
.sidebar .sidebar-content {
background: #F5F5F5;
}
</style>
""",
unsafe_allow_html=True
)
# Chat input and output
user_input = st.text_input("You: ", "")
if st.button("Send"):
if user_input:
# Tokenize user input
input_tokens = tokenizer(user_input,
return_tensors="pt",
return_attention_mask=False,
truncation=True,
max_length=MAX_LENGTH,
padding=False)
# Check if CUDA is available and use it if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_tokens = input_tokens.to(device)
# Generate response
generation_output = model.generate(
input_ids=input_tokens['input_ids'],
max_new_tokens=20,
use_cache=True,
return_dict_in_generate=True)
# Decode response
response = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
st.text_area("Bot:", value=response, height=200, max_chars=None)
else:
st.warning("Please enter a message.")
# Footer
st.sidebar.markdown(
"""
### About
This is a conversational chatbot built using the internlm2_5-7b-chat model and AirLLM.
"""
)