|
import gradio as gr |
|
import onnxruntime as ort |
|
from transformers import RobertaTokenizer, ViTImageProcessor |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
import os |
|
import time |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
|
|
|
model_path = "./multimodal_model.onnx" |
|
try: |
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"ONNX model not found at {model_path}") |
|
|
|
logger.info(f"Loading ONNX model from {model_path}") |
|
sess_options = ort.SessionOptions() |
|
sess_options.log_severity_level = 0 |
|
ort_session = ort.InferenceSession( |
|
model_path, |
|
sess_options=sess_options, |
|
providers=['CPUExecutionProvider'] |
|
) |
|
logger.info("ONNX model loaded successfully") |
|
|
|
input_names = [input.name for input in ort_session.get_inputs()] |
|
input_shapes = {input.name: input.shape for input in ort_session.get_inputs()} |
|
output_names = [output.name for output in ort_session.get_outputs()] |
|
|
|
logger.info(f"Model inputs: {input_names} with shapes {input_shapes}") |
|
logger.info(f"Model outputs: {output_names}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading ONNX model: {e}") |
|
raise |
|
|
|
labels = ["Real", "Real Text with fake image", "Fake"] |
|
|
|
def softmax(x): |
|
"""Compute softmax values for each sets of scores in x.""" |
|
e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) |
|
return e_x / e_x.sum(axis=1, keepdims=True) |
|
|
|
def image_with_prediction(img, label, confidence): |
|
"""Return the original image with an overlay showing the prediction""" |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
img_copy = img.copy() |
|
draw = ImageDraw.Draw(img_copy) |
|
|
|
width, height = img_copy.size |
|
|
|
overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150)) |
|
img_copy.paste(overlay, (0, height-40), overlay) |
|
|
|
text = f"{label}: {confidence:.1%}" |
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", 20) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
try: |
|
text_width = draw.textlength(text, font=font) |
|
except AttributeError: |
|
text_width = font.getsize(text)[0] if hasattr(font, 'getsize') else 200 |
|
|
|
text_position = ((width - text_width) // 2, height - 35) |
|
draw.text(text_position, text, fill=(255, 255, 255), font=font) |
|
|
|
return img_copy |
|
|
|
def predict_news(text, image): |
|
if text is None or text.strip() == "": |
|
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please enter some text to analyze." |
|
|
|
if image is None: |
|
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please upload an image to analyze." |
|
|
|
try: |
|
logger.info(f"Processing text: {text[:50]}...") |
|
logger.info(f"Processing image size: {image.size}") |
|
|
|
|
|
inputs = tokenizer.encode_plus(text, add_special_tokens = True, return_tensors='np', max_length=80, truncation=True, padding='max_length') |
|
|
|
input_ids = inputs['input_ids'] |
|
attention_mask = inputs['attention_mask'] |
|
|
|
logger.info(f"Input IDs shape: {input_ids.shape}") |
|
logger.info(f"Attention mask shape: {attention_mask.shape}") |
|
|
|
|
|
image_processed = vit_processor(images=image, return_tensors="np")["pixel_values"] |
|
logger.info(f"Processed image shape: {image_processed.shape}") |
|
|
|
ort_inputs = {} |
|
for input_meta in ort_session.get_inputs(): |
|
input_name = input_meta.name |
|
if 'ids' in input_name.lower() or input_name == 'text_input_ids': |
|
ort_inputs[input_name] = input_ids |
|
elif 'mask' in input_name.lower() or input_name == 'text_attention_mask': |
|
ort_inputs[input_name] = attention_mask |
|
elif 'image' in input_name.lower() or input_name == 'image_input': |
|
ort_inputs[input_name] = image_processed |
|
|
|
logger.info(f"ONNX input keys: {list(ort_inputs.keys())}") |
|
|
|
|
|
start_time = time.time() |
|
logger.info("Starting inference") |
|
outputs = ort_session.run(None, ort_inputs) |
|
inference_time = time.time() - start_time |
|
logger.info(f"Inference completed in {inference_time:.3f}s") |
|
|
|
|
|
logits = outputs[0] |
|
logger.info(f"Raw output shape: {logits.shape}, values: {logits}") |
|
|
|
probs = softmax(logits)[0] |
|
logger.info(f"Probabilities: {probs}") |
|
|
|
pred_idx = int(np.argmax(probs)) |
|
confidence = float(probs[pred_idx]) |
|
|
|
if pred_idx == 1: |
|
color = "orange" |
|
message = f"This content appears to be **REAL TEXT WITH FAKE IMAGE** with {confidence:.1%} confidence." |
|
elif pred_idx == 2: |
|
color = "red" |
|
message = f"This content appears to contain **FAKE** with {confidence:.1%} confidence." |
|
else: |
|
color = "green" |
|
message = f"This content appears to be **REAL** with {confidence:.1%} confidence." |
|
|
|
analysis = f""" |
|
<div style='text-align: center; padding: 10px; background-color: {color}15; border-radius: 5px; margin-top: 10px;'> |
|
<span style='font-size: 18px; color: {color}; font-weight: bold;'>{message}</span> |
|
<p>Inference time: {inference_time:.3f} seconds</p> |
|
</div> |
|
""" |
|
|
|
result = { |
|
labels[0]: float(probs[0]), |
|
labels[1]: float(probs[1]), |
|
labels[2]: float(probs[2]) |
|
} |
|
|
|
interpretation = image_with_prediction(image, labels[pred_idx], confidence) |
|
|
|
return result, interpretation, analysis |
|
|
|
except Exception as e: |
|
logger.error(f"Error during analysis: {str(e)}", exc_info=True) |
|
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, f"Error during analysis: {str(e)}" |
|
|
|
examples = [ |
|
["COVID-19 vaccine causes severe side effects in 80% of recipients", "https://images.unsplash.com/photo-1605289982774-9a6fef564df8?q=80&w=1000&auto=format&fit=crop"], |
|
["Scientists discover new species of deep-sea fish", "https://images.unsplash.com/photo-1524704796725-9fc3044a58b2?q=80&w=1000&auto=format&fit=crop"], |
|
] |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# 📰 Fake News Detector (RoBERTa + ViT) |
|
|
|
This multimodal AI system analyzes both text and images to detect potentially fake news content. |
|
Upload an image and enter a news headline to see if the combination is likely to be real or fake news. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
text_input = gr.Textbox( |
|
label="News Headline / Text", |
|
placeholder="Enter the news headline or text here...", |
|
lines=3 |
|
) |
|
image_input = gr.Image(type="pil", label="Associated Image") |
|
|
|
analyze_btn = gr.Button("Analyze Content", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
label_output = gr.Label(label="Prediction Probabilities") |
|
image_output = gr.Image(type="pil", label="Visual Analysis") |
|
analysis_html = gr.HTML(label="Analysis") |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[text_input, image_input], |
|
outputs=[label_output, image_output, analysis_html], |
|
fn=predict_news, |
|
cache_examples=True, |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
This system combines: |
|
- **RoBERTa**: Analyzes the textual content |
|
- **ViT**: Processes the image data |
|
- **Multimodal Fusion**: Combines both signals to make a prediction |
|
|
|
The model was trained on the Fakeddit dataset containing real and fake news pairs with both text and images. |
|
""" |
|
) |
|
|
|
analyze_btn.click( |
|
predict_news, |
|
inputs=[text_input, image_input], |
|
outputs=[label_output, image_output, analysis_html] |
|
) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting Gradio application") |
|
demo.launch() |