vincentclaes commited on
Commit
5505694
1 Parent(s): f284c49

avoid gpu memory overflow

Browse files
Files changed (2) hide show
  1. app.py +4 -0
  2. scrape_website.py +1 -0
app.py CHANGED
@@ -89,6 +89,8 @@ def evaluate(
89
  **kwargs,
90
  ):
91
  content = process_webpage(url=url)
 
 
92
  prompt = generate_prompt(instruction, content)
93
  inputs = tokenizer(prompt, return_tensors="pt")
94
  input_ids = inputs["input_ids"].to(device)
@@ -109,6 +111,8 @@ def evaluate(
109
  )
110
  s = generation_output.sequences[0]
111
  output = tokenizer.decode(s)
 
 
112
  return output.split("### Response:")[1].strip()
113
 
114
 
 
89
  **kwargs,
90
  ):
91
  content = process_webpage(url=url)
92
+ # avoid GPU memory overflow
93
+ torch.cuda.empty_cache()
94
  prompt = generate_prompt(instruction, content)
95
  inputs = tokenizer(prompt, return_tensors="pt")
96
  input_ids = inputs["input_ids"].to(device)
 
111
  )
112
  s = generation_output.sequences[0]
113
  output = tokenizer.decode(s)
114
+ # avoid GPU memory overflow
115
+ torch.cuda.empty_cache()
116
  return output.split("### Response:")[1].strip()
117
 
118
 
scrape_website.py CHANGED
@@ -1,6 +1,7 @@
1
  import requests
2
  from bs4 import BeautifulSoup
3
 
 
4
  def process_webpage(url:str):
5
  # A set to keep track of visited pages
6
  visited_pages = set()
 
1
  import requests
2
  from bs4 import BeautifulSoup
3
 
4
+
5
  def process_webpage(url:str):
6
  # A set to keep track of visited pages
7
  visited_pages = set()