File size: 5,479 Bytes
5038c7a
 
 
 
 
 
 
 
 
 
 
1872b66
 
5038c7a
1872b66
 
 
5038c7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1872b66
5038c7a
 
 
 
 
 
 
 
 
 
 
 
 
1872b66
 
5038c7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1872b66
 
5038c7a
 
 
 
 
 
 
 
1872b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
from langchain_community.llms import HuggingFaceHub
from langchain_community.llms import OpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import warnings

warnings.filterwarnings("ignore")

class LLLResponseGenerator():
    def __init__(self):
        self.context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction."
        self.conversation_history = []

    def update_context(self, user_text):
        self.conversation_history.append(user_text)
        self.context = "\n".join(self.conversation_history)

    def llm_inference(
        self,
        model_type: str,
        question: str,
        prompt_template: str,
        ai_tone: str,
        questionnaire: str,
        user_text: str,
        openai_model_name: str = "",
        hf_repo_id: str = "tiiuae/falcon-7b-instruct",
        temperature: float = 0.1,
        max_length: int = 128,
    ) -> str:
        """Call HuggingFace/OpenAI model for inference

        Given a question, prompt_template, and other parameters, this function calls the relevant
        API to fetch LLM inference results.

        Args:
            model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai'
            question: The question to be asked to the LLM.
            prompt_template: The prompt template itself.
            ai_tone: Can be either empathy, encouragement or suggest medical help.
            questionnaire: Can be either depression, anxiety or adhd.
            user_text: Response given by the user.
            hf_repo_id: The Huggingface model's repo_id
            temperature: (Default: 1.0). Range: Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
            max_length: Integer to define the maximum length in tokens of the output summary.

        Returns:
            A Python string which contains the inference result.

        HuggingFace repo_id examples:
            - google/flan-t5-xxl
            - tiiuae/falcon-7b-instruct

        """
        prompt = PromptTemplate(
            template=prompt_template,
            input_variables=[
                "context",
                "ai_tone",
                "questionnaire",
                "question",
                "user_text",
            ],
        )

        if model_type == "openai":
            llm = OpenAI(
                model_name=openai_model_name, temperature=temperature, max_tokens=max_length
            )
            llm_chain = LLMChain(prompt=prompt, llm=llm)
            return llm_chain.run(
                context=self.context,
                ai_tone=ai_tone,
                questionnaire=questionnaire,
                question=question,
                user_text=user_text,
            )

        elif model_type == "huggingface":
            llm = HuggingFaceHub(
                repo_id=hf_repo_id,
                model_kwargs={"temperature": temperature, "max_length": max_length},
            )

            llm_chain = LLMChain(prompt=prompt, llm=llm)
            response = llm_chain.run(
                context=self.context,
                ai_tone=ai_tone,
                questionnaire=questionnaire,
                question=question,
                user_text=user_text,
            )

            # Extracting only the response part from the output
            response_start_index = response.find("Response;")
            return response[response_start_index + len("Response;"):].strip()

        else:
            print(
                "Please use the correct value of model_type parameter: It can have a value of either openai or huggingface"
            )


if __name__ == "__main__":
    # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' and 'OPENAI_API_KEY' values.
    HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')

    ai_tone = "EMPATHY"
    questionnaire = "ADHD"
    question = (
        "How often do you find yourself having trouble focusing on tasks or activities?"
    )
    user_text = "I feel distracted all the time, and I am never able to finish"

    # The user may have signs of {questionnaire}.
    template = """INSTRUCTIONS: {context}

    Respond to the user with a tone of {ai_tone}.

    Question asked to the user: {question}

    Response by the user: {user_text}

    Provide some advice and ask a relevant question back to the user.

    Response;
    """

    temperature = 0.1
    max_length = 128

    model = LLLResponseGenerator()

    # Initial prompt
    print("Bot:", model.llm_inference(
        model_type="huggingface",
        question=question,
        prompt_template=template,
        ai_tone=ai_tone,
        questionnaire=questionnaire,
        user_text=user_text,
        temperature=temperature,
        max_length=max_length,
        ))

    while True:
        user_input = input("You: ")
        if user_input.lower() == "exit":
            break

        model.update_context(user_input)

        print("Bot:", model.llm_inference(
            model_type="huggingface",
            question=question,
            prompt_template=template,
            ai_tone=ai_tone,
            questionnaire=questionnaire,
            user_text=user_input,
            temperature=temperature,
            max_length=max_length,
        ))