|
import streamlit as st |
|
import pandas as pd |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import os |
|
from huggingface_hub import login |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
|
|
|
|
def authenticate_huggingface(): |
|
token = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
if token: |
|
login(token) |
|
else: |
|
st.error("Hugging Face token not found. Please set the HUGGINGFACEHUB_API_TOKEN environment variable.") |
|
|
|
|
|
@st.cache_resource |
|
def load_llama_model(): |
|
authenticate_huggingface() |
|
model_name = "meta-llama/Llama-2-7b-hf" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, token=True) |
|
return tokenizer, model |
|
|
|
|
|
def query_llama_model(penal_code, tokenizer, model): |
|
prompt = f"What is California Penal Code {penal_code}?" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
outputs = model.generate(**inputs, max_new_tokens=100) |
|
|
|
|
|
description = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return description |
|
|
|
|
|
def update_csv_with_descriptions(csv_file, tokenizer, model): |
|
|
|
df = pd.read_csv(csv_file) |
|
|
|
|
|
penal_code_dict = {} |
|
|
|
|
|
for index, row in df.iterrows(): |
|
penal_code = row['Offense Number'] |
|
|
|
|
|
if not row['Description']: |
|
st.write(f"Querying description for {penal_code}...") |
|
description = query_llama_model(penal_code, tokenizer, model) |
|
|
|
|
|
df.at[index, 'Description'] = description |
|
|
|
|
|
penal_code_dict[penal_code] = description |
|
|
|
|
|
updated_file_path = 'updated_' + csv_file.name |
|
df.to_csv(updated_file_path, index=False) |
|
|
|
return penal_code_dict, updated_file_path |
|
|
|
|
|
def main(): |
|
st.title("Penal Code Description Extractor with Llama 2") |
|
|
|
|
|
tokenizer, model = load_llama_model() |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a CSV file with Penal Codes", type=["csv"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
st.write("Uploaded CSV File:") |
|
df = pd.read_csv(uploaded_file) |
|
st.dataframe(df) |
|
|
|
|
|
if st.button("Get Penal Code Descriptions"): |
|
penal_code_dict, updated_file_path = update_csv_with_descriptions(uploaded_file, tokenizer, model) |
|
|
|
|
|
st.write("Penal Code Descriptions:") |
|
st.json(penal_code_dict) |
|
|
|
|
|
with open(updated_file_path, 'rb') as f: |
|
st.download_button( |
|
label="Download Updated CSV", |
|
data=f, |
|
file_name=updated_file_path, |
|
mime='text/csv' |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|