|
import torch |
|
import streamlit as st |
|
import os |
|
from dotenv import load_dotenv |
|
from airllm import AutoModel,AirLLMInternLM |
|
from transformers import AutoTokenizer, GenerationConfig |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
MAX_LENGTH = 128 |
|
model_name = "internlm/internlm2_5-7b" |
|
try: |
|
model = AirLLMInternLM.from_pretrained(model_name) |
|
except: |
|
try: |
|
model=AirLLMInternLM(model_name) |
|
except: |
|
model=AutoModel.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
st.set_page_config( |
|
page_title="Conversational Chatbot with internlm2_5-7b-chat", |
|
page_icon="π€", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
|
|
st.title("Conversational Chatbot with internlm2_5-7b-chat and AirLLM") |
|
|
|
|
|
st.sidebar.header("Chatbot Configuration") |
|
theme = st.sidebar.selectbox("Choose a theme", ["Default", "Dark", "Light"]) |
|
|
|
|
|
if theme == "Dark": |
|
st.markdown( |
|
""" |
|
<style> |
|
.reportview-container { |
|
background: #2E2E2E; |
|
color: #FFFFFF; |
|
} |
|
.sidebar .sidebar-content { |
|
background: #333333; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
elif theme == "Light": |
|
st.markdown( |
|
""" |
|
<style> |
|
.reportview-container { |
|
background: #FFFFFF; |
|
color: #000000; |
|
} |
|
.sidebar .sidebar-content { |
|
background: #F5F5F5; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
user_input = st.text_input("You: ", "") |
|
if st.button("Send"): |
|
if user_input: |
|
|
|
input_tokens = tokenizer(user_input, |
|
return_tensors="pt", |
|
return_attention_mask=False, |
|
truncation=True, |
|
max_length=MAX_LENGTH, |
|
padding=False) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
input_tokens = input_tokens.to(device) |
|
|
|
|
|
generation_output = model.generate( |
|
input_ids=input_tokens['input_ids'], |
|
max_new_tokens=20, |
|
use_cache=True, |
|
return_dict_in_generate=True) |
|
|
|
|
|
response = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True) |
|
st.text_area("Bot:", value=response, height=200, max_chars=None) |
|
else: |
|
st.warning("Please enter a message.") |
|
|
|
|
|
st.sidebar.markdown( |
|
""" |
|
### About |
|
This is a conversational chatbot built using the internlm2_5-7b-chat model and AirLLM. |
|
""" |
|
) |
|
|