critzyblenderbot / main.py
Ehren12's picture
switch back to 90M model
a200dd4
raw
history blame
1.55 kB
from dotenv import load_dotenv
import os
from transformers import BlenderbotSmallForConditionalGeneration, AutoTokenizer
from helpers import in_cache
from fastapi import FastAPI
from pydantic import BaseModel
import redis
from starlette.middleware.cors import CORSMiddleware
load_dotenv()
model = BlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M", cache_dir="new_cache_dir/")
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M", cache_dir="new_cache_dir/")
class UserMSGRequest(BaseModel):
message: str
r = redis.Redis(
host='eu2-suitable-cod-32440.upstash.io',
port=32440,
password=os.getenv('UPSTASH_REDIS_PWD'),
ssl=True
)
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/status")
def status ():
return {"status": "OK"}
@app.post("/generate-message")
async def root(utterance: UserMSGRequest):
key_list = r.keys()
is_in_cache = in_cache(utterance.message, key_list)
if is_in_cache.status:
print("From Cache!")
return r.get(is_in_cache.closest_match)
inputs = tokenizer(utterance.message, return_tensors = "pt")
results = model.generate(**inputs)
response = tokenizer.batch_decode(results, skip_special_tokens=True)[0]
r.set(utterance.message, response, 250)
return response
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)