Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModel | |
import pandas as pd | |
import numpy as np | |
import random | |
import torch | |
import os | |
import gradio as gr | |
# Model | |
auth_token = os.environ.get("TOKEN_FROM_SECRET") | |
checkpoint = 'srota/job-bert-mini' | |
model = AutoModel.from_pretrained(checkpoint, token=auth_token) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, token=auth_token) | |
# Data | |
titles = pd.read_csv('inventory.csv', usecols=['title'])['title'].tolist() | |
descriptions = pd.read_csv('inventory.csv', usecols=['description'])['description'].tolist() | |
with open('inventory.npy', 'rb') as f: | |
embeddings = np.load(f) | |
# Inference | |
def inference(query, top_k=5): | |
with torch.no_grad(): | |
inputs = tokenizer([query], padding=True, truncation=True, max_length=512, return_tensors='pt') | |
query_embedding = model(**inputs)['last_hidden_state'][:,0,:].detach().numpy() | |
cosines = np.dot(query_embedding, embeddings.T)[0] | |
indexes = np.argsort(cosines)[-top_k:] | |
return '\n\n'.join(['*' + t for i, t in enumerate(titles) if i in indexes]) | |
# Gradio | |
examples = [['Data Scientist'], ['Warehouse Worker'], ['Gardener'], ['Part-Time Cleaner'], ['Math Teacher'], ['Registered Nurse'], ['Line Cook'],['Night Porter'],['Dietitian'],['Planned Surveyor'],['Driving Instructor'],['Senior It Engineer'],['Stores Person'],['Dental Hygienist'],['Event Manager'],['Welder'],['Underwriter'],['Frontend Developer'],['Paralegal'],['Copywriter'],['Community Nurse'],['Courier'],['Personal Trainer'],['Night Porter'],['Pharmacist'],['Carpenter']] | |
demo = gr.Interface( | |
fn=inference, | |
title='Job Search', | |
description='Simulate a semantic search for retrieving job titles that match the user query (the match is performed between the user query and 15K job descriptions)', | |
inputs=gr.Textbox(lines=1, placeholder='', label="User keyword"), | |
outputs=gr.Textbox(lines=10, label="Relevant jobs"), | |
examples=random.sample(examples, 10) | |
) | |
demo.launch() |