Spaces:
Runtime error
Runtime error
File size: 5,974 Bytes
9f58c97 322e17b c967827 3938eeb c967827 9f58c97 a1918b1 9f58c97 43c2efd 9f58c97 c967827 9f58c97 c967827 9f58c97 c967827 9f58c97 c967827 9f58c97 c967827 9f58c97 c967827 9f58c97 c967827 9f58c97 c967827 9f58c97 b3a8f85 9f58c97 a1918b1 9f58c97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# blip2_generate_fix.py
import requests
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
MODEL = "Salesforce/blip2-opt-2.7b" # یا مدل موردنظر
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("loading processor and model (this may take a while)...")
processor = Blip2Processor.from_pretrained(MODEL, use_auth_token=None)
model = Blip2ForConditionalGeneration.from_pretrained(MODEL, use_auth_token=None)
model.to(DEVICE)
model.eval()
def generate_answer_from_image(pil_image: Image.Image, prompt_text: str = "",
max_new_tokens: int = 64, num_beams: int = 4, temperature: float = 1.0):
"""
Robust generation that avoids returning the prompt in the final output.
Returns: (generated_text, debug_info)
"""
# آمادهسازی ورودی با پردازشگر
inputs = processor(pil_image, prompt_text, return_tensors="pt")
# تعیین طول پرامپت (تعداد توکنهای ورودی)
input_len = 0
if "input_ids" in inputs and inputs["input_ids"] is not None:
input_len = inputs["input_ids"].shape[-1]
else:
# اگر input_ids وجود ندارد، متن پرامپت را با tokenizer رمزکن کن تا طولش را داشته باشیم
# (بعضی پردازشگرها pixel_values فقط بازمیگردانند)
tok = processor.tokenizer(prompt_text, return_tensors="pt")
if "input_ids" in tok and tok["input_ids"] is not None:
input_len = tok["input_ids"].shape[-1]
else:
input_len = 0
# انتقال به دستگاه
for k, v in inputs.items():
inputs[k] = v.to(DEVICE)
# تولید با جزئیات و بازگشت dict
gen_kwargs = dict(max_new_tokens=max_new_tokens, num_beams=num_beams)
# اگر temperature=0 => deterministic (do_sample=False)
if temperature is None or temperature == 0:
gen_kwargs["do_sample"] = False
else:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = float(temperature)
with torch.no_grad():
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=False, **gen_kwargs)
# گرفتن توالی کامل (ممکن است tensor یا attribute باشد)
if hasattr(outputs, "sequences"):
sequences = outputs.sequences
else:
# fallback (older transformers)
sequences = outputs
# اطمینان از انتقال به CPU برای پردازش tokenizer
sequences_cpu = sequences.cpu()
# slice جدید: فقط توکنهای جدید (اگر طول خروجی بیشتر از طول پرامپت بود)
try:
total_len = sequences_cpu.shape[-1]
except Exception:
total_len = None
generated_text = ""
debug = {"input_len": input_len, "total_len": total_len}
if total_len and input_len and total_len > input_len:
# فقط بخش تولیدشده را جدا کن
gen_tokens = sequences_cpu[:, input_len:]
# batch_decode منتظر لیست یا آرایهی اعداد است
token_list = gen_tokens[0].tolist()
generated_text = processor.tokenizer.decode(token_list, skip_special_tokens=True).strip()
debug["method"] = "slice_tokens"
else:
# fallback: اگر نتوانستیم برش بزنیم، کل توالی را decode کن و سپس تلاش کن پرامپت متنی را از ابتدای
# خروجی حذف کنی (پاراگرافی). این روش آخرین امید است زیرا معیار دقیقتری نیست،
# ولی امنتر از بازگرداندن کل prompt است.
full = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip()
debug["method"] = "decode_full"
# تلاش برای حذف متن پرامپت (اولین وقوع) — فقط اگر prompt_text غیرخالی و در متن وجود داشته باشد
if prompt_text and prompt_text.strip():
# ممکن است تفاوتهای tokenization منجر به عدم داشتن دقیق prompt_text شود؛
# ما اولین وقوع متن پرامپت را حذف میکنیم اگر دقیقاً در خروجی آمده باشد.
if prompt_text.strip() in full:
generated_text = full.replace(prompt_text.strip(), "", 1).strip()
debug["removed_prompt_by_string"] = True
else:
# به عنوان آخرین راه، اگر prompt کوتاه است، سعی میکنیم تا نزدیکترین بخش را حذف کنیم
# (ایمن عمل کن: فقط اگر طول خروجی خیلی طولانی باشد)
generated_text = full
debug["removed_prompt_by_string"] = False
else:
generated_text = full
# اگر هنوز خالی بود، خروجی کامل را بازگردان کن (ولی این دیگر نباید پرامپت باشد)
if not generated_text:
# در صورت نیاز میتوانیم full را بازگردانیم
generated_text = processor.decode(sequences_cpu[0], skip_special_tokens=True).strip()
debug["final_fallback"] = True
return generated_text, debug
# -------------------- مثال اجرای تست --------------------
if __name__ == "__main__":
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
prompt = "Describe the image in detail in fluent English."
output, dbg = generate_answer_from_image(raw_image, prompt_text=prompt, max_new_tokens=64, num_beams=4, temperature=1.0)
print("=== GENERATED ===")
print(output)
print("=== DEBUG ===")
print(dbg)
|