kamau1 commited on
Commit
bb15705
1 Parent(s): 156cab1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +119 -51
main.py CHANGED
@@ -10,19 +10,23 @@ from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import HTMLResponse
11
  import uvicorn
12
 
 
13
  from pydantic import BaseModel
14
  from pymongo import MongoClient
15
  import jwt
16
  from jwt import encode as jwt_encode
17
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
18
  from bson import ObjectId
19
 
 
20
  import ctranslate2
21
  import sentencepiece as spm
22
  import fasttext
 
23
 
24
- import pytz
25
  from datetime import datetime
 
 
 
26
  import os
27
 
28
  app = FastAPI()
@@ -47,9 +51,9 @@ templates_folder = os.path.join(os.path.dirname(__file__), "templates")
47
 
48
  # Authentication
49
  class User(BaseModel):
50
- username: str = None # Make the username field optional
51
- email: str
52
- password: str
53
 
54
  # Connect to the MongoDB database
55
  client = MongoClient("mongodb://localhost:27017")
@@ -64,63 +68,63 @@ security = HTTPBearer()
64
 
65
  @app.post("/login")
66
  def login(user: User):
67
- # Check if user exists in the database
68
- user_data = users_collection.find_one(
69
- {"email": user.email, "password": user.password}
70
- )
71
- if user_data:
72
- # Generate a token
73
- token = generate_token(user.email)
74
- # Convert ObjectId to string
75
- user_data["_id"] = str(user_data["_id"])
76
- # Store user details and token in local storage
77
- user_data["token"] = token
78
- return user_data
79
- return {"message": "Invalid email or password"}
80
 
81
  #Implement the registration route:
82
  @app.post("/register")
83
  def register(user: User):
84
- # Check if user already exists in the database
85
- existing_user = users_collection.find_one({"email": user.email})
86
- if existing_user:
87
  return {"message": "User already exists"}
88
- #Insert the new user into the database
89
- user_dict = user.dict()
90
- users_collection.insert_one(user_dict)
91
- # Generate a token
92
- token = generate_token(user.email)
93
- # Convert ObjectId to string
94
- user_dict["_id"] = str(user_dict["_id"])
95
- # Store user details and token in local storage
96
- user_dict["token"] = token
97
- return user_dict
98
 
99
 
100
  #Implement the `/api/user` route to fetch user data based on the JWT token
101
  @app.get("/api/user")
102
  def get_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
103
- # Extract the token from the Authorization header
104
- token = credentials.credentials
105
- # Authenticate and retrieve the user data from the database based on the token
106
- # Here, you would implement the authentication logic and fetch user details
107
- # based on the token from the database or any other authentication mechanism
108
- # For demonstration purposes, assuming the user data is stored in local storage
109
- # Note: Local storage is not accessible from server-side code
110
- # This is just a placeholder to demonstrate the concept
111
- user_data = {
112
- "username": "John Doe",
113
- "email": "johndoe@example.com"
114
- }
115
- if user_data["username"] and user_data["email"]:
116
- return user_data
117
- raise HTTPException(status_code=401, detail="Invalid token")
118
 
119
  #Define a helper function to generate a JWT token
120
  def generate_token(email: str) -> str:
121
- payload = {"email": email}
122
- token = jwt_encode(payload, SECRET_KEY, algorithm="HS256")
123
- return token
124
 
125
 
126
  # Get time of request
@@ -135,6 +139,29 @@ def get_time():
135
 
136
  full_date = f"{curr_day} | {curr_date} | {curr_time}"
137
  return full_date, curr_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Load the model and tokenizer ..... only once!
140
  beam_size = 1 # change to a smaller value for faster inference
@@ -154,13 +181,17 @@ sp_model_full_path = os.path.join(os.path.dirname(__file__), sp_model_file)
154
  sp = spm.SentencePieceProcessor()
155
  sp.load(sp_model_full_path)
156
 
 
157
  # Import The Translator model
158
  print("\nimporting Translator model")
159
  ct_model_file = "sematrans-3.3B"
160
  ct_model_full_path = os.path.join(os.path.dirname(__file__), ct_model_file)
161
  translator = ctranslate2.Translator(ct_model_full_path, device)
 
 
 
162
 
163
- print('\nDone importing models\n')
164
 
165
 
166
  def translate_detect(userinput: str, target_lang: str):
@@ -213,6 +244,25 @@ def translate_enter(userinput: str, source_lang: str, target_lang: str):
213
  # Return the source language and the translated text
214
  return translations_desubword[0]
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  @app.get("/", response_class=HTMLResponse)
218
  async def read_root(request: Request):
@@ -258,5 +308,23 @@ async def translate_enter_endpoint(request: Request):
258
  "translated_text": translated_text_e,
259
  }
260
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- print("\nAPI starting .......\n")
 
 
 
 
 
 
 
 
10
  from fastapi.responses import HTMLResponse
11
  import uvicorn
12
 
13
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
  from pydantic import BaseModel
15
  from pymongo import MongoClient
16
  import jwt
17
  from jwt import encode as jwt_encode
 
18
  from bson import ObjectId
19
 
20
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
21
  import ctranslate2
22
  import sentencepiece as spm
23
  import fasttext
24
+ import torch
25
 
 
26
  from datetime import datetime
27
+ import gradio as gr
28
+ import pytz
29
+ import time
30
  import os
31
 
32
  app = FastAPI()
 
51
 
52
  # Authentication
53
  class User(BaseModel):
54
+ username: str = None # Make the username field optional
55
+ email: str
56
+ password: str
57
 
58
  # Connect to the MongoDB database
59
  client = MongoClient("mongodb://localhost:27017")
 
68
 
69
  @app.post("/login")
70
  def login(user: User):
71
+ # Check if user exists in the database
72
+ user_data = users_collection.find_one(
73
+ {"email": user.email, "password": user.password}
74
+ )
75
+ if user_data:
76
+ # Generate a token
77
+ token = generate_token(user.email)
78
+ # Convert ObjectId to string
79
+ user_data["_id"] = str(user_data["_id"])
80
+ # Store user details and token in local storage
81
+ user_data["token"] = token
82
+ return user_data
83
+ return {"message": "Invalid email or password"}
84
 
85
  #Implement the registration route:
86
  @app.post("/register")
87
  def register(user: User):
88
+ # Check if user already exists in the database
89
+ existing_user = users_collection.find_one({"email": user.email})
90
+ if existing_user:
91
  return {"message": "User already exists"}
92
+ #Insert the new user into the database
93
+ user_dict = user.dict()
94
+ users_collection.insert_one(user_dict)
95
+ # Generate a token
96
+ token = generate_token(user.email)
97
+ # Convert ObjectId to string
98
+ user_dict["_id"] = str(user_dict["_id"])
99
+ # Store user details and token in local storage
100
+ user_dict["token"] = token
101
+ return user_dict
102
 
103
 
104
  #Implement the `/api/user` route to fetch user data based on the JWT token
105
  @app.get("/api/user")
106
  def get_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
107
+ # Extract the token from the Authorization header
108
+ token = credentials.credentials
109
+ # Authenticate and retrieve the user data from the database based on the token
110
+ # Here, you would implement the authentication logic and fetch user details
111
+ # based on the token from the database or any other authentication mechanism
112
+ # For demonstration purposes, assuming the user data is stored in local storage
113
+ # Note: Local storage is not accessible from server-side code
114
+ # This is just a placeholder to demonstrate the concept
115
+ user_data = {
116
+ "username": "John Doe",
117
+ "email": "johndoe@example.com"
118
+ }
119
+ if user_data["username"] and user_data["email"]:
120
+ return user_data
121
+ raise HTTPException(status_code=401, detail="Invalid token")
122
 
123
  #Define a helper function to generate a JWT token
124
  def generate_token(email: str) -> str:
125
+ payload = {"email": email}
126
+ token = jwt_encode(payload, SECRET_KEY, algorithm="HS256")
127
+ return token
128
 
129
 
130
  # Get time of request
 
139
 
140
  full_date = f"{curr_day} | {curr_date} | {curr_time}"
141
  return full_date, curr_time
142
+
143
+
144
+ def load_models():
145
+ # build model and tokenizer
146
+ model_name_dict = {
147
+ #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
148
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
149
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
150
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
151
+ 'nllb-moe-54b': 'facebook/nllb-moe-54b',
152
+ }
153
+
154
+ model_dict = {}
155
+
156
+ for call_name, real_name in model_name_dict.items():
157
+ print('\tLoading model: %s' % call_name)
158
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
159
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
160
+ model_dict[call_name+'_model'] = model
161
+ model_dict[call_name+'_tokenizer'] = tokenizer
162
+
163
+ return model_dict
164
+
165
 
166
  # Load the model and tokenizer ..... only once!
167
  beam_size = 1 # change to a smaller value for faster inference
 
181
  sp = spm.SentencePieceProcessor()
182
  sp.load(sp_model_full_path)
183
 
184
+ '''
185
  # Import The Translator model
186
  print("\nimporting Translator model")
187
  ct_model_file = "sematrans-3.3B"
188
  ct_model_full_path = os.path.join(os.path.dirname(__file__), ct_model_file)
189
  translator = ctranslate2.Translator(ct_model_full_path, device)
190
+ '''
191
+ print("\nimporting Translator model")
192
+ model_dict = load_models()
193
 
194
+ print('\nDone importing models\n')
195
 
196
 
197
  def translate_detect(userinput: str, target_lang: str):
 
244
  # Return the source language and the translated text
245
  return translations_desubword[0]
246
 
247
+ def translate_faster(userinput3: str, source_lang3: str, target_lang3: str):
248
+ if len(model_dict) == 2:
249
+ model_name = 'nllb-moe-54b'
250
+
251
+ start_time = time.time()
252
+
253
+ model = model_dict[model_name + '_model']
254
+ tokenizer = model_dict[model_name + '_tokenizer']
255
+
256
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source_lang3, tgt_lang=target_lang3)
257
+ output = translator(userinput3, max_length=400)
258
+ end_time = time.time()
259
+
260
+ output = output[0]['translation_text']
261
+ result = {'inference_time': end_time - start_time,
262
+ 'source': source,
263
+ 'target': target,
264
+ 'result': output}
265
+ return result
266
 
267
  @app.get("/", response_class=HTMLResponse)
268
  async def read_root(request: Request):
 
308
  "translated_text": translated_text_e,
309
  }
310
 
311
+ @app.post("/translate_faster/")
312
+ async def translate_faster_endpoint(request: Request):
313
+ dataf = await request.json()
314
+ userinputf = datae.get("userinput")
315
+ source_langf = datae.get("source_lang")
316
+ target_langf = datae.get("target_lang")
317
+ ffull_date = get_time()[0]
318
+ print(f"\nrequest: {ffull_date}\nSource_language; {source_langf}, Target Language; {target_langf}, User Input: {userinputf}\n")
319
+
320
+ if not userinputf or not target_langf:
321
+ raise HTTPException(status_code=422, detail="'userinput' 'sourc_lang'and 'target_lang' are required.")
322
 
323
+ translated_text_f = translate_faster(userinputf, source_langf, target_langf)
324
+ fcurrent_time = get_time()[1]
325
+ print(f"\nresponse: {fcurrent_time}; ... Translated Text: {translated_text_f}\n\n")
326
+ return {
327
+ "translated_text": translated_text_f,
328
+ }
329
+
330
+ print("\nAPI started successfully .......\n")