import string import gradio as gr import requests import torch from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, ) model_dir = "my-bert-model" config = AutoConfig.from_pretrained(model_dir, num_labels=2, finetuning_task="text-classification") tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) def inference(input_text): inputs = tokenizer.batch_encode_plus( [input_text], max_length=512, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors="pt", ) with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = logits.argmax().item() output = model.config.id2label[predicted_class_id] return output with gr.Blocks(css=""" .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} #component-21 > div.wrap.svelte-w6rprc {height: 600px;} """) as demo: with gr.Column(elem_id="container"): with gr.Row(): with gr.Row(): input_text = gr.Textbox( placeholder="Insert your prompt here:", scale=5, container=False ) answer = gr.Textbox(lines=0, label="Answer") generate_bt = gr.Button("Generate", scale=1) inputs = [input_text] outputs = [answer] generate_bt.click( fn=inference, inputs=inputs, outputs=outputs, show_progress=False ) demo.queue() demo.launch()