import datetime from google.protobuf import message import torch import time import threading import streamlit as st import random from typing import Iterable # from unsloth import FastLanguageModel from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast from datetime import datetime from threading import Thread # fine_tuned_model_name = "jed-tiotuico/twitter-llama" # sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M" sota_model_name = "MBZUAI/LaMini-GPT-124M" alpaca_input_text_format = "### Instruction:\n{}\n\n### Response:\n" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # if device is cpu try mps? if device == "cpu": # check if mps is available device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") def get_model_tokenizer(sota_model_name): tokenizer = AutoTokenizer.from_pretrained( sota_model_name, cache_dir="/Users/jedtiotuico/.hf_cache", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( sota_model_name, cache_dir="/Users/jedtiotuico/.hf_cache", trust_remote_code=True ).to(device) return model, tokenizer def write_user_chat_message(user_chat, customer_msg): if customer_msg: if user_chat == None: user_chat = st.chat_message("user") user_chat.write(customer_msg) def write_stream_user_chat_message(user_chat, model, token, prompt): if prompt: if user_chat == None: user_chat = st.chat_message("user") new_customer_msg = user_chat.write_stream( stream_generation( prompt, show_prompt=False, tokenizer=tokenizer, model=model, ) ) return new_customer_msg def get_mistral_model_tokenizer(sota_model_name): tokenizer = AutoTokenizer.from_pretrained( sota_model_name, cache_dir="/Users/jedtiotuico/.hf_cache", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( sota_model_name, cache_dir="/Users/jedtiotuico/.hf_cache", trust_remote_code=True ).to(device) return model, tokenizer class DeckPicker: def __init__(self, items): self.items = items[:] # Make a copy of the items to shuffle self.original_items = items[:] # Keep the original order random.shuffle(self.items) # Shuffle the items self.index = -1 # Initialize the index def pick(self): """Pick the next item from the deck. If all items have been picked, reshuffle.""" self.index += 1 if self.index >= len(self.items): self.index = 0 random.shuffle(self.items) # Reshuffle if at the end return self.items[self.index] def get_state(self): """Return the current state of the deck and the last picked index.""" return self.items, self.index # Example of usage nouns = [ "service", "issue", "account", "support", "problem", "help", "team", "request", "response", "email", "ticket", "update", "error", "system", "connection", "downtime", "billing", "charge", "refund", "password", "outage", "agent", "feature", "access", "status", "interface", "network", "subscription", "upgrade", "notification", "data", "server", "log", "message", "renewal", "setup", "security", "feedback", "confirmation", "printer" ] verbs = [ "have", "print", "need", "help", "update", "resolve", "access", "contact", "receive", "reset", "support", "experience", "report", "request", "process", "check", "confirm", "explain", "manage", "handle", "disconnect", "renew", "change", "fix", "cancel", "complete", "notify", "respond", "fail", "restore", "review", "escalate", "submit", "configure", "troubleshoot", "log", "operate", "suspend", "pay", "adjust" ] adjectives = [ "quick", "immediate", "urgent", "unable", "detailed", "frequent", "technical", "possible", "slow", "helpful", "unresponsive", "secure", "successful", "necessary", "available", "scheduled", "regular", "interrupted", "automatic", "manual", "last", "online", "offline", "new", "current", "prior", "due", "related", "temporary", "permanent", "next", "previous", "complicated", "easy", "difficult", "major", "minor", "alternative", "additional", "expired" ] def create_few_shots(noun_picker, verb_picker, adjective_picker): noun = noun_picker.pick() verb = verb_picker.pick() adjective = adjective_picker.pick() context = f""" Write a short realistic customer support tweet message by a customer for another company. Avoid adding hashtags or mentions in the message. Ensure that the sentiment is negative. Ensure that the word count is around 15 to 25 words. Ensure the message contains the noun: {noun}, verb: {verb}, and adjective: {adjective}. Example of return messages 5/5: 1/5: your website is straight up garbage. how do you sell high end technology but you cant get a website right? 2/5: my phone is all static during calls and when i plug in headphones any audio still comes thru the speaks wtf 3/5: hi, i'm having trouble logging into my groceries account it keeps refreshing back to the log in page, any ideas? 4/5: please check you dms asap if you're really about customer service. 2 weeks since my accident and nothing. 5/5: I'm extremely disappointed with your service. You charged me for a temporary solution, and there's no adjustment in sight. Now it's your turn, ensure to only generate one message 1/1: """ return context st.header("ReplyCaddy") st.write("AI-powered customer support assistant. Reduces anxiety when responding to customer support on social media.") # image https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true # st.write("Made with [Unsloth](https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true") def stream_generation( prompt: str, tokenizer: PreTrainedTokenizerFast, model: AutoModelForCausalLM, max_new_tokens: int = 2048, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 100, repetition_penalty: float = 1.1, penalty_alpha: float = 0.25, no_repeat_ngram_size: int = 3, show_prompt: bool = False, ) -> Iterable[str]: """ Stream the generation of a prompt. Args: prompt (str): the prompt max_new_tokens (int, optional): the maximum number of tokens to generate. Defaults to 32. temperature (float, optional): the temperature of the generation. Defaults to 0.7. top_p (float, optional): the top-p value of the generation. Defaults to 0.9. top_k (int, optional): the top-k value of the generation. Defaults to 100. repetition_penalty (float, optional): the repetition penalty of the generation. Defaults to 1.1. penalty_alpha (float, optional): the penalty alpha of the generation. Defaults to 0.25. no_repeat_ngram_size (int, optional): the no repeat ngram size of the generation. Defaults to 3. show_prompt (bool, optional): whether to show the prompt or not. Defaults to False. tokenizer (PreTrainedTokenizerFast): the tokenizer model (AutoModelForCausalLM): the model Yields: str: the generated text """ # init the streaming object with tokenizer # skip_prompt = not show_prompt, skip_special_tokens = True streamer = TextIteratorStreamer(tokenizer, skip_prompt=not show_prompt, skip_special_tokens=True) # type: ignore # setup kwargs for generation generation_kwargs = dict( input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device), streamer=streamer, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, penalty_alpha=penalty_alpha, no_repeat_ngram_size=no_repeat_ngram_size, max_new_tokens=max_new_tokens, ) # start the generation in a separate thread generation_thread = threading.Thread( target=model.generate, kwargs=generation_kwargs # type: ignore ) generation_thread.start() blacklisted_tokens = ["<|url|>"] for new_text in streamer: # filter out blacklisted tokens if any(token in new_text for token in blacklisted_tokens): continue yield new_text # wait for the generation to finish generation_thread.join() twitter_llama_model = None twitter_llama_tokenizer = None streamer = None # define state and the chat messages def init_session_states(assistant_chat, user_chat): if "user_msg_as_prompt" not in st.session_state: st.session_state["user_msg_as_prompt"] = "" user_chat = None if "user_msg_as_prompt" in st.session_state: user_chat = st.chat_message("user") assistant_chat = st.chat_message("assistant") if "greet" not in st.session_state: st.session_state["greet"] = False greeting_text = "Hello! I'm here to help. Copy and paste your customer's message, or generate using AI." assistant_chat.write(greeting_text) init_session_states(assistant_chat, user_chat) # Generate Response Tweet if user_chat: if st.button("Generate Polite and Friendly Response"): if "user_msg_as_prompt" in st.session_state: customer_msg = st.session_state["user_msg_as_prompt"] if customer_msg: write_user_chat_message(user_chat, customer_msg) model, tokenizer = get_model_tokenizer(sota_model_name) input_text = alpaca_input_text_format.format(customer_msg) st.markdown(f"""```\n{input_text}```""", unsafe_allow_html=True) response_tweet = assistant_chat.write_stream( stream_generation( input_text, show_prompt=False, tokenizer=tokenizer, model=model, ) ) else: st.error("Please enter a customer message, or generate one for the ai to respond") # main ui prompt # - text box # - submit with st.form(key="my_form"): prompt = st.text_area("Customer Message") write_user_chat_message(user_chat, prompt) if st.form_submit_button("Submit"): assistant_chat.write("Hi, Human.") # below ui prompt # - examples # st.markdown("Example:", unsafe_allow_html=True) if st.button("your website is straight up garbage. how do you sell high end technology but you cant get a website right?"): customer_msg = "your website is straight up garbage. how do you sell high end technology but you cant get a website right?" st.session_state["user_msg_as_prompt"] = customer_msg write_user_chat_message(user_chat, customer_msg) model, tokenizer = get_model_tokenizer(sota_model_name) input_text = alpaca_input_text_format.format(customer_msg) st.write(f"```\n{input_text}```") assistant_chat.write_stream( stream_generation( input_text, show_prompt=False, tokenizer=tokenizer, model=model, ) ) # - Generate Customer Tweet if st.button("Generate Customer Message using Few Shots"): max_seq_length = 2048 dtype = torch.float16 load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. model, tokenizer = get_mistral_model_tokenizer(sota_model_name) noun_picker = DeckPicker(nouns) verb_picker = DeckPicker(verbs) adjective_picker = DeckPicker(adjectives) few_shots = create_few_shots(noun_picker, verb_picker, adjective_picker) few_shot_prompt = f"[INST]{few_shots}[/INST]\n" st.markdown("Prompt:") st.markdown(f"""```\n{few_shot_prompt}```""", unsafe_allow_html=True) new_customer_msg = write_stream_user_chat_message(user_chat, model, tokenizer, few_shot_prompt) st.session_state["user_msg_as_prompt"] = new_customer_msg st.markdown("------------") st.markdown("

Thanks to:

", unsafe_allow_html=True) st.markdown("""Unsloth https://github.com/unslothai check out the [wiki](https://github.com/unslothai/unsloth/wiki)""") st.markdown("""Georgi Gerganov's ggml https://github.com/ggerganov/ggml""") st.markdown("""Meta's Llama https://github.com/meta-llama""") st.markdown("""Mistral AI - https://github.com/mistralai""") st.markdown("""Zhang Peiyuan's TinyLlama https://github.com/jzhang38/TinyLlama""") st.markdown("""Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li, Carlos Guestrin, Percy Liang, Tatsunori B. Hashimoto - [Alpaca: A Strong, Replicable Instruction-Following Model](https://crfm.stanford.edu/2023/03/13/alpaca.html)""") if device == "cuda": gpu_stats = torch.cuda.get_device_properties(0) max_memory = gpu_stats.total_memory / 1024 ** 3 start_gpu_memory = torch.cuda.memory_reserved(0) / 1024 ** 3 st.write(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") st.write(f"{start_gpu_memory} GB of memory reserved.") st.write("Packages:") st.write(f"pytorch: {torch.__version__}")