File size: 4,585 Bytes
0c8db5a
 
 
5b6a681
0c8db5a
 
 
 
5b6a681
0c8db5a
 
 
 
5b6a681
0c8db5a
 
 
5b6a681
0c8db5a
5b6a681
0c8db5a
 
5b6a681
0c8db5a
 
5b6a681
0c8db5a
 
 
 
 
 
 
 
 
 
 
5b6a681
0c8db5a
 
5b6a681
0c8db5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee701f8
0c8db5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b6a681
0c8db5a
 
5b6a681
0c8db5a
5b6a681
0c8db5a
 
 
 
 
 
 
 
5b6a681
0c8db5a
 
 
 
 
 
 
5b6a681
0c8db5a
5b6a681
0c8db5a
 
 
 
 
 
 
 
 
 
 
 
5b6a681
0c8db5a
5b6a681
 
0c8db5a
 
 
 
 
 
 
 
 
 
 
 
 
 
5b6a681
 
0c8db5a
 
 
 
5b6a681
 
0c8db5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Token limits
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 512

# Description
DESCRIPTION = """\
# Demo for "Self-Training Elicits Concise Reasoning in Large Language Models"

This Space showcases the model [tergel/llama-3.2-3b-instruct-gsm8k-fs-gpt4o-bon](https://huggingface.co/tergel/llama-3.2-3b-instruct-gsm8k-fs-gpt4o-bon)

We provide a simple chat interface allowing you to observe the concise CoT solutions that our model can produce. Feel free to play with it.
"""

# Decide on device
device = "cuda" if torch.cuda.is_available() else "cpu"

if not torch.cuda.is_available():
    DESCRIPTION += "\n\n<p>**Warning**: Running on CPU 🥶 – this may be extremely slow. We will upgrade to GPUs soon.</p>"
    
# Load model and tokenizer
model_id = "tergel/llama-3.2-3b-instruct-gsm8k-fs-gpt4o-bon"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=None if device == "cpu" else "auto",
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
)
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False

@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    system_prompt: str = "",
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    temperature: float = 0.7,
    top_p: float = 0.95,
    top_k: int = 40,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    # Build conversation
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})
    
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)
    

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=40,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        [
            "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?"
        ],
        [
            "Claire makes a 3 egg omelet every morning for breakfast.  How many dozens of eggs will she eat in 4 weeks?"
        ],
        [
            "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?"
        ],
    ],
    cache_examples=False,
    type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()


if __name__ == "__main__":
    demo.queue(max_size=20).launch()