FomoFix / llm_response_generator.py
ASledziewska
add updated files
1830956
raw
history blame
5.19 kB
import os
from langchain_community.llms import HuggingFaceHub
from langchain_community.llms import OpenAI
# from langchain.llms import HuggingFaceHub, OpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import warnings
warnings.filterwarnings("ignore")
class LLLResponseGenerator():
def __init__(self):
print("initialized")
def llm_inference(
self,
model_type: str,
question: str,
prompt_template: str,
context: str,
ai_tone: str,
questionnaire: str,
user_text: str,
openai_model_name: str = "",
hf_repo_id: str = "tiiuae/falcon-7b-instruct",
temperature: float = 0.1,
max_length: int = 128,
) -> str:
"""Call HuggingFace/OpenAI model for inference
Given a question, prompt_template, and other parameters, this function calls the relevant
API to fetch LLM inference results.
Args:
model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
question: The question to be asked to the LLM.
prompt_template: The prompt template itself.
context: Instructions for the LLM.
ai_tone: Can be either empathy, encouragement or suggest medical help.
questionnaire: Can be either depression, anxiety or adhd.
user_text: Response given by the user.
hf_repo_id: The Huggingface model's repo_id
temperature: (Default: 1.0). Range: Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
max_length: Integer to define the maximum length in tokens of the output summary.
Returns:
A Python string which contains the inference result.
HuggingFace repo_id examples:
- google/flan-t5-xxl
- tiiuae/falcon-7b-instruct
"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=[
"context",
"ai_tone",
"questionnaire",
"question",
"user_text",
],
)
if model_type == "openai":
# https://api.python.langchain.com/en/stable/llms/langchain.llms.openai.OpenAI.html#langchain.llms.openai.OpenAI
llm = OpenAI(
model_name=openai_model_name, temperature=temperature, max_tokens=max_length
)
llm_chain = LLMChain(prompt=prompt, llm=llm)
return llm_chain.run(
context=context,
ai_tone=ai_tone,
questionnaire=questionnaire,
question=question,
user_text=user_text,
)
elif model_type == "huggingface":
# https://python.langchain.com/docs/integrations/llms/huggingface_hub
llm = HuggingFaceHub(
repo_id=hf_repo_id,
model_kwargs={"temperature": temperature, "max_length": max_length},
)
llm_chain = LLMChain(prompt=prompt, llm=llm)
response = llm_chain.run(
context=context,
ai_tone=ai_tone,
questionnaire=questionnaire,
question=question,
user_text=user_text,
)
print(response)
# Extracting only the response part from the output
response_start_index = response.find("Response;")
return response[response_start_index + len("Response;"):].strip()
else:
print(
"Please use the correct value of model_type parameter: It can have a value of either openai or huggingface"
)
if __name__ == "__main__":
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."
ai_tone = "EMPATHY"
questionnaire = "ADHD"
question = (
"How often do you find yourself having trouble focusing on tasks or activities?"
)
user_text = "I feel distracted all the time, and I am never able to finish"
# The user may have signs of {questionnaire}.
template = """INSTRUCTIONS: {context}
Respond to the user with a tone of {ai_tone}.
Question asked to the user: {question}
Response by the user: {user_text}
Provide some advice and ask a relevant question back to the user.
Response;
"""
temperature = 0.1
max_length = 128
model = LLLResponseGenerator()
llm_response = model.llm_inference(
model_type="huggingface",
question=question,
prompt_template=template,
context=context,
ai_tone=ai_tone,
questionnaire=questionnaire,
user_text=user_text,
temperature=temperature,
max_length=max_length,
)
print(llm_response)