Pearl-7B / app.py
louisbrulenaudet's picture
Upload app.py
3f19717 verified
raw history blame
No virus
8.79 kB
# -*- coding: utf-8 -*-
# Copyright (c) Louis Brulé Naudet. All Rights Reserved.
# This software may be used and distributed according to the terms of the License Agreement.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
def setup(
model_id: str,
description: str
) -> tuple:
"""
Set up the model and tokenizer for a given pre-trained model ID.
Parameters
----------
model_id : str
The ID of the pre-trained model to load.
description : str
A string containing additional description information.
Returns
-------
tuple
A tuple containing the loaded model, tokenizer, and updated description.
If an error occurs during setup, model and tokenizer are None, and an error message is appended to the description.
"""
if not torch.cuda.is_available():
description += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
return None, None, description
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
model_id
)
tokenizer.use_default_system_prompt = False
# Update the description
description += "\n<p>Model and tokenizer set up successfully.</p>"
return model, tokenizer, description
# except Exception as e:
# # If an error occurs during setup, append the error message to the description
# description += f"\n<p>Error occurred during model setup: {str(e)}</p>"
# return None, None, description
DESCRIPTION = """\
# Pearl-7B-0211-ties, an xtraordinary 7B model
This space showcases the <a style='color:white;' href='https://huggingface.co/louisbrulenaudet/Pearl-7B-0211-ties'>Pearl-7B-0211-ties</a>
model by Louis Brulé Naudet, a language model with 7.24 billion parameters that achieves a score exceeding 75.10 on the Open LLM Leaderboard
(average).
"""
model, tokenizer, description = setup(
model_id="louisbrulenaudet/Pearl-7B-0211-ties",
description=DESCRIPTION
)
print(model, tokenizer)
def preprocess_conversation(
message: str,
chat_history: list,
system_prompt: str
):
"""
Preprocess the conversation history by formatting it appropriately.
Parameters
----------
message : str
The user's message.
chat_history : list
The conversation history, where each element is a tuple (user_message, assistant_response).
system_prompt : str
The system prompt.
Returns
-------
list
The formatted conversation history.
"""
conversation = []
if system_prompt:
conversation.append(
{
"role": "system",
"content": system_prompt
}
)
for user, assistant in chat_history:
conversation.extend(
[
{
"role": "user",
"content": user
},
{
"role": "assistant",
"content": assistant
}
]
)
conversation.append(
{
"role": "user",
"content": message
}
)
return conversation
def trim_input_ids(
input_ids,
max_length
):
"""
Trim the input token IDs if they exceed the maximum length.
Parameters
----------
input_ids : torch.Tensor
The input token IDs.
max_length : int
The maximum length allowed.
Returns
-------
torch.Tensor
The trimmed input token IDs.
"""
if input_ids.shape[1] > max_length:
input_ids = input_ids[:, -max_length:]
print(f"Trimmed input from conversation as it was longer than {max_length} tokens.")
return input_ids
@spaces.GPU
def generate(
message: str,
chat_history: list,
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1,
) -> Iterator[str]:
"""
Generate a response to a given message within a conversation context.
This function utilizes a pre-trained language model to generate a response to a given message, considering the conversation context provided in the chat history.
Parameters
----------
message : str
The user's message for which a response is generated.
chat_history : list
A list containing tuples representing the conversation history. Each tuple should consist of two elements: the user's message and the assistant's response.
system_prompt : str
The system prompt, if any, to be included in the conversation context.
max_new_tokens : int, optional
The maximum number of tokens to generate for the response (default is 1024).
temperature : float, optional
The temperature parameter controlling the randomness of token generation (default is 0.6).
top_p : float, optional
The cumulative probability cutoff for token generation (default is 0.9).
top_k : int, optional
The number of top tokens to consider for token generation (default is 50).
repetition_penalty : float, optional
The repetition penalty controlling the likelihood of repeating tokens in the generated sequence (default is 1).
Yields
------
str
A generated response to the given message.
Notes
-----
- This function requires a GPU for efficient processing and may not work properly on CPU.
- The conversation history should be provided in the form of a list of tuples, where each tuple represents a user message followed by the assistant's response.
"""
global tokenizer
global model
conversation = preprocess_conversation(
message=message,
chat_history=chat_history,
system_prompt=system_prompt
)
input_ids = tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True
)
input_ids = trim_input_ids(
input_ids=input_ids,
max_length=MAX_INPUT_TOKEN_LENGTH
)
input_ids = input_ids.to(
torch.device("cuda")
)
streamer = TextIteratorStreamer(
tokenizer,
timeout=10.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": False,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
"eos_token_id": tokenizer.eos_token_id
}
t = Thread(
target=model.generate,
kwargs=generate_kwargs
)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
return "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=1048,
step=1,
value=1048,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1,
),
],
stop_btn=None,
examples=[
["implement snake game using pygame"],
["Can you explain briefly to me what is the Python programming language?"],
["write a program to find the factorial of a number"],
],
)
with gr.Blocks() as demo:
gr.Markdown(
value=DESCRIPTION
)
chat_interface.render()
if __name__ == "__main__":
demo.queue().launch(
show_api=False
)