Fizzarolli commited on
Commit
1967c03
1 Parent(s): 009db99
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import spaces
3
  from PIL import Image
4
  from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
 
 
5
 
6
  TITLE = "E621 Tagger"
7
  DESCRIPTION = "Tag images with E621 tags"
@@ -14,12 +16,23 @@ processor = AutoProcessor.from_pretrained(MODEL_ID)
14
  @spaces.GPU
15
  def tag_image(image):
16
  inputs = processor(images=image, text="<image>tag en", return_tensors="pt")
17
- outputs = model.generate(**inputs)
18
- return processor.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  gr.Interface(
21
  fn=tag_image,
22
- inputs=gr.Image(type="pil"),
23
  outputs=gr.Textbox(type="text"),
24
  title=TITLE,
25
  description=DESCRIPTION,
 
2
  import spaces
3
  from PIL import Image
4
  from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
5
+ from transformers import TextIteratorStreamer
6
+ from threading import Thread
7
 
8
  TITLE = "E621 Tagger"
9
  DESCRIPTION = "Tag images with E621 tags"
 
16
  @spaces.GPU
17
  def tag_image(image):
18
  inputs = processor(images=image, text="<image>tag en", return_tensors="pt")
19
+ streamer = TextIteratorStreamer(tokenizer=processor.tokenizer, skip_prompt=True)
20
+ generation_kwargs = dict(inputs, streamer=streamer, return_full_text=False)
21
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
22
+
23
+ text = ""
24
+ thread.start()
25
+ for new_text in streamer:
26
+ text += new_text
27
+ yield text
28
+
29
+ thread.join()
30
+
31
+ return text
32
 
33
  gr.Interface(
34
  fn=tag_image,
35
+ inputs=[gr.Image(type="pil")],
36
  outputs=gr.Textbox(type="text"),
37
  title=TITLE,
38
  description=DESCRIPTION,