tungdop2's picture
init code
284cb2b
raw
history blame
1.51 kB
from fastapi import FastAPI
from pydantic import BaseModel, Extra
import argparse
from typing import Optional
import uvicorn
from model import ChallengePromptGenerator
class Prompt(BaseModel, extra=Extra.allow):
prompt: str
seed: Optional[int] = 0
max_length: Optional[int] = 77
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=10001)
parser.add_argument("--netuid", type=str, default=23)
parser.add_argument("--min_stake", type=int, default=100)
parser.add_argument(
"--chain_endpoint",
type=str,
default="finney",
)
parser.add_argument("--disable_secure", action="store_true", default=False)
args = parser.parse_args()
return args
class ChallengeImage:
def __init__(self):
self.challenge_prompt = ChallengePromptGenerator()
self.app = FastAPI(title="Challenge Prompt")
self.app.add_api_route("/", self.__call__, methods=["POST"])
async def __call__(
self,
data: Prompt,
):
data = dict(data)
prompt = data["prompt"]
if not prompt:
prompt = "an image of "
complete_prompt = self.challenge_prompt.infer_prompt(
[prompt], max_generation_length=77, sampling_topk=100
)[0].strip()
return complete_prompt
if __name__ == "__main__":
args = get_args()
print("Args: ", args)
app = ChallengeImage()
uvicorn.run(app.app, host="0.0.0.0", port=args.port)