ASledziewska commited on
Commit
1872b66
1 Parent(s): 5038c7a

Update llm_response_generator.py

Browse files
Files changed (1) hide show
  1. llm_response_generator.py +29 -18
llm_response_generator.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.llms import OpenAI
4
- # from langchain.llms import HuggingFaceHub, OpenAI
5
  from langchain.chains import LLMChain
6
  from langchain.prompts import PromptTemplate
7
  import warnings
@@ -9,17 +8,19 @@ import warnings
9
  warnings.filterwarnings("ignore")
10
 
11
  class LLLResponseGenerator():
12
-
13
  def __init__(self):
14
- print("initialized")
 
15
 
 
 
 
16
 
17
  def llm_inference(
18
  self,
19
  model_type: str,
20
  question: str,
21
  prompt_template: str,
22
- context: str,
23
  ai_tone: str,
24
  questionnaire: str,
25
  user_text: str,
@@ -37,7 +38,6 @@ class LLLResponseGenerator():
37
  model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
38
  question: The question to be asked to the LLM.
39
  prompt_template: The prompt template itself.
40
- context: Instructions for the LLM.
41
  ai_tone: Can be either empathy, encouragement or suggest medical help.
42
  questionnaire: Can be either depression, anxiety or adhd.
43
  user_text: Response given by the user.
@@ -65,13 +65,12 @@ class LLLResponseGenerator():
65
  )
66
 
67
  if model_type == "openai":
68
- # https://api.python.langchain.com/en/stable/llms/langchain.llms.openai.OpenAI.html#langchain.llms.openai.OpenAI
69
  llm = OpenAI(
70
  model_name=openai_model_name, temperature=temperature, max_tokens=max_length
71
  )
72
  llm_chain = LLMChain(prompt=prompt, llm=llm)
73
  return llm_chain.run(
74
- context=context,
75
  ai_tone=ai_tone,
76
  questionnaire=questionnaire,
77
  question=question,
@@ -79,15 +78,14 @@ class LLLResponseGenerator():
79
  )
80
 
81
  elif model_type == "huggingface":
82
- # https://python.langchain.com/docs/integrations/llms/huggingface_hub
83
  llm = HuggingFaceHub(
84
  repo_id=hf_repo_id,
85
  model_kwargs={"temperature": temperature, "max_length": max_length},
86
  )
87
 
88
  llm_chain = LLMChain(prompt=prompt, llm=llm)
89
- response = llm_chain.run(
90
- context=context,
91
  ai_tone=ai_tone,
92
  questionnaire=questionnaire,
93
  question=question,
@@ -108,8 +106,6 @@ if __name__ == "__main__":
108
  # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
109
  HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
110
 
111
- context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."
112
-
113
  ai_tone = "EMPATHY"
114
  questionnaire = "ADHD"
115
  question = (
@@ -136,17 +132,32 @@ if __name__ == "__main__":
136
 
137
  model = LLLResponseGenerator()
138
 
139
-
140
- llm_response = model.llm_inference(
141
  model_type="huggingface",
142
  question=question,
143
  prompt_template=template,
144
- context=context,
145
  ai_tone=ai_tone,
146
  questionnaire=questionnaire,
147
  user_text=user_text,
148
  temperature=temperature,
149
  max_length=max_length,
150
- )
151
-
152
- print(llm_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.llms import OpenAI
 
4
  from langchain.chains import LLMChain
5
  from langchain.prompts import PromptTemplate
6
  import warnings
 
8
  warnings.filterwarnings("ignore")
9
 
10
  class LLLResponseGenerator():
 
11
  def __init__(self):
12
+ self.context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."
13
+ self.conversation_history = []
14
 
15
+ def update_context(self, user_text):
16
+ self.conversation_history.append(user_text)
17
+ self.context = "\n".join(self.conversation_history)
18
 
19
  def llm_inference(
20
  self,
21
  model_type: str,
22
  question: str,
23
  prompt_template: str,
 
24
  ai_tone: str,
25
  questionnaire: str,
26
  user_text: str,
 
38
  model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
39
  question: The question to be asked to the LLM.
40
  prompt_template: The prompt template itself.
 
41
  ai_tone: Can be either empathy, encouragement or suggest medical help.
42
  questionnaire: Can be either depression, anxiety or adhd.
43
  user_text: Response given by the user.
 
65
  )
66
 
67
  if model_type == "openai":
 
68
  llm = OpenAI(
69
  model_name=openai_model_name, temperature=temperature, max_tokens=max_length
70
  )
71
  llm_chain = LLMChain(prompt=prompt, llm=llm)
72
  return llm_chain.run(
73
+ context=self.context,
74
  ai_tone=ai_tone,
75
  questionnaire=questionnaire,
76
  question=question,
 
78
  )
79
 
80
  elif model_type == "huggingface":
 
81
  llm = HuggingFaceHub(
82
  repo_id=hf_repo_id,
83
  model_kwargs={"temperature": temperature, "max_length": max_length},
84
  )
85
 
86
  llm_chain = LLMChain(prompt=prompt, llm=llm)
87
+ response = llm_chain.run(
88
+ context=self.context,
89
  ai_tone=ai_tone,
90
  questionnaire=questionnaire,
91
  question=question,
 
106
  # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
107
  HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
108
 
 
 
109
  ai_tone = "EMPATHY"
110
  questionnaire = "ADHD"
111
  question = (
 
132
 
133
  model = LLLResponseGenerator()
134
 
135
+ # Initial prompt
136
+ print("Bot:", model.llm_inference(
137
  model_type="huggingface",
138
  question=question,
139
  prompt_template=template,
 
140
  ai_tone=ai_tone,
141
  questionnaire=questionnaire,
142
  user_text=user_text,
143
  temperature=temperature,
144
  max_length=max_length,
145
+ ))
146
+
147
+ while True:
148
+ user_input = input("You: ")
149
+ if user_input.lower() == "exit":
150
+ break
151
+
152
+ model.update_context(user_input)
153
+
154
+ print("Bot:", model.llm_inference(
155
+ model_type="huggingface",
156
+ question=question,
157
+ prompt_template=template,
158
+ ai_tone=ai_tone,
159
+ questionnaire=questionnaire,
160
+ user_text=user_input,
161
+ temperature=temperature,
162
+ max_length=max_length,
163
+ ))