Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Louis Brulé Naudet. All Rights Reserved. | |
# This software may be used and distributed according to the terms of the License Agreement. | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 2048 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
def setup( | |
model_id: str, | |
description: str | |
) -> tuple: | |
""" | |
Set up the model and tokenizer for a given pre-trained model ID. | |
Parameters | |
---------- | |
model_id : str | |
The ID of the pre-trained model to load. | |
description : str | |
A string containing additional description information. | |
Returns | |
------- | |
tuple | |
A tuple containing the loaded model, tokenizer, and updated description. | |
If an error occurs during setup, model and tokenizer are None, and an error message is appended to the description. | |
""" | |
if not torch.cuda.is_available(): | |
description += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
return None, None, description | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id | |
) | |
tokenizer.use_default_system_prompt = False | |
# Update the description | |
description += "\n<p>Model and tokenizer set up successfully.</p>" | |
return model, tokenizer, description | |
# except Exception as e: | |
# # If an error occurs during setup, append the error message to the description | |
# description += f"\n<p>Error occurred during model setup: {str(e)}</p>" | |
# return None, None, description | |
DESCRIPTION = """\ | |
# Pearl-7B-0211-ties, an xtraordinary 7B model | |
This space showcases the [Pearl-7B-0211-ties](https://huggingface.co/louisbrulenaudet/Pearl-7B-0211-ties) | |
model by Louis Brulé Naudet, a language model with 7.24 billion parameters that achieves a score exceeding 75.10 on the Open LLM Leaderboard | |
(average). | |
**03-22-2024 - To date, louisbrulenaudet/Pearl-34B-ties is the "Best 🤝 base merges and moerges model of around 30B" on the Open LLM Leaderboard.** | |
""" | |
model, tokenizer, description = setup( | |
model_id="louisbrulenaudet/Pearl-7B-0211-ties", | |
description=DESCRIPTION | |
) | |
def preprocess_conversation( | |
message: str, | |
chat_history: list, | |
system_prompt: str | |
): | |
""" | |
Preprocess the conversation history by formatting it appropriately. | |
Parameters | |
---------- | |
message : str | |
The user's message. | |
chat_history : list | |
The conversation history, where each element is a tuple (user_message, assistant_response). | |
system_prompt : str | |
The system prompt. | |
Returns | |
------- | |
list | |
The formatted conversation history. | |
""" | |
conversation = [] | |
if system_prompt: | |
conversation.append( | |
{ | |
"role": "system", | |
"content": system_prompt | |
} | |
) | |
for user, assistant in chat_history: | |
conversation.extend( | |
[ | |
{ | |
"role": "user", | |
"content": user | |
}, | |
{ | |
"role": "assistant", | |
"content": assistant | |
} | |
] | |
) | |
conversation.append( | |
{ | |
"role": "user", | |
"content": message | |
} | |
) | |
return conversation | |
def trim_input_ids( | |
input_ids, | |
max_length | |
): | |
""" | |
Trim the input token IDs if they exceed the maximum length. | |
Parameters | |
---------- | |
input_ids : torch.Tensor | |
The input token IDs. | |
max_length : int | |
The maximum length allowed. | |
Returns | |
------- | |
torch.Tensor | |
The trimmed input token IDs. | |
""" | |
if input_ids.shape[1] > max_length: | |
input_ids = input_ids[:, -max_length:] | |
print(f"Trimmed input from conversation as it was longer than {max_length} tokens.") | |
return input_ids | |
def generate( | |
message: str, | |
chat_history: list, | |
system_prompt: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1, | |
) -> Iterator[str]: | |
""" | |
Generate a response to a given message within a conversation context. | |
This function utilizes a pre-trained language model to generate a response to a given message, considering the conversation context provided in the chat history. | |
Parameters | |
---------- | |
message : str | |
The user's message for which a response is generated. | |
chat_history : list | |
A list containing tuples representing the conversation history. Each tuple should consist of two elements: the user's message and the assistant's response. | |
system_prompt : str | |
The system prompt, if any, to be included in the conversation context. | |
max_new_tokens : int, optional | |
The maximum number of tokens to generate for the response (default is 1024). | |
temperature : float, optional | |
The temperature parameter controlling the randomness of token generation (default is 0.6). | |
top_p : float, optional | |
The cumulative probability cutoff for token generation (default is 0.9). | |
top_k : int, optional | |
The number of top tokens to consider for token generation (default is 50). | |
repetition_penalty : float, optional | |
The repetition penalty controlling the likelihood of repeating tokens in the generated sequence (default is 1). | |
Yields | |
------ | |
str | |
A generated response to the given message. | |
Notes | |
----- | |
- This function requires a GPU for efficient processing and may not work properly on CPU. | |
- The conversation history should be provided in the form of a list of tuples, where each tuple represents a user message followed by the assistant's response. | |
""" | |
global tokenizer | |
global model | |
conversation = preprocess_conversation( | |
message=message, | |
chat_history=chat_history, | |
system_prompt=system_prompt | |
) | |
input_ids = tokenizer.apply_chat_template( | |
conversation, | |
return_tensors="pt", | |
add_generation_prompt=True | |
) | |
input_ids = trim_input_ids( | |
input_ids=input_ids, | |
max_length=MAX_INPUT_TOKEN_LENGTH | |
) | |
input_ids = input_ids.to( | |
torch.device("cuda") | |
) | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
timeout=10.0, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": False, | |
"num_beams": 1, | |
"repetition_penalty": repetition_penalty, | |
"eos_token_id": tokenizer.eos_token_id, | |
"pad_token_id": tokenizer.eos_token_id | |
} | |
t = Thread( | |
target=model.generate, | |
kwargs=generate_kwargs | |
) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
return "".join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Textbox(label="System prompt", lines=6), | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=MAX_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1, | |
), | |
], | |
fill_height=True, | |
examples=[ | |
["implement snake game using pygame"], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["write a program to find the factorial of a number"], | |
], | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
value=DESCRIPTION | |
) | |
gr.DuplicateButton() | |
chat_interface.render() | |
if __name__ == "__main__": | |
demo.queue().launch( | |
show_api=False | |
) |