nebiyu29 commited on
Commit
16a1d49
1 Parent(s): 71d31ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import gradio as gr
3
+
4
+ # Load model directly
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ import torch
8
+ import transformers
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
11
+ model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
12
+
13
+
14
+ # Load the model and tokenizer
15
+ # model = transformers.AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
16
+
17
+ # tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
18
+
19
+ # Define a function to split a text into segments of 512 tokens
20
+ def split_text(text):
21
+ # Tokenize the text
22
+ tokens = tokenizer.tokenize(text)
23
+ # Initialize an empty list for segments
24
+ segments = []
25
+ # Initialize an empty list for current segment
26
+ current_segment = []
27
+ # Initialize a counter for tokens
28
+ token_count = 0
29
+ # Loop through the tokens
30
+ for token in tokens:
31
+ # Add the token to the current segment
32
+ current_segment.append(token)
33
+ # Increment the token count
34
+ token_count += 1
35
+ # If the token count reaches 512 or the end of the text, add the current segment to the segments list
36
+ if token_count == 512 or token == tokens[-1]:
37
+ # Convert the current segment to a string and add it to the segments list
38
+ segments.append(tokenizer.convert_tokens_to_string(current_segment))
39
+ # Reset the current segment and the token count
40
+ current_segment = []
41
+ token_count = 0
42
+ # Return the segments list
43
+ return segments
44
+
45
+ # a function that classifies text
46
+
47
+ def classify_text(text):
48
+ # Define labels
49
+ labels = ["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"]
50
+
51
+ # Split text into segments using split_text
52
+ segments = split_text(text)
53
+
54
+ # Initialize empty list for predictions
55
+ predictions = []
56
+
57
+ # Move device to GPU if available
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ model = model.to(device)
60
+
61
+ # Loop through segments, process, and store predictions
62
+ for segment in segments:
63
+ inputs = tokenizer([segment], padding=True, return_tensors="pt")
64
+ input_ids = inputs["input_ids"].to(device)
65
+ attention_mask = inputs["attention_mask"].to(device)
66
+
67
+ with torch.no_grad():
68
+ outputs = model(input_ids, attention_mask=attention_mask)
69
+
70
+ # Extract predictions for each segment
71
+ probs, preds = extract_predictions(outputs) # Define this function based on your model's output
72
+
73
+ # Append predictions for this segment
74
+ predictions.append({
75
+ "segment_text": segment,
76
+ "label": preds[0], # Assuming single label prediction
77
+ "probability": probs[preds[0]] # Access probability for the predicted label
78
+ })
79
+
80
+
81
+
82
+ # Define a function to extract predictions from model output (adjust as needed)
83
+ def extract_predictions(outputs):
84
+ # Assuming outputs contain logits and labels (adapt based on your model's output format)
85
+ logits = outputs.logits
86
+ probs = logits.softmax(dim=1)
87
+ preds = torch.argmax(probs, dim=1)
88
+ return probs, preds # Return all probabilities and predicted labels
89
+
90
+
91
+
92
+ # def classify_text(text):
93
+ # """
94
+ # This function preprocesses, feeds text to the model, and outputs the predicted class.
95
+ # """
96
+ # inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
97
+ # outputs = model(**inputs)
98
+ # logits = outputs.logits # Access logits instead of pipeline output
99
+ # predictions = torch.argmax(logits, dim=-1) # Apply argmax for prediction
100
+ # return model.config.id2label[predictions.item()] # Map index to class label
101
+
102
+ interface = gr.Interface(
103
+ fn=classify_text,
104
+ inputs="text",
105
+ outputs="text",
106
+ title="Text Classification Demo",
107
+ description="Enter some text, and the model will classify it.",
108
+ #choices=["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"] # Adjust class names
109
+ )
110
+
111
+ #interface.launch()