File size: 1,359 Bytes
481f3b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9609df9
 
481f3b1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
from typing import Literal
from langchain.prompts import ChatPromptTemplate
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser


class Translation(BaseModel):
    """Analyzing the user message input"""
    
    translation: str = Field(
        description="Translate the message input to English",
    )


def make_translation_chain(llm):

    openai_functions = [convert_to_openai_function(Translation)]
    llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"Translation"})

    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant, you will translate the user input message to English using the function provided"),
        ("user", "input: {input}")
    ])

    chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
    return chain


def make_translation_node(llm):
    translation_chain = make_translation_chain(llm)

    def translate_query(state):
        print("---- Translate query ----")

        user_input = state["user_input"]
        translation = translation_chain.invoke({"input":user_input})
        return {"query":translation["translation"]}

    return translate_query