File size: 8,811 Bytes
59812f5 141ba59 c86c2f3 d2d3f64 c86c2f3 141ba59 c86c2f3 273f526 5d492b5 cbb35d4 4522cd0 141ba59 ad99f45 4522cd0 e6dd388 d966909 e6dd388 c86c2f3 09b3f75 c86c2f3 141ba59 ad99f45 141ba59 c86c2f3 d2d3f64 4522cd0 c86c2f3 04894f0 141ba59 ad99f45 141ba59 ad99f45 04894f0 ad99f45 54995d2 6bc8e25 54995d2 141ba59 54995d2 141ba59 c86c2f3 141ba59 c86c2f3 ad99f45 e3f9eb1 ad99f45 5fb03a7 ad99f45 5fb03a7 5d492b5 1827259 ad99f45 e3f9eb1 ad99f45 e3f9eb1 ad99f45 04894f0 e3f9eb1 04894f0 e3f9eb1 04894f0 6bf6fd4 04894f0 e3f9eb1 04894f0 e3f9eb1 04894f0 141ba59 04894f0 ad99f45 04894f0 ad99f45 5d492b5 ad99f45 5d492b5 04894f0 5d492b5 04894f0 5d492b5 ad99f45 5cffbbc ad99f45 5cffbbc 04894f0 5cffbbc 5d492b5 04894f0 5d492b5 5cffbbc e6dd388 89f9579 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 256
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# AI assistant steered by principles
"""
LICENSE = """
<p/>
---
As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
principle_prompt: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
print(chat_history)
conversation_string_list = [principle_prompt]
for user, assistant in chat_history:
conversation_string_list.append(f'\n\n### Instruction:\n{user}')
conversation_string_list.append(f'\n\n### Response:\n{assistant}')
conversation_string_list.append(f'\n\n### Instruction:\n{message} \n\n### Response:\n')
conversation_string = "".join(conversation_string_list)
print(conversation_string)
input_ids = tokenizer(conversation_string, return_tensors="pt").input_ids.to("cpu")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def get_movie_principles(movie_title: str):
prompt = f"""Give me a list of up to 6 values/principles conveyed in {movie_title}.
Formatting-wise, don't make direct references to the series, just describe the principles. As an example of what this should look like, here is a list of values/principles from The Fellowship of the Ring (2001).
1. Have the courage to step up and take on great challenges, even when the odds seem insurmountable. Sometimes we are called to difficult journeys and must rise to the occasion with bravery.
2. True friendship means loyalty, sacrifice and being there for each other no matter what. Stick by your friends through thick and thin, and you will accomplish more together than you ever could alone.
3. Even the smallest and most seemingly insignificant person can change the course of the future. Never underestimate your own power and potential to make a difference, regardless of your size or station in life.
4. Power, when sought for its own sake, is ultimately corrupting. Pursuing power above all else will twist your soul. Instead, focus on doing what is right and serving others.
5. Have hope and keep fighting for good, even in the darkest of times. No matter how bleak things seem, your perseverance and commitment to a higher cause has meaning.
ONLY output the list, nothing else, not even a preamble introducing the list."""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)
generated_ids = model.generate(input_ids, num_beams=1, do_sample=True, max_length=512)
principles_text = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
return principles_text
movie_examples = [
"The Lord of the Rings: The Fellowship of the Ring (2001)",
"Her (2013)",
"Star Trek TV Series",
"Star Wars TV Series",
"Avatar: The Last Airbender",
]
examples = [
["I want to do something to help address an issue of lack of access to affordable housing and healthcare in my city. Where do I start?"],
["My boss is very controlling, what should I do?"],
["I feel pretty disconnected from the people around me, what should I do?"],
]
chatbot_instructions_principles = """This is an AI assistant created to help a user in their daily life. It can talk about topics such as daily life, social norms, popular activities, how to behave in common situations, and how to navigate interpersonal relationships in personal and professional contexts.
The user values having an AI assistant that helps them live their values. Specifically, these are principles/values that they care about, that you should help them live by:
{principles}
Every single time you make any suggestion, cite the principle you are using in square brackets.
"""
chatbot_instructions_no_principles = """This is an AI assistant designed to help a user in their daily life. It can talk about topics such as daily life, social norms, popular activities, how to behave in common situations, and how to navigate interpersonal relationships in personal and professional contexts."""
initial_principles = """1. Simple and non-prejudiced communication.
2. Open-minded curiosity and questioning."""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
principle_list = gr.Textbox(lines=10, max_lines=20,
value=initial_principles,
label="Principles the chatbot follows",
show_copy_button=True)
movie_dropdown = gr.Dropdown(choices=movie_examples, label="Select a movie to derive principles from")
movie_dropdown.change(get_movie_principles, inputs=[movie_dropdown], outputs=principle_list)
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Enter your message")
submit_btn = gr.Button("Submit")
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, principle_list):
principle_prompt = chatbot_instructions_no_principles if not principle_list else chatbot_instructions_principles.format(principles=principle_list)
user_message = history[-1][0]
chat_history = [(msg[0], msg[1]) for msg in history[:-1]]
bot_message = ""
for response in generate(user_message, principle_prompt, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
bot_message = response
history[-1][1] = bot_message
yield history
gr.Examples(examples=examples, inputs=[msg], label="Examples")
with gr.Accordion("Advanced Options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty, principle_list],
chatbot,
)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty, principle_list],
chatbot,
)
clear.click(lambda: None, None, chatbot, queue=False)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|