| | import os |
| | import io |
| | import torch |
| | import pandas as pd |
| | import gradio as gr |
| | from PIL import Image |
| | from sd_parsers import ParserManager |
| | from torchvision import transforms |
| | from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig |
| | import lpips |
| | import piq |
| | import plotly.express as px |
| |
|
| | |
| | |
| | |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
| | clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| |
|
| | |
| | blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") |
| | if torch.cuda.is_available(): |
| | bnb_config = BitsAndBytesConfig(load_in_8bit=True) |
| | blip_model = Blip2ForConditionalGeneration.from_pretrained( |
| | "Salesforce/blip2-flan-t5-xl", |
| | quantization_config=bnb_config, |
| | device_map="auto" |
| | ) |
| | else: |
| | blip_model = Blip2ForConditionalGeneration.from_pretrained( |
| | "Salesforce/blip2-flan-t5-xl", |
| | torch_dtype=torch.float16 |
| | ).to(device) |
| |
|
| | |
| | lpips_model = lpips.LPIPS(net='alex').to(device) |
| |
|
| | |
| | |
| | |
| |
|
| | def extract_metadata(file): |
| | """Extract prompt and model name using sd-parsers from file path.""" |
| | parser = ParserManager() |
| | info = parser.parse(file.name) |
| | prompt = info.prompts[0].value if info.prompts else '' |
| | |
| | model_name = '' |
| | if hasattr(info, 'models') and info.models: |
| | |
| | first = next(iter(info.models)) |
| | model_name = first.name if hasattr(first, 'name') else str(first) |
| | return prompt, model_name |
| |
|
| | |
| | preprocess = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | (0.48145466, 0.4578275, 0.40821073), |
| | (0.26862954, 0.26130258, 0.27577711) |
| | ) |
| | ]) |
| |
|
| | |
| | |
| | |
| |
|
| | def compute_clip_score(img: Image.Image, text: str) -> float: |
| | inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device) |
| | outputs = clip_model(**inputs) |
| | score = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds) |
| | return float((score.clamp(min=0) * 100).mean()) |
| |
|
| | @torch.no_grad() |
| | def compute_caption_similarity(img: Image.Image, prompt: str) -> float: |
| | inputs = blip_processor(images=img, return_tensors="pt").to(device) |
| | out = blip_model.generate(**inputs) |
| | caption = blip_processor.decode(out[0], skip_special_tokens=True) |
| | return compute_clip_score(img, caption) |
| |
|
| | @torch.no_grad() |
| | def compute_iqa_metrics(img: Image.Image): |
| | tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) |
| | brisque = float(piq.brisque(tensor).cpu()) |
| | niqe = float(piq.niqe(tensor).cpu()) |
| | return brisque, niqe |
| |
|
| | @torch.no_grad() |
| | def compute_lpips_pair(img1: Image.Image, img2: Image.Image) -> float: |
| | t1 = transforms.ToTensor()(img1).unsqueeze(0).to(device) |
| | t2 = transforms.ToTensor()(img2).unsqueeze(0).to(device) |
| | return float(lpips_model(t1, t2).cpu()) |
| |
|
| | |
| | |
| | |
| |
|
| | def analyze_images(files): |
| | records = [] |
| | imgs_by_model = {} |
| |
|
| | for f in files: |
| | img = Image.open(f.name).convert('RGB') |
| | prompt, model = extract_metadata(f) |
| |
|
| | cs = compute_clip_score(img, prompt) |
| | cap_sim = compute_caption_similarity(img, prompt) |
| | brisque, niqe = compute_iqa_metrics(img) |
| | aesthetic = compute_clip_score(img, "a beautiful high quality image") |
| |
|
| | records.append({ |
| | 'model': model, |
| | 'prompt': prompt, |
| | 'clip_score': cs, |
| | 'caption_sim': cap_sim, |
| | 'brisque': brisque, |
| | 'niqe': niqe, |
| | 'aesthetic': aesthetic |
| | }) |
| | imgs_by_model.setdefault(model, []).append(img) |
| |
|
| | df = pd.DataFrame(records) |
| |
|
| | diversity = {} |
| | for model, imgs in imgs_by_model.items(): |
| | if len(imgs) < 2: |
| | diversity[model] = 0.0 |
| | else: |
| | pairs = [compute_lpips_pair(imgs[i], imgs[j]) |
| | for i in range(len(imgs)) for j in range(i+1, len(imgs))] |
| | diversity[model] = sum(pairs) / len(pairs) |
| |
|
| | agg = df.groupby('model').agg( |
| | clip_score_mean=('clip_score', 'mean'), |
| | caption_sim_mean=('caption_sim', 'mean'), |
| | brisque_mean=('brisque', 'mean'), |
| | niqe_mean=('niqe', 'mean'), |
| | aesthetic_mean=('aesthetic', 'mean') |
| | ).reset_index() |
| | agg['diversity'] = agg['model'].map(diversity) |
| |
|
| | return df, agg |
| |
|
| | |
| | |
| | |
| |
|
| | def plot_metrics(agg: pd.DataFrame): |
| | return px.bar( |
| | agg, |
| | x='model', |
| | y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'], |
| | barmode='group', |
| | title='Сравнение моделей по метрикам' |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def run_analysis(files): |
| | df, agg = analyze_images(files) |
| | fig = plot_metrics(agg) |
| | return df, fig |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# AI Image Quality Evaluator") |
| | gr.Markdown("Загрузите PNG-изображения (с EXIF-метаданными SD) для анализа и сравнения моделей.") |
| |
|
| | with gr.Row(): |
| | input_files = gr.File(file_count="multiple", label="Выберите PNG файлы") |
| | output_table = gr.Dataframe( |
| | headers=[ |
| | "model", "clip_score_mean", "caption_sim_mean", "brisque_mean", |
| | "niqe_mean", "aesthetic_mean", "diversity" |
| | ], |
| | label="Сводная таблица" |
| | ) |
| |
|
| | plot_output = gr.Plot(label="График метрик") |
| |
|
| | run_btn = gr.Button("Запустить анализ") |
| | run_btn.click(run_analysis, inputs=[input_files], outputs=[output_table, plot_output]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name='0.0.0.0', share=False) |
| |
|