import os from langchain_community.llms import HuggingFaceHub from langchain_community.llms import OpenAI from langchain.chains import LLMChain from langchain.prompts import PromptTemplate import warnings warnings.filterwarnings("ignore") class LLLResponseGenerator(): def __init__(self): self.context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction." self.conversation_history = [] def update_context(self, user_text): self.conversation_history.append(user_text) self.context = "\n".join(self.conversation_history) def llm_inference( self, model_type: str, question: str, prompt_template: str, ai_tone: str, questionnaire: str, user_text: str, openai_model_name: str = "", hf_repo_id: str = "tiiuae/falcon-7b-instruct", temperature: float = 0.1, max_length: int = 128, ) -> str: """Call HuggingFace/OpenAI model for inference Given a question, prompt_template, and other parameters, this function calls the relevant API to fetch LLM inference results. Args: model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai' question: The question to be asked to the LLM. prompt_template: The prompt template itself. ai_tone: Can be either empathy, encouragement or suggest medical help. questionnaire: Can be either depression, anxiety or adhd. user_text: Response given by the user. hf_repo_id: The Huggingface model's repo_id temperature: (Default: 1.0). Range: Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability. max_length: Integer to define the maximum length in tokens of the output summary. Returns: A Python string which contains the inference result. HuggingFace repo_id examples: - google/flan-t5-xxl - tiiuae/falcon-7b-instruct """ prompt = PromptTemplate( template=prompt_template, input_variables=[ "context", "ai_tone", "questionnaire", "question", "user_text", ], ) if model_type == "openai": llm = OpenAI( model_name=openai_model_name, temperature=temperature, max_tokens=max_length ) llm_chain = LLMChain(prompt=prompt, llm=llm) return llm_chain.run( context=self.context, ai_tone=ai_tone, questionnaire=questionnaire, question=question, user_text=user_text, ) elif model_type == "huggingface": llm = HuggingFaceHub( repo_id=hf_repo_id, model_kwargs={"temperature": temperature, "max_length": max_length}, ) llm_chain = LLMChain(prompt=prompt, llm=llm) response = llm_chain.run( context=self.context, ai_tone=ai_tone, questionnaire=questionnaire, question=question, user_text=user_text, ) # Extracting only the response part from the output response_start_index = response.find("Response;") return response[response_start_index + len("Response;"):].strip() else: print( "Please use the correct value of model_type parameter: It can have a value of either openai or huggingface" ) if __name__ == "__main__": # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values. HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') ai_tone = "EMPATHY" questionnaire = "ADHD" question = ( "How often do you find yourself having trouble focusing on tasks or activities?" ) user_text = "I feel distracted all the time, and I am never able to finish" # The user may have signs of {questionnaire}. template = """INSTRUCTIONS: {context} Respond to the user with a tone of {ai_tone}. Question asked to the user: {question} Response by the user: {user_text} Provide some advice and ask a relevant question back to the user. Response; """ temperature = 0.1 max_length = 128 model = LLLResponseGenerator() # Initial prompt print("Bot:", model.llm_inference( model_type="huggingface", question=question, prompt_template=template, ai_tone=ai_tone, questionnaire=questionnaire, user_text=user_text, temperature=temperature, max_length=max_length, )) while True: user_input = input("You: ") if user_input.lower() == "exit": break model.update_context(user_input) print("Bot:", model.llm_inference( model_type="huggingface", question=question, prompt_template=template, ai_tone=ai_tone, questionnaire=questionnaire, user_text=user_input, temperature=temperature, max_length=max_length, ))