stibiumghost commited on
Commit
7c17b3b
β€’
1 Parent(s): 3908076

Update text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +1 -14
text_gen.py CHANGED
@@ -1,18 +1,5 @@
1
- import transformers
2
  import string
3
 
4
- model_names = ['microsoft/GODEL-v1_1-large-seq2seq',
5
- 'facebook/blenderbot-1B-distill',
6
- 'facebook/blenderbot_small-90M']
7
-
8
- tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
- transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
10
- transformers.BlenderbotSmallTokenizer.from_pretrained(model_names[2])]
11
-
12
- model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
13
- transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
14
- transformers.BlenderbotSmallForConditionalGeneration.from_pretrained(model_names[2])]
15
-
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
  text = f'{context} {text}'
@@ -30,7 +17,7 @@ def generate_text(text, context, model_name, model, tokenizer, minimum=15, maxim
30
  def capitalization(line):
31
  line, end = line[:-1], line[-1]
32
  for mark in '.?!':
33
- line = f'{mark} '.join([part.strip()[0].upper() + part.strip()[1:] for part in line.split(mark) if len(part) > 1])
34
  line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i'
35
  else word for word in line.split()])
36
  return line.replace(' i\'', ' I\'') + end
 
 
1
  import string
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
5
  text = f'{context} {text}'
 
17
  def capitalization(line):
18
  line, end = line[:-1], line[-1]
19
  for mark in '.?!':
20
+ line = f'{mark} '.join([part.strip()[0].upper() + part.strip()[1:] for part in line.split(mark) if len(part) > 1])
21
  line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i'
22
  else word for word in line.split()])
23
  return line.replace(' i\'', ' I\'') + end