yonikremer commited on
Commit
826e275
1 Parent(s): feb3275

created an initial app

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The Streamlit app for the project demo.
3
+ In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm.
4
+ """
5
+
6
+ import streamlit as st
7
+ from grouped_sampling import GroupedSamplingPipeLine
8
+
9
+ available_models_list = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
10
+
11
+
12
+ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
13
+ """
14
+ Creates a pipeline with the given model name and group size.
15
+ :param model_name: The name of the model to use.
16
+ :param group_size: The size of the groups to use.
17
+ :return: A pipeline with the given model name and group size.
18
+ """
19
+ return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size)
20
+
21
+
22
+ def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
23
+ """
24
+ Called when the user submits the form.
25
+ :param model_name: The name of the model to use.
26
+ :param group_size: The size of the groups to use.
27
+ :param prompt: The prompt to use.
28
+ :return: The output of the model.
29
+ """
30
+ pipeline = create_pipeline(model_name, group_size)
31
+ return pipeline(prompt)["generated_text"]
32
+
33
+
34
+ with st.form("request_form"):
35
+ selected_model_name: str = st.text_input(
36
+ label="Model name",
37
+ value="gpt2",
38
+ help=f"The name of the model to use. Must be a model from this list: {available_models_list}"
39
+ )
40
+
41
+ output_length: int = st.number_input(
42
+ label="Output Length in tokens",
43
+ min_value=1,
44
+ max_value=4096,
45
+ value=100,
46
+ help="The length of the output text in tokens (word pieces)."
47
+ )
48
+
49
+ submitted_prompt: str = st.text_area(
50
+ label="Input for the model",
51
+ help="Enter the prompt for the model. The model will generate a response based on this prompt.",
52
+ max_chars=16384,
53
+ )
54
+
55
+ submitted: bool = st.form_submit_button(
56
+ label="Generate",
57
+ help="Generate the output text.",
58
+ disabled=False
59
+
60
+ )
61
+
62
+ if submitted:
63
+ output = on_form_submit(selected_model_name, output_length, submitted_prompt)
64
+ st.write(f"Generated text: {output}")