doubledsbv commited on
Commit
5517288
1 Parent(s): 47cca70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -1
app.py CHANGED
@@ -1,3 +1,60 @@
1
  import gradio as gr
 
 
2
 
3
- gr.load("models/doubledsbv/Llama-3-Kafka-8B-v0.3").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import torch
4
 
5
+ import transformers
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ model_id = "doubledsbv/Llama-3-Kafka-8B-v0.3"
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
12
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
13
+
14
+
15
+ generate_text = transformers.pipeline(
16
+ model=model, tokenizer=tokenizer,
17
+ return_full_text=True,
18
+ task='text-generation',
19
+ device="gpu",
20
+ )
21
+
22
+ @spaces.GPU
23
+ def chat_function(message, history, system_prompt,max_new_tokens,temperature):
24
+ messages = [
25
+ {"role": "system", "content": system_prompt},
26
+ {"role": "user", "content": message},
27
+ ]
28
+ prompt = pipeline.tokenizer.apply_chat_template(
29
+ messages,
30
+ tokenize=False,
31
+ add_generation_prompt=True
32
+ )
33
+ terminators = [
34
+ pipeline.tokenizer.eos_token_id,
35
+ pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
36
+ ]
37
+ outputs = pipeline(
38
+ prompt,
39
+ max_new_tokens=max_new_tokens,
40
+ eos_token_id=terminators,
41
+ do_sample=True,
42
+ temperature=temperature,
43
+ top_p=0.9,
44
+ )
45
+ return outputs[0]["generated_text"][len(prompt):]
46
+
47
+ gr.ChatInterface(
48
+ chat_function,
49
+ chatbot=gr.Chatbot(height=400),
50
+ textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
51
+ title="Llama-3-Kafka-8B-v0.3",
52
+ description="""
53
+ German-focused finetuned version of Llama-3-8B
54
+ """,
55
+ additional_inputs=[
56
+ gr.Textbox("Du bist ein freundlicher KI-Assistent", label="System Prompt"),
57
+ gr.Slider(512, 8192, label="Max New Tokens"),
58
+ gr.Slider(0, 1, label="Temperature")
59
+ ]
60
+ ).launch()