llama2-chat / app.py
Hunzla's picture
Update app.py
2e669c6
raw
history blame contribute delete
No virus
1.14 kB
# main.py
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
# Load model and tokenizer
model_name = "meta-llama/Llama-2-7b-chat-hf"
print("started loading model")
model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
return_dict=True,
revision="main",
)
# return_dict=True,
# torch_dtype=torch.float16,
print("loaded model")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
# Or the desired revision
)
print("loaded tokenizer")
chat_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
print("built pipeline")
# Define the generate_response function
def generate_response(prompt):
response = chat_pipeline(prompt, max_length=50)[0]['generated_text']
return response
# Create Gradio interface
interface = gr.Interface(
fn=generate_response,
inputs="text",
outputs="text",
layout="vertical",
title="LLAMA-2-7B Chatbot",
description="Enter a prompt and get a chatbot response.",
examples=[["Tell me a joke."]],
)
if __name__ == "__main__":
interface.launch()