boris commited on
Commit
d209547
1 Parent(s): fb1fbca

style: use isort

Browse files
.github/workflows/black.yml DELETED
@@ -1,14 +0,0 @@
1
- name: Lint
2
-
3
- on:
4
- push:
5
- branches: [main]
6
- pull_request:
7
- branches: [main]
8
-
9
- jobs:
10
- lint:
11
- runs-on: ubuntu-latest
12
- steps:
13
- - uses: actions/checkout@v2
14
- - uses: psf/black@stable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/gradio/app_gradio.py CHANGED
@@ -7,26 +7,18 @@
7
 
8
  import random
9
 
 
10
  import jax
11
- import flax.linen as nn
12
- from flax.training.common_utils import shard
13
  from flax.jax_utils import replicate
14
-
15
- from transformers import BartTokenizer
16
-
17
  from PIL import Image, ImageDraw, ImageFont
18
- import numpy as np
19
-
20
- from vqgan_jax.modeling_flax_vqgan import VQModel
21
- from dalle_mini.model import CustomFlaxBartForConditionalGeneration
22
 
23
  # ## CLIP Scoring
24
- from transformers import CLIPProcessor, FlaxCLIPModel
25
-
26
- import gradio as gr
27
-
28
- from PIL import Image, ImageDraw, ImageFont
29
 
 
30
 
31
  DALLE_REPO = "flax-community/dalle-mini"
32
  DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
 
7
 
8
  import random
9
 
10
+ import gradio as gr
11
  import jax
12
+ import numpy as np
 
13
  from flax.jax_utils import replicate
14
+ from flax.training.common_utils import shard
 
 
15
  from PIL import Image, ImageDraw, ImageFont
 
 
 
 
16
 
17
  # ## CLIP Scoring
18
+ from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
19
+ from vqgan_jax.modeling_flax_vqgan import VQModel
 
 
 
20
 
21
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
22
 
23
  DALLE_REPO = "flax-community/dalle-mini"
24
  DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
app/streamlit/app.py CHANGED
@@ -1,9 +1,10 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
- from .backend import ServiceError, get_images_from_backend
5
  import streamlit as st
6
 
 
 
7
  st.sidebar.markdown(
8
  """
9
  <style>
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
 
4
  import streamlit as st
5
 
6
+ from .backend import ServiceError, get_images_from_backend
7
+
8
  st.sidebar.markdown(
9
  """
10
  <style>
app/streamlit/backend.py CHANGED
@@ -1,6 +1,7 @@
1
- import requests
2
- from io import BytesIO
3
  import base64
 
 
 
4
  from PIL import Image
5
 
6
 
 
 
 
1
  import base64
2
+ from io import BytesIO
3
+
4
+ import requests
5
  from PIL import Image
6
 
7
 
dalle_mini/data.py CHANGED
@@ -1,10 +1,12 @@
1
  from dataclasses import dataclass, field
2
- from datasets import load_dataset, Dataset
3
  from functools import partial
4
- import numpy as np
5
  import jax
6
  import jax.numpy as jnp
 
 
7
  from flax.training.common_utils import shard
 
8
  from .text import TextNormalizer
9
 
10
 
 
1
  from dataclasses import dataclass, field
 
2
  from functools import partial
3
+
4
  import jax
5
  import jax.numpy as jnp
6
+ import numpy as np
7
+ from datasets import Dataset, load_dataset
8
  from flax.training.common_utils import shard
9
+
10
  from .text import TextNormalizer
11
 
12
 
dalle_mini/model.py CHANGED
@@ -1,16 +1,14 @@
1
- import jax
2
  import flax.linen as nn
3
-
 
4
  from transformers.models.bart.modeling_flax_bart import (
5
- FlaxBartModule,
6
- FlaxBartForConditionalGenerationModule,
7
- FlaxBartForConditionalGeneration,
8
- FlaxBartEncoder,
9
  FlaxBartDecoder,
 
 
 
 
10
  )
11
 
12
- from transformers import BartConfig
13
-
14
 
15
  class CustomFlaxBartModule(FlaxBartModule):
16
  def setup(self):
 
 
1
  import flax.linen as nn
2
+ import jax
3
+ from transformers import BartConfig
4
  from transformers.models.bart.modeling_flax_bart import (
 
 
 
 
5
  FlaxBartDecoder,
6
+ FlaxBartEncoder,
7
+ FlaxBartForConditionalGeneration,
8
+ FlaxBartForConditionalGenerationModule,
9
+ FlaxBartModule,
10
  )
11
 
 
 
12
 
13
  class CustomFlaxBartModule(FlaxBartModule):
14
  def setup(self):
dalle_mini/text.py CHANGED
@@ -2,13 +2,15 @@
2
  Utilities for processing text.
3
  """
4
 
 
 
 
 
5
  from pathlib import Path
6
- from unidecode import unidecode
7
 
8
- import re, math, random, html
9
  import ftfy
10
-
11
  from huggingface_hub import hf_hub_download
 
12
 
13
  # based on wiki word occurence
14
  person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
 
2
  Utilities for processing text.
3
  """
4
 
5
+ import html
6
+ import math
7
+ import random
8
+ import re
9
  from pathlib import Path
 
10
 
 
11
  import ftfy
 
12
  from huggingface_hub import hf_hub_download
13
+ from unidecode import unidecode
14
 
15
  # based on wiki word occurence
16
  person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
tools/train/train.py CHANGED
@@ -18,37 +18,31 @@ Fine-tuning the library models for seq2seq, text to image.
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
- import os
22
  import logging
 
23
  import sys
24
  import time
25
- from dataclasses import dataclass, field
26
  from pathlib import Path
27
  from typing import Callable, Optional
28
- import json
29
 
30
  import datasets
31
- from datasets import Dataset
32
- from tqdm import tqdm
33
- from dataclasses import asdict
34
-
35
  import jax
36
  import jax.numpy as jnp
37
  import optax
38
  import transformers
 
 
39
  from flax import jax_utils, traverse_util
40
- from flax.serialization import from_bytes, to_bytes
41
  from flax.jax_utils import unreplicate
 
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
44
- from transformers import (
45
- AutoTokenizer,
46
- HfArgumentParser,
47
- )
48
  from transformers.models.bart.modeling_flax_bart import BartConfig
49
 
50
- import wandb
51
-
52
  from dalle_mini.data import Dataset
53
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
54
 
 
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
+ import json
22
  import logging
23
+ import os
24
  import sys
25
  import time
26
+ from dataclasses import asdict, dataclass, field
27
  from pathlib import Path
28
  from typing import Callable, Optional
 
29
 
30
  import datasets
 
 
 
 
31
  import jax
32
  import jax.numpy as jnp
33
  import optax
34
  import transformers
35
+ import wandb
36
+ from datasets import Dataset
37
  from flax import jax_utils, traverse_util
 
38
  from flax.jax_utils import unreplicate
39
+ from flax.serialization import from_bytes, to_bytes
40
  from flax.training import train_state
41
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
42
+ from tqdm import tqdm
43
+ from transformers import AutoTokenizer, HfArgumentParser
 
 
44
  from transformers.models.bart.modeling_flax_bart import BartConfig
45
 
 
 
46
  from dalle_mini.data import Dataset
47
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
48