edithram23 commited on
Commit
8bfa5bb
1 Parent(s): 1cb45a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -15,9 +15,26 @@ model_dir_large = 'edithram23/Redaction_Personal_info_v1'
15
  tokenizer_large = AutoTokenizer.from_pretrained(model_dir_large)
16
  model_large = AutoModelForSeq2SeqLM.from_pretrained(model_dir_large)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def mask_generation(text,model=model_large,tokenizer=tokenizer_large):
19
- if(len(text)<30):
20
  text = text+'.'
 
21
  inputs = ["Mask Generation: " + text.lower()+'.']
22
  inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt")
23
  output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text))
 
15
  tokenizer_large = AutoTokenizer.from_pretrained(model_dir_large)
16
  model_large = AutoModelForSeq2SeqLM.from_pretrained(model_dir_large)
17
 
18
+ model_dir_small = 'edithram23/Redaction'
19
+ tokenizer_small = AutoTokenizer.from_pretrained(model_dir_small)
20
+ model_small = AutoModelForSeq2SeqLM.from_pretrained(model_dir_small)
21
+
22
+ def small(text,model=model_small,tokenizer=tokenizer_small):
23
+ inputs = ["Mask Generation: " + text.lower()+'.']
24
+ inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt")
25
+ output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text))
26
+ decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
27
+ predicted_title = decoded_output.strip()
28
+ pattern = r'\[.*?\]'
29
+ # Replace all occurrences of the pattern with [redacted]
30
+ redacted_text = re.sub(pattern, '[redacted]', predicted_title)
31
+ return redacted_text
32
+
33
+
34
  def mask_generation(text,model=model_large,tokenizer=tokenizer_large):
35
+ if(len(text)<90):
36
  text = text+'.'
37
+ return small(text)
38
  inputs = ["Mask Generation: " + text.lower()+'.']
39
  inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt")
40
  output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text))