File size: 3,644 Bytes
c2b2585
8445393
c2b2585
 
 
5c62403
c2b2585
 
2d2dd9a
7cb6518
 
2d2dd9a
c2b2585
e4b3deb
c2b2585
e4b3deb
c2b2585
183e675
c2b2585
183e675
c2b2585
fdb003a
 
 
8445393
c2b2585
 
71f0ed8
8445393
 
 
 
 
 
46e822a
8445393
 
c2b2585
 
8445393
 
183e675
fdb003a
8445393
fdb003a
 
 
5c62403
c2b2585
8445393
 
2503b95
c2b2585
fdb003a
c2b2585
fdb003a
a590790
fdb003a
 
 
 
c2b2585
fdb003a
4f51b52
fdb003a
 
c2b2585
be37eee
c2b2585
c81e522
8445393
fdb003a
c2b2585
2d2dd9a
c81e522
2d2dd9a
8445393
c347b21
8445393
 
 
 
fdb003a
97c87a2
fdb003a
8445393
c2b2585
fdb003a
 
2503b95
c2b2585
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from threading import Thread  # Import the Thread class from the threading module

import torch  # Import the PyTorch library
import gradio as gr  # Import Gradio for creating a UI
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer  # Import Hugging Face Transformers

# Define the Hugging Face model ID and check for available GPU (cuda)
model_id = "declare-lab/flan-alpaca-large"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())

# Load the pre-trained model based on the device
if torch_device == "cuda":
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# Define a function to run model text generation
def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
    # Get the model and tokenizer, and tokenize the user text.
    model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)

    # Start generation on a separate thread, so that we don't block the UI.
    # Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        temperature=float(temperature),
        top_k=top_k
    )
    
    # Create a new thread for model generation
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output

# Define a function to reset the user input textbox
def reset_textbox():
    return gr.update(value='')

# Create a Gradio UI interface
with gr.Blocks() as demo:
    # Display a title
    gr.Markdown(
        "# Testing ALPACA Model \n"
    )

    with gr.Row():
        with gr.Column(scale=4):
            # Create a textbox for user input
            user_text = gr.Textbox(
                placeholder="Ask Me Anything ... ",
                label="User input"
            )
            # Create a textbox for model output
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            # Create a submit button
            button_submit = gr.Button(value="Submit")

        with gr.Column(scale=1):
            # Create sliders for adjusting generation parameters
            max_new_tokens = gr.Slider(
                minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
            )
            top_p = gr.Slider(
                minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
            )
            top_k = gr.Slider(
                minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
            )
            temperature = gr.Slider(
                minimum=0.1, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature",
            )

    # Set up the submission of user input
    user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
    button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)

    # Launch the Gradio interface
    demo.queue(max_size=32).launch(enable_queue=True)