Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
class LLAMAChatbot: | |
def __init__(self): | |
st.title("LLAMA Chatbot") | |
self.initialize_model() | |
self.initialize_session_state() | |
def initialize_model(self): | |
"""Initialize the LLAMA model and tokenizer""" | |
try: | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("joermd/llma-speedy") | |
model = AutoModelForCausalLM.from_pretrained( | |
"joermd/llma-speedy", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
return model, tokenizer | |
self.model, self.tokenizer = load_model() | |
st.success("تم تحميل النموذج بنجاح!") | |
except Exception as e: | |
st.error(f"حدث خطأ أثناء تحميل النموذج: {str(e)}") | |
st.stop() | |
def initialize_session_state(self): | |
"""Initialize chat history if it doesn't exist""" | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
def display_chat_history(self): | |
"""Display all messages from chat history""" | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
def add_message(self, role, content): | |
"""Add a message to the chat history""" | |
st.session_state.messages.append({ | |
"role": role, | |
"content": content | |
}) | |
def generate_response(self, user_input, max_length=1000): | |
"""Generate response using LLAMA model""" | |
try: | |
# Prepare the input context with chat history | |
context = "" | |
for message in st.session_state.messages[-4:]: # Use last 4 messages for context | |
if message["role"] == "user": | |
context += f"Human: {message['content']}\n" | |
else: | |
context += f"Assistant: {message['content']}\n" | |
context += f"Human: {user_input}\nAssistant:" | |
# Tokenize input | |
inputs = self.tokenizer(context, return_tensors="pt", truncation=True) | |
inputs = inputs.to(self.model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
response = response.split("Assistant:")[-1].strip() | |
return response | |
except Exception as e: | |
return f"عذراً، حدث خطأ أثناء توليد الإجابة: {str(e)}" | |
def simulate_typing(self, message_placeholder, response): | |
"""Simulate typing effect for bot response""" | |
full_response = "" | |
for chunk in response.split(): | |
full_response += chunk + " " | |
time.sleep(0.05) | |
message_placeholder.markdown(full_response + "▌") | |
message_placeholder.markdown(full_response) | |
return full_response | |
def run(self): | |
"""Main application loop""" | |
# Display existing chat history | |
self.display_chat_history() | |
# Handle user input | |
if user_input := st.chat_input("اكتب رسالتك هنا..."): | |
# Display and save user message | |
self.add_message("user", user_input) | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
# Generate and display response | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
with st.spinner("جاري التفكير..."): | |
assistant_response = self.generate_response(user_input) | |
full_response = self.simulate_typing(message_placeholder, assistant_response) | |
self.add_message("assistant", full_response) | |
if __name__ == "__main__": | |
# Set page config | |
st.set_page_config( | |
page_title="LLAMA Chatbot", | |
page_icon="🤖", | |
layout="wide" | |
) | |
# Initialize and run the chatbot | |
chatbot = LLAMAChatbot() | |
chatbot.run() |