File size: 2,978 Bytes
f1ace72
 
 
139f8c4
f1ace72
 
 
 
 
 
 
 
 
139f8c4
f1ace72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139f8c4
 
f1ace72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139f8c4
 
 
f1ace72
 
 
 
 
 
 
 
 
 
 
 
 
139f8c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1ace72
139f8c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, HTTPException
from transformers import pipeline
import uvicorn
import streamlit as st

# Load trained model
model_name = "DINGOLANI/distilbert-ner-v2"

try:
    nlp_ner = pipeline("token-classification", model=model_name, tokenizer=model_name)
except Exception as e:
    raise RuntimeError(f"Failed to load model: {e}")

# Corrected label mapping based on expected training labels
label_map = {
    "LABEL_1": "B-BRAND",
    "LABEL_2": "I-BRAND",
    "LABEL_3": "B-CATEGORY",
    "LABEL_4": "I-CATEGORY",
    "LABEL_5": "B-GENDER",
    "LABEL_6": "B-PRICE",
    "LABEL_7": "I-PRICE"
}

entity_filter = {
    "B-BRAND": "BRAND",
    "I-BRAND": "BRAND",
    "B-CATEGORY": "CATEGORY",
    "I-CATEGORY": "CATEGORY",
    "B-GENDER": "GENDER",
    "B-PRICE": "PRICE",
    "I-PRICE": "PRICE"
}

app = FastAPI()

@app.get("/")
def home():
    return {"message": "NER API is running!"}

@app.post("/predict/")
def predict(query: str):
    try:
        result = nlp_ner(query)  

        for label in result:
            label["score"] = float(label["score"])

        print("RAW MODEL OUTPUT:", result)  

        structured_output = {}
        prev_label = None
        prev_word = None
        
        for label in result:
            entity_bio = label_map.get(label.get("entity"))  
            entity = entity_filter.get(entity_bio)

            if entity:  
                word = label["word"]
                
                if word.startswith("##"):
                    if prev_label == entity and prev_word:
                        structured_output[entity][-1] += word[2:]
                    else:
                        structured_output.setdefault(entity, []).append(word[2:])
                else:
                    structured_output.setdefault(entity, []).append(word)

                prev_label = entity
                prev_word = word

        return {
            "query": query,
            "raw_output": result,  
            "structured_output": structured_output  
        }
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {e}")

# πŸš€ Streamlit Frontend
def main():
    st.set_page_config(page_title="Luxury Fashion NER", layout="wide")

    st.title("πŸ‘œ Luxury Fashion Entity Extractor")
    st.write("Enter a text query and extract structured entities like **Brand, Category, Gender, and Price.**")

    query = st.text_input("Enter Query:", "Gucci handbags for women under $5000")

    if st.button("Analyze"):
        response = predict(query)
        
        col1, col2 = st.columns(2)
        
        with col1:
            st.subheader("πŸ” Structured Output")
            for key, value in response["structured_output"].items():
                st.write(f"**{key}:** {', '.join(value)}")
        
        with col2:
            st.subheader("πŸ›  Raw Model Output")
            st.json(response["raw_output"])

if __name__ == "__main__":
    main()