File size: 1,139 Bytes
5fe04ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pickle

import faiss
from langchain import OpenAI
from langchain.chains import VectorDBQAWithSourcesChain

from zeno import ZenoOptions, distill, metric, model


@model
def get_model(model_name):
    # Blendle Notion chatbot example from:
    # https://github.com/hwchase17/chat-langchain-notion

    index = faiss.read_index("./docs.index")
    with open("./faiss_store.pkl", "rb") as f:
        store = pickle.load(f)
    store.index = index
    chain = VectorDBQAWithSourcesChain.from_llm(
        llm=OpenAI(temperature=0), vectorstore=store
    )

    def pred(df, ops: ZenoOptions):
        res = []
        for question in df[ops.data_column]:
            result = chain({"question": question})
            res.append(
                "Answer: {}\nSources: {}".format(result["answer"], result["sources"])
            )
        return res

    return pred


@distill
def correct(df, ops: ZenoOptions):
    return df.apply(
        lambda x: x[ops.label_column].lower() in x[ops.output_column].lower(), axis=1
    )


@metric
def accuracy(df, ops: ZenoOptions):
    return df[ops.distill_columns["correct"]].astype(int).mean()