hk-bt-rnd commited on
Commit
ad5ee12
1 Parent(s): c5550af

Init space

Browse files
Files changed (5) hide show
  1. __pycache__/model.cpython-310.pyc +0 -0
  2. app.py +87 -0
  3. model.py +37 -0
  4. requirements.txt +76 -0
  5. weight.pt +3 -0
__pycache__/model.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from matplotlib import cm
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
7
+ from model import Classifier
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # Load model directly
12
+ MODEL_NAME = "cahya/roberta-base-indonesian-522M"
13
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+ class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Fantasy', 'Romance', 'Sci-Fi']
15
+ config = AutoConfig.from_pretrained(MODEL_NAME)
16
+ transformer = AutoModel.from_pretrained(MODEL_NAME)
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+ cp = torch.load(r"weight.pt", map_location="cpu")
19
+ transformer.load_state_dict(cp['w_t'])
20
+ classifier = Classifier(input_size = config.hidden_size, output_sizes = [1, 1, 1, 3, 5])
21
+ classifier.load_state_dict(cp['w_c'])
22
+
23
+ transformer.to(device)
24
+ classifier.to(device)
25
+
26
+ target_names = ["Individual", 'Group']
27
+ strength_names = ["Weak", 'Moderate', 'Strong']
28
+ type_names = ['Religion','Race','Physical','Gender','Other']
29
+
30
+ act_sig = nn.Sigmoid()
31
+ act_soft = nn.Softmax()
32
+
33
+ def predict(sentence):
34
+ # Tokenize the input sentence
35
+ inputs = tokenizer(sentence,
36
+ add_special_tokens = True, \
37
+ max_length = 256, \
38
+ padding = "max_length", \
39
+ truncation = True,
40
+ return_tensors='pt')
41
+
42
+ input_ids = inputs['input_ids'].to(device)
43
+ att_masks = inputs['attention_mask'].to(device)
44
+
45
+ # Get model predictions
46
+ with torch.no_grad():
47
+ out = transformer(input_ids, attention_mask=att_masks)
48
+ logits = out.pooler_output
49
+ out = classifier(logits)
50
+ hs_out, abusive_out, target_out, strength_out, type_out = out[0], out[1], out[2], out[3], out[4]
51
+ hs_act, abusive_act, target_act, strength_act, type_act = act_sig(hs_out).squeeze(), \
52
+ act_sig(abusive_out).squeeze(), act_sig(target_out).squeeze(0), act_soft(strength_out), act_sig(type_out).squeeze(0)
53
+
54
+ # Interpret the predictions
55
+ is_hate_speech = bool(hs_act >= 0.5)
56
+ is_abusive = bool(abusive_act >= 0.5)
57
+ hate_speech_target = int(target_act >= 0.5)
58
+ hate_speech_strength = strength_act.argmax().item()
59
+ if is_hate_speech:
60
+ hate_speech_target_label = target_names[hate_speech_target]
61
+ hate_speech_strength_label = strength_names[hate_speech_strength]
62
+ hate_speech_type_label = []
63
+ print('target', target_act)
64
+ print('strength', strength_act)
65
+
66
+ for idx, prob in enumerate(type_act):
67
+ if prob >= 0.5:
68
+ hate_speech_type_label.append(type_names[idx])
69
+ if len(hate_speech_type_label) == 0:
70
+ hate_speech_type_label.append("Other")
71
+ else:
72
+ hate_speech_target_label = "Non-HS"
73
+ hate_speech_strength_label = "Non-HS"
74
+ hate_speech_type_label = "Non-HS"
75
+
76
+ return is_hate_speech, is_abusive, hate_speech_target_label, hate_speech_strength_label, {"hs_type":hate_speech_type_label}
77
+
78
+ # Create the Gradio interface
79
+ iface = gr.Interface(fn=predict, inputs=gr.Textbox(label="Enter a sentence"), outputs=[
80
+ gr.Label(label="Is Hate Speech"),
81
+ gr.Label(label="Is Abusive"),
82
+ gr.Label(label="Hate Speech Target"),
83
+ gr.Label(label="Hate Speech Strength"),
84
+ gr.JSON(label="Hate Speech Type")
85
+ ], title="Hate Speech Detection")
86
+ iface.launch() # Launches the mini app!
87
+
model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.models as models
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
5
+
6
+ class Classifier(nn.Module):
7
+ def __init__(self, input_size = 512, output_sizes = [1], dropout_rate = 0.1):
8
+ super(Classifier, self).__init__()
9
+
10
+ self.hs_head = nn.Sequential(
11
+ nn.Dropout(dropout_rate),
12
+ nn.Linear(input_size, output_sizes[0])
13
+ )
14
+
15
+ self.abusive_head = nn.Sequential(
16
+ nn.Dropout(dropout_rate),
17
+ nn.Linear(input_size, output_sizes[1])
18
+ )
19
+
20
+ self.target_head = nn.Sequential(
21
+ nn.Dropout(dropout_rate),
22
+ nn.Linear(input_size, output_sizes[2])
23
+ )
24
+
25
+ self.strength_head = nn.Sequential(
26
+ nn.Dropout(dropout_rate),
27
+ nn.Linear(input_size, output_sizes[3])
28
+ )
29
+
30
+ self.type_head = nn.Sequential(
31
+ nn.Dropout(dropout_rate),
32
+ nn.Linear(input_size, output_sizes[4])
33
+ )
34
+
35
+ def forward(self, input):
36
+ return self.hs_head(input), self.abusive_head(input), self.target_head(input), \
37
+ self.strength_head(input), self.type_head(input)
requirements.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.2.0
3
+ annotated-types==0.6.0
4
+ anyio==4.3.0
5
+ attrs==23.2.0
6
+ certifi==2024.2.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.2.0
11
+ cycler==0.12.1
12
+ exceptiongroup==1.2.0
13
+ fastapi==0.110.0
14
+ ffmpy==0.3.2
15
+ filelock==3.13.1
16
+ fonttools==4.50.0
17
+ fsspec==2024.3.1
18
+ gradio==4.22.0
19
+ gradio_client==0.13.0
20
+ h11==0.14.0
21
+ httpcore==1.0.4
22
+ httpx==0.27.0
23
+ huggingface-hub==0.21.4
24
+ idna==3.6
25
+ importlib_resources==6.3.2
26
+ Jinja2==3.1.3
27
+ jsonschema==4.21.1
28
+ jsonschema-specifications==2023.12.1
29
+ kiwisolver==1.4.5
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==2.1.5
32
+ matplotlib==3.8.3
33
+ mdurl==0.1.2
34
+ mpmath==1.3.0
35
+ networkx==3.2.1
36
+ numpy==1.26.4
37
+ orjson==3.9.15
38
+ packaging==24.0
39
+ pandas==2.2.1
40
+ pillow==10.2.0
41
+ pydantic==2.6.4
42
+ pydantic_core==2.16.3
43
+ pydub==0.25.1
44
+ Pygments==2.17.2
45
+ pyparsing==3.1.2
46
+ python-dateutil==2.9.0.post0
47
+ python-multipart==0.0.9
48
+ pytz==2024.1
49
+ PyYAML==6.0.1
50
+ referencing==0.34.0
51
+ regex==2023.12.25
52
+ requests==2.31.0
53
+ rich==13.7.1
54
+ rpds-py==0.18.0
55
+ ruff==0.3.3
56
+ safetensors==0.4.2
57
+ semantic-version==2.10.0
58
+ shellingham==1.5.4
59
+ six==1.16.0
60
+ sniffio==1.3.1
61
+ starlette==0.36.3
62
+ sympy==1.12
63
+ tokenizers==0.15.2
64
+ tomlkit==0.12.0
65
+ toolz==0.12.1
66
+ torch==2.2.1
67
+ torchaudio==2.2.1
68
+ torchvision==0.17.1
69
+ tqdm==4.66.2
70
+ transformers==4.38.2
71
+ typer==0.9.0
72
+ typing_extensions==4.10.0
73
+ tzdata==2024.1
74
+ urllib3==2.2.1
75
+ uvicorn==0.29.0
76
+ websockets==11.0.3
weight.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb083732f7dd150113bba50f7f5125a4d3b83adf98db912d64394d38d9290e1b
3
+ size 504022203