from rag_fns.generation import do_generation | |
from rag_fns.retrieval import do_retrieval | |
from rag_fns.setup_load import import_data, load_oai_model | |
def do_rag(user_input: str, stream: bool = False, n_results: int = 3): | |
# Load the data | |
talk_ids, embeds, talk_info = import_data() | |
# Load the model | |
oai_client = load_oai_model() | |
retrieved_docs = do_retrieval( | |
query0=user_input, | |
n_results=n_results, | |
api_client=oai_client, | |
talk_ids=talk_ids, | |
embeds=embeds, | |
talk_info=talk_info, | |
) | |
response, prompt_tokens = do_generation( | |
query1=user_input, keep_texts=retrieved_docs, gen_client=oai_client, stream=stream | |
) | |
return response, retrieved_docs, prompt_tokens | |