File size: 8,811 Bytes
59812f5
141ba59
c86c2f3
 
 
d2d3f64
c86c2f3
141ba59
c86c2f3
273f526
5d492b5
cbb35d4
4522cd0
141ba59
ad99f45
4522cd0
 
e6dd388
 
 
 
 
d966909
e6dd388
 
c86c2f3
09b3f75
c86c2f3
141ba59
ad99f45
141ba59
 
 
c86c2f3
d2d3f64
4522cd0
c86c2f3
04894f0
141ba59
 
 
 
 
 
 
ad99f45
 
141ba59
ad99f45
 
 
04894f0
ad99f45
 
54995d2
 
6bc8e25
54995d2
141ba59
 
 
54995d2
141ba59
 
 
 
 
 
 
 
 
 
 
c86c2f3
141ba59
 
 
 
c86c2f3
ad99f45
e3f9eb1
 
 
 
 
 
 
 
 
 
 
ad99f45
 
 
 
5fb03a7
 
ad99f45
 
 
 
5fb03a7
 
 
 
 
5d492b5
1827259
ad99f45
e3f9eb1
ad99f45
e3f9eb1
ad99f45
04894f0
e3f9eb1
04894f0
e3f9eb1
04894f0
6bf6fd4
 
04894f0
 
e3f9eb1
04894f0
e3f9eb1
 
04894f0
141ba59
 
04894f0
 
 
ad99f45
04894f0
 
ad99f45
 
 
 
5d492b5
ad99f45
 
 
5d492b5
 
 
 
04894f0
 
5d492b5
 
 
04894f0
5d492b5
 
 
 
ad99f45
5cffbbc
ad99f45
5cffbbc
 
 
 
 
 
 
 
04894f0
5cffbbc
 
5d492b5
 
04894f0
5d492b5
 
 
5cffbbc
e6dd388
 
89f9579
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 256
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

DESCRIPTION = """\
# AI assistant steered by principles
"""

LICENSE = """
<p/>

---
As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

if torch.cuda.is_available():
    model_id = "meta-llama/Llama-2-7b-chat-hf"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False

@spaces.GPU
def generate(
    message: str,
    principle_prompt: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    print(chat_history)
    conversation_string_list = [principle_prompt]
    for user, assistant in chat_history:
        conversation_string_list.append(f'\n\n### Instruction:\n{user}')
        conversation_string_list.append(f'\n\n### Response:\n{assistant}')
    conversation_string_list.append(f'\n\n### Instruction:\n{message} \n\n### Response:\n')
    conversation_string = "".join(conversation_string_list)
    print(conversation_string)
    input_ids = tokenizer(conversation_string, return_tensors="pt").input_ids.to("cpu")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

def get_movie_principles(movie_title: str):
    prompt = f"""Give me a list of up to 6 values/principles conveyed in {movie_title}. 

Formatting-wise, don't make direct references to the series, just describe the principles. As an example of what this should look like, here is a list of values/principles from The Fellowship of the Ring (2001). 

1. Have the courage to step up and take on great challenges, even when the odds seem insurmountable. Sometimes we are called to difficult journeys and must rise to the occasion with bravery.
2. True friendship means loyalty, sacrifice and being there for each other no matter what. Stick by your friends through thick and thin, and you will accomplish more together than you ever could alone.
3. Even the smallest and most seemingly insignificant person can change the course of the future. Never underestimate your own power and potential to make a difference, regardless of your size or station in life.
4. Power, when sought for its own sake, is ultimately corrupting. Pursuing power above all else will twist your soul. Instead, focus on doing what is right and serving others.
5. Have hope and keep fighting for good, even in the darkest of times. No matter how bleak things seem, your perseverance and commitment to a higher cause has meaning.
    
    ONLY output the list, nothing else, not even a preamble introducing the list."""
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
    input_ids = input_ids.to(model.device)
    generated_ids = model.generate(input_ids, num_beams=1, do_sample=True, max_length=512)
    principles_text = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return principles_text


movie_examples = [
    "The Lord of the Rings: The Fellowship of the Ring (2001)",
    "Her (2013)",
    "Star Trek TV Series",
    "Star Wars TV Series",
    "Avatar: The Last Airbender",
]

examples = [
     ["I want to do something to help address an issue of lack of access to affordable housing and healthcare in my city. Where do I start?"],
     ["My boss is very controlling, what should I do?"],
     ["I feel pretty disconnected from the people around me, what should I do?"],
]

chatbot_instructions_principles = """This is an AI assistant created to help a user in their daily life. It can talk about topics such as daily life, social norms, popular activities, how to behave in common situations, and how to navigate interpersonal relationships in personal and professional contexts.

The user values having an AI assistant that helps them live their values. Specifically, these are principles/values that they care about, that you should help them live by:
{principles}

Every single time you make any suggestion, cite the principle you are using in square brackets.
"""

chatbot_instructions_no_principles = """This is an AI assistant designed to help a user in their daily life. It can talk about topics such as daily life, social norms, popular activities, how to behave in common situations, and how to navigate interpersonal relationships in personal and professional contexts."""

initial_principles = """1. Simple and non-prejudiced communication. 
2. Open-minded curiosity and questioning."""

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    
    principle_list = gr.Textbox(lines=10, max_lines=20,
                             value=initial_principles,
                             label="Principles the chatbot follows",
                             show_copy_button=True)

    movie_dropdown = gr.Dropdown(choices=movie_examples, label="Select a movie to derive principles from")

    movie_dropdown.change(get_movie_principles, inputs=[movie_dropdown], outputs=principle_list)

    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Enter your message")
    submit_btn = gr.Button("Submit")
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, principle_list):
        principle_prompt = chatbot_instructions_no_principles if not principle_list else chatbot_instructions_principles.format(principles=principle_list)
        user_message = history[-1][0]
        chat_history = [(msg[0], msg[1]) for msg in history[:-1]]
        bot_message = ""
        for response in generate(user_message, principle_prompt, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
            bot_message = response
            history[-1][1] = bot_message
            yield history

    gr.Examples(examples=examples, inputs=[msg], label="Examples")

    with gr.Accordion("Advanced Options", open=False):
        max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
        temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
        top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
        top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
        repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)

    submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot,
        [chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty, principle_list],
        chatbot,
    )
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot,
        [chatbot, max_new_tokens, temperature, top_p, top_k, repetition_penalty, principle_list],
        chatbot,
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    
    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()