File size: 1,477 Bytes
dde4d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os

from modules.log import logly

os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import torch
from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig


def pipeline_tunnel(model, device):
    return pipeline(
        "text-generation",
        model=model,
        model_kwargs={
            "torch_dtype": torch.float16,
        },
        device=device,
    )


def chatbot(input_text, max_new_tokens=1250, temperature=0.7, top_k=50, top_p=0.95, model="google/gemma-2b-it",
            quantization="4-bit", device="cuda"):
    try:
        if quantization == "4-bit":
            quantization_config = BitsAndBytesConfig(load_in_4bit=True)

        else:
            quantization_config = BitsAndBytesConfig(load_in_8bit=True)

        tokenizer = AutoTokenizer.from_pretrained(model, quantization_config=quantization_config)

        pipeline_model = pipeline_tunnel(model, device)
        messages = [{"role": "user", "content": input_text}]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        outputs = pipeline_model(
            prompt,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p
        )
        generated_text = outputs[0]["generated_text"][len(prompt):]
        return generated_text
    except Exception as e:
        logly.error(f"Error in GemGPT: {e}")
        return str(e)