File size: 5,438 Bytes
a4c8f8f
ab8f613
 
a4c8f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
ab8f613
 
 
a4c8f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1830956
a4c8f8f
 
ab8f613
 
 
a4c8f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import re
from dotenv import load_dotenv, find_dotenv
from langchain_community.llms import HuggingFaceHub
from langchain_community.llms import OpenAI
# from langchain.llms import HuggingFaceHub, OpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import warnings

warnings.filterwarnings("ignore")

class LLLResponseGenerator():

    def __init__(self):
        print("initialized")

    def remove_html_tags(text):
        clean_text = re.sub(r'<[^>]*>', '', text)
        return clean_text

    def llm_inference(
        self,
        model_type: str,
        question: str,
        prompt_template: str,
        context: str,
        ai_tone: str,
        questionnaire: str,
        user_text: str,
        openai_model_name: str = "",
        hf_repo_id: str = "tiiuae/falcon-7b-instruct",
        temperature: float = 0.1,
        max_length: int = 128,
    ) -> str:
        """Call HuggingFace/OpenAI model for inference

        Given a question, prompt_template, and other parameters, this function calls the relevant
        API to fetch LLM inference results.

        Args:
            model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
            question: The question to be asked to the LLM.
            prompt_template: The prompt template itself.
            context: Instructions for the LLM.
            ai_tone: Can be either empathy, encouragement or suggest medical help.
            questionnaire: Can be either depression, anxiety or adhd.
            user_text: Response given by the user.
            hf_repo_id: The Huggingface model's repo_id
            temperature: (Default: 1.0). Range: Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
            max_length: Integer to define the maximum length in tokens of the output summary.

        Returns:
            A Python string which contains the inference result.

        HuggingFace repo_id examples:
            - google/flan-t5-xxl
            - tiiuae/falcon-7b-instruct

        """
        prompt = PromptTemplate(
            template=prompt_template,
            input_variables=[
                "context",
                "ai_tone",
                "questionnaire",
                "question",
                "user_text",
            ],
        )

        if model_type == "openai":
            # https://api.python.langchain.com/en/stable/llms/langchain.llms.openai.OpenAI.html#langchain.llms.openai.OpenAI
            llm = OpenAI(
                model_name=openai_model_name, temperature=temperature, max_tokens=max_length
            )
            llm_chain = LLMChain(prompt=prompt, llm=llm)
            return llm_chain.run(
                context=context,
                ai_tone=ai_tone,
                questionnaire=questionnaire,
                question=question,
                user_text=user_text,
            )

        elif model_type == "huggingface":
            # https://python.langchain.com/docs/integrations/llms/huggingface_hub
            llm = HuggingFaceHub(
                repo_id=hf_repo_id,
                model_kwargs={"temperature": temperature, "max_length": max_length},
            )

            llm_chain = LLMChain(prompt=prompt, llm=llm)
            response =  llm_chain.run(
                context=context,
                ai_tone=ai_tone,
                questionnaire=questionnaire,
                question=question,
                user_text=user_text,
            )
            print(response)
            # Extracting only the response part from the output
            response_start_index = response.find("Response;")
            response_output = response[response_start_index + len("Response;"):].strip()
            # print(response_output)
            return response_output

        else:
            print(
                "Please use the correct value of model_type parameter: It can have a value of either openai or huggingface"
            )


if __name__ == "__main__":
    # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
    HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')

    context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."

    ai_tone = "EMPATHY"
    questionnaire = "ADHD"
    question = (
        "How often do you find yourself having trouble focusing on tasks or activities?"
    )
    user_text = "I feel distracted all the time, and I am never able to finish"

    # The user may have signs of {questionnaire}.
    template = """INSTRUCTIONS: {context}

    Respond to the user with a tone of {ai_tone}.

    Question asked to the user: {question}

    Response by the user: {user_text}

    Provide some advice and ask a relevant question back to the user.

    Response;
    """

    temperature = 0.1
    max_length = 128

    model = LLLResponseGenerator()


    llm_response = model.llm_inference(
        model_type="huggingface",
        question=question,
        prompt_template=template,
        context=context,
        ai_tone=ai_tone,
        questionnaire=questionnaire,
        user_text=user_text,
        temperature=temperature,
        max_length=max_length,
        )

    print(llm_response)