rajeshradhakrishnan commited on
Commit
696bdbb
1 Parent(s): 617c204

update to facebook Bart50

Browse files
Files changed (2) hide show
  1. main.py +14 -4
  2. static/script.js +1 -1
main.py CHANGED
@@ -2,18 +2,28 @@ import os
2
  from fastapi import FastAPI, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import FileResponse
5
- from transformers import pipeline
6
 
7
 
 
 
 
8
  app = FastAPI()
9
 
10
 
11
- pipe_flan = pipeline("translation_en_to_ml", model="google/flan-t5-small")
12
 
13
  @app.get("/infer_t5")
14
  def t5(input):
15
- output = pipe_flan(input)
16
- return {"output": output[0]["generated_text"]}
 
 
 
 
 
 
 
 
17
 
18
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
19
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import FileResponse
5
+ from transformers import pipeline, MBartForConditionalGeneration, MBart50TokenizerFast
6
 
7
 
8
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
9
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
10
+
11
  app = FastAPI()
12
 
13
 
 
14
 
15
  @app.get("/infer_t5")
16
  def t5(input):
17
+ model_inputs = tokenizer(input, return_tensors="pt")
18
+
19
+ # translate from English to Malayalam
20
+ generated_tokens = model.generate(
21
+ **model_inputs,
22
+ forced_bos_token_id=tokenizer.lang_code_to_id["ml_IN"]
23
+ )
24
+
25
+ output = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
26
+ return {"output":output}
27
 
28
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
29
 
static/script.js CHANGED
@@ -68,7 +68,7 @@ async function getMessage(){
68
  const [prompterText, assistantText] = generatePrompterAssistantText(data[0].generated_text);
69
  // const en_text_ml = "English: " + assistantText[0] + " Malayalam:";
70
  // console.log(en_text_ml)
71
- outPutElement.textContent = await translateText(assistantText);
72
  const pElement = document.createElement('p')
73
  pElement.textContent = inputElement.value
74
  pElement.addEventListener('click', () => changeInput(pElement.textContent))
 
68
  const [prompterText, assistantText] = generatePrompterAssistantText(data[0].generated_text);
69
  // const en_text_ml = "English: " + assistantText[0] + " Malayalam:";
70
  // console.log(en_text_ml)
71
+ outPutElement.textContent = await translateText(assistantText[0]);
72
  const pElement = document.createElement('p')
73
  pElement.textContent = inputElement.value
74
  pElement.addEventListener('click', () => changeInput(pElement.textContent))