File size: 10,211 Bytes
53ec504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import streamlit as st
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, TextStreamer
# bitsandbytes is no longer needed
import io
import sys
import threading
import time
import queue # Import the queue module

# --- Configuration ---
DEFAULT_MODEL_PATH = "lora_model" # Or your default path
# DEFAULT_LOAD_IN_4BIT is removed as we are not using quantization

# --- Page Configuration ---
st.set_page_config(page_title="Fine-tuned LLM Chat Interface (CPU)", layout="wide")
st.title("Fine-tuned LLM Chat Interface (CPU Mode)")
st.warning("Running in CPU mode. Expect slower generation times and higher RAM usage.", icon="⚠️")

# --- Model Loading (Cached for CPU) ---
@st.cache_resource(show_spinner="Loading model and tokenizer onto CPU...")
def load_model_and_tokenizer_cpu(model_path):
    """Loads the PEFT model and tokenizer onto the CPU."""
    try:
        # Use standard float32 for CPU compatibility and stability
        torch_dtype = torch.float32

        model = AutoPeftModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch_dtype,
            # load_in_4bit=False, # Explicitly removed/not needed
            device_map="cpu",   # Force loading onto CPU
        )
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model.eval() # Set model to evaluation mode
        print("Model and tokenizer loaded successfully onto CPU.")
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading model from path '{model_path}' onto CPU: {e}", icon="🚨")
        print(f"Error loading model onto CPU: {e}")
        return None, None

# --- Custom Streamer Class (Modified for Queue) ---
class QueueStreamer(TextStreamer):
    def __init__(self, tokenizer, skip_prompt, q):
        super().__init__(tokenizer, skip_prompt=skip_prompt)
        self.queue = q
        self.stop_signal = None # Can be used if needed, but queue is primary

    def on_finalized_text(self, text: str, stream_end: bool = False):
        """Puts the text onto the queue."""
        self.queue.put(text)
        if stream_end:
             self.end()

    def end(self):
        """Signals the end of generation by putting None in the queue."""
        self.queue.put(self.stop_signal) # Put None (or a specific sentinel)


# --- Sidebar for Settings ---
with st.sidebar:
    st.header("Model Configuration")
    st.info(f"Model loaded on startup: `{DEFAULT_MODEL_PATH}` (CPU Mode).")

    st.header("Generation Settings")
    temperature = st.slider("Temperature", min_value=0.0, max_value=2.0, value=0.7, step=0.05)
    # min_p might not be as commonly used or effective without top_p/top_k,
    # but keeping it allows experimentation. Consider using top_k or top_p instead.
    # Example: top_p = st.slider("Top P", min_value=0.01, max_value=1.0, value=0.9, step=0.01)
    min_p = st.slider("Min P", min_value=0.01, max_value=1.0, value=0.1, step=0.01) # Keep for now
    max_tokens = st.slider("Max New Tokens", min_value=50, max_value=2048, value=256, step=50) # Reduced default for CPU

    if st.button("Clear Chat History"):
        st.session_state.messages = []
        st.rerun() # Rerun to clear display immediately


# --- Load Model (runs only once on first run or if cache is cleared) ---
model, tokenizer = load_model_and_tokenizer_cpu(DEFAULT_MODEL_PATH)

# --- Initialize Session State ---
if "messages" not in st.session_state:
    st.session_state.messages = []

# --- Main Chat Interface ---
if model is None or tokenizer is None:
    st.error("CPU Model loading failed. Please check the path, available RAM, and logs. Cannot proceed.")
    st.stop()

# Display conversation history
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Handle user input
user_input = st.chat_input("Ask the fine-tuned model (CPU)...")

if user_input:
    # Add user message to history and display it
    st.session_state.messages.append({"role": "user", "content": user_input})
    with st.chat_message("user"):
        st.markdown(user_input)

    # Prepare for model response
    with st.chat_message("assistant"):
        response_placeholder = st.empty()
        response_placeholder.markdown("Generating response on CPU... please wait... ▌") # Initial message
        text_queue = queue.Queue() # Create a queue for this specific response
        # Initialize the modified streamer
        text_streamer = QueueStreamer(tokenizer, skip_prompt=True, q=text_queue)

        # Prepare input for the model
        messages_for_model = st.session_state.messages

        try:
            # Ensure inputs are on the CPU (model.device should be 'cpu' now)
            target_device = model.device
            # print(f"Model device: {target_device}") # Debugging: should print 'cpu'

            if tokenizer.chat_template:
                 inputs = tokenizer.apply_chat_template(
                    messages_for_model,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt"
                 ).to(target_device) # Send input tensors to CPU
            else:
                 prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_model]) + "\nassistant:"
                 inputs = tokenizer(prompt_text, return_tensors="pt").input_ids.to(target_device) # Send input tensors to CPU

            # Generation arguments
            generation_kwargs = dict(
                input_ids=inputs,
                streamer=text_streamer, # Use the QueueStreamer
                max_new_tokens=max_tokens,
                use_cache=True, # Caching can still help CPU generation speed
                temperature=temperature if temperature > 0 else None,
                top_p=None,     # Consider adding top_p slider in UI
                # top_k=50,       # Example: Or use top_k
                min_p=min_p,
                do_sample=True if temperature > 0 else False,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
            )

            # Define the target function for the thread
            def generation_thread_func():
                try:
                    # Run generation in the background thread (on CPU)
                    # Wrap in torch.no_grad() to save memory during inference
                    with torch.no_grad():
                        model.generate(**generation_kwargs)
                except Exception as e:
                    # If error occurs in thread, signal stop and maybe log
                    print(f"Error in generation thread: {e}")
                    # Attempt to put error message in queue? Or just rely on main thread error handling
                    st.error(f"Error during generation: {e}") # Show error in UI too
                finally:
                    # Ensure the queue loop terminates even if error occurred
                    text_streamer.end()


            # Start the generation thread
            thread = threading.Thread(target=generation_thread_func)
            thread.start()

            # --- Main thread: Read from queue and update UI ---
            generated_text = ""
            while True:
                try:
                    # Get the next text chunk from the queue
                    # Use timeout to prevent blocking indefinitely if thread hangs
                    chunk = text_queue.get(block=True, timeout=1) # Short timeout OK for slow CPU gen
                    if chunk is text_streamer.stop_signal: # Check for end signal (None)
                        break
                    generated_text += chunk
                    response_placeholder.markdown(generated_text + "▌") # Update placeholder
                except queue.Empty:
                    # If queue is empty, check if the generation thread is still running
                    if not thread.is_alive():
                        # Thread finished, but maybe didn't put the stop signal (error?)
                        break # Exit loop
                    # Otherwise, continue waiting for next chunk
                    continue
                except Exception as e:
                     st.error(f"Error reading from generation queue: {e}")
                     print(f"Error reading from queue: {e}")
                     break # Exit loop on queue error

            # Final update without the cursor
            response_placeholder.markdown(generated_text)

            # Add the complete assistant response to history *after* generation
            if generated_text: # Only add if something was generated
                 st.session_state.messages.append({"role": "assistant", "content": generated_text})
            else:
                 # Handle case where generation failed silently in thread or produced nothing
                 if not any(m['role'] == 'assistant' and m['content'].startswith("*Error") for m in st.session_state.messages):
                     st.warning("Assistant produced no output.", icon="⚠️")


            # Wait briefly for the thread to finish if it hasn't already
            thread.join(timeout=5.0) # Longer timeout might be needed if cleanup is slow


        except Exception as e:
            st.error(f"Error during generation setup or queue handling: {e}", icon="🔥")
            print(f"Error setting up generation or handling queue: {e}")
            # Add error to chat history for context
            error_message = f"*Error generating response: {e}*"
            if not generated_text: # Add if no text was generated at all
                st.session_state.messages.append({"role": "assistant", "content": error_message})
                response_placeholder.error(f"Error generating response: {e}")
            else: # Append error notice if some text was generated before error
                st.session_state.messages.append({"role": "assistant", "content": generated_text + "\n\n" + error_message})
                response_placeholder.markdown(generated_text + f"\n\n*{error_message}*")