Value-Props / tabs /chatbot_tab.py
Demosthene-OR's picture
.....
5b68c01
raw
history blame
4.05 kB
import streamlit as st # type: ignore
import os
from sentence_transformers import SentenceTransformer
from translate_app import tr
import getpass
from langchain_mistralai import ChatMistralAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, END, MessagesState, StateGraph
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from typing import Sequence
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, trim_messages
from langgraph.graph.message import add_messages
from typing_extensions import Annotated, TypedDict
from dotenv import load_dotenv
import warnings
warnings.filterwarnings('ignore')
title = "Sales coaching"
sidebar_name = "Sales coaching"
dataPath = st.session_state.DataPath
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_HUB_API_URL"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = "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" # getpass.getpass()
os.environ["LANGCHAIN_PROJECT"] = "Sales Coaching Chatbot"
os.environ["MISTRAL_API_KEY"] = "W8q7N24HGM2ATpUdmB8rxrqkERtsxcuj"
model = ChatMistralAI(model="mistral-large-latest")
dataPath = st.session_state.DataPath
trimmer = trim_messages(
max_tokens=60,
strategy="last",
token_counter=model,
include_system=True,
allow_partial=False,
start_on="human",
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Answer all questions to the best of your ability in {language}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
class State(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
language: str
def call_model(state: State):
chain = prompt | model
trimmed_messages = trimmer.invoke(state["messages"])
response = chain.invoke(
{"messages": trimmed_messages, "language": state["language"]}
)
return {"messages": [response]}
# Define a new graph
workflow = StateGraph(state_schema=State)
# Define the (single) node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)
workflow.add_edge("model", END)
# Add memory
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "abc123"}}
def run():
st.write("")
st.title(tr(title))
messages = [
SystemMessage(content="you're a good assistant"),
HumanMessage(content="hi! I'm bob"),
AIMessage(content="hi!"),
HumanMessage(content="I like vanilla ice cream"),
AIMessage(content="nice"),
HumanMessage(content="whats 2 + 2"),
AIMessage(content="4"),
HumanMessage(content="thanks"),
AIMessage(content="no problem!"),
HumanMessage(content="having fun?"),
AIMessage(content="yes!"),
]
trimmer.invoke(messages)
query = "Hi I'm Todd, please tell me a joke."
language = "French"
input_messages = [HumanMessage(query)]
for chunk, metadata in app.stream(
{"messages": input_messages, "language": language},
config,
stream_mode="messages",
):
if isinstance(chunk, AIMessage): # Filter to just model responses
st.write(chunk.content, end="")
'''
sentences = ["This is an example sentence", "Each sentence is converted"]
sentences[0] = st.text_area(label=tr("Saisir un élément issu de la proposition de valeur (quelque soit la langue):"), value="This is an example sentence")
sentences[1] = st.text_area(label=tr("Saisir une phrase issue de l'acte de vente (quelque soit la langue):"), value="Each sentence is converted", height=200)
st.button(label=tr("Validez"), type="primary")
st.write(tr("Transformation de chaque phrase en vecteur (dimension = 384 ):"))
'''
st.write("")
st.write("")
st.write("")
st.write("")