RWKV / app.py
Stevross's picture
Create app.py
c53c6b7
raw
history blame
1.52 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
# Load the pre-trained model and tokenizer
models = {
"EleutherAI/gpt-neo-2.7B": "EleutherAI/gpt-neo-2.7B",
"BlinkDL/rwkv-4-pile-430m": "BlinkDL/rwkv-4-pile-430m",
"BlinkDL/rwkv-4-pile-1b5": "BlinkDL/rwkv-4-pile-1b5",
"BlinkDL/RWKV-4-Raven": "BlinkDL/RWKV-4-Raven"
}
def generate_text(prompt, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Tokenize the input
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Generate a response
output = model.generate(input_ids, max_length=100, num_return_sequences=1)
# Decode the output
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
def main():
# Create a Gradio interface
model_dropdown = gr.inputs.Dropdown(choices=list(models.keys()), label="Select Model")
prompt_input = gr.inputs.Textbox(lines=5, placeholder="Enter your prompt here...", label="Prompt")
output_text = gr.outputs.Textbox(label="Generated Text")
interface = gr.Interface(
fn=generate_text,
inputs=[prompt_input, model_dropdown],
outputs=output_text,
title="Chat-bot using RWKV LLM",
description="Select a model and enter a prompt to generate text using the chat-bot."
)
# Launch the interface
interface.launch()
if __name__ == '__main__':
main()