Spaces:
Sleeping
Sleeping
import glob | |
from venv import create | |
import gradio as gr | |
from typing import Any | |
from dotenv import load_dotenv | |
import requests | |
from griptape.structures import Agent | |
from griptape.tasks import PromptTask | |
from griptape.drivers import ( | |
LocalConversationMemoryDriver, | |
GriptapeCloudStructureRunDriver, | |
LocalFileManagerDriver, | |
LocalStructureRunDriver, | |
GriptapeCloudConversationMemoryDriver, | |
) | |
from griptape.memory.structure import ConversationMemory | |
from griptape.tools import StructureRunTool, FileManagerTool | |
from griptape.rules import Rule, Ruleset | |
from griptape.configs.drivers import AnthropicDriversConfig | |
from griptape.configs import Defaults | |
import time | |
import os | |
from urllib.parse import urljoin | |
# Load environment variables | |
load_dotenv() | |
Defaults.drivers_config = AnthropicDriversConfig() | |
base_url = "https://cloud.griptape.ai" | |
headers_api = { | |
"Authorization": f"Bearer {os.environ['GT_CLOUD_API_KEY']}", | |
"Content-Type": "application/json", | |
} | |
threads = {} | |
# custom_css = """ | |
# #component-2 { | |
# height: 75vh !important; | |
# min-height: 600px !important; | |
# """ | |
def create_thread_id(session_id: str) -> str: | |
if not session_id in threads: | |
params = { | |
"name": session_id, | |
"messages": [], | |
} | |
response = requests.post( | |
url=urljoin(base_url, "/api/threads"), headers=headers_api, json=params | |
) | |
response.raise_for_status() | |
thread_id = response.json()["thread_id"] | |
threads[session_id] = thread_id | |
return thread_id | |
else: | |
return threads[session_id] | |
# Create an agent that will create a prompt that can be used as input for the query agent from the Griptape Cloud. | |
# Function that logs user history - adds to history parameter of Gradio | |
# TODO: Figure out the exact use of this function | |
def user(user_message, history): | |
history.append([user_message, None]) | |
return ("", history) | |
# Function that logs bot history - adds to the history parameter of Gradio | |
# TODO: Figure out the exact use of this function | |
def bot(history): | |
response = send_message(history[-1][0]) | |
history[-1][1] = "" | |
for character in response: | |
history[-1][1] += character | |
time.sleep(0.005) | |
yield history | |
def create_prompt_task(session_id: str, message: str) -> PromptTask: | |
return PromptTask( | |
f""" | |
Re-structure the values to form a query from the user's questions: '{message}' and the input value from the conversation memory. Leave out attributes that aren't important to the user: | |
""", | |
) | |
def build_talk_agent(session_id: str, message: str) -> Agent: | |
create_thread_id(session_id) | |
ruleset = Ruleset( | |
name="Local Gradio Agent", | |
rules=[ | |
Rule( | |
value="You are responsible for structuring a user's questions into a specific format for a query." | |
), | |
Rule( | |
value="""You ask the user follow-up questions to fill in missing information for: | |
years experience, | |
location, | |
role, | |
skills, | |
expected salary, | |
availability, | |
past companies, | |
past projects, | |
show reel details | |
""" | |
), | |
Rule( | |
value="Return the current query structure and any questions to fill in missing information." | |
), | |
], | |
) | |
return Agent( | |
conversation_memory=ConversationMemory( | |
conversation_memory_driver=GriptapeCloudConversationMemoryDriver( | |
thread_id=threads[session_id], | |
) | |
), | |
tasks=[create_prompt_task(session_id, message)], | |
rulesets=[ruleset], | |
) | |
# Creates an agent for each run | |
# The agent uses local memory, which it differentiates between by session_hash. | |
def build_agent(session_id: str, message: str, kbs:str) -> Agent: | |
create_thread_id(session_id) | |
ruleset = Ruleset( | |
name="Local Gradio Agent", | |
rules=[ | |
Rule( | |
value="You are responsible for structuring a user's questions into a query and then querying." | |
), | |
Rule( | |
value="Only return the result of the query, do not provide additional commentary." | |
), | |
Rule(value="Only perform one task at a time."), | |
Rule( | |
value="Do not perform the query unless the user has said 'Done' with formulating." | |
), | |
Rule( | |
value="Only perform the query as one string argument." | |
), | |
Rule( | |
value="If you reformulate the query, then you must ask the user if they are 'Done' again." | |
), | |
Rule( | |
value="If the user says they want to start over, then you must delete the conversation memory file." | |
), | |
], | |
) | |
query_client = StructureRunTool( | |
name="QueryResumeSearcher", | |
description=f"""Use it to search for a candidate with the query. | |
Add this as another argument after the input: {kbs} | |
""", | |
driver=GriptapeCloudStructureRunDriver( | |
structure_id=os.getenv("GT_STRUCTURE_ID"), | |
api_key=os.getenv("GT_CLOUD_API_KEY"), | |
structure_run_wait_time_interval=3, | |
structure_run_max_wait_time_attempts=30, | |
), | |
) | |
talk_client = StructureRunTool( | |
name="FormulateQueryFromUser", | |
description="Used to formulate a query from the user's input.", | |
driver=LocalStructureRunDriver( | |
structure_factory_fn=lambda: build_talk_agent(session_id, message), | |
), | |
) | |
return Agent( | |
conversation_memory=ConversationMemory( | |
conversation_memory_driver=GriptapeCloudConversationMemoryDriver( | |
thread_id=threads[session_id], | |
) | |
), | |
tools=[talk_client, query_client], | |
rulesets=[ruleset], | |
) | |
def send_message(message: str, history, knowledge_bases, request: gr.Request) -> Any: | |
if request: | |
session_hash = request.session_hash | |
agent = build_agent(session_hash, message, str(knowledge_bases)) | |
response = agent.run(message) | |
return response.output.value | |
with gr.Blocks() as demo: | |
knowledge_bases = gr.CheckboxGroup(label="Select Knowledge Bases", choices=["skills","demographics","linked_in","showreels"]) | |
chatbot = gr.ChatInterface(fn=send_message, chatbot=gr.Chatbot(height=300),additional_inputs=knowledge_bases) | |
demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD"))) | |
# demo.launch(share=True) | |
# Set it back to empty when a session is done | |
# Is there a better way? | |
threads = {} | |