import gradio as gr | |
#import torch | |
#from threading import Thread | |
#tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1") | |
#model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16) | |
#model = model.to('cuda:0') | |
#class StopOnTokens(StoppingCriteria): | |
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
# stop_ids = [29, 0] | |
# for stop_id in stop_ids: | |
# if input_ids[0][-1] == stop_id: | |
# return True | |
# return False | |
def predict(message, history): | |
history_transformer_format = history + [[message, ""]] | |
messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) | |
for item in history_transformer_format]) | |
#model_inputs = tokenizer([messages], return_tensors="pt").to("cuda") | |
#streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
#generate_kwargs = dict( | |
# model_inputs, | |
# streamer=streamer, | |
# max_new_tokens=1024, | |
# do_sample=True, | |
# top_p=0.95, | |
# top_k=1000, | |
# temperature=1.0, | |
# num_beams=1, | |
# stopping_criteria=StoppingCriteriaList([stop]) | |
# ) | |
#t = Thread(target=model.generate, kwargs=generate_kwargs) | |
#t.start() | |
#partial_message = "" | |
#for new_token in streamer: | |
# if new_token != '<': | |
# partial_message += new_token | |
# yield partial_message | |
return "Hello" | |
gr.ChatInterface(predict).launch() |