import subprocess import jinja2 import gradio import matplotlib.pyplot as plt import numpy as np import base64 from io import BytesIO subprocess.run( ["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"]) #@article{asssome2022restoring, # title = {Restoring and attributing ancient texts using deep neural networks}, # author = {Assael*, Yannis and Sommerschield*, Thea and Shillingford, Brendan and Bordbar, Mahyar and Pavlopoulos, John and Chatzipanagiotou, Marita and Androutsopoulos, Ion and Prag, Jonathan and de Freitas, Nando}, # doi = {10.1038/s41586-022-04448-z}, # journal = {Nature}, # year = {2022} #} # Copyright 2021 the Ithaca Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Example for running inference. See also colab.""" import functools import pickle from ithaca.eval import inference from ithaca.models.model import Model from ithaca.util.alphabet import GreekAlphabet import jax def create_time_plot(attribution): class dataset_config: date_interval = 10 date_max = 800 date_min = -800 def bce_ad(d): if d < 0: return f'{abs(d)} BCE' elif d > 0: return f'{abs(d)} AD' return 0 #compute scores date_pred_y = np.array(attribution.year_scores) date_pred_x = np.arange( dataset_config.date_min + dataset_config.date_interval / 2, dataset_config.date_max + dataset_config.date_interval / 2, dataset_config.date_interval) date_pred_argmax = date_pred_y.argmax( ) * dataset_config.date_interval + dataset_config.date_min + dataset_config.date_interval // 2 date_pred_avg = np.dot(date_pred_y, date_pred_x) # Plot figure fig = plt.figure(figsize=(10, 5), dpi=100) plt.bar(date_pred_x, date_pred_y, color='#f2c852', width=10., label='Ithaca distribution') plt.axvline(x=date_pred_avg, color='#67ac5b', linewidth=2., label='Ithaca average') plt.ylabel('Probability', fontsize=14) yticks = np.arange(0, 1.1, 0.1) yticks_str = list(map(lambda x: f'{int(x*100)}%', yticks)) plt.yticks(yticks, yticks_str, fontsize=12, rotation=0) plt.ylim(0, int((date_pred_y.max()+0.1)*10)/10) plt.xlabel('Date', fontsize=14) xticks = list(range(dataset_config.date_min, dataset_config.date_max + 1, 25)) xticks_str = list(map(bce_ad, xticks)) plt.xticks(xticks, xticks_str, fontsize=12, rotation=0) plt.xlim(int(date_pred_avg - 100), int(date_pred_avg + 100)) plt.legend(loc='upper right', fontsize=12) #encode to base64 for html parsing tmpfile = BytesIO() fig.savefig(tmpfile, format='png') encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8') html = '
Input text: | {% for char in restoration_results.input_text -%} {%- if loop.index0 in prediction_idx -%} {{char}} {%- else -%} {{char}} {%- endif -%} {%- endfor %} | |
Hypothesis {{ loop.index }}: | {{ "%.1f%%"|format(100 * pred.score) }} | {% for char in pred.text -%} {%- if loop.index0 in prediction_idx -%} {{char}} {%- else -%} {{char}} {%- endif -%} {%- endfor %} |