pragneshbarik commited on
Commit
a9fbba5
1 Parent(s): 1ae5964

fixed errors

Browse files
components/generate_chat_stream.py CHANGED
@@ -3,7 +3,7 @@ from middlewares.utils import gen_augmented_prompt_via_websearch
3
  from middlewares.chat_client import chat
4
 
5
 
6
- def generate_chat_stream(session_state, prompt, config):
7
  # 1. augments prompt according to the template
8
  # 2. returns chat_stream and source links
9
  # 3. chat_stream and source links are used by stream_handler and show_source
@@ -11,8 +11,8 @@ def generate_chat_stream(session_state, prompt, config):
11
  links = []
12
  if session_state.rag_enabled:
13
  with st.spinner("Fetching relevent documents from Web...."):
14
- prompt, links = gen_augmented_prompt_via_websearch(
15
- prompt=prompt,
16
  pre_context=session_state.pre_context,
17
  post_context=session_state.post_context,
18
  pre_prompt=session_state.pre_prompt,
@@ -25,12 +25,6 @@ def generate_chat_stream(session_state, prompt, config):
25
  )
26
 
27
  with st.spinner("Generating response..."):
28
- chat_stream = chat(
29
- prompt,
30
- session_state.history,
31
- chat_client=chat_bot_dict[session_state.chat_bot],
32
- temperature=session_state.temp,
33
- max_new_tokens=session_state.max_tokens,
34
- )
35
 
36
  return chat_stream, links
 
3
  from middlewares.chat_client import chat
4
 
5
 
6
+ def generate_chat_stream(session_state, query, config):
7
  # 1. augments prompt according to the template
8
  # 2. returns chat_stream and source links
9
  # 3. chat_stream and source links are used by stream_handler and show_source
 
11
  links = []
12
  if session_state.rag_enabled:
13
  with st.spinner("Fetching relevent documents from Web...."):
14
+ query, links = gen_augmented_prompt_via_websearch(
15
+ prompt=query,
16
  pre_context=session_state.pre_context,
17
  post_context=session_state.post_context,
18
  pre_prompt=session_state.pre_prompt,
 
25
  )
26
 
27
  with st.spinner("Generating response..."):
28
+ chat_stream = chat(session_state, query, config)
 
 
 
 
 
 
29
 
30
  return chat_stream, links
config.yaml CHANGED
@@ -17,4 +17,6 @@ CHAT_BOTS:
17
  Mistral 7B v0.1: mistralai/Mistral-7B-Instruct-v0.1
18
  Mistral 7B v0.2: mistralai/Mistral-7B-Instruct-v0.2
19
 
 
 
20
  COST_PER_1000_TOKENS_USD: 0.001737375
 
17
  Mistral 7B v0.1: mistralai/Mistral-7B-Instruct-v0.1
18
  Mistral 7B v0.2: mistralai/Mistral-7B-Instruct-v0.2
19
 
20
+ CROSS_ENCODERS:
21
+
22
  COST_PER_1000_TOKENS_USD: 0.001737375
middlewares/chat_client.py CHANGED
@@ -9,7 +9,7 @@ API_TOKEN = os.getenv("HF_TOKEN")
9
 
10
 
11
 
12
- def format_prompt(session_state,query, history, chat_client):
13
  if chat_client=="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" :
14
  model_input = f"""<|im_start|>system
15
  {session_state.system_instruction}
@@ -37,22 +37,23 @@ def format_prompt(session_state,query, history, chat_client):
37
  return model_input
38
 
39
 
40
- def chat(
41
- prompt,
42
- history,
43
- chat_client="mistralai/Mistral-7B-Instruct-v0.1",
44
- temperature=0.9,
45
- max_new_tokens=256,
46
- top_p=0.95,
47
- repetition_penalty=1.0,
48
- truncate = False
49
- ):
50
 
 
 
 
 
 
 
 
 
 
 
51
  client = InferenceClient(chat_client, token=API_TOKEN)
52
  temperature = float(temperature)
53
  if temperature < 1e-2:
54
  temperature = 1e-2
55
- top_p = float(top_p)
56
 
57
  generate_kwargs = dict(
58
  temperature=temperature,
@@ -63,7 +64,7 @@ def chat(
63
  seed=42,
64
  )
65
 
66
- formatted_prompt = format_prompt(prompt, history)
67
 
68
  stream = client.text_generation(
69
  formatted_prompt,
@@ -71,6 +72,7 @@ def chat(
71
  stream=True,
72
  details=True,
73
  return_full_text=False,
 
74
  )
75
 
76
  return stream
 
9
 
10
 
11
 
12
+ def format_prompt(session_state ,query, history, chat_client):
13
  if chat_client=="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" :
14
  model_input = f"""<|im_start|>system
15
  {session_state.system_instruction}
 
37
  return model_input
38
 
39
 
40
+ def chat(session_state, query, config):
 
 
 
 
 
 
 
 
 
41
 
42
+
43
+
44
+ chat_bot_dict = config["CHAT_BOTS"]
45
+ chat_client = chat_bot_dict[session_state.chat_bot]
46
+ temperature = session_state.temp
47
+ max_new_tokens = session_state.max_tokens
48
+ repetition_penalty = session_state.repetition_penalty
49
+ history = session_state.history
50
+
51
+
52
  client = InferenceClient(chat_client, token=API_TOKEN)
53
  temperature = float(temperature)
54
  if temperature < 1e-2:
55
  temperature = 1e-2
56
+ top_p = float(0.95)
57
 
58
  generate_kwargs = dict(
59
  temperature=temperature,
 
64
  seed=42,
65
  )
66
 
67
+ formatted_prompt = format_prompt(session_state, query, history, chat_client)
68
 
69
  stream = client.text_generation(
70
  formatted_prompt,
 
72
  stream=True,
73
  details=True,
74
  return_full_text=False,
75
+ truncate = 32000
76
  )
77
 
78
  return stream