Spaces:
Runtime error
Runtime error
import logging | |
import re | |
import string | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
GEMMA_WORD_PATTERNS = [ | |
"(?<=\*)(.*?)(?=\*)", | |
'(?<=")(.*?)(?=")', | |
] | |
def query_hf( | |
query: str, | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
generation_config: dict, | |
device: str, | |
) -> str: | |
"""Queries an LLM model using the Vertex AI API. | |
Args: | |
query (str): Query sent to the Vertex API | |
model (str): Model target by Vertex | |
generation_config (dict): Configurations used by the model | |
Returns: | |
str: Vertex AI text response | |
""" | |
generation_config = GenerationConfig( | |
do_sample=True, | |
max_new_tokens=generation_config["max_output_tokens"], | |
top_k=generation_config["top_k"], | |
top_p=generation_config["top_p"], | |
temperature=generation_config["temperature"], | |
) | |
input_ids = tokenizer(query, return_tensors="pt").to(device) | |
outputs = model.generate(**input_ids, generation_config=generation_config) | |
outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
outputs = outputs.replace(query, "") | |
return outputs | |
def query_word( | |
category: str, | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
generation_config: dict, | |
device: str, | |
) -> str: | |
"""Queries a word to be used for the hangman game. | |
Args: | |
category (str): Category used as source sample a word | |
model (str): Model target by Vertex | |
generation_config (dict): Configurations used by the model | |
Returns: | |
str: Queried word | |
""" | |
logger.info(f"Quering word for category: '{category}'...") | |
query = f"Name a single existing {category}." | |
matched_word = "" | |
while not matched_word: | |
word = query_hf(query, model, tokenizer, generation_config, device) | |
logger.info(f"Evaluating result: '{word}'...") | |
# Extract word of interest from Gemma's output | |
for pattern in GEMMA_WORD_PATTERNS: | |
matched_words = re.findall(rf"{pattern}", word) | |
matched_words = [x for x in matched_words if x != ""] | |
if matched_words: | |
matched_word = matched_words[-1] | |
matched_word = matched_word.translate(str.maketrans("", "", string.punctuation)) | |
matched_word = matched_word.lower() | |
logger.info("Word queried successful") | |
return matched_word | |
def query_hint( | |
word: str, | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
generation_config: dict, | |
device: str, | |
) -> str: | |
"""Queries a hint for the hangman game. | |
Args: | |
word (str): Word used as source to create the hint | |
model (str): Model target by Vertex | |
generation_config (dict): Configurations used by the model | |
Returns: | |
str: Queried hint | |
""" | |
logger.info(f"Quering hint for word: '{word}'...") | |
query = f"Describe the word '{word}' without mentioning it." | |
hint = query_hf(query, model, tokenizer, generation_config, device) | |
hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE) | |
logger.info("Hint queried successful") | |
return hint | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__file__) | |