SRT_Translation / app.py
Lenylvt's picture
Update app.py
8084b82 verified
raw
history blame
No virus
4.25 kB
import streamlit as st
import pandas as pd
import pysrt
from transformers import MarianMTModel, MarianTokenizer
import tempfile
import os
from io import BytesIO
import requests
def fetch_languages(url):
response = requests.get(url)
if response.status_code == 200:
df = pd.read_csv(BytesIO(response.content), delimiter="|", skiprows=2, header=None).dropna(axis=1, how='all')
df.columns = ['ISO 639-1', 'ISO 639-2', 'Language Name', 'Native Name']
df['ISO 639-1'] = df['ISO 639-1'].str.strip()
language_options = [(row['ISO 639-1'], f"{row['ISO 639-1']} - {row['Language Name']}") for index, row in df.iterrows()]
return language_options
else:
return []
def translate_text(text, source_language_code, target_language_code):
model_name = f"Helsinki-NLP/opus-mt-{source_language_code}-{target_language_code}"
if source_language_code == target_language_code:
return "Translation between the same languages is not supported."
try:
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
except Exception as e:
return f"Failed to load model for {source_language_code} to {target_language_code}: {str(e)}"
translated_texts = []
for sentence in text.split("\n"):
translated = model.generate(**tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512))
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
translated_texts.append(translated_text)
return "\n".join(translated_texts)
def translate_srt(input_file, source_language_code, target_language_code):
subs = pysrt.open(input_file)
total_subs = len(subs)
translated_subs = []
progress_text = "Operation in progress. For information, the progress bar start when the translation begin."
progress_bar = st.progress(0, text=progress_text) # Initialize the progress bar
for idx, sub in enumerate(subs):
translated_text = translate_text(sub.text, source_language_code, target_language_code)
translated_sub = pysrt.SubRipItem(index=idx+1, start=sub.start, end=sub.end, text=translated_text)
translated_subs.append(translated_sub)
progress_bar.progress((idx + 1) / total_subs) # Update progress bar
translated_file = pysrt.SubRipFile(items=translated_subs)
return translated_file
st.title("SRT Translation")
st.write("We use model from [Language Technology Research Group at the University of Helsinki](https://huggingface.co/Helsinki-NLP). For API use please visit [this space](https://huggingface.co/spaces/Lenylvt/SRT_Translation-API)")
# Fetch language options
url = "https://huggingface.co/Lenylvt/LanguageISO/resolve/main/iso.md"
language_options = fetch_languages(url)
source_language_code, target_language_code = None, None
if language_options:
source_language_code = st.selectbox("1️⃣ Select Source Language", options=language_options, format_func=lambda x: x[1])[0]
target_language_code = st.selectbox("2️⃣ Select Target Language", options=language_options, format_func=lambda x: x[1])[0]
file_input = st.file_uploader("📁 Upload SRT File", type=["srt"], accept_multiple_files=False)
if file_input is not None and source_language_code and target_language_code:
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as temp_file:
temp_file.write(file_input.getvalue())
temp_file.flush()
translated_srt = translate_srt(temp_file.name, source_language_code, target_language_code)
os.unlink(temp_file.name) # Delete the temp file
# Save the translated subtitles to a temporary file and then read it into BytesIO
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as temp_file:
translated_srt.save(temp_file.name, encoding='utf-8')
temp_file.seek(0)
translated_srt_bytes = open(temp_file.name, 'rb').read()
os.unlink(temp_file.name) # Delete the temp file after reading
st.download_button(
label="⬇️ Download Translated SRT",
data=translated_srt_bytes,
file_name="translated_subtitles.srt",
mime="text/plain",
)