bambadij commited on
Commit
57da3c7
·
1 Parent(s): 3fe712e

change model to llm write

Browse files
Files changed (2) hide show
  1. app.py +92 -54
  2. requirements.txt +10 -3
app.py CHANGED
@@ -1,20 +1,29 @@
1
  #load package
2
- from fastapi import FastAPI,HTTPException,status,UploadFile,File
3
  from pydantic import BaseModel
4
- import uvicorn
5
- import logging
6
  import torch
7
- from transformers import T5Tokenizer, T5ForConditionalGeneration
 
 
 
 
 
 
 
 
8
  import os
9
- import numpy as np
 
 
 
10
 
11
  # Configurer les répertoires de cache
12
  os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
13
  os.environ['HF_HOME'] = '/app/.cache'
14
  # Charger le modèle et le tokenizer
15
- model_name = "t5-base"
16
- tokenizer = T5Tokenizer.from_pretrained(model_name)
17
- model = T5ForConditionalGeneration.from_pretrained(model_name)
18
 
19
  #Additional information
20
 
@@ -34,57 +43,86 @@ app =FastAPI(
34
  logging.basicConfig(level=logging.INFO)
35
  logger =logging.getLogger(__name__)
36
 
 
 
 
 
 
 
 
37
 
38
  default_prompt = """Bonjour,
39
- en tant qu’expert dans la gestion et le traitement de plaintes réseaux chez un opérateur de télécommunications, fais moi un descriptif clair de la situation concernant la plainte dont les informations sont fournies plus bas dans ce message. Ecris la situation en 4 ou 5 phrases claires et concises, fais comme si tu parlais à un humain et rajoutes les informations relatives au Client pour une meilleure connaissance de ce dernier ainsi que des éléments de dates/délais pour être précis sur le traitement de la plainte. N’hésites pas à innover sur le ton à utiliser car n’oublies pas que tu dois faire comme si tu parlais à un humain. Ce ton peut être adapté et ne pas toujours être le même en fonction des cas.
40
- Pour m’éviter de lire tout le détail de la plainte (voir le texte partagé plus bas), essayes de trouver toutes les informations utiles permettant de mieux appréhender la situation, par exemple : si les coordonnées GPS (Lat, Lon) sont disponibles essayes de m'indiquer le lieu où est survenue la plainte même de manière approximative. Essayes également de glaner sur internet toutes les informations pouvant aider à mieux comprendre et traiter la plainte (cela peut inclure des informations des réseaux sociaux, des concurrents, etc.) tout en priorisant dans l’analyse les informations fournies dans le texte plus bas; informations qui ont été renseignées par les experts internes chez l’opérateur de télécommunications en question et qui sont structurées en plusieurs sections :
41
- a) Un titre de la plainte
42
- b) Une section avec les Détails de la Plainte contenant l’objet, le numéro client, l’expéditeur, la date de création, les coordonnées géographiques (lat, lon)
43
- c) Une section avec les états d’avancement incluant les échanges (sous format chat) entre les différents acteurs impliqués dans le traitement de la plainte
44
- d) Une section contenant les éléments relatifs à la qualification de la plainte (type de plainte, origine, domaine, sous-domaine, etc…)
45
- e) Une section avec les fichiers joints à la plainte et autres pièces jointes pour mieux comprendre et trouver une solution à cette plainte en vue de satisfaire le Client
46
-
47
- Dans la situation que tu vas me donner (en quelques 4 ou 5 phrases comme si tu t’adresses à un humain), assures toi que les points cruciaux (voire rédhibitoires) ci-dessous sont bien présents :
48
- 1) Ecris la situation en 4 ou 5 phrases claires et concises, fais comme si tu parlais à un humain
49
- 2) Rajoutes les informations relatives au Client pour être précis sur la connaissance de ce dernier.
50
- 3) Rajoutes des éléments de dates (remontée, transfert, prise en charge, résolution, clôture, etc…) ainsi que les délais (par exemple de réponse des différents acteurs ou experts de la chaine de traitement) pour mieux apprécier l'efficacité du traitement de la plainte.
51
- 4) Rajoutes à la fin une recommandation importante afin d'éviter le mécontentement du Client par exemple pour éviter qu’une Plainte ne soit clôturée sans solution pour le Client notamment et à titre illustratif seulement dans certains cas pour un Client qui a payé pour un service et ne l'a pas obtenu, On ne peut décemment pas clôturer sa plainte sans solution en lui disant d’être plus vigilant, il faut recommander à l’équipe en charge de la plainte de le rembourser ou de trouver un moyen de donner au Client le service pour lequel il a payé (à défaut de le rembourser).
52
- 5) N’hésites pas à innover sur le ton à utiliser car n’oublies pas que tu dois faire comme si tu parlais à un humain. Ce ton peut être adapté et ne pas toujours être le même en fonction des cas.
53
  """
54
- class TextSummary(BaseModel):
55
- prompt:str
56
-
57
-
58
- class ComplaintData(BaseModel):
59
- raw_text: str
60
-
61
- @app.get("/")
62
- async def home():
63
- return 'STN BIG DATA'
64
- # Fonction pour générer du texte à partir d'une requête
65
- def generate_text(prompt: str) -> str:
66
- # Préparer le texte d'entrée pour le modèle
67
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
68
-
69
- # Générer du texte avec le modèle
70
- outputs = model.generate(input_ids, max_length=1024, num_beams=5, early_stopping=True)
71
-
72
- # Décoder la sortie du modèle
73
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
-
75
- return generated_text
76
- # Endpoint de l'API pour la génération de texte
77
- @app.post("/summarize-v2/")
78
- async def generate_text_endpoint(request: TextSummary):
79
- try:
80
- # Appeler la fonction pour générer du texte
81
- generated_text = generate_text(request.prompt)
82
- return {"generated_text": generated_text}
83
- except Exception as e:
84
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  if __name__ == "__main__":
86
  uvicorn.run("app:app",reload=True)
87
 
88
 
89
-
90
 
 
1
  #load package
2
+ from fastapi import FastAPI
3
  from pydantic import BaseModel
 
 
4
  import torch
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ StoppingCriteria,
9
+ StoppingCriteriaList,
10
+ TextIteratorStreamer
11
+ )
12
+ from typing import List, Tuple
13
+ from threading import Thread
14
  import os
15
+ from pydantic import BaseModel
16
+ import logging
17
+ import uvicorn
18
+
19
 
20
  # Configurer les répertoires de cache
21
  os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
22
  os.environ['HF_HOME'] = '/app/.cache'
23
  # Charger le modèle et le tokenizer
24
+ model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
25
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)
26
+
27
 
28
  #Additional information
29
 
 
43
  logging.basicConfig(level=logging.INFO)
44
  logger =logging.getLogger(__name__)
45
 
46
+ class StopOnTokens(StoppingCriteria):
47
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
48
+ stop_ids = model.config.eos_token_id
49
+ for stop_id in stop_ids:
50
+ if input_ids[0][-1] == stop_id:
51
+ return True
52
+ return False
53
 
54
  default_prompt = """Bonjour,
55
+
56
+ En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
57
+ 1. **Informations Client** : Indique des détails pertinents sur le client.
58
+ 2. **Dates et Délais** : Mentionne les dates clés et les délais (prise en charge, résolution, etc.).
59
+ 3. **Contexte et Détails** : Inclut les éléments essentiels de la plainte (titre, détails, états d’avancement, qualification, fichiers joints).
60
+
61
+ Ajoute une recommandation importante pour éviter le mécontentement du client, par exemple, en cas de service non fourni malgré le paiement. Adapte le ton pour qu'il soit humain et engageant.
62
+
63
+ Merci !
64
+
 
 
 
 
65
  """
66
+ class PredictionRequest(BaseModel):
67
+ history: List[Tuple[str, str]] = []
68
+ prompt: str = default_prompt
69
+ max_length: int = 128000
70
+ top_p: float = 0.8
71
+ temperature: float = 0.6
72
+ @app.post("/generate/")
73
+ async def predict(request: PredictionRequest):
74
+ history = request.history
75
+ prompt = request.prompt
76
+ max_length = request.max_length
77
+ top_p = request.top_p
78
+ temperature = request.temperature
79
+
80
+ stop = StopOnTokens()
81
+ messages = []
82
+ if prompt:
83
+ messages.append({"role": "system", "content": prompt})
84
+ for idx, (user_msg, model_msg) in enumerate(history):
85
+ if prompt and idx == 0:
86
+ continue
87
+ if idx == len(history) - 1 and not model_msg:
88
+ query = user_msg
89
+ break
90
+ if user_msg:
91
+ messages.append({"role": "user", "content": user_msg})
92
+ if model_msg:
93
+ messages.append({"role": "assistant", "content": model_msg})
94
+
95
+ model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
96
+ next(model.parameters()).device)
97
+ streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
98
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
99
+ tokenizer.get_command("<|observation|>")]
100
+ generate_kwargs = {
101
+ "input_ids": model_inputs,
102
+ "streamer": streamer,
103
+ "max_new_tokens": max_length,
104
+ "do_sample": True,
105
+ "top_p": top_p,
106
+ "temperature": temperature,
107
+ "stopping_criteria": StoppingCriteriaList([stop]),
108
+ "repetition_penalty": 1,
109
+ "eos_token_id": eos_token_id,
110
+ }
111
+
112
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
113
+ t.start()
114
+
115
+ generated_text = ""
116
+ for new_token in streamer:
117
+ if new_token and '<|user|>' in new_token:
118
+ new_token = new_token.split('<|user|>')[0]
119
+ if new_token:
120
+ generated_text += new_token
121
+ history[-1][1] = generated_text
122
+
123
+ return {"history": history}
124
  if __name__ == "__main__":
125
  uvicorn.run("app:app",reload=True)
126
 
127
 
 
128
 
requirements.txt CHANGED
@@ -1,13 +1,20 @@
1
  fastapi==0.111.0
2
- torch==2.3.1
3
- transformers==4.44.1
4
  uvicorn==0.30.1
5
  pydantic==2.7.4
6
  pillow==10.3.0
7
  numpy
8
  scipy==1.11.3
9
- sentencepiece==0.2.0
10
  pytesseract==0.3.10
11
  Pillow==10.3.0
12
  BeautifulSoup4==4.12.3
13
  protobuf
 
 
 
 
 
 
 
 
 
1
  fastapi==0.111.0
2
+ torch==2.2.0
3
+ transformers==4.44.0
4
  uvicorn==0.30.1
5
  pydantic==2.7.4
6
  pillow==10.3.0
7
  numpy
8
  scipy==1.11.3
 
9
  pytesseract==0.3.10
10
  Pillow==10.3.0
11
  BeautifulSoup4==4.12.3
12
  protobuf
13
+ spaces==0.29.2
14
+ accelerate==0.33.0
15
+ sentencepiece==0.2.0
16
+ huggingface-hub==0.24.5
17
+ jinja2==3.1.4
18
+ sentence_transformers==3.0.1
19
+ tiktoken==0.7.0
20
+ einops==0.8.0