Rakib Hossain
Add complete Bangla sentiment analysis: data, fine-tuned model, and visualizations
49c214c
| """ | |
| Compare multiple sentiment analysis models for Bangla | |
| Shows which model performs best | |
| """ | |
| from transformers import pipeline | |
| import pandas as pd | |
| from sklearn.metrics import accuracy_score, classification_report | |
| import time | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| class ModelComparison: | |
| def __init__(self): | |
| self.models = { | |
| 'Multilingual Sentiment': 'tabularisai/multilingual-sentiment-analysis', | |
| 'XLM-RoBERTa Base': 'cardiffnlp/twitter-xlm-roberta-base-sentiment', | |
| 'mBERT Sentiment': 'nlptown/bert-base-multilingual-uncased-sentiment', | |
| } | |
| self.results = {} | |
| def load_test_data(self, csv_file='data/raw/bangla_news_labeled.csv'): | |
| """Load manually labeled test data""" | |
| df = pd.read_csv(csv_file) | |
| return df['text'].tolist(), df['sentiment'].tolist() | |
| def evaluate_model(self, model_name, model_id, texts, true_labels): | |
| """Evaluate a single model""" | |
| print(f"\nπ Testing: {model_name}") | |
| print("-" * 60) | |
| try: | |
| # Load model | |
| classifier = pipeline("sentiment-analysis", model=model_id) | |
| # Predict | |
| predictions = [] | |
| start_time = time.time() | |
| for text in texts[:100]: # Test on first 100 samples | |
| try: | |
| result = classifier(text[:512])[0] | |
| predictions.append(result['label'].lower()) | |
| except: | |
| predictions.append('neutral') | |
| end_time = time.time() | |
| # Calculate metrics | |
| accuracy = accuracy_score(true_labels[:100], predictions) | |
| avg_time = (end_time - start_time) / 100 | |
| print(f"β Accuracy: {accuracy:.4f}") | |
| print(f"β±οΈ Avg Time per prediction: {avg_time:.4f}s") | |
| self.results[model_name] = { | |
| 'accuracy': accuracy, | |
| 'avg_time': avg_time, | |
| 'predictions': predictions | |
| } | |
| return True | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return False | |
| def compare_all_models(self): | |
| """Compare all models""" | |
| print("=" * 60) | |
| print("π MODEL COMPARISON FOR BANGLA SENTIMENT ANALYSIS") | |
| print("=" * 60) | |
| # Load test data | |
| texts, true_labels = self.load_test_data() | |
| # Test each model | |
| for model_name, model_id in self.models.items(): | |
| self.evaluate_model(model_name, model_id, texts, true_labels) | |
| time.sleep(2) # Prevent rate limiting | |
| # Summary | |
| self.print_summary() | |
| self.plot_comparison() | |
| def print_summary(self): | |
| """Print comparison summary""" | |
| print("\n" + "=" * 60) | |
| print("π COMPARISON SUMMARY") | |
| print("=" * 60) | |
| df_results = pd.DataFrame(self.results).T | |
| print(df_results[['accuracy', 'avg_time']]) | |
| # Find best model | |
| best_model = df_results['accuracy'].idxmax() | |
| print(f"\nπ Best Model: {best_model}") | |
| print(f" Accuracy: {df_results.loc[best_model, 'accuracy']:.4f}") | |
| # Save results | |
| df_results.to_csv('outputs/model_comparison_results.csv') | |
| print("\nπΎ Results saved to outputs/model_comparison_results.csv") | |
| def plot_comparison(self): | |
| """Create comparison visualizations""" | |
| df = pd.DataFrame(self.results).T | |
| # Accuracy comparison | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Plot 1: Accuracy | |
| df['accuracy'].plot(kind='bar', ax=ax1, color='skyblue', edgecolor='black') | |
| ax1.set_title('Model Accuracy Comparison', fontsize=14, fontweight='bold') | |
| ax1.set_ylabel('Accuracy', fontsize=12) | |
| ax1.set_xlabel('Model', fontsize=12) | |
| ax1.set_ylim(0, 1) | |
| ax1.grid(axis='y', alpha=0.3) | |
| plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right') | |
| # Plot 2: Speed | |
| df['avg_time'].plot(kind='bar', ax=ax2, color='lightcoral', edgecolor='black') | |
| ax2.set_title('Average Prediction Time', fontsize=14, fontweight='bold') | |
| ax2.set_ylabel('Time (seconds)', fontsize=12) | |
| ax2.set_xlabel('Model', fontsize=12) | |
| ax2.grid(axis='y', alpha=0.3) | |
| plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha='right') | |
| plt.tight_layout() | |
| plt.savefig('outputs/model_comparison.png', dpi=300, bbox_inches='tight') | |
| print("π Visualization saved to outputs/model_comparison.png") | |
| def main(): | |
| comparator = ModelComparison() | |
| comparator.compare_all_models() | |
| if __name__ == "__main__": | |
| main() |