model / main.py
edithram23's picture
Update main.py
c21ac27 verified
raw
history blame contribute delete
No virus
2.43 kB
import os
os.environ["HF_HOME"] = "/.cache"
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_dir_large = 'edithram23/Redaction_Personal_info_v1'
tokenizer_large = AutoTokenizer.from_pretrained(model_dir_large)
model_large = AutoModelForSeq2SeqLM.from_pretrained(model_dir_large)
def mask_generation(text,model=model_large,tokenizer=tokenizer_large):
import re
inputs = ["Mask Generation: " + text+'.']
inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text))
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
predicted_title = decoded_output.strip()
pattern = r'\[.*?\]'
# Replace all occurrences of the pattern with [redacted]
redacted_text = re.sub(pattern, '[redacted]', predicted_title)
return redacted_text
from fastapi import FastAPI
import uvicorn
app = FastAPI()
from fastapi import FastAPI, Form
from fastapi.responses import HTMLResponse
app = FastAPI()
# Serve the HTML form
@app.get("/", response_class=HTMLResponse)
async def read_form():
html_content = """
<!DOCTYPE html>
<html>
<head>
<title>FastAPI Input Form</title>
</head>
<body>
<h1>Enter Your Message</h1>
<form action="/submit" method="post">
<label for="message">Message:</label>
<input type="text" id="message" name="message">
<input type="submit" value="Submit">
</form>
</body>
</html>
"""
return html_content
# Handle the form submission
@app.post("/submit",response_class=HTMLResponse)
async def submit_form(message: str = Form(...)):
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>FastAPI Input Form</title>
</head>
<body>
<h1>Enter Your Message</h1>
<form action="/" method="get">
<p>{mask_generation(message)}</p>
<input type="submit" value="HOME">
</form>
</body>
</html>
"""
return html_content
# @app.get("/")
# async def hello():
# return {"msg" : "Live"}
@app.post("/mask")
async def mask_input(query):
output = mask_generation(query)
return {"data" : output}
if __name__ == '__main__':
os.environ["HF_HOME"] = "/.cache"
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True, workers=1)