import os import pandas as pd from PIL import Image, UnidentifiedImageError import torch from torchvision import transforms from transformers import AutoProcessor, FocalNetForImageClassification import pyarrow as pa import pyarrow.parquet as pq # 画像フォルダとモデルのパスを指定 image_folder = "scraped_images" # 画像フォルダのパス model_path = "MichalMlodawski/nsfw-image-detection-large" # NSFWモデルのパス # サブフォルダを含めてjpgファイルを再帰的に取得 jpg_files = [] for root, dirs, files in os.walk(image_folder): for file in files: if file.lower().endswith(".jpg"): jpg_files.append(os.path.join(root, file)) # jpgファイルが存在するか確認 if not jpg_files: print("No jpg files found in folder:", image_folder) exit() # モデルとプロセッサの読み込み feature_extractor = AutoProcessor.from_pretrained(model_path) model = FocalNetForImageClassification.from_pretrained(model_path) model.eval() # 画像の変換処理 transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ラベルとNSFWカテゴリのマッピング label_to_category = { "LABEL_0": "Safe", "LABEL_1": "Questionable", "LABEL_2": "Unsafe" } # 結果を保存するためのリスト results = [] # ログファイルを作成(破損画像ファイルを記録) error_log = "error_log.txt" # 各画像に対して分類処理を行い、結果を取得 for jpg_file in jpg_files: try: # 画像を開く image = Image.open(jpg_file).convert("RGB") except UnidentifiedImageError: # 画像を識別できない場合のエラーハンドリング with open(error_log, "a", encoding="utf-8") as log_file: log_file.write(f"Unidentified image file: {jpg_file}. Skipping...\n") print(f"Unidentified image file: {jpg_file}. Skipping...") continue image_tensor = transform(image).unsqueeze(0) # モデルでの推論 inputs = feature_extractor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) confidence, predicted = torch.max(probabilities, 1) # ラベルを取得 label = model.config.id2label[predicted.item()] category = label_to_category.get(label, "Unknown") # 結果をリストに追加 results.append({ "file_path": jpg_file, "label": label, "category": category, "confidence": confidence.item() * 100 }) # 結果をDataFrameに変換 df = pd.DataFrame(results) # Parquet形式で保存 parquet_file = "nsfw_classification_results.parquet" table = pa.Table.from_pandas(df) pq.write_table(table, parquet_file) print(f"Classification completed and saved to {parquet_file}!")