import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-scico') model = AutoModelForSequenceClassification.from_pretrained('allenai/longformer-scico') start_token = tokenizer.convert_tokens_to_ids("") end_token = tokenizer.convert_tokens_to_ids("") def get_global_attention(input_ids): global_attention_mask = torch.zeros(input_ids.shape) global_attention_mask[:, 0] = 1 # global attention to the CLS token start = torch.nonzero(input_ids == start_token) # global attention to the token end = torch.nonzero(input_ids == end_token) # global attention to the token globs = torch.cat((start, end)) value = torch.ones(globs.shape[0]) global_attention_mask.index_put_(tuple(globs.t()), value) return global_attention_mask def inference(m1,m2): b = {} m1 = m1 m2 = m2 inputs = m1 + " " + m2 tokens = tokenizer(inputs, return_tensors='pt') global_attention_mask = get_global_attention(tokens['input_ids']) with torch.no_grad(): output = model(tokens['input_ids'], tokens['attention_mask'], global_attention_mask) scores = torch.softmax(output.logits, dim=-1) listscore = scores.tolist() b['not related'] = listscore[0] b['coref'] = listscore[1] b['parent'] = listscore[2] b['child'] = listscore[3] return b title = "Longformer-scico" description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." article = "

Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis | Github Repo

" gr.Interface( inference, gr.inputs.Textbox(label="Input"), gr.outputs.Label(label="Output"), title=title, description=description, article=article, enable_queue=True ).launch(debug=True)