shljessie commited on
Commit
c70b3cf
1 Parent(s): 194ecfa

update gradio chat interface

Browse files
Files changed (1) hide show
  1. app.py +49 -4
app.py CHANGED
@@ -1,7 +1,52 @@
 
 
1
  import gradio as gr
2
  import torch
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
  import gradio as gr
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
7
+ # Check if CUDA is available
8
+ if not torch.cuda.is_available():
9
+ raise EnvironmentError("CUDA is not available. This script requires a GPU.")
10
+
11
+ # Model Configuration
12
+ MODEL_ID = "meta-llama/Llama-2-7b-chat"
13
+ MAX_INPUT_TOKEN_LENGTH = 4096
14
+ MAX_NEW_TOKENS = 1024
15
+ TEMPERATURE = 0.6
16
+ TOP_P = 0.9
17
+ TOP_K = 50
18
+ REPETITION_PENALTY = 1.2
19
+
20
+ # Load the model and tokenizer
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23
+
24
+ def generate_response(user_input):
25
+ """
26
+ Generate a response to the user input using the Llama-2 7B model.
27
+ """
28
+ input_ids = tokenizer.encode(user_input, return_tensors="pt")
29
+ input_ids = input_ids.to(model.device)
30
+
31
+ # Generate a response
32
+ output = model.generate(input_ids, max_length=MAX_INPUT_TOKEN_LENGTH + len(input_ids[0]),
33
+ max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,
34
+ top_k=TOP_K, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY)
35
+
36
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
37
+ return response
38
+
39
+ def chatbot_interface(user_input):
40
+ return generate_response(user_input)
41
+
42
+ # Create the Gradio interface
43
+ iface = gr.Interface(
44
+ fn=chatbot_interface,
45
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Type your message here..."),
46
+ outputs="text",
47
+ title="Llama-2 7B Chatbot",
48
+ description="This is a chatbot powered by the Llama-2 7B model. Try asking it something!",
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ iface.launch()