import streamlit as st import torch from peft import AutoPeftModelForCausalLM from transformers import AutoTokenizer, TextStreamer # bitsandbytes is no longer needed import io import sys import threading import time import queue # Import the queue module # --- Configuration --- DEFAULT_MODEL_PATH = "lora_model" # Or your default path # DEFAULT_LOAD_IN_4BIT is removed as we are not using quantization # --- Page Configuration --- st.set_page_config(page_title="Fine-tuned LLM Chat Interface (CPU)", layout="wide") st.title("Fine-tuned LLM Chat Interface (CPU Mode)") st.warning("Running in CPU mode. Expect slower generation times and higher RAM usage.", icon="⚠️") # --- Model Loading (Cached for CPU) --- @st.cache_resource(show_spinner="Loading model and tokenizer onto CPU...") def load_model_and_tokenizer_cpu(model_path): """Loads the PEFT model and tokenizer onto the CPU.""" try: # Use standard float32 for CPU compatibility and stability torch_dtype = torch.float32 model = AutoPeftModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, # load_in_4bit=False, # Explicitly removed/not needed device_map="cpu", # Force loading onto CPU ) tokenizer = AutoTokenizer.from_pretrained(model_path) model.eval() # Set model to evaluation mode print("Model and tokenizer loaded successfully onto CPU.") return model, tokenizer except Exception as e: st.error(f"Error loading model from path '{model_path}' onto CPU: {e}", icon="🚨") print(f"Error loading model onto CPU: {e}") return None, None # --- Custom Streamer Class (Modified for Queue) --- class QueueStreamer(TextStreamer): def __init__(self, tokenizer, skip_prompt, q): super().__init__(tokenizer, skip_prompt=skip_prompt) self.queue = q self.stop_signal = None # Can be used if needed, but queue is primary def on_finalized_text(self, text: str, stream_end: bool = False): """Puts the text onto the queue.""" self.queue.put(text) if stream_end: self.end() def end(self): """Signals the end of generation by putting None in the queue.""" self.queue.put(self.stop_signal) # Put None (or a specific sentinel) # --- Sidebar for Settings --- with st.sidebar: st.header("Model Configuration") st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}` (CPU Mode).") st.header("Generation Settings") temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05) # min_p might not be as commonly used or effective without top_p/top_k, # but keeping it allows experimentation. Consider using top_k or top_p instead. # Example: top_p = st.slider("Top P", min_value=0.01, max_value=1.0, value=0.9, step=0.01) min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01) # Keep for now max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=256, step=50) # Reduced default for CPU if st.button("Clear Chat History"): st.session_state.messages = [] st.rerun() # Rerun to clear display immediately # --- Load Model (runs only once on first run or if cache is cleared) --- model, tokenizer = load_model_and_tokenizer_cpu(DEFAULT_MODEL_PATH) # --- Initialize Session State --- if "messages" not in st.session_state: st.session_state.messages = [] # --- Main Chat Interface --- if model is None or tokenizer is None: st.error("CPU Model loading failed. Please check the path, available RAM, and logs. Cannot proceed.") st.stop() # Display conversation history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Handle user input user_input = st.chat_input("Ask the fine-tuned model (CPU)...") if user_input: # Add user message to history and display it st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.markdown(user_input) # Prepare for model response with st.chat_message("assistant"): response_placeholder = st.empty() response_placeholder.markdown("Generating response on CPU... please wait... ▌") # Initial message text_queue = queue.Queue() # Create a queue for this specific response # Initialize the modified streamer text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue) # Prepare input for the model messages_for_model = st.session_state.messages try: # Ensure inputs are on the CPU (model.device should be 'cpu' now) target_device = model.device # print(f"Model device: {target_device}") # Debugging: should print 'cpu' if tokenizer.chat_template: inputs = tokenizer.apply_chat_template( messages_for_model, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(target_device) # Send input tensors to CPU else: prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:" inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(target_device) # Send input tensors to CPU # Generation arguments generation_kwargs = dict( input_ids=inputs, streamer=text_streamer, # Use the QueueStreamer max_new_tokens=max_tokens, use_cache=True, # Caching can still help CPU generation speed temperature=temperature if temperature > 0 else None, top_p=None, # Consider adding top_p slider in UI # top_k=50, # Example: Or use top_k min_p=min_p, do_sample=True if temperature > 0 else False, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id ) # Define the target function for the thread def generation_thread_func(): try: # Run generation in the background thread (on CPU) # Wrap in torch.no_grad() to save memory during inference with torch.no_grad(): model.generate(**generation_kwargs) except Exception as e: # If error occurs in thread, signal stop and maybe log print(f"Error in generation thread: {e}") # Attempt to put error message in queue? Or just rely on main thread error handling st.error(f"Error during generation: {e}") # Show error in UI too finally: # Ensure the queue loop terminates even if error occurred text_streamer.end() # Start the generation thread thread = threading.Thread(target=generation_thread_func) thread.start() # --- Main thread: Read from queue and update UI --- generated_text = "" while True: try: # Get the next text chunk from the queue # Use timeout to prevent blocking indefinitely if thread hangs chunk = text_queue.get(block=True, timeout=1) # Short timeout OK for slow CPU gen if chunk is text_streamer.stop_signal: # Check for end signal (None) break generated_text += chunk response_placeholder.markdown(generated_text + "▌") # Update placeholder except queue.Empty: # If queue is empty, check if the generation thread is still running if not thread.is_alive(): # Thread finished, but maybe didn't put the stop signal (error?) break # Exit loop # Otherwise, continue waiting for next chunk continue except Exception as e: st.error(f"Error reading from generation queue: {e}") print(f"Error reading from queue: {e}") break # Exit loop on queue error # Final update without the cursor response_placeholder.markdown(generated_text) # Add the complete assistant response to history *after* generation if generated_text: # Only add if something was generated st.session_state.messages.append({"role": "assistant", "content": generated_text}) else: # Handle case where generation failed silently in thread or produced nothing if not any(m['role'] == 'assistant' and m['content'].startswith("*Error") for m in st.session_state.messages): st.warning("Assistant produced no output.", icon="⚠️") # Wait briefly for the thread to finish if it hasn't already thread.join(timeout=5.0) # Longer timeout might be needed if cleanup is slow except Exception as e: st.error(f"Error during generation setup or queue handling: {e}", icon="🔥") print(f"Error setting up generation or handling queue: {e}") # Add error to chat history for context error_message = f"*Error generating response: {e}*" if not generated_text: # Add if no text was generated at all st.session_state.messages.append({"role": "assistant", "content": error_message}) response_placeholder.error(f"Error generating response: {e}") else: # Append error notice if some text was generated before error st.session_state.messages.append({"role": "assistant", "content": generated_text + "\n\n" + error_message}) response_placeholder.markdown(generated_text + f"\n\n*{error_message}*")