Spaces:
Sleeping
Sleeping
from dotenv import load_dotenv | |
import os | |
from transformers import BlenderbotSmallForConditionalGeneration, AutoTokenizer | |
from helpers import in_cache | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import redis | |
from starlette.middleware.cors import CORSMiddleware | |
load_dotenv() | |
model = BlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M", cache_dir="new_cache_dir/") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M", cache_dir="new_cache_dir/") | |
class UserMSGRequest(BaseModel): | |
message: str | |
r = redis.Redis( | |
host='eu2-suitable-cod-32440.upstash.io', | |
port=32440, | |
password=os.getenv('UPSTASH_REDIS_PWD'), | |
ssl=True | |
) | |
app = FastAPI() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def status (): | |
return {"status": "OK"} | |
async def root(utterance: UserMSGRequest): | |
key_list = r.keys() | |
is_in_cache = in_cache(utterance.message, key_list) | |
if is_in_cache.status: | |
print("From Cache!") | |
return r.get(is_in_cache.closest_match) | |
inputs = tokenizer(utterance.message, return_tensors = "pt") | |
results = model.generate(**inputs) | |
response = tokenizer.batch_decode(results, skip_special_tokens=True)[0] | |
r.set(utterance.message, response, 250) | |
return response | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8080) |