|
import os |
|
import sys |
|
import subprocess |
|
import socket |
|
|
|
|
|
def install_packages(): |
|
""" |
|
Install required packages if they are not already installed. |
|
""" |
|
packages = [ |
|
"gradio==4.44.1", |
|
"transformers", |
|
"torch", |
|
"huggingface_hub" |
|
] |
|
for package in packages: |
|
try: |
|
__import__(package.split('==')[0]) |
|
except ImportError: |
|
print(f"Installing {package}...") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
|
|
|
install_packages() |
|
|
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline |
|
|
|
class HallucinationMitigator: |
|
def __init__(self, model_name="distilgpt2", penalty_weight=0.5, threshold=0.8): |
|
""" |
|
Initialize the hallucination mitigation text generator. |
|
Using distilgpt2 for lighter weight and better performance on CPU |
|
""" |
|
self.model_name = model_name |
|
self.penalty_weight = penalty_weight |
|
self.threshold = threshold |
|
|
|
|
|
try: |
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
print(f"Using device: {'CUDA' if device == 0 else 'CPU'}") |
|
except: |
|
|
|
device = -1 |
|
print("Defaulting to CPU") |
|
|
|
|
|
self.generator = pipeline( |
|
"text-generation", |
|
model=self.model_name, |
|
device=device |
|
) |
|
|
|
def detect_hallucination(self, text): |
|
""" |
|
Simple hallucination detection mechanism. |
|
""" |
|
try: |
|
|
|
word_count = len(text.split()) |
|
unique_words = len(set(text.split())) |
|
|
|
is_potential_hallucination = ( |
|
word_count < 5 or |
|
text.count('.') < 1 or |
|
(unique_words / word_count) < self.threshold |
|
) |
|
return is_potential_hallucination |
|
except Exception as e: |
|
print(f"Hallucination detection error: {e}") |
|
return False |
|
|
|
def generate_text(self, input_text): |
|
""" |
|
Generate text with hallucination mitigation. |
|
""" |
|
try: |
|
|
|
generated = self.generator( |
|
input_text, |
|
max_length=100, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=0.7 |
|
) |
|
|
|
|
|
output = generated[0]['generated_text'] |
|
|
|
|
|
if self.detect_hallucination(output): |
|
output = f"⚠️ Generated text might contain hallucinations:\n{output}" |
|
|
|
return output |
|
|
|
except Exception as e: |
|
return f"Error generating text: {str(e)}" |
|
|
|
def create_gradio_interface(): |
|
""" |
|
Create Gradio interface for text generation. |
|
""" |
|
|
|
mitigator = HallucinationMitigator() |
|
|
|
|
|
iface = gr.Interface( |
|
fn=mitigator.generate_text, |
|
inputs=gr.Textbox(label="Enter your prompt"), |
|
outputs=gr.Textbox(label="Generated Text"), |
|
title="OPERA - Hallucination Mitigation", |
|
description="Text generation with Over-Trust Penalty (OTP) hallucination detection", |
|
theme="default" |
|
) |
|
|
|
return iface |
|
|
|
|
|
def main(): |
|
|
|
interface = create_gradio_interface() |
|
|
|
|
|
interface.launch(server_name='0.0.0.0') |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|