|
import os |
|
import glob |
|
import pandas as pd |
|
import argparse |
|
from google import genai |
|
from tqdm import tqdm |
|
import time |
|
import re |
|
from word_segmentation_vi import word_segmentation_vi |
|
|
|
def setup_genai(api_key): |
|
"""Configure the Google Generative AI client with your API key""" |
|
return genai.Client(api_key=api_key) |
|
|
|
def classify_text(model, text, suggest_label=False): |
|
"""Classify Vietnamese text into hate speech categories using Google's Generative AI""" |
|
prompt = f""" |
|
Analyze the following Vietnamese text for hate speech (each sentence is separated by a newline): |
|
"{text}" |
|
|
|
Rate it on these categories (0=NORMAL, 1=CLEAN, 2=OFFENSIVE, 3=HATE): |
|
- individual (targeting specific individuals) |
|
- groups (targeting groups or organizations) |
|
- religion/creed (targeting religious groups or beliefs) |
|
- race/ethnicity (racial/ethnic hate speech) |
|
- politics (political hate speech) |
|
If the text doesn't specify a person or group in a category, return 0 for that category. |
|
Else, return 1 for CLEAN, 2 for OFFENSIVE, or 3 for HATE. |
|
|
|
{'The number at the end of the sentence (between <SuggestLabel> and </SuggestLabel> tags is the suggestion label for the sentence. (0 is normal/clean, 1 is offensive/hate in at least one category)' if suggest_label else ''} |
|
|
|
For each sentence in the text, return only 5 numbers separated by commas (corresponding to the label of individual, groups, religion/creed, race/ethnicity, politics) and numbers for each sentence seperated by newlines, like (with no other text): |
|
0,1,0,0,0 |
|
1,0,0,0,2 |
|
""" |
|
|
|
try: |
|
response = model.models.generate_content(model="gemini-2.0-flash", contents=prompt) |
|
values = response.text.strip().split('\n') |
|
values = [line.split(',') for line in values] |
|
return values |
|
|
|
except Exception as e: |
|
print(f"Error classifying text: {e}") |
|
return None |
|
|
|
def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="free_text", suggest_column="labels"): |
|
"""Process a single CSV file to match the test.csv format""" |
|
print(f"Processing {input_file}...") |
|
|
|
|
|
try: |
|
df = pd.read_csv(input_file) |
|
except Exception as e: |
|
print(f"Error reading {input_file}: {e}") |
|
return |
|
|
|
|
|
if text_col in df.columns: |
|
df.rename(columns={text_col: 'content'}, inplace=True) |
|
elif 'content' not in df.columns: |
|
print(f"Error: 'content' column not found in {input_file}") |
|
return |
|
|
|
|
|
category_columns = ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics'] |
|
for col in category_columns: |
|
if col not in df.columns: |
|
|
|
df[col] = 0 |
|
|
|
print("Suggesting labels: ", 'True' if suggest_column in df.columns else 'False') |
|
|
|
|
|
batch_size = 100 |
|
for start in tqdm(range(0, len(df), batch_size), desc="Processing batches"): |
|
end = min(start + batch_size, len(df)) |
|
batch_df = df.iloc[start:end] |
|
|
|
|
|
if all(batch_df[cat].all() != 0 for cat in category_columns): |
|
continue |
|
|
|
|
|
batch_strings = [str(sentence) for sentence in batch_df['content'].tolist()] |
|
suggest_label = False |
|
if suggest_column in df.columns: |
|
batch_strings = [str(sentence) + " " + f"<SuggestLabel>{str(label)}</SuggestLabel>" for sentence, label in zip(batch_strings, batch_df[suggest_column].tolist())] |
|
suggest_label = True |
|
|
|
|
|
text_to_classify = "\n".join(batch_strings) |
|
classifications = classify_text(model, text_to_classify, suggest_label=suggest_label) |
|
|
|
|
|
|
|
if classifications is None: |
|
for _ in range(2): |
|
classifications = classify_text(model, text_to_classify) |
|
if classifications is not None: |
|
break |
|
time.sleep(rate_limit_pause) |
|
else: |
|
print(f"Error classifying batch starting at index {start}. Skipping...") |
|
continue |
|
|
|
try: |
|
|
|
for i, row in enumerate(classifications): |
|
for j, col in enumerate(category_columns): |
|
df.at[start + i, col] = int(row[j]) |
|
except Exception as e: |
|
for _ in range(2): |
|
classifications = classify_text(model, text_to_classify) |
|
if classifications is not None: |
|
break |
|
time.sleep(rate_limit_pause) |
|
else: |
|
print(f"Error classifying batch starting at index {start}. Skipping...") |
|
continue |
|
|
|
try: |
|
for i, row in enumerate(classifications): |
|
for j, col in enumerate(category_columns): |
|
df.at[start + i, col] = int(row[j]) |
|
except Exception as e: |
|
print(f"Error updating DataFrame: {e}") |
|
continue |
|
|
|
time.sleep(rate_limit_pause) |
|
|
|
|
|
df['content'] = df['content'].apply(lambda x: word_segmentation_vi(str(x))) |
|
|
|
|
|
for col in category_columns: |
|
df[col] = df[col].astype(int) |
|
|
|
if 'label_id' in df.columns: |
|
df.drop(columns=['label_id'], inplace=True) |
|
df.to_csv(output_file, index=False) |
|
print(f"Saved processed file to {output_file}") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Process ViHSD CSV files with Google Generative AI") |
|
parser.add_argument("--input_dir", required=True, help="Directory containing input CSV files") |
|
parser.add_argument("--output_dir", required=True, help="Directory to save processed files") |
|
parser.add_argument("--api_key", required=True, help="Google Generative AI API key") |
|
parser.add_argument("--pause", type=float, default=4.0, help="Pause between API calls (seconds)") |
|
parser.add_argument("--text_col", default="free_text", help="Column name for text content in input CSV files") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
model = setup_genai(args.api_key) |
|
|
|
|
|
csv_files = glob.glob(os.path.join(args.input_dir, "*.csv")) |
|
if not csv_files: |
|
print(f"No CSV files found in {args.input_dir}") |
|
return |
|
|
|
print(f"Found {len(csv_files)} CSV files to process") |
|
|
|
|
|
for input_file in csv_files: |
|
output_file = os.path.join(args.output_dir, os.path.basename(input_file)) |
|
if os.path.exists(output_file): |
|
print(f"Output file {output_file} already exists. Skipping...") |
|
continue |
|
process_file(input_file, output_file, model, args.pause, text_col=args.text_col) |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
main() |