srota commited on
Commit
727155f
1 Parent(s): 70748bc

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import pandas as pd
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import os
7
+ import gradio as gr
8
+
9
+ # Model
10
+ auth_token = os.environ.get("TOKEN_FROM_SECRET")
11
+ checkpoint = 'srota/job-bert-mini'
12
+ model = AutoModel.from_pretrained(checkpoint, token=auth_token)
13
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, token=auth_token)
14
+
15
+ # Data
16
+ titles = pd.read_csv('inventory.csv', usecols=['title'])['title'].tolist()
17
+ descriptions = pd.read_csv('inventory.csv', usecols=['description'])['description'].tolist()
18
+ with open('inventory.npy', 'rb') as f:
19
+ embeddings = np.load(f)
20
+
21
+ # Inference
22
+ def inference(query, top_k=5):
23
+ with torch.no_grad():
24
+ inputs = tokenizer([query], padding=True, truncation=True, max_length=512, return_tensors='pt')
25
+ query_embedding = model(**inputs)['last_hidden_state'][:,0,:].detach().numpy()
26
+ cosines = np.dot(query_embedding, embeddings.T)[0]
27
+ indexes = np.argsort(cosines)[-top_k:]
28
+ return '\n\n'.join(['*' + t for i, t in enumerate(titles) if i in indexes])
29
+
30
+ # Gradio
31
+ examples = [['Data Scientist'], ['Warehouse Worker'], ['Gardener'], ['Part-Time Cleaner'], ['Math Teacher'], ['Registered Nurse'], ['Line Cook'],['Night Porter'],['Dietitian'],['Planned Surveyor'],['Driving Instructor'],['Senior It Engineer'],['Stores Person'],['Dental Hygienist'],['Event Manager'],['Welder'],['Underwriter'],['Frontend Developer'],['Paralegal'],['Copywriter'],['Community Nurse'],['Courier'],['Personal Trainer'],['Night Porter'],['Pharmacist'],['Carpenter']]
32
+
33
+ demo = gr.Interface(
34
+ fn=inference,
35
+ title='Job Search',
36
+ description='Simulate a semantic search for retrieving job titles that match the user query (the match is performed between the user query and 15K job descriptions)',
37
+ inputs=gr.Textbox(lines=1, placeholder='', label="User keyword"),
38
+ outputs=gr.Textbox(lines=10, label="Relevant jobs"),
39
+ examples=random.sample(examples, 10)
40
+ )
41
+
42
+ demo.launch()