Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import torch
|
2 |
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
|
3 |
from datasets import load_dataset, load_metric
|
|
|
4 |
|
5 |
# Carregar o dataset IMDB
|
6 |
dataset = load_dataset('imdb')
|
|
|
7 |
|
8 |
# Carregar o tokenizer e o modelo RoBERTa
|
9 |
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
@@ -57,15 +59,21 @@ print(results)
|
|
57 |
model.save_pretrained('./model')
|
58 |
tokenizer.save_pretrained('./model')
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
commit_message="Initial upload of RoBERTa IMDB model"
|
70 |
)
|
71 |
|
|
|
|
|
|
1 |
import torch
|
2 |
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
|
3 |
from datasets import load_dataset, load_metric
|
4 |
+
import gradio as gr
|
5 |
|
6 |
# Carregar o dataset IMDB
|
7 |
dataset = load_dataset('imdb')
|
8 |
+
metric = load_metric('accuracy')
|
9 |
|
10 |
# Carregar o tokenizer e o modelo RoBERTa
|
11 |
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
|
|
59 |
model.save_pretrained('./model')
|
60 |
tokenizer.save_pretrained('./model')
|
61 |
|
62 |
+
# Função de inferência
|
63 |
+
def predict(text):
|
64 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
65 |
+
outputs = model(**inputs)
|
66 |
+
predictions = torch.argmax(outputs.logits, dim=-1)
|
67 |
+
return "Positive" if predictions.item() == 1 else "Negative"
|
68 |
|
69 |
+
# Interface Gradio
|
70 |
+
iface = gr.Interface(
|
71 |
+
fn=predict,
|
72 |
+
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter a movie review..."),
|
73 |
+
outputs="text",
|
74 |
+
title="IMDB Review Sentiment Analysis",
|
75 |
+
description="A simple Gradio interface to predict sentiment of IMDB movie reviews using a RoBERTa model."
|
|
|
76 |
)
|
77 |
|
78 |
+
if __name__ == "__main__":
|
79 |
+
iface.launch()
|