File size: 13,255 Bytes
ffa493c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import datetime
from google.protobuf import message
import torch
import time
import threading
import streamlit as st
import random
from typing import Iterable
# from unsloth import FastLanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast
from datetime import datetime
from threading import Thread

# fine_tuned_model_name = "jed-tiotuico/twitter-llama"
# sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"

fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M"
sota_model_name = "MBZUAI/LaMini-GPT-124M"
alpaca_input_text_format = "### Instruction:\n{}\n\n### Response:\n"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# if device is cpu try mps?
if device == "cpu":
    # check if mps is available
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

def get_model_tokenizer(sota_model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        sota_model_name,
        cache_dir="/Users/jedtiotuico/.hf_cache",
        trust_remote_code=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        sota_model_name,
        cache_dir="/Users/jedtiotuico/.hf_cache",
        trust_remote_code=True
    ).to(device)

    return model, tokenizer

def write_user_chat_message(user_chat, customer_msg):
    if customer_msg:
        if user_chat == None:
            user_chat = st.chat_message("user")

        user_chat.write(customer_msg)

def write_stream_user_chat_message(user_chat, model, token, prompt):
    if prompt:
        if user_chat == None:
            user_chat = st.chat_message("user")

        new_customer_msg = user_chat.write_stream(
            stream_generation(
                prompt,
                show_prompt=False,
                tokenizer=tokenizer,
                model=model,
            )
        )

        return new_customer_msg

def get_mistral_model_tokenizer(sota_model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        sota_model_name,
        cache_dir="/Users/jedtiotuico/.hf_cache",
        trust_remote_code=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        sota_model_name,
        cache_dir="/Users/jedtiotuico/.hf_cache",
        trust_remote_code=True
    ).to(device)

    return model, tokenizer

class DeckPicker:
    def __init__(self, items):
        self.items = items[:]  # Make a copy of the items to shuffle
        self.original_items = items[:]  # Keep the original order
        random.shuffle(self.items)  # Shuffle the items
        self.index = -1  # Initialize the index

    def pick(self):
        """Pick the next item from the deck. If all items have been picked, reshuffle."""
        self.index += 1
        if self.index >= len(self.items):
            self.index = 0
            random.shuffle(self.items)  # Reshuffle if at the end
        return self.items[self.index]

    def get_state(self):
        """Return the current state of the deck and the last picked index."""
        return self.items, self.index

# Example of usage
nouns = [
    "service", "issue", "account", "support", "problem", "help", "team",
    "request", "response", "email", "ticket", "update", "error", "system",
    "connection", "downtime", "billing", "charge", "refund", "password",
    "outage", "agent", "feature", "access", "status", "interface", "network",
    "subscription", "upgrade", "notification", "data", "server", "log", "message",
    "renewal", "setup", "security", "feedback", "confirmation", "printer"
]

verbs = [
    "have", "print", "need", "help", "update", "resolve", "access", "contact",
    "receive", "reset", "support", "experience", "report", "request", "process",
    "check", "confirm", "explain", "manage", "handle", "disconnect", "renew",
    "change", "fix", "cancel", "complete", "notify", "respond", "fail", "restore",
    "review", "escalate", "submit", "configure", "troubleshoot", "log", "operate",
    "suspend", "pay", "adjust"
]

adjectives = [
    "quick", "immediate", "urgent", "unable", "detailed", "frequent", "technical",
    "possible", "slow", "helpful", "unresponsive", "secure", "successful", "necessary",
    "available", "scheduled", "regular", "interrupted", "automatic", "manual", "last",
    "online", "offline", "new", "current", "prior", "due", "related", "temporary",
    "permanent", "next", "previous", "complicated", "easy", "difficult", "major",
    "minor", "alternative", "additional", "expired"
]

def create_few_shots(noun_picker, verb_picker, adjective_picker):
  noun = noun_picker.pick()
  verb = verb_picker.pick()
  adjective = adjective_picker.pick()

  context = f"""
Write a short realistic customer support tweet message by a customer for another company.
Avoid adding hashtags or mentions in the message.
Ensure that the sentiment is negative.
Ensure that the word count is around 15 to 25 words.
Ensure the message contains the noun: {noun}, verb: {verb}, and adjective: {adjective}.

Example of return messages 5/5:

1/5: your website is straight up garbage. how do you sell high end technology but you cant get a website right?
2/5: my phone is all static during calls and when i plug in headphones any audio still comes thru the speaks wtf
3/5: hi, i'm having trouble logging into my groceries account it keeps refreshing back to the log in page, any ideas?
4/5: please check you dms asap if you're really about customer service. 2 weeks since my accident and nothing.
5/5: I'm extremely disappointed with your service. You charged me for a temporary solution, and there's no adjustment in sight.

Now it's your turn, ensure to only generate one message
1/1:
"""
  return context

st.header("ReplyCaddy")
st.write("AI-powered customer support assistant. Reduces anxiety when responding to customer support on social media.")
# image https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true
# st.write("Made with [Unsloth](https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true")

def stream_generation(
    prompt: str,
    tokenizer: PreTrainedTokenizerFast,
    model: AutoModelForCausalLM,
    max_new_tokens: int = 2048,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: int = 100,
    repetition_penalty: float = 1.1,
    penalty_alpha: float = 0.25,
    no_repeat_ngram_size: int = 3,
    show_prompt: bool = False,
) -> Iterable[str]:
    """
    Stream the generation of a prompt.

    Args:
        prompt (str): the prompt
        max_new_tokens (int, optional): the maximum number of tokens to generate. Defaults to 32.
        temperature (float, optional): the temperature of the generation. Defaults to 0.7.
        top_p (float, optional): the top-p value of the generation. Defaults to 0.9.
        top_k (int, optional): the top-k value of the generation. Defaults to 100.
        repetition_penalty (float, optional): the repetition penalty of the generation. Defaults to 1.1.
        penalty_alpha (float, optional): the penalty alpha of the generation. Defaults to 0.25.
        no_repeat_ngram_size (int, optional): the no repeat ngram size of the generation. Defaults to 3.
        show_prompt (bool, optional): whether to show the prompt or not. Defaults to False.
        tokenizer (PreTrainedTokenizerFast): the tokenizer
        model (AutoModelForCausalLM): the model

    Yields:
        str: the generated text
    """
    # init the streaming object with tokenizer
    # skip_prompt = not show_prompt, skip_special_tokens = True
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=not show_prompt, skip_special_tokens=True)  # type: ignore

    # setup kwargs for generation
    generation_kwargs = dict(
        input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device),
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        penalty_alpha=penalty_alpha,
        no_repeat_ngram_size=no_repeat_ngram_size,
        max_new_tokens=max_new_tokens,
    )

    # start the generation in a separate thread
    generation_thread = threading.Thread(
        target=model.generate, kwargs=generation_kwargs  # type: ignore
    )
    generation_thread.start()

    blacklisted_tokens = ["<|url|>"]
    for new_text in streamer:
        # filter out blacklisted tokens
        if any(token in new_text for token in blacklisted_tokens):
            continue

        yield new_text

    # wait for the generation to finish
    generation_thread.join()

twitter_llama_model = None
twitter_llama_tokenizer = None
streamer = None

# define state and the chat messages
def init_session_states(assistant_chat, user_chat):
    if "user_msg_as_prompt" not in st.session_state:
        st.session_state["user_msg_as_prompt"] = ""

user_chat = None
if "user_msg_as_prompt" in st.session_state:
    user_chat = st.chat_message("user")

assistant_chat = st.chat_message("assistant")
if "greet" not in st.session_state:
    st.session_state["greet"] = False
    greeting_text = "Hello! I'm here to help. Copy and paste your customer's message, or generate using AI."
    assistant_chat.write(greeting_text)

init_session_states(assistant_chat, user_chat)

# Generate Response Tweet
if user_chat:
    if st.button("Generate Polite and Friendly Response"):
        if "user_msg_as_prompt" in st.session_state:
            customer_msg = st.session_state["user_msg_as_prompt"]
            if customer_msg:
                write_user_chat_message(user_chat, customer_msg)

                model, tokenizer = get_model_tokenizer(sota_model_name)

                input_text = alpaca_input_text_format.format(customer_msg)
                st.markdown(f"""```\n{input_text}```""", unsafe_allow_html=True)
                response_tweet = assistant_chat.write_stream(
                    stream_generation(
                        input_text,
                        show_prompt=False,
                        tokenizer=tokenizer,
                        model=model,
                    )
                )
            else:
                st.error("Please enter a customer message, or generate one for the ai to respond")

# main ui prompt
# - text box
# - submit
with st.form(key="my_form"):
    prompt = st.text_area("Customer Message")
    write_user_chat_message(user_chat, prompt)
    if st.form_submit_button("Submit"):
        assistant_chat.write("Hi, Human.")

# below ui prompt
# - examples
# st.markdown("<b>Example:</b>", unsafe_allow_html=True)
if st.button("your website is straight up garbage. how do you sell high end technology but you cant get a website right?"):
    customer_msg = "your website is straight up garbage. how do you sell high end technology but you cant get a website right?"
    st.session_state["user_msg_as_prompt"] = customer_msg
    write_user_chat_message(user_chat, customer_msg)
    model, tokenizer = get_model_tokenizer(sota_model_name)
    input_text = alpaca_input_text_format.format(customer_msg)
    st.write(f"```\n{input_text}```")
    assistant_chat.write_stream(
        stream_generation(
            input_text,
            show_prompt=False,
            tokenizer=tokenizer,
            model=model,
        )
    )

# - Generate Customer Tweet
if st.button("Generate Customer Message using Few Shots"):
    max_seq_length = 2048
    dtype = torch.float16
    load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

    model, tokenizer = get_mistral_model_tokenizer(sota_model_name)

    noun_picker = DeckPicker(nouns)
    verb_picker = DeckPicker(verbs)
    adjective_picker = DeckPicker(adjectives)
    few_shots = create_few_shots(noun_picker, verb_picker, adjective_picker)
    few_shot_prompt = f"<s>[INST]{few_shots}[/INST]\n"
    st.markdown("Prompt:")
    st.markdown(f"""```\n{few_shot_prompt}```""", unsafe_allow_html=True)

    new_customer_msg = write_stream_user_chat_message(user_chat, model, tokenizer, few_shot_prompt)
    st.session_state["user_msg_as_prompt"] = new_customer_msg


st.markdown("------------")
st.markdown("<p>Thanks to:</p>", unsafe_allow_html=True)
st.markdown("""Unsloth https://github.com/unslothai check out the [wiki](https://github.com/unslothai/unsloth/wiki)""")
st.markdown("""Georgi Gerganov's ggml https://github.com/ggerganov/ggml""")
st.markdown("""Meta's Llama https://github.com/meta-llama""")
st.markdown("""Mistral AI  - https://github.com/mistralai""")
st.markdown("""Zhang Peiyuan's TinyLlama https://github.com/jzhang38/TinyLlama""")
st.markdown("""Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois,
    Xuechen Li, Carlos Guestrin, Percy Liang, Tatsunori B. Hashimoto
    - [Alpaca: A Strong, Replicable Instruction-Following Model](https://crfm.stanford.edu/2023/03/13/alpaca.html)""")

if device == "cuda":
    gpu_stats = torch.cuda.get_device_properties(0)
    max_memory = gpu_stats.total_memory / 1024 ** 3
    start_gpu_memory = torch.cuda.memory_reserved(0) / 1024 ** 3
    st.write(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
    st.write(f"{start_gpu_memory} GB of memory reserved.")

st.write("Packages:")
st.write(f"pytorch: {torch.__version__}")