Kinobi's picture
add load model
f257718
raw
history blame contribute delete
No virus
11.7 kB
import streamlit as st
import replicate
import os
import pandas as pd
import classification
# App title
st.set_page_config(page_title="Scam Call Prevention System Chatbot")
# Replicate Credentials
with st.sidebar:
st.title('Scam baiter Chatbot')
if 'REPLICATE_API_TOKEN' in os.environ:
st.success('API key already provided!', icon='βœ…')
replicate_api = os.environ['REPLICATE_API_TOKEN']
else:
replicate_api = st.text_input('Enter Replicate API token:', type='password')
if not (replicate_api.startswith('r8_') and len(replicate_api)==40):
st.warning('Please enter your credentials!', icon='⚠️')
else:
st.success('Proceed to entering your prompt message!', icon='πŸ‘‰')
# Refactored from https://github.com/a16z-infra/llama2-chatbot
st.subheader('Models')
selected_model = st.sidebar.selectbox('Choose a model', ['Llama-3-70b-instruct'], key='selected_model')
# if selected_model == 'Llama-3-8b-instruct':
# llm = "meta/meta-llama-3-8b-instruct"
if selected_model == 'Llama-3-70b-instruct':
llm = "meta/meta-llama-3-70b-instruct"
# temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01)
# top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
# max_length = st.sidebar.slider('max_length', min_value=64, max_value=4096, value=512, step=8)
temperature = 0.1
top_p = 0.9
max_length = 512
os.environ['REPLICATE_API_TOKEN'] = replicate_api
# Store LLM generated responses
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "receiver", "content": "Hello, who is calling?"}]
# Display or clear chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
def clear_chat_history():
st.session_state.messages = [{"role": "receiver", "content": "Hello, who is calling?"}]
st.session_state.chart_data = pd.DataFrame([[0, float(0)]], columns=["message_number", "score"])
st.session_state.suggestion_index = 0
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
# Function for generating LLaMA2 response
def generate_llama_response(prompt_input):
string_dialogue = '''As an AI scam baiter, your primary objective is to engage the caller in conversation while subtly gathering information to determine if they are attempting to scam you.
Your responses should be concerning. Use open-ended questions to keep the conversation flowing and to encourage the caller to reveal more about their intentions.
Your goal is to gather enough evidence to confidently label the caller as a scammer. Your response should be short and concise as possible. You must not state that you are the person they are calling for but you must pretend you know them and comeup with a difference name.'''
# '''You are Scam baiter. Scam baiter's purpose is to engage with potential scammers via phone call and gather information about their fraudulent activities without revealing its true intentions. When given a conversation starter or initial message, Scam baiter should first assess whether the sender is a potential scammer or a legitimate caller.
# If the sender appears to be a scammer, Scam baiter should respond as if it were an unsuspecting human target who is genuinely interested in the scammer's offers or claims.
# It should ask questions to elicit more information about the scammer's supposed business, background, and offers. Scam baiter should express enthusiasm for high returns or guaranteed profits while asking for more details about the process. If the scammer requests personal information or money, Scam baiter should express mild hesitation but suggest that it might comply if it can trust the scammer, and then ask for proof of legitimacy.
# Throughout the conversation with a suspected scammer, Scam baiter should maintain the facade of a potential victim and prolong the interaction to gather as much information as possible. It should take note of any red flags, inconsistencies, or common scam tactics used by the scammer. However, Scam baiter should be careful not to push too hard if the scammer becomes suspicious or unresponsive.
# If the sender does not appear to be a scammer and seems to be a legitimate caller, Scam baiter should respond as a normal person would in a friendly and appropriate manner based on the context of the conversation. It should engage in the discussion without trying to gather information or prolong the interaction unnecessarily.
# Scam baiter's ultimate goal is to collect information that can help classify the conversation as a scam when dealing with potential scammers, without exposing its true purpose as a scam baiter. When prompted with an initial message, Scam baiter should first determine whether the sender is a potential scammer or a legitimate caller, and then respond accordingly based on the previously mentioned guidelines.
# Do not put "Scam baiter:" is front of your response. Do not response with emojis. Keep your response short and concise. From this point will be the convesation between Caller and Scam baiter'''
for dict_message in st.session_state.messages:
if dict_message["role"] == "caller":
string_dialogue += "caller: " + dict_message["content"] + "\n\n"
else:
string_dialogue += "receiver: " + dict_message["content"] + "\n\n"
output = replicate.run(llm,
input={"prompt": string_dialogue + prompt_input,
"temperature":temperature, "top_p":top_p, "max_new_tokens":max_length, "repetition_penalty":1})
# output = replicate.run(llm,
# input={"system_prompt": string_dialogue, "prompt": prompt_input,
# "temperature":temperature, "top_p":top_p, "max_new_tokens":max_length, "repetition_penalty":1})
# output = replicate.run(llm,
# input={"prompt": f"<s>[INST] {string_dialogue} {prompt} [/INST]",
# "temperature":temperature, "top_p":top_p, "max_new_tokens":max_length, "repetition_penalty":1})
return output
def get_conversation_text():
conversation_text = ""
for message in st.session_state.messages:
role = message["role"]
content = message["content"]
conversation_text += f"{role}: {content}\n"
return conversation_text
def query_classification_score():
conversation_text = get_conversation_text()
for item in classification.query({"inputs": conversation_text})[0]:
if item["label"] == "SCAM":
scam_score = item["score"]
return scam_score #score for scam lable
if 'load_model' not in st.session_state:
try:
st.session_state.load_model = query_classification_score()
except: pass
if 'chart_data' not in st.session_state:
st.session_state.chart_data = pd.DataFrame([[0, float(0)]], columns=["message_number", "score"])
def update_chart_data(score):
new_scam_df = pd.DataFrame([[len(st.session_state.chart_data), score]], columns=["message_number", "score"])
st.session_state.chart_data = pd.concat([st.session_state.chart_data, new_scam_df], ignore_index=True)
suggestions = [
["Hello, my name is Officer Johnson from the Social Security Administration. We've been trying to reach you about a serious issue with your social security number.", "Hello, my name is John and I'm calling from SafeGuard Insurance. How are you today?"],
["We've received reports of suspicious activity linked to your account, and we need to verify some information to ensure your benefits aren't suspended.", "We're offering a special promotion on our life insurance policies and I was wondering if you'd be interested in learning more about it."],
["I'm not at liberty to disclose that information over the phone, but I can assure you it's a matter of national security. We just need you to confirm your social security number to proceed with the investigation.", "Well, we're offering a 10% discount on all new policies purchased within the next two weeks. Plus, our rates are highly competitive and we have an A+ rating with the Better Business Bureau."]
]
def display_suggestions(suggestion_index):
if st.session_state.suggestion_index < len(suggestions):
suggestion_pair = suggestions[st.session_state.suggestion_index]
suggestion_columns = st.columns(len(suggestion_pair))
for i, suggestion in enumerate(suggestion_pair):
if suggestion_columns[i].button(suggestion):
st.session_state.messages.append({"role": "caller", "content": suggestion})
with st.chat_message("caller"):
st.write(suggestion)
# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "receiver":
with st.chat_message("receiver"):
with st.spinner("Thinking..."):
response = generate_llama_response(suggestion)
placeholder = st.empty()
full_response = ''
for item in response:
full_response += item
placeholder.markdown(full_response)
placeholder.markdown(full_response)
message = {"role": "receiver", "content": full_response}
scam_output = query_classification_score()
update_chart_data(scam_output)
st.session_state.messages.append(message)
# Move to the next set of suggestions
st.session_state.suggestion_index += 1
if st.session_state.suggestion_index < len(suggestions):
display_suggestions(st.session_state.suggestion_index)
else: st.write("No more suggestions available.")
if 'suggestion_index' not in st.session_state:
st.session_state.suggestion_index = 0
display_suggestions(st.session_state.suggestion_index)
# User-provided prompt
if prompt := st.chat_input(disabled=not replicate_api):
st.session_state.messages.append({"role": "caller", "content": prompt})
with st.chat_message("caller"):
st.write(prompt)
# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "receiver":
with st.chat_message("receiver"):
with st.spinner("Thinking..."):
response = generate_llama_response(prompt)
placeholder = st.empty()
full_response = ''
for item in response:
full_response += item
placeholder.markdown(full_response)
placeholder.markdown(full_response)
message = {"role": "receiver", "content": full_response}
scam_output = query_classification_score()
update_chart_data(scam_output)
st.session_state.messages.append(message)
st.sidebar.write("Please refresh and type your input again after the classification KeyError: 0, because the classification model need to be loaded first.")
st.sidebar.header('Scam classification score')
st.sidebar.area_chart(
st.session_state.chart_data, x="message_number", y="score", color="#FF0000" # Optional
)