finetuned-llm-demo-app / prompt_parsing.py
tnt306's picture
Add emoji
278c5af
from langchain_core.prompts import ChatPromptTemplate
from typing import Dict, List, Tuple
from time import sleep
from enum import Enum
import os
import functools
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from typing import TypedDict, Annotated, List, Any
from pydantic_ai import Agent, Tool, RunContext
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter
from langgraph.types import StreamWriter
from pydantic_ai.providers.openai import OpenAIProvider
from langchain_core.runnables.config import RunnableConfig
from httpx import AsyncClient
from utils import get_query_from_vector_store_index, get_latest_news
from langgraph.graph.state import CompiledStateGraph
class ModelNames(Enum):
Qwen25_7B_Instruct_1M_q4_k_m_Finetuned = "Qwen2.5-7B-Instruct-1M-q4_k_m-Finetuned"
Qwen25_7B_Instruct_1M_q4_k_m_Original = "Qwen2.5-7B-Instruct-1M-q4_k_m-Original"
class LLMWaitTime(Enum):
"""
OpenRouter allows 20 requests per minute, 200 requests per day for free tier, AKA. 3 seconds per request. (https://openrouter.ai/docs/api-reference/limits)
Gemini 2.0 Flash: RPM: 15; RPD: 1,500 βž” AKA. 4 seconds per request. (https://ai.google.dev/gemini-api/docs/rate-limits#free-tier)
"""
OpenRouter_DeepSeek_R1 = 3
OpenRouter_Qwen25_72B_Instruct = 3
OpenRouter_Llama33_70B_Instruct = 3
Google_Gemini_20_Flash = 4
LOCAL_LLM_URL = "http://127.0.0.1"
prompt_arxiv_qa = ChatPromptTemplate(
[
("system", "You are a helpful Research bot."),
(
"human",
'Below is the title and abstract of a paper from arXiv. Create {num_questions} pairs of questions and corresponding answers, based on the title and abstract. Avoid using abbreviations and acronyms. Questions start with the string "Question:". Answers start with the string "Answer:". Include only the list and nothing else.\n\nTitle: {title}\n\nAbstract: {abstract}',
),
]
)
prompt_arxiv_summary = ChatPromptTemplate(
[
("system", "You are a helpful Research bot."),
(
"human",
"Below is the title and abstract of a paper from arXiv. Summarize it, and additionally include other relevant information to help users understand the paper better.\n\nTitle: {title}\n\nAbstract: {abstract}",
),
]
)
prompt_paraphrase = ChatPromptTemplate(
[
("system", "You are a helpful Research bot. {further_instruction}"),
("human", "Paraphrase the following {thing} below:\n\n{thing}:{sentence}"),
]
)
def parse_arxiv_qa_prompt_output(output: str) -> List[Dict]:
lines = output.split("\n")
lst_qa = []
question = ""
answer = ""
for line in lines:
line = line.strip()
if len(line) > 0:
if line.startswith("Question:"):
question = line[line.index(" ") + 1 :]
elif line.startswith("Answer:"):
answer = line[line.index(" ") + 1 :]
lst_qa.append({"question": question, "answer": answer})
question = ""
answer = ""
else:
print(f"Error: [{line}] not question nor answer")
return lst_qa
def llm_wait_after_request(provider: LLMWaitTime):
def decorator(some_function):
@functools.wraps(some_function)
def wrapper(*args, **kwargs):
res = some_function(*args, **kwargs)
sleep(provider.value)
return res
return wrapper
return decorator
########################################################
# Define state schema
class AgentState(TypedDict):
latest_user_message: str
messages: Annotated[List[bytes], lambda x, y: x + y]
reasoner_system_prompt_as_ai_assistant = 'You are a helpful Artificial Intelligence (AI) Research bot, with expertise on Large Language Model (LLM). You have especially deep knowledge about the Research Paper "Byte Latent Transformer (BLT): Patches Scale Better Than Tokens". Users can ask you questions, and you will provide the corresponding answers. If the questions are related to Byte Latent Transformer (BLT), the answers must be in a detailed manner, and primarily come from the information in the Research Paper, additionally with your general knowledge. The goal is to help users understand fully.'
def get_reasoner_system_prompt(ctx: RunContext[str]) -> str:
return ctx.deps
# Shared resources across sessions.
rag_query_engine = get_query_from_vector_store_index()
async def reasoner(state: AgentState, writer: StreamWriter, config: RunnableConfig):
latest_user_message = state["latest_user_message"]
print(
f"(っ◕‿◕)っ reasoner(): latest_user_message = {latest_user_message}", flush=True
)
# SETTING: Current chosen model.
reasoner_agents: Dict[str, Agent[str, str]] = config["configurable"]["reasoner_agents"] # type: ignore
model = config["configurable"]["chosen_model"] # type: ignore
print(f"reasoner(): chosen model = {model}", flush=True)
reasoner_agent = reasoner_agents[model]
# SETTING: with system prompt as AI Research bot (or not). βž” Modify the System Prompt.
with_system_prompt_for_reasoner = config["configurable"]["with_system_prompt_for_reasoner"] # type: ignore
print(
f"reasoner(): with_system_prompt_for_reasoner = {with_system_prompt_for_reasoner}",
flush=True,
)
reasoner_system_prompt = (
reasoner_system_prompt_as_ai_assistant
if with_system_prompt_for_reasoner
else ""
)
# SETTING: with rag (or not). βž” Modify the Question.
with_rag_for_reasoner = config["configurable"]["with_rag_for_reasoner"] # type: ignore
print(f"reasoner(): with_rag_for_reasoner = {with_rag_for_reasoner}", flush=True)
if with_rag_for_reasoner:
latest_user_message += f"\n\nUse the context below for relevant information:\nContext:\n{rag_query_engine(state['latest_user_message'])}"
# SETTING: with tools (or not). βž” add tools to Agent.
with_tools_for_reasoner = config["configurable"]["with_tools_for_reasoner"] # type: ignore
print(
f"reasoner(): with_tools_for_reasoner = {with_tools_for_reasoner}", flush=True
)
reasoner_agent._function_tools.clear()
if with_tools_for_reasoner:
reasoner_agent._register_tool(Tool(get_latest_news))
print(
f"reasoner(): reasoner_agent._function_tools.keys() = {reasoner_agent._function_tools.keys()}",
flush=True,
)
# Get the message history into the format for Pydantic AI
message_history: list[ModelMessage] = []
for message_row in state["messages"]:
message_history.extend(ModelMessagesTypeAdapter.validate_json(message_row))
# Now run the Agent!
if with_tools_for_reasoner: # If with tools, "stream" is not supported.
# Can't use reasoner_agent.run_sync() here because: 1). We're in Async code right now; 2). run_sync() is just a wrapper for run() and run_stream().
print(
f"reasoner(): reasoner_agent.run(). message_history len = {len(message_history)}",
flush=True,
)
result = await reasoner_agent.run(
latest_user_message,
message_history=message_history,
deps=reasoner_system_prompt,
)
writer(result.output) # type: ignore
else:
print(
f"reasoner(): reasoner_agent.run_stream(). message_history len = {len(message_history)}",
flush=True,
)
async with reasoner_agent.run_stream(
latest_user_message,
message_history=message_history,
deps=reasoner_system_prompt,
) as result:
async for chunk in result.stream_text(delta=True):
writer(chunk)
print("(っ◕‿◕)っ reasoner(): out!", flush=True)
"""MyNote:
The "new_messages_json" includes the latest user message and the AI's response.
If first time, it will include the system prompt as well.
"""
# Report statistics for this call.
config["configurable"]["reasoner_statistic_report"](result.usage()) # type: ignore
return {"messages": [result.new_messages_json()]}
def generate_agentic_flow() -> (
Tuple[Dict[str, Agent[str, str]], CompiledStateGraph, str]
):
### BUILD RESOURCES (FOR A SPECIFIC USER SESSION) ###
reasoner_agents = {
ModelNames.Qwen25_7B_Instruct_1M_q4_k_m_Finetuned.value: Agent(
OpenAIModel(
ModelNames.Qwen25_7B_Instruct_1M_q4_k_m_Finetuned.value,
provider=OpenAIProvider(
api_key=os.environ["LOCAL_LLM_API_KEY"],
base_url=f"{LOCAL_LLM_URL}:8081/v1",
http_client=AsyncClient(headers={"Connection": "close"}),
),
),
retries=3,
deps_type=str,
),
ModelNames.Qwen25_7B_Instruct_1M_q4_k_m_Original.value: Agent(
OpenAIModel(
ModelNames.Qwen25_7B_Instruct_1M_q4_k_m_Original.value,
provider=OpenAIProvider(
api_key=os.environ["LOCAL_LLM_API_KEY"],
base_url=f"{LOCAL_LLM_URL}:8080/v1",
http_client=AsyncClient(headers={"Connection": "close"}),
),
),
retries=3,
deps_type=str,
),
}
# Register system_prompt at runtime.
for reasoner_agent in reasoner_agents.values():
reasoner_agent.system_prompt(get_reasoner_system_prompt)
### BUILD THE GRAPH (FOR A SPECIFIC USER SESSION) ###
builder = StateGraph(AgentState)
# Add nodes
builder.add_node("reasoner", reasoner)
# Set edges
builder.add_edge(START, "reasoner")
builder.add_edge("reasoner", END)
# Maintain memory across different graph runs. βž” Must also use "thread_id" in RunnableConfig/"configurable".
memory = MemorySaver()
agentic_flow = builder.compile(checkpointer=memory)
# For debug
# agentic_flow.get_graph().draw_png("graph.png")
return (reasoner_agents, agentic_flow, reasoner_system_prompt_as_ai_assistant)