Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import os | |
import sys | |
import time | |
import json | |
from typing import List | |
from transformers import ( | |
LlamaTokenizer, | |
LlamaForCausalLM, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
LlamaConfig | |
) | |
from peft import PeftModel | |
from accelerate import disk_offload | |
model = AutoModelForCausalLM.from_pretrained( | |
"Johntad110/llama-2-7b-amharic-tokenizer", | |
return_dict=True, | |
load_in_8bit=True, | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
attn_implementation="sdpa" | |
) | |
tokenizer = LlamaTokenizer.from_pretrained( | |
"Johntad110/llama-2-7b-amharic-tokenizer" | |
) | |
embedding_size = model.get_input_embeddings().weight.shape[0] | |
if len(tokenizer) != embedding_size: | |
model.resize_token_embeddings(len(tokenizer)) | |
model = PeftModel.from_pretrained(model, "Johntad110/llama-2-amharic-peft") | |
model.eval() # Set model to evaluation mode | |
def generate_text( | |
prompt: str, | |
max_new_tokens: int = None, | |
seed: int = 42, | |
do_sample: bool = True, | |
min_length: int = None, | |
use_cache: bool = True, | |
top_p: float = 1.0, | |
temperature: float = 1.0, | |
top_k: int = 1, | |
repetition_penalty: float = 1.0, | |
length_penalty: int = 1, | |
): | |
""" | |
Function to perform text generation with user-defined parameters | |
""" | |
torch.cuda.manual_seed(seed) | |
torch.manual_seed(seed) | |
batch = tokenizer(prompt, return_tensors="pt") | |
batch = {k: v.to("cuda") for k, v in batch.items()} | |
with torch.no_grad(): | |
outputs = model.generate( | |
**batch, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
top_p=top_p, | |
temperature=temperature, | |
min_length=min_length, | |
use_cache=use_cache, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
) | |
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return output_text | |
interface = gr.Interface( | |
fn=generate_text, | |
inputs=[gr.Textbox(label="Prompt")], | |
outputs="text" | |
) | |
interface.launch(debug=True) | |