jobmatcher-ai / app.py
ageraustine's picture
Update app.py
9c43365 verified
raw
history blame
5.15 kB
import openai
import streamlit as st
from streamlit_chat import message
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langgraph.graph import MessageGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import HumanMessage
from typing import List
import os
import uuid
template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later
You should get the following information from them:
- Job
- Company
- tools for example for a software engineer(which frameworks/languages)
If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess.
If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge
Ask one question at a time
After you are able to discerne all the information, call the relevant tool"""
OPENAI_API_KEY='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'
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
llm = ChatOpenAI(temperature=0)
def get_messages_info(messages):
return [SystemMessage(content=template)] + messages
class PromptInstructions(BaseModel):
"""Instructions on how to prompt the LLM."""
job: str
company: str
technologies: List[str]
hobies: List[str]
llm_with_tool = llm.bind_tools([PromptInstructions])
chain = get_messages_info | llm_with_tool
# Helper function for determining if tool was called
def _is_tool_call(msg):
return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs
# New system prompt
prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills:
{reqs}"""
# Function to get the messages for the profile
# Will only get messages AFTER the tool call
def get_profile_messages(messages):
tool_call = None
other_msgs = []
for m in messages:
if _is_tool_call(m):
tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments']
elif tool_call is not None:
other_msgs.append(m)
return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs
profile_gen_chain = get_profile_messages | llm
def get_state(messages):
if _is_tool_call(messages[-1]):
return "profile"
elif not isinstance(messages[-1], HumanMessage):
return END
for m in messages:
if _is_tool_call(m):
return "profile"
return "info"
@st.cache_resource
def get_graph():
memory = SqliteSaver.from_conn_string(":memory:")
nodes = {k:k for k in ['info', 'profile', END]}
workflow = MessageGraph()
workflow.add_node("info", chain)
workflow.add_node("profile", profile_gen_chain)
workflow.add_conditional_edges("info", get_state, nodes)
workflow.add_conditional_edges("profile", get_state, nodes)
workflow.set_entry_point("info")
graph = workflow.compile(checkpointer=memory)
return graph
graph = get_graph()
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
# Streamlit app layout
st.title("JobEasy AI")
clear_button = st.sidebar.button("Clear Conversation", key="clear")
# Initialise session state variables
if 'generated' not in st.session_state:
st.session_state['generated'] = ['Please tell me about your most recent career']
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'messages' not in st.session_state:
st.session_state['messages'] = []
# reset everything
if clear_button:
st.session_state['generated'] = ['Please tell me about your most recent career']
st.session_state['past'] = []
st.session_state['messages'] = []
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
@st.cache_resource
def query(payload):
for output in graph.stream([HumanMessage(content=payload)], config=config):
if "__end__" in output:
continue
# stream() yields dictionaries with output keyed by node name
for key, value in output.items():
st.session_state['messages'].append({"role": "assistant", "content": value.content})
st.session_state['past'].append(user_input)
st.session_state['generated'].append(value.content)
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("You:", key='input', height=100)
submit_button = st.form_submit_button(label='Send')
if submit_button and user_input:
query(user_input)
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["generated"][i], key=str(i))
if len(st.session_state["past"]) > 0 and i < len(st.session_state["past"]):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')