lrschuman17 commited on
Commit
9b1fd3e
·
verified ·
1 Parent(s): 0117ed0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -27
app.py CHANGED
@@ -1,35 +1,60 @@
1
- import requests
2
  import pandas as pd
 
3
 
4
- # Read your data
5
- data = pd.read_csv('/path/to/torn.csv')
6
- notes = data['Notes'].tolist() # Ensure 'Notes' is a column in your CSV
7
 
8
- # Define your candidate labels for classification
9
  candidate_labels = ["ACL Tear", "Meniscus Tear", "Achilles Tear", "Fracture", "Hamstring", "Foot", "Shoulder", "Hip", "Calf", "Hand", "Wrist"]
10
 
11
- # Hugging Face API details
12
- API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-mnli"
13
- headers = {"Authorization": f"Bearer YOUR_HUGGING_FACE_API_KEY"} # Replace with your actual API key
 
 
 
14
 
15
- def classify_with_api(note):
16
- # Prepare payload
17
- payload = {
18
- "inputs": note,
19
- "parameters": {"candidate_labels": candidate_labels},
20
- }
21
- # Make request to API
22
- response = requests.post(API_URL, headers=headers, json=payload)
23
- if response.status_code == 200:
24
- return response.json()
25
- else:
26
- print("Error:", response.status_code, response.text)
27
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Classify each note and store results
30
- data['Classifications'] = data['Notes'].apply(lambda x: classify_with_api(x))
31
- data['Top Classification'] = data['Classifications'].apply(lambda x: x['labels'][0] if x else None)
32
- data['Top Score'] = data['Classifications'].apply(lambda x: x['scores'][0] if x else None)
 
 
 
 
33
 
34
- # Display results
35
- print(data[['Notes', 'Top Classification', 'Top Score']])
 
1
+ import gradio as gr
2
  import pandas as pd
3
+ from transformers import pipeline
4
 
5
+ # Initialize the zero-shot classifier
6
+ classifier = pipeline('zero-shot-classification', model='distilbert-base-uncased')
 
7
 
8
+ # Define the possible categories (more granular categories)
9
  candidate_labels = ["ACL Tear", "Meniscus Tear", "Achilles Tear", "Fracture", "Hamstring", "Foot", "Shoulder", "Hip", "Calf", "Hand", "Wrist"]
10
 
11
+ def classify_injuries(file):
12
+ # Load the uploaded CSV file
13
+ df = pd.read_csv(file.name)
14
+
15
+ # Limit to a sample (e.g., first 100 rows) if necessary for performance
16
+ new_df = df.head(100).copy()
17
 
18
+ # Apply zero-shot classification to each note in the 'Notes' column
19
+ classifications = classifier(new_df['Notes'].tolist(), candidate_labels)
20
+
21
+ # Add the classification results to the DataFrame
22
+ new_df['Classifications'] = classifications
23
+ new_df['Top Classification'] = new_df['Classifications'].apply(lambda x: x['labels'][0] if isinstance(x, dict) else None)
24
+ new_df['Top Score'] = new_df['Classifications'].apply(lambda x: x['scores'][0] if isinstance(x, dict) else None)
25
+
26
+ # Initialize the 'Specific Injury' column with default value
27
+ new_df['Specific Injury'] = None
28
+
29
+ # Define a function to determine the specific injury based on keywords
30
+ def extract_specific_injury(note, injury):
31
+ note = note.lower()
32
+ if "left" in note:
33
+ return f"left {injury.lower()} injury"
34
+ elif "right" in note:
35
+ return f"right {injury.lower()} injury"
36
+ else:
37
+ return f"{injury.lower()} injury"
38
+
39
+ # Apply specific injury classification based on keywords
40
+ for injury in candidate_labels:
41
+ new_df.loc[new_df['Top Classification'].str.contains(injury, case=False, na=False), 'Specific Injury'] = \
42
+ new_df['Notes'].apply(lambda x: extract_specific_injury(x, injury) if injury.lower() in x.lower() else None)
43
+
44
+ # Sort by 'Top Score' in descending order
45
+ new_df_sorted = new_df.sort_values(by='Top Score', ascending=False)
46
+
47
+ # Return a subset of columns for clarity
48
+ return new_df_sorted[['Notes', 'Top Classification', 'Top Score', 'Specific Injury']]
49
 
50
+ # Set up the Gradio interface
51
+ iface = gr.Interface(
52
+ fn=classify_injuries,
53
+ inputs=gr.File(label="Upload CSV File (must have a 'Notes' column)"),
54
+ outputs="dataframe",
55
+ title="Injury Classification App",
56
+ description="Upload a CSV file with injury notes. The app classifies each note based on specified injury types and provides specific classifications where possible."
57
+ )
58
 
59
+ # Launch the Gradio app
60
+ iface.launch()