Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
from tqdm import tqdm | |
from facility_predict import Preprocess, Facility_Model, obj_Facility_Model, processor | |
def predict_batch_from_csv(input_file, output_file): | |
# Load batch data from CSV | |
batch_data = pd.read_csv(input_file) | |
# Initialize predictions list | |
predictions = [] | |
# Iterate over rows with tqdm for progress tracking | |
for _, row in tqdm(batch_data.iterrows(), total=len(batch_data)): | |
text = row['facility_name'] # Replace 'facility_name' with the actual column name containing the text data | |
if pd.isnull(text): | |
cleaned_text = "" | |
else: | |
cleaned_text = processor.clean_text(text) | |
prepared_data = processor.process_tokenizer(cleaned_text) | |
if cleaned_text == "": | |
prediction = "" # Set prediction as empty string | |
else: | |
prediction = obj_Facility_Model.inference(prepared_data) | |
predictions.append(prediction) | |
# Create DataFrame for predictions | |
output_data = pd.DataFrame({'prediction': predictions}) | |
# Merge with input DataFrame | |
pred_output_df = pd.concat([batch_data.reset_index(drop=True), output_data], axis=1) | |
# Save predictions to CSV | |
pred_output_df.to_csv(output_file, index=False) | |
return "Prediction completed. Results saved to " + output_file | |
# Define the Gradio interface | |
input_csv = gr.inputs.File(label="Input CSV", type="file") | |
output_csv = gr.outputs.File(label="Output CSV") | |
# Define the prediction function for the Gradio interface | |
def predict_interface(input_file): | |
output_file = "./output.csv" | |
predict_batch_from_csv(input_file.name, output_file) | |
return output_file | |
# Connect the interface with the prediction function | |
iface = gr.Interface(fn=predict_interface, inputs=input_csv, outputs=output_csv, title="CSV Batch Prediction") | |
# Run the interface | |
iface.launch() | |