File size: 4,634 Bytes
734db66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444fbc0
734db66
e72ac74
 
 
3bdc1bc
 
 
 
 
e72ac74
734db66
 
e72ac74
eef7e20
734db66
 
6fac7ff
04f8781
 
6fac7ff
734db66
 
 
444fbc0
33a6d1c
734db66
 
 
89ff17f
 
734db66
444fbc0
eab08cb
734db66
4325935
492adf0
f23d728
00b67ab
734db66
04f8781
7e991d2
734db66
6da8cb0
734db66
f051983
 
 
 
eef7e20
f051983
 
b26b437
f051983
 
 
 
 
 
 
 
 
8b0c846
f051983
 
 
 
3bdc1bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fac7ff
734db66
 
 
eef7e20
734db66
 
eef7e20
734db66
 
 
 
 
 
 
 
 
eab08cb
734db66
5f90b6d
734db66
 
 
 
bd370c7
3bdc1bc
 
 
f051983
ae99a46
0028070
 
 
26d9c3b
734db66
444fbc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import LLMChain
from langchain import PromptTemplate
import re
import pandas as pd
from langchain.vectorstores import FAISS
import requests
from typing import List
from langchain.schema import (
    SystemMessage,
    HumanMessage,
    AIMessage
)
import os
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chat_models import ChatOpenAI

from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any

import ast
from utils import ClaudeLLM

from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer

embeddings = HuggingFaceEmbeddings()
embeddings_1 = HuggingFaceEmbeddings(model_name = "BAAI/bge-large-en-v1.5")

db_art = FAISS.load_local('db_art', embeddings)
db_art_1 = FAISS.load_local('db_art_1', embeddings_1)
# db_yt = FAISS.load_local('db_yt', embeddings)
mp_docs = {}


def retrieve_thoughts(query, n, db):

    # print(db.similarity_search_with_score(query = query, k = k, fetch_k = k*10))
    #filter = {'Product Name': prod}
    
    docs_with_score = db.similarity_search_with_score(query = query, k = len(db.index_to_docstore_id.values()), fetch_k = len(db.index_to_docstore_id.values()))
    
    df = pd.DataFrame([dict(doc[0])['metadata'] for doc in docs_with_score], )
    df = pd.concat((df, pd.DataFrame([dict(doc[0])['page_content'] for doc in docs_with_score], columns = ['page_content'])), axis = 1)
    df = pd.concat((df, pd.DataFrame([doc[1] for doc in docs_with_score], columns = ['score'])), axis = 1)
    df['_id'] = df['_id'].apply(lambda x: str(x))
    df.sort_values("score", inplace = True)

  # TO-DO: What if user query doesn't match what we provide as documents

    tier_1 = df[df['score'] < 1]
    

    chunks_1 = tier_1.groupby(['_id' ]).apply(lambda x: {f"chunk_{i}": row for i, row  in enumerate(x.sort_values('id')[['id', 'score','page_content']].to_dict('records'))}).values
    tier_1_adjusted = tier_1.groupby(['_id']).first().reset_index()[['_id', 'title', 'author','url', 'score']]
    tier_1_adjusted['ref'] = range(1, len(tier_1_adjusted) + 1 )
    tier_1_adjusted['chunks'] = list(chunks_1)
    score = tier_1.groupby(['_id' ]).apply(lambda x: x['score'].median()).values
    tier_1_adjusted['score'] = list(score)
    tier_1_adjusted.sort_values("score", inplace = True)

    
    tier_1_adjusted = tier_1_adjusted[:min(len(tier_1_adjusted), 10)]

    return {'tier 1':tier_1_adjusted, }

def qa_retrieve_art(query,):

    docs = ""

    global db_art

    global mp_docs
    thoughts = retrieve_thoughts(query, 0, db_art)
    if not(thoughts):

        if mp_docs:
            thoughts = mp_docs
    else:
        mp_docs = thoughts

    tier_1 = thoughts['tier 1']
    
    reference = tier_1[['_id', 'url', 'author', 'title', 'chunks', 'score']].to_dict('records')

    return {'Reference': reference}


def qa_retrieve_bge(query,):

    docs = ""

    global db_art_1

    global mp_docs
    thoughts = retrieve_thoughts(query, 0, db_art_1)
    if not(thoughts):

        if mp_docs:
            thoughts = mp_docs
    else:
        mp_docs = thoughts

    tier_1 = thoughts['tier 1']
    
    reference = tier_1[['_id', 'url', 'author', 'title', 'chunks', 'score']].to_dict('records')

    return {'Reference': reference}

def qa_retrieve_yt(query,):

    docs = ""

    global db_yt

    global mp_docs
    thoughts = retrieve_thoughts(query, 0, db_yt)
    if not(thoughts):

        if mp_docs:
            thoughts = mp_docs
    else:
        mp_docs = thoughts

    tier_1 = thoughts['tier 1']
    
    reference = tier_1[['_id', 'url', 'author', 'title', 'chunks', 'score']].to_dict('records')

    return {'Reference': reference}

def flush():
  return None

ref_art_1 = gr.Interface(fn=qa_retrieve_bge, label="bge Articles",
                     inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
                     outputs = gr.components.JSON(label="articles"))
ref_art = gr.Interface(fn=qa_retrieve_art, label="mpnet Articles",
                     inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
                     outputs = gr.components.JSON(label="articles"))
# ref_yt = gr.Interface(fn=qa_retrieve_yt, label="Youtube",
#                      inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
#                      outputs = gr.components.JSON(label="youtube"),title = "youtube", examples=examples)
demo = gr.Parallel( ref_art_1, ref_art)

demo.launch()