# ########################################################################### # # CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) # (C) Cloudera, Inc. 2022 # All rights reserved. # # Applicable Open Source License: Apache 2.0 # # NOTE: Cloudera open source products are modular software products # made up of hundreds of individual components, each of which was # individually copyrighted. Each Cloudera open source product is a # collective work under U.S. Copyright Law. Your license to use the # collective work is as provided in your written agreement with # Cloudera. Used apart from the collective work, this file is # licensed for your use pursuant to the open source license # identified above. # # This code is provided to you pursuant a written agreement with # (i) Cloudera, Inc. or (ii) a third-party authorized to distribute # this code. If you do not have a written agreement with Cloudera nor # with an authorized and properly licensed third party, you do not # have any rights to access nor to use this code. # # Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the # contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY # KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED # WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO # IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND # FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, # AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS # ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE # OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR # CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES # RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # # ########################################################################### from typing import List import tokenizers import streamlit as st from src.style_transfer import StyleTransfer from src.style_classification import StyleIntensityClassifier from src.content_preservation import ContentPreservationScorer from src.transformer_interpretability import InterpretTransformer from apps.data_utils import StyleAttributeData, string_to_list_string # CALLBACKS def increment_page_progress(): st.session_state.page_progress += 1 def reset_page_progress_state(): del st.session_state.st_result st.session_state.page_progress = 1 # UTILITY CLASSES class DisableableButton: """ Utility class for creating "disable-able" buttons upon click. We initialize an empty container, then update that container with buttons upon calling `create_enabled_button` and `disable` methods where clicking is enabled and then disabled, respectively. """ def __init__(self, button_number, button_text): self.button_number = button_number self.button_text = button_text def _init_placeholder_container(self): self.ph = st.empty() def create_enabled_button(self): self._init_placeholder_container() self.ph.button( self.button_text, on_click=increment_page_progress, key=f"ph{self.button_number}_before", disabled=False, ) def disable(self): self.ph.button( self.button_text, key=f"ph{self.button_number}_after", disabled=True ) # CACHED FUNCTIONS @st.cache( hash_funcs={tokenizers.Tokenizer: lambda _: None}, allow_output_mutation=True, show_spinner=False, ) def get_cached_style_intensity_classifier( style_data: StyleAttributeData, ) -> StyleIntensityClassifier: """ Return a cached style classifier. This function overwrites the existing model's config values for `id2label` and `label2id`. Args: style_data (StyleAttributeData) Returns: StyleIntensityClassifier """ sic = StyleIntensityClassifier(style_data.cls_model_path) # create or overwrite id-label lookup in model config sic.pipeline.model.config.__dict__["id2label"] = { i: a for i, a in enumerate( [ style_data.source_attribute.capitalize(), style_data.target_attribute.capitalize(), ] ) } sic.pipeline.model.config.__dict__["label2id"] = { v: k for k, v in sic.pipeline.model.config.__dict__["id2label"].items() } return sic @st.cache( hash_funcs={tokenizers.Tokenizer: lambda _: None}, allow_output_mutation=True, show_spinner=False, ) def get_cached_word_attributions( text_sample: str, style_data: StyleAttributeData ) -> str: """ Calculated word attributions and return HTML visual. This function overwrites the existing model's config values for `id2label` and `label2id`. Args: text_sample (str) style_data (StyleAttributeData) Returns: str """ it = InterpretTransformer(cls_model_identifier=style_data.cls_model_path) # create or overwrite id-label lookup in model config it.explainer.id2label = { i: a for i, a in enumerate( [ style_data.source_attribute.capitalize(), style_data.target_attribute.capitalize(), ] ) } it.explainer.label2id = {v: k for k, v in it.explainer.id2label.items()} return it.visualize_feature_attribution_scores(text_sample).data @st.cache( hash_funcs={tokenizers.Tokenizer: lambda _: None}, allow_output_mutation=True, show_spinner=False, ) def get_sti_metric( input_text: str, output_text: str, style_data: StyleAttributeData ) -> List[float]: """ Calculate Style Transfer Intensity (STI) Args: input_text (str) output_text (str) style_data (StyleAttributeData) Returns: List[float] """ sti = StyleIntensityClassifier( model_identifier=style_data.cls_model_path, ) return sti.calculate_transfer_intensity_fraction( string_to_list_string(input_text), string_to_list_string(output_text) ) @st.cache( hash_funcs={tokenizers.Tokenizer: lambda _: None}, allow_output_mutation=True, show_spinner=False, ) def get_cps_metric( input_text: str, output_text: str, style_data: StyleAttributeData ) -> List[float]: """ Calculate Content Preservation Score (CPS) Args: input_text (str) output_text (str) style_data (StyleAttributeData) Returns: List[float] """ cps = ContentPreservationScorer( cls_model_identifier=style_data.cls_model_path, sbert_model_identifier=style_data.sbert_model_path, ) return cps.calculate_content_preservation_score( string_to_list_string(input_text), string_to_list_string(output_text), mask_type="none", ) def generate_style_transfer( text_sample: str, style_data: StyleAttributeData, max_gen_length: int, num_beams: int, temperature: int, ): """ Run inference on seq2seq model and persist result to `session_state` varaible. Args: text_sample (str): _description_ style_data (StyleAttributeData): _description_ max_gen_length (int): _description_ num_beams (int): _description_ temperature (int): _description_ """ with st.spinner("Transferring style, hang tight!"): generate_kwargs = { "max_gen_length": max_gen_length, "num_beams": num_beams, "temperature": temperature, } st_class = StyleTransfer( model_identifier=style_data.seq2seq_model_path, **generate_kwargs, ) st_result = st_class.transfer(text_sample) st.session_state.st_result = st_result