File size: 3,623 Bytes
04b4d4a
 
 
 
 
 
 
 
 
 
5ab0078
 
 
 
 
 
04b4d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ab0078
 
 
 
 
 
 
04b4d4a
 
 
 
 
 
 
 
5ab0078
 
04b4d4a
 
 
 
5ab0078
04b4d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

@st.cache_resource
def load_model():
    """Load model and tokenizer with caching"""
    try:
        tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
        model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-3.2-1B")
        
        # Set up padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = model.config.eos_token_id
            
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None, None

# Page config
st.set_page_config(page_title="Chat with Quasar-32B", layout="wide")
st.title("Chat with Quasar-32B")

# Initialize session state for chat history
if 'messages' not in st.session_state:
    st.session_state.messages = []

# Load model and tokenizer
model, tokenizer = load_model()

# Chat interface
def generate_response(prompt):
    """Generate response from the model"""
    try:
        # Prepare the input
        inputs = tokenizer(
            prompt, 
            return_tensors="pt", 
            padding=True,
            truncation=True,
            max_length=512  # Add max length for input
        )
        
        # Generate response
        with torch.no_grad():
            outputs = model.generate(
                inputs["input_ids"],
                max_length=200,
                num_return_sequences=1,
                temperature=0.7,
                pad_token_id=tokenizer.pad_token_id,
                attention_mask=inputs["attention_mask"]  # Add attention mask
            )
        
        # Decode and return the response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response.replace(prompt, "").strip()  # Remove the input prompt from response
    except Exception as e:
        return f"Error generating response: {str(e)}"

# Chat interface
st.write("### Chat")
chat_container = st.container()

# Display chat history
with chat_container:
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.write(message["content"])

# User input
if prompt := st.chat_input("Type your message here"):
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})
    
    # Display user message
    with chat_container:
        with st.chat_message("user"):
            st.write(prompt)
    
    # Generate and display assistant response
    if model and tokenizer:
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                response = generate_response(prompt)
                st.write(response)
                st.session_state.messages.append({"role": "assistant", "content": response})
    else:
        st.error("Model failed to load. Please check your configuration.")

# Add a button to clear chat history
if st.button("Clear Chat History"):
    st.session_state.messages = []
    st.experimental_rerun()

# Display system information
with st.sidebar:
    st.write("### System Information")
    st.write("Model: Quasar-32B")
    st.write("Status: Running" if model and tokenizer else "Status: Not loaded")
    
    # Add some helpful instructions
    st.write("### Instructions")
    st.write("1. Type your message in the chat input")
    st.write("2. Press Enter or click Send")
    st.write("3. Wait for the AI to respond")
    st.write("4. Use 'Clear Chat History' to start fresh")