z00mP commited on
Commit
58cde81
·
1 Parent(s): d953944

change complition model interface

Browse files
Files changed (2) hide show
  1. app.py +14 -1
  2. backend/query_llm.py +5 -6
app.py CHANGED
@@ -39,6 +39,14 @@ def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_p
39
  top_k_param = int(top_k_param)
40
  query = history[-1][0]
41
 
 
 
 
 
 
 
 
 
42
  if not query:
43
  raise gr.Warning("Please submit a non-empty string as a prompt")
44
 
@@ -48,9 +56,13 @@ def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_p
48
 
49
  #documents = retrieve(query, TOP_K)
50
  documents = retrieve(query, top_k_param, chunk_table, embedding_model)
 
 
51
  if cross_encoder != "None" and len(documents) > 1:
52
  documents = rerank_documents(cross_encoder, documents, query, top_k_rerank=rerank_topk)
53
  #"cross-encoder/ms-marco-MiniLM-L-6-v2"
 
 
54
 
55
 
56
 
@@ -79,7 +91,8 @@ def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_p
79
  # generate_fn = generate_openai
80
  #else:
81
  # raise gr.Error(f"API {api_kind} is not supported")
82
-
 
83
  history[-1][1] = ""
84
  for character in generate_fn(prompt, history[:-1], llm_model):
85
  history[-1][1] = character
 
39
  top_k_param = int(top_k_param)
40
  query = history[-1][0]
41
 
42
+ logger.info("bot launched ...")
43
+ logger.info(f"embedding model: {embedding_model}")
44
+ logger.info(f"LLM model: {llm_model}")
45
+ logger.info(f"Cross encoder model: {cross_encoder}")
46
+ logger.info(f"TopK: {top_k_param}")
47
+ logger.info(f"ReRank TopK: {rerank_topk}")
48
+
49
+
50
  if not query:
51
  raise gr.Warning("Please submit a non-empty string as a prompt")
52
 
 
56
 
57
  #documents = retrieve(query, TOP_K)
58
  documents = retrieve(query, top_k_param, chunk_table, embedding_model)
59
+ logger.info('Retrived document count:', len(documents))
60
+
61
  if cross_encoder != "None" and len(documents) > 1:
62
  documents = rerank_documents(cross_encoder, documents, query, top_k_rerank=rerank_topk)
63
  #"cross-encoder/ms-marco-MiniLM-L-6-v2"
64
+ logger.info('ReRank done, document count:', len(documents))
65
+
66
 
67
 
68
 
 
91
  # generate_fn = generate_openai
92
  #else:
93
  # raise gr.Error(f"API {api_kind} is not supported")
94
+
95
+ logger.info(f'Complition started. llm_model: {llm_model}, prompt: {prompt}')
96
  history[-1][1] = ""
97
  for character in generate_fn(prompt, history[:-1], llm_model):
98
  history[-1][1] = character
backend/query_llm.py CHANGED
@@ -10,12 +10,12 @@ from transformers import AutoTokenizer
10
 
11
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
- TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
14
 
15
- HF_CLIENT = InferenceClient(
16
- os.getenv("HF_MODEL"),
17
- token=HF_TOKEN
18
- )
19
  OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
20
 
21
  HF_GENERATE_KWARGS = {
@@ -115,7 +115,6 @@ def generate_openai(prompt: str, history: str, model_name: str) -> Generator[str
115
 
116
  try:
117
  stream = OAI_CLIENT.chat.completions.create(
118
- #model=os.getenv("OPENAI_MODEL"),
119
  model = model_name,
120
  messages=formatted_prompt,
121
  **OAI_GENERATE_KWARGS,
 
10
 
11
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
+ #TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
14
 
15
+ #HF_CLIENT = InferenceClient(
16
+ # os.getenv("HF_MODEL"),
17
+ # token=HF_TOKEN
18
+ #)
19
  OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
20
 
21
  HF_GENERATE_KWARGS = {
 
115
 
116
  try:
117
  stream = OAI_CLIENT.chat.completions.create(
 
118
  model = model_name,
119
  messages=formatted_prompt,
120
  **OAI_GENERATE_KWARGS,