Spaces:
Runtime error
Runtime error
import torch | |
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments | |
from datasets import load_dataset, load_metric | |
import gradio as gr | |
# Carregar o dataset IMDB | |
dataset = load_dataset('imdb') | |
metric = load_metric('accuracy') | |
# Carregar o tokenizer e o modelo RoBERTa | |
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
model = RobertaForSequenceClassification.from_pretrained('roberta-base') | |
# Tokenizar os dados | |
def preprocess_function(examples): | |
return tokenizer(examples['text'], padding='max_length', truncation=True) | |
tokenized_datasets = dataset.map(preprocess_function, batched=True) | |
# Preparar o data collator | |
from transformers import DataCollatorWithPadding | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
# Configurar os argumentos de treinamento | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy='epoch', | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
) | |
# Definir a função de métricas | |
def compute_metrics(eval_pred): | |
logits, labels = eval_pred | |
predictions = torch.argmax(logits, dim=-1) | |
return metric.compute(predictions=predictions, references=labels) | |
# Definir o Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_datasets['train'], | |
eval_dataset=tokenized_datasets['test'], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
compute_metrics=compute_metrics | |
) | |
# Treinar o modelo | |
trainer.train() | |
# Avaliar o modelo | |
results = trainer.evaluate() | |
print(results) | |
# Salvar o modelo | |
model.save_pretrained('./model') | |
tokenizer.save_pretrained('./model') | |
# Função de inferência | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
outputs = model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=-1) | |
return "Positive" if predictions.item() == 1 else "Negative" | |
# Interface Gradio | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter a movie review..."), | |
outputs="text", | |
title="IMDB Review Sentiment Analysis", | |
description="A simple Gradio interface to predict sentiment of IMDB movie reviews using a RoBERTa model." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |