huggyllama / main.py
ekinnk's picture
Update main.py
52b5db1
#Imporing required libraries
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import torch
# Defining the pipeline and the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto",offload_folder="offload", torch_dtype=torch.float16)
#model = model.to(device)
print("***")
print("Loaded tokenizer and model")
print(device)
print("***")
pipe_flan = pipeline("text-generation", model=model, tokenizer=tokenizer)
print("***")
print("Created pipeline")
print("***")
# Text generation
def generator(input):
output = pipe_flan(input, max_length=50, num_return_sequences=1)
return output[0]["generated_text"]
# Creating the Gradio Interface
demo = gr.Interface(
fn=generator,
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
outputs=gr.outputs.Textbox(label="Generated Text")
)
host, port = "0.0.0.0", 7860
print("***")
print(f"Set up interface. Hosting now on {host}:{port}")
print("***")
# Lauching the Gradio Interface
demo.launch(server_name=host, server_port=port)