File size: 8,455 Bytes
7df2acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43d0a4a
7df2acb
 
 
 
43d0a4a
7df2acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import gradio as gr
import onnxruntime as ort
from transformers import RobertaTokenizer, ViTImageProcessor
from PIL import Image
import numpy as np
import torch
import os
import time
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

model_path = "./multimodal_model.onnx"
try:
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"ONNX model not found at {model_path}")
    
    logger.info(f"Loading ONNX model from {model_path}")
    sess_options = ort.SessionOptions()
    sess_options.log_severity_level = 0 
    ort_session = ort.InferenceSession(
        model_path, 
        sess_options=sess_options,
        providers=['CPUExecutionProvider']
    )
    logger.info("ONNX model loaded successfully")
    
    input_names = [input.name for input in ort_session.get_inputs()]
    input_shapes = {input.name: input.shape for input in ort_session.get_inputs()}
    output_names = [output.name for output in ort_session.get_outputs()]
    
    logger.info(f"Model inputs: {input_names} with shapes {input_shapes}")
    logger.info(f"Model outputs: {output_names}")
    
except Exception as e:
    logger.error(f"Error loading ONNX model: {e}")
    raise

labels = ["Real", "Real Text with fake image", "Fake"]  

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / e_x.sum(axis=1, keepdims=True)

def image_with_prediction(img, label, confidence):
    """Return the original image with an overlay showing the prediction"""
    from PIL import Image, ImageDraw, ImageFont
    
    img_copy = img.copy()
    draw = ImageDraw.Draw(img_copy)
    
    width, height = img_copy.size
    
    overlay = Image.new('RGBA', (width, 40), (0, 0, 0, 150))
    img_copy.paste(overlay, (0, height-40), overlay)
    
    text = f"{label}: {confidence:.1%}"
    
    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except IOError:
        font = ImageFont.load_default()
    
    try:
        text_width = draw.textlength(text, font=font)
    except AttributeError:
        text_width = font.getsize(text)[0] if hasattr(font, 'getsize') else 200
    
    text_position = ((width - text_width) // 2, height - 35)
    draw.text(text_position, text, fill=(255, 255, 255), font=font)
    
    return img_copy

def predict_news(text, image):
    if text is None or text.strip() == "":
        return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please enter some text to analyze."
    
    if image is None:
        return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, "Please upload an image to analyze."
    
    try:
        logger.info(f"Processing text: {text[:50]}...")
        logger.info(f"Processing image size: {image.size}")
        
        # Process text input
        inputs = tokenizer.encode_plus(text, add_special_tokens = True, return_tensors='np', max_length=80, truncation=True, padding='max_length')
        
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        logger.info(f"Input IDs shape: {input_ids.shape}")
        logger.info(f"Attention mask shape: {attention_mask.shape}")
        
        # Process image input
        image_processed = vit_processor(images=image, return_tensors="np")["pixel_values"]
        logger.info(f"Processed image shape: {image_processed.shape}")

        ort_inputs = {}
        for input_meta in ort_session.get_inputs():
            input_name = input_meta.name
            if 'ids' in input_name.lower() or input_name == 'text_input_ids':
                ort_inputs[input_name] = input_ids
            elif 'mask' in input_name.lower() or input_name == 'text_attention_mask':
                ort_inputs[input_name] = attention_mask
            elif 'image' in input_name.lower() or input_name == 'image_input':
                ort_inputs[input_name] = image_processed
        
        logger.info(f"ONNX input keys: {list(ort_inputs.keys())}")
        
        # Run inference
        start_time = time.time()
        logger.info("Starting inference")
        outputs = ort_session.run(None, ort_inputs)
        inference_time = time.time() - start_time
        logger.info(f"Inference completed in {inference_time:.3f}s")
        
        # Process model outputs
        logits = outputs[0]
        logger.info(f"Raw output shape: {logits.shape}, values: {logits}")
        
        probs = softmax(logits)[0]
        logger.info(f"Probabilities: {probs}")
        
        pred_idx = int(np.argmax(probs))
        confidence = float(probs[pred_idx])
        
        if pred_idx == 1: 
            color = "orange"
            message = f"This content appears to be **REAL TEXT WITH FAKE IMAGE** with {confidence:.1%} confidence."
        elif pred_idx == 2:
            color = "red"
            message = f"This content appears to contain **FAKE** with {confidence:.1%} confidence."
        else:
            color = "green"
            message = f"This content appears to be **REAL** with {confidence:.1%} confidence."

        analysis = f"""
        <div style='text-align: center; padding: 10px; background-color: {color}15; border-radius: 5px; margin-top: 10px;'>
            <span style='font-size: 18px; color: {color}; font-weight: bold;'>{message}</span>
            <p>Inference time: {inference_time:.3f} seconds</p>
        </div>
        """
        
        result = {
            labels[0]: float(probs[0]), 
            labels[1]: float(probs[1]), 
            labels[2]: float(probs[2])
        }
        
        interpretation = image_with_prediction(image, labels[pred_idx], confidence)
        
        return result, interpretation, analysis
        
    except Exception as e:
        logger.error(f"Error during analysis: {str(e)}", exc_info=True)
        return {labels[0]: 0.0, labels[1]: 0.0, labels[2]: 0.0}, None, f"Error during analysis: {str(e)}"

examples = [
    ["COVID-19 vaccine causes severe side effects in 80% of recipients", "https://images.unsplash.com/photo-1605289982774-9a6fef564df8?q=80&w=1000&auto=format&fit=crop"],
    ["Scientists discover new species of deep-sea fish", "https://images.unsplash.com/photo-1524704796725-9fc3044a58b2?q=80&w=1000&auto=format&fit=crop"],
]
    
# Build Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 📰 Fake News Detector (RoBERTa + ViT)
        
        This multimodal AI system analyzes both text and images to detect potentially fake news content.
        Upload an image and enter a news headline to see if the combination is likely to be real or fake news.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            text_input = gr.Textbox(
                label="News Headline / Text",
                placeholder="Enter the news headline or text here...",
                lines=3
            )
            image_input = gr.Image(type="pil", label="Associated Image")
            
            analyze_btn = gr.Button("Analyze Content", variant="primary")
        
        with gr.Column(scale=1):
            label_output = gr.Label(label="Prediction Probabilities")
            image_output = gr.Image(type="pil", label="Visual Analysis")
            analysis_html = gr.HTML(label="Analysis")
    
    gr.Examples(
        examples=examples,
        inputs=[text_input, image_input],
        outputs=[label_output, image_output, analysis_html],
        fn=predict_news,
        cache_examples=True,
    )
    
    gr.Markdown(
        """
        This system combines:
        - **RoBERTa**: Analyzes the textual content
        - **ViT**: Processes the image data
        - **Multimodal Fusion**: Combines both signals to make a prediction
        
        The model was trained on the Fakeddit dataset containing real and fake news pairs with both text and images.
        """
    )
    
    analyze_btn.click(
        predict_news,
        inputs=[text_input, image_input],
        outputs=[label_output, image_output, analysis_html]
    )

if __name__ == "__main__":
    logger.info("Starting Gradio application")
    demo.launch()