user2434 commited on
Commit
1a0e502
β€’
1 Parent(s): c90606a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -15
app.py CHANGED
@@ -7,34 +7,51 @@ import email
7
 
8
  # loading and preprocessing dataset
9
  emails = pd.read_csv('emails.csv')
10
-
11
  def preprocess_email_content(raw_email):
12
  message = email.message_from_string(raw_email).get_payload()
13
  return message.replace("\n", "").replace("\r", "").replace("> >>> > >", "").strip()
14
 
15
  content_text = [preprocess_email_content(item) for item in emails['message']]
16
- train_content, _ = train_test_split(content_text, train_size=0.00005)
17
 
18
  # ChromaDB setup
19
  client = chromadb.Client()
20
  collection = client.create_collection(name="Enron_emails")
21
  collection.add(documents=train_content, ids=[f'id{i+1}' for i in range(len(train_content))])
22
 
23
- # initialize model and tokenizer globally but don't load them yet
24
- tokenizer = None
25
- model = None
26
- text_gen = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def load_model():
29
- global tokenizer, model, text_gen
30
- if model is None or tokenizer is None:
31
- tokenizer = GPT2Tokenizer.from_pretrained('./fine_tuned_model')
32
- model = GPT2LMHeadModel.from_pretrained('./fine_tuned_model')
33
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
34
- text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
35
 
 
36
  def question_answer(question):
37
- load_model() # loading model on first use
38
  try:
39
  generated = text_gen(question, max_length=200, num_return_sequences=1)
40
  generated_text = generated[0]['generated_text'].replace(question, "").strip()
@@ -42,12 +59,13 @@ def question_answer(question):
42
  except Exception as e:
43
  return f"Error in generating response: {str(e)}"
44
 
 
45
  iface = gr.Interface(
46
  fn=question_answer,
47
  inputs="text",
48
  outputs="text",
49
  title="Answering questions about the Enron case.",
50
  description="Ask a question about the Enron case!",
51
- examples=["What is Enron?"]
52
  )
53
  iface.launch()
 
7
 
8
  # loading and preprocessing dataset
9
  emails = pd.read_csv('emails.csv')
 
10
  def preprocess_email_content(raw_email):
11
  message = email.message_from_string(raw_email).get_payload()
12
  return message.replace("\n", "").replace("\r", "").replace("> >>> > >", "").strip()
13
 
14
  content_text = [preprocess_email_content(item) for item in emails['message']]
15
+ train_content, _ = train_test_split(content_text, train_size=0.00005) # was unable to load more emails
16
 
17
  # ChromaDB setup
18
  client = chromadb.Client()
19
  collection = client.create_collection(name="Enron_emails")
20
  collection.add(documents=train_content, ids=[f'id{i+1}' for i in range(len(train_content))])
21
 
22
+ # model and tokenizer
23
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
24
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
25
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
26
+
27
+ # tokenizing and training
28
+ tokenized_emails = tokenizer(train_content, truncation=True, padding=True)
29
+ with open('tokenized_emails.txt', 'w') as file:
30
+ for ids in tokenized_emails['input_ids']:
31
+ file.write(' '.join(map(str, ids)) + '\n')
32
+
33
+ dataset = TextDataset(tokenizer=tokenizer, file_path='tokenized_emails.txt', block_size=128)
34
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
35
+ training_args = TrainingArguments(
36
+ output_dir='./output',
37
+ num_train_epochs=3,
38
+ per_device_train_batch_size=8
39
+ )
40
+
41
+ trainer = Trainer(
42
+ model=model,
43
+ args=training_args,
44
+ data_collator=data_collator,
45
+ train_dataset=dataset
46
+ )
47
+ trainer.train()
48
 
49
+ # saving the model
50
+ model.save_pretrained("./fine_tuned_model")
51
+ tokenizer.save_pretrained("./fine_tuned_model")
 
 
 
 
52
 
53
+ # Gradio interface
54
  def question_answer(question):
 
55
  try:
56
  generated = text_gen(question, max_length=200, num_return_sequences=1)
57
  generated_text = generated[0]['generated_text'].replace(question, "").strip()
 
59
  except Exception as e:
60
  return f"Error in generating response: {str(e)}"
61
 
62
+ text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
63
  iface = gr.Interface(
64
  fn=question_answer,
65
  inputs="text",
66
  outputs="text",
67
  title="Answering questions about the Enron case.",
68
  description="Ask a question about the Enron case!",
69
+ examples=["What is Eron?"]
70
  )
71
  iface.launch()