fffiloni commited on
Commit
7175dd2
1 Parent(s): 3e63247

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -7,6 +7,10 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
7
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
8
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
9
 
 
 
 
 
10
  def infer(image_input):
11
  #img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
12
  raw_image = Image.open(image_input).convert('RGB')
@@ -19,7 +23,26 @@ def infer(image_input):
19
  caption = processor.decode(out[0], skip_special_tokens=True)
20
  print(caption)
21
 
22
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  css="""
25
  #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
@@ -39,7 +62,8 @@ with gr.Blocks(css=css) as demo:
39
  )
40
  image_in = gr.Image(label="Image input", type="filepath")
41
  submit_btn = gr.Button('Sumbit')
42
- story = gr.Textbox(label="Generated Story")
43
- submit_btn.click(fn=infer, inputs=[image_in], outputs=[story])
 
44
 
45
  demo.queue().launch()
 
7
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
8
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
9
 
10
+ hf_token = os.environ.get('HF_TOKEN')
11
+ from gradio_client import Client
12
+ client = Client("https://fffiloni-test-llama-api.hf.space/", hf_token=hf_token)
13
+
14
  def infer(image_input):
15
  #img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
16
  raw_image = Image.open(image_input).convert('RGB')
 
23
  caption = processor.decode(out[0], skip_special_tokens=True)
24
  print(caption)
25
 
26
+ llama_q = f"""
27
+ I'll give you a simple image caption, from i want you to provide a story that would fit well with the image.
28
+
29
+ Here's the music description :
30
+
31
+ {caption}
32
+
33
+ """
34
+
35
+ result = client.predict(
36
+ llama_q, # str in 'Message' Textbox component
37
+ api_name="/predict"
38
+ )
39
+
40
+
41
+
42
+
43
+ print(f"Llama2 result: {result}")
44
+
45
+ return caption, result
46
 
47
  css="""
48
  #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
 
62
  )
63
  image_in = gr.Image(label="Image input", type="filepath")
64
  submit_btn = gr.Button('Sumbit')
65
+ caption = gr.Textbox(label="Generated Caption")
66
+ story = gr.Textbox(label="generated Story")
67
+ submit_btn.click(fn=infer, inputs=[image_in], outputs=[caption, story])
68
 
69
  demo.queue().launch()