Spaces:
Runtime error
Runtime error
import openai | |
import streamlit as st | |
from streamlit_chat import message | |
from langchain_core.messages import SystemMessage | |
from langchain_openai import ChatOpenAI | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langgraph.graph import MessageGraph, END | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
from langchain_core.messages import HumanMessage | |
from typing import List | |
import os | |
import uuid | |
template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later | |
You should get the following information from them: | |
- Job | |
- Company | |
- tools for example for a software engineer(which frameworks/languages) | |
If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess. | |
If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge | |
Ask one question at a time | |
After you are able to discerne all the information, call the relevant tool""" | |
OPENAI_API_KEY='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' | |
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY | |
llm = ChatOpenAI(temperature=0) | |
def get_messages_info(messages): | |
return [SystemMessage(content=template)] + messages | |
class PromptInstructions(BaseModel): | |
"""Instructions on how to prompt the LLM.""" | |
job: str | |
company: str | |
technologies: List[str] | |
hobies: List[str] | |
llm_with_tool = llm.bind_tools([PromptInstructions]) | |
chain = get_messages_info | llm_with_tool | |
# Helper function for determining if tool was called | |
def _is_tool_call(msg): | |
return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs | |
# New system prompt | |
prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills: | |
{reqs}""" | |
# Function to get the messages for the profile | |
# Will only get messages AFTER the tool call | |
def get_profile_messages(messages): | |
tool_call = None | |
other_msgs = [] | |
for m in messages: | |
if _is_tool_call(m): | |
tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments'] | |
elif tool_call is not None: | |
other_msgs.append(m) | |
return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs | |
profile_gen_chain = get_profile_messages | llm | |
def get_state(messages): | |
if _is_tool_call(messages[-1]): | |
return "profile" | |
elif not isinstance(messages[-1], HumanMessage): | |
return END | |
for m in messages: | |
if _is_tool_call(m): | |
return "profile" | |
return "info" | |
def get_graph(): | |
memory = SqliteSaver.from_conn_string(":memory:") | |
nodes = {k:k for k in ['info', 'profile', END]} | |
workflow = MessageGraph() | |
workflow.add_node("info", chain) | |
workflow.add_node("profile", profile_gen_chain) | |
workflow.add_conditional_edges("info", get_state, nodes) | |
workflow.add_conditional_edges("profile", get_state, nodes) | |
workflow.set_entry_point("info") | |
graph = workflow.compile(checkpointer=memory) | |
return graph | |
graph = get_graph() | |
config = {"configurable": {"thread_id": str(uuid.uuid4())}} | |
# Streamlit app layout | |
st.title("JobEasy AI") | |
clear_button = st.sidebar.button("Clear Conversation", key="clear") | |
# Initialise session state variables | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = ['Please tell me about your most recent career'] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
if 'messages' not in st.session_state: | |
st.session_state['messages'] = [] | |
# reset everything | |
if clear_button: | |
st.session_state['generated'] = ['Please tell me about your most recent career'] | |
st.session_state['past'] = [] | |
st.session_state['messages'] = [] | |
# container for chat history | |
response_container = st.container() | |
# container for text box | |
container = st.container() | |
def query(payload): | |
for output in graph.stream([HumanMessage(content=payload)], config=config): | |
if "__end__" in output: | |
continue | |
# stream() yields dictionaries with output keyed by node name | |
for key, value in output.items(): | |
st.session_state['messages'].append({"role": "assistant", "content": value.content}) | |
st.session_state['past'].append(user_input) | |
st.session_state['generated'].append(value.content) | |
with container: | |
with st.form(key='my_form', clear_on_submit=True): | |
user_input = st.text_area("You:", key='input', height=100) | |
submit_button = st.form_submit_button(label='Send') | |
if submit_button and user_input: | |
query(user_input) | |
if st.session_state['generated']: | |
with response_container: | |
for i in range(len(st.session_state['generated'])): | |
message(st.session_state["generated"][i], key=str(i)) | |
if len(st.session_state["past"]) > 0 and i < len(st.session_state["past"]): | |
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') | |