Spaces:
Sleeping
Sleeping
import os | |
from langchain_community.llms import HuggingFaceHub | |
from langchain_community.llms import OpenAI | |
# from langchain.llms import HuggingFaceHub, OpenAI | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
import warnings | |
warnings.filterwarnings("ignore") | |
class LLLResponseGenerator(): | |
def __init__(self): | |
print("initialized") | |
def llm_inference( | |
self, | |
model_type: str, | |
question: str, | |
prompt_template: str, | |
context: str, | |
ai_tone: str, | |
questionnaire: str, | |
user_text: str, | |
openai_model_name: str = "", | |
hf_repo_id: str = "tiiuae/falcon-7b-instruct", | |
temperature: float = 0.1, | |
max_length: int = 128, | |
) -> str: | |
"""Call HuggingFace/OpenAI model for inference | |
Given a question, prompt_template, and other parameters, this function calls the relevant | |
API to fetch LLM inference results. | |
Args: | |
model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai' | |
question: The question to be asked to the LLM. | |
prompt_template: The prompt template itself. | |
context: Instructions for the LLM. | |
ai_tone: Can be either empathy, encouragement or suggest medical help. | |
questionnaire: Can be either depression, anxiety or adhd. | |
user_text: Response given by the user. | |
hf_repo_id: The Huggingface model's repo_id | |
temperature: (Default: 1.0). Range: Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability. | |
max_length: Integer to define the maximum length in tokens of the output summary. | |
Returns: | |
A Python string which contains the inference result. | |
HuggingFace repo_id examples: | |
- google/flan-t5-xxl | |
- tiiuae/falcon-7b-instruct | |
""" | |
prompt = PromptTemplate( | |
template=prompt_template, | |
input_variables=[ | |
"context", | |
"ai_tone", | |
"questionnaire", | |
"question", | |
"user_text", | |
], | |
) | |
if model_type == "openai": | |
# https://api.python.langchain.com/en/stable/llms/langchain.llms.openai.OpenAI.html#langchain.llms.openai.OpenAI | |
llm = OpenAI( | |
model_name=openai_model_name, temperature=temperature, max_tokens=max_length | |
) | |
llm_chain = LLMChain(prompt=prompt, llm=llm) | |
return llm_chain.run( | |
context=context, | |
ai_tone=ai_tone, | |
questionnaire=questionnaire, | |
question=question, | |
user_text=user_text, | |
) | |
elif model_type == "huggingface": | |
# https://python.langchain.com/docs/integrations/llms/huggingface_hub | |
llm = HuggingFaceHub( | |
repo_id=hf_repo_id, | |
model_kwargs={"temperature": temperature, "max_length": max_length}, | |
) | |
llm_chain = LLMChain(prompt=prompt, llm=llm) | |
response = llm_chain.run( | |
context=context, | |
ai_tone=ai_tone, | |
questionnaire=questionnaire, | |
question=question, | |
user_text=user_text, | |
) | |
print(response) | |
# Extracting only the response part from the output | |
response_start_index = response.find("Response;") | |
return response[response_start_index + len("Response;"):].strip() | |
else: | |
print( | |
"Please use the correct value of model_type parameter: It can have a value of either openai or huggingface" | |
) | |
if __name__ == "__main__": | |
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values. | |
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction." | |
ai_tone = "EMPATHY" | |
questionnaire = "ADHD" | |
question = ( | |
"How often do you find yourself having trouble focusing on tasks or activities?" | |
) | |
user_text = "I feel distracted all the time, and I am never able to finish" | |
# The user may have signs of {questionnaire}. | |
template = """INSTRUCTIONS: {context} | |
Respond to the user with a tone of {ai_tone}. | |
Question asked to the user: {question} | |
Response by the user: {user_text} | |
Provide some advice and ask a relevant question back to the user. | |
Response; | |
""" | |
temperature = 0.1 | |
max_length = 128 | |
model = LLLResponseGenerator() | |
llm_response = model.llm_inference( | |
model_type="huggingface", | |
question=question, | |
prompt_template=template, | |
context=context, | |
ai_tone=ai_tone, | |
questionnaire=questionnaire, | |
user_text=user_text, | |
temperature=temperature, | |
max_length=max_length, | |
) | |
print(llm_response) | |