import transformers import gradio as gr import torch import csv # Load a pre-trained model model = transformers.AutoModel.from_pretrained("Wessel/DiabloGPT-medium-harrypotter") model.eval() # Define a function to run the model on input text def predict_sentiment(input_text): input_ids = transformers.BertTokenizer.encode(input_text, add_special_tokens=True) input_ids = torch.tensor(input_ids).unsqueeze(0) outputs = model(input_ids) logits = outputs[0] sentiment = "Positive" if logits[0][0] > 0 else "Negative" return sentiment # Create a chat history to store previous inputs and outputs chat_history = [] # Define a function to update the chat history def update_history(input_text, sentiment): chat_history.append(f"User: {input_text}") chat_history.append(f"Model: {sentiment}") # Read the prompts from a CSV file prompts = [] with open("prompts.csv") as csvfile: reader = csv.reader(csvfile) for row in reader: prompts.append(row[0]) # Create an input interface using Gradio inputs = gr.inputs.Dropdown(prompts, default=prompts[0]) # Create an output interface using Gradio outputs = gr.outputs.Chatbox(label="Sentiment", lines=1) # Run the interface interface = gr.Interface(predict_sentiment, inputs, outputs, title="Sentiment Analysis", on_output=update_history) interface.launch()