ComeBien_Demo / app.py
rovi27's picture
Update app.py
b613f1c verified
raw
history blame
4.69 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig
import os
#sft_model = "somosnlp/gemma-FULL-RAC-Colombia_v2"
#sft_model = "somosnlp/RecetasDeLaAbuela_mistral-7b-instruct-v0.2-bnb-4bit"
#base_model_name = "unsloth/Mistral-7B-Instruct-v0.2"
sft_model = "somosnlp/RecetasDeLaAbuela_gemma-2b-it-bnb-4bit"
base_model_name = "unsloth/gemma-2b-it-bnb-4bit"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
max_seq_length=400
# if torch.cuda.get_device_capability()[0] >= 8:
# # print("Flash Attention")
# attn_implementation="flash_attention_2"
# else:
# attn_implementation=None
attn_implementation=None
#base_model = AutoModelForCausalLM.from_pretrained(model_name,return_dict=True,torch_dtype=torch.float16,)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name,return_dict=True,device_map="auto", torch_dtype=torch.float16,)
#base_model = AutoModelForCausalLM.from_pretrained(base_model_name, return_dict=True, device_map = {"":0}, attn_implementation = attn_implementation,).eval()
tokenizer = AutoTokenizer.from_pretrained(base_model_name, max_length = max_seq_length)
ft_model = PeftModel.from_pretrained(base_model, sft_model)
model = ft_model.merge_and_unload()
model.save_pretrained(".")
#model.to('cuda')
tokenizer.save_pretrained(".")
class ListOfTokensStoppingCriteria(StoppingCriteria):
"""
Clase para definir un criterio de parada basado en una lista de tokens específicos.
"""
def __init__(self, tokenizer, stop_tokens):
self.tokenizer = tokenizer
# Codifica cada token de parada y guarda sus IDs en una lista
self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens]
def __call__(self, input_ids, scores, **kwargs):
# Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada
for stop_token_ids in self.stop_token_ids_list:
len_stop_tokens = len(stop_token_ids)
if len(input_ids[0]) >= len_stop_tokens:
if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids:
return True
return False
# Uso del criterio de parada personalizado
stop_tokens = ["<end_of_turn>"] # Lista de tokens de parada
# Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada
stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)
# Añade tu criterio de parada a una StoppingCriteriaList
stopping_criteria_list = StoppingCriteriaList([stopping_criteria])
def generate_text(prompt, context, max_length=2100):
prompt=prompt.replace("\n", "").replace("¿","").replace("?","")
input_text = f'''<bos><start_of_turn>system ¿{context}?<end_of_turn><start_of_turn>user ¿{prompt}?<end_of_turn><start_of_turn>model'''
inputs = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=False).to("cuda:0")
max_new_tokens=max_length
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=0.32,
#top_p=0.9,
top_k=50, # 45
repetition_penalty=1.04, #1.1
do_sample=True,
)
outputs = model.generate(generation_config=generation_config, input_ids=inputs, stopping_criteria=stopping_criteria_list,)
return tokenizer.decode(outputs[0], skip_special_tokens=False) #True
def mostrar_respuesta(pregunta, contexto):
try:
res= generate_text(pregunta, contexto, max_length=500)
return str(res)
except Exception as e:
return str(e)
# Ejemplos de preguntas
mis_ejemplos = [
["¿Dime la receta de la tortilla de patatatas?"],
["¿Dime la receta del ceviche?"],
["¿Como se cocinan unos autenticos frijoles?"],
]
iface = gr.Interface(
fn=mostrar_respuesta,
inputs=[gr.Textbox(label="Pregunta"), gr.Textbox(label="Contexto", value="You are a helpful AI assistant. Eres un experto cocinero hispanoamericano."),],
outputs=[gr.Textbox(label="Respuesta", lines=2),],
title="Recetas de la Abuel@",
description="Introduce tu pregunta sobre recetas de cocina.",
ejemplos=mis_ejemplos,
)
iface.queue(max_size=14).launch() # share=True,debug=True