|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from typing import List |
|
from typing import Literal |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain_core.utils.function_calling import convert_to_openai_function |
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser |
|
|
|
|
|
class KeywordExtraction(BaseModel): |
|
""" |
|
Analyzing the user query to extract keywords to feed a search engine |
|
""" |
|
|
|
keywords: List[str] = Field( |
|
description=""" |
|
Extract the keywords from the user query to feed a search engine as a list |
|
Avoid adding super specific keywords to prefer general keywords |
|
Maximum 3 keywords |
|
|
|
Examples: |
|
- "What is the impact of deep sea mining ?" -> ["deep sea mining"] |
|
- "How will El Nino be impacted by climate change" -> ["el nino","climate change"] |
|
- "Is climate change a hoax" -> ["climate change","hoax"] |
|
""" |
|
) |
|
|
|
|
|
def make_keywords_extraction_chain(llm): |
|
|
|
openai_functions = [convert_to_openai_function(KeywordExtraction)] |
|
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"}) |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are a helpful assistant"), |
|
("user", "input: {input}") |
|
]) |
|
|
|
chain = prompt | llm_with_functions | JsonOutputFunctionsParser() |
|
return chain |