noequal commited on
Commit
c3da8c6
·
1 Parent(s): 8f7110d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -40
app.py CHANGED
@@ -1,50 +1,78 @@
1
- # Import the necessary libraries
2
  import streamlit as st
3
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline
4
  import torch
 
5
 
6
- # Load the gpt2-large model and tokenizer for text generation
7
- gen_model = GPT2LMHeadModel.from_pretrained('gpt2-large')
8
- gen_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
9
 
10
- # Load the zero-shot text classification pipeline from HuggingFace
11
- classifier = pipeline('zero-shot-classification')
 
12
 
13
- # Define a function that takes a text as input and returns a list of labels as output
14
- def generate_labels(text):
15
- # Append the special token [LABEL] to the text
16
- text = text + ' [LABEL]'
17
- # Convert the text to input ids and attention mask
18
- input_ids = gen_tokenizer.encode(text, return_tensors='pt')
19
- attention_mask = torch.ones_like(input_ids)
20
- # Generate up to 5 labels from the model
21
- outputs = gen_model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids)+5, do_sample=True, top_p=0.95)
22
- # Decode the generated text
23
- generated = gen_tokenizer.decode(outputs[0], skip_special_tokens=False)
24
- # Split the generated text by commas
25
- labels = generated.split(',')
26
- # Remove the special token and any whitespace from the labels
27
- labels = [label.replace('[LABEL]', '').strip() for label in labels]
28
- # Filter out any empty or duplicate labels
29
- labels = list(dict.fromkeys(filter(None, labels)))
30
- # Return the labels as a list
31
- return labels
32
 
33
  # Create a title and a text input for the app
34
  st.title('Thematic Analysis with GPT-2 Large')
35
- text = st.text_input('Enter some text to classify')
36
 
37
- # If the text is not empty, generate labels and classify the text
38
  if text:
39
- # Generate labels from the text
40
- labels = generate_labels(text)
41
- # Display the generated labels
42
- st.write(f'The generated labels are: {", ".join(labels)}')
43
- # Classify the text using the generated labels
44
- result = classifier(text, labels)
45
- # Get the label and the score with the highest probability
46
- label = result['labels'][0]
47
- score = result['scores'][0]
48
- # Display the label and the score
49
- st.write(f'The predicted label is: {label}')
50
- st.write(f'The probability is: {score:.4f}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
  import streamlit as st
 
3
  import torch
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, DataCollatorForLanguageModeling
5
 
6
+ # Step 1: Set Up Your Environment
7
+ # Environment setup and package installations.
 
8
 
9
+ # Step 2: Data Preparation
10
+ # Load and preprocess your CSV dataset.
11
+ df = pd.read_csv('stepkids_training_data.csv')
12
 
13
+ # Filter out rows with missing label data
14
+ df = df.dropna(subset=['Theme 1', 'Theme 2', 'Theme 3', 'Theme 4', 'Theme 5'])
15
+
16
+ text_list = df['Post Text'].tolist()
17
+ labels = df[['Theme 1', 'Theme 2', 'Theme 3', 'Theme 4', 'Theme 5']].values.tolist()
18
+
19
+ # Step 3: Model Selection
20
+ # Load your GPT-2 model for text generation.
21
+ model_name = "gpt2" # Choose the appropriate GPT-2 model variant
22
+ text_gen_model = GPT2LMHeadModel.from_pretrained(model_name)
23
+ text_gen_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
24
+ text_gen_tokenizer.pad_token = text_gen_tokenizer.eos_token
25
+
26
+ # Load your sequence classification model (e.g., BERT)
27
+ seq_classifier_model = GPT2ForSequenceClassification.from_pretrained("fine_tuned_classifier_model")
28
+ seq_classifier_tokenizer = GPT2Tokenizer.from_pretrained("fine_tuned_classifier_model")
29
+ seq_classifier_tokenizer.pad_token = seq_classifier_tokenizer.eos_token
 
 
30
 
31
  # Create a title and a text input for the app
32
  st.title('Thematic Analysis with GPT-2 Large')
33
+ text = st.text_area('Enter some text')
34
 
35
+ # If the text is not empty, perform both text generation and sequence classification
36
  if text:
37
+ # Perform text generation
38
+ generated_text = generate_text(text, text_gen_model, text_gen_tokenizer)
39
+ st.write('Generated Text:')
40
+ st.write(generated_text)
41
+
42
+ # Perform sequence classification
43
+ labels = classify_text(text, seq_classifier_model, seq_classifier_tokenizer)
44
+ st.write('Classified Labels:')
45
+ st.write(labels)
46
+
47
+ # Function for generating text based on input
48
+ def generate_text(input_text, model, tokenizer):
49
+ # Append the special token to the input
50
+ input_text = input_text + ' [LABEL]'
51
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
52
+ attention_mask = torch.ones_like(input_ids)
53
+ outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids) + 5, do_sample=True, top_p=0.95)
54
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
55
+ labels = generated.split(',')
56
+ labels = [label.replace('[LABEL]', '').strip() for label in labels]
57
+ return generated
58
+
59
+ # Function for sequence classification
60
+ def classify_text(input_text, model, tokenizer):
61
+ # Tokenize the input text
62
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
63
+ attention_mask = torch.ones_like(input_ids)
64
+ # Perform sequence classification
65
+ result = model(input_ids, attention_mask=attention_mask)
66
+ # Post-process the results (e.g., select labels based on a threshold)
67
+ labels = post_process_labels(result)
68
+ return labels
69
+
70
+ # Post-process labels based on a threshold or confidence score
71
+ def post_process_labels(results):
72
+ # Implement your logic to extract and filter labels
73
+ # based on your sequence classification model's output
74
+ # For example, you might use a threshold for each label's score
75
+ # to determine whether it should be considered a valid theme.
76
+ # Return the selected labels as a list.
77
+ selected_labels = []
78
+ return selected_labels