File size: 5,100 Bytes
30ffb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87dd32d
30ffb9e
 
 
 
 
 
 
 
 
87dd32d
 
30ffb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from openai import OpenAI
from typing import List, Any, Tuple
from tqdm import tqdm
import streamlit as st
from concurrent.futures import ThreadPoolExecutor, as_completed

from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv('env'), override=True)

try:
    api_key = st.secrets['secrets']['OPENAI_API_KEY']
except:
    api_key = os.environ['OPENAI_API_KEY']
class GPT_Turbo:

    def __init__(self, model: str="gpt-3.5-turbo-0613", api_key: str=api_key):
        self.model = model
        self.client = OpenAI(api_key=api_key)

    def get_chat_completion(self, 
                            prompt: str, 
                            system_message: str='You are a helpful assistant.',
                            user_message: str=None,
                            temperature: int=0, 
                            max_tokens: int=500,
                            stream: bool=False,
                            show_response: bool=False
                            ) -> str:
        messages =  [
            {'role': 'system', 'content': system_message},
            {'role': 'assistant', 'content': prompt}
                    ]
        if user_message is not None:
            messages.append({'role': 'user', 'content': user_message})
        
        response = self.client.chat.completions.create( model=self.model,
                                                        messages=messages,
                                                        temperature=temperature,
                                                        max_tokens=max_tokens,
                                                        stream=stream)
        if show_response:
            return response
        return response.choices[0].message.content
    
    def multi_thread_request(self, 
                             filepath: str,
                             prompt: str,
                             content: List[str],
                             temperature: int=0
                             ) -> List[Any]:
        
        data = []
        with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
            futures = [exec.submit(self.get_completion_from_messages, [{'role': 'user','content': f'{prompt} ```{c}```'}], temperature, 500, False) for c in content]
            with open(filepath, 'a') as f:
                for future in as_completed(futures):
                    result = future.result()
                    if len(data) % 10 == 0:
                            print(f'{len(data)} of {len(content)} completed.')
                    if result:
                        data.append(result)
                        self.write_to_file(file_handle=f, data=result)
        return [res for res in data if res]
    
    def generate_question_context_pairs(self, 
                                        context_tuple: Tuple[str, str], 
                                        num_questions_per_chunk: int=2, 
                                        max_words_per_question: int=10
                                        ) -> List[str]:
        
        doc_id, context = context_tuple
        prompt = f'Context information is included below enclosed in triple backticks. Given the context information and not prior knowledge, generate questions based on the below query.\n\nYou are an end user querying for information about your favorite podcast. \
                   Your task is to setup {num_questions_per_chunk} questions that can be answered using only the given context. The questions should be diverse in nature across the document and be no longer than {max_words_per_question} words. \
                   Restrict the questions to the context information provided.\n\
                   ```{context}```\n\n'
        
        response = self.get_completion_from_messages(prompt=prompt, temperature=0, max_tokens=500, show_response=True)
        questions = response.choices[0].message["content"]
        return (doc_id, questions)

    def batch_generate_question_context_pairs(self,
                                              context_tuple_list: List[Tuple[str, str]],
                                              num_questions_per_chunk: int=2,
                                              max_words_per_question: int=10
                                              ) -> List[Tuple[str, str]]:
        data = []
        progress = tqdm(unit="Generated Questions", total=len(context_tuple_list))
        with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
            futures = [exec.submit(self.generate_question_context_pairs, context_tuple, num_questions_per_chunk, max_words_per_question) for context_tuple in context_tuple_list]
            for future in as_completed(futures):
                result = future.result()
                if result:
                    data.append(result)
                    progress.update(1)
        return data
    
    def get_embedding(self):
         pass
    
    def write_to_file(self, file_handle, data: str) -> None:
            file_handle.write(data)
            file_handle.write('\n')