Spaces:
Sleeping
Sleeping
LVKinyanjui
commited on
Commit
·
71c54ff
1
Parent(s):
8790464
Implemented llm chat history, modified model inference module to try resolve import errors
Browse files- inference_main.py +8 -2
- modules/inference/instruct.py +20 -22
inference_main.py
CHANGED
@@ -1,10 +1,16 @@
|
|
1 |
import streamlit as st
|
2 |
-
from modules.inference.instruct import infer
|
3 |
|
4 |
st.write("## Ask your Local LLM")
|
5 |
text_input = st.text_input("Query", value="Why is the sky Blue")
|
6 |
submit = st.button("Submit")
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
if submit:
|
9 |
-
response = infer(text_input)
|
10 |
response
|
|
|
1 |
import streamlit as st
|
2 |
+
from modules.inference.instruct import infer, load_model
|
3 |
|
4 |
st.write("## Ask your Local LLM")
|
5 |
text_input = st.text_input("Query", value="Why is the sky Blue")
|
6 |
submit = st.button("Submit")
|
7 |
|
8 |
+
@st.cache_resource
|
9 |
+
def load_model_cached():
|
10 |
+
return load_model()
|
11 |
+
|
12 |
+
model = load_model_cached()
|
13 |
+
|
14 |
if submit:
|
15 |
+
response = infer(model, text_input)
|
16 |
response
|
modules/inference/instruct.py
CHANGED
@@ -41,46 +41,44 @@ def load_model():
|
|
41 |
)
|
42 |
return pipeline
|
43 |
|
44 |
-
pipeline = load_model()
|
45 |
|
46 |
-
message_store_path = "messages.jsonl"
|
47 |
-
|
48 |
-
messages: list[dict] = [
|
49 |
-
{"role": "system", "content": SYSTEM_MESSAGE},
|
50 |
-
]
|
51 |
-
|
52 |
-
if os.path.exists(message_store_path):
|
53 |
-
with open(message_store_path, "r", encoding="utf-8") as f:
|
54 |
-
messages = [json.loads(line) for line in f]
|
55 |
-
print(messages)
|
56 |
-
|
57 |
-
def infer(message: str):
|
58 |
"""
|
59 |
Params:
|
60 |
message: Most recent query to the llm.
|
|
|
|
|
61 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
messages.append({"role": "user", "content": message})
|
|
|
63 |
|
64 |
# Perfom inference
|
65 |
-
|
66 |
messages,
|
67 |
-
max_new_tokens=
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
# Save the newly updated messages object
|
72 |
-
with open(message_store_path, "
|
73 |
for line in output:
|
74 |
json.dump(line, f)
|
75 |
f.write("\n")
|
76 |
-
|
77 |
-
return
|
78 |
|
79 |
if __name__ == "__main__":
|
|
|
80 |
while True:
|
81 |
print("Press Ctrl + C to exit.")
|
82 |
message = input("Ask a question.")
|
83 |
-
print(infer(message))
|
84 |
|
85 |
print("---------------------------------------")
|
86 |
print("\n\n")
|
|
|
41 |
)
|
42 |
return pipeline
|
43 |
|
|
|
44 |
|
45 |
+
def infer(model, message: str, n_output_tokens=256, message_store_path: str = "messages.jsonl"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
"""
|
47 |
Params:
|
48 |
message: Most recent query to the llm.
|
49 |
+
messages: Chat history up to current point properly formatted like
|
50 |
+
{"role": "user", "content": "What is your name?"}
|
51 |
"""
|
52 |
+
if os.path.exists(message_store_path):
|
53 |
+
with open(message_store_path, "r", encoding="utf-8") as f:
|
54 |
+
messages = [json.loads(line) for line in f]
|
55 |
+
else:
|
56 |
+
messages = [
|
57 |
+
{"role": "system", "content": SYSTEM_MESSAGE},
|
58 |
+
]
|
59 |
messages.append({"role": "user", "content": message})
|
60 |
+
print(messages)
|
61 |
|
62 |
# Perfom inference
|
63 |
+
outputs = model(
|
64 |
messages,
|
65 |
+
max_new_tokens=n_output_tokens)
|
66 |
+
output: list = outputs[0]["generated_text"]
|
67 |
+
|
|
|
68 |
# Save the newly updated messages object
|
69 |
+
with open(message_store_path, "a", encoding="utf-8") as f:
|
70 |
for line in output:
|
71 |
json.dump(line, f)
|
72 |
f.write("\n")
|
73 |
+
|
74 |
+
return output[-1]['content']
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
+
model = load_model()
|
78 |
while True:
|
79 |
print("Press Ctrl + C to exit.")
|
80 |
message = input("Ask a question.")
|
81 |
+
print(infer(model, message))
|
82 |
|
83 |
print("---------------------------------------")
|
84 |
print("\n\n")
|