nmarafo commited on
Commit
7f034b6
·
verified ·
1 Parent(s): 9496a4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -124
app.py CHANGED
@@ -1,137 +1,126 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
-
4
- import os
5
- import spaces
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
  import torch
8
- from threading import Thread
9
- from sentence_transformers import SentenceTransformer
10
- from datasets import load_dataset
11
- import time
12
-
13
- token = os.environ["HF_TOKEN"]
14
- ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
15
-
16
- dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
17
-
18
- data = dataset["train"]
19
- data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
20
-
21
 
22
- model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
 
23
 
24
- # use quantization to lower GPU usage
25
- bnb_config = BitsAndBytesConfig(
26
- load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
27
- )
28
 
29
- tokenizer = AutoTokenizer.from_pretrained(model_id,token=token)
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_id,
32
- torch_dtype=torch.bfloat16,
33
- device_map="auto",
34
- quantization_config=bnb_config,
35
  token=token
36
  )
37
- terminators = [
38
- tokenizer.eos_token_id,
39
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
40
- ]
41
-
42
- SYS_PROMPT = """You are an assistant for answering questions.
43
- You are given the extracted parts of a long document and a question. Provide a conversational answer.
44
- If you don't know the answer, just say "I do not know." Don't make up an answer."""
45
-
46
-
47
-
48
- def search(query: str, k: int = 3 ):
49
- """a function that embeds a new query and returns the most probable results"""
50
- embedded_query = ST.encode(query) # embed new query
51
- scores, retrieved_examples = data.get_nearest_examples( # retrieve results
52
- "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
53
- k=k # get only top k results
54
- )
55
- return scores, retrieved_examples
56
-
57
- def format_prompt(prompt,retrieved_documents,k):
58
- """using the retrieved documents we will prompt the model to generate our responses"""
59
- PROMPT = f"Question:{prompt}\nContext:"
60
- for idx in range(k) :
61
- PROMPT+= f"{retrieved_documents['text'][idx]}\n"
62
- return PROMPT
63
-
64
-
65
- @spaces.GPU(duration=150)
66
- def talk(prompt,history):
67
- k = 1 # number of retrieved documents
68
- scores , retrieved_documents = search(prompt, k)
69
- formatted_prompt = format_prompt(prompt,retrieved_documents,k)
70
- formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
71
- messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
72
- # tell the model to generate
73
- input_ids = tokenizer.apply_chat_template(
74
- messages,
75
- add_generation_prompt=True,
76
- return_tensors="pt"
77
- ).to(model.device)
78
- outputs = model.generate(
79
- input_ids,
80
- max_new_tokens=1024,
81
- eos_token_id=terminators,
82
- do_sample=True,
83
- temperature=0.6,
84
- top_p=0.9,
85
- )
86
- streamer = TextIteratorStreamer(
87
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
88
- )
89
- generate_kwargs = dict(
90
- input_ids= input_ids,
91
- streamer=streamer,
92
- max_new_tokens=1024,
93
- do_sample=True,
94
- top_p=0.95,
95
- temperature=0.75,
96
- eos_token_id=terminators,
97
- )
98
- t = Thread(target=model.generate, kwargs=generate_kwargs)
99
- t.start()
100
-
101
- outputs = []
102
- for text in streamer:
103
- outputs.append(text)
104
- print(outputs)
105
- yield "".join(outputs)
106
-
107
-
108
- TITLE = "# RAG"
109
-
110
- DESCRIPTION = """
111
- A rag pipeline with a chatbot feature
112
- Resources used to build this project :
113
- * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
114
- * dataset : https://huggingface.co/datasets/not-lain/wikipedia
115
- * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
116
- * chatbot : https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
117
- * Full documentation : https://huggingface.co/blog/not-lain/rag-chatbot-using-llama3
118
- """
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  demo = gr.ChatInterface(
122
- fn=talk,
123
- chatbot=gr.Chatbot(
124
- show_label=True,
125
- show_share_button=True,
126
- show_copy_button=True,
127
- likeable=True,
128
- layout="bubble",
129
- bubble_full_width=False,
130
- ),
131
- theme="Soft",
132
- examples=[["what's anarchy ? "]],
133
- title=TITLE,
134
- description=DESCRIPTION,
135
-
136
  )
137
- demo.launch(debug=True)
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
3
  import torch
4
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Cargar el token de Hugging Face desde los secretos
7
+ token = os.environ.get("HF_TOKEN")
8
 
9
+ # Cargar el modelo y el tokenizador sin cuantización (usando CPU)
10
+ model_id = "PrunaAI/google-shieldgemma-2b-bnb-4bit-smashed"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
 
12
 
13
+ # Cargar el modelo y forzar el uso de CPU
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_id,
16
+ device_map="cpu", # Forzar uso de CPU
17
+ torch_dtype=torch.float32, # Cambiar el tipo de datos para CPU
 
18
  token=token
19
  )
20
+
21
+ # Función para generar el prompt dependiendo del idioma seleccionado
22
+ def generar_prompt(message, tipo_clasificacion, idioma):
23
+ if idioma == "Español":
24
+ politicas = {
25
+ "Contenido Peligroso": """
26
+ * "No al contenido peligroso": El mensaje no debe contener ni buscar la generación de contenido que cause daño a uno mismo o a otros (por ejemplo, acceder o fabricar armas de fuego o explosivos, promoción del terrorismo, instrucciones para el suicidio).
27
+ """,
28
+ "Acoso": """
29
+ * "No al acoso": El mensaje no debe contener ni buscar la generación de contenido que sea malicioso, intimidante, que fomente el bullying o abusivo dirigido a otra persona (por ejemplo, amenazas físicas, negación de eventos trágicos, menosprecio a víctimas de violencia).
30
+ """,
31
+ "Discurso de Odio": """
32
+ * "No al discurso de odio": El mensaje no debe contener ni buscar la generación de contenido que ataque la identidad o atributos protegidos (por ejemplo, insultos raciales, promoción de la discriminación, llamados a la violencia contra grupos protegidos) o que deshumanice, menosprecie o vilipendie por motivos de raza, etnia, religión, discapacidad, edad, nacionalidad, estatus de veterano, orientación sexual, género, identidad de género, casta, u otro estatus protegido.
33
+ """,
34
+ "Contenido Sexual": """
35
+ * "No al contenido sexual explícito": El mensaje no debe contener ni buscar la generación de contenido que haga referencia a actos sexuales u otro contenido lascivo (por ejemplo, descripciones gráficas sexuales, contenido destinado a causar excitación). Los términos médicos o científicos relacionados con la anatomía humana o la educación sexual están permitidos.
36
+ """
37
+ }
38
+ else: # English
39
+ politicas = {
40
+ "Dangerous Content": """
41
+ * "No Dangerous Content": The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
42
+ """,
43
+ "Harassment": """
44
+ * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
45
+ """,
46
+ "Hate Speech": """
47
+ * "No Hate Speech": The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.
48
+ """,
49
+ "Sexually Explicit Information": """
50
+ * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
51
+ """
52
+ }
53
+
54
+ # Formatear el prompt según el tipo de clasificación seleccionado
55
+ politica_seguridad = politicas.get(tipo_clasificacion, "")
56
+
57
+ if idioma == "Español":
58
+ prompt = f"""Eres un experto en políticas tratando de determinar si un mensaje del usuario
59
+ viola las políticas de seguridad definidas.
60
+ <inicio_turno>
61
+ Pregunta del usuario: {message.strip()}
62
+ <fin_turno>
63
+ Nuestro principio de seguridad se define a continuación:
64
+ {politica_seguridad.strip()}
65
+ ¿La pregunta del usuario viola el principio anterior? Tu respuesta debe comenzar con
66
+ 'Sí' o 'No'. Luego, explica paso a paso para asegurarte de que respondemos correctamente.
67
+ """
68
+ else: # Default to English
69
+ prompt = f"""You are a policy expert trying to help determine whether a user prompt is in
70
+ violation of the defined safety policies.
71
+ <start_of_turn>
72
+ Human Question: {message.strip()}
73
+ <end_of_turn>
74
+ Our safety principle is defined below:
75
+ {politica_seguridad.strip()}
76
+ Does the human question violate the above principle? Your answer must start with
77
+ 'Yes' or 'No'. Then walk through step by step to be sure we answer correctly.
78
+ """
79
+
80
+ return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
+ # Función para procesar la respuesta y clasificar según las políticas de seguridad
84
+ def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, language, harm_type):
85
+ prompt = generar_prompt(message, harm_type, language)
86
+
87
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu") # Forzar a CPU
88
+
89
+ with torch.no_grad():
90
+ logits = model(**inputs).logits
91
+
92
+ # Extraer los logits para los tokens 'Yes'/'No' en inglés o 'Sí'/'No' en español
93
+ vocab = tokenizer.get_vocab()
94
+ if language == "Español":
95
+ selected_logits = logits[0, -1, [vocab['Sí'], vocab['No']]]
96
+ else:
97
+ selected_logits = logits[0, -1, [vocab['Yes'], vocab['No']]]
98
+
99
+ # Convertir los logits en una probabilidad con softmax
100
+ probabilities = torch.softmax(selected_logits, dim=0)
101
+
102
+ # Devolver la probabilidad de 'Sí'/'Yes' y la respuesta generada
103
+ score_yes = probabilities[0].item()
104
+
105
+ if language == "Español":
106
+ response = f"Puntuación para 'Sí' (violación): {score_yes:.4f}"
107
+ else:
108
+ response = f"Score for 'Yes' (violation): {score_yes:.4f}"
109
+
110
+ return response
111
+
112
+ # Crear la interfaz de Gradio con selección de idioma y tipo de contenido
113
  demo = gr.ChatInterface(
114
+ respond,
115
+ additional_inputs=[
116
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
117
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
118
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
119
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
120
+ gr.Dropdown(choices=["English", "Español"], value="English", label="Idioma/Language"),
121
+ gr.Dropdown(choices=["Dangerous Content", "Harassment", "Hate Speech", "Sexually Explicit Information"], value="Harassment", label="Harm Type")
122
+ ],
 
 
 
 
 
123
  )
124
+
125
+
126
+ demo.launch(debug=True)