|
import transformers |
|
import gradio as gr |
|
import torch |
|
import csv |
|
|
|
|
|
model = transformers.AutoModel.from_pretrained("Wessel/DiabloGPT-medium-harrypotter") |
|
model.eval() |
|
|
|
|
|
def predict_sentiment(input_text): |
|
input_ids = transformers.BertTokenizer.encode(input_text, add_special_tokens=True) |
|
input_ids = torch.tensor(input_ids).unsqueeze(0) |
|
outputs = model(input_ids) |
|
logits = outputs[0] |
|
sentiment = "Positive" if logits[0][0] > 0 else "Negative" |
|
return sentiment |
|
|
|
|
|
chat_history = [] |
|
|
|
|
|
def update_history(input_text, sentiment): |
|
chat_history.append(f"User: {input_text}") |
|
chat_history.append(f"Model: {sentiment}") |
|
|
|
|
|
prompts = [] |
|
with open("prompts.csv") as csvfile: |
|
reader = csv.reader(csvfile) |
|
for row in reader: |
|
prompts.append(row[0]) |
|
|
|
|
|
inputs = gr.inputs.Dropdown(prompts, default=prompts[0]) |
|
|
|
|
|
outputs = gr.outputs.Chatbox(label="Sentiment", lines=1) |
|
|
|
|
|
interface = gr.Interface(predict_sentiment, inputs, outputs, title="Sentiment Analysis", |
|
on_output=update_history) |
|
interface.launch() |