loubnabnl HF staff commited on
Commit
41d27ac
1 Parent(s): c3fb384

add multiprocessing in inference and clean code

Browse files
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -1,29 +1,27 @@
1
- import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from transformers import pipeline
4
- import torch
5
  import json
6
  import pandas as pd
7
  import requests
 
 
 
 
8
 
9
  GITHUB_CODE = "https://huggingface.co/datasets/lvwerra/github-code"
10
  INCODER_IMG = "https://huggingface.co/datasets/loubnabnl/repo-images/raw/main/incoder.png"
11
 
12
- @st.cache(allow_output_mutation=True)
13
- def load_tokenizer(model_ckpt):
14
- return AutoTokenizer.from_pretrained(model_ckpt)
15
-
16
- @st.cache(allow_output_mutation=True)
17
- def load_model(model_ckpt):
18
- model = AutoModelForCausalLM.from_pretrained(model_ckpt, low_cpu_mem_usage=True)
19
- return model
20
-
21
  @st.cache()
22
  def load_examples():
23
  with open("utils/examples.json", "r") as f:
24
  examples = json.load(f)
25
  return examples
26
 
 
 
 
 
 
 
 
27
  st.set_page_config(page_icon=":laptop:", layout="wide")
28
 
29
  st.sidebar.header("Models")
@@ -84,9 +82,11 @@ elif selected_task == "Code generation":
84
  gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
85
  if st.button("Generate code!"):
86
  with st.spinner("Generating code..."):
87
- for model in selected_models:
88
- url = f'https://hf.space/embed/loubnabnl/{model.lower()}-subspace/+/api/predict/'
89
- r = requests.post(url=url, json={"data": [gen_prompt, max_new_tokens, temperature, seed]})
90
- generated_text = r.json()['data'][0]
91
- st.markdown(f"{model}")
92
- st.code(generated_text)
 
 
 
 
 
 
 
1
  import json
2
  import pandas as pd
3
  import requests
4
+ from multiprocessing import Pool
5
+ from functools import partial
6
+ import streamlit as st
7
+
8
 
9
  GITHUB_CODE = "https://huggingface.co/datasets/lvwerra/github-code"
10
  INCODER_IMG = "https://huggingface.co/datasets/loubnabnl/repo-images/raw/main/incoder.png"
11
 
 
 
 
 
 
 
 
 
 
12
  @st.cache()
13
  def load_examples():
14
  with open("utils/examples.json", "r") as f:
15
  examples = json.load(f)
16
  return examples
17
 
18
+ def generate_code(model_name, gen_prompt, max_new_tokens, temperature, seed):
19
+ url = f'https://hf.space/embed/loubnabnl/{model_name.lower()}-subspace/+/api/predict/'
20
+ r = requests.post(url=url, json={"data": [gen_prompt, max_new_tokens, temperature, seed]})
21
+ generated_text = r.json()['data'][0]
22
+ st.markdown(model_name)
23
+ st.code(generated_text)
24
+
25
  st.set_page_config(page_icon=":laptop:", layout="wide")
26
 
27
  st.sidebar.header("Models")
 
82
  gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
83
  if st.button("Generate code!"):
84
  with st.spinner("Generating code..."):
85
+ # Create a multiprocessing Pool
86
+ pool = Pool()
87
+ generate_parallel=partial(generate_code,
88
+ gen_prompt=gen_prompt,
89
+ max_new_tokens=max_new_tokens,
90
+ temperature=temperature,
91
+ seed=seed)
92
+ pool.map(generate_parallel, selected_models)