File size: 3,302 Bytes
79c0556
 
 
 
 
 
 
ba08d17
79c0556
 
08cc5dc
 
79c0556
6719fe3
9d07cf2
6719fe3
9d07cf2
045955f
e64f669
fa26b1c
 
e64f669
fa26b1c
013647d
fa26b1c
 
224e382
79c0556
 
1f1886f
79c0556
 
 
 
 
317e9ce
79c0556
 
 
045955f
 
79c0556
 
 
 
 
 
 
 
 
 
 
231f795
5fc43de
79c0556
217bda1
231f795
 
5fc43de
231f795
 
1f1886f
f41f752
224e382
 
231f795
224e382
5fc43de
 
224e382
231f795
ba08d17
8bc4ecf
5fc43de
 
0e29da6
5fc43de
 
 
0e29da6
231f795
 
ba08d17
 
 
045955f
 
 
9b22ed0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pandas as pd
import openai
import faiss
import numpy as np
import time
import os
import pickle
import gradio as gr
from langchain.embeddings.openai import OpenAIEmbeddings
from io import StringIO
from huggingface_hub import hf_hub_download
from huggingface_hub import login

openai.api_key = os.getenv("OPENAI_API_KEY")
hf_token = os.getenv("HF_TOKEN")

login(token=hf_token)

def load_embeddings_and_faiss():
    embeddings_path = hf_hub_download(repo_id="chukbert/embedding-faq-medquad", filename="embeddings.pkl",repo_type="dataset", token=hf_token)
    faiss_index_path = hf_hub_download(repo_id="chukbert/embedding-faq-medquad", filename="faiss.index",repo_type="dataset", token=hf_token)

    faiss_index = faiss.read_index(faiss_index_path)
    
    with open(embeddings_path, 'rb') as f:
        question_embeddings = pickle.load(f)

    return faiss_index, question_embeddings

def retrieve_answer(question, faiss_index, embedding_model, answers, log_output, threshold=0.2):
    question_embedding = embedding_model.embed_query(question)
    distances, indices = faiss_index.search(np.array([question_embedding]), k=1)

    closest_distance = distances[0][0]
    closest_index = indices[0][0]
    log_output.write(f"closest_distance: {closest_distance}")

    if closest_distance > threshold:
        return "No good match found in dataset. Using GPT-4o-mini to generate an answer."

    return answers[closest_index]

def ask_openai_gpt4(question):
    response = openai.chat.completions.create(
        messages=[
            {"role": "user", "content": f"Answer the following medical question: {question}"}
        ],
        model="gpt-4o-mini",
        max_tokens=150
    )
    return response.choices[0].message.content

def chatbot(user_input):
    log_output = StringIO() 

    faiss_index, question_embeddings = load_embeddings_and_faiss()
    embedding_model = OpenAIEmbeddings(openai_api_key=openai.api_key)
    
    start_time = time.time()
    
    log_output.write("Retrieving answer from FAISS...\n")
    response_text = retrieve_answer(user_input, faiss_index, embedding_model, answers, log_output, threshold=0.3)
    
    if response_text == "No good match found in dataset. Using GPT-4o-mini to generate an answer.":
        log_output.write("No good match found in dataset. Using GPT-4o-mini to generate an answer.\n")
        response_text = ask_openai_gpt4(user_input)

    end_time = time.time() 
    response_time = end_time - start_time  

    return response_text, f"Response time: {response_time:.4f} seconds", log_output.getvalue()

demo = gr.Interface(
    fn=chatbot,  
    inputs="text", 
    outputs=[
        gr.Textbox(label="Chatbot Response"),
        gr.Textbox(label="Response Time"),   
        gr.Textbox(label="Logs") 
    ],
    title="Medical Chatbot with Custom Knowledge About Medical FAQ",
    description="A chatbot with custom knowledge using FAISS for quick responses or fallback to GPT-4o-mini when no relevant answer is found. Response time is also tracked."
)

if __name__ == "__main__":
    df = pd.read_csv("medquad.csv")
    questions = df['question'].tolist()
    answers = df['answer'].tolist()

    print(f"Loaded questions and answers. Number of questions: {len(questions)}, Number of answers: {len(answers)}")
    demo.launch()