marcderbauer commited on
Commit
25b5680
1 Parent(s): 9909948

Improved logging

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -4,8 +4,7 @@ from transformers import BloomTokenizerFast, BloomForCausalLM
4
  tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
5
  # https://huggingface.co/blog/how-to-generate
6
 
7
- def generate(text, temp=0.7):
8
- print(temp)
9
  input_ids = tokenizer.encode(text, return_tensors='pt')
10
  output = model.generate(
11
  input_ids,
@@ -17,7 +16,10 @@ def generate(text, temp=0.7):
17
  repetition_penalty=1.2,
18
  min_length=len(text)+1
19
  )
20
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
21
 
22
  description = "Generate Titles for the Vice Youtube Channel"
23
  title = "Vice Headlines"
4
  tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
5
  # https://huggingface.co/blog/how-to-generate
6
 
7
+ def generate(text, temp=0.7, logging=True):
 
8
  input_ids = tokenizer.encode(text, return_tensors='pt')
9
  output = model.generate(
10
  input_ids,
16
  repetition_penalty=1.2,
17
  min_length=len(text)+1
18
  )
19
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
20
+ if logging:
21
+ print(f"\n\n{'-'*100}\nInput: {text}\nOutput: {decoded}\nTemp: {temp}")
22
+ return decoded
23
 
24
  description = "Generate Titles for the Vice Youtube Channel"
25
  title = "Vice Headlines"