radinhas commited on
Commit
82cccd9
·
1 Parent(s): 312f0cc

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +73 -6
apis/chat_api.py CHANGED
@@ -1,7 +1,14 @@
1
  import argparse
2
  import uvicorn
3
  import sys
 
 
 
 
4
  import json
 
 
 
5
  import string
6
  import random
7
  import base64
@@ -31,12 +38,12 @@ class ChatAPIApp:
31
  )
32
  self.setup_routes()
33
 
34
- def get_available_models(self):
35
  f = open('apis/lang_name.json', "r")
36
  self.available_models = json.loads(f.read())
37
  return self.available_models
38
 
39
- class ChatCompletionsPostItem(BaseModel):
40
  from_language: str = Field(
41
  default="auto",
42
  description="(str) `Detect`",
@@ -51,7 +58,7 @@ class ChatAPIApp:
51
  )
52
 
53
 
54
- def chat_completions(self, item: ChatCompletionsPostItem):
55
  translator = Translator()
56
  f = open('apis/lang_name.json', "r")
57
  available_langs = json.loads(f.read())
@@ -73,6 +80,60 @@ class ChatAPIApp:
73
  json_compatible_item_data = jsonable_encoder(item_response)
74
  return JSONResponse(content=json_compatible_item_data)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  class DetectLanguagePostItem(BaseModel):
78
  input_text: str = Field(
@@ -125,15 +186,21 @@ class ChatAPIApp:
125
  def setup_routes(self):
126
  for prefix in ["", "/v1"]:
127
  self.app.get(
128
- prefix + "/models",
129
  summary="Get available languages",
130
- )(self.get_available_models)
131
 
132
  self.app.post(
133
  prefix + "/translate",
134
  summary="translate text",
135
- )(self.chat_completions)
 
 
 
 
 
136
 
 
137
  self.app.post(
138
  prefix + "/detect",
139
  summary="detect language",
 
1
  import argparse
2
  import uvicorn
3
  import sys
4
+ import os
5
+ import io
6
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
7
+ import time
8
  import json
9
+ from typing import List
10
+ import torch
11
+ import logging
12
  import string
13
  import random
14
  import base64
 
38
  )
39
  self.setup_routes()
40
 
41
+ def get_available_langs(self):
42
  f = open('apis/lang_name.json', "r")
43
  self.available_models = json.loads(f.read())
44
  return self.available_models
45
 
46
+ class TranslateCompletionsPostItem(BaseModel):
47
  from_language: str = Field(
48
  default="auto",
49
  description="(str) `Detect`",
 
58
  )
59
 
60
 
61
+ def translate_completions(self, item: TranslateCompletionsPostItem):
62
  translator = Translator()
63
  f = open('apis/lang_name.json', "r")
64
  available_langs = json.loads(f.read())
 
80
  json_compatible_item_data = jsonable_encoder(item_response)
81
  return JSONResponse(content=json_compatible_item_data)
82
 
83
+ def translate_ai_completions(self, item: TranslateCompletionsPostItem):
84
+ translator = Translator()
85
+ f = open('apis/lang_name.json', "r")
86
+ available_langs = json.loads(f.read())
87
+ from_lang = 'en'
88
+ to_lang = 'en'
89
+ for lang_item in available_langs:
90
+ if item.to_language == lang_item['code']:
91
+ to_lang = item.to_language
92
+ if item.from_language == lang_item['code']:
93
+ from_lang = item.from_language
94
+
95
+ if to_lang == 'auto':
96
+ to_lang = 'en'
97
+
98
+ if from_lang == 'auto':
99
+ from_lang = translator.detect(item.input_text).lang
100
+
101
+ if torch.cuda.is_available():
102
+ device = torch.device("cuda:0")
103
+ else:
104
+ device = torch.device("cpu")
105
+ logging.warning("GPU not found, using CPU, translation will be very slow.")
106
+
107
+ time_start = time.time()
108
+
109
+ tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
110
+ model = M2M100ForConditionalGeneration.from_pretrained(
111
+ "facebook/m2m100_1.2B", cache_dir="models/"
112
+ ).to(device)
113
+ model.eval()
114
+
115
+ tokenizer.src_lang = from_lang
116
+ with torch.no_grad():
117
+ encoded_input = tokenizer(item.input_text, return_tensors="pt").to(device)
118
+ generated_tokens = model.generate(
119
+ **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(to_lang)
120
+ )
121
+ translated_text = tokenizer.batch_decode(
122
+ generated_tokens, skip_special_tokens=True
123
+ )[0]
124
+
125
+ time_end = time.time()
126
+ translated = translated_text
127
+ item_response = {
128
+ "from_language": from_lang,
129
+ "to_language": to_lang,
130
+ "text": item.input_text,
131
+ "translate": translated,
132
+ "start": str(time_start),
133
+ "end": str(time_end)
134
+ }
135
+ json_compatible_item_data = jsonable_encoder(item_response)
136
+ return JSONResponse(content=json_compatible_item_data)
137
 
138
  class DetectLanguagePostItem(BaseModel):
139
  input_text: str = Field(
 
186
  def setup_routes(self):
187
  for prefix in ["", "/v1"]:
188
  self.app.get(
189
+ prefix + "/langs",
190
  summary="Get available languages",
191
+ )(self.get_available_langs)
192
 
193
  self.app.post(
194
  prefix + "/translate",
195
  summary="translate text",
196
+ )(self.translate_completions)
197
+
198
+ self.app.post(
199
+ prefix + "/translate/ai",
200
+ summary="translate text with ai",
201
+ )(self.translate_ai_completions)
202
 
203
+
204
  self.app.post(
205
  prefix + "/detect",
206
  summary="detect language",