import openai import streamlit as st from streamlit_chat import message from langchain_core.messages import SystemMessage from langchain_openai import ChatOpenAI from langchain_core.pydantic_v1 import BaseModel, Field from langgraph.graph import MessageGraph, END from langgraph.checkpoint.sqlite import SqliteSaver from langchain_core.messages import HumanMessage from typing import List import os import uuid template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later You should get the following information from them: - Job - Company - tools for example for a software engineer(which frameworks/languages) If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess. If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge Ask one question at a time After you are able to discerne all the information, call the relevant tool""" OPENAI_API_KEY='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' os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY llm = ChatOpenAI(temperature=0) def get_messages_info(messages): return [SystemMessage(content=template)] + messages class PromptInstructions(BaseModel): """Instructions on how to prompt the LLM.""" job: str company: str technologies: List[str] hobies: List[str] llm_with_tool = llm.bind_tools([PromptInstructions]) chain = get_messages_info | llm_with_tool # Helper function for determining if tool was called def _is_tool_call(msg): return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs # New system prompt prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills: {reqs}""" # Function to get the messages for the profile # Will only get messages AFTER the tool call def get_profile_messages(messages): tool_call = None other_msgs = [] for m in messages: if _is_tool_call(m): tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments'] elif tool_call is not None: other_msgs.append(m) return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs profile_gen_chain = get_profile_messages | llm def get_state(messages): if _is_tool_call(messages[-1]): return "profile" elif not isinstance(messages[-1], HumanMessage): return END for m in messages: if _is_tool_call(m): return "profile" return "info" memory = SqliteSaver.from_conn_string(":memory:") nodes = {k:k for k in ['info', 'profile', END]} workflow = MessageGraph() workflow.add_node("info", chain) workflow.add_node("profile", profile_gen_chain) workflow.add_conditional_edges("info", get_state, nodes) workflow.add_conditional_edges("profile", get_state, nodes) workflow.set_entry_point("info") graph = workflow.compile(checkpointer=memory) config = {"configurable": {"thread_id": str(uuid.uuid4())}} # Streamlit app layout st.title("JobEasy AI") clear_button = st.sidebar.button("Clear Conversation", key="clear") # Initialise session state variables if 'generated' not in st.session_state: st.session_state['generated'] = ['Please tell me about your most recent career'] if 'past' not in st.session_state: st.session_state['past'] = [] if 'messages' not in st.session_state: st.session_state['messages'] = [] # reset everything if clear_button: st.session_state['generated'] = ['Please tell me about your most recent career'] st.session_state['past'] = [] st.session_state['messages'] = [] # container for chat history response_container = st.container() # container for text box container = st.container() @st.cache_resource def query(payload): for output in graph.stream([HumanMessage(content=payload)], config=config): if "__end__" in output: continue # stream() yields dictionaries with output keyed by node name for key, value in output.items(): st.session_state['messages'].append({"role": "assistant", "content": value.content}) st.session_state['past'].append(user_input) st.session_state['generated'].append(value.content) with container: with st.form(key='my_form', clear_on_submit=True): user_input = st.text_area("You:", key='input', height=100) submit_button = st.form_submit_button(label='Send') if submit_button and user_input: query(user_input) if st.session_state['generated']: with response_container: for i in range(len(st.session_state['generated'])): message(st.session_state["generated"][i], key=str(i)) if len(st.session_state["past"]) > 0 and i < len(st.session_state["past"]): message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')