File size: 3,693 Bytes
51391bc
 
 
 
1f34b73
d3f3263
e1242c3
50c7bec
d3f3263
51391bc
 
 
6f7e417
51391bc
 
 
6f7e417
51391bc
 
 
 
 
 
7283ba2
 
51391bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from huggingface_hub import login
from dotenv import load_dotenv
load_dotenv()
# token = os.environ['YOUR_ACCESS_TOKEN_VARIABLE']

# Authenticate with Hugging Face
def authenticate_huggingface():
    token = os.getenv("HUGGINGFACEHUB_API_TOKEN")  # Load token from environment variable
    if token:
        login(token)  # This logs in using the Hugging Face token
    else:
        st.error("Hugging Face token not found. Please set the HUGGINGFACEHUB_API_TOKEN environment variable.")

# Load the Llama 2 model from Hugging Face
@st.cache_resource
def load_llama_model():
    authenticate_huggingface()  # Ensure authentication is done before loading
    model_name = "meta-llama/Llama-2-7b-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, token=True)
    return tokenizer, model

# Function to query the Llama 2 model
def query_llama_model(penal_code, tokenizer, model):
    prompt = f"What is California Penal Code {penal_code}?"

    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt")

    # Generate output from the model
    outputs = model.generate(**inputs, max_new_tokens=100)

    # Decode the generated text
    description = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return description

# Function to process CSV and update descriptions
def update_csv_with_descriptions(csv_file, tokenizer, model):
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    # Dictionary to store penal codes and their descriptions
    penal_code_dict = {}
    
    # Iterate through each row in the CSV
    for index, row in df.iterrows():
        penal_code = row['Offense Number']
        
        # Check if description is already present
        if not row['Description']:
            st.write(f"Querying description for {penal_code}...")
            description = query_llama_model(penal_code, tokenizer, model)
            
            # Update the dataframe with the description
            df.at[index, 'Description'] = description
            
            # Add to dictionary
            penal_code_dict[penal_code] = description
    
    # Save the updated CSV file
    updated_file_path = 'updated_' + csv_file.name
    df.to_csv(updated_file_path, index=False)
    
    return penal_code_dict, updated_file_path

# Streamlit UI
def main():
    st.title("Penal Code Description Extractor with Llama 2")

    # Load the Llama 2 model and tokenizer
    tokenizer, model = load_llama_model()
    
    # Upload CSV file
    uploaded_file = st.file_uploader("Upload a CSV file with Penal Codes", type=["csv"])
    
    if uploaded_file is not None:
        # Display uploaded file
        st.write("Uploaded CSV File:")
        df = pd.read_csv(uploaded_file)
        st.dataframe(df)
        
        # Process the file and update descriptions
        if st.button("Get Penal Code Descriptions"):
            penal_code_dict, updated_file_path = update_csv_with_descriptions(uploaded_file, tokenizer, model)
            
            # Show dictionary output
            st.write("Penal Code Descriptions:")
            st.json(penal_code_dict)
            
            # Provide a download link for the updated CSV
            with open(updated_file_path, 'rb') as f:
                st.download_button(
                    label="Download Updated CSV",
                    data=f,
                    file_name=updated_file_path,
                    mime='text/csv'
                )

if __name__ == "__main__":
    main()