File size: 5,146 Bytes
1841e7e
 
 
f3c8871
 
 
 
 
8cb514e
f3c8871
 
 
 
 
 
 
 
23a1fa8
f3c8871
 
 
 
 
23a1fa8
f3c8871
 
 
1841e7e
 
 
 
 
f3c8871
 
 
 
 
 
 
 
 
 
8cb514e
 
f3c8871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c43365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3c8871
 
 
e315388
f79a327
f3c8871
1841e7e
 
105edfc
1841e7e
f6568b9
1841e7e
863891d
1841e7e
7084404
1841e7e
 
a7bd912
1841e7e
a7bd912
1841e7e
27feb48
1841e7e
 
 
 
b60483e
7c38857
 
1568509
7c38857
 
 
 
 
 
 
 
9d8ffa1
13ac208
 
 
 
 
7c38857
105edfc
1841e7e
 
 
105edfc
9466441
105edfc
a7bd912
105edfc
a7bd912
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import openai
import streamlit as st
from streamlit_chat import message
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langgraph.graph import MessageGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import HumanMessage
from typing import List
import os
import uuid

template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later

You should get the following information from them:

- Job
- Company
- tools for example for a software engineer(which frameworks/languages)

If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess.
If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge
Ask one question at a time
 
After you are able to discerne all the information, call the relevant tool"""


OPENAI_API_KEY='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'
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY


llm = ChatOpenAI(temperature=0)

def get_messages_info(messages):
    return [SystemMessage(content=template)] + messages


class PromptInstructions(BaseModel):
    """Instructions on how to prompt the LLM."""
    job: str
    company: str
    technologies: List[str]
    hobies: List[str]

llm_with_tool = llm.bind_tools([PromptInstructions])

chain = get_messages_info | llm_with_tool

# Helper function for determining if tool was called
def _is_tool_call(msg):
    return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs


# New system prompt
prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills:

{reqs}"""

# Function to get the messages for the profile
# Will only get messages AFTER the tool call
def get_profile_messages(messages):
    tool_call = None
    other_msgs = []
    for m in messages:
        if _is_tool_call(m):
            tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments']
        elif tool_call is not None:
            other_msgs.append(m)
    return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs

profile_gen_chain = get_profile_messages | llm

def get_state(messages):
    if _is_tool_call(messages[-1]):
        return "profile"
    elif not isinstance(messages[-1], HumanMessage):
        return END
    for m in messages:
        if _is_tool_call(m):
            return "profile"
    return "info"

@st.cache_resource
def get_graph():
    memory = SqliteSaver.from_conn_string(":memory:")
    
    nodes = {k:k for k in ['info', 'profile', END]}
    workflow = MessageGraph()
    workflow.add_node("info", chain)
    workflow.add_node("profile", profile_gen_chain)
    workflow.add_conditional_edges("info", get_state, nodes)
    workflow.add_conditional_edges("profile", get_state, nodes)
    workflow.set_entry_point("info")
    graph = workflow.compile(checkpointer=memory)
    return graph
    
graph = get_graph()
config = {"configurable": {"thread_id": str(uuid.uuid4())}}

# Streamlit app layout
st.title("JobEasy AI")
clear_button = st.sidebar.button("Clear Conversation", key="clear")

# Initialise session state variables
if 'generated' not in st.session_state:
    st.session_state['generated'] = ['Please tell me about your most recent career']
if 'past' not in st.session_state:
    st.session_state['past'] = []
if 'messages' not in st.session_state:
    st.session_state['messages'] = []
 

# reset everything
if clear_button:
    st.session_state['generated'] = ['Please tell me about your most recent career']
    st.session_state['past'] = []
    st.session_state['messages'] = []
 

# container for chat history
response_container = st.container()
# container for text box
container = st.container()

@st.cache_resource
def query(payload):
    for output in graph.stream([HumanMessage(content=payload)], config=config):
      if "__end__" in output:
          continue
      # stream() yields dictionaries with output keyed by node name
      for key, value in output.items():
          st.session_state['messages'].append({"role": "assistant", "content": value.content})
          st.session_state['past'].append(user_input)
          st.session_state['generated'].append(value.content)
                      
with container:
        with st.form(key='my_form', clear_on_submit=True):
            user_input = st.text_area("You:", key='input', height=100)
            submit_button = st.form_submit_button(label='Send')
    
        if submit_button and user_input:
            query(user_input)
                  
if st.session_state['generated']:
    with response_container:
        for i in range(len(st.session_state['generated'])):
            message(st.session_state["generated"][i], key=str(i))
            if len(st.session_state["past"]) > 0 and i < len(st.session_state["past"]):
                message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')