File size: 1,774 Bytes
e19a1ca
09a6e08
aba50e6
 
 
 
 
 
 
 
 
 
 
e19a1ca
09a6e08
93097a8
 
 
 
 
 
 
 
aba50e6
93097a8
 
aba50e6
f1dc58b
93097a8
 
aba50e6
93097a8
 
aba50e6
93097a8
 
 
 
e19a1ca
aba50e6
09a6e08
 
 
 
aba50e6
 
e19a1ca
 
09a6e08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pandas as pd
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the Hugging Face forecasting model
def load_model():
    model_name = "Ankur87/Llama2_Time_series_forecasting_7.0"  # Using the specified model
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

def forecast(csv_file):
    # Read CSV with correct delimiter and parse timestamps
    data = pd.read_csv(csv_file.name, sep=";", parse_dates=['timestamp_column'])

    # Ensure timestamp format is correct
    data['timestamp_column'] = pd.to_datetime(data['timestamp_column'], format="%Y%m%d %H:%M")

    # Convert data to a structured format for the model
    input_text = "\n".join([f"{row['timestamp_column']}: {row['Inbound']}" for _, row in data.iterrows()])
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)

    # Generate forecast
    with torch.no_grad():
        predictions = model.generate(**inputs, max_length=2500, num_return_sequences=1)

    # Decode the generated forecast
    forecast_text = tokenizer.decode(predictions[0], skip_special_tokens=True)

    # Save forecast result to CSV
    forecasts = pd.DataFrame({'forecast': [forecast_text]})
    output_file = "forecasts.csv"
    forecasts.to_csv(output_file, index=False)
    
    return output_file

# Gradio Interface
iface = gr.Interface(
    fn=forecast,
    inputs=gr.File(label="Upload CSV File"),
    outputs=gr.File(label="Download Forecasts"),
    title="Time Series Forecasting with Llama2",
    description="Upload a CSV file with a timestamp column to generate forecasts using Llama2."
)

iface.launch()