ROS2 / app.py
AnasHXH's picture
Create app.py
79bb5a6
raw
history blame
991 Bytes
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load your model
model_checkpoint = "AnasHXH/Ros_model"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
def generate_command(input_text):
# Tokenize text and convert to model input format
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
# Generate output from the model
outputs = model.generate(inputs["input_ids"])
# Decode the generated tokens to text
command = tokenizer.decode(outputs[0], skip_special_tokens=True)
return command
# Define your Gradio interface
iface = gr.Interface(
fn=generate_command, # the function to wrap
inputs="text", # the input data type
outputs="text", # the output data type
title="Robot Command Generator",
description="Type in English to get the robot command"
)
# Run the Gradio app
iface.launch()