File size: 3,287 Bytes
3e88504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac8ba98
 
e30907e
 
 
 
 
ac8ba98
 
3e88504
 
 
 
ac8ba98
 
3e88504
 
 
ac8ba98
 
 
 
3e88504
 
 
ac8ba98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import os
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

def create_chains(openai_key):
    os.environ["OPENAI_API_KEY"] = openai_key
    
    # Create the classifier chain
    classifier_prompt = PromptTemplate.from_template(
        """Given the user question below, classify it as either being about `LangChain`, `OpenAI`, or `Other`.
        Do not respond with more than one word.
        
        <question>
        {question}
        </question>
        
        Classification:"""
    )
    
    classifier_chain = classifier_prompt | ChatOpenAI(model="gpt-4") | StrOutputParser()
    
    # Create specialized chains
    langchain_chain = (
        PromptTemplate.from_template(
            """You are an expert in LangChain. 
            Always answer questions starting with "As a LangChain expert". 
            Question: {question}
            Answer:"""
        ) | ChatOpenAI(model="gpt-4")
    )
    
    openai_chain = (
        PromptTemplate.from_template(
            """You are an expert in OpenAI. 
            Always answer questions starting with "As an OpenAI expert". 
            Question: {question}
            Answer:"""
        ) | ChatOpenAI(model="gpt-4")
    )
    
    general_chain = (
        PromptTemplate.from_template(
            """Respond to the following question:
            Question: {question}
            Answer:"""
        ) | ChatOpenAI(model="gpt-4")
    )
    
    return classifier_chain, langchain_chain, openai_chain, general_chain

def route_question(question, openai_key):
    try:
        classifier_chain, langchain_chain, openai_chain, general_chain = create_chains(openai_key)
        
        # Classify the question
        classification = classifier_chain.invoke({"question": question})
        
        # Route to appropriate chain
        if "langchain" in classification.lower():
            response = langchain_chain.invoke({"question": question})
        elif "openai" in classification.lower():
            response = openai_chain.invoke({"question": question})
        else:
            response = general_chain.invoke({"question": question})
        
        return f"Classification: {classification}\nResponse: {response.content}"
    except Exception as e:
        return f"Error: {str(e)}"

# Example questions for each category
example_questions = [
    ["What is LangChain and how does it work?"],
    ["How do I use OpenAI's GPT models?"],
    ["What is the capital of France?"],
    ["Explain LangChain's routing capabilities"],
    ["Tell me about OpenAI's latest developments"]
]

# Create Gradio interface
demo = gr.Interface(
    fn=route_question,
    inputs=[
        gr.Textbox(label="Enter your question", placeholder="Type your question here..."),
        gr.Textbox(label="OpenAI API Key", type="password", placeholder="Enter your OpenAI API key")
    ],
    outputs=gr.Textbox(label="Response"),
    title="LangChain Router Demo",
    description="""This demo shows how routing works in LangChain. Ask questions about LangChain, OpenAI, or any other topic.
    """,
    examples=example_questions,
    cache_examples=False
)

if __name__ == "__main__":
    demo.launch()