File size: 3,248 Bytes
d81d6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422252e
d81d6d2
 
 
 
 
422252e
 
d81d6d2
422252e
 
d81d6d2
 
 
 
 
 
 
 
 
 
 
 
 
d38f5f1
 
422252e
d81d6d2
 
 
 
 
 
 
 
 
d38f5f1
422252e
 
 
d81d6d2
 
 
 
 
 
 
 
 
 
 
6797eb5
d81d6d2
 
6797eb5
422252e
6797eb5
 
 
 
 
d81d6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6797eb5
d81d6d2
 
 
 
ed97a98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from threading import Thread
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    TextIteratorStreamer
)

MODEL_ID = "universeTBD/astrollama"
WINDOW_SIZE = 4096
DEVICE = "cuda"

config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID
)

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    config=config,
    device_map="auto",
    use_safetensors=True,
    trust_remote_code=True,
    load_in_4bit=True,
    torch_dtype=torch.bfloat16
)


def generate_text(prompt: str,
                  max_new_tokens: int = 512,
                  temperature: float = 0.5,
                  top_p: float = 0.95,
                  top_k: int = 50) -> str:
    # Encode the prompt
    inputs = tokenizer([prompt],
                       return_tensors='pt',
                       add_special_tokens=False).to(DEVICE)
    # Prepare arguments for generation
    input_length = inputs["input_ids"].shape[-1]
    max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
    if temperature >= 1.0:
        temperature = 0.99
    elif temperature <= 0.0:
        temperature = 0.01
    if top_p > 1.0 or top_p <= 0.0:
        top_p = 1.0
    if top_k <= 0:
        top_k = 100
    streamer = TextIteratorStreamer(tokenizer,
                                    timeout=10.,
                                    skip_prompt=True,
                                    skip_special_tokens=True)
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
    )
    # Generate text
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    generated_text = prompt
    for new_text in streamer:
        generated_text += new_text
    return generated_text


demo = gr.Interface(
    fn=generate_text,
    inputs=[
        # Prompt
        gr.Textbox(
            label="Prompt",
            container=False,
            show_label=False,
            placeholder="Enter some text...",
            lines=10,
            scale=10,
        ),
        gr.Slider(
            label="Maximum new tokens",
            minimum=1,
            maximum=4096,
            step=1,
            value=1024,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.01,
            maximum=0.99,
            step=0.01,
            value=0.5,
        ),
        gr.Slider(
            label="Top-p (for sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
        gr.Slider(
            label='Top-k (for sampling)',
            minimum=1,
            maximum=1000,
            step=1,
            value=100,
        )
    ],
    outputs=[
        gr.Textbox(
            container=False,
            show_label=False,
            placeholder="Generated output...",
            scale=10,
            lines=10,
        )
    ],
)

demo.queue(max_size=20).launch()