|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import time |
|
import gradio as gr |
|
|
|
def generate_prompt(instruction, input=""): |
|
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') |
|
input = input.strip().replace('\r\n','\n').replace('\n\n','\n') |
|
if input: |
|
return f"""Instruction: {instruction} |
|
|
|
Input: {input} |
|
|
|
Response:""" |
|
else: |
|
return f"""User: hi |
|
|
|
Lover: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. |
|
|
|
User: {instruction} |
|
|
|
Lover:""" |
|
|
|
model_path = "models/rwkv-6-world-1b6/" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
trust_remote_code=True, |
|
use_flash_attention_2=False |
|
).to(torch.float32) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, |
|
bos_token="</s>", |
|
eos_token="</ s>", |
|
unk_token="<unk>", |
|
pad_token="<pad>", |
|
trust_remote_code=True, |
|
padding_side='left', |
|
clean_up_tokenization_spaces=False |
|
) |
|
|
|
|
|
def generate_text(input_text): |
|
prompt = generate_prompt(input_text) |
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
generated_text = "" |
|
stop_sequence_found = False |
|
for i in range(333): |
|
output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0) |
|
new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True) |
|
|
|
print(new_word, end="", flush=True) |
|
generated_text += new_word |
|
|
|
if new_word == '\n' or new_word == '.': |
|
stop_sequence_found = True |
|
break |
|
|
|
input_ids = output |
|
|
|
if stop_sequence_found: |
|
print("\n(Stop sequence found)") |
|
print() |
|
return generated_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs="text", |
|
outputs="text", |
|
title="RWKV Chatbot", |
|
description="Enter your prompt below:", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|