fffiloni commited on
Commit
522e040
1 Parent(s): fec39f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -2,10 +2,13 @@ import gradio as gr
2
  import torch
3
 
4
  from PIL import Image
5
- from transformers import BlipProcessor, BlipForConditionalGeneration
6
 
7
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
8
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
 
 
 
9
 
10
  import os
11
  hf_token = os.environ.get('HF_TOKEN')
@@ -16,17 +19,28 @@ def infer(image_input):
16
  #img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
17
  raw_image = Image.open(image_input).convert('RGB')
18
 
19
- # unconditional image captioning
20
- inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
21
-
22
- out = model.generate(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- caption = processor.decode(out[0], skip_special_tokens=True)
25
- print(caption)
26
 
27
  llama_q = f"""
28
  I'll give you a simple image caption, from i want you to provide a story that would fit well with the image:
29
- '{caption}'
30
 
31
  """
32
 
@@ -40,7 +54,7 @@ def infer(image_input):
40
 
41
  print(f"Llama2 result: {result}")
42
 
43
- return caption, result
44
 
45
  css="""
46
  #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
 
2
  import torch
3
 
4
  from PIL import Image
5
+ from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
6
 
7
+ model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
8
+ processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
 
13
  import os
14
  hf_token = os.environ.get('HF_TOKEN')
 
19
  #img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
20
  raw_image = Image.open(image_input).convert('RGB')
21
 
22
+ prompt = "Can you please describe what's happening in the image, and give information about the characters and the place ?"
23
+ inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(device)
24
+
25
+ outputs = model.generate(
26
+ **inputs,
27
+ do_sample=False,
28
+ num_beams=5,
29
+ max_length=256,
30
+ min_length=1,
31
+ top_p=0.9,
32
+ repetition_penalty=1.5,
33
+ length_penalty=1.0,
34
+ temperature=1,
35
+ )
36
+ generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
37
+ print(generated_text)
38
 
39
+
 
40
 
41
  llama_q = f"""
42
  I'll give you a simple image caption, from i want you to provide a story that would fit well with the image:
43
+ '{generated_text}'
44
 
45
  """
46
 
 
54
 
55
  print(f"Llama2 result: {result}")
56
 
57
+ return generated_text, result
58
 
59
  css="""
60
  #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}