Spaces:
Sleeping
Sleeping
main updated
Browse files
main.py
CHANGED
@@ -1,145 +1,154 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
import io
|
5 |
-
from PIL import Image
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torch.backends.cudnn as cudnn
|
9 |
-
|
10 |
-
from minigpt4.common.config import Config
|
11 |
-
from minigpt4.common.dist_utils import get_rank
|
12 |
-
from minigpt4.common.registry import registry
|
13 |
-
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
14 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
15 |
-
from fastapi.responses import RedirectResponse
|
16 |
-
from fastapi.middleware.cors import CORSMiddleware
|
17 |
-
from pydantic import BaseModel
|
18 |
-
from PIL import Image
|
19 |
-
import io
|
20 |
-
import uvicorn
|
21 |
-
# imports modules for registration
|
22 |
-
from minigpt4.datasets.builders import *
|
23 |
-
from minigpt4.models import *
|
24 |
-
from minigpt4.processors import *
|
25 |
-
from minigpt4.runners import *
|
26 |
-
from minigpt4.tasks import *
|
27 |
-
|
28 |
-
|
29 |
-
def parse_args():
|
30 |
-
parser = argparse.ArgumentParser(description="Demo")
|
31 |
-
parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4_eval.yaml',
|
32 |
-
help="path to configuration file.")
|
33 |
-
parser.add_argument(
|
34 |
-
"--options",
|
35 |
-
nargs="+",
|
36 |
-
help="override some settings in the used config, the key-value pair "
|
37 |
-
"in xxx=yyy format will be merged into config file (deprecate), "
|
38 |
-
"change to --cfg-options instead.",
|
39 |
-
)
|
40 |
-
args = parser.parse_args()
|
41 |
-
return args
|
42 |
-
|
43 |
-
|
44 |
-
def setup_seeds(config):
|
45 |
-
seed = config.run_cfg.seed + get_rank()
|
46 |
-
|
47 |
-
random.seed(seed)
|
48 |
-
np.random.seed(seed)
|
49 |
-
torch.manual_seed(seed)
|
50 |
-
|
51 |
-
cudnn.benchmark = False
|
52 |
-
cudnn.deterministic = True
|
53 |
-
|
54 |
-
|
55 |
-
# ========================================
|
56 |
-
# Model Initialization
|
57 |
-
# ========================================
|
58 |
-
|
59 |
-
SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
|
60 |
-
You can duplicate and use it with a paid private GPU.
|
61 |
-
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
|
62 |
-
Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
|
63 |
-
'''
|
64 |
-
|
65 |
-
print('Initializing Chat')
|
66 |
-
cfg = Config(parse_args())
|
67 |
-
|
68 |
-
model_config = cfg.model_cfg
|
69 |
-
model_cls = registry.get_model_class(model_config.arch)
|
70 |
-
model = model_cls.from_config(model_config).to('cuda:0')
|
71 |
-
|
72 |
-
vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
|
73 |
-
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
74 |
-
chat = Chat(model, vis_processor)
|
75 |
-
print('Initialization Finished')
|
76 |
-
|
77 |
-
# ========================================
|
78 |
-
# Gradio Setting
|
79 |
-
# ========================================
|
80 |
-
|
81 |
-
app = FastAPI()
|
82 |
-
app.add_middleware(
|
83 |
-
CORSMiddleware,
|
84 |
-
allow_origins=["*"], # Replace "*" with your frontend domain
|
85 |
-
allow_credentials=True,
|
86 |
-
allow_methods=["GET", "POST"],
|
87 |
-
allow_headers=["*"],
|
88 |
-
)
|
89 |
-
|
90 |
-
|
91 |
-
class Item(BaseModel):
|
92 |
-
gr_img: UploadFile = File(..., description="Image file")
|
93 |
-
text_input: str = None
|
94 |
-
|
95 |
-
|
96 |
-
chat_state = CONV_VISION.copy()
|
97 |
-
img_list = []
|
98 |
-
chatbot = []
|
99 |
-
|
100 |
-
|
101 |
-
@app.get("/")
|
102 |
-
async def root():
|
103 |
-
return RedirectResponse(url="/docs")
|
104 |
-
|
105 |
-
|
106 |
-
@app.post("/upload_img/")
|
107 |
-
async def upload_img(
|
108 |
-
file: UploadFile = File(...),
|
109 |
-
):
|
110 |
-
pil_image = Image.open(io.BytesIO(await file.read()))
|
111 |
-
chat.upload_img(pil_image, chat_state, img_list)
|
112 |
-
return {"message": "image uploaded successfully."}
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import io
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
from typing import List
|
10 |
+
from minigpt4.common.config import Config
|
11 |
+
from minigpt4.common.dist_utils import get_rank
|
12 |
+
from minigpt4.common.registry import registry
|
13 |
+
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
14 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
15 |
+
from fastapi.responses import RedirectResponse
|
16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
17 |
+
from pydantic import BaseModel
|
18 |
+
from PIL import Image
|
19 |
+
import io
|
20 |
+
import uvicorn
|
21 |
+
# imports modules for registration
|
22 |
+
from minigpt4.datasets.builders import *
|
23 |
+
from minigpt4.models import *
|
24 |
+
from minigpt4.processors import *
|
25 |
+
from minigpt4.runners import *
|
26 |
+
from minigpt4.tasks import *
|
27 |
+
|
28 |
+
|
29 |
+
def parse_args():
|
30 |
+
parser = argparse.ArgumentParser(description="Demo")
|
31 |
+
parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4_eval.yaml',
|
32 |
+
help="path to configuration file.")
|
33 |
+
parser.add_argument(
|
34 |
+
"--options",
|
35 |
+
nargs="+",
|
36 |
+
help="override some settings in the used config, the key-value pair "
|
37 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
38 |
+
"change to --cfg-options instead.",
|
39 |
+
)
|
40 |
+
args = parser.parse_args()
|
41 |
+
return args
|
42 |
+
|
43 |
+
|
44 |
+
def setup_seeds(config):
|
45 |
+
seed = config.run_cfg.seed + get_rank()
|
46 |
+
|
47 |
+
random.seed(seed)
|
48 |
+
np.random.seed(seed)
|
49 |
+
torch.manual_seed(seed)
|
50 |
+
|
51 |
+
cudnn.benchmark = False
|
52 |
+
cudnn.deterministic = True
|
53 |
+
|
54 |
+
|
55 |
+
# ========================================
|
56 |
+
# Model Initialization
|
57 |
+
# ========================================
|
58 |
+
|
59 |
+
SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
|
60 |
+
You can duplicate and use it with a paid private GPU.
|
61 |
+
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
|
62 |
+
Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
|
63 |
+
'''
|
64 |
+
|
65 |
+
print('Initializing Chat')
|
66 |
+
cfg = Config(parse_args())
|
67 |
+
|
68 |
+
model_config = cfg.model_cfg
|
69 |
+
model_cls = registry.get_model_class(model_config.arch)
|
70 |
+
model = model_cls.from_config(model_config).to('cuda:0')
|
71 |
+
|
72 |
+
vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
|
73 |
+
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
74 |
+
chat = Chat(model, vis_processor)
|
75 |
+
print('Initialization Finished')
|
76 |
+
|
77 |
+
# ========================================
|
78 |
+
# Gradio Setting
|
79 |
+
# ========================================
|
80 |
+
|
81 |
+
app = FastAPI()
|
82 |
+
app.add_middleware(
|
83 |
+
CORSMiddleware,
|
84 |
+
allow_origins=["*"], # Replace "*" with your frontend domain
|
85 |
+
allow_credentials=True,
|
86 |
+
allow_methods=["GET", "POST"],
|
87 |
+
allow_headers=["*"],
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
class Item(BaseModel):
|
92 |
+
gr_img: UploadFile = File(..., description="Image file")
|
93 |
+
text_input: str = None
|
94 |
+
|
95 |
+
|
96 |
+
chat_state = CONV_VISION.copy()
|
97 |
+
img_list = []
|
98 |
+
chatbot = []
|
99 |
+
|
100 |
+
|
101 |
+
@app.get("/")
|
102 |
+
async def root():
|
103 |
+
return RedirectResponse(url="/docs")
|
104 |
+
|
105 |
+
|
106 |
+
@app.post("/upload_img/")
|
107 |
+
async def upload_img(
|
108 |
+
file: UploadFile = File(...),
|
109 |
+
):
|
110 |
+
pil_image = Image.open(io.BytesIO(await file.read()))
|
111 |
+
chat.upload_img(pil_image, chat_state, img_list)
|
112 |
+
return {"message": "image uploaded successfully."}
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
@app.post("/process/")
|
117 |
+
async def process_item(prompts: List[str] = Form(...)):
|
118 |
+
if not img_list: # Check if img_list is empty or None
|
119 |
+
raise HTTPException(status_code=400, detail="No images uploaded.")
|
120 |
+
|
121 |
+
global chatbot
|
122 |
+
responses = []
|
123 |
+
|
124 |
+
for prompt in prompts:
|
125 |
+
# Process each prompt individually
|
126 |
+
chat.ask(prompt, chat_state)
|
127 |
+
chatbot.append([prompt, None])
|
128 |
+
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=1, max_length=2000)[0]
|
129 |
+
chatbot[-1][1] = llm_message
|
130 |
+
responses.append({
|
131 |
+
"prompt": prompt,
|
132 |
+
"response": llm_message
|
133 |
+
})
|
134 |
+
|
135 |
+
return responses
|
136 |
+
|
137 |
+
|
138 |
+
@app.post("/reset/")
|
139 |
+
async def reset(
|
140 |
+
|
141 |
+
):
|
142 |
+
global chat_state, img_list, chatbot # Use global keyword to reassign
|
143 |
+
img_list = []
|
144 |
+
if chat_state is not None:
|
145 |
+
chat_state.messages = []
|
146 |
+
if img_list is not None:
|
147 |
+
img_list = []
|
148 |
+
if chatbot is not None:
|
149 |
+
chatbot = []
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
# Run the FastAPI app with Uvicorn
|
154 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860)
|