Upload 14 files
Browse files- .gitattributes +0 -33
- LICENSE +9 -0
- README.md +62 -3
- README_ASSETS.md +8 -0
- app.py +41 -0
- configs/response_config.json +8 -0
- configs/text_emotion_config.json +8 -0
- example_inputs/example_images.txt +4 -0
- example_inputs/example_texts.txt +3 -0
- inference.py +33 -0
- requirements.txt +13 -0
- train_response_generator.py +75 -0
- train_text_emotion.py +72 -0
- utils.py +50 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
| 1 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License 2.0
|
| 2 |
+
|
| 3 |
+
Copyright 2025 hmnshudhmn24
|
| 4 |
+
|
| 5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
you may not use this file except in compliance with the License.
|
| 7 |
+
You may obtain a copy of the License at
|
| 8 |
+
|
| 9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
README.md
CHANGED
|
@@ -1,3 +1,62 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
datasets: [go_emotions, empathetic_dialogues]
|
| 5 |
+
pipeline_tag: text-generation
|
| 6 |
+
library_name: transformers
|
| 7 |
+
tags:
|
| 8 |
+
- multimodal
|
| 9 |
+
- emotion-detection
|
| 10 |
+
- empathetic-chatbot
|
| 11 |
+
- t5
|
| 12 |
+
- clip
|
| 13 |
+
- streamlit
|
| 14 |
+
base_model: t5-small
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# Emo — Multimodal Emotion-Aware Assistant
|
| 18 |
+
|
| 19 |
+
**Repository:** `hmnshudhmn24/emo-multimodal-assistant`
|
| 20 |
+
|
| 21 |
+
**Short:** An advanced assistant that detects user emotion from text *and image*, and responds empathetically by conditioning a text-generator (T5) on the detected emotions.
|
| 22 |
+
|
| 23 |
+
**Components**
|
| 24 |
+
- Text emotion classifier (DistilBERT fine-tuned on GoEmotions)
|
| 25 |
+
- Image emotion detector (CLIP zero-shot with emotion labels)
|
| 26 |
+
- Response generator (T5-small fine-tuned on EmpatheticDialogues)
|
| 27 |
+
- Inference script combining everything
|
| 28 |
+
- Streamlit app for quick demo (text + optional image upload)
|
| 29 |
+
|
| 30 |
+
## Quick usage (inference)
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
from inference import EmoAssistant
|
| 34 |
+
|
| 35 |
+
assistant = EmoAssistant(
|
| 36 |
+
text_emotion_model="hmnshudhmn24/emo-text-emotion",
|
| 37 |
+
response_model="hmnshudhmn24/emo-response-generator"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# text-only
|
| 41 |
+
reply = assistant.respond(user_text="I'm so stressed about exams.")
|
| 42 |
+
print(reply)
|
| 43 |
+
|
| 44 |
+
# text + image (image path)
|
| 45 |
+
reply = assistant.respond(user_text="I had a rough day", image_path="example.jpg")
|
| 46 |
+
print(reply)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## How to train (short)
|
| 50 |
+
1. Train text emotion classifier:
|
| 51 |
+
```bash
|
| 52 |
+
python train_text_emotion.py --save-dir ./emo-text-emotion
|
| 53 |
+
```
|
| 54 |
+
2. Train response generator (empathetic responses):
|
| 55 |
+
```bash
|
| 56 |
+
python train_response_generator.py --save-dir ./emo-response-generator
|
| 57 |
+
```
|
| 58 |
+
3. After training, add `pytorch_model.bin`, tokenizer files, and README for each model and upload to Hugging Face or put them in local folders referenced by `inference.py`.
|
| 59 |
+
|
| 60 |
+
## Notes & Ethics
|
| 61 |
+
- This is not for medical/mental-health diagnosis. It’s built for supportive, empathetic responses only.
|
| 62 |
+
- Always add content / safety filters before production.
|
README_ASSETS.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Assets & Where to put model files
|
| 2 |
+
|
| 3 |
+
- After training `train_text_emotion.py`, save the model and tokenizer to `./emo-text-emotion/` or push to `hmnshudhmn24/emo-text-emotion`.
|
| 4 |
+
- After training `train_response_generator.py`, save to `./emo-response-generator/` or push to `hmnshudhmn24/emo-response-generator`.
|
| 5 |
+
- The inference script expects:
|
| 6 |
+
- a text classifier model (HF name or local path)
|
| 7 |
+
- a response generator (HF name or local path)
|
| 8 |
+
- CLIP is loaded from `openai/clip-vit-base-patch32` via transformers
|
app.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py - Streamlit demo for Emo Multimodal Assistant
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from inference import EmoAssistant
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
st.set_page_config(page_title="Emo Assistant", layout="centered")
|
| 9 |
+
st.title("Emo — Multimodal Emotion-Aware Assistant")
|
| 10 |
+
|
| 11 |
+
st.markdown("This demo detects emotion from text and optional image, then generates an empathetic response.")
|
| 12 |
+
|
| 13 |
+
# model selection / paths
|
| 14 |
+
text_model = st.text_input("Text emotion model (HF repo or local path)", value="distilbert-base-uncased")
|
| 15 |
+
response_model = st.text_input("Response generator model (HF repo or local path)", value="t5-small")
|
| 16 |
+
|
| 17 |
+
assistant = None
|
| 18 |
+
if st.button("Load models"):
|
| 19 |
+
with st.spinner("Loading models — this may take a minute..."):
|
| 20 |
+
assistant = EmoAssistant(text_emotion_model=text_model, response_model=response_model)
|
| 21 |
+
st.success("Models loaded — ready to go!")
|
| 22 |
+
|
| 23 |
+
user_text = st.text_area("Your message", value="I had a rough day at work and feel exhausted.", height=120)
|
| 24 |
+
|
| 25 |
+
uploaded_file = st.file_uploader("Upload an image (optional)", type=["jpg","jpeg","png"])
|
| 26 |
+
image_path = None
|
| 27 |
+
if uploaded_file is not None:
|
| 28 |
+
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1])
|
| 29 |
+
tfile.write(uploaded_file.read())
|
| 30 |
+
tfile.flush()
|
| 31 |
+
image_path = tfile.name
|
| 32 |
+
st.image(Image.open(image_path), caption="Uploaded image", use_column_width=True)
|
| 33 |
+
|
| 34 |
+
if st.button("Get empathetic reply"):
|
| 35 |
+
if assistant is None:
|
| 36 |
+
with st.spinner("Loading models (first time)..."):
|
| 37 |
+
assistant = EmoAssistant(text_emotion_model=text_model, response_model=response_model)
|
| 38 |
+
with st.spinner("Detecting emotion and generating response..."):
|
| 39 |
+
reply = assistant.respond(user_text, image_path=image_path)
|
| 40 |
+
st.subheader("Assistant reply")
|
| 41 |
+
st.write(reply)
|
configs/response_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name_or_path": "t5-small",
|
| 3 |
+
"max_input_length": 256,
|
| 4 |
+
"max_target_length": 64,
|
| 5 |
+
"learning_rate": 0.0003,
|
| 6 |
+
"batch_size": 8,
|
| 7 |
+
"num_train_epochs": 3
|
| 8 |
+
}
|
configs/text_emotion_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name_or_path": "distilbert-base-uncased",
|
| 3 |
+
"num_labels": 28,
|
| 4 |
+
"max_length": 128,
|
| 5 |
+
"learning_rate": 2e-05,
|
| 6 |
+
"batch_size": 16,
|
| 7 |
+
"num_train_epochs": 3
|
| 8 |
+
}
|
example_inputs/example_images.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Put image file paths or URLs here for testing image emotion detection.
|
| 2 |
+
# Example:
|
| 3 |
+
# ./examples/sad_person.jpg
|
| 4 |
+
# ./examples/happy_group.jpg
|
example_inputs/example_texts.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
I had a terrible day at work, nothing I did went right.
|
| 2 |
+
I’m so excited for my new project — I can’t wait!
|
| 3 |
+
I feel like nobody understands me.
|
inference.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference.py
|
| 2 |
+
from transformers import pipeline, T5ForConditionalGeneration, T5TokenizerFast, AutoTokenizer, AutoModelForSequenceClassification
|
| 3 |
+
from utils import predict_image_emotion, combine_emotions
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class EmoAssistant:
|
| 7 |
+
def __init__(self, text_emotion_model: str, response_model: str, device: int = None):
|
| 8 |
+
self.device = device if device is not None else (0 if torch.cuda.is_available() else -1)
|
| 9 |
+
# text emotion pipeline (single-label)
|
| 10 |
+
self.text_clf = pipeline("text-classification", model=text_emotion_model, device=self.device, return_all_scores=False)
|
| 11 |
+
# response generator
|
| 12 |
+
self.response_tokenizer = T5TokenizerFast.from_pretrained(response_model)
|
| 13 |
+
self.response_model = T5ForConditionalGeneration.from_pretrained(response_model).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
|
| 15 |
+
def detect_text_emotion(self, text: str):
|
| 16 |
+
res = self.text_clf(text)[0]
|
| 17 |
+
return res.get("label", "neutral")
|
| 18 |
+
|
| 19 |
+
def respond(self, user_text: str, image_path: str = None, max_length: int = 64):
|
| 20 |
+
text_emotion = self.detect_text_emotion(user_text)
|
| 21 |
+
image_emotions = None
|
| 22 |
+
if image_path:
|
| 23 |
+
image_emotions = predict_image_emotion(image_path)
|
| 24 |
+
combined = combine_emotions(text_emotion, image_emotions)
|
| 25 |
+
prompt = f"emotion: {combined} context: {user_text}"
|
| 26 |
+
inputs = self.response_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(self.response_model.device)
|
| 27 |
+
outputs = self.response_model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True)
|
| 28 |
+
reply = self.response_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 29 |
+
return reply
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
assistant = EmoAssistant(text_emotion_model="distilbert-base-uncased", response_model="t5-small")
|
| 33 |
+
print(assistant.respond("I failed my exam today and feel terrible."))
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.30.0
|
| 2 |
+
datasets>=2.10.0
|
| 3 |
+
torch>=1.12.0
|
| 4 |
+
accelerate>=0.18.0
|
| 5 |
+
sentencepiece
|
| 6 |
+
torchvision
|
| 7 |
+
pillow
|
| 8 |
+
streamlit
|
| 9 |
+
python-multipart
|
| 10 |
+
scikit-learn
|
| 11 |
+
evaluate
|
| 12 |
+
ftfy
|
| 13 |
+
rouge_score
|
train_response_generator.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train_response_generator.py
|
| 2 |
+
import argparse
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from transformers import (T5TokenizerFast, T5ForConditionalGeneration, Trainer, TrainingArguments)
|
| 5 |
+
import numpy as np
|
| 6 |
+
import evaluate
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
p = argparse.ArgumentParser()
|
| 11 |
+
p.add_argument("--save-dir", type=str, default="./emo-response-generator")
|
| 12 |
+
p.add_argument("--num_train_epochs", type=int, default=3)
|
| 13 |
+
p.add_argument("--per_device_train_batch_size", type=int, default=8)
|
| 14 |
+
return p.parse_args()
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
args = parse_args()
|
| 18 |
+
dataset = load_dataset("empathetic_dialogues")
|
| 19 |
+
|
| 20 |
+
tokenizer = T5TokenizerFast.from_pretrained("t5-small")
|
| 21 |
+
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
| 22 |
+
|
| 23 |
+
def preprocess(examples):
|
| 24 |
+
prompts = []
|
| 25 |
+
targets = []
|
| 26 |
+
for ctx, resp, emo in zip(examples["context"], examples["response"], examples["emotion"]):
|
| 27 |
+
prefix = f"emotion: {emo} context: "
|
| 28 |
+
ctx_text = " ".join(ctx) if isinstance(ctx, list) else ctx
|
| 29 |
+
prompts.append(prefix + ctx_text)
|
| 30 |
+
targets.append(resp)
|
| 31 |
+
model_inputs = tokenizer(prompts, max_length=256, truncation=True, padding="max_length")
|
| 32 |
+
labels = tokenizer(text_target=targets, max_length=64, truncation=True, padding="max_length")
|
| 33 |
+
model_inputs["labels"] = labels["input_ids"]
|
| 34 |
+
return model_inputs
|
| 35 |
+
|
| 36 |
+
tokenized = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)
|
| 37 |
+
|
| 38 |
+
training_args = TrainingArguments(
|
| 39 |
+
output_dir=args.save_dir,
|
| 40 |
+
evaluation_strategy="epoch",
|
| 41 |
+
save_strategy="epoch",
|
| 42 |
+
learning_rate=3e-4,
|
| 43 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 44 |
+
per_device_eval_batch_size=16,
|
| 45 |
+
num_train_epochs=args.num_train_epochs,
|
| 46 |
+
weight_decay=0.01,
|
| 47 |
+
logging_steps=200,
|
| 48 |
+
predict_with_generate=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
rouge = evaluate.load("rouge")
|
| 52 |
+
def compute_metrics(eval_pred):
|
| 53 |
+
preds, labels = eval_pred
|
| 54 |
+
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
| 55 |
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
| 56 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 57 |
+
result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
| 58 |
+
return {k: float(v.mid.fmeasure * 100) for k, v in result.items()}
|
| 59 |
+
|
| 60 |
+
trainer = Trainer(
|
| 61 |
+
model=model,
|
| 62 |
+
args=training_args,
|
| 63 |
+
train_dataset=tokenized["train"],
|
| 64 |
+
eval_dataset=tokenized["validation"],
|
| 65 |
+
tokenizer=tokenizer,
|
| 66 |
+
compute_metrics=compute_metrics
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
trainer.train()
|
| 70 |
+
trainer.save_model(args.save_dir)
|
| 71 |
+
tokenizer.save_pretrained(args.save_dir)
|
| 72 |
+
print(f"Saved response generator to {args.save_dir}")
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
train_text_emotion.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train_text_emotion.py
|
| 2 |
+
import argparse
|
| 3 |
+
from datasets import load_dataset, ClassLabel
|
| 4 |
+
from transformers import (DistilBertTokenizerFast, DistilBertForSequenceClassification,
|
| 5 |
+
Trainer, TrainingArguments)
|
| 6 |
+
import numpy as np
|
| 7 |
+
import evaluate
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
def parse_args():
|
| 11 |
+
p = argparse.ArgumentParser()
|
| 12 |
+
p.add_argument("--save-dir", type=str, default="./emo-text-emotion")
|
| 13 |
+
p.add_argument("--num_train_epochs", type=int, default=3)
|
| 14 |
+
p.add_argument("--per_device_train_batch_size", type=int, default=16)
|
| 15 |
+
return p.parse_args()
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
args = parse_args()
|
| 19 |
+
dataset = load_dataset("go_emotions")
|
| 20 |
+
# Simplify multi-label to single-label for demo: pick first label if exists
|
| 21 |
+
def to_single_label(example):
|
| 22 |
+
labels = example.get("labels", [])
|
| 23 |
+
example["label"] = labels[0] if labels else 27 # 27 ~ neutral
|
| 24 |
+
return example
|
| 25 |
+
|
| 26 |
+
dataset = dataset.map(to_single_label)
|
| 27 |
+
|
| 28 |
+
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
|
| 29 |
+
|
| 30 |
+
def preprocess(examples):
|
| 31 |
+
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
|
| 32 |
+
|
| 33 |
+
tokenized = dataset.map(preprocess, batched=True)
|
| 34 |
+
tokenized = tokenized.rename_column("label", "labels")
|
| 35 |
+
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
|
| 36 |
+
|
| 37 |
+
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=28)
|
| 38 |
+
|
| 39 |
+
metric = evaluate.load("accuracy")
|
| 40 |
+
def compute_metrics(eval_pred):
|
| 41 |
+
logits, labels = eval_pred
|
| 42 |
+
preds = np.argmax(logits, axis=-1)
|
| 43 |
+
return metric.compute(predictions=preds, references=labels)
|
| 44 |
+
|
| 45 |
+
training_args = TrainingArguments(
|
| 46 |
+
output_dir=args.save_dir,
|
| 47 |
+
evaluation_strategy="epoch",
|
| 48 |
+
save_strategy="epoch",
|
| 49 |
+
learning_rate=2e-5,
|
| 50 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 51 |
+
per_device_eval_batch_size=64,
|
| 52 |
+
num_train_epochs=args.num_train_epochs,
|
| 53 |
+
weight_decay=0.01,
|
| 54 |
+
logging_steps=200
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
trainer = Trainer(
|
| 58 |
+
model=model,
|
| 59 |
+
args=training_args,
|
| 60 |
+
train_dataset=tokenized["train"],
|
| 61 |
+
eval_dataset=tokenized["validation"],
|
| 62 |
+
tokenizer=tokenizer,
|
| 63 |
+
compute_metrics=compute_metrics
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
trainer.train()
|
| 67 |
+
trainer.save_model(args.save_dir)
|
| 68 |
+
tokenizer.save_pretrained(args.save_dir)
|
| 69 |
+
print(f"Saved text emotion model to {args.save_dir}")
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils.py
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
EMOTION_LABELS = [
|
| 8 |
+
"admiration","amusement","anger","annoyance","approval","caring","confusion","curiosity",
|
| 9 |
+
"desire","disappointment","disapproval","disgust","embarrassment","excitement","fear",
|
| 10 |
+
"gratitude","grief","joy","love","nervousness","optimism","pride","realization","relief",
|
| 11 |
+
"remorse","sadness","surprise","neutral"
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
def load_clip():
|
| 15 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 16 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 17 |
+
return model, processor
|
| 18 |
+
|
| 19 |
+
def predict_image_emotion(image_path: str, top_k: int = 3):
|
| 20 |
+
"""Zero-shot emotion detection using CLIP: compute similarity between image and emotion text prompts."""
|
| 21 |
+
model, processor = load_clip()
|
| 22 |
+
from PIL import Image
|
| 23 |
+
img = Image.open(image_path).convert("RGB")
|
| 24 |
+
inputs = processor(text=EMOTION_LABELS, images=img, return_tensors="pt", padding=True)
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
image_features = model.get_image_features(inputs=inputs["pixel_values"])
|
| 27 |
+
text_inputs = processor(text=EMOTION_LABELS, return_tensors="pt", padding=True)
|
| 28 |
+
text_features = model.get_text_features(input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"])
|
| 29 |
+
# normalize and compute cosine
|
| 30 |
+
img_feat = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
|
| 31 |
+
txt_feat = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
|
| 32 |
+
sims = (img_feat @ txt_feat.T).squeeze(0).cpu().numpy()
|
| 33 |
+
idx = np.argsort(-sims)[:top_k]
|
| 34 |
+
return [EMOTION_LABELS[i] for i in idx]
|
| 35 |
+
|
| 36 |
+
def combine_emotions(text_emotion: Optional[str], image_emotions: Optional[List[str]]):
|
| 37 |
+
parts = []
|
| 38 |
+
if text_emotion:
|
| 39 |
+
parts.append(text_emotion)
|
| 40 |
+
if image_emotions:
|
| 41 |
+
parts.extend(image_emotions[:2])
|
| 42 |
+
seen = set()
|
| 43 |
+
combined = []
|
| 44 |
+
for p in parts:
|
| 45 |
+
if p not in seen:
|
| 46 |
+
combined.append(p)
|
| 47 |
+
seen.add(p)
|
| 48 |
+
if not combined:
|
| 49 |
+
return "neutral"
|
| 50 |
+
return ", ".join(combined)
|