ayush0504's picture
Update app.py
53ec504 verified
import streamlit as st
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, TextStreamer
# bitsandbytes is no longer needed
import io
import sys
import threading
import time
import queue # Import the queue module
# --- Configuration ---
DEFAULT_MODEL_PATH = "lora_model" # Or your default path
# DEFAULT_LOAD_IN_4BIT is removed as we are not using quantization
# --- Page Configuration ---
st.set_page_config(page_title="Fine-tuned LLM Chat Interface (CPU)", layout="wide")
st.title("Fine-tuned LLM Chat Interface (CPU Mode)")
st.warning("Running in CPU mode. Expect slower generation times and higher RAM usage.", icon="⚠️")
# --- Model Loading (Cached for CPU) ---
@st.cache_resource(show_spinner="Loading model and tokenizer onto CPU...")
def load_model_and_tokenizer_cpu(model_path):
"""Loads the PEFT model and tokenizer onto the CPU."""
try:
# Use standard float32 for CPU compatibility and stability
torch_dtype = torch.float32
model = AutoPeftModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
# load_in_4bit=False, # Explicitly removed/not needed
device_map="cpu", # Force loading onto CPU
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval() # Set model to evaluation mode
print("Model and tokenizer loaded successfully onto CPU.")
return model, tokenizer
except Exception as e:
st.error(f"Error loading model from path '{model_path}' onto CPU: {e}", icon="🚨")
print(f"Error loading model onto CPU: {e}")
return None, None
# --- Custom Streamer Class (Modified for Queue) ---
class QueueStreamer(TextStreamer):
def __init__(self, tokenizer, skip_prompt, q):
super().__init__(tokenizer, skip_prompt=skip_prompt)
self.queue = q
self.stop_signal = None # Can be used if needed, but queue is primary
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Puts the text onto the queue."""
self.queue.put(text)
if stream_end:
self.end()
def end(self):
"""Signals the end of generation by putting None in the queue."""
self.queue.put(self.stop_signal) # Put None (or a specific sentinel)
# --- Sidebar for Settings ---
with st.sidebar:
st.header("Model Configuration")
st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}` (CPU Mode).")
st.header("Generation Settings")
temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05)
# min_p might not be as commonly used or effective without top_p/top_k,
# but keeping it allows experimentation. Consider using top_k or top_p instead.
# Example: top_p = st.slider("Top P", min_value=0.01, max_value=1.0, value=0.9, step=0.01)
min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01) # Keep for now
max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=256, step=50) # Reduced default for CPU
if st.button("Clear Chat History"):
st.session_state.messages = []
st.rerun() # Rerun to clear display immediately
# --- Load Model (runs only once on first run or if cache is cleared) ---
model, tokenizer = load_model_and_tokenizer_cpu(DEFAULT_MODEL_PATH)
# --- Initialize Session State ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Main Chat Interface ---
if model is None or tokenizer is None:
st.error("CPU Model loading failed. Please check the path, available RAM, and logs. Cannot proceed.")
st.stop()
# Display conversation history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle user input
user_input = st.chat_input("Ask the fine-tuned model (CPU)...")
if user_input:
# Add user message to history and display it
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
# Prepare for model response
with st.chat_message("assistant"):
response_placeholder = st.empty()
response_placeholder.markdown("Generating response on CPU... please wait... ▌") # Initial message
text_queue = queue.Queue() # Create a queue for this specific response
# Initialize the modified streamer
text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue)
# Prepare input for the model
messages_for_model = st.session_state.messages
try:
# Ensure inputs are on the CPU (model.device should be 'cpu' now)
target_device = model.device
# print(f"Model device: {target_device}") # Debugging: should print 'cpu'
if tokenizer.chat_template:
inputs = tokenizer.apply_chat_template(
messages_for_model,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(target_device) # Send input tensors to CPU
else:
prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:"
inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(target_device) # Send input tensors to CPU
# Generation arguments
generation_kwargs = dict(
input_ids=inputs,
streamer=text_streamer, # Use the QueueStreamer
max_new_tokens=max_tokens,
use_cache=True, # Caching can still help CPU generation speed
temperature=temperature if temperature > 0 else None,
top_p=None, # Consider adding top_p slider in UI
# top_k=50, # Example: Or use top_k
min_p=min_p,
do_sample=True if temperature > 0 else False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
)
# Define the target function for the thread
def generation_thread_func():
try:
# Run generation in the background thread (on CPU)
# Wrap in torch.no_grad() to save memory during inference
with torch.no_grad():
model.generate(**generation_kwargs)
except Exception as e:
# If error occurs in thread, signal stop and maybe log
print(f"Error in generation thread: {e}")
# Attempt to put error message in queue? Or just rely on main thread error handling
st.error(f"Error during generation: {e}") # Show error in UI too
finally:
# Ensure the queue loop terminates even if error occurred
text_streamer.end()
# Start the generation thread
thread = threading.Thread(target=generation_thread_func)
thread.start()
# --- Main thread: Read from queue and update UI ---
generated_text = ""
while True:
try:
# Get the next text chunk from the queue
# Use timeout to prevent blocking indefinitely if thread hangs
chunk = text_queue.get(block=True, timeout=1) # Short timeout OK for slow CPU gen
if chunk is text_streamer.stop_signal: # Check for end signal (None)
break
generated_text += chunk
response_placeholder.markdown(generated_text + "▌") # Update placeholder
except queue.Empty:
# If queue is empty, check if the generation thread is still running
if not thread.is_alive():
# Thread finished, but maybe didn't put the stop signal (error?)
break # Exit loop
# Otherwise, continue waiting for next chunk
continue
except Exception as e:
st.error(f"Error reading from generation queue: {e}")
print(f"Error reading from queue: {e}")
break # Exit loop on queue error
# Final update without the cursor
response_placeholder.markdown(generated_text)
# Add the complete assistant response to history *after* generation
if generated_text: # Only add if something was generated
st.session_state.messages.append({"role": "assistant", "content": generated_text})
else:
# Handle case where generation failed silently in thread or produced nothing
if not any(m['role'] == 'assistant' and m['content'].startswith("*Error") for m in st.session_state.messages):
st.warning("Assistant produced no output.", icon="⚠️")
# Wait briefly for the thread to finish if it hasn't already
thread.join(timeout=5.0) # Longer timeout might be needed if cleanup is slow
except Exception as e:
st.error(f"Error during generation setup or queue handling: {e}", icon="🔥")
print(f"Error setting up generation or handling queue: {e}")
# Add error to chat history for context
error_message = f"*Error generating response: {e}*"
if not generated_text: # Add if no text was generated at all
st.session_state.messages.append({"role": "assistant", "content": error_message})
response_placeholder.error(f"Error generating response: {e}")
else: # Append error notice if some text was generated before error
st.session_state.messages.append({"role": "assistant", "content": generated_text + "\n\n" + error_message})
response_placeholder.markdown(generated_text + f"\n\n*{error_message}*")