import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-3b", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( "stabilityai/stable-code-3b", trust_remote_code=True, torch_dtype="auto" ) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [0, 2] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def chat(message, history): stop = StopOnTokens() history = history or [] inputs = tokenizer(message, return_tensors="pt").to(model.device) print('generate') tokens = model.generate( **inputs, max_new_tokens=4096, temperature=0.2, do_sample=True, ) print('decode') response = tokenizer.decode(tokens[0], skip_special_tokens=True) history.append((message, response)) return history, history iface = gr.Interface( chat, ["text", "state"], ["chatbot", "state"], allow_flagging="never" ) iface.launch()