|
import gradio as gr |
|
import torch |
|
|
|
|
|
EXAMPLE_MD = """ |
|
```python |
|
import torch |
|
|
|
t1 = torch.arange({n1}).view({dim1}) |
|
|
|
t2 = torch.arange({n2}).view({dim2}) |
|
|
|
(t1 @ t2).shape = {out_shape} |
|
|
|
``` |
|
|
|
""" |
|
|
|
|
|
def generate_example(dim1: list, dim2: list): |
|
n1 = 1 |
|
n2 = 1 |
|
for i in dim1: |
|
n1 *= i |
|
for i in dim2: |
|
n2 *= i |
|
|
|
t1 = torch.arange(n1).view(dim1) |
|
t2 = torch.arange(n2).view(dim2) |
|
try: |
|
out_shape = list((t1 @ t2).shape) |
|
except RuntimeError: |
|
out_shape = "error" |
|
|
|
code = EXAMPLE_MD.format( |
|
n1=str(n1), dim1=str(dim1), n2=str(n2), dim2=str(dim2), out_shape=str(out_shape) |
|
) |
|
|
|
return dim1, dim2, code |
|
|
|
|
|
def sanitize_dimention(dim): |
|
if dim is None: |
|
gr.Error("one of the dimentions is empty, please fill it") |
|
if "[" in dim: |
|
dim = dim.replace("[", "") |
|
if "]" in dim: |
|
dim = dim.replace("]", "") |
|
if "," in dim: |
|
dim = dim.replace(",", " ").strip() |
|
out = [int(i.strip()) for i in dim.split()] |
|
else: |
|
out = [int(dim.strip())] |
|
if 0 in out: |
|
gr.Error( |
|
"Found the number 0 in one of the dimensions which is not allowed, consider using 1 instead" |
|
) |
|
return out |
|
|
|
|
|
def create_row(dim): |
|
out = "| " |
|
for i in dim: |
|
out = out + str(i) + " | " |
|
return out + "\n" |
|
|
|
|
|
def create_header(n_dim, checks=None): |
|
checks = ["<!-- -->"] * n_dim if checks is None else checks |
|
out = "| " |
|
for i in checks: |
|
out = out + i + " | " |
|
out += "\n" + "|---" * n_dim + "|\n" |
|
return out |
|
|
|
|
|
def generate_table(dim1, dim2, checks=None): |
|
n_dim = len(dim1) |
|
table = create_header(n_dim, checks) |
|
|
|
table += create_row(dim1) |
|
|
|
table += create_row(dim2) |
|
return table |
|
|
|
|
|
def alignment_and_fill_with_ones(dim1, dim2): |
|
n_dim = max(len(dim1), len(dim2)) |
|
|
|
if len(dim1) == len(dim2): |
|
pass |
|
elif len(dim1) < len(dim2): |
|
placeholder = [1] * (n_dim - len(dim1)) |
|
placeholder.extend(dim1) |
|
dim1 = placeholder |
|
else: |
|
placeholder = [1] * (n_dim - len(dim2)) |
|
placeholder.extend(dim2) |
|
dim2 = placeholder |
|
return dim1, dim2 |
|
|
|
def check_validity(dim1,dim2): |
|
if len(dim1) < 2: |
|
return ["WIP"] * len(dim1) |
|
out = [] |
|
for i in range(len(dim1)-2): |
|
if dim1[i] == dim2[i]: |
|
out.append("V") |
|
else : |
|
out.append("X") |
|
|
|
if dim1[-1] == dim2[-2]: |
|
out.extend(["V","V"]) |
|
else : |
|
out.extend(["X","X"]) |
|
return out |
|
|
|
|
|
def substitute_ones_with_concat(dim1,dim2): |
|
for i in range(len(dim1)-2): |
|
dim1[i] = dim2[i] if dim1[i] == 1 else dim1[i] |
|
dim2[i] = dim1[i] if dim2[i] == 1 else dim2[i] |
|
return dim1, dim2 |
|
|
|
def predict(dim1, dim2): |
|
dim1 = sanitize_dimention(dim1) |
|
dim2 = sanitize_dimention(dim2) |
|
dim1, dim2, code = generate_example(dim1, dim2) |
|
|
|
|
|
|
|
dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2) |
|
table1 = generate_table(dim1, dim2) |
|
|
|
dim1, dim2 = substitute_ones_with_concat(dim1,dim2) |
|
table2 = generate_table(dim1, dim2) |
|
|
|
checks = check_validity(dim1,dim2) |
|
table3 = generate_table(dim1,dim2,checks) |
|
|
|
out = code |
|
out += "\n# Step1 (alignment and pre_append with ones)\n" + table1 |
|
out += "\n# Step2 (susbtitute columns that have 1 with concat)\nexcept for last 2 dimensions\n" + table2 |
|
out += "\n# Step3 (check if matrix multiplication is valid)\n" |
|
out += "* last dimension of dim1 should equal before last dimension of dim2\n" |
|
out += "* all the other dimensions should be equal to one another\n\n" + table3 |
|
return out |
|
|
|
|
|
demo = gr.Interface( |
|
predict, |
|
inputs=["text", "text"], |
|
outputs=["markdown"], |
|
examples=[["9,2,1,3,3", "5,3,7"], ["1,2,3", "5,2,7"]], |
|
) |
|
|
|
demo.launch(debug=True) |
|
|