Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from peft import AutoPeftModelForCausalLM | |
from transformers import AutoTokenizer, TextStreamer | |
# bitsandbytes is no longer needed | |
import io | |
import sys | |
import threading | |
import time | |
import queue # Import the queue module | |
# --- Configuration --- | |
DEFAULT_MODEL_PATH = "lora_model" # Or your default path | |
# DEFAULT_LOAD_IN_4BIT is removed as we are not using quantization | |
# --- Page Configuration --- | |
st.set_page_config(page_title="Fine-tuned LLM Chat Interface (CPU)", layout="wide") | |
st.title("Fine-tuned LLM Chat Interface (CPU Mode)") | |
st.warning("Running in CPU mode. Expect slower generation times and higher RAM usage.", icon="⚠️") | |
# --- Model Loading (Cached for CPU) --- | |
def load_model_and_tokenizer_cpu(model_path): | |
"""Loads the PEFT model and tokenizer onto the CPU.""" | |
try: | |
# Use standard float32 for CPU compatibility and stability | |
torch_dtype = torch.float32 | |
model = AutoPeftModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch_dtype, | |
# load_in_4bit=False, # Explicitly removed/not needed | |
device_map="cpu", # Force loading onto CPU | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model.eval() # Set model to evaluation mode | |
print("Model and tokenizer loaded successfully onto CPU.") | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model from path '{model_path}' onto CPU: {e}", icon="🚨") | |
print(f"Error loading model onto CPU: {e}") | |
return None, None | |
# --- Custom Streamer Class (Modified for Queue) --- | |
class QueueStreamer(TextStreamer): | |
def __init__(self, tokenizer, skip_prompt, q): | |
super().__init__(tokenizer, skip_prompt=skip_prompt) | |
self.queue = q | |
self.stop_signal = None # Can be used if needed, but queue is primary | |
def on_finalized_text(self, text: str, stream_end: bool = False): | |
"""Puts the text onto the queue.""" | |
self.queue.put(text) | |
if stream_end: | |
self.end() | |
def end(self): | |
"""Signals the end of generation by putting None in the queue.""" | |
self.queue.put(self.stop_signal) # Put None (or a specific sentinel) | |
# --- Sidebar for Settings --- | |
with st.sidebar: | |
st.header("Model Configuration") | |
st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}` (CPU Mode).") | |
st.header("Generation Settings") | |
temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05) | |
# min_p might not be as commonly used or effective without top_p/top_k, | |
# but keeping it allows experimentation. Consider using top_k or top_p instead. | |
# Example: top_p = st.slider("Top P", min_value=0.01, max_value=1.0, value=0.9, step=0.01) | |
min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01) # Keep for now | |
max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=256, step=50) # Reduced default for CPU | |
if st.button("Clear Chat History"): | |
st.session_state.messages = [] | |
st.rerun() # Rerun to clear display immediately | |
# --- Load Model (runs only once on first run or if cache is cleared) --- | |
model, tokenizer = load_model_and_tokenizer_cpu(DEFAULT_MODEL_PATH) | |
# --- Initialize Session State --- | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# --- Main Chat Interface --- | |
if model is None or tokenizer is None: | |
st.error("CPU Model loading failed. Please check the path, available RAM, and logs. Cannot proceed.") | |
st.stop() | |
# Display conversation history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Handle user input | |
user_input = st.chat_input("Ask the fine-tuned model (CPU)...") | |
if user_input: | |
# Add user message to history and display it | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
# Prepare for model response | |
with st.chat_message("assistant"): | |
response_placeholder = st.empty() | |
response_placeholder.markdown("Generating response on CPU... please wait... ▌") # Initial message | |
text_queue = queue.Queue() # Create a queue for this specific response | |
# Initialize the modified streamer | |
text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue) | |
# Prepare input for the model | |
messages_for_model = st.session_state.messages | |
try: | |
# Ensure inputs are on the CPU (model.device should be 'cpu' now) | |
target_device = model.device | |
# print(f"Model device: {target_device}") # Debugging: should print 'cpu' | |
if tokenizer.chat_template: | |
inputs = tokenizer.apply_chat_template( | |
messages_for_model, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(target_device) # Send input tensors to CPU | |
else: | |
prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:" | |
inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(target_device) # Send input tensors to CPU | |
# Generation arguments | |
generation_kwargs = dict( | |
input_ids=inputs, | |
streamer=text_streamer, # Use the QueueStreamer | |
max_new_tokens=max_tokens, | |
use_cache=True, # Caching can still help CPU generation speed | |
temperature=temperature if temperature > 0 else None, | |
top_p=None, # Consider adding top_p slider in UI | |
# top_k=50, # Example: Or use top_k | |
min_p=min_p, | |
do_sample=True if temperature > 0 else False, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id | |
) | |
# Define the target function for the thread | |
def generation_thread_func(): | |
try: | |
# Run generation in the background thread (on CPU) | |
# Wrap in torch.no_grad() to save memory during inference | |
with torch.no_grad(): | |
model.generate(**generation_kwargs) | |
except Exception as e: | |
# If error occurs in thread, signal stop and maybe log | |
print(f"Error in generation thread: {e}") | |
# Attempt to put error message in queue? Or just rely on main thread error handling | |
st.error(f"Error during generation: {e}") # Show error in UI too | |
finally: | |
# Ensure the queue loop terminates even if error occurred | |
text_streamer.end() | |
# Start the generation thread | |
thread = threading.Thread(target=generation_thread_func) | |
thread.start() | |
# --- Main thread: Read from queue and update UI --- | |
generated_text = "" | |
while True: | |
try: | |
# Get the next text chunk from the queue | |
# Use timeout to prevent blocking indefinitely if thread hangs | |
chunk = text_queue.get(block=True, timeout=1) # Short timeout OK for slow CPU gen | |
if chunk is text_streamer.stop_signal: # Check for end signal (None) | |
break | |
generated_text += chunk | |
response_placeholder.markdown(generated_text + "▌") # Update placeholder | |
except queue.Empty: | |
# If queue is empty, check if the generation thread is still running | |
if not thread.is_alive(): | |
# Thread finished, but maybe didn't put the stop signal (error?) | |
break # Exit loop | |
# Otherwise, continue waiting for next chunk | |
continue | |
except Exception as e: | |
st.error(f"Error reading from generation queue: {e}") | |
print(f"Error reading from queue: {e}") | |
break # Exit loop on queue error | |
# Final update without the cursor | |
response_placeholder.markdown(generated_text) | |
# Add the complete assistant response to history *after* generation | |
if generated_text: # Only add if something was generated | |
st.session_state.messages.append({"role": "assistant", "content": generated_text}) | |
else: | |
# Handle case where generation failed silently in thread or produced nothing | |
if not any(m['role'] == 'assistant' and m['content'].startswith("*Error") for m in st.session_state.messages): | |
st.warning("Assistant produced no output.", icon="⚠️") | |
# Wait briefly for the thread to finish if it hasn't already | |
thread.join(timeout=5.0) # Longer timeout might be needed if cleanup is slow | |
except Exception as e: | |
st.error(f"Error during generation setup or queue handling: {e}", icon="🔥") | |
print(f"Error setting up generation or handling queue: {e}") | |
# Add error to chat history for context | |
error_message = f"*Error generating response: {e}*" | |
if not generated_text: # Add if no text was generated at all | |
st.session_state.messages.append({"role": "assistant", "content": error_message}) | |
response_placeholder.error(f"Error generating response: {e}") | |
else: # Append error notice if some text was generated before error | |
st.session_state.messages.append({"role": "assistant", "content": generated_text + "\n\n" + error_message}) | |
response_placeholder.markdown(generated_text + f"\n\n*{error_message}*") |