File size: 3,151 Bytes
5138ffd
 
e653440
3ce5fb6
 
5138ffd
 
 
 
e653440
5138ffd
389e675
e653440
389e675
e653440
 
 
 
 
 
 
 
 
 
 
 
 
3ce5fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e653440
3ce5fb6
e653440
3ce5fb6
 
 
 
 
 
 
 
 
 
 
e653440
fd28db2
5138ffd
fd28db2
 
 
 
5138ffd
 
fd28db2
e653440
b02ce42
5138ffd
21c264c
b02ce42
fd28db2
b02ce42
 
 
 
 
 
fd28db2
 
b02ce42
5138ffd
e653440
5138ffd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import os
import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)

# Load the tokenizer and model
tokenizer = GemmaTokenizer.from_pretrained("google/codegemma-7b-it")
model = AutoModelForCausalLM.from_pretrained("google/codegemma-7b-it", device_map="auto")

@spaces.GPU(duration=120)
def codegemma(message: str, history: list, temperature: float, max_new_tokens: int) -> str:
    """
    Generate a response using the CodeGemma model.

    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.

    Returns:
        str: The generated response.
    """
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            chat.append({"role": "assistant", "content": item[1]})
    chat.append({"role": "user", "content": message})
    messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    # Tokenize the messages string
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
    )
    
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Initialize an empty string to store the generated text
    partial_text = ""
    for new_text in streamer:
        # print(new_text)
        partial_text += new_text
        # Yield an empty string to cleanup the message textbox and the updated conversation history
        yield partial_text


placeholder = """
<div style="opacity: 0.65;">
    <img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/7dd7659cff2eab51f0f5336f378edfca01dd16fa/gemma_lockup_vertical_full-color_rgb.png" style="width:30%;">
    <br><b>CodeGemma-7B-IT Chatbot</b>
</div>
"""


# Gradio block
chatbot=gr.Chatbot(placeholder=placeholder,)
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown("# CODEGEMMA-7b-IT")
    gr.ChatInterface(codegemma,
                     chatbot=chatbot,
                     fill_height=True,
                     additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
                     additional_inputs=[
                         gr.Slider(0, 1, 0.95, label="Temperature", render=False),
                         gr.Slider(128, 4096, 512, label="Max new tokens", render=False ),
                         ],
                     examples=[["Write a Python function to calculate the nth fibonacci number."]],
                     cache_examples=False,
                     )
    

if __name__ == "__main__":
    demo.launch(debug=False)