|
import torch
|
|
from transformers import LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
|
|
|
|
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
|
"meta-llama/Llama-3.2-1B", legacy=False
|
|
)
|
|
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
|
|
|
|
checkpoint = torch.load(
|
|
r"C:\Users\Michelle\Documents\GitHub\Fine_Tuning_Llama3.2\results\checkpoint-6\rng_state.pth",
|
|
map_location="cpu",
|
|
)
|
|
print(checkpoint.keys())
|
|
|
|
""" model.load_state_dict(
|
|
checkpoint["model_state_dict"]
|
|
) """
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
def generar_respuesta(pregunta):
|
|
inputs = tokenizer.encode(pregunta, return_tensors="pt")
|
|
outputs = model.generate(inputs, max_length=200, num_return_sequences=1)
|
|
respuesta = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return respuesta
|
|
|
|
|
|
|
|
pregunta = "No, solo aquellos que cuentan con la autorizaci贸n correspondiente por parte del SAT"
|
|
respuesta = generar_respuesta(pregunta)
|
|
print(f"Pregunta: {pregunta}")
|
|
print(f"Respuesta: {respuesta}")
|
|
|