speech-test commited on
Commit
f0dfc26
β€’
1 Parent(s): 9ece248

temporary API rate limit fix

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -2,14 +2,19 @@ import random
2
  import torch
3
  import gradio as gr
4
  from gradio.mix import Series
 
5
  from rudalle.pipelines import generate_images
6
  from rudalle import get_rudalle_model, get_tokenizer, get_vae
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
9
  dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
10
  tokenizer = get_tokenizer()
11
  vae = get_vae().to(device)
12
 
 
 
 
13
  def dalle_wrapper(prompt: str):
14
  top_k, top_p = random.choice([
15
  (1024, 0.98),
@@ -30,8 +35,9 @@ def dalle_wrapper(prompt: str):
30
  return title, images[0]
31
 
32
 
33
- translator = gr.Interface.load("huggingface/facebook/wmt19-en-ru",
34
- inputs=[gr.inputs.Textbox(label="What would you like to see?")])
 
35
  outputs = [
36
  gr.outputs.HTML(label=""),
37
  gr.outputs.Image(label=""),
 
2
  import torch
3
  import gradio as gr
4
  from gradio.mix import Series
5
+ from transformers import pipeline
6
  from rudalle.pipelines import generate_images
7
  from rudalle import get_rudalle_model, get_tokenizer, get_vae
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ translation_pipe = pipeline("translation", model="facebook/wmt19-en-ru", device=0)
11
  dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
12
  tokenizer = get_tokenizer()
13
  vae = get_vae().to(device)
14
 
15
+ def translation_wrapper(text: str):
16
+ return translation_pipe(text)[0]["translation_text"]
17
+
18
  def dalle_wrapper(prompt: str):
19
  top_k, top_p = random.choice([
20
  (1024, 0.98),
 
35
  return title, images[0]
36
 
37
 
38
+ translator = gr.Interface(fn=translation_wrapper,
39
+ inputs=[gr.inputs.Textbox(label='What would you like to see?')],
40
+ outputs="text")
41
  outputs = [
42
  gr.outputs.HTML(label=""),
43
  gr.outputs.Image(label=""),