Create app.py
Browse files
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,194 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import time
         
     | 
| 3 | 
         
            +
            import lhotse
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import os
         
     | 
| 6 | 
         
            +
            from transformers import Wav2Vec2ForCTC, Wav2Vec2ForPreTraining
         
     | 
| 7 | 
         
            +
            import gradio as gr
         
     | 
| 8 | 
         
            +
            import geoviews as gv
         
     | 
| 9 | 
         
            +
            import geoviews.tile_sources as gts
         
     | 
| 10 | 
         
            +
            import uuid
         
     | 
| 11 | 
         
            +
            import gdown
         
     | 
| 12 | 
         
            +
            import math
         
     | 
| 13 | 
         
            +
            import torch.nn as nn
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            device = torch.device("cpu")
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class AttentionPool(nn.Module):
         
     | 
| 19 | 
         
            +
                def __init__(self, att, query_embed):
         
     | 
| 20 | 
         
            +
                    super(AttentionPool, self).__init__()
         
     | 
| 21 | 
         
            +
                    self.query_embed = query_embed
         
     | 
| 22 | 
         
            +
                    self.att = att
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
         
     | 
| 25 | 
         
            +
                    # Create mask
         
     | 
| 26 | 
         
            +
                    max_seq_length = x_lens.max().item()
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    # Step 2: Create a binary mask
         
     | 
| 29 | 
         
            +
                    mask = torch.arange(max_seq_length)[None, :].to(x.device) >= x_lens[:, None]
         
     | 
| 30 | 
         
            +
                    
         
     | 
| 31 | 
         
            +
                    # Step 3: Expand the mask to match the shape required by MultiheadAttention
         
     | 
| 32 | 
         
            +
                    # The mask should have shape (batch_size, 1, 1, max_seq_length)
         
     | 
| 33 | 
         
            +
                    x, w = self.att(
         
     | 
| 34 | 
         
            +
                        self.query_embed.unsqueeze(0).unsqueeze(1).repeat(x.size(0), 1, 1),
         
     | 
| 35 | 
         
            +
                        x,
         
     | 
| 36 | 
         
            +
                        x,
         
     | 
| 37 | 
         
            +
                        key_padding_mask=mask
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
                    x = x.squeeze(1)
         
     | 
| 40 | 
         
            +
                    return x, w
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class AveragePool(nn.Module):
         
     | 
| 44 | 
         
            +
                def __init__(self):
         
     | 
| 45 | 
         
            +
                    super(AveragePool, self).__init__()
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
         
     | 
| 48 | 
         
            +
                    # Create mask
         
     | 
| 49 | 
         
            +
                    max_seq_length = x_lens.max().item()
         
     | 
| 50 | 
         
            +
                    # Step 2: Create a binary mask
         
     | 
| 51 | 
         
            +
                    mask = torch.arange(max_seq_length)[None, :].to(x.device) >= x_lens[:, None]
         
     | 
| 52 | 
         
            +
                    x[mask] = torch.nan
         
     | 
| 53 | 
         
            +
                    return x.nanmean(dim=1), None 
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            class Wav2Vec2Model(nn.Module):
         
     | 
| 57 | 
         
            +
                def __init__(self,
         
     | 
| 58 | 
         
            +
                    modelpath='facebook/mms-300m',
         
     | 
| 59 | 
         
            +
                    freeze_feat_extractor=True,
         
     | 
| 60 | 
         
            +
                    pooling_loc=0,
         
     | 
| 61 | 
         
            +
                    pooling_type='att',
         
     | 
| 62 | 
         
            +
                ):
         
     | 
| 63 | 
         
            +
                    super(Wav2Vec2Model, self).__init__()
         
     | 
| 64 | 
         
            +
                    try:
         
     | 
| 65 | 
         
            +
                        self.encoder = Wav2Vec2ForCTC.from_pretrained(modelpath).wav2vec2
         
     | 
| 66 | 
         
            +
                    except:
         
     | 
| 67 | 
         
            +
                        self.encoder = Wav2Vec2ForPreTraining.from_pretrained(modelpath).wav2vec2
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    if freeze_feat_extractor:
         
     | 
| 70 | 
         
            +
                        self.encoder.feature_extractor._freeze_parameters()
         
     | 
| 71 | 
         
            +
                    self.freeze_feat_extractor = freeze_feat_extractor
         
     | 
| 72 | 
         
            +
                    self.odim = self._get_output_dim()
         
     | 
| 73 | 
         
            +
                    
         
     | 
| 74 | 
         
            +
                    self.frozen = False
         
     | 
| 75 | 
         
            +
                    if pooling_type == 'att':
         
     | 
| 76 | 
         
            +
                        assert pooling_loc == 0
         
     | 
| 77 | 
         
            +
                        self.att = nn.MultiheadAttention(self.odim, 1, batch_first=True)
         
     | 
| 78 | 
         
            +
                        self.loc_embed = nn.Parameter(
         
     | 
| 79 | 
         
            +
                            torch.FloatTensor(self.odim).uniform_(-1, 1)
         
     | 
| 80 | 
         
            +
                        )
         
     | 
| 81 | 
         
            +
                        self.pooling = AttentionPool(self.att, self.loc_embed)
         
     | 
| 82 | 
         
            +
                    elif pooling_type == 'avg':
         
     | 
| 83 | 
         
            +
                        self.pooling = AveragePool()
         
     | 
| 84 | 
         
            +
                    self.pooling_type = pooling_type
         
     | 
| 85 | 
         
            +
                    # pooling loc is on 0: embeddings 1: unnormalized coords, 2: normalized coords
         
     | 
| 86 | 
         
            +
                    self.pooling_loc = pooling_loc
         
     | 
| 87 | 
         
            +
                    self.linear_out = nn.Linear(self.odim, 3)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
         
     | 
| 90 | 
         
            +
                    x = self.encoder(
         
     | 
| 91 | 
         
            +
                        x.squeeze(-1), output_hidden_states=False
         
     | 
| 92 | 
         
            +
                    )[0]
         
     | 
| 93 | 
         
            +
                    
         
     | 
| 94 | 
         
            +
                    for width, stride in [(10, 5), (3, 2), (3, 2), (3, 2), (3, 2), (2, 2), (2, 2)]:
         
     | 
| 95 | 
         
            +
                        x_lens = torch.floor((x_lens - width) / stride + 1)
         
     | 
| 96 | 
         
            +
                    if self.pooling_loc == 0: 
         
     | 
| 97 | 
         
            +
                        x, w = self.pooling(x, x_lens)
         
     | 
| 98 | 
         
            +
                        x = self.linear_out(x)
         
     | 
| 99 | 
         
            +
                        x = x.div(x.norm(dim=1).unsqueeze(-1))
         
     | 
| 100 | 
         
            +
                    elif self.pooling_loc == 1:
         
     | 
| 101 | 
         
            +
                        x = self.linear_out(x)
         
     | 
| 102 | 
         
            +
                        x, w = self.pooling(x, x_lens)
         
     | 
| 103 | 
         
            +
                        x = x.div(x.norm(dim=1).unsqueeze(-1))
         
     | 
| 104 | 
         
            +
                    elif self.pooling_loc == 2:
         
     | 
| 105 | 
         
            +
                        x = self.linear_out(x)
         
     | 
| 106 | 
         
            +
                        x = x.div(x.norm(dim=1).unsqueeze(-1))
         
     | 
| 107 | 
         
            +
                        x = self.pooling(x, x_lens)
         
     | 
| 108 | 
         
            +
                        x = x.div(x.norm(dim=1).unsqueeze(-1))
         
     | 
| 109 | 
         
            +
                    return x, w
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def freeze_encoder(self):
         
     | 
| 112 | 
         
            +
                    for p in self.encoder.encoder.parameters():
         
     | 
| 113 | 
         
            +
                        if p.requires_grad:
         
     | 
| 114 | 
         
            +
                            p.requires_grad = False
         
     | 
| 115 | 
         
            +
                    self.frozen = True
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                def unfreeze_encoder(self):
         
     | 
| 118 | 
         
            +
                    for i, p in enumerate(self.encoder.encoder.parameters()):
         
     | 
| 119 | 
         
            +
                        p.requires_grad = True
         
     | 
| 120 | 
         
            +
                    if self.freeze_feat_extractor:
         
     | 
| 121 | 
         
            +
                        self.encoder.feature_extractor._freeze_parameters()
         
     | 
| 122 | 
         
            +
                    self.frozen = False
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def _get_output_dim(self):
         
     | 
| 125 | 
         
            +
                    x = torch.rand(1, 400)
         
     | 
| 126 | 
         
            +
                    return self.encoder(x).last_hidden_state.size(-1)
         
     | 
| 127 | 
         
            +
             
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            # download model checkpoint
         
     | 
| 130 | 
         
            +
            # bad way to do this probably but oh well
         
     | 
| 131 | 
         
            +
            if 'checkpoint.pt' not in os.listdir():
         
     | 
| 132 | 
         
            +
                checkpoint_url = "https://drive.google.com/uc?id=162jJ_YC4MGEfXBWvAK-kXnZcXX3v1smr"
         
     | 
| 133 | 
         
            +
                output = "checkpoint.pt"
         
     | 
| 134 | 
         
            +
                gdown.download(checkpoint_url, output, quiet=False)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            model = Wav2Vec2Model()
         
     | 
| 137 | 
         
            +
            model.to(device)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            # load model checkpoint
         
     | 
| 140 | 
         
            +
            for f in os.listdir():
         
     | 
| 141 | 
         
            +
                if '.pt' in f and 'checkpoint' in f:
         
     | 
| 142 | 
         
            +
                    checkpoint = torch.load(f, map_location=f'cpu')
         
     | 
| 143 | 
         
            +
                    model.load_state_dict(checkpoint)
         
     | 
| 144 | 
         
            +
                    model.eval()
         
     | 
| 145 | 
         
            +
                    print(f'Loaded state dict {f}')
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            def predict(audio_path):
         
     | 
| 148 | 
         
            +
                # get raw audio data
         
     | 
| 149 | 
         
            +
                try:
         
     | 
| 150 | 
         
            +
                    a = lhotse.Recording.from_file(audio_path)
         
     | 
| 151 | 
         
            +
                except:
         
     | 
| 152 | 
         
            +
                    return (None, "Please wait a bit until the audio file has uploaded, then try again")
         
     | 
| 153 | 
         
            +
                a = a.resample(16000)
         
     | 
| 154 | 
         
            +
                a = lhotse.cut.MultiCut(recording = a, start=0, duration=10, id="temp", channel=a.to_dict()['sources'][0]['channels']).to_mono(mono_downmix = True) # if multi channel, convert to single channel
         
     | 
| 155 | 
         
            +
                cuts = lhotse.CutSet(cuts={"cut":a})
         
     | 
| 156 | 
         
            +
                
         
     | 
| 157 | 
         
            +
                audio_data, audio_lens = lhotse.dataset.collation.collate_audio(cuts)   
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                # pass through model
         
     | 
| 160 | 
         
            +
                x, _ = model.forward(audio_data, audio_lens)
         
     | 
| 161 | 
         
            +
                print(x)
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                pred_lon = torch.atan2(x[:, 0], x[:, 1]).unsqueeze(-1)
         
     | 
| 164 | 
         
            +
                pred_lat = torch.asin(x[:, 2]).unsqueeze(-1)
         
     | 
| 165 | 
         
            +
                x_polar = torch.cat((pred_lat, pred_lon), dim=1).to(device)
         
     | 
| 166 | 
         
            +
                coords = x_polar.mul(180. / math.pi).cpu().detach().numpy()
         
     | 
| 167 | 
         
            +
                print(coords)
         
     | 
| 168 | 
         
            +
                
         
     | 
| 169 | 
         
            +
                
         
     | 
| 170 | 
         
            +
                coords = [[-lon, math.degrees(math.asin(math.sin(math.radians(lat))))] if lat > 90 else [lon, lat] for lat, lon in coords][0] # wraparound fix (lat > 90)
         
     | 
| 171 | 
         
            +
              
         
     | 
| 172 | 
         
            +
                # create plot
         
     | 
| 173 | 
         
            +
                guesses = gv.Points([coords]).opts(
         
     | 
| 174 | 
         
            +
                    size=8, cmap='Spectral_r', color='blue', fill_alpha=1
         
     | 
| 175 | 
         
            +
                    )
         
     | 
| 176 | 
         
            +
                plot = (gts.OSM * guesses).options(
         
     | 
| 177 | 
         
            +
                    gv.opts.Points(width=800, height=400, xlim=(-180*110000, 180*110000), ylim=(-90*140000, 90*140000), xaxis=None, yaxis=None)
         
     | 
| 178 | 
         
            +
                    )
         
     | 
| 179 | 
         
            +
                filename = f"{str(uuid.uuid4())}.png"
         
     | 
| 180 | 
         
            +
                gv.save(plot, filename=filename, fmt='png')
         
     | 
| 181 | 
         
            +
                coords = [round(i, 2) for i in coords]
         
     | 
| 182 | 
         
            +
                coords = [coords[1], coords[0]]
         
     | 
| 183 | 
         
            +
                print(filename, coords)
         
     | 
| 184 | 
         
            +
                return (filename, str(coords)[1:-1])
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
            gradio_app = gr.Interface(
         
     | 
| 187 | 
         
            +
                predict,
         
     | 
| 188 | 
         
            +
                inputs=gr.Audio(label="Record Audio (10 seconds)", type="filepath", min_length=10.0),
         
     | 
| 189 | 
         
            +
                outputs=[gr.Image(type="filepath", label="Map of Prediction"), gr.Textbox(placeholder="Latitude, Longitude", label="Prediction (Latitude, Longitude)")],
         
     | 
| 190 | 
         
            +
                title="Speech Geolocation Demo",
         
     | 
| 191 | 
         
            +
            )
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 194 | 
         
            +
                gradio_app.launch()
         
     |