Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import TextIteratorStreamer | |
import threading | |
class ModelWrapper: | |
def __init__(self): | |
self.model = None # Model will be loaded when GPU is allocated | |
def generate(self, prompt): | |
if self.model is None: | |
# Load the model when GPU is allocated | |
self.model = AutoGPTQForCausalLM.from_quantized( | |
model_id, | |
device_map='auto', | |
trust_remote_code=True, | |
) | |
self.model.eval() | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors='pt').to('cuda') | |
# Set up the streamer | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
# Prepare generation arguments | |
generation_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
do_sample=True, | |
max_new_tokens=512, | |
) | |
# Start generation in a separate thread to enable streaming | |
thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield generated text in real-time | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
yield generated_text |