File size: 1,946 Bytes
47cca70
5517288
47cca70
5517288
 
6b1e60c
5517288
fed235b
5517288
 
 
49aa929
5517288
aad79ca
5517288
 
 
0eb0e45
5517288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305c593
8c4f787
 
 
 
5517288
 
 
20b3000
ee8f3ef
 
5517288
 
 
ee8f3ef
 
49aa929
5517288
 
 
 
 
e0015e1
 
fed235b
5517288
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import spaces

import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

model_id = "doubledsbv/Llama-3-Kafka-8B-v0.1"

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
#streamer = TextStreamer(tokenizer)

pipeline = transformers.pipeline(
    model=model, tokenizer=tokenizer,
    return_full_text=True,  
    task='text-generation',
    device="cuda",
)

@spaces.GPU
def chat_function(message, history, system_prompt,max_new_tokens,temperature):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": message},
    ]
    prompt = pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    
    if temperature == 0:
        temperature += 0.1
    
    outputs = pipeline(
        prompt,
        max_new_tokens=max_new_tokens,
        num_beams=3,
        num_return_sequences=1,
        early_stopping=True,
        eos_token_id=terminators,
        do_sample=True,
        temperature=temperature,
        top_p=0.9,
        #min_p=0.075,
        #streamer=streamer
    )
    return outputs[0]["generated_text"][len(prompt):]

gr.ChatInterface(
    chat_function,
    chatbot=gr.Chatbot(height=500),
    textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=5),
    title="Llama-3-Kafka-8B-v0.1",
    description="""
    German-focused finetuned version of Llama-3-8B
    """,
    additional_inputs=[
        gr.Textbox("Du bist ein freundlicher KI-Assistent", label="System Prompt"),
        gr.Slider(512, 8192, label="Max New Tokens"),
        gr.Slider(0, 1, label="Temperature")
    ]
).launch()