File size: 5,115 Bytes
0b70ac0
fa3a13f
5a3f7ee
 
fa3a13f
5a3f7ee
 
 
fa3a13f
5a3f7ee
 
 
 
fa3a13f
5a3f7ee
 
8227d8e
47fe778
5a3f7ee
fa3a13f
5a3f7ee
 
8227d8e
 
5a3f7ee
 
 
fa3a13f
5a3f7ee
 
 
8227d8e
5a3f7ee
 
 
 
 
8227d8e
5a3f7ee
 
 
 
 
8227d8e
5a3f7ee
 
 
 
 
 
 
 
8227d8e
47fe778
5a3f7ee
 
 
 
 
 
fa3a13f
5a3f7ee
fa3a13f
5a3f7ee
 
 
 
 
 
8227d8e
5a3f7ee
47fe778
5a3f7ee
 
 
fa3a13f
5a3f7ee
 
 
 
 
 
 
fa3a13f
 
5a3f7ee
 
 
 
fa3a13f
 
5a3f7ee
 
 
 
 
8227d8e
 
 
 
 
 
 
 
 
 
5a3f7ee
 
 
8227d8e
 
5a3f7ee
 
 
 
 
fa3a13f
8227d8e
 
 
 
fa3a13f
5a3f7ee
 
0b70ac0
5a3f7ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# --- Konfiguration ---
MODEL_ID = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
HF_TOKEN = os.getenv("HF_TOKEN") # Optional: Für private Modelle oder Zugriffsbeschränkungen

# --- Lade Modell und Tokenizer (explizit auf CPU) ---
print(f"Lade Tokenizer: {MODEL_ID}")
# Stelle sicher, dass trust_remote_code=True gesetzt ist, da Qwen3 dies oft benötigt
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)

if tokenizer.pad_token is None:
    print("pad_token nicht gesetzt, verwende eos_token als pad_token.")
    tokenizer.pad_token = tokenizer.eos_token

print(f"Lade Modell: {MODEL_ID} auf CPU. Dies kann einige Zeit dauern...")
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
        trust_remote_code=True,
        token=HF_TOKEN
    )
except Exception as e:
    print(f"Fehler beim Laden mit bfloat16 ({e}), versuche float32...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32,
        device_map="cpu",
        trust_remote_code=True,
        token=HF_TOKEN
    )

model.eval()
print("Modell und Tokenizer erfolgreich geladen.")

# --- Vorhersagefunktion für das ChatInterface ---
def predict(message, history):
    messages_for_template = []
    for user_msg, ai_msg in history: # history ist jetzt eine Liste von Listen/Tupeln
        messages_for_template.append({"role": "user", "content": user_msg})
        messages_for_template.append({"role": "assistant", "content": ai_msg})
    messages_for_template.append({"role": "user", "content": message})

    try:
        prompt = tokenizer.apply_chat_template(
            messages_for_template,
            tokenize=False,
            add_generation_prompt=True
        )
    except Exception as e:
        print(f"Fehler beim Anwenden des Chat-Templates: {e}")
        prompt_parts = []
        for turn in messages_for_template:
            prompt_parts.append(f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>")
        prompt = "\n".join(prompt_parts) + "\n<|im_start|>assistant\n"

    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to("cpu")

    generation_kwargs = {
        "max_new_tokens": 512,
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 50,
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id,
    }

    print("Generiere Antwort...")
    with torch.no_grad():
        outputs = model.generate(**inputs, **generation_kwargs)

    response_ids = outputs[0][inputs.input_ids.shape[-1]:]
    response = tokenizer.decode(response_ids, skip_special_tokens=True)
    print(f"Antwort: {response}")
    return response

# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek Qwen3 8B (CPU)") as demo:
    gr.Markdown(
        """
        # DeepSeek Qwen3 8B Chat (CPU)
        Dies ist eine Demo des `deepseek-ai/DeepSeek-R1-0528-Qwen3-8B` Modells, das auf einer CPU läuft.
        **Achtung:** Antworten können aufgrund der CPU-Inferenz **sehr langsam** sein (mehrere Minuten pro Antwort sind möglich).
        Bitte habe Geduld.
        """
    )
    chatbot_interface = gr.ChatInterface(
        fn=predict,
        chatbot=gr.Chatbot(
            height=600,
            label="Chat",
            show_label=False,
            # bubble_full_width=False, # Entfernt, da veraltet
            # type="messages" # Wichtig, um die Warnung zu beheben, aber history-Format in predict() muss passen
                            # Da predict bereits die history als [[user, ai], [user, ai]] erwartet (Standard für ChatInterface),
                            # lassen wir type hier weg, damit es mit dem Format von predict harmoniert.
                            # Wenn predict `history` als [{"role": "user", ...}, {"role": "assistant", ...}] erwarten würde,
                            # dann wäre `type="messages"` hier richtig.
                            # Da die Warnung sich auf die Standardeinstellung bezieht, die bald "messages" sein wird,
                            # und unsere predict-Funktion bereits das "tuples"-Format verarbeitet, ist das OK für jetzt.
                            # Man könnte predict anpassen, um das "messages" Format direkt zu verarbeiten, wenn man type="messages" setzt.
        ),
        textbox=gr.Textbox(
            placeholder="Stelle mir eine Frage...",
            container=False,
            scale=7
        ),
        examples=[
            ["Hallo, wer bist du?"],
            ["Was ist die Hauptstadt von Frankreich?"],
            ["Schreibe ein kurzes Gedicht über KI."]
        ],
        # Entferne die nicht unterstützten Button-Argumente:
        # retry_btn="Wiederholen",
        # undo_btn="Letzte entfernen",
        # clear_btn="Chat löschen",
    )
    gr.Markdown("Modell von [deepseek-ai](https://huggingface.co/deepseek-ai) auf Hugging Face.")

if __name__ == "__main__":
    demo.launch()