Spaces:
Runtime error
Runtime error
File size: 1,477 Bytes
dde4d03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import os
from modules.log import logly
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import torch
from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
def pipeline_tunnel(model, device):
return pipeline(
"text-generation",
model=model,
model_kwargs={
"torch_dtype": torch.float16,
},
device=device,
)
def chatbot(input_text, max_new_tokens=1250, temperature=0.7, top_k=50, top_p=0.95, model="google/gemma-2b-it",
quantization="4-bit", device="cuda"):
try:
if quantization == "4-bit":
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
else:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model, quantization_config=quantization_config)
pipeline_model = pipeline_tunnel(model, device)
messages = [{"role": "user", "content": input_text}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = pipeline_model(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
generated_text = outputs[0]["generated_text"][len(prompt):]
return generated_text
except Exception as e:
logly.error(f"Error in GemGPT: {e}")
return str(e)
|