File size: 4,975 Bytes
4375b7f
4e683ec
76a154f
 
574a1e8
 
 
 
4e683ec
574a1e8
76a154f
 
4375b7f
76a154f
574a1e8
4e683ec
7c825ed
 
76a154f
 
574a1e8
76a154f
ea3f2b3
98df5b4
574a1e8
4e683ec
c25c78c
d534002
4e683ec
e9816b5
76a154f
574a1e8
2d13e8e
76a154f
 
4e683ec
76a154f
4e683ec
27319be
4e683ec
 
 
 
574a1e8
846f032
ef40822
9a170de
574a1e8
4e683ec
846f032
574a1e8
846f032
4e683ec
574a1e8
318864b
6111f2c
 
 
 
4e683ec
574a1e8
3f60a5e
4e683ec
6111f2c
4e683ec
 
 
 
 
 
 
 
 
 
 
76a154f
574a1e8
4e683ec
 
 
 
76a154f
574a1e8
4e683ec
 
a6da5c7
4e683ec
a400f4b
4e683ec
 
76a154f
 
 
 
4e683ec
 
 
76a154f
 
 
4e683ec
 
 
 
76a154f
 
 
4e683ec
 
 
 
76a154f
 
 
 
4e683ec
 
 
 
 
 
 
 
 
90753c9
4e683ec
574a1e8
4e683ec
 
 
ab9a889
574a1e8
ab9a889
8d7b027
 
 
574a1e8
4e683ec
574a1e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
from threading import Thread
from typing import Iterator

import gradio as gr  # Importing Gradio for creating UI interfaces.
import spaces  # Import for using Hugging Face Spaces functionalities.
import torch  # PyTorch library for deep learning applications.
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer  # Import necessary components from Hugging Face's Transformers.

# Constants for maximum token lengths and defaults.
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

# Initial description for the UI interface, showcasing the AI version and creator.
DESCRIPTION = """\
# Masher AI v6 7B
This Space demonstrates Masher AI v6 7B by Maheswar.
"""

# Check for GPU availability, append a warning to the description if running on CPU.
if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"

# If a GPU is available, load the model and tokenizer with specific configurations.
if torch.cuda.is_available():
    model_id = "mahiatlinux/MasherAI-v6-7B"
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False

# Define a function decorated to use GPU and enable queue for processing the generation tasks.
@spaces.GPU(enable_queue=False)
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.1,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    # Preparing conversation history for processing.
    conversation = []
    # Adding system prompt.
    # conversation.append({"from": "human", "value": system_prompt})
    # Extending the conversation history with user and assistant interactions.
    for user, assistant in chat_history:
        conversation.extend([{"from": "human", "value": user}, {"from": "gpt", "value": assistant}])
    # Adding the latest message from the user to the conversation.
    conversation.append({"from": "human", "value": message})

    # Tokenize and prepare the input, handle exceeding token lengths.
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    # Setup for asynchronous text generation.
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Collect and yield generated outputs as they become available.
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Setup Gradio interface for chat, including additional controls for the generation parameters.
chat_interface = gr.ChatInterface(
    fn=generate,
    fill_height=True,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        # Examples to assist users in starting conversations with the AI.
    ],
)

chatbot=gr.Chatbot(height=450, label='Gradio ChatInterface')
# Setup and launch the Gradio demo with Blocks API.
with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

# Main entry point to start the web application if this script is run directly.
if __name__ == "__main__":
    demo.queue(max_size=20).launch()