vqa-guessing-game / response_db.py
sedrickkeh's picture
major updates demo v2
5a72dbb
raw
history blame
1.84 kB
from pymongo import MongoClient
import datetime
import os
class ResponseDb:
def __init__(self):
# Set up the connection
mongodb_username=os.environ['mongodb_username']
mongodb_pw=os.environ['mongodb_pw']
mongodb_cluster_url=os.environ['mongodb_cluster_url']
self.client = MongoClient(f"mongodb+srv://{mongodb_username}:{mongodb_pw}@{mongodb_cluster_url}/?retryWrites=true&w=majority")
self.db = self.client['vqa-game']
self.collection = self.db['vqa-game']
def add(self, dialogue_id, task_id, turn, question, response):
curr_datetime = datetime.datetime.now()
document = {"dialogue_id":dialogue_id,
"task_id":task_id,
"turn":turn,
"question":question,
"response":response,
"datetime":curr_datetime}
result = self.collection.insert_one(document)
def get(self):
return self.collection.find()
def get_code(taskid, history, top_pred):
taskid = int(taskid)
mongodb_username=os.environ['mongodb_username_2']
mongodb_pw=os.environ['mongodb_pw_2']
mongodb_cluster_url=os.environ['mongodb_cluster_url_2']
client = MongoClient(f"mongodb+srv://{mongodb_username}:{mongodb_pw}@{mongodb_cluster_url}/?retryWrites=true&w=majority")
db = client['vqa-codes']
collection = db['vqa-codes']
threshold_dict = {1001: 6, 1002: 2, 1003: 4, 1004: 2}
if int(taskid) in threshold_dict:
threshold = threshold_dict[int(taskid)]
if len(history)<=threshold and top_pred == 0:
return list(collection.find({"taskid":int(taskid)}))[0]['code']
else:
return list(collection.find({"taskid":3000-int(taskid)}))[0]['code']
return list(collection.find({"taskid":taskid}))[0]['code']