Spaces:
Sleeping
Sleeping
File size: 1,553 Bytes
3eed03b a200dd4 3eed03b a200dd4 3eed03b 4c2e966 3eed03b 5e2686d 3eed03b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from dotenv import load_dotenv
import os
from transformers import BlenderbotSmallForConditionalGeneration, AutoTokenizer
from helpers import in_cache
from fastapi import FastAPI
from pydantic import BaseModel
import redis
from starlette.middleware.cors import CORSMiddleware
load_dotenv()
model = BlenderbotSmallForConditionalGeneration.from_pretrained("facebook/blenderbot_small-90M", cache_dir="new_cache_dir/")
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M", cache_dir="new_cache_dir/")
class UserMSGRequest(BaseModel):
message: str
r = redis.Redis(
host='eu2-suitable-cod-32440.upstash.io',
port=32440,
password=os.getenv('UPSTASH_REDIS_PWD'),
ssl=True
)
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/status")
def status ():
return {"status": "OK"}
@app.post("/generate-message")
async def root(utterance: UserMSGRequest):
key_list = r.keys()
is_in_cache = in_cache(utterance.message, key_list)
if is_in_cache.status:
print("From Cache!")
return r.get(is_in_cache.closest_match)
inputs = tokenizer(utterance.message, return_tensors = "pt")
results = model.generate(**inputs)
response = tokenizer.batch_decode(results, skip_special_tokens=True)[0]
r.set(utterance.message, response, 250)
return response
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080) |