bstraehle commited on
Commit
ce136c7
1 Parent(s): 1ce0835

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -3,7 +3,8 @@ import os, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
 
6
- from rag_langchain import llm_chain, rag_chain, rag_ingestion
 
7
  from trace import trace_wandb
8
 
9
  _ = load_dotenv(find_dotenv())
@@ -31,25 +32,28 @@ def invoke(openai_api_key, prompt, rag_option):
31
  os.environ["OPENAI_API_KEY"] = openai_api_key
32
 
33
  if (RAG_INGESTION):
34
- rag_ingestion(config)
 
 
 
35
 
36
  chain = None
37
  completion = ""
38
  result = ""
39
- cb = ""
40
  err_msg = ""
41
 
42
  try:
43
  start_time_ms = round(time.time() * 1000)
44
 
45
  if (rag_option == RAG_LANGCHAIN):
46
- completion, chain, cb = rag_chain(config, prompt)
47
 
48
  result = completion["result"]
49
  elif (rag_option == RAG_LLAMAINDEX):
50
- print("TODO")
51
  else:
52
- completion, chain, cb = llm_chain(config, prompt)
53
 
54
  if (completion.generations[0] != None and completion.generations[0][0] != None):
55
  result = completion.generations[0][0].text
@@ -66,7 +70,7 @@ def invoke(openai_api_key, prompt, rag_option):
66
  # completion,
67
  # result,
68
  # chain,
69
- # cb,
70
  # err_msg,
71
  # start_time_ms,
72
  # end_time_ms)
 
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
 
6
+ from rag_langchain import llm_chain, rag_chain, rag_ingestion_langchain
7
+ from rag_llamaindex import rag_ingestion_llamaindex, rag_retrieval
8
  from trace import trace_wandb
9
 
10
  _ = load_dotenv(find_dotenv())
 
32
  os.environ["OPENAI_API_KEY"] = openai_api_key
33
 
34
  if (RAG_INGESTION):
35
+ if (rag_option == RAG_LANGCHAIN):
36
+ rag_ingestion_llangchain(config)
37
+ elif (rag_option == RAG_LLAMAINDEX):
38
+ rag_ingestion_llamaindex(config)
39
 
40
  chain = None
41
  completion = ""
42
  result = ""
43
+ callback = ""
44
  err_msg = ""
45
 
46
  try:
47
  start_time_ms = round(time.time() * 1000)
48
 
49
  if (rag_option == RAG_LANGCHAIN):
50
+ completion, chain, callback = rag_chain(config, prompt)
51
 
52
  result = completion["result"]
53
  elif (rag_option == RAG_LLAMAINDEX):
54
+ result = rag_retrieval(config, prompt)
55
  else:
56
+ completion, chain, callback = llm_chain(config, prompt)
57
 
58
  if (completion.generations[0] != None and completion.generations[0][0] != None):
59
  result = completion.generations[0][0].text
 
70
  # completion,
71
  # result,
72
  # chain,
73
+ # callback,
74
  # err_msg,
75
  # start_time_ms,
76
  # end_time_ms)