bstraehle commited on
Commit
5059289
1 Parent(s): 2a2835e

Update trace.py

Browse files
Files changed (1) hide show
  1. trace.py +16 -27
trace.py CHANGED
@@ -5,40 +5,29 @@ from wandb.sdk.data_types.trace_tree import Trace
5
  WANDB_API_KEY = os.environ["WANDB_API_KEY"]
6
 
7
  def trace_wandb(config,
8
- is_rag_off,
9
- prompt,
10
- completion,
11
- result,
12
- chain,
13
- cb,
14
- err_msg,
15
- start_time_ms,
16
  end_time_ms):
17
  wandb.init(project = "openai-llm-rag")
18
 
19
  trace = Trace(
20
- kind = "chain",
21
- name = "" if (chain == None) else type(chain).__name__,
22
  status_code = "success" if (str(err_msg) == "") else "error",
23
  status_message = str(err_msg),
24
- inputs = {"is_rag": not is_rag_off,
25
- "prompt": prompt,
26
- "chain_prompt": (str(chain.prompt) if (is_rag_off) else
27
- str(chain.combine_documents_chain.llm_chain.prompt)),
28
- "source_documents": "" if (is_rag_off) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
29
  } if (str(err_msg) == "") else {},
30
- outputs = {"result": result,
31
- "cb": str(cb),
32
- "completion": str(completion),
33
- } if (str(err_msg) == "") else {},
34
- model_dict = {"client": (str(chain.llm.client) if (is_rag_off) else
35
- str(chain.combine_documents_chain.llm_chain.llm.client)),
36
- "model_name": (str(chain.llm.model_name) if (is_rag_off) else
37
- str(chain.combine_documents_chain.llm_chain.llm.model_name)),
38
- "temperature": (str(chain.llm.temperature) if (is_rag_off) else
39
- str(chain.combine_documents_chain.llm_chain.llm.temperature)),
40
- "retriever": ("" if (is_rag_off) else str(chain.retriever)),
41
- } if (str(err_msg) == "") else {},
42
  start_time_ms = start_time_ms,
43
  end_time_ms = end_time_ms
44
  )
 
5
  WANDB_API_KEY = os.environ["WANDB_API_KEY"]
6
 
7
  def trace_wandb(config,
8
+ rag_option,
9
+ prompt,
10
+ completion,
11
+ result,
12
+ callback,
13
+ err_msg,
14
+ start_time_ms,
 
15
  end_time_ms):
16
  wandb.init(project = "openai-llm-rag")
17
 
18
  trace = Trace(
19
+ kind = "LLM & RAG",
20
+ name = "Context-Aware Reasoning Application",
21
  status_code = "success" if (str(err_msg) == "") else "error",
22
  status_message = str(err_msg),
23
+ inputs = {"prompt": str(prompt),
24
+ "rag_option": str(rag_option),
25
+ "config": str(config)
 
 
26
  } if (str(err_msg) == "") else {},
27
+ outputs = {"result": str(result),
28
+ "callback": str(callback),
29
+ "completion": str(completion)
30
+ } if (str(err_msg) == "") else {}
 
 
 
 
 
 
 
 
31
  start_time_ms = start_time_ms,
32
  end_time_ms = end_time_ms
33
  )