Eric Michael Martinez commited on
Commit
6e8003f
1 Parent(s): 0c2b2b7
Files changed (6) hide show
  1. app/app.py +234 -0
  2. app/db.py +49 -0
  3. app/schemas.py +15 -0
  4. app/users.py +68 -0
  5. main.py +4 -0
  6. requirements.txt +9 -0
app/app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import os
3
+ import requests
4
+ import gradio as gr
5
+ import openai
6
+
7
+ from fastapi import Depends, FastAPI, Request
8
+ from app.db import User, create_db_and_tables
9
+ from app.schemas import UserCreate, UserRead, UserUpdate
10
+ from app.users import auth_backend, current_active_user, fastapi_users
11
+ from dotenv import load_dotenv
12
+ import examples as chatbot_examples
13
+
14
+ # Get the current environment from the environment variable
15
+ current_environment = os.getenv("APP_ENV", "dev")
16
+
17
+ # Load the appropriate .env file based on the current environment
18
+ if current_environment == "dev":
19
+ load_dotenv(".env.dev")
20
+ elif current_environment == "test":
21
+ load_dotenv(".env.test")
22
+ elif current_environment == "prod":
23
+ load_dotenv(".env.prod")
24
+ else:
25
+ raise ValueError("Invalid environment specified")
26
+
27
+
28
+ def api_login(email, password):
29
+ port = os.getenv("APP_PORT")
30
+ scheme = os.getenv("APP_SCHEME")
31
+ host = os.getenv("APP_HOST")
32
+
33
+ url = f"{scheme}://{host}:{port}/auth/jwt/login"
34
+ payload = {
35
+ 'username': email,
36
+ 'password': password
37
+ }
38
+ headers = {
39
+ 'Content-Type': 'application/x-www-form-urlencoded'
40
+ }
41
+
42
+ response = requests.post(
43
+ url,
44
+ data=payload,
45
+ headers=headers
46
+ )
47
+
48
+ if(response.status_code==200):
49
+ response_json = response.json()
50
+ api_key = response_json['access_token']
51
+ return True, api_key
52
+ else:
53
+ response_json = response.json()
54
+ detail = response_json['detail']
55
+ return False, detail
56
+
57
+
58
+ def get_api_key(email, password):
59
+ successful, message = api_login(email, password)
60
+
61
+ if(successful):
62
+ return os.getenv("APP_API_BASE"), message
63
+ else:
64
+ raise gr.Error(message)
65
+ return "", ""
66
+
67
+ # Define a function to get the AI's reply using the OpenAI API
68
+ def get_ai_reply(message, model="gpt-3.5-turbo", system_message=None, temperature=0, message_history=[]):
69
+ # Initialize the messages list
70
+ messages = []
71
+
72
+ # Add the system message to the messages list
73
+ if system_message is not None:
74
+ messages += [{"role": "system", "content": system_message}]
75
+
76
+ # Add the message history to the messages list
77
+ if message_history is not None:
78
+ messages += message_history
79
+
80
+ # Add the user's message to the messages list
81
+ messages += [{"role": "user", "content": message}]
82
+
83
+ # Make an API call to the OpenAI ChatCompletion endpoint with the model and messages
84
+ completion = openai.ChatCompletion.create(
85
+ model=model,
86
+ messages=messages,
87
+ temperature=temperature
88
+ )
89
+
90
+ # Extract and return the AI's response from the API response
91
+ return completion.choices[0].message.content.strip()
92
+
93
+ # Define a function to handle the chat interaction with the AI model
94
+ def chat(model, system_message, message, chatbot_messages, history_state):
95
+ # Initialize chatbot_messages and history_state if they are not provided
96
+ chatbot_messages = chatbot_messages or []
97
+ history_state = history_state or []
98
+
99
+ # Try to get the AI's reply using the get_ai_reply function
100
+ try:
101
+ ai_reply = get_ai_reply(message, model=model, system_message=system_message, message_history=history_state)
102
+ except Exception as e:
103
+ # If an error occurs, raise a Gradio error
104
+ raise gr.Error(e)
105
+
106
+ # Append the user's message and the AI's reply to the chatbot_messages list
107
+ chatbot_messages.append((message, ai_reply))
108
+
109
+ # Append the user's message and the AI's reply to the history_state list
110
+ history_state.append({"role": "user", "content": message})
111
+ history_state.append({"role": "assistant", "content": ai_reply})
112
+
113
+ # Return None (empty out the user's message textbox), the updated chatbot_messages, and the updated history_state
114
+ return None, chatbot_messages, history_state
115
+
116
+ # Define a function to launch the chatbot interface using Gradio
117
+ def get_chatbot_app(additional_examples=[]):
118
+ # Load chatbot examples and merge with any additional examples provided
119
+ examples = chatbot_examples.load_examples(additional=additional_examples)
120
+
121
+ # Define a function to get the names of the examples
122
+ def get_examples():
123
+ return [example["name"] for example in examples]
124
+
125
+ # Define a function to choose an example based on the index
126
+ def choose_example(index):
127
+ if(index!=None):
128
+ system_message = examples[index]["system_message"].strip()
129
+ user_message = examples[index]["message"].strip()
130
+ return system_message, user_message, [], []
131
+ else:
132
+ return "", "", [], []
133
+
134
+ # Create the Gradio interface using the Blocks layout
135
+ with gr.Blocks() as app:
136
+ with gr.Tab("Conversation"):
137
+ with gr.Row():
138
+ with gr.Column():
139
+ # Create a dropdown to select examples
140
+ example_dropdown = gr.Dropdown(get_examples(), label="Examples", type="index")
141
+ # Create a button to load the selected example
142
+ example_load_btn = gr.Button(value="Load")
143
+ # Create a textbox for the system message (prompt)
144
+ system_message = gr.TextArea(label="System Message (Prompt)", value="You are a helpful assistant.", lines=20, max_lines=400)
145
+ with gr.Column():
146
+ # Create a dropdown to select the AI model
147
+ model_selector = gr.Dropdown(
148
+ ["gpt-3.5-turbo"],
149
+ label="Model",
150
+ value="gpt-3.5-turbo"
151
+ )
152
+ # Create a chatbot interface for the conversation
153
+ chatbot = gr.Chatbot(label="Conversation")
154
+ # Create a textbox for the user's message
155
+ message = gr.Textbox(label="Message")
156
+ # Create a state object to store the conversation history
157
+ history_state = gr.State()
158
+ # Create a button to send the user's message
159
+ btn = gr.Button(value="Send")
160
+
161
+ # Connect the example load button to the choose_example function
162
+ example_load_btn.click(choose_example, inputs=[example_dropdown], outputs=[system_message, message, chatbot, history_state])
163
+ # Connect the send button to the chat function
164
+ btn.click(chat, inputs=[model_selector, system_message, message, chatbot, history_state], outputs=[message, chatbot, history_state])
165
+ with gr.Tab("Get API Key"):
166
+ email_box = gr.Textbox(label="Email Address", placeholder="Student Email")
167
+ password_box = gr.Textbox(label="Password", type="password", placeholder="Student ID")
168
+ btn = gr.Button(value ="Generate")
169
+ api_host_box = gr.Textbox(label="OpenAI API Base", interactive=False)
170
+ api_key_box = gr.Textbox(label="OpenAI API Key", interactive=False)
171
+ btn.click(get_api_key, inputs = [email_box, password_box], outputs = [api_host_box, api_key_box])
172
+ # Return the app
173
+ return app
174
+
175
+ app = FastAPI()
176
+
177
+ app.include_router(
178
+ fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
179
+ )
180
+ app.include_router(
181
+ fastapi_users.get_register_router(UserRead, UserCreate),
182
+ prefix="/auth",
183
+ tags=["auth"],
184
+ )
185
+ app.include_router(
186
+ fastapi_users.get_users_router(UserRead, UserUpdate),
187
+ prefix="/users",
188
+ tags=["users"],
189
+ )
190
+
191
+ @app.get("/authenticated-route")
192
+ async def authenticated_route(user: User = Depends(current_active_user)):
193
+ return {"message": f"Hello {user.email}!"}
194
+
195
+ @app.post("/v1/chat/completions")
196
+ async def openai_api_chat_completions_passthrough(
197
+ request: Request,
198
+ user: User = Depends(fastapi_users.current_user()),
199
+ ):
200
+ if not user:
201
+ raise HTTPException(status_code=401, detail="Unauthorized")
202
+
203
+ # Get the request data and headers
204
+ request_data = await request.json()
205
+ request_headers = request.headers
206
+ openai_api_key = os.getenv("OPENAI_API_KEY")
207
+
208
+ if(request_data['model']=='gpt-4' or request_data['model'] == 'gpt-4-32k'):
209
+ print("User requested gpt-4, falling back to gpt-3.5-turbo")
210
+ request_data['model'] = 'gpt-3.5-turbo'
211
+
212
+ # Forward the request to the OpenAI API
213
+ response = requests.post(
214
+ "https://api.openai.com/v1/chat/completions",
215
+ json=request_data,
216
+ headers={
217
+ "Content-Type": request_headers.get("Content-Type"),
218
+ "Authorization": f"Bearer {openai_api_key}",
219
+ },
220
+ )
221
+ print(response)
222
+
223
+ # Return the OpenAI API response
224
+ return response.json()
225
+
226
+ @app.on_event("startup")
227
+ async def on_startup():
228
+ # Not needed if you setup a migration system like Alembic
229
+ await create_db_and_tables()
230
+
231
+ gradio_gui = get_chatbot_app()
232
+ gradio_gui.auth = api_login
233
+ gradio_gui.auth_message = "Hello"
234
+ app = gr.mount_gradio_app(app, gradio_gui, path="/gradio")
app/db.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import AsyncGenerator
2
+
3
+ from fastapi import Depends
4
+ from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
5
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
6
+ from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
7
+ from sqlalchemy.orm import sessionmaker
8
+ from dotenv import load_dotenv
9
+ import os
10
+
11
+ # Get the current environment from the environment variable
12
+ current_environment = os.getenv("APP_ENV", "dev")
13
+
14
+ # Load the appropriate .env file based on the current environment
15
+ if current_environment == "dev":
16
+ load_dotenv(".env.dev")
17
+ elif current_environment == "test":
18
+ load_dotenv(".env.test")
19
+ elif current_environment == "prod":
20
+ load_dotenv(".env.prod")
21
+ else:
22
+ raise ValueError("Invalid environment specified")
23
+
24
+ db_connection_string = os.getenv("DB_CONNECTION_STRING")
25
+
26
+ DATABASE_URL = db_connection_string
27
+ Base: DeclarativeMeta = declarative_base()
28
+
29
+
30
+ class User(SQLAlchemyBaseUserTableUUID, Base):
31
+ pass
32
+
33
+
34
+ engine = create_async_engine(DATABASE_URL)
35
+ async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
36
+
37
+
38
+ async def create_db_and_tables():
39
+ async with engine.begin() as conn:
40
+ await conn.run_sync(Base.metadata.create_all)
41
+
42
+
43
+ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
44
+ async with async_session_maker() as session:
45
+ yield session
46
+
47
+
48
+ async def get_user_db(session: AsyncSession = Depends(get_async_session)):
49
+ yield SQLAlchemyUserDatabase(session, User)
app/schemas.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+
3
+ from fastapi_users import schemas
4
+
5
+
6
+ class UserRead(schemas.BaseUser[uuid.UUID]):
7
+ pass
8
+
9
+
10
+ class UserCreate(schemas.BaseUserCreate):
11
+ pass
12
+
13
+
14
+ class UserUpdate(schemas.BaseUserUpdate):
15
+ pass
app/users.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import os
3
+ from typing import Optional
4
+ from fastapi import Depends, Request
5
+ from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin
6
+ from fastapi_users.authentication import (
7
+ AuthenticationBackend,
8
+ BearerTransport,
9
+ JWTStrategy,
10
+ )
11
+ from fastapi_users.db import SQLAlchemyUserDatabase
12
+ from app.db import User, get_user_db
13
+ from dotenv import load_dotenv
14
+
15
+ # Get the current environment from the environment variable
16
+ current_environment = os.getenv("APP_ENV", "dev")
17
+
18
+ # Load the appropriate .env file based on the current environment
19
+ if current_environment == "dev":
20
+ load_dotenv(".env.dev")
21
+ elif current_environment == "test":
22
+ load_dotenv(".env.test")
23
+ elif current_environment == "prod":
24
+ load_dotenv(".env.prod")
25
+ else:
26
+ raise ValueError("Invalid environment specified")
27
+
28
+ SECRET = os.getenv("APP_SECRET")
29
+
30
+
31
+ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
32
+ reset_password_token_secret = SECRET
33
+ verification_token_secret = SECRET
34
+
35
+ async def on_after_register(self, user: User, request: Optional[Request] = None):
36
+ print(f"User {user.id} has registered.")
37
+
38
+ async def on_after_forgot_password(
39
+ self, user: User, token: str, request: Optional[Request] = None
40
+ ):
41
+ print(f"User {user.id} has forgot their password. Reset token: {token}")
42
+
43
+ async def on_after_request_verify(
44
+ self, user: User, token: str, request: Optional[Request] = None
45
+ ):
46
+ print(f"Verification requested for user {user.id}. Verification token: {token}")
47
+
48
+
49
+ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
50
+ yield UserManager(user_db)
51
+
52
+
53
+ bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
54
+
55
+
56
+ def get_jwt_strategy() -> JWTStrategy:
57
+ return JWTStrategy(secret=SECRET, lifetime_seconds=3600)
58
+
59
+
60
+ auth_backend = AuthenticationBackend(
61
+ name="jwt",
62
+ transport=bearer_transport,
63
+ get_strategy=get_jwt_strategy,
64
+ )
65
+
66
+ fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
67
+
68
+ current_active_user = fastapi_users.current_user(active=True)
main.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import uvicorn
2
+
3
+ if __name__ == "__main__":
4
+ uvicorn.run(f"app.app:app", host="0.0.0.0", port=8000, log_level="info")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.95.1
2
+ fastapi_users==10.4.2
3
+ gradio==3.27.0
4
+ httpx==0.24.0
5
+ openai==0.27.4
6
+ python-dotenv==1.0.0
7
+ Requests==2.28.2
8
+ SQLAlchemy==1.4.47
9
+ uvicorn==0.21.1