File size: 686 Bytes
6bf8f89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import pipeline

umbrellaSubjects = [
    "Science, Technology, Engineering, Mathematics",
    "Philosophy",
    "Arts and Humanities",
    "Social Sciences",
    "Languages",
    "Professional Studies"
]
classifier =  pipeline("zero-shot-classification", model="vicgalle/xlm-roberta-large-xnli-anli")

def process_string(string):
    # Remove whitespace from string
    string = string.replace(' ', '')
    # Split string by comma
    string_list = string.split(',')
    # Return list with no whitespace
    return [term.strip() for term in string_list]
def classify(sentences,categories):
    return classifier(process_string(sentences),process_string(categories))