|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
import torch |
|
|
|
|
|
model_id = "kingabzpro/Llama-3.1-8B-Instruct-Mental-Health-Classification" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
return_dict=True, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
) |
|
|
|
|
|
def classify_mental_health(text): |
|
prompt = f"""Classify the text into Normal, Depression, Anxiety, Bipolar, and return the answer as the corresponding mental health disorder label. |
|
text: {text} |
|
label: """.strip() |
|
|
|
|
|
outputs = pipe(prompt, max_new_tokens=2, do_sample=True, temperature=0.1) |
|
|
|
|
|
label = outputs[0]["generated_text"].split("label: ")[-1].strip() |
|
return label |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Mental Health Text Classification") |
|
|
|
text_input = gr.Textbox(label="Enter your text:") |
|
label_output = gr.Textbox(label="Predicted Mental Health Label") |
|
|
|
btn = gr.Button("Classify") |
|
|
|
|
|
btn.click(classify_mental_health, inputs=text_input, outputs=label_output) |
|
|
|
demo.launch() |
|
|