rss9051 commited on
Commit
19ed2fd
·
verified ·
1 Parent(s): 4d9e68c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -8,14 +8,15 @@ model_name = "rss9051/autotrein-BERT-iiLEX-dgs-0004"
8
  client = InferenceClient(model=model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
- # Função para dividir texto em chunks com base na tokenização
12
  def split_text_into_chunks(text, max_tokens=512):
13
  tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
14
  chunks = []
15
  for i in range(0, len(tokens), max_tokens):
16
- chunk = tokens[i:i + max_tokens] # Garantir que cada chunk tenha no máximo max_tokens
 
 
17
  chunks.append(chunk)
18
- # Decodificar os chunks de volta para texto
19
  return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
20
 
21
  # Função para classificar texto longo
@@ -24,13 +25,16 @@ def classify_text(text):
24
  all_responses = [] # Lista para armazenar respostas de cada chunk
25
 
26
  for chunk in chunks:
27
- response_bytes = client.post(json={"inputs": chunk}) # Enviar o chunk
28
- response_str = response_bytes.decode('utf-8') # Decodificar de bytes para string
29
- response = json.loads(response_str) # Converter string JSON para objeto Python
 
30
 
31
- if isinstance(response, list) and len(response) > 0:
32
- sorted_response = sorted(response[0], key=lambda x: x['score'], reverse=True)
33
- all_responses.append(sorted_response[0]) # Adicionar a melhor classificação do chunk
 
 
34
 
35
  # Combinar resultados de todos os chunks
36
  if all_responses:
 
8
  client = InferenceClient(model=model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
+ # Função para dividir texto em chunks com truncamento garantido
12
  def split_text_into_chunks(text, max_tokens=512):
13
  tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
14
  chunks = []
15
  for i in range(0, len(tokens), max_tokens):
16
+ chunk = tokens[i:i + max_tokens]
17
+ if len(chunk) > max_tokens:
18
+ chunk = chunk[:max_tokens] # Truncar qualquer excesso
19
  chunks.append(chunk)
 
20
  return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
21
 
22
  # Função para classificar texto longo
 
25
  all_responses = [] # Lista para armazenar respostas de cada chunk
26
 
27
  for chunk in chunks:
28
+ try:
29
+ response_bytes = client.post(json={"inputs": chunk}) # Enviar o chunk
30
+ response_str = response_bytes.decode('utf-8') # Decodificar de bytes para string
31
+ response = json.loads(response_str) # Converter string JSON para objeto Python
32
 
33
+ if isinstance(response, list) and len(response) > 0:
34
+ sorted_response = sorted(response[0], key=lambda x: x['score'], reverse=True)
35
+ all_responses.append(sorted_response[0]) # Adicionar a melhor classificação do chunk
36
+ except Exception as e:
37
+ print(f"Erro ao processar chunk: {e}")
38
 
39
  # Combinar resultados de todos os chunks
40
  if all_responses: