MLR-Copilot / example /ex2_init.py
Lim0011's picture
Upload 2 files
960d190 verified
import pandas as pd
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import numpy as np
import random
import torch
from sklearn.model_selection import train_test_split
DIMENSIONS = ["cohesion", "syntax", "vocabulary", "phraseology", "grammar", "conventions"]
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
def compute_metrics_for_regression(y_test, y_test_pred):
metrics = {}
for task in DIMENSIONS:
targets_task = [t[DIMENSIONS.index(task)] for t in y_test]
pred_task = [l[DIMENSIONS.index(task)] for l in y_test_pred]
rmse = mean_squared_error(targets_task, pred_task, squared=False)
metrics[f"rmse_{task}"] = rmse
return metrics
def train_model(X_train, y_train, X_valid, y_valid):
model = None # Placeholder for model training
return model
def predict(model, X):
y_pred = np.random.rand(len(X), len(DIMENSIONS))
return y_pred
if __name__ == '__main__':
ellipse_df = pd.read_csv('train.csv',
header=0, names=['text_id', 'full_text', 'Cohesion', 'Syntax',
'Vocabulary', 'Phraseology','Grammar', 'Conventions'],
index_col='text_id')
ellipse_df = ellipse_df.dropna(axis=0)
data_df = ellipse_df
X = list(data_df.full_text.to_numpy())
y = np.array([data_df.drop(['full_text'], axis=1).iloc[i] for i in range(len(X))])
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.10, random_state=SEED)
model = train_model(X_train, y_train, X_valid, y_valid)
y_valid_pred = predict(model, X_valid)
metrics = compute_metrics_for_regression(y_valid, y_valid_pred)
print(metrics)
print("final MCRMSE on validation set: ", np.mean(list(metrics.values())))
submission_df = pd.read_csv('test.csv', header=0, names=['text_id', 'full_text'], index_col='text_id')
X_submission = list(submission_df.full_text.to_numpy())
y_submission = predict(model, X_submission)
submission_df = pd.DataFrame(y_submission, columns=DIMENSIONS)
submission_df.index = submission_df.index.rename('text_id')
submission_df.to_csv('submission.csv')