Spaces:
Sleeping
Sleeping
File size: 10,211 Bytes
53ec504 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import streamlit as st
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, TextStreamer
# bitsandbytes is no longer needed
import io
import sys
import threading
import time
import queue # Import the queue module
# --- Configuration ---
DEFAULT_MODEL_PATH = "lora_model" # Or your default path
# DEFAULT_LOAD_IN_4BIT is removed as we are not using quantization
# --- Page Configuration ---
st.set_page_config(page_title="Fine-tuned LLM Chat Interface (CPU)", layout="wide")
st.title("Fine-tuned LLM Chat Interface (CPU Mode)")
st.warning("Running in CPU mode. Expect slower generation times and higher RAM usage.", icon="⚠️")
# --- Model Loading (Cached for CPU) ---
@st.cache_resource(show_spinner="Loading model and tokenizer onto CPU...")
def load_model_and_tokenizer_cpu(model_path):
"""Loads the PEFT model and tokenizer onto the CPU."""
try:
# Use standard float32 for CPU compatibility and stability
torch_dtype = torch.float32
model = AutoPeftModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
# load_in_4bit=False, # Explicitly removed/not needed
device_map="cpu", # Force loading onto CPU
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval() # Set model to evaluation mode
print("Model and tokenizer loaded successfully onto CPU.")
return model, tokenizer
except Exception as e:
st.error(f"Error loading model from path '{model_path}' onto CPU: {e}", icon="🚨")
print(f"Error loading model onto CPU: {e}")
return None, None
# --- Custom Streamer Class (Modified for Queue) ---
class QueueStreamer(TextStreamer):
def __init__(self, tokenizer, skip_prompt, q):
super().__init__(tokenizer, skip_prompt=skip_prompt)
self.queue = q
self.stop_signal = None # Can be used if needed, but queue is primary
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Puts the text onto the queue."""
self.queue.put(text)
if stream_end:
self.end()
def end(self):
"""Signals the end of generation by putting None in the queue."""
self.queue.put(self.stop_signal) # Put None (or a specific sentinel)
# --- Sidebar for Settings ---
with st.sidebar:
st.header("Model Configuration")
st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}` (CPU Mode).")
st.header("Generation Settings")
temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05)
# min_p might not be as commonly used or effective without top_p/top_k,
# but keeping it allows experimentation. Consider using top_k or top_p instead.
# Example: top_p = st.slider("Top P", min_value=0.01, max_value=1.0, value=0.9, step=0.01)
min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01) # Keep for now
max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=256, step=50) # Reduced default for CPU
if st.button("Clear Chat History"):
st.session_state.messages = []
st.rerun() # Rerun to clear display immediately
# --- Load Model (runs only once on first run or if cache is cleared) ---
model, tokenizer = load_model_and_tokenizer_cpu(DEFAULT_MODEL_PATH)
# --- Initialize Session State ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Main Chat Interface ---
if model is None or tokenizer is None:
st.error("CPU Model loading failed. Please check the path, available RAM, and logs. Cannot proceed.")
st.stop()
# Display conversation history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle user input
user_input = st.chat_input("Ask the fine-tuned model (CPU)...")
if user_input:
# Add user message to history and display it
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
# Prepare for model response
with st.chat_message("assistant"):
response_placeholder = st.empty()
response_placeholder.markdown("Generating response on CPU... please wait... ▌") # Initial message
text_queue = queue.Queue() # Create a queue for this specific response
# Initialize the modified streamer
text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue)
# Prepare input for the model
messages_for_model = st.session_state.messages
try:
# Ensure inputs are on the CPU (model.device should be 'cpu' now)
target_device = model.device
# print(f"Model device: {target_device}") # Debugging: should print 'cpu'
if tokenizer.chat_template:
inputs = tokenizer.apply_chat_template(
messages_for_model,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(target_device) # Send input tensors to CPU
else:
prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:"
inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(target_device) # Send input tensors to CPU
# Generation arguments
generation_kwargs = dict(
input_ids=inputs,
streamer=text_streamer, # Use the QueueStreamer
max_new_tokens=max_tokens,
use_cache=True, # Caching can still help CPU generation speed
temperature=temperature if temperature > 0 else None,
top_p=None, # Consider adding top_p slider in UI
# top_k=50, # Example: Or use top_k
min_p=min_p,
do_sample=True if temperature > 0 else False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
)
# Define the target function for the thread
def generation_thread_func():
try:
# Run generation in the background thread (on CPU)
# Wrap in torch.no_grad() to save memory during inference
with torch.no_grad():
model.generate(**generation_kwargs)
except Exception as e:
# If error occurs in thread, signal stop and maybe log
print(f"Error in generation thread: {e}")
# Attempt to put error message in queue? Or just rely on main thread error handling
st.error(f"Error during generation: {e}") # Show error in UI too
finally:
# Ensure the queue loop terminates even if error occurred
text_streamer.end()
# Start the generation thread
thread = threading.Thread(target=generation_thread_func)
thread.start()
# --- Main thread: Read from queue and update UI ---
generated_text = ""
while True:
try:
# Get the next text chunk from the queue
# Use timeout to prevent blocking indefinitely if thread hangs
chunk = text_queue.get(block=True, timeout=1) # Short timeout OK for slow CPU gen
if chunk is text_streamer.stop_signal: # Check for end signal (None)
break
generated_text += chunk
response_placeholder.markdown(generated_text + "▌") # Update placeholder
except queue.Empty:
# If queue is empty, check if the generation thread is still running
if not thread.is_alive():
# Thread finished, but maybe didn't put the stop signal (error?)
break # Exit loop
# Otherwise, continue waiting for next chunk
continue
except Exception as e:
st.error(f"Error reading from generation queue: {e}")
print(f"Error reading from queue: {e}")
break # Exit loop on queue error
# Final update without the cursor
response_placeholder.markdown(generated_text)
# Add the complete assistant response to history *after* generation
if generated_text: # Only add if something was generated
st.session_state.messages.append({"role": "assistant", "content": generated_text})
else:
# Handle case where generation failed silently in thread or produced nothing
if not any(m['role'] == 'assistant' and m['content'].startswith("*Error") for m in st.session_state.messages):
st.warning("Assistant produced no output.", icon="⚠️")
# Wait briefly for the thread to finish if it hasn't already
thread.join(timeout=5.0) # Longer timeout might be needed if cleanup is slow
except Exception as e:
st.error(f"Error during generation setup or queue handling: {e}", icon="🔥")
print(f"Error setting up generation or handling queue: {e}")
# Add error to chat history for context
error_message = f"*Error generating response: {e}*"
if not generated_text: # Add if no text was generated at all
st.session_state.messages.append({"role": "assistant", "content": error_message})
response_placeholder.error(f"Error generating response: {e}")
else: # Append error notice if some text was generated before error
st.session_state.messages.append({"role": "assistant", "content": generated_text + "\n\n" + error_message})
response_placeholder.markdown(generated_text + f"\n\n*{error_message}*") |