File size: 5,240 Bytes
40f110b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
from PIL import Image
import torch
import numpy as np
import faiss

from transformers import (
    GitProcessor,
    GitForCausalLM,
    AutoTokenizer,
    AutoModelForCausalLM,
    CLIPProcessor,
    CLIPModel
)
from sentence_transformers import SentenceTransformer
from datasets import load_dataset

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

tokenizer_llama = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model_llama = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
).eval()

text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Загрузка только первых 10000 изображений через streaming
MAX_IMAGES = 10_000
dataset_stream = load_dataset("huggan/wikiart", split="train", streaming=True)
first_10000 = [x for i, x in enumerate(dataset_stream) if i < MAX_IMAGES]

image_index = faiss.read_index("image_index_llama.faiss")
text_index = faiss.read_index("text_index_llama.faiss")

def clean_caption(text):
    return text.replace("[ unused0 ]", "").strip()

def generate_captions(image: Image.Image):
    inputs = git_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

    captions = []
    with torch.no_grad():
        deterministic_ids = git_model.generate(
            pixel_values=inputs,
            max_new_tokens=30,
            do_sample=False
        )
        captions.append(clean_caption(git_processor.tokenizer.decode(deterministic_ids[0], skip_special_tokens=True)))

        sampled_ids = git_model.generate(
            pixel_values=inputs,
            max_new_tokens=30,
            do_sample=True,
            top_k=100,
            temperature=0.8,
            num_return_sequences=2
        )
        sampled = git_processor.tokenizer.batch_decode(sampled_ids, skip_special_tokens=True)
        captions.extend([clean_caption(c) for c in sampled])

    return captions

def refine_caption(base, desc1, desc2):
    prompt = f"""
Given the base caption that is true and factual:
\"{base}\"

And two descriptive captions:
1) {desc1}
2) {desc2}

Write a short, coherent description that is faithful to the base caption but incorporates descriptive elements from captions 1 and 2 without contradicting the original meaning.
"""
    inputs = tokenizer_llama(prompt, return_tensors="pt").to(model_llama.device)
    with torch.no_grad():
        output = model_llama.generate(**inputs, max_new_tokens=100, do_sample=False)
        text = tokenizer_llama.decode(output[0], skip_special_tokens=True)
        answer = text[len(prompt):].strip()
        for prefix in ["Example:", "example:"]:
            if answer.startswith(prefix):
                answer = answer[len(prefix):].strip()
        return answer

def get_text_embedding(text):
    emb = text_encoder.encode([text], normalize_embeddings=False).astype("float32")
    faiss.normalize_L2(emb)
    return emb

def get_image_embedding(image):
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)
    emb = image_features.cpu().numpy().astype("float32")
    faiss.normalize_L2(emb)
    return emb

def get_results_with_images(embedding, index, top_k=2):
    D, I = index.search(embedding, top_k)
    results = []
    for idx in I[0]:
        if idx >= MAX_IMAGES:
            continue
        try:
            item = first_10000[idx]
            img = item["image"]
            caption = item["caption"]
            caption_text = f"ID: {idx}\n{caption}"
            results.append((img, caption_text))
        except IndexError:
            continue
    return results

def search_similar_images(image: Image.Image):
    captions = generate_captions(image)
    refined = refine_caption(captions[0], captions[1], captions[2])

    text_emb = get_text_embedding(refined)
    image_emb = get_image_embedding(image)

    text_results = get_results_with_images(text_emb, text_index)
    image_results = get_results_with_images(image_emb, image_index)

    return refined, text_results, image_results

demo = gr.Interface(
    fn=search_similar_images,
    inputs=gr.Image(label="Загрузите изображение", type="pil"),
    outputs=[
        gr.Textbox(label="📜 Сгенерированное описание"),
        gr.Gallery(label="🔍 Похожие по описанию (caption)", height="auto", columns=2),
        gr.Gallery(label="🎨 Похожие по изображению (CLIP)", height="auto", columns=2)
    ],
    title="🎨 Semantic WikiArt Search",
    description="Загрузите изображение. Модель сгенерирует описание, получит эмбеддинги и найдёт похожие картины по описанию и изображению."
)

demo.launch()