File size: 5,974 Bytes
9f58c97
322e17b
c967827
3938eeb
c967827
 
9f58c97
 
a1918b1
9f58c97
 
 
 
 
43c2efd
9f58c97
 
c967827
9f58c97
 
c967827
9f58c97
 
 
 
 
 
 
c967827
9f58c97
 
 
 
 
 
 
c967827
 
 
9f58c97
c967827
9f58c97
 
 
 
 
 
 
 
c967827
 
9f58c97
c967827
9f58c97
 
 
 
 
 
c967827
9f58c97
 
b3a8f85
9f58c97
a1918b1
9f58c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# blip2_generate_fix.py
import requests
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration

MODEL = "Salesforce/blip2-opt-2.7b"  # یا مدل موردنظر
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("loading processor and model (this may take a while)...")
processor = Blip2Processor.from_pretrained(MODEL, use_auth_token=None)
model = Blip2ForConditionalGeneration.from_pretrained(MODEL, use_auth_token=None)
model.to(DEVICE)
model.eval()

def generate_answer_from_image(pil_image: Image.Image, prompt_text: str = "", 
                               max_new_tokens: int = 64, num_beams: int = 4, temperature: float = 1.0):
    """
    Robust generation that avoids returning the prompt in the final output.
    Returns: (generated_text, debug_info)
    """
    # آماده‌سازی ورودی با پردازشگر
    inputs = processor(pil_image, prompt_text, return_tensors="pt")
    
    # تعیین طول پرامپت (تعداد توکن‌های ورودی)
    input_len = 0
    if "input_ids" in inputs and inputs["input_ids"] is not None:
        input_len = inputs["input_ids"].shape[-1]
    else:
        # اگر input_ids وجود ندارد، متن پرامپت را با tokenizer رمزکن کن تا طولش را داشته باشیم
        # (بعضی پردازشگرها pixel_values فقط بازمی‌گردانند)
        tok = processor.tokenizer(prompt_text, return_tensors="pt")
        if "input_ids" in tok and tok["input_ids"] is not None:
            input_len = tok["input_ids"].shape[-1]
        else:
            input_len = 0

    # انتقال به دستگاه
    for k, v in inputs.items():
        inputs[k] = v.to(DEVICE)

    # تولید با جزئیات و بازگشت dict
    gen_kwargs = dict(max_new_tokens=max_new_tokens, num_beams=num_beams)
    # اگر temperature=0 => deterministic (do_sample=False)
    if temperature is None or temperature == 0:
        gen_kwargs["do_sample"] = False
    else:
        gen_kwargs["do_sample"] = True
        gen_kwargs["temperature"] = float(temperature)

    with torch.no_grad():
        outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=False, **gen_kwargs)

    # گرفتن توالی کامل (ممکن است tensor یا attribute باشد)
    if hasattr(outputs, "sequences"):
        sequences = outputs.sequences
    else:
        # fallback (older transformers)
        sequences = outputs

    # اطمینان از انتقال به CPU برای پردازش tokenizer
    sequences_cpu = sequences.cpu()

    # slice جدید: فقط توکن‌های جدید (اگر طول خروجی بیشتر از طول پرامپت بود)
    try:
        total_len = sequences_cpu.shape[-1]
    except Exception:
        total_len = None

    generated_text = ""
    debug = {"input_len": input_len, "total_len": total_len}

    if total_len and input_len and total_len > input_len:
        # فقط بخش تولیدشده را جدا کن
        gen_tokens = sequences_cpu[:, input_len:]
        # batch_decode منتظر لیست یا آرایه‌ی اعداد است
        token_list = gen_tokens[0].tolist()
        generated_text = processor.tokenizer.decode(token_list, skip_special_tokens=True).strip()
        debug["method"] = "slice_tokens"
    else:
        # fallback: اگر نتوانستیم برش بزنیم، کل توالی را decode کن و سپس تلاش کن پرامپت متنی را از ابتدای
        # خروجی حذف کنی (پاراگرافی). این روش آخرین امید است زیرا معیار دقیق‌تری نیست،
        # ولی امن‌تر از بازگرداندن کل prompt است.
        full = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip()
        debug["method"] = "decode_full"
        # تلاش برای حذف متن پرامپت (اولین وقوع) — فقط اگر prompt_text غیرخالی و در متن وجود داشته باشد
        if prompt_text and prompt_text.strip():
            # ممکن است تفاوت‌های tokenization منجر به عدم داشتن دقیق prompt_text شود؛
            # ما اولین وقوع متن پرامپت را حذف می‌کنیم اگر دقیقاً در خروجی آمده باشد.
            if prompt_text.strip() in full:
                generated_text = full.replace(prompt_text.strip(), "", 1).strip()
                debug["removed_prompt_by_string"] = True
            else:
                # به عنوان آخرین راه، اگر prompt کوتاه است، سعی می‌کنیم تا نزدیک‌ترین بخش را حذف کنیم
                # (ایمن عمل کن: فقط اگر طول خروجی خیلی طولانی باشد)
                generated_text = full
                debug["removed_prompt_by_string"] = False
        else:
            generated_text = full

    # اگر هنوز خالی بود، خروجی کامل را بازگردان کن (ولی این دیگر نباید پرامپت باشد)
    if not generated_text:
        # در صورت نیاز می‌توانیم full را بازگردانیم
        generated_text = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip()
        debug["final_fallback"] = True

    return generated_text, debug

# -------------------- مثال اجرای تست --------------------
if __name__ == "__main__":
    img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
    prompt = "Describe the image in detail in fluent English."
    output, dbg = generate_answer_from_image(raw_image, prompt_text=prompt, max_new_tokens=64, num_beams=4, temperature=1.0)
    print("=== GENERATED ===")
    print(output)
    print("=== DEBUG ===")
    print(dbg)