khurrameycon commited on
Commit
a3cc5d4
·
verified ·
1 Parent(s): 7f7c55c

TextIteratorStreamer

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- from transformers import AutoProcessor, MllamaForConditionalGeneration
5
  from PIL import Image
6
  import spaces
7
  import tempfile
@@ -91,11 +91,28 @@ def predict_text(text, url = 'https://arinsight.co/2024_FA_AEC_1200_GR1_GR2.pdf'
91
  # inputs = processor(image, input_text, return_tensors="pt").to(device)
92
  inputs = processor(text=input_text, return_tensors="pt").to("cuda")
93
  # Generate a response from the model
94
- outputs = model.generate(**inputs, max_new_tokens=1024)
95
 
96
- # Decode the output to return the final response
97
- response = processor.decode(outputs[1], skip_special_tokens=True, skip_prompt=True)
98
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  # Define the Gradio interface
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ from transformers import AutoProcessor, MllamaForConditionalGeneration, TextIteratorStreamer
5
  from PIL import Image
6
  import spaces
7
  import tempfile
 
91
  # inputs = processor(image, input_text, return_tensors="pt").to(device)
92
  inputs = processor(text=input_text, return_tensors="pt").to("cuda")
93
  # Generate a response from the model
94
+ # outputs = model.generate(**inputs, max_new_tokens=1024)
95
 
96
+ # # Decode the output to return the final response
97
+ # response = processor.decode(outputs[0], skip_special_tokens=True, skip_prompt=True)
98
+
99
+
100
+ streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
101
+
102
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
103
+ generated_text = ""
104
+
105
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
106
+ thread.start()
107
+ buffer = ""
108
+
109
+ for new_text in streamer:
110
+ buffer += new_text
111
+ # generated_text_without_prompt = buffer
112
+ # # time.sleep(0.01)
113
+ # yield buffer
114
+
115
+ return buffer
116
 
117
 
118
  # Define the Gradio interface