File size: 4,593 Bytes
fccfdf4
97a2367
fccfdf4
ef628bc
 
fccfdf4
97a2367
 
 
4257e69
0d5774d
 
97a2367
 
 
 
0d5774d
97a2367
0d5774d
97a2367
 
 
03b1321
 
 
fccfdf4
97a2367
fccfdf4
03b1321
97a2367
 
 
 
03b1321
97a2367
 
03b1321
97a2367
 
 
 
03b1321
0d5774d
97a2367
0d5774d
97a2367
 
0d5774d
97a2367
 
 
 
 
 
 
 
03b1321
97a2367
 
fccfdf4
 
97a2367
0d5774d
97a2367
fccfdf4
97a2367
fccfdf4
 
97a2367
 
 
fccfdf4
 
97a2367
 
0d5774d
97a2367
 
0d5774d
97a2367
0d5774d
 
 
fccfdf4
 
 
 
 
 
 
 
 
97a2367
fccfdf4
 
 
 
 
 
 
 
 
 
 
97a2367
fccfdf4
 
 
7077c22
fccfdf4
 
 
 
97a2367
fccfdf4
 
 
97a2367
0d5774d
 
 
 
 
 
 
97a2367
 
 
 
0d5774d
ef628bc
fccfdf4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os

def initialize_model():
    """Initialize a small and fast model for CPU"""
    # Using a tiny model optimized for CPU
    model_id = "facebook/opt-125m"  # Much smaller model (125M parameters)
    model_id ="GEB-AGI/geb-1.3b"
    
    try:
        # Initialize the pipeline directly - more efficient than loading model separately
        pipe = pipeline(
            "text-generation",
            model=model_id,
            device_map="cpu",
            model_kwargs={"low_cpu_mem_usage": True}
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        return pipe, tokenizer
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        raise e

def generate_response(pipe, tokenizer, prompt, conversation_history):
    """Generate model response"""
    try:
        # Format conversation context
        context = ""
        for turn in conversation_history[-3:]:  # Only use last 3 turns for efficiency
            context += f"Human: {turn['user']}\nAssistant: {turn['assistant']}\n"
        
        # Create the full prompt
        full_prompt = f"{context}Human: {prompt}\nAssistant:"
        
        # Generate response with conservative parameters
        response = pipe(
            full_prompt,
            max_new_tokens=50,  # Limit response length
            temperature=0.7,
            top_p=0.9,
            num_return_sequences=1,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
        )[0]['generated_text']
        
        # Extract only the assistant's response
        try:
            assistant_response = response.split("Assistant:")[-1].strip()
            if not assistant_response:
                return "I apologize, but I couldn't generate a proper response."
            return assistant_response
        except:
            return response.split(prompt)[-1].strip()

    except Exception as e:
        return f"An error occurred: {str(e)}"

def main():
    st.set_page_config(page_title="LLM Chat Interface", page_icon="πŸ€–")
    
    st.title("πŸ’¬ Quick Chat Assistant")

    # Initialize session state
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []
    
    if "model_loaded" not in st.session_state:
        st.session_state.model_loaded = False

    # Initialize model (only once)
    if not st.session_state.model_loaded:
        with st.spinner("Loading the model... (this should take just a few seconds)"):
            try:
                pipe, tokenizer = initialize_model()
                st.session_state.pipe = pipe
                st.session_state.tokenizer = tokenizer
                st.session_state.model_loaded = True
            except Exception as e:
                st.error(f"Error loading model: {str(e)}")
                return

    # Display chat messages
    for message in st.session_state.chat_history:
        with st.chat_message("user"):
            st.write(message["user"])
        with st.chat_message("assistant"):
            st.write(message["assistant"])

    # Chat input
    if prompt := st.chat_input("Ask me anything!"):
        # Display user message
        with st.chat_message("user"):
            st.write(prompt)

        # Generate and display assistant response
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                current_turn = {"user": prompt, "assistant": ""}
                st.session_state.chat_history.append(current_turn)
                
                response = generate_response(
                    st.session_state.pipe,
                    st.session_state.tokenizer,
                    prompt,
                    st.session_state.chat_history
                )
                
                st.write(response)
                st.session_state.chat_history[-1]["assistant"] = response

        # Keep only last 5 turns
        if len(st.session_state.chat_history) > 5:
            st.session_state.chat_history = st.session_state.chat_history[-5:]

    # Sidebar
    with st.sidebar:
        if st.button("Clear Chat"):
            st.session_state.chat_history = []
            st.rerun()
        
        st.markdown("---")
        st.markdown("""
        ### Chat Info
        - Using OPT-125M model
        - Optimized for quick responses
        - Best for short conversations
        """)

if __name__ == "__main__":
    main()