Sheng Lei commited on
Commit
2a6f3b3
1 Parent(s): 51c5d35

more fixes

Browse files
Files changed (3) hide show
  1. app.py +10 -5
  2. restrictedItems/predict.py +2 -1
  3. restrictedItems/train.py +2 -1
app.py CHANGED
@@ -15,10 +15,15 @@ model = BertForSequenceClassification.from_pretrained("sleiyer/restricted_item_d
15
  # Load the trained model and tokenizer
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
17
 
 
 
 
 
 
18
  # Function to predict the class of a single input text
19
- def predict(text: str):
20
  # Preprocess the input text
21
- inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
22
 
23
  # Make predictions
24
  with torch.no_grad():
@@ -34,10 +39,10 @@ def predict(text: str):
34
  predicted_label = label_map[predicted_class]
35
 
36
  # Displaying the user input
37
- return f'The item "{text}" is classified as: "{predicted_label}"'
38
 
39
  return predicted_class
40
 
41
  @app.post("/predict")
42
- def predict(input):
43
- return predict(input)
 
15
  # Load the trained model and tokenizer
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
17
 
18
+ from pydantic import BaseModel
19
+
20
+ class Predict(BaseModel):
21
+ input: str
22
+
23
  # Function to predict the class of a single input text
24
+ def predict(request: Predict):
25
  # Preprocess the input text
26
+ inputs = tokenizer(request.input, return_tensors='pt', truncation=True, padding=True)
27
 
28
  # Make predictions
29
  with torch.no_grad():
 
39
  predicted_label = label_map[predicted_class]
40
 
41
  # Displaying the user input
42
+ return f'The item "{request.input}" is classified as: "{predicted_label}"'
43
 
44
  return predicted_class
45
 
46
  @app.post("/predict")
47
+ def predictApi(request: Predict):
48
+ return predict(request)
restrictedItems/predict.py CHANGED
@@ -2,7 +2,8 @@ from transformers import BertTokenizer, BertForSequenceClassification
2
  import torch
3
 
4
  # Load the trained model and tokenizer
5
- model = BertForSequenceClassification.from_pretrained("/Users/slei/hackweek2024-sup-genai-tools/hackweek2024-sup-genai-tools/trained_model")
 
6
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
7
 
8
  # Function to predict the class of a single input text
 
2
  import torch
3
 
4
  # Load the trained model and tokenizer
5
+ # model = BertForSequenceClassification.from_pretrained("/Users/slei/hackweek2024-sup-genai-tools/spaces/restricted_item_detector/trained_model")
6
+ model = BertForSequenceClassification.from_pretrained("sleiyer/restricted_item_detector")
7
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
 
9
  # Function to predict the class of a single input text
restrictedItems/train.py CHANGED
@@ -100,7 +100,7 @@ train_dataset = ShoppingCartDataset(train_encodings, train_labels)
100
  val_dataset = ShoppingCartDataset(val_encodings, val_labels)
101
 
102
  # Load pre-trained BERT model
103
- model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
104
 
105
  # Training arguments
106
  training_args = TrainingArguments(
@@ -128,4 +128,5 @@ trainer.train()
128
  # Evaluate model
129
  trainer.evaluate()
130
 
 
131
  model.push_to_hub("sleiyer/restricted_item_detector")
 
100
  val_dataset = ShoppingCartDataset(val_encodings, val_labels)
101
 
102
  # Load pre-trained BERT model
103
+ model = BertForSequenceClassification.from_pretrained('sleiyer/restricted_item_detector')
104
 
105
  # Training arguments
106
  training_args = TrainingArguments(
 
128
  # Evaluate model
129
  trainer.evaluate()
130
 
131
+ model.save_pretrained('trained_model')
132
  model.push_to_hub("sleiyer/restricted_item_detector")