File size: 2,557 Bytes
b3fc90b
 
82fdbb3
 
b579ccf
 
b3fc90b
06c3202
b3fc90b
06c3202
 
 
 
 
 
 
203a4b1
f282311
 
23e9e9c
b3fc90b
1fac75f
b3fc90b
 
 
 
 
 
 
0cd2aee
 
 
8f62329
b579ccf
 
 
 
5794520
b579ccf
 
8f62329
b579ccf
b3fc90b
 
 
 
315794b
 
06c3202
 
b3fc90b
06c3202
c605c03
b3fc90b
 
 
0cd2aee
c99c9ac
b3fc90b
 
 
 
b579ccf
7929d48
 
 
 
 
16ada75
 
 
7929d48
16ada75
6ca39c0
6afbc36
 
b3fc90b
a91dc62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json

from services.qa_service.utils import format_prompt



class QAService:
    def __init__(self, conf, pinecone, model_pipeline, question, goals, session_key, keycheck):
        self.conf = conf
        
        if keycheck:
            self.sess_key = session_key
            self.pinecones = pinecone.run(namespace=self.sess_key)
        else:
            self.sess_key = self.conf["embeddings"]["demo_namespace"]
            self.pinecones = pinecone.run(namespace=self.sess_key)
        
        self.pc = self.pinecones['connection']
        self.embedder = self.pinecones['embedder']
        self.model_pipeline = model_pipeline
        self.question = question
        self.goals = goals
    
    def __enter__(self):
        print("Start Q&A Service")
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting Q&A Service")

    def parse_results(self, result):

        parsed = []
        for i in result['matches']:
            collect = i['metadata']['_node_content']
            content = json.loads(collect)
            parsed.append({
                "speakers": content["metadata"]["speakers"],
                "text": content["text"]
            })
            
        return parsed
    
    def retrieve_context(self):
        """Pass embedded question into pinecone"""
        embedded_query = self.embedder.get_text_embedding(self.question)
        print("session key: "+self.sess_key)
        print("index name: "+self.conf['embeddings']['index_name'])
        index = self.pc.Index(self.conf['embeddings']['index_name'])
        result = index.query(
            vector=embedded_query,
            namespace=self.sess_key, # I think namespace comes somewhere here during querying!!!
            top_k=self.conf["embeddings"]["examples"], 
            include_values=False,
            include_metadata=True
        )

        output = self.parse_results(result)
        return output
    
    def run(self):
        """Query pinecone outputs and infer results"""
        full_context = self.retrieve_context()

        transcript_count = len(full_context)

        context = ""
        for i in range(transcript_count):
            context += "Transcript snippet %s" % (i + 1)
            context += "\n"
            context += full_context[i]["text"]
            if (i+1) < transcript_count:
                context += "\n\n"
        
        prompt = format_prompt(self.question, context)
        output = self.model_pipeline.infer(prompt)
        
        return output, context