Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,54 +2,89 @@ import gradio as gr
|
|
| 2 |
import subprocess
|
| 3 |
import time
|
| 4 |
from ollama import chat
|
| 5 |
-
from ollama import ChatResponse
|
| 6 |
from huggingface_hub import InferenceClient
|
| 7 |
import os
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
if not hf_api_key:
|
| 12 |
-
print(
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
OLLAMA_MODEL = "llama3.2:3b"
|
| 18 |
# OLLAMA_MODEL = "llama3.2:1b"
|
| 19 |
# OLLAMA_MODEL = "llama3.2:3b-instruct-q2_K"
|
| 20 |
|
| 21 |
-
#
|
|
|
|
|
|
|
| 22 |
from transformers import pipeline, DistilBertTokenizerFast
|
| 23 |
|
| 24 |
-
# Path to your locally saved model
|
| 25 |
-
# bert_model_path = "fine_tuned_aita_classifier"
|
| 26 |
bert_model_path = "dingusagar/distillbert-aita-classifier"
|
| 27 |
|
| 28 |
tokenizer = DistilBertTokenizerFast.from_pretrained(bert_model_path)
|
| 29 |
classifier = pipeline(
|
| 30 |
"text-classification",
|
| 31 |
-
model=bert_model_path,
|
| 32 |
-
tokenizer=tokenizer,
|
| 33 |
-
truncation=True
|
| 34 |
)
|
| 35 |
|
| 36 |
bert_label_map = {
|
| 37 |
-
|
| 38 |
-
|
| 39 |
}
|
| 40 |
|
| 41 |
bert_label_map_formatted = {
|
| 42 |
-
|
| 43 |
-
|
| 44 |
}
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
|
|
|
| 48 |
result = classifier([prompt])[0]
|
| 49 |
-
label = result[
|
| 50 |
confidence = f"{result['score']*100:.2f}"
|
| 51 |
return label, confidence
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def start_ollama_server():
|
| 54 |
# Start Ollama server in the background
|
| 55 |
print("Starting Ollama server...")
|
|
@@ -64,187 +99,185 @@ def start_ollama_server():
|
|
| 64 |
subprocess.Popen(["ollama", "run", OLLAMA_MODEL], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 65 |
print("Ollama started model.")
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
classify_and_explain_prompt = f"""
|
| 70 |
-
### You are an unbiased expert from subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
| 71 |
-
The community uses the following acronyms.
|
| 72 |
-
AITA : Am I the asshole? Usually posted in the question.
|
| 73 |
-
YTA : You are the asshole in this situation.
|
| 74 |
NTA : You are not the asshole in this situation.
|
| 75 |
|
| 76 |
-
### The task for you label YTA or NTA for the given text. Give a short explanation for the label. Be brutally honest and unbiased. Base your explanation entirely on the given text only.
|
| 77 |
|
| 78 |
-
If the label is YTA, also explain what could the user have done better.
|
| 79 |
### The output format is as follows:
|
| 80 |
-
"YTA" or "NTA", a short explanation.
|
| 81 |
|
| 82 |
### Situation : {question}
|
| 83 |
### Response :"""
|
| 84 |
|
| 85 |
-
explain_only_prompt =
|
| 86 |
-
### You know about the subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
| 87 |
-
The community uses the following acronyms.
|
| 88 |
-
AITA : Am I the asshole? Usually posted in the question.
|
| 89 |
-
YTA : You are the asshole in this situation.
|
| 90 |
NTA : You are not the asshole in this situation.
|
| 91 |
|
| 92 |
-
### The task for you explain why a particular situation was tagged as NTA or YTA by most users. I will give the situation as well as the NTA or YTA tag. just give your explanation for the label. Be nice but give a brutally honest and unbiased view. Base your explanation entirely on the given text and the label tag only. Do not assume anything extra.
|
| 93 |
Use second person terms like you in the explanation.
|
| 94 |
|
| 95 |
### Situation : {question}
|
| 96 |
### Label Tag : {expected_class}
|
| 97 |
### Explanation for {expected_class} :"""
|
| 98 |
|
| 99 |
-
if expected_class
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
print(f"Prompt to llama : {prompt}")
|
| 105 |
-
stream = chat(
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
], stream=True)
|
| 111 |
response = ""
|
| 112 |
for chunk in stream:
|
| 113 |
-
response += chunk[
|
| 114 |
yield response
|
| 115 |
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
def ask_hf_inference_client(question, expected_class=""):
|
| 119 |
-
print(
|
| 120 |
-
|
| 121 |
-
### You are an unbiased expert from subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
| 122 |
-
The community uses the following acronyms.
|
| 123 |
-
AITA : Am I the asshole? Usually posted in the question.
|
| 124 |
-
YTA : You are the asshole in this situation.
|
| 125 |
-
NTA : You are not the asshole in this situation.
|
| 126 |
-
|
| 127 |
-
### The task for you label YTA or NTA for the given text. Give a short explanation for the label. Be brutally honest and unbiased. Base your explanation entirely on the given text only.
|
| 128 |
-
|
| 129 |
-
If the label is YTA, also explain what could the user have done better.
|
| 130 |
-
### The output format is as follows:
|
| 131 |
-
"YTA" or "NTA", a short explanation.
|
| 132 |
-
|
| 133 |
-
### Situation : {question}
|
| 134 |
-
### Response :"""
|
| 135 |
-
|
| 136 |
-
explain_only_prompt = f"""
|
| 137 |
-
### You know about the subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
| 138 |
-
The community uses the following acronyms.
|
| 139 |
-
AITA : Am I the asshole? Usually posted in the question.
|
| 140 |
-
YTA : You are the asshole in this situation.
|
| 141 |
-
NTA : You are not the asshole in this situation.
|
| 142 |
-
|
| 143 |
-
### The task for you explain why a particular situation was tagged as NTA or YTA by most users. I will give the situation as well as the NTA or YTA tag. just give your explanation for the label. Be nice but give a brutally honest and unbiased view. Base your explanation entirely on the given text and the label tag only. Do not assume anything extra.
|
| 144 |
-
Use second person terms like you in the explanation.
|
| 145 |
-
|
| 146 |
-
### Situation : {question}
|
| 147 |
-
### Label Tag : {expected_class}
|
| 148 |
-
### Explanation for {expected_class} :"""
|
| 149 |
-
|
| 150 |
-
if expected_class == "":
|
| 151 |
-
prompt = classify_and_explain_prompt
|
| 152 |
-
else:
|
| 153 |
-
prompt = explain_only_prompt
|
| 154 |
|
| 155 |
print(f"Prompt to HF_Inference API : {prompt}")
|
| 156 |
|
| 157 |
-
messages = [
|
| 158 |
-
{
|
| 159 |
-
"role": "user",
|
| 160 |
-
"content": prompt
|
| 161 |
-
}
|
| 162 |
-
]
|
| 163 |
-
|
| 164 |
-
stream = hf_client.chat.completions.create(
|
| 165 |
-
model="meta-llama/Llama-3.2-3B-Instruct",
|
| 166 |
-
messages=messages,
|
| 167 |
-
max_tokens=500,
|
| 168 |
-
stream=True
|
| 169 |
-
)
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
# Separate function for Ollama response
|
| 177 |
def gradio_ollama_interface(prompt, bert_class=""):
|
| 178 |
return ask_ollama(prompt, expected_class=bert_class)
|
|
|
|
|
|
|
| 179 |
def gradio_interface(prompt, selected_model):
|
| 180 |
if selected_model == MODEL_CHOICE_LLAMA:
|
| 181 |
for chunk in ask_ollama(prompt):
|
| 182 |
yield chunk
|
| 183 |
elif selected_model == MODEL_CHOICE_BERT:
|
| 184 |
label, confidence = ask_bert(prompt)
|
| 185 |
-
|
| 186 |
-
response = f"{
|
| 187 |
return response
|
| 188 |
elif selected_model == MODEL_CHOICE_BERT_LLAMA:
|
| 189 |
label, confidence = ask_bert(prompt)
|
| 190 |
-
initial_response =
|
|
|
|
|
|
|
|
|
|
| 191 |
yield initial_response
|
| 192 |
for chunk in ask_ollama(prompt, expected_class=bert_label_map[label]):
|
| 193 |
yield initial_response + "\n" + chunk
|
| 194 |
elif selected_model == MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE:
|
| 195 |
label, confidence = ask_bert(prompt)
|
| 196 |
-
initial_response =
|
|
|
|
|
|
|
|
|
|
| 197 |
yield initial_response
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
| 201 |
else:
|
| 202 |
return "Something went wrong. Select the correct model configuration from settings. "
|
| 203 |
|
|
|
|
| 204 |
MODEL_CHOICE_BERT_LLAMA = "Fine-tuned BERT (classification) + Llama 3.2 3B (explanation)"
|
| 205 |
-
MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE =
|
|
|
|
|
|
|
| 206 |
MODEL_CHOICE_BERT = "Fine-tuned BERT (classification only)"
|
| 207 |
MODEL_CHOICE_LLAMA = "Llama 3.2 3B (classification + explanation)"
|
| 208 |
|
| 209 |
-
MODEL_OPTIONS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
# Example texts
|
| 212 |
EXAMPLES = [
|
| 213 |
"I refused to invite my coworker to my birthday party even though we’re part of the same friend group. AITA?",
|
| 214 |
"I didn't attend my best friend's wedding because I couldn't afford the trip. Now they are mad at me. AITA?",
|
| 215 |
"I told my coworker they were being unprofessional during a meeting in front of everyone. AITA?",
|
| 216 |
-
"I told my kid that she should become an engineer like me, she is into painting and wants to pursue arts. AITA? "
|
| 217 |
]
|
| 218 |
|
|
|
|
| 219 |
# Build the Gradio app
|
| 220 |
-
#
|
| 221 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
gr.Markdown("# AITA Classifier")
|
| 223 |
gr.Markdown(
|
| 224 |
"""### Ask this AI app if you are wrong in a situation. Describe the conflict you experienced, give both sides of the story and find out if you are right (NTA) or, you are the a**shole (YTA). Inspired by the subreddit [r/AmItheAsshole](https://www.reddit.com/r/AmItheAsshole/), this app tries to provide honest and unbiased assessments of user's life situations.
|
| 225 |
<sub>**Disclaimer:** The responses generated by this AI model are based on the training data derived from the subreddit posts and do not represent the views or opinions of the creators or authors. This was our fun little project, please don't take the generated responses too seriously :) </sub>
|
| 226 |
-
"""
|
| 227 |
-
|
| 228 |
-
# Add Accordion for settings
|
| 229 |
-
# with gr.Accordion("Settings", open=True):
|
| 230 |
-
# model_selector = gr.Dropdown(
|
| 231 |
-
# label="Select Models",
|
| 232 |
-
# choices=MODEL_OPTIONS,
|
| 233 |
-
# value=MODEL_CHOICE_BERT_LLAMA
|
| 234 |
-
# )
|
| 235 |
|
| 236 |
with gr.Row():
|
| 237 |
model_selector = gr.Dropdown(
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
|
| 243 |
with gr.Row():
|
| 244 |
-
input_prompt = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
with gr.Row():
|
| 247 |
-
# Add example texts
|
| 248 |
example = gr.Examples(
|
| 249 |
examples=EXAMPLES,
|
| 250 |
inputs=input_prompt,
|
|
@@ -255,9 +288,12 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.green, secon
|
|
| 255 |
submit_button = gr.Button("Check A**hole or not!", variant="primary")
|
| 256 |
|
| 257 |
with gr.Row():
|
| 258 |
-
output_response = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
-
# Link the button click to the interface function
|
| 261 |
submit_button.click(gradio_interface, inputs=[input_prompt, model_selector], outputs=output_response)
|
| 262 |
|
| 263 |
# Launch the app
|
|
|
|
| 2 |
import subprocess
|
| 3 |
import time
|
| 4 |
from ollama import chat
|
| 5 |
+
# from ollama import ChatResponse # (unused)
|
| 6 |
from huggingface_hub import InferenceClient
|
| 7 |
import os
|
| 8 |
|
| 9 |
+
# -----------------------------
|
| 10 |
+
# Hugging Face Inference Client
|
| 11 |
+
# -----------------------------
|
| 12 |
+
# Accept common env var names
|
| 13 |
+
hf_api_key = (
|
| 14 |
+
os.getenv("HF_API_KEY")
|
| 15 |
+
or os.getenv("HF_TOKEN")
|
| 16 |
+
or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
if not hf_api_key:
|
| 20 |
+
print(
|
| 21 |
+
"[WARN] No HF API token found in HF_API_KEY / HF_TOKEN / HUGGING_FACE_HUB_TOKEN.\n"
|
| 22 |
+
" If you see 401/404 from the Inference API, set one of these."
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# IMPORTANT: initialize the client WITH the model id here, and do NOT pass a model in the URL
|
| 26 |
+
# This avoids constructing a broken path like /models/<repo>/v1/chat/completions
|
| 27 |
+
HF_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
|
| 28 |
+
hf_client = InferenceClient(model=HF_MODEL_ID, token=hf_api_key)
|
| 29 |
|
| 30 |
+
# Optional: warn on very old huggingface_hub versions
|
| 31 |
+
try:
|
| 32 |
+
import huggingface_hub as hfh
|
| 33 |
+
from packaging import version
|
| 34 |
|
| 35 |
+
if version.parse(hfh.__version__) < version.parse("0.25.0"):
|
| 36 |
+
print(
|
| 37 |
+
f"[WARN] huggingface_hub {hfh.__version__} is a bit old. "
|
| 38 |
+
"Consider: pip install -U huggingface_hub"
|
| 39 |
+
)
|
| 40 |
+
except Exception:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
# -----------------------------
|
| 44 |
+
# Default local model (Ollama)
|
| 45 |
+
# -----------------------------
|
| 46 |
OLLAMA_MODEL = "llama3.2:3b"
|
| 47 |
# OLLAMA_MODEL = "llama3.2:1b"
|
| 48 |
# OLLAMA_MODEL = "llama3.2:3b-instruct-q2_K"
|
| 49 |
|
| 50 |
+
# -----------------------------
|
| 51 |
+
# Fine-tuned BERT classifier
|
| 52 |
+
# -----------------------------
|
| 53 |
from transformers import pipeline, DistilBertTokenizerFast
|
| 54 |
|
|
|
|
|
|
|
| 55 |
bert_model_path = "dingusagar/distillbert-aita-classifier"
|
| 56 |
|
| 57 |
tokenizer = DistilBertTokenizerFast.from_pretrained(bert_model_path)
|
| 58 |
classifier = pipeline(
|
| 59 |
"text-classification",
|
| 60 |
+
model=bert_model_path,
|
| 61 |
+
tokenizer=tokenizer,
|
| 62 |
+
truncation=True,
|
| 63 |
)
|
| 64 |
|
| 65 |
bert_label_map = {
|
| 66 |
+
"LABEL_0": "YTA",
|
| 67 |
+
"LABEL_1": "NTA",
|
| 68 |
}
|
| 69 |
|
| 70 |
bert_label_map_formatted = {
|
| 71 |
+
"LABEL_0": "You are the A**hole (YTA)",
|
| 72 |
+
"LABEL_1": "Not the A**hole (NTA)",
|
| 73 |
}
|
| 74 |
|
| 75 |
+
|
| 76 |
+
def ask_bert(prompt: str):
|
| 77 |
+
print("Getting response from Fine-tuned BERT")
|
| 78 |
result = classifier([prompt])[0]
|
| 79 |
+
label = result["label"]
|
| 80 |
confidence = f"{result['score']*100:.2f}"
|
| 81 |
return label, confidence
|
| 82 |
|
| 83 |
+
|
| 84 |
+
# -----------------------------
|
| 85 |
+
# Ollama helpers
|
| 86 |
+
# -----------------------------
|
| 87 |
+
|
| 88 |
def start_ollama_server():
|
| 89 |
# Start Ollama server in the background
|
| 90 |
print("Starting Ollama server...")
|
|
|
|
| 99 |
subprocess.Popen(["ollama", "run", OLLAMA_MODEL], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 100 |
print("Ollama started model.")
|
| 101 |
|
| 102 |
+
|
| 103 |
+
def _build_prompts(question: str, expected_class: str = ""):
|
| 104 |
classify_and_explain_prompt = f"""
|
| 105 |
+
### You are an unbiased expert from subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
| 106 |
+
The community uses the following acronyms.
|
| 107 |
+
AITA : Am I the asshole? Usually posted in the question.
|
| 108 |
+
YTA : You are the asshole in this situation.
|
| 109 |
NTA : You are not the asshole in this situation.
|
| 110 |
|
| 111 |
+
### The task for you label YTA or NTA for the given text. Give a short explanation for the label. Be brutally honest and unbiased. Base your explanation entirely on the given text only.
|
| 112 |
|
| 113 |
+
If the label is YTA, also explain what could the user have done better.
|
| 114 |
### The output format is as follows:
|
| 115 |
+
"YTA" or "NTA", a short explanation.
|
| 116 |
|
| 117 |
### Situation : {question}
|
| 118 |
### Response :"""
|
| 119 |
|
| 120 |
+
explain_only_prompt = f"""
|
| 121 |
+
### You know about the subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
| 122 |
+
The community uses the following acronyms.
|
| 123 |
+
AITA : Am I the asshole? Usually posted in the question.
|
| 124 |
+
YTA : You are the asshole in this situation.
|
| 125 |
NTA : You are not the asshole in this situation.
|
| 126 |
|
| 127 |
+
### The task for you explain why a particular situation was tagged as NTA or YTA by most users. I will give the situation as well as the NTA or YTA tag. just give your explanation for the label. Be nice but give a brutally honest and unbiased view. Base your explanation entirely on the given text and the label tag only. Do not assume anything extra.
|
| 128 |
Use second person terms like you in the explanation.
|
| 129 |
|
| 130 |
### Situation : {question}
|
| 131 |
### Label Tag : {expected_class}
|
| 132 |
### Explanation for {expected_class} :"""
|
| 133 |
|
| 134 |
+
return (explain_only_prompt if expected_class else classify_and_explain_prompt)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def ask_ollama(question: str, expected_class: str = ""):
|
| 138 |
+
print("Getting response from Ollama")
|
| 139 |
+
prompt = _build_prompts(question, expected_class)
|
| 140 |
|
| 141 |
print(f"Prompt to llama : {prompt}")
|
| 142 |
+
stream = chat(
|
| 143 |
+
model=OLLAMA_MODEL,
|
| 144 |
+
messages=[{"role": "user", "content": prompt}],
|
| 145 |
+
stream=True,
|
| 146 |
+
)
|
|
|
|
| 147 |
response = ""
|
| 148 |
for chunk in stream:
|
| 149 |
+
response += chunk["message"]["content"]
|
| 150 |
yield response
|
| 151 |
|
| 152 |
|
| 153 |
+
# --------------------------------------
|
| 154 |
+
# Hugging Face Inference (Chat Completions)
|
| 155 |
+
# --------------------------------------
|
| 156 |
|
| 157 |
+
def ask_hf_inference_client(question: str, expected_class: str = ""):
|
| 158 |
+
print("Getting response from HF Inference (chat.completions)")
|
| 159 |
+
prompt = _build_prompts(question, expected_class)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
print(f"Prompt to HF_Inference API : {prompt}")
|
| 162 |
|
| 163 |
+
messages = [{"role": "user", "content": prompt}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
try:
|
| 166 |
+
# NOTE: We initialized the client with a model, so we DO NOT pass model= here
|
| 167 |
+
stream = hf_client.chat.completions.create(
|
| 168 |
+
messages=messages,
|
| 169 |
+
max_tokens=500,
|
| 170 |
+
stream=True,
|
| 171 |
+
temperature=0.2,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
for chunk in stream:
|
| 175 |
+
# Be defensive: delta/content may be None on some events
|
| 176 |
+
try:
|
| 177 |
+
delta = chunk.choices[0].delta
|
| 178 |
+
if delta and getattr(delta, "content", None):
|
| 179 |
+
yield delta.content
|
| 180 |
+
except Exception:
|
| 181 |
+
# If schema slightly differs, just ignore and continue
|
| 182 |
+
continue
|
| 183 |
+
except Exception as e:
|
| 184 |
+
# Surface a friendly message in the UI instead of crashing Gradio
|
| 185 |
+
yield f"[HF Inference error] {type(e).__name__}: {e}"
|
| 186 |
|
| 187 |
|
| 188 |
+
# -----------------------------
|
| 189 |
+
# Gradio glue
|
| 190 |
+
# -----------------------------
|
| 191 |
|
|
|
|
| 192 |
def gradio_ollama_interface(prompt, bert_class=""):
|
| 193 |
return ask_ollama(prompt, expected_class=bert_class)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
def gradio_interface(prompt, selected_model):
|
| 197 |
if selected_model == MODEL_CHOICE_LLAMA:
|
| 198 |
for chunk in ask_ollama(prompt):
|
| 199 |
yield chunk
|
| 200 |
elif selected_model == MODEL_CHOICE_BERT:
|
| 201 |
label, confidence = ask_bert(prompt)
|
| 202 |
+
label_fmt = bert_label_map_formatted[label]
|
| 203 |
+
response = f"{label_fmt} with confidence {confidence}%"
|
| 204 |
return response
|
| 205 |
elif selected_model == MODEL_CHOICE_BERT_LLAMA:
|
| 206 |
label, confidence = ask_bert(prompt)
|
| 207 |
+
initial_response = (
|
| 208 |
+
f"Response from BERT model: {bert_label_map_formatted[label]} with confidence {confidence}%\n\n"
|
| 209 |
+
"Generating explanation using Llama model...\n"
|
| 210 |
+
)
|
| 211 |
yield initial_response
|
| 212 |
for chunk in ask_ollama(prompt, expected_class=bert_label_map[label]):
|
| 213 |
yield initial_response + "\n" + chunk
|
| 214 |
elif selected_model == MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE:
|
| 215 |
label, confidence = ask_bert(prompt)
|
| 216 |
+
initial_response = (
|
| 217 |
+
f"Response from BERT model: {bert_label_map_formatted[label]} with confidence {confidence}%\n\n"
|
| 218 |
+
"Generating explanation using Llama (HF Inference)...\n"
|
| 219 |
+
)
|
| 220 |
yield initial_response
|
| 221 |
+
acc = initial_response
|
| 222 |
+
for piece in ask_hf_inference_client(prompt, expected_class=bert_label_map[label]):
|
| 223 |
+
acc += piece or ""
|
| 224 |
+
yield acc
|
| 225 |
else:
|
| 226 |
return "Something went wrong. Select the correct model configuration from settings. "
|
| 227 |
|
| 228 |
+
|
| 229 |
MODEL_CHOICE_BERT_LLAMA = "Fine-tuned BERT (classification) + Llama 3.2 3B (explanation)"
|
| 230 |
+
MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE = (
|
| 231 |
+
"Fine-tuned BERT (classification) + Llama 3.2 3B Inference api (fast explanation)"
|
| 232 |
+
)
|
| 233 |
MODEL_CHOICE_BERT = "Fine-tuned BERT (classification only)"
|
| 234 |
MODEL_CHOICE_LLAMA = "Llama 3.2 3B (classification + explanation)"
|
| 235 |
|
| 236 |
+
MODEL_OPTIONS = [
|
| 237 |
+
MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE,
|
| 238 |
+
MODEL_CHOICE_BERT_LLAMA,
|
| 239 |
+
MODEL_CHOICE_LLAMA,
|
| 240 |
+
MODEL_CHOICE_BERT,
|
| 241 |
+
]
|
| 242 |
|
| 243 |
# Example texts
|
| 244 |
EXAMPLES = [
|
| 245 |
"I refused to invite my coworker to my birthday party even though we’re part of the same friend group. AITA?",
|
| 246 |
"I didn't attend my best friend's wedding because I couldn't afford the trip. Now they are mad at me. AITA?",
|
| 247 |
"I told my coworker they were being unprofessional during a meeting in front of everyone. AITA?",
|
| 248 |
+
"I told my kid that she should become an engineer like me, she is into painting and wants to pursue arts. AITA? ",
|
| 249 |
]
|
| 250 |
|
| 251 |
+
# -----------------------------
|
| 252 |
# Build the Gradio app
|
| 253 |
+
# -----------------------------
|
| 254 |
+
with gr.Blocks(
|
| 255 |
+
theme=gr.themes.Default(
|
| 256 |
+
primary_hue=gr.themes.colors.green, secondary_hue=gr.themes.colors.purple
|
| 257 |
+
)
|
| 258 |
+
) as demo:
|
| 259 |
gr.Markdown("# AITA Classifier")
|
| 260 |
gr.Markdown(
|
| 261 |
"""### Ask this AI app if you are wrong in a situation. Describe the conflict you experienced, give both sides of the story and find out if you are right (NTA) or, you are the a**shole (YTA). Inspired by the subreddit [r/AmItheAsshole](https://www.reddit.com/r/AmItheAsshole/), this app tries to provide honest and unbiased assessments of user's life situations.
|
| 262 |
<sub>**Disclaimer:** The responses generated by this AI model are based on the training data derived from the subreddit posts and do not represent the views or opinions of the creators or authors. This was our fun little project, please don't take the generated responses too seriously :) </sub>
|
| 263 |
+
"""
|
| 264 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
with gr.Row():
|
| 267 |
model_selector = gr.Dropdown(
|
| 268 |
+
label="Selected Model",
|
| 269 |
+
choices=MODEL_OPTIONS,
|
| 270 |
+
value=MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE,
|
| 271 |
+
)
|
| 272 |
|
| 273 |
with gr.Row():
|
| 274 |
+
input_prompt = gr.Textbox(
|
| 275 |
+
label="Enter your situation here",
|
| 276 |
+
placeholder="Am I the a**hole for...",
|
| 277 |
+
lines=5,
|
| 278 |
+
)
|
| 279 |
|
| 280 |
with gr.Row():
|
|
|
|
| 281 |
example = gr.Examples(
|
| 282 |
examples=EXAMPLES,
|
| 283 |
inputs=input_prompt,
|
|
|
|
| 288 |
submit_button = gr.Button("Check A**hole or not!", variant="primary")
|
| 289 |
|
| 290 |
with gr.Row():
|
| 291 |
+
output_response = gr.Textbox(
|
| 292 |
+
label="Response",
|
| 293 |
+
lines=10,
|
| 294 |
+
placeholder="""Result will be YTA (you are the A**hole) or NTA(Not the A**shole)""",
|
| 295 |
+
)
|
| 296 |
|
|
|
|
| 297 |
submit_button.click(gradio_interface, inputs=[input_prompt, model_selector], outputs=output_response)
|
| 298 |
|
| 299 |
# Launch the app
|