Spaces:
Build error
Build error
Upload memory_support_chatbot_for_pregnant_women.py
Browse files
memory_support_chatbot_for_pregnant_women.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import pandas as pd # Pandas for data manipulation and analysis
|
| 4 |
+
import nltk # NLTK for natural language processing tasks
|
| 5 |
+
from nltk.corpus import stopwords # Stopwords from NLTK
|
| 6 |
+
from nltk.tokenize import word_tokenize # Word tokenizer from NLTK
|
| 7 |
+
import streamlit as st # Streamlit for creating interactive web apps
|
| 8 |
+
import matplotlib.pyplot as plt # Matplotlib for data visualization
|
| 9 |
+
from wordcloud import WordCloud # Wordcloud for generating word clouds
|
| 10 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel # GPT-2 model from Transformers
|
| 11 |
+
import torch # PyTorch for deep learning tasks
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset,TensorDataset # DataLoader and Dataset for handling data
|
| 13 |
+
from transformers import GPT2Config, GPT2LMHeadModel, AdamW # AdamW optimizer for GPT-2 model
|
| 14 |
+
from transformers import AdamW, get_scheduler # Scheduler for optimizer
|
| 15 |
+
from torch.nn.utils.rnn import pad_sequence # Padding sequences for model input
|
| 16 |
+
from nltk.sentiment import SentimentIntensityAnalyzer # Sentiment analysis from NLTK
|
| 17 |
+
from sklearn.feature_extraction.text import TfidfVectorizer # TF-IDF vectorizer
|
| 18 |
+
from sklearn.decomposition import LatentDirichletAllocation # LDA for topic modeling
|
| 19 |
+
nltk.download('vader_lexicon') # Download the VADER lexicon for sentiment analysis
|
| 20 |
+
sia = SentimentIntensityAnalyzer() # Initialize the SentimentIntensityAnalyzer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import nltk
|
| 24 |
+
|
| 25 |
+
def setup_nltk():
|
| 26 |
+
try:
|
| 27 |
+
# Download NLTK resources if not already downloaded
|
| 28 |
+
nltk.data.find('tokenizers/punkt')
|
| 29 |
+
nltk.data.find('corpora/stopwords')
|
| 30 |
+
nltk.data.find('sentiment/vader_lexicon')
|
| 31 |
+
except LookupError:
|
| 32 |
+
nltk.download('punkt')
|
| 33 |
+
nltk.download('stopwords')
|
| 34 |
+
nltk.download('vader_lexicon')
|
| 35 |
+
|
| 36 |
+
setup_nltk()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # GPT-2 tokenizer
|
| 45 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2') # GPT-2 model
|
| 46 |
+
model.resize_token_embeddings(len(tokenizer)) # Resize token embeddings
|
| 47 |
+
|
| 48 |
+
txt_files = [
|
| 49 |
+
"Cognition in Pregnancy- Perceptions and Performance, 2005-2006 - Dataset - B2FIND.txt",
|
| 50 |
+
"Frontiers | Cognitive disorder and associated factors among pregnant women attending antenatal servi.txt",
|
| 51 |
+
"Frustrated By Brain Fog? How Pregnancy Actually Alters Yo....txt",
|
| 52 |
+
"Is Pregnancy Brain Real?.txt",
|
| 53 |
+
"Is ‘pregnancy brain’ real or just a myth? | Your Pregnancy Matters | UT Southwestern Medical Center.txt",
|
| 54 |
+
"Memory and affective changes during the antepartum- A narrative review and integrative hypothesis- J.txt",
|
| 55 |
+
"Pregnancy 'does cause memory loss' | Medical research | The Guardian.txt",
|
| 56 |
+
"Pregnancy Brain — Forgetfulness During Pregnancy.txt",
|
| 57 |
+
"Pregnancy brain- When it starts and what causes pregnancy brain fog | BabyCenter.txt",
|
| 58 |
+
"Pregnancy does cause memory loss, study says - CNN.txt",
|
| 59 |
+
"Textbook J.A. Russell, A.J. Douglas, R.J. Windle, C.D. Ingram - The Maternal Brain_ Neurobiological and Neuroendocrine Adaptation and Disorders in Pregnancy & Post Partum-Elsevier Science (2001).txt",
|
| 60 |
+
"The effect of pregnancy on maternal cognition - PMC.txt",
|
| 61 |
+
"This Is Your Brain on Motherhood - The New York Times.txt",
|
| 62 |
+
"Working memory from pregnancy to postpartum.txt",
|
| 63 |
+
"What Is Mom Brain and Is It Real?.txt",
|
| 64 |
+
"Memory loss in Pregnancy- Myth or Fact? - International Forum for Wellbeing in Pregnancy.txt",
|
| 65 |
+
"Memory and mood changes in pregnancy- a qualitative content analysis of women’s first-hand accounts.txt",
|
| 66 |
+
"Is Mom Brain real? Understanding and coping with postpartum brain fog.txt",
|
| 67 |
+
"Everyday Life Memory Deficits in Pregnant Women.txt",
|
| 68 |
+
"Cognitive Function Decline in the Third Trimester.txt",
|
| 69 |
+
"'Mommy brain' might be a good thing, new research suggests | CBC Radio.txt"
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
data = []
|
| 73 |
+
for file_path in txt_files:
|
| 74 |
+
with open(file_path, "r") as file:
|
| 75 |
+
text = file.read()
|
| 76 |
+
data.append({"text": text})
|
| 77 |
+
|
| 78 |
+
df = pd.DataFrame(data)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
print(df)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
nltk.download('punkt') # Download the 'punkt' tokenizer models for tokenization
|
| 85 |
+
df['tokens'] = df['text'].apply(word_tokenize) # Tokenize each text in the 'text' column into a list of words
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
nltk.download('stopwords') # Download the stopwords corpus for English
|
| 89 |
+
stop_words = set(stopwords.words('english')) # Load the English stopwords into a set
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
df['cleaned_text'] = df['tokens'].apply(lambda x: [word.lower() for word in x if (word.isalnum() and word.lower() not in stop_words)])
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
df['cleaned_text'] = df['cleaned_text'].apply(lambda x: ' '.join(x))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
print(df['cleaned_text'])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
df['word_count'] = df['cleaned_text'].apply(lambda x: len(x.split()))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
average_length = df['word_count'].mean()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
min_word_count = df['word_count'].min()
|
| 108 |
+
max_word_count = df['word_count'].max()
|
| 109 |
+
|
| 110 |
+
print(f"Average length of articles: {average_length:.2f} words")
|
| 111 |
+
print(f"Minimum word count: {min_word_count} words")
|
| 112 |
+
print(f"Maximum word count: {max_word_count} words")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
print(df[['cleaned_text', 'word_count']])
|
| 116 |
+
|
| 117 |
+
plt.figure(figsize=(10, 6))
|
| 118 |
+
plt.hist(df['word_count'], bins=20, color='brown', edgecolor='black')
|
| 119 |
+
plt.xlabel('Word Count')
|
| 120 |
+
plt.ylabel('Frequency')
|
| 121 |
+
plt.title('Distribution of Word Counts')
|
| 122 |
+
plt.show()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
all_text = " ".join(df["cleaned_text"])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
wordcloud = WordCloud(width=800, height=400, background_color="white").generate(all_text)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
plt.figure(figsize=(10, 5))
|
| 132 |
+
plt.imshow(wordcloud, interpolation="bilinear")
|
| 133 |
+
plt.axis("off")
|
| 134 |
+
plt.title('Word Cloud of Cognitive Memory Issues')
|
| 135 |
+
plt.show()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def cleaned_text(text):
|
| 140 |
+
|
| 141 |
+
cleaned_text = text.lower() # Example: convert text to lowercase
|
| 142 |
+
return cleaned_text
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
df["cleaned_text"] = df["text"].apply(cleaned_text)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
types_of_issues = ['Memory Loss', 'Difficulty Concentrating', 'Forgetfulness', 'Brain Fog', 'Others']
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
frequencies = {issue: 0 for issue in types_of_issues}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
for text in df["cleaned_text"]:
|
| 155 |
+
for issue in types_of_issues:
|
| 156 |
+
if issue.lower() in text: # Example: use lowercase for comparison
|
| 157 |
+
frequencies[issue] += 1
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
df_frequencies = pd.DataFrame(list(frequencies.items()), columns=['Types of cognitive memory issues', 'Frequency'])
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
plt.figure(figsize=(10, 6))
|
| 164 |
+
plt.bar(df_frequencies['Types of cognitive memory issues'], df_frequencies['Frequency'], color='skyblue')
|
| 165 |
+
plt.xlabel('Types of cognitive memory issues')
|
| 166 |
+
plt.ylabel('Frequency')
|
| 167 |
+
plt.title('Frequency of Cognitive Memory Issues')
|
| 168 |
+
plt.xticks(rotation=45)
|
| 169 |
+
plt.show()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
print(df['cleaned_text'])
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
tfidf_vectorizer = TfidfVectorizer(stop_words='english')
|
| 178 |
+
tfidf_matrix = tfidf_vectorizer.fit_transform(df['cleaned_text'])
|
| 179 |
+
|
| 180 |
+
lda = LatentDirichletAllocation(n_components=5, random_state=42)
|
| 181 |
+
lda.fit(tfidf_matrix)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
terms = tfidf_vectorizer.get_feature_names_out()
|
| 185 |
+
topics = [[terms[i] for i in topic.argsort()[:-6:-1]] for topic in lda.components_]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
for i, topic in enumerate(topics):
|
| 189 |
+
print(f"Topic {i+1}: {', '.join(topic)}")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
df['sentiment_score'] = df['cleaned_text'].apply(lambda x: sia.polarity_scores(x)['compound'])
|
| 194 |
+
|
| 195 |
+
plt.figure(figsize=(10, 6))
|
| 196 |
+
plt.hist(df['sentiment_score'], bins=20, color='green')
|
| 197 |
+
plt.xlabel('Sentiment Score')
|
| 198 |
+
plt.ylabel('Frequency')
|
| 199 |
+
plt.title('Sentiment Analysis of Articles')
|
| 200 |
+
plt.show()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def baseline_generate_response(raw_text):
|
| 207 |
+
# Tokenize the raw text
|
| 208 |
+
input_ids = tokenizer(raw_text, return_tensors='pt')['input_ids']
|
| 209 |
+
# Generate output
|
| 210 |
+
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
|
| 211 |
+
# Decode the output
|
| 212 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 213 |
+
return response
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
raw_text1 = "How does pregnancy affect memory?"
|
| 217 |
+
raw_text2 = "What are the effects of pregnancy on cognitive function?"
|
| 218 |
+
baseline_response_raw1 = baseline_generate_response(raw_text1)
|
| 219 |
+
baseline_response_raw2 = baseline_generate_response(raw_text2)
|
| 220 |
+
print("Baseline Response (Raw Text 1):", baseline_response_raw1.rstrip('!'))
|
| 221 |
+
print("Baseline Response (Raw Text 2):", baseline_response_raw2.rstrip('!'))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def baseline_generate_response(cleaned_text):
|
| 227 |
+
input_ids = tokenizer.encode(cleaned_text, return_tensors='pt')
|
| 228 |
+
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, attention_mask=input_ids.ne(tokenizer.eos_token_id))
|
| 229 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 230 |
+
return response
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
cleaned_text1 = "How does pregnancy affect memory?"
|
| 234 |
+
cleaned_text2 = "What are the effects of pregnancy on cognitive function?"
|
| 235 |
+
baseline_response1 = baseline_generate_response(cleaned_text1)
|
| 236 |
+
baseline_response2 = baseline_generate_response(cleaned_text2)
|
| 237 |
+
print("Baseline Response 1:", baseline_response1)
|
| 238 |
+
print("Baseline Response 2:", baseline_response2)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
max_length = 512
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
df['tokenized_text'] = df['cleaned_text'].apply(lambda x: tokenizer.encode(x[:max_length], return_tensors='pt'))
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
padding_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
| 250 |
+
|
| 251 |
+
padded_sequences = pad_sequence([seq.squeeze(0)[:max_length] for seq in df['tokenized_text']], batch_first=True, padding_value=padding_value)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
input_ids = torch.cat(tuple(padded_sequences), dim=0)
|
| 255 |
+
labels = input_ids.clone()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if 'input_ids' in locals() and 'labels' in locals():
|
| 259 |
+
print("input_ids and labels are defined.")
|
| 260 |
+
else:
|
| 261 |
+
print("input_ids and labels are not defined.")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
num_epochs = 3
|
| 265 |
+
learning_rate = 5e-5 # Adjusted learning rate
|
| 266 |
+
weight_decay = 0.01 # Adjusted weight decay
|
| 267 |
+
warmup_steps = 500 # Adjusted warmup steps
|
| 268 |
+
max_seq_length = 1024 # Maximum sequence length
|
| 269 |
+
|
| 270 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
| 271 |
+
scheduler = get_scheduler("linear", optimizer, num_warmup_steps=warmup_steps, num_training_steps=len(df) * num_epochs)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
model.train()
|
| 275 |
+
for epoch in range(num_epochs):
|
| 276 |
+
total_loss = 0.0
|
| 277 |
+
for text in df['cleaned_text']:
|
| 278 |
+
input_ids = tokenizer.encode(text, return_tensors='pt', max_length=max_seq_length, truncation=True)
|
| 279 |
+
optimizer.zero_grad()
|
| 280 |
+
outputs = model(input_ids=input_ids, labels=input_ids)
|
| 281 |
+
loss = outputs.loss
|
| 282 |
+
total_loss += loss.item()
|
| 283 |
+
loss.backward()
|
| 284 |
+
optimizer.step()
|
| 285 |
+
scheduler.step()
|
| 286 |
+
average_loss = total_loss / len(df)
|
| 287 |
+
print(f"Epoch {epoch+1}: Average Loss = {average_loss}")
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
model.save_pretrained('fine_tuned_gpt2_model')
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
fine_tuned_model = GPT2LMHeadModel.from_pretrained('fine_tuned_gpt2_model')
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def baseline_generate_response(input_text):
|
| 298 |
+
# Tokenize the input text
|
| 299 |
+
input_ids = tokenizer.encode(input_text, return_tensors='pt')
|
| 300 |
+
# Generate output
|
| 301 |
+
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, attention_mask=input_ids.ne(tokenizer.eos_token_id))
|
| 302 |
+
# Decode the output
|
| 303 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 304 |
+
return response
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def fine_tuned_generate_response(input_text):
|
| 308 |
+
input_ids = tokenizer.encode(input_text, return_tensors='pt')
|
| 309 |
+
output = fine_tuned_model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, attention_mask=input_ids.ne(tokenizer.eos_token_id))
|
| 310 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 311 |
+
return response
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
input_text = "How does pregnancy affect memory?"
|
| 315 |
+
|
| 316 |
+
baseline_response = baseline_generate_response(input_text)
|
| 317 |
+
fine_tuned_response = fine_tuned_generate_response(input_text)
|
| 318 |
+
|
| 319 |
+
print("Baseline Response:", baseline_response)
|
| 320 |
+
print("Fine-Tuned Response:", fine_tuned_response)
|
| 321 |
+
|
| 322 |
+
print("Are the responses the same?")
|
| 323 |
+
print(baseline_response == fine_tuned_response)
|
| 324 |
+
|
| 325 |
+
try:
|
| 326 |
+
st.title("Memory Support Chatbox for Pregnant Women")
|
| 327 |
+
user_input = st.text_input("You:", "Enter your message here...")
|
| 328 |
+
if user_input:
|
| 329 |
+
input_ids = tokenizer.encode(user_input, return_tensors='pt')
|
| 330 |
+
reply_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
|
| 331 |
+
reply_text = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
|
| 332 |
+
st.text_area("Chatbot:", value=reply_text, height=200)
|
| 333 |
+
except Exception as e:
|
| 334 |
+
st.error(f"An error occurred: {e}")
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
st.subheader("Text Analysis")
|
| 338 |
+
|
| 339 |
+
st.subheader("Word Cloud")
|
| 340 |
+
all_text = " ".join(df["cleaned_text"])
|
| 341 |
+
wordcloud = WordCloud(width=800, height=400, background_color="white").generate(all_text)
|
| 342 |
+
plt.figure(figsize=(10, 5))
|
| 343 |
+
plt.imshow(wordcloud, interpolation="bilinear")
|
| 344 |
+
plt.axis("off")
|
| 345 |
+
st.pyplot()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
sample_prompts = [
|
| 351 |
+
"What causes pregnancy brain fog?",
|
| 352 |
+
"How does pregnancy affect the brain?",
|
| 353 |
+
"How can I improve my memory during pregnancy?",
|
| 354 |
+
"Can pregnancy brain fog affect my ability to work or perform daily tasks?",
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
for prompt in sample_prompts:
|
| 359 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
| 360 |
+
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, attention_mask=input_ids.ne(tokenizer.eos_token_id))
|
| 361 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 362 |
+
print("Response:", response)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
sample_prompts = [
|
| 367 |
+
"What causes pregnancy brain fog?",
|
| 368 |
+
"How does pregnancy affect the brain?",
|
| 369 |
+
"How can I improve my memory during pregnancy?",
|
| 370 |
+
"Can pregnancy brain fog affect my ability to work or perform daily tasks?",
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
data = []
|
| 375 |
+
for prompt in sample_prompts:
|
| 376 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
| 377 |
+
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, attention_mask=input_ids.ne(tokenizer.eos_token_id))
|
| 378 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 379 |
+
data.append({"prompt": prompt, "response": response})
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
df = pd.DataFrame(data)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
print(df)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
plt.figure(figsize=(8, 8))
|
| 391 |
+
plt.pie([len(response.split()) for response in df['response']], labels=df['prompt'], autopct='%1.1f%%', startangle=90)
|
| 392 |
+
plt.title('Distribution of response lengths for sample prompts')
|
| 393 |
+
plt.axis('equal')
|
| 394 |
+
plt.show()
|