File size: 2,612 Bytes
84e4eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, logging as hf_logging
from threading import Thread
import gradio as gr
from huggingface_hub import login

# --- Hugging Face authentication ---
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
    raise ValueError("Please set the HF_TOKEN environment variable.")
login(token=HF_TOKEN)

hf_logging.set_verbosity_error()  # suppress warnings

# --- Model ID ---
model_id = "motionlabs/NEWT-1.7B-QWEN-PREVIEW"

# --- Logs helper ---
log_messages = []

def log(msg):
    log_messages.append(msg)
    print(msg)
    return "\n".join(log_messages)

log("Initializing tokenizer and model…")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
log("Tokenizer loaded.")

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    use_auth_token=HF_TOKEN
)
log("Model loaded.")

# --- Chat streaming ---
def stream_chat(history, message):
    messages = []
    for user, bot in history:
        messages.append({"role": "user", "content": user})
        if bot:
            messages.append({"role": "assistant", "content": bot})
    messages.append({"role": "user", "content": message})

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    gen_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.9,
        temperature=0.7,
    )

    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    output_text = ""
    for token in streamer:
        output_text += token
        yield history + [(message, output_text)]

# --- Gradio UI ---
with gr.Blocks(title=f"Chat with {model_id}") as demo:
    gr.Markdown(f"# Chat with {model_id}")
    
    chatbot = gr.Chatbot()
    msg = gr.Textbox(placeholder="Type your message here…")
    clear = gr.Button("Clear")
    logs = gr.Textbox(label="Logs", value="\n".join(log_messages), interactive=False)

    def user_submit(user_message, history):
        return "", history + [(user_message, None)]

    msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
        stream_chat, [chatbot, msg], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch()