Pclanglais commited on
Commit
1b0843f
1 Parent(s): 9bd205f

Create inference_classification_transcript.py

Browse files
inference_classification_transcript.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
+ from tqdm.auto import tqdm
4
+
5
+ # Constants
6
+ batch_size = 1000
7
+
8
+ # Load tokenizer and model
9
+ model_checkpoint = "PleIAs/French-TV-Headline-Classification"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=512)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
12
+ classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
13
+
14
+ # Read the dataset
15
+ val_classification = pd.read_parquet("[file]")
16
+ val_classification.reset_index(drop=True, inplace=True)
17
+
18
+ # Calculate the number of batches needed
19
+ num_batches = (len(val_classification) + batch_size - 1) // batch_size
20
+
21
+ # Initialize the list to collect DataFrames
22
+ list_df = []
23
+
24
+ for i in tqdm(range(num_batches), desc="Processing batches"):
25
+ start_index = i * batch_size
26
+ end_index = min((i + 1) * batch_size, len(val_classification))
27
+ batch = val_classification.iloc[start_index:end_index]
28
+
29
+ # Extract texts from the DataFrame
30
+ texts = batch["corrected_text"].tolist()
31
+
32
+ # Classify texts in batches
33
+ classifications = classification_pipeline(texts, truncation=True, padding=True, top_k=None)
34
+
35
+ # Prepare data for DataFrame
36
+ rows = []
37
+ for text_index, class_results in enumerate(classifications):
38
+ for entry in class_results:
39
+ rows.append({
40
+ 'text_id': start_index + text_index,
41
+ 'label': entry['label'],
42
+ 'score': round(entry['score'] * 100, 2),
43
+ 'identifier': batch.iloc[text_index]['identifier']
44
+ })
45
+
46
+ # Create DataFrame from the processed batch
47
+ df = pd.DataFrame(rows)
48
+ list_df.append(df)
49
+
50
+
51
+ # Concatenate all DataFrames in the list
52
+ final_df = pd.concat(list_df, ignore_index=True)
53
+
54
+ print(final_df)
55
+
56
+ # Save the resulting DataFrame to a CSV file
57
+ final_df.to_csv("transcript_classification.csv", index=False)