STron
Changed issues due to pull
43d0a4a
import gradio as gr
import onnxruntime as ort
from transformers import RobertaTokenizer, ViTImageProcessor
from PIL import Image
import numpy as np
import torch
import os
import time
import logging
# Setup logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model_path = "./multimodal_model.onnx"
try:
if not os.path.exists(model_path):
raise FileNotFoundError(f"ONNX model not found at {model_path}")
logger.info(f"Loading ONNX model from {model_path}")
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 0
ort_session = ort.InferenceSession(
model_path,
sess_options=sess_options,
providers=['CPUExecutionProvider']
)
logger.info("ONNX model loaded successfully")
input_names = [input.name for input in ort_session.get_inputs()]
input_shapes = {input.name: input.shape for input in ort_session.get_inputs()}
output_names = [output.name for output in ort_session.get_outputs()]
logger.info(f"Model inputs: {input_names} with shapes {input_shapes}")
logger.info(f"Model outputs: {output_names}")
except Exception as e:
logger.error(f"Error loading ONNX model: {e}")
raise
labels = ["Real", "Real Text with fake image", "Fake"]
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
return e_x / e_x.sum(axis=1, keepdims=True)
def image_with_prediction(img, label, confidence):
"""Return the original image with an overlay showing the prediction"""
from PIL import Image, ImageDraw, ImageFont
img_copy = img.copy()
draw = ImageDraw.Draw(img_copy)
width, height = img_copy.size
overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150))
img_copy.paste(overlay, (0, height-40), overlay)
text = f"{label}: {confidence:.1%}"
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
try:
text_width = draw.textlength(text, font=font)
except AttributeError:
text_width = font.getsize(text)[0] if hasattr(font, 'getsize') else 200
text_position = ((width - text_width) // 2, height - 35)
draw.text(text_position, text, fill=(255, 255, 255), font=font)
return img_copy
def predict_news(text, image):
if text is None or text.strip() == "":
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please enter some text to analyze."
if image is None:
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please upload an image to analyze."
try:
logger.info(f"Processing text: {text[:50]}...")
logger.info(f"Processing image size: {image.size}")
# Process text input
inputs = tokenizer.encode_plus(text, add_special_tokens = True, return_tensors='np', max_length=80, truncation=True, padding='max_length')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
logger.info(f"Input IDs shape: {input_ids.shape}")
logger.info(f"Attention mask shape: {attention_mask.shape}")
# Process image input
image_processed = vit_processor(images=image, return_tensors="np")["pixel_values"]
logger.info(f"Processed image shape: {image_processed.shape}")
ort_inputs = {}
for input_meta in ort_session.get_inputs():
input_name = input_meta.name
if 'ids' in input_name.lower() or input_name == 'text_input_ids':
ort_inputs[input_name] = input_ids
elif 'mask' in input_name.lower() or input_name == 'text_attention_mask':
ort_inputs[input_name] = attention_mask
elif 'image' in input_name.lower() or input_name == 'image_input':
ort_inputs[input_name] = image_processed
logger.info(f"ONNX input keys: {list(ort_inputs.keys())}")
# Run inference
start_time = time.time()
logger.info("Starting inference")
outputs = ort_session.run(None, ort_inputs)
inference_time = time.time() - start_time
logger.info(f"Inference completed in {inference_time:.3f}s")
# Process model outputs
logits = outputs[0]
logger.info(f"Raw output shape: {logits.shape}, values: {logits}")
probs = softmax(logits)[0]
logger.info(f"Probabilities: {probs}")
pred_idx = int(np.argmax(probs))
confidence = float(probs[pred_idx])
if pred_idx == 1:
color = "orange"
message = f"This content appears to be **REAL TEXT WITH FAKE IMAGE** with {confidence:.1%} confidence."
elif pred_idx == 2:
color = "red"
message = f"This content appears to contain **FAKE** with {confidence:.1%} confidence."
else:
color = "green"
message = f"This content appears to be **REAL** with {confidence:.1%} confidence."
analysis = f"""
<div style='text-align: center; padding: 10px; background-color: {color}15; border-radius: 5px; margin-top: 10px;'>
<span style='font-size: 18px; color: {color}; font-weight: bold;'>{message}</span>
<p>Inference time: {inference_time:.3f} seconds</p>
</div>
"""
result = {
labels[0]: float(probs[0]),
labels[1]: float(probs[1]),
labels[2]: float(probs[2])
}
interpretation = image_with_prediction(image, labels[pred_idx], confidence)
return result, interpretation, analysis
except Exception as e:
logger.error(f"Error during analysis: {str(e)}", exc_info=True)
return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, f"Error during analysis: {str(e)}"
examples = [
["COVID-19 vaccine causes severe side effects in 80% of recipients", "https://images.unsplash.com/photo-1605289982774-9a6fef564df8?q=80&w=1000&auto=format&fit=crop"],
["Scientists discover new species of deep-sea fish", "https://images.unsplash.com/photo-1524704796725-9fc3044a58b2?q=80&w=1000&auto=format&fit=crop"],
]
# Build Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 📰 Fake News Detector (RoBERTa + ViT)
This multimodal AI system analyzes both text and images to detect potentially fake news content.
Upload an image and enter a news headline to see if the combination is likely to be real or fake news.
"""
)
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="News Headline / Text",
placeholder="Enter the news headline or text here...",
lines=3
)
image_input = gr.Image(type="pil", label="Associated Image")
analyze_btn = gr.Button("Analyze Content", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(label="Prediction Probabilities")
image_output = gr.Image(type="pil", label="Visual Analysis")
analysis_html = gr.HTML(label="Analysis")
gr.Examples(
examples=examples,
inputs=[text_input, image_input],
outputs=[label_output, image_output, analysis_html],
fn=predict_news,
cache_examples=True,
)
gr.Markdown(
"""
This system combines:
- **RoBERTa**: Analyzes the textual content
- **ViT**: Processes the image data
- **Multimodal Fusion**: Combines both signals to make a prediction
The model was trained on the Fakeddit dataset containing real and fake news pairs with both text and images.
"""
)
analyze_btn.click(
predict_news,
inputs=[text_input, image_input],
outputs=[label_output, image_output, analysis_html]
)
if __name__ == "__main__":
logger.info("Starting Gradio application")
demo.launch()