josequinonez commited on
Commit
03fb0c0
·
verified ·
1 Parent(s): cc3c752

Update app.py

Browse files

"facebook/bart-large-cnn"
Release memory

Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import streamlit as st
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Use AutoModelForSeq2SeqLM for BART
4
  import torch
@@ -28,19 +27,25 @@ def summarize_bart(article):
28
  )
29
 
30
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
31
  return summary
32
 
33
  def answer_question_bart(article, question):
34
  """Answers a question based on an article using the facebook/bart-large-cnn model."""
35
  # For Q&A with BART, concatenate the question and article with a separator
36
- # Let's use a prompt format similar to what worked in the notebook tests for BART QA
37
  input_text = f"Answer the question based on the following article.\n\nArticle: {article}\n\nQuestion: {question}\n\nAnswer:"
38
 
39
  # Tokenize the input
40
  inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True, padding=True)
41
 
42
  # Generate the answer
43
- # Adjust generation parameters as needed for Q&A
44
  outputs = model.generate(
45
  inputs["input_ids"],
46
  attention_mask=inputs["attention_mask"],
@@ -67,6 +72,14 @@ def answer_question_bart(article, question):
67
  if answer.startswith("Answer:"): # Handle cases where the model might repeat "Answer:"
68
  answer = answer[len("Answer:"):].strip()
69
 
 
 
 
 
 
 
 
 
70
 
71
  return answer
72
 
@@ -103,5 +116,4 @@ if st.button("Process"):
103
  st.warning("Please provide an article to answer the question from.")
104
  elif not question_input:
105
  st.warning("Please provide a question to answer.")
106
-
107
 
 
 
1
  import streamlit as st
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Use AutoModelForSeq2SeqLM for BART
3
  import torch
 
27
  )
28
 
29
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+
31
+ # --- Memory Release ---
32
+ del inputs
33
+ del outputs
34
+ if torch.cuda.is_available():
35
+ torch.cuda.empty_cache()
36
+ # --- End Memory Release ---
37
+
38
  return summary
39
 
40
  def answer_question_bart(article, question):
41
  """Answers a question based on an article using the facebook/bart-large-cnn model."""
42
  # For Q&A with BART, concatenate the question and article with a separator
 
43
  input_text = f"Answer the question based on the following article.\n\nArticle: {article}\n\nQuestion: {question}\n\nAnswer:"
44
 
45
  # Tokenize the input
46
  inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True, padding=True)
47
 
48
  # Generate the answer
 
49
  outputs = model.generate(
50
  inputs["input_ids"],
51
  attention_mask=inputs["attention_mask"],
 
72
  if answer.startswith("Answer:"): # Handle cases where the model might repeat "Answer:"
73
  answer = answer[len("Answer:"):].strip()
74
 
75
+ # --- Memory Release ---
76
+ del inputs
77
+ del outputs
78
+ # del generated_text # Be careful deleting generated_text if you need to return it
79
+ if torch.cuda.is_available():
80
+ torch.cuda.empty_cache()
81
+ # --- End Memory Release ---
82
+
83
 
84
  return answer
85
 
 
116
  st.warning("Please provide an article to answer the question from.")
117
  elif not question_input:
118
  st.warning("Please provide a question to answer.")
 
119