OttoYu commited on
Commit
8cd3b31
1 Parent(s): 107d2ee

Create run.py

Browse files
Files changed (1) hide show
  1. run.py +47 -0
run.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
5
+
6
+ # Load the model from Hugging Face
7
+ model_path = "your_model_path" # Replace with your own model path
8
+ tokenizer = AutoTokenizer.from_pretrained("OttoYu/Tree-Dbh")
9
+ model = AutoModelForSequenceClassification.from_pretrained("OttoYu/Tree-Dbh")
10
+
11
+ # Set up the inference pipeline
12
+ text_classification = pipeline(
13
+ "text-classification",
14
+ model=model,
15
+ tokenizer=tokenizer,
16
+ device=0 if torch.cuda.is_available() else -1,
17
+ return_all_scores=True,
18
+ )
19
+
20
+ # Define a function to get the predicted tree height and crown spread for a given dbh
21
+ def predict_tree_properties(dbh):
22
+ # Prepare the input text
23
+ input_text = f"dbh: {dbh}"
24
+
25
+ # Get the predicted probabilities for each class
26
+ results = text_classification(input_text)
27
+ probs = results[0]["scores"]
28
+
29
+ # Convert the probabilities to tree height and crown spread
30
+ tree_height = probs[0] * 100 # Scale the probability to 0-100
31
+ crown_spread = probs[1] * 10 # Scale the probability to 0-10
32
+
33
+ # Return the predicted tree properties
34
+ return {"tree_height": tree_height, "crown_spread": crown_spread}
35
+
36
+ # Define a function to get user input and display the predicted tree properties
37
+ def run_inference():
38
+ # Get user input for dbh
39
+ dbh = input("Enter the dbh value (in cm): ")
40
+
41
+ # Make the prediction and display the results
42
+ tree_properties = predict_tree_properties(dbh)
43
+ print(f"Predicted Tree Height: {tree_properties['tree_height']:.2f} m")
44
+ print(f"Predicted Crown Spread: {tree_properties['crown_spread']:.2f} m")
45
+
46
+ # Call the function to run the inference
47
+ run_inference()