Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import transformers | |
from transformers import AutoModel, AutoTokenizer | |
import numpy as np | |
import torch | |
model_names = ['python'] | |
models = {} | |
model = AutoModel.from_pretrained(f'ZarahShibli/tmp_trainer',return_dict=False) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print("device",device) | |
MAX_LEN = 200 | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', return_tensors='pt') | |
target_cols = [ 'DevelopmentNotes', 'Expand','Parameters', 'Summary', 'Usage'] | |
def predict(comment_sentence): | |
# Set the model in evaluation mode | |
model.eval() | |
fin_outputs=[] # create an empty list to store outputs | |
# Tokenize the comment sentence using the BERT tokenizer and encode it with special tokens | |
inputs = tokenizer.encode_plus( | |
comment_sentence, | |
truncation=True, | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
padding='max_length', | |
return_token_type_ids=True, | |
return_tensors='pt' | |
) | |
# Retrieve the input ids, attention mask, and token type ids from the encoded inputs | |
ids = inputs['input_ids'].to(device, dtype=torch.long) | |
mask = inputs['attention_mask'].to(device, dtype=torch.long) | |
token_type_ids = inputs['token_type_ids'].to(device, dtype=torch.long) | |
# Forward pass through the model | |
with torch.no_grad(): | |
outputs = model(ids, mask, token_type_ids) | |
print(outputs) | |
fin_outputs.extend(torch.sigmoid(outputs).detach().numpy().tolist()) | |
print(fin_outputs) | |
# Convert the outputs to boolean values based on the threshold | |
outputs_boolean = np.array(fin_outputs) >= 0.25 | |
# Get the indices where outputs are true | |
true_indices = np.where(outputs_boolean)[1] | |
# Map the indices to their corresponding categories | |
predicted_categories = [target_cols[idx] for idx in true_indices] | |
return predicted_categories | |
iface = gr.Interface(fn=predict, | |
inputs="text", #gr.inputs.Dropdown(model_names, label='class')], | |
outputs="text", | |
) | |
iface.launch() | |