import os import clip import torch import logging import json import pickle from PIL import Image import gradio as gr from scipy.special import softmax logger = logging.getLogger("basebody") CLIP_MODEL_NAME = "ViT-B/16" TEXT_PROMPTS_FILE_NAME = "text_prompts.json" LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl" HF_TOKEN = os.getenv('HF_TOKEN') hf_writer = gr.HuggingFaceDatasetSaver( HF_TOKEN, "Roblox/basebody_feedback" ) clip_model, preprocess = clip.load( CLIP_MODEL_NAME, device="cpu" ) with open( os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r" ) as f: text_prompts = json.load(f) with open( os.path.join( os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME ), "rb", ) as f: lr_model = pickle.load(f) logger.info("Logistic regression model loaded, coefficients: ") all_text_features = [] with torch.no_grad(): for k, prompts in text_prompts.items(): assert len(prompts) == 2 inputs = clip.tokenize(prompts) outputs = clip_model.encode_text(inputs) all_text_features.append(outputs) all_text_features = torch.cat(all_text_features, dim=0) all_text_features = all_text_features.cpu() def predict_fn(input_img): input_img = Image.fromarray(input_img.astype("uint8"), "RGB") image = preprocess( input_img ).unsqueeze(0) with torch.no_grad(): image_features = clip_model.encode_image(image) cosine_simlarities = softmax( (all_text_features @ image_features.cpu().T) .squeeze() .reshape(len(text_prompts), 2, -1), axis=1, )[:, 0, :] # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}") logger.info(f"cosine_simlarities: {cosine_simlarities}") probabilities = lr_model.predict_proba( cosine_simlarities.reshape(1, -1) ) logger.info(f"probabilities: {probabilities}") decision_json = json.dumps( {"is_base_body": float(probabilities[0][1].round(3))} ).encode("utf-8") logger.info(f"decision_json: {decision_json}") return decision_json iface = gr.Interface( fn=predict_fn, inputs="image", outputs="text", description=""" The model returns the probability of the image being a base body. If probability > 0.95, the image can be automatically tagged as a base body. If probability < 0.4, the image can be automatically REJECTED as NOT as base body. All other cases will be submitted for moderation. To minimize noise, please flag the input only if you think there is an error in the probability returned by the model and it is off by at least 0.2 """, allow_flagging="manual", flagging_options=[ "probability error > 0.5", "0.2 < probability error < 0.5", "probability error < 0.2" ], flagging_callback=hf_writer ) iface.launch()