File size: 1,745 Bytes
4c71f4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI, File, UploadFile
from model import predictor
from os import listdir
from os.path import *
from PIL import Image

import os
import hashlib
import threading
import time

gpredictor = None
app = FastAPI()

@app.get('/')
def root():
    return {'app': 'Thanks for visiting!!'}


@app.get('/favicon.ico', include_in_schema=False)
@app.post('/uploadfile/')
async def create_upload_file(file: UploadFile = File(...)):
    contents = await file.read()
    hash = hashlib.sha256(contents).hexdigest()
    file.filename = f'images/upload_{hash}.jpg'
    if not os.path.isfile(file.filename):
        with open(file.filename, 'wb') as f:
            f.write(contents)
    images[file.filename] = Image.open(file.filename)
    return {'filename': file.filename}


@app.get('/vqa')
async def answer(
    image: str,
    question: str
):
    if image not in images:
        print('not in image')
        pil_image = Image.open(image)
        images[image] = pil_image
    else:
        pil_image = images[image]
    while gpredictor is None:
        time.sleep(1)
    answer = gpredictor.predict_answer_from_text( pil_image, question )
    return {'answer': answer }

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
images={}

def runInThread():
    collect_images()
    print('Initialize model in thread')
    global gpredictor
    gpredictor = predictor.Predictor()
    print('Model is initialized')


def collect_images():
    image_path = join(dirname(abspath(__file__)), 'images')
    for f in listdir(image_path):
        if f.startswith('image'):
            full_image_path = join(image_path, f)
            images[full_image_path] = Image.open(full_image_path)

thread = threading.Thread(target=runInThread)
thread.start()