CesarLeblanc commited on
Commit
6f59e3c
1 Parent(s): 6b0aebb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -4,16 +4,11 @@ from datasets import load_dataset
4
  import requests
5
  from bs4 import BeautifulSoup
6
 
7
- def return_model(task):
8
- if task == 'classification':
9
- model = pipeline("text-classification", model="CesarLeblanc/test_model")
10
- else:
11
- model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
12
- return model
13
 
14
- def return_dataset():
15
- dataset = load_dataset("CesarLeblanc/text_classification_dataset")
16
- return dataset
 
17
 
18
  def return_text(habitat_label, habitat_score, confidence):
19
  if habitat_score*100 > confidence:
@@ -58,9 +53,7 @@ def return_species_image(species):
58
  return image
59
 
60
  def classification(text, typology, confidence):
61
- model = return_model("classification")
62
- dataset = return_dataset()
63
- result = model(text)
64
  habitat_label = result[0]['label']
65
  habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
66
  habitat_score = result[0]['score']
@@ -69,9 +62,8 @@ def classification(text, typology, confidence):
69
  return formatted_output, image_output
70
 
71
  def masking(text):
72
- model = return_model("masking")
73
  masked_text = text + ', [MASK] [MASK]'
74
- pred = model(masked_text, top_k=1)
75
  new_species = [pred[i][0]['token_str'] for i in range(len(pred))]
76
  new_species = ' '.join(new_species)
77
  text = f"The last species from this vegetation plot is probably {new_species}."
 
4
  import requests
5
  from bs4 import BeautifulSoup
6
 
 
 
 
 
 
 
7
 
8
+ classification_model = pipeline("text-classification", model="CesarLeblanc/test_model")
9
+ mask_model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
10
+
11
+ dataset = load_dataset("CesarLeblanc/text_classification_dataset")
12
 
13
  def return_text(habitat_label, habitat_score, confidence):
14
  if habitat_score*100 > confidence:
 
53
  return image
54
 
55
  def classification(text, typology, confidence):
56
+ result = classification_model(text)
 
 
57
  habitat_label = result[0]['label']
58
  habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
59
  habitat_score = result[0]['score']
 
62
  return formatted_output, image_output
63
 
64
  def masking(text):
 
65
  masked_text = text + ', [MASK] [MASK]'
66
+ pred = fill_model(masked_text, top_k=1)
67
  new_species = [pred[i][0]['token_str'] for i in range(len(pred))]
68
  new_species = ' '.join(new_species)
69
  text = f"The last species from this vegetation plot is probably {new_species}."