AvaPersona commited on
Commit
8bb427b
·
verified ·
1 Parent(s): 6805787

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -1,27 +1,43 @@
 
1
  from transformers import LlamaForCausalLM, LlamaTokenizer
2
  import gradio as gr
3
- import torch
4
 
5
  # Load the model and tokenizer
6
- model_name = "meta-llama/Llama-3.1-8B" # Replace with the desired LLaMA model
7
- tokenizer = LlamaTokenizer.from_pretrained(model_name)
8
- model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
 
 
 
 
 
 
 
9
 
10
- # Define the response generation function
11
- def generate_response(prompt, max_length=100):
12
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # Use CUDA if available
13
- outputs = model.generate(inputs['input_ids'], max_length=max_length, temperature=0.7)
 
 
 
 
 
 
 
 
14
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
  return response
16
 
17
- # Create the Gradio interface
18
- interface = gr.Interface(
19
  fn=generate_response,
20
- inputs=gr.Textbox(lines=5, placeholder="Enter your prompt here..."),
21
- outputs="text",
22
- title="LLaMA Chatbot",
23
- description="A chatbot powered by LLaMA. Enter a prompt and get a response!",
24
  )
25
 
26
  # Launch the app
27
- interface.launch()
 
 
1
+ import torch
2
  from transformers import LlamaForCausalLM, LlamaTokenizer
3
  import gradio as gr
 
4
 
5
  # Load the model and tokenizer
6
+ MODEL_NAME = "meta-llama/Llama-2-8b-hf" # Update this if using a custom LLaMA model
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ print("Loading model...")
10
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)
11
+ model = LlamaForCausalLM.from_pretrained(
12
+ MODEL_NAME,
13
+ torch_dtype=torch.float16, # Use float16 for better performance
14
+ device_map="auto" # Automatically load onto available GPU
15
+ )
16
 
17
+ # Define a function for generating responses
18
+ def generate_response(prompt):
19
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
20
+ with torch.no_grad():
21
+ outputs = model.generate(
22
+ input_ids=inputs["input_ids"],
23
+ attention_mask=inputs["attention_mask"],
24
+ max_length=512,
25
+ temperature=0.7, # Adjust creativity level
26
+ top_p=0.95, # Top-p sampling
27
+ num_return_sequences=1
28
+ )
29
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
  return response
31
 
32
+ # Gradio UI
33
+ iface = gr.Interface(
34
  fn=generate_response,
35
+ inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here..."),
36
+ outputs=gr.Textbox(label="LLaMA Response"),
37
+ title="LLaMA 3.1 8B Chatbot",
38
+ description="An interactive demo of the LLaMA 3.1 8B model using Hugging Face Spaces."
39
  )
40
 
41
  # Launch the app
42
+ if __name__ == "__main__":
43
+ iface.launch()