File size: 10,701 Bytes
36ef005
 
9a4c5ac
 
 
26726c7
e8c22f8
 
36ef005
26726c7
36ef005
55d3e9f
 
36ef005
 
 
 
 
 
 
 
26726c7
6964639
36ef005
26726c7
 
 
 
 
 
36ef005
f149660
36ef005
 
 
 
 
26726c7
36ef005
 
 
f149660
36ef005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a4c5ac
36ef005
26726c7
36ef005
26726c7
36ef005
6964639
d64ef24
 
26726c7
d64ef24
36ef005
e8c22f8
9a4c5ac
 
e8c22f8
9a4c5ac
 
 
 
 
 
 
 
 
e8c22f8
9a4c5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6964639
26726c7
e8c22f8
 
 
 
 
 
 
 
 
4a47453
 
26726c7
4a47453
 
 
 
 
 
 
 
 
 
9a4c5ac
4a47453
 
6964639
26726c7
9a4c5ac
 
36ef005
9a4c5ac
36ef005
 
 
 
 
9a4c5ac
 
c7c3a66
 
9a4c5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c3a66
 
 
36ef005
 
 
 
 
 
 
 
 
 
 
 
f149660
9a4c5ac
f149660
26726c7
36ef005
26726c7
36ef005
 
9a4c5ac
 
 
36ef005
 
 
9a4c5ac
 
 
36ef005
 
 
 
 
 
f149660
36ef005
f149660
36ef005
3c9f4cd
36ef005
 
 
 
 
3c9f4cd
 
 
 
9a4c5ac
 
 
 
 
 
 
36ef005
c7c3a66
36ef005
c7c3a66
3c9f4cd
9a4c5ac
36ef005
 
 
 
 
 
 
 
 
4a47453
36ef005
 
 
 
3c9f4cd
 
 
36ef005
f149660
36ef005
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import streamlit as st
import torch
import pandas as pd
import PyPDF2
import pickle
import os
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
from huggingface_hub import login, hf_hub_download
import time
from ch09util import subsequent_mask, create_model

# Device setup
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Set page configuration
st.set_page_config(
    page_title="Translator Agent",
    page_icon="🚀",
    layout="centered"
)

# Model repository name
MODEL_NAME = "amiguel/custom-en2fr-transformer-v1"

# Retrieve Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    st.error("🔐 Hugging Face token not found in environment variables. Please set HF_TOKEN in Space secrets.")
    st.stop()

# Title with rocket emojis
st.title("🚀 English to French Translator 🚀")

# Configure Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"

# Sidebar configuration (removed token input)
with st.sidebar:
    st.header("Upload Documents 📂")
    uploaded_file = st.file_uploader(
        "Choose a PDF or XLSX file to translate",
        type=["pdf", "xlsx"],
        label_visibility="collapsed"
    )

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# File processing function
@st.cache_data
def process_file(uploaded_file):
    if uploaded_file is None:
        return ""
    
    try:
        if uploaded_file.type == "application/pdf":
            pdf_reader = PyPDF2.PdfReader(uploaded_file)
            return "\n".join([page.extract_text() for page in pdf_reader.pages])
        elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
            df = pd.read_excel(uploaded_file)
            return df.to_markdown()
    except Exception as e:
        st.error(f"📄 Error processing file: {str(e)}")
        return ""

# Custom model loading function
@st.cache_resource
def load_model_and_resources():
    try:
        login(token=HF_TOKEN)
        
        # Load tokenizer from the model repo
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME,
            token=HF_TOKEN
        )
        
        # Define Transformer configuration
        class TransformerConfig(PretrainedConfig):
            model_type = "custom_transformer"
            def __init__(self, src_vocab_size=11055, tgt_vocab_size=11239, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
                super().__init__(**kwargs)
                self.src_vocab_size = src_vocab_size
                self.tgt_vocab_size = tgt_vocab_size
                self.d_model = d_model
                self.d_ff = d_ff
                self.h = h
                self.N = N
                self.dropout = dropout

        # Define Transformer model
        class CustomTransformer(PreTrainedModel):
            config_class = TransformerConfig
            def __init__(self, config):
                super().__init__(config)
                self.model = create_model(
                    config.src_vocab_size,
                    config.tgt_vocab_size,
                    N=config.N,
                    d_model=config.d_model,
                    d_ff=config.d_ff,
                    h=config.h,
                    dropout=config.dropout
                )
            def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
                return self.model(src, tgt, src_mask, tgt_mask)

        # Load config with validation from the model repo
        config_dict = TransformerConfig.from_pretrained(MODEL_NAME, token=HF_TOKEN).to_dict()
        if "src_vocab_size" not in config_dict or "tgt_vocab_size" not in config_dict:
            st.warning(
                f"Config at {MODEL_NAME}/config.json is missing 'src_vocab_size' or 'tgt_vocab_size'. "
                "Using defaults (11055, 11239). For accuracy, update the training script to save these values."
            )
            config = TransformerConfig()
        else:
            config = TransformerConfig(**config_dict)

        # Initialize model on meta device and load weights explicitly
        model = CustomTransformer(config)
        weights_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors", token=HF_TOKEN)
        from safetensors.torch import load_file
        state_dict = load_file(weights_path)
        model.load_state_dict(state_dict)
        
        # Move model to the target device safely
        if DEVICE == "cuda":
            model = model.to_empty(device=DEVICE)  # Move structure to GPU
            model.load_state_dict(state_dict)      # Reload weights on GPU
        else:
            model = model.to(DEVICE)  # CPU can handle direct move after loading weights
        
        model.eval()

        # Load dictionaries from the model repo
        dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=HF_TOKEN)
        with open(dict_path, "rb") as fb:
            en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
        
        return model, tokenizer, en_word_dict, fr_word_dict, en_idx_dict, fr_idx_dict
        
    except Exception as e:
        st.error(f"🤖 Model loading failed: {str(e)}")
        return None

# Custom streaming generation function
def custom_streaming_generate(input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict):
    try:
        model.eval()
        PAD, UNK = 0, 1
        tokenized_en = ["BOS"] + tokenizer.tokenize(input_text) + ["EOS"]
        enidx = [en_word_dict.get(i, UNK) for i in tokenized_en]
        src = torch.tensor(enidx).long().to(DEVICE).unsqueeze(0)
        src_mask = (src != 0).unsqueeze(-2)
        memory = model.model.encode(src, src_mask)
        start_symbol = fr_word_dict["BOS"]
        ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
        for _ in range(100):
            out = model.model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
            prob = model.model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.data[0]
            sym = fr_idx_dict.get(next_word, "UNK")
            if sym != "EOS":
                token = sym.replace("</w>", " ")
                for x in '''?:;.,'("-!&)%''':
                    token = token.replace(f" {x}", f"{x}")
                yield token
            else:
                break
            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        # Yield a final empty token to ensure completion
        yield ""
        
    except Exception as e:
        raise Exception(f"Generation error: {str(e)}")

# Display chat messages
for message in st.session_state.messages:
    try:
        avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
        with st.chat_message(message["role"], avatar=avatar):
            st.markdown(message["content"])
    except:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

# Chat input handling
if prompt := st.chat_input("Enter text to translate into French..."):
    # Load model and resources if not already loaded
    if "model" not in st.session_state:
        model_data = load_model_and_resources()
        if model_data is None:
            st.error("Failed to load model. Please check the HF_TOKEN in Space secrets and try again.")
            st.stop()
            
        st.session_state.model, st.session_state.tokenizer, \
        st.session_state.en_word_dict, st.session_state.fr_word_dict, \
        st.session_state.en_idx_dict, st.session_state.fr_idx_dict = model_data
    
    model = st.session_state.model
    tokenizer = st.session_state.tokenizer
    en_word_dict = st.session_state.en_word_dict
    fr_word_dict = st.session_state.fr_word_dict
    fr_idx_dict = st.session_state.fr_idx_dict
    
    # Add user message
    with st.chat_message("user", avatar=USER_AVATAR):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Process file or use prompt directly
    file_context = process_file(uploaded_file)
    input_text = file_context if file_context else prompt
    
    # Generate translation with streaming
    if model and tokenizer:
        try:
            with st.chat_message("assistant", avatar=BOT_AVATAR):
                start_time = time.time()
                
                # Create a placeholder for streaming output
                response_container = st.empty()
                full_response = ""
                
                # Stream translation tokens
                for token in custom_streaming_generate(
                    input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict
                ):
                    if token:  # Only append non-empty tokens
                        full_response += token
                        response_container.markdown(full_response)
                
                # Calculate performance metrics
                end_time = time.time()
                input_tokens = len(tokenizer(input_text)["input_ids"])
                output_tokens = len(tokenizer(full_response)["input_ids"])
                speed = output_tokens / (end_time - start_time) if (end_time - start_time) > 0 else 0
                
                # Calculate costs (hypothetical pricing model)
                input_cost = (input_tokens / 1000000) * 5  # $5 per million input tokens
                output_cost = (output_tokens / 1000000) * 15  # $15 per million output tokens
                total_cost_usd = input_cost + output_cost
                total_cost_aoa = total_cost_usd * 1160  # Convert to AOA (Angolan Kwanza)
                
                # Display metrics
                st.caption(
                    f"🤖 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
                    f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | "
                    f"💵 Cost (AOA): {total_cost_aoa:.4f}"
                )
                
                # Store the full response in chat history
                st.session_state.messages.append({"role": "assistant", "content": full_response})
                
        except Exception as e:
            st.error(f"⚡ Translation error: {str(e)}")
    else:
        st.error("🤖 Model not loaded!")