Syrinx commited on
Commit
1e8f625
1 Parent(s): 257b300

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -28
app.py CHANGED
@@ -1,52 +1,49 @@
1
  import torch
2
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
-
4
  import streamlit as st
5
 
6
  # Load the tokenizer and model
7
  tokenizer = GPT2Tokenizer.from_pretrained('webtoon_tokenizer')
8
  model = GPT2LMHeadModel.from_pretrained('webtoon_model')
9
 
10
-
11
- # Define the app
12
- def main():
13
- st.title('Webtoon Description Generator')
14
-
15
- # Get the input from the user
16
- title = st.text_input('Enter the title of the Webtoon:', '')
17
-
18
- # Generate the description
19
- if st.button('Generate Description'):
20
- with st.spinner('Generating...'):
21
- description = generate_description(title)
22
- st.success(description)
23
-
24
  # Check if GPU is available
25
- if torch.cuda.is_available():
26
- device = torch.device("cuda")
27
- else:
28
- device = torch.device("cpu")
29
 
30
  # Define the function that generates the description
31
  def generate_description(title):
32
  # Preprocess the input
33
  input_text = f"{title}"
34
  input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
 
35
 
36
  # Generate the output using the model
37
- output = model.generate(
38
- input_ids=input_ids,
39
- max_length=200,
40
- num_beams=4,
41
- early_stopping=True,
42
- no_repeat_ngram_size=2
43
- )
 
 
44
 
45
  # Convert the output to text
46
  description = tokenizer.decode(output[0], skip_special_tokens=True)
47
-
48
  return description
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  if __name__ == '__main__':
52
- main()
 
1
  import torch
2
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
 
3
  import streamlit as st
4
 
5
  # Load the tokenizer and model
6
  tokenizer = GPT2Tokenizer.from_pretrained('webtoon_tokenizer')
7
  model = GPT2LMHeadModel.from_pretrained('webtoon_model')
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Check if GPU is available
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
 
 
12
 
13
  # Define the function that generates the description
14
  def generate_description(title):
15
  # Preprocess the input
16
  input_text = f"{title}"
17
  input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
18
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
19
 
20
  # Generate the output using the model
21
+ with torch.no_grad(): # Disable gradient calculation for inference
22
+ output = model.generate(
23
+ input_ids=input_ids,
24
+ attention_mask=attention_mask, # Pass attention_mask to avoid warnings
25
+ max_length=100, # Reduce max_length for quicker inference
26
+ num_beams=2, # Reduce num_beams for quicker inference
27
+ early_stopping=True,
28
+ no_repeat_ngram_size=2
29
+ )
30
 
31
  # Convert the output to text
32
  description = tokenizer.decode(output[0], skip_special_tokens=True)
 
33
  return description
34
 
35
+ # Define the app
36
+ def main():
37
+ st.title('Webtoon Description Generator')
38
+
39
+ # Get the input from the user
40
+ title = st.text_input('Enter the title of the Webtoon:', '')
41
+
42
+ # Generate the description
43
+ if st.button('Generate Description'):
44
+ with st.spinner('Generating...'):
45
+ description = generate_description(title)
46
+ st.success(description)
47
 
48
  if __name__ == '__main__':
49
+ main()