import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("monsoon-nlp/gpt-winowhy") model = GPT2LMHeadModel.from_pretrained("monsoon-nlp/gpt-winowhy", pad_token_id=tokenizer.eos_token_id) def hello(prompt, items): inp = prompt.strip() if inp[-1] not in ['?', '!', '.']: inp += '.' inp += ' %' input_ids = torch.tensor([tokenizer.encode(inp)]) output = model.generate(input_ids, max_length=35) resp = tokenizer.decode(output[0], skip_special_tokens=True) if '%' in resp: resp1 = resp[resp.index('%') + 1 : ] names = [] if ',' in items: for item in items.split(','): names.append(item.strip()) else: for word in prompt.split(' '): if word[0] == word[0].upper(): names.append(word) if len(names) > 2: # remove first one which assumedly is a capital names = names[1:] if (names[0] in resp1) and ((names[1] not in resp1) or (resp1.index(names[0]) < resp1.index(names[1]))): force_inp = resp1[:resp1.index(names[0])] + names[1] remainder = resp1[resp1.index(names[0]) + len(names[0]):].strip().split(' ') elif (names[1] in resp1): force_inp = resp1[:resp1.index(names[1])] + names[0] remainder = resp1[resp1.index(names[1]) + len(names[1]):].strip().split(' ') else: return [resp1, 'Name not present'] if len(remainder) > 0: if remainder[0] in ['is', 'are', 'was', 'were']: force_inp += ' ' + ' '.join(remainder[:2]) else: force_inp += ' ' + remainder[0] alt = inp + ' ' + force_inp input_ids2 = torch.tensor([tokenizer.encode(alt)]) output2 = model.generate(input_ids2, max_new_tokens=30, min_length=30, do_sample=True) resp2 = tokenizer.decode(output2[0], skip_special_tokens=True) resp2 = resp2[resp2.index('%') + 1 : ] return [resp1, resp2] io = gr.Interface(fn=hello, inputs=[ gr.inputs.Textbox(label="WinoWhy Prompt"), gr.inputs.Textbox(label="Answers (optional)"), ], outputs=[ gr.outputs.Textbox(label="Fine-tuned reply"), gr.outputs.Textbox(label="Alternative reply"), ], verbose=True, title='Anti-Explanations', description='Learn more at https://medium.com/nerd-for-tech/searching-for-anti-explanations-418d26816b44', #thumbnail='https://github.com/MonsoonNLP/gradio-gptnyc', analytics_enabled=True) io.launch(debug=True)