kshtiiz commited on
Commit
503acc3
1 Parent(s): 5100edd

main updated

Browse files
Files changed (1) hide show
  1. main.py +154 -145
main.py CHANGED
@@ -1,145 +1,154 @@
1
- import argparse
2
- import os
3
- import random
4
- import io
5
- from PIL import Image
6
- import numpy as np
7
- import torch
8
- import torch.backends.cudnn as cudnn
9
-
10
- from minigpt4.common.config import Config
11
- from minigpt4.common.dist_utils import get_rank
12
- from minigpt4.common.registry import registry
13
- from minigpt4.conversation.conversation import Chat, CONV_VISION
14
- from fastapi import FastAPI, HTTPException, File, UploadFile, Form
15
- from fastapi.responses import RedirectResponse
16
- from fastapi.middleware.cors import CORSMiddleware
17
- from pydantic import BaseModel
18
- from PIL import Image
19
- import io
20
- import uvicorn
21
- # imports modules for registration
22
- from minigpt4.datasets.builders import *
23
- from minigpt4.models import *
24
- from minigpt4.processors import *
25
- from minigpt4.runners import *
26
- from minigpt4.tasks import *
27
-
28
-
29
- def parse_args():
30
- parser = argparse.ArgumentParser(description="Demo")
31
- parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4_eval.yaml',
32
- help="path to configuration file.")
33
- parser.add_argument(
34
- "--options",
35
- nargs="+",
36
- help="override some settings in the used config, the key-value pair "
37
- "in xxx=yyy format will be merged into config file (deprecate), "
38
- "change to --cfg-options instead.",
39
- )
40
- args = parser.parse_args()
41
- return args
42
-
43
-
44
- def setup_seeds(config):
45
- seed = config.run_cfg.seed + get_rank()
46
-
47
- random.seed(seed)
48
- np.random.seed(seed)
49
- torch.manual_seed(seed)
50
-
51
- cudnn.benchmark = False
52
- cudnn.deterministic = True
53
-
54
-
55
- # ========================================
56
- # Model Initialization
57
- # ========================================
58
-
59
- SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
60
- You can duplicate and use it with a paid private GPU.
61
- <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
62
- Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
63
- '''
64
-
65
- print('Initializing Chat')
66
- cfg = Config(parse_args())
67
-
68
- model_config = cfg.model_cfg
69
- model_cls = registry.get_model_class(model_config.arch)
70
- model = model_cls.from_config(model_config).to('cuda:0')
71
-
72
- vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
73
- vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
74
- chat = Chat(model, vis_processor)
75
- print('Initialization Finished')
76
-
77
- # ========================================
78
- # Gradio Setting
79
- # ========================================
80
-
81
- app = FastAPI()
82
- app.add_middleware(
83
- CORSMiddleware,
84
- allow_origins=["*"], # Replace "*" with your frontend domain
85
- allow_credentials=True,
86
- allow_methods=["GET", "POST"],
87
- allow_headers=["*"],
88
- )
89
-
90
-
91
- class Item(BaseModel):
92
- gr_img: UploadFile = File(..., description="Image file")
93
- text_input: str = None
94
-
95
-
96
- chat_state = CONV_VISION.copy()
97
- img_list = []
98
- chatbot = []
99
-
100
-
101
- @app.get("/")
102
- async def root():
103
- return RedirectResponse(url="/docs")
104
-
105
-
106
- @app.post("/upload_img/")
107
- async def upload_img(
108
- file: UploadFile = File(...),
109
- ):
110
- pil_image = Image.open(io.BytesIO(await file.read()))
111
- chat.upload_img(pil_image, chat_state, img_list)
112
- return {"message": "image uploaded successfully."}
113
-
114
-
115
- @app.post("/process/")
116
- async def process_item(prompt: str = Form(...)):
117
- if not img_list: # Check if img_list is empty or None
118
- return {"error": "No images uploaded."}
119
-
120
- global chatbot
121
- chat.ask(prompt, chat_state)
122
- chatbot = chatbot + [[prompt, None]]
123
- llm_message = \
124
- chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=1, max_length=2000)[0]
125
- chatbot[-1][1] = llm_message
126
- return chatbot
127
-
128
-
129
- @app.post("/reset/")
130
- async def reset(
131
-
132
- ):
133
- global chat_state, img_list, chatbot # Use global keyword to reassign
134
- img_list = []
135
- if chat_state is not None:
136
- chat_state.messages = []
137
- if img_list is not None:
138
- img_list = []
139
- if chatbot is not None:
140
- chatbot = []
141
-
142
-
143
- if __name__ == "__main__":
144
- # Run the FastAPI app with Uvicorn
145
- uvicorn.run("main:app", host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import io
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ from typing import List
10
+ from minigpt4.common.config import Config
11
+ from minigpt4.common.dist_utils import get_rank
12
+ from minigpt4.common.registry import registry
13
+ from minigpt4.conversation.conversation import Chat, CONV_VISION
14
+ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
15
+ from fastapi.responses import RedirectResponse
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel
18
+ from PIL import Image
19
+ import io
20
+ import uvicorn
21
+ # imports modules for registration
22
+ from minigpt4.datasets.builders import *
23
+ from minigpt4.models import *
24
+ from minigpt4.processors import *
25
+ from minigpt4.runners import *
26
+ from minigpt4.tasks import *
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(description="Demo")
31
+ parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4_eval.yaml',
32
+ help="path to configuration file.")
33
+ parser.add_argument(
34
+ "--options",
35
+ nargs="+",
36
+ help="override some settings in the used config, the key-value pair "
37
+ "in xxx=yyy format will be merged into config file (deprecate), "
38
+ "change to --cfg-options instead.",
39
+ )
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def setup_seeds(config):
45
+ seed = config.run_cfg.seed + get_rank()
46
+
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+
51
+ cudnn.benchmark = False
52
+ cudnn.deterministic = True
53
+
54
+
55
+ # ========================================
56
+ # Model Initialization
57
+ # ========================================
58
+
59
+ SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
60
+ You can duplicate and use it with a paid private GPU.
61
+ <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
62
+ Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
63
+ '''
64
+
65
+ print('Initializing Chat')
66
+ cfg = Config(parse_args())
67
+
68
+ model_config = cfg.model_cfg
69
+ model_cls = registry.get_model_class(model_config.arch)
70
+ model = model_cls.from_config(model_config).to('cuda:0')
71
+
72
+ vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
73
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
74
+ chat = Chat(model, vis_processor)
75
+ print('Initialization Finished')
76
+
77
+ # ========================================
78
+ # Gradio Setting
79
+ # ========================================
80
+
81
+ app = FastAPI()
82
+ app.add_middleware(
83
+ CORSMiddleware,
84
+ allow_origins=["*"], # Replace "*" with your frontend domain
85
+ allow_credentials=True,
86
+ allow_methods=["GET", "POST"],
87
+ allow_headers=["*"],
88
+ )
89
+
90
+
91
+ class Item(BaseModel):
92
+ gr_img: UploadFile = File(..., description="Image file")
93
+ text_input: str = None
94
+
95
+
96
+ chat_state = CONV_VISION.copy()
97
+ img_list = []
98
+ chatbot = []
99
+
100
+
101
+ @app.get("/")
102
+ async def root():
103
+ return RedirectResponse(url="/docs")
104
+
105
+
106
+ @app.post("/upload_img/")
107
+ async def upload_img(
108
+ file: UploadFile = File(...),
109
+ ):
110
+ pil_image = Image.open(io.BytesIO(await file.read()))
111
+ chat.upload_img(pil_image, chat_state, img_list)
112
+ return {"message": "image uploaded successfully."}
113
+
114
+
115
+
116
+ @app.post("/process/")
117
+ async def process_item(prompts: List[str] = Form(...)):
118
+ if not img_list: # Check if img_list is empty or None
119
+ raise HTTPException(status_code=400, detail="No images uploaded.")
120
+
121
+ global chatbot
122
+ responses = []
123
+
124
+ for prompt in prompts:
125
+ # Process each prompt individually
126
+ chat.ask(prompt, chat_state)
127
+ chatbot.append([prompt, None])
128
+ llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=1, max_length=2000)[0]
129
+ chatbot[-1][1] = llm_message
130
+ responses.append({
131
+ "prompt": prompt,
132
+ "response": llm_message
133
+ })
134
+
135
+ return responses
136
+
137
+
138
+ @app.post("/reset/")
139
+ async def reset(
140
+
141
+ ):
142
+ global chat_state, img_list, chatbot # Use global keyword to reassign
143
+ img_list = []
144
+ if chat_state is not None:
145
+ chat_state.messages = []
146
+ if img_list is not None:
147
+ img_list = []
148
+ if chatbot is not None:
149
+ chatbot = []
150
+
151
+
152
+ if __name__ == "__main__":
153
+ # Run the FastAPI app with Uvicorn
154
+ uvicorn.run("main:app", host="0.0.0.0", port=7860)