Ebimsv commited on
Commit
1ed72c0
1 Parent(s): 333fb39

Chatbot with TinyLlama

Browse files
Files changed (4) hide show
  1. app.py +120 -0
  2. imgs/TinyLlama_logo.png +0 -0
  3. imgs/user_logo.png +0 -0
  4. requirements.txt +86 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ctransformers import AutoModelForCausalLM, AutoConfig, Config #import for GGUF/GGML models
3
+ import datetime
4
+
5
+ modelfile="TinyLlama/TinyLlama-1.1B-Chat-v0.6"
6
+
7
+ i_temperature = 0.30
8
+ i_max_new_tokens=1100
9
+ i_repetitionpenalty = 1.2
10
+ i_contextlength=12048
11
+ logfile = 'TinyLlama.1B.txt'
12
+
13
+ print("loading model...")
14
+
15
+ stt = datetime.datetime.now()
16
+ conf = AutoConfig(Config(temperature=i_temperature,
17
+ repetition_penalty=i_repetitionpenalty,
18
+ batch_size=64,
19
+ max_new_tokens=i_max_new_tokens,
20
+ context_length=i_contextlength))
21
+ llm = AutoModelForCausalLM.from_pretrained(modelfile,
22
+ model_type="llama",
23
+ config=conf)
24
+ dt = datetime.datetime.now() - stt
25
+ print(f"Model loaded in {dt}")
26
+
27
+ def writehistory(text):
28
+ with open(logfile, 'a', encoding='utf-8') as f:
29
+ f.write(text)
30
+ f.write('\n')
31
+ f.close()
32
+
33
+ with gr.Blocks(theme='ParityError/Interstellar') as demo:
34
+ # TITLE SECTION
35
+ with gr.Row():
36
+ with gr.Column(scale=12):
37
+ gr.HTML("<center>"
38
+ + "<h1>🦙 TinyLlama 1.1B 🐋 4K context window</h2></center>")
39
+ gr.Markdown("""
40
+ **Currently Running**: [TinyLlama/TinyLlama-1.1B-Chat-v0.6](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.6) &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; **Chat History Log File**: *TinyLlama.1B.txt*
41
+
42
+ - **Base Model**: TinyLlama/TinyLlama-1.1B-Chat-v0.6, Fine tuned on OpenOrca GPT4 subset for 1 epoch, Using CHATML format.
43
+ - **License**: Apache 2.0, following the TinyLlama base model.
44
+ The model output is not censored and the authors do not endorse the opinions in the generated content. Use at your own risk.
45
+ """)
46
+ gr.Image(value='imgs/TinyLlama_logo.png', width=70)
47
+
48
+ # chat and parameters settings
49
+ with gr.Row():
50
+ with gr.Column(scale=4):
51
+ chatbot = gr.Chatbot(height = 350, show_copy_button=True, avatar_images = ["imgs/user_logo.png","imgs/TinyLlama_logo.png"])
52
+ with gr.Row():
53
+ with gr.Column(scale=14):
54
+ msg = gr.Textbox(show_label=False, placeholder="Enter text", lines=2)
55
+ submitBtn = gr.Button("\n💬 Send\n", size="lg", variant="primary", min_width=140)
56
+
57
+ with gr.Column(min_width=50, scale=1):
58
+ with gr.Tab(label="Parameter Setting"):
59
+ gr.Markdown("# Parameters")
60
+ top_p = gr.Slider(minimum=-0,
61
+ maximum=1.0,
62
+ value=0.95,
63
+ step=0.05,
64
+ interactive=True,
65
+ label="Top-p")
66
+ temperature = gr.Slider(minimum=0.1,
67
+ maximum=1.0,
68
+ value=0.30,
69
+ step=0.01,
70
+ interactive=True,
71
+ label="Temperature")
72
+ max_length_tokens = gr.Slider(minimum=0,
73
+ maximum=4096,
74
+ value=1060,
75
+ step=4,
76
+ interactive=True,
77
+ label="Max Generation Tokens")
78
+ rep_pen = gr.Slider(minimum=0,
79
+ maximum=5,
80
+ value=1.2,
81
+ step=0.05,
82
+ interactive=True,
83
+ label="Repetition Penalty")
84
+
85
+ clear = gr.Button("🗑️ Clear All Messages", variant='secondary')
86
+ def user(user_message, history):
87
+ writehistory(f"USER: {user_message}")
88
+ return "", history + [[user_message, None]]
89
+
90
+ def bot(history, t, p, m, r):
91
+ SYSTEM_PROMPT = """<|im_start|>system
92
+ You are a helpful bot. Your answers are clear and concise.
93
+ <|im_end|>
94
+
95
+ """
96
+ prompt = f"<|im_start|>system<|im_end|><|im_start|>user\n{history[-1][0]}<|im_end|>\n<|im_start|>assistant\n"
97
+ print(f"history lenght: {len(history)}")
98
+ if len(history) == 1:
99
+ print("this is the first round")
100
+ else:
101
+ print("here we should pass more conversations")
102
+ history[-1][1] = ""
103
+ for character in llm(prompt,
104
+ temperature = t,
105
+ top_p = p,
106
+ repetition_penalty = r,
107
+ max_new_tokens=m,
108
+ stop = ['<|im_end|>'],
109
+ stream = True):
110
+ history[-1][1] += character
111
+ yield history
112
+ writehistory(f"temperature: {t}, top_p: {p}, maxNewTokens: {m}, repetitionPenalty: {r}\n---\nBOT: {history}\n\n")
113
+ # Log in the terminal the messages
114
+ print(f"USER: {history[-1][0]}\n---\ntemperature: {t}, top_p: {p}, maxNewTokens: {m}, repetitionPenalty: {r}\n---\nBOT: {history[-1][1]}\n\n")
115
+ # Clicking the submitBtn will call the generation with Parameters in the slides
116
+ submitBtn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot,temperature,top_p,max_length_tokens,rep_pen], chatbot)
117
+ clear.click(lambda: None, None, chatbot, queue=False)
118
+
119
+ demo.queue() # required to yield the streams from the text generation
120
+ demo.launch(inbrowser=True, share=True)
imgs/TinyLlama_logo.png ADDED
imgs/user_logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ async-timeout==4.0.3
8
+ attrs==23.2.0
9
+ certifi==2024.2.2
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ cmake==3.28.3
13
+ colorama==0.4.6
14
+ contourpy==1.2.0
15
+ cycler==0.12.1
16
+ exceptiongroup==1.2.0
17
+ fastapi==0.110.0
18
+ ffmpy==0.3.2
19
+ filelock==3.13.1
20
+ fonttools==4.50.0
21
+ frozenlist==1.4.1
22
+ fsspec==2024.3.0
23
+ gradio==4.21.0
24
+ gradio_client==0.12.0
25
+ h11==0.14.0
26
+ httpcore==1.0.4
27
+ httpx==0.27.0
28
+ huggingface-hub==0.21.4
29
+ idna==3.6
30
+ importlib_resources==6.3.0
31
+ Jinja2==3.1.3
32
+ jsonschema==4.21.1
33
+ jsonschema-specifications==2023.12.1
34
+ kiwisolver==1.4.5
35
+ linkify-it-py==2.0.3
36
+ lit==18.1.1
37
+ markdown-it-py==2.2.0
38
+ MarkupSafe==2.1.5
39
+ matplotlib==3.8.3
40
+ mdit-py-plugins==0.3.3
41
+ mdurl==0.1.2
42
+ mpmath==1.3.0
43
+ multidict==6.0.5
44
+ networkx==3.2.1
45
+ numpy==1.26.4
46
+ orjson==3.9.15
47
+ packaging==24.0
48
+ pandas==2.2.1
49
+ pillow==10.2.0
50
+ pydantic==2.6.4
51
+ pydantic_core==2.16.3
52
+ pydub==0.25.1
53
+ Pygments==2.17.2
54
+ pyparsing==3.1.2
55
+ python-dateutil==2.9.0.post0
56
+ python-multipart==0.0.9
57
+ pytz==2024.1
58
+ PyYAML==6.0.1
59
+ referencing==0.33.0
60
+ regex==2023.12.25
61
+ requests==2.31.0
62
+ rich==13.7.1
63
+ rpds-py==0.18.0
64
+ ruff==0.3.3
65
+ safetensors==0.4.2
66
+ semantic-version==2.10.0
67
+ shellingham==1.5.4
68
+ six==1.16.0
69
+ sniffio==1.3.1
70
+ starlette==0.36.3
71
+ sympy==1.12
72
+ tokenizers==0.13.3
73
+ tomlkit==0.12.0
74
+ toolz==0.12.1
75
+ torch==2.0.1
76
+ tqdm==4.66.2
77
+ transformers==4.31.0
78
+ triton==2.0.0
79
+ typer==0.9.0
80
+ typing_extensions==4.10.0
81
+ tzdata==2024.1
82
+ uc-micro-py==1.0.3
83
+ urllib3==2.2.1
84
+ uvicorn==0.28.0
85
+ websockets==11.0.3
86
+ yarl==1.9.4