LVKinyanjui commited on
Commit
1ad978f
1 Parent(s): 728c92a

Abstracted away inference implementation and succesfully tested the instruct template

Browse files
Dockerfile CHANGED
@@ -19,4 +19,4 @@ COPY . .
19
  EXPOSE 8000
20
 
21
  # Run the application.
22
- CMD streamlit run app_inference.py --server.port 7860
 
19
  EXPOSE 8000
20
 
21
  # Run the application.
22
+ CMD streamlit run inference_main.py --server.port 7860
inference_main.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from modules.inference.llama3_1_8b_instruct import infer
3
+
4
+ st.write("## Ask your Local LLM")
5
+ text_input = st.text_input("Query", value="Why is the sky Blue")
6
+ submit = st.button("Submit")
7
+
8
+ if submit:
9
+ response = infer(text_input)
10
+ response
modules/inference/{llama3_1_8b_instruct.py → instruct.py} RENAMED
@@ -1,45 +1,41 @@
1
- import streamlit as st
2
-
3
  import transformers, torch
4
  import json, os
5
 
6
  from huggingface_hub import login
7
 
8
  # CONSTANTS
9
- MAX_NEW_TOKENS = 256
10
  SYSTEM_MESSAGE = "You are a hepful, knowledgeable assistant"
11
 
12
- # ENV VARS
13
- # To avert Permision error with transformer and hf models
14
- os.environ['SENTENCE_TRANSFORMERS_HOME'] = '.'
15
- token = os.getenv("HF_TOKEN_WRITE") # Must be a write token
16
- # STREAMLIT UI AREA
17
 
18
- st.write("## Ask your Local LLM")
19
- text_input = st.text_input("Query", value="Why is the sky Blue")
20
- submit = st.button("Submit")
 
 
21
 
22
- # MODEL AREA
23
- # Use the token to authenticate
24
- login(token=token,
25
- write_permission=True # Must be set to True when we pass in our own token
26
- # Otherwise we get Permission Denied.
27
- )
28
- model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
29
 
30
- @st.cache_resource
31
  def load_model():
 
32
  pipeline = transformers.pipeline(
33
  "text-generation",
34
  model=model_id,
35
  model_kwargs={"torch_dtype": torch.bfloat16},
36
  device_map="auto",
37
  )
 
38
 
39
  pipeline = load_model()
40
 
41
  message_store_path = "messages.jsonl"
42
- messages = [
 
43
  {"role": "system", "content": SYSTEM_MESSAGE},
44
  ]
45
 
@@ -48,13 +44,10 @@ if os.path.exists(message_store_path):
48
  messages = [json.loads(line) for line in f]
49
  print(messages)
50
 
51
- @st.cache_data
52
- def infer(message: str, messages: list[dict]):
53
  """
54
  Params:
55
  message: Most recent query to the llm.
56
- messages: Chat history up to current point properly formatted like
57
- {"role": "user", "content": "What is your name?"}
58
  """
59
  messages.append({"role": "user", "content": message})
60
 
@@ -63,14 +56,23 @@ def infer(message: str, messages: list[dict]):
63
  messages,
64
  max_new_tokens=MAX_NEW_TOKENS)
65
 
 
 
66
  # Save the newly updated messages object
67
  with open(message_store_path, "w", encoding="utf-8") as f:
68
  for line in output:
69
  json.dump(line, f)
70
  f.write("\n")
71
 
72
- return output[-1]['generated_text'][-1]['content']
 
 
 
 
 
 
 
 
 
73
 
74
- if submit:
75
- response = infer(text_input, messages)
76
- response
 
 
 
1
  import transformers, torch
2
  import json, os
3
 
4
  from huggingface_hub import login
5
 
6
  # CONSTANTS
7
+ MAX_NEW_TOKENS = 1024
8
  SYSTEM_MESSAGE = "You are a hepful, knowledgeable assistant"
9
 
10
+ # # ENV VARS
11
+ # # To avert Permision error with transformer and hf models
12
+ # os.environ['SENTENCE_TRANSFORMERS_HOME'] = '.'
13
+ # token = os.getenv("HF_TOKEN_WRITE") # Must be a write token
 
14
 
15
+ # # Use the token to authenticate
16
+ # login(token=token,
17
+ # write_permission=True # Must be set to True when we pass in our own token
18
+ # # Otherwise we get Permission Denied.
19
+ # )
20
 
21
+ model_id = "microsoft/Phi-3.5-mini-instruct"
22
+ # model_id = "meta-llama/Llama-3.2-1B-Instruct"
 
 
 
 
 
23
 
 
24
  def load_model():
25
+ print(f"Loading {model_id}")
26
  pipeline = transformers.pipeline(
27
  "text-generation",
28
  model=model_id,
29
  model_kwargs={"torch_dtype": torch.bfloat16},
30
  device_map="auto",
31
  )
32
+ return pipeline
33
 
34
  pipeline = load_model()
35
 
36
  message_store_path = "messages.jsonl"
37
+
38
+ messages: list[dict] = [
39
  {"role": "system", "content": SYSTEM_MESSAGE},
40
  ]
41
 
 
44
  messages = [json.loads(line) for line in f]
45
  print(messages)
46
 
47
+ def infer(message: str):
 
48
  """
49
  Params:
50
  message: Most recent query to the llm.
 
 
51
  """
52
  messages.append({"role": "user", "content": message})
53
 
 
56
  messages,
57
  max_new_tokens=MAX_NEW_TOKENS)
58
 
59
+ output_text = output[-1]['generated_text'][-1]['content']
60
+
61
  # Save the newly updated messages object
62
  with open(message_store_path, "w", encoding="utf-8") as f:
63
  for line in output:
64
  json.dump(line, f)
65
  f.write("\n")
66
 
67
+ return output_text
68
+
69
+ if __name__ == "__main__":
70
+ while True:
71
+ print("Press Ctrl + C to exit.")
72
+ message = input("Ask a question.")
73
+ print(infer(message))
74
+
75
+ print("---------------------------------------")
76
+ print("\n\n")
77
 
78
+ print(messages)
 
 
requirements.txt CHANGED
@@ -1,7 +1,10 @@
 
 
 
 
1
  chromadb==0.5.5
2
  pymupdf==1.24.9
3
  streamlit==1.38.0
4
- transformers==4.44.2
5
  langchain==0.3.0
6
  langchain-core==0.3.5
7
  langchain-text-splitters==0.3.0
@@ -10,5 +13,4 @@ langchain-community==0.3.0
10
  python-dotenv==1.0.1
11
  tiktoken==0.7.0
12
  huggingface-hub==0.25.1
13
- torch==2.4.1
14
  langchain-ollama==0.2.0
 
1
+ flash_attn==2.5.8
2
+ torch==2.3.1
3
+ accelerate==0.31.0
4
+ transformers==4.43.0
5
  chromadb==0.5.5
6
  pymupdf==1.24.9
7
  streamlit==1.38.0
 
8
  langchain==0.3.0
9
  langchain-core==0.3.5
10
  langchain-text-splitters==0.3.0
 
13
  python-dotenv==1.0.1
14
  tiktoken==0.7.0
15
  huggingface-hub==0.25.1
 
16
  langchain-ollama==0.2.0