LVKinyanjui commited on
Commit
71c54ff
·
1 Parent(s): 8790464

Implemented llm chat history, modified model inference module to try resolve import errors

Browse files
Files changed (2) hide show
  1. inference_main.py +8 -2
  2. modules/inference/instruct.py +20 -22
inference_main.py CHANGED
@@ -1,10 +1,16 @@
1
  import streamlit as st
2
- from modules.inference.instruct import infer
3
 
4
  st.write("## Ask your Local LLM")
5
  text_input = st.text_input("Query", value="Why is the sky Blue")
6
  submit = st.button("Submit")
7
 
 
 
 
 
 
 
8
  if submit:
9
- response = infer(text_input)
10
  response
 
1
  import streamlit as st
2
+ from modules.inference.instruct import infer, load_model
3
 
4
  st.write("## Ask your Local LLM")
5
  text_input = st.text_input("Query", value="Why is the sky Blue")
6
  submit = st.button("Submit")
7
 
8
+ @st.cache_resource
9
+ def load_model_cached():
10
+ return load_model()
11
+
12
+ model = load_model_cached()
13
+
14
  if submit:
15
+ response = infer(model, text_input)
16
  response
modules/inference/instruct.py CHANGED
@@ -41,46 +41,44 @@ def load_model():
41
  )
42
  return pipeline
43
 
44
- pipeline = load_model()
45
 
46
- message_store_path = "messages.jsonl"
47
-
48
- messages: list[dict] = [
49
- {"role": "system", "content": SYSTEM_MESSAGE},
50
- ]
51
-
52
- if os.path.exists(message_store_path):
53
- with open(message_store_path, "r", encoding="utf-8") as f:
54
- messages = [json.loads(line) for line in f]
55
- print(messages)
56
-
57
- def infer(message: str):
58
  """
59
  Params:
60
  message: Most recent query to the llm.
 
 
61
  """
 
 
 
 
 
 
 
62
  messages.append({"role": "user", "content": message})
 
63
 
64
  # Perfom inference
65
- output = pipeline(
66
  messages,
67
- max_new_tokens=MAX_NEW_TOKENS)
68
-
69
- output_text = output[-1]['generated_text'][-1]['content']
70
-
71
  # Save the newly updated messages object
72
- with open(message_store_path, "w", encoding="utf-8") as f:
73
  for line in output:
74
  json.dump(line, f)
75
  f.write("\n")
76
-
77
- return output_text
78
 
79
  if __name__ == "__main__":
 
80
  while True:
81
  print("Press Ctrl + C to exit.")
82
  message = input("Ask a question.")
83
- print(infer(message))
84
 
85
  print("---------------------------------------")
86
  print("\n\n")
 
41
  )
42
  return pipeline
43
 
 
44
 
45
+ def infer(model, message: str, n_output_tokens=256, message_store_path: str = "messages.jsonl"):
 
 
 
 
 
 
 
 
 
 
 
46
  """
47
  Params:
48
  message: Most recent query to the llm.
49
+ messages: Chat history up to current point properly formatted like
50
+ {"role": "user", "content": "What is your name?"}
51
  """
52
+ if os.path.exists(message_store_path):
53
+ with open(message_store_path, "r", encoding="utf-8") as f:
54
+ messages = [json.loads(line) for line in f]
55
+ else:
56
+ messages = [
57
+ {"role": "system", "content": SYSTEM_MESSAGE},
58
+ ]
59
  messages.append({"role": "user", "content": message})
60
+ print(messages)
61
 
62
  # Perfom inference
63
+ outputs = model(
64
  messages,
65
+ max_new_tokens=n_output_tokens)
66
+ output: list = outputs[0]["generated_text"]
67
+
 
68
  # Save the newly updated messages object
69
+ with open(message_store_path, "a", encoding="utf-8") as f:
70
  for line in output:
71
  json.dump(line, f)
72
  f.write("\n")
73
+
74
+ return output[-1]['content']
75
 
76
  if __name__ == "__main__":
77
+ model = load_model()
78
  while True:
79
  print("Press Ctrl + C to exit.")
80
  message = input("Ask a question.")
81
+ print(infer(model, message))
82
 
83
  print("---------------------------------------")
84
  print("\n\n")