Spaces:
Sleeping
Sleeping
File size: 3,933 Bytes
de77006 73f7de1 de77006 d15a5e9 de77006 acd5f2b de77006 73f7de1 d15a5e9 d2cc6db d15a5e9 6e4f7cc d15a5e9 6e4f7cc de77006 acd5f2b de77006 6e4f7cc de77006 1e2dab1 de77006 1e2dab1 de77006 73f7de1 6e4f7cc acd5f2b 6e4f7cc acd5f2b 6e4f7cc 73f7de1 ef7e5b5 73f7de1 1e2dab1 73f7de1 1e2dab1 de77006 6e4f7cc de77006 73f7de1 1e2dab1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_INPUT_TOKEN_LENGTH = 8192
DESCRIPTION = """\
# CataLlama-v0.1-Instruct-DPO
This Space demonstrates model [CataLlama-v0.1-Instruct-DPO](https://huggingface.co/catallama/CataLlama-v0.1-Instruct-DPO).
CataLlama is a fine-tune of Llama-3-8B to enhance it's proficiency on the Catalan Language.
The model is capable of performing the following **tasks in Catalan**:
- Translation from English to Catalan and Catalan to English
- Summarization - both short form and long form
- Information extraction (suitable for RAG)
- Named Entity Recognition (NER)
- Open question answering
- Sentiment analysis
"""
LICENSE = """\
As a derivate work of [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) by Meta, this demo is governed by the original [llama-3 license](https://llama.meta.com/llama3/license)
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "catallama/CataLlama-v0.1-Instruct-DPO"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU(duration=120)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="Ets un chatbot amigable. Responeu preguntes i ajudeu els usuaris.",
label="System message",
lines=6
),
gr.Slider(
minimum=1,
maximum=2048,
value=1024,
step=256,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.05,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.90,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["A quina velocitat poden volar els cocodrils?"],
["Explica pas a pas com resoldre l'equació següent: 2x + 10 = 0"],
["Pot Donald Trump sopar amb Juli Cèsar?"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|