File size: 4,548 Bytes
4ac9df6
971fbd1
 
 
9957949
971fbd1
d9760ae
 
dd36385
2aecac8
971fbd1
dd36385
4ac9df6
 
f030151
4ac9df6
d9760ae
4ac9df6
 
d9760ae
 
 
 
971fbd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c376a77
971fbd1
 
 
 
 
 
 
 
d9760ae
 
971fbd1
d9760ae
971fbd1
d9760ae
971fbd1
 
 
 
d9760ae
 
971fbd1
 
2aecac8
 
4ac9df6
971fbd1
 
 
 
 
 
 
 
 
 
 
dd36385
971fbd1
d850ee8
2aecac8
 
d850ee8
d9760ae
 
971fbd1
d9760ae
 
 
 
 
dd36385
d9760ae
 
971fbd1
 
 
d9760ae
971fbd1
 
 
 
 
 
d9760ae
 
971fbd1
d9760ae
 
971fbd1
d9760ae
 
 
 
 
 
971fbd1
d9760ae
 
d850ee8
971fbd1
d9760ae
 
 
4ac9df6
d9760ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Any

# Configure model (updated for local execution)
DEFAULT_SYSTEM_PROMPT = """You are a friendly Assistant. Provide clear, accurate, and brief answers. 
Keep responses polite, engaging, and to the point. If unsure, politely suggest alternatives."""

MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"  # Directly specify model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Page configuration
st.set_page_config(
    page_title="DeepSeek-AI R1",
    page_icon="🤖",
    layout="centered"
)

def initialize_session_state():
    """Initialize all session state variables"""
    if "messages" not in st.session_state:
        st.session_state.messages = []
    if "model_loaded" not in st.session_state:
        st.session_state.update({
            "model_loaded": False,
            "model": None,
            "tokenizer": None
        })

def load_model():
    """Load model and tokenizer with quantization"""
    if not st.session_state.model_loaded:
        with st.spinner("Loading model (this may take a minute)..."):
            tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            
            st.session_state.update({
                "model": model,
                "tokenizer": tokenizer,
                "model_loaded": True
            })

def configure_sidebar() -> Dict[str, Any]:
    """Create sidebar components"""
    with st.sidebar:
        st.header("Configuration")
        return {
            "system_message": st.text_area("System Message", value=DEFAULT_SYSTEM_PROMPT, height=100),
            "max_tokens": st.slider("Max Tokens", 10, 4000, 512),
            "temperature": st.slider("Temperature", 0.1, 1.0, 0.7),
            "top_p": st.slider("Top-p", 0.1, 1.0, 0.9)
        }

def format_prompt(system_message: str, user_input: str) -> str:
    """Format prompt according to model's required template"""
    return f"""<|begin_of_sentence|>System: {system_message}
<|User|>{user_input}<|Assistant|>"""

def generate_response(prompt: str, settings: Dict[str, Any]) -> str:
    """Generate response using local model"""
    inputs = st.session_state.tokenizer(prompt, return_tensors="pt").to(DEVICE)
    
    outputs = st.session_state.model.generate(
        inputs.input_ids,
        max_new_tokens=settings["max_tokens"],
        temperature=settings["temperature"],
        top_p=settings["top_p"],
        pad_token_id=st.session_state.tokenizer.eos_token_id
    )
    
    response = st.session_state.tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("\n</think>\n")[0].strip()
    response = response.replace("<|User|>", "").strip()
    response = response.replace("<|System|>", "").strip()
    return response.split("<|Assistant|>")[-1].strip()

def handle_chat_interaction(settings: Dict[str, Any]):
    """Manage chat interactions"""
    if prompt := st.chat_input("Type your message..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        
        with st.chat_message("user"):
            st.markdown(prompt)

        try:
            with st.spinner("Generating response..."):
                full_prompt = format_prompt(
                    settings["system_message"],
                    prompt
                )
                
                response = generate_response(full_prompt, settings)
                
                with st.chat_message("assistant"):
                    st.markdown(response)
                st.session_state.messages.append({"role": "assistant", "content": response})
        
        except Exception as e:
            st.error(f"Generation error: {str(e)}")

def display_chat_history():
    """Display chat history"""
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

def main():
    initialize_session_state()
    load_model()  # Load model before anything else
    settings = configure_sidebar()
    
    st.title("🤖 DeepSeek Chat")
    st.caption(f"Running {MODEL_NAME} directly on {DEVICE.upper()}")
    
    display_chat_history()
    handle_chat_interaction(settings)

if __name__ == "__main__":
    main()