Spaces:
Runtime error
Runtime error
import pandas as pd | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
import torch | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
from pathlib import Path | |
def make_descriptions(file, title): | |
if Path(file).suffix == '.csv': | |
# print(file) | |
df = pd.read_csv(file) | |
print(df.head()) | |
columns = list(df.columns) | |
print(columns) | |
table_description0 = { | |
'path': 'random', | |
'number': 1, | |
'columns': ["clothes", "animals", "students"], | |
'title': "fashionable student clothes" | |
} | |
table_description1 = { | |
'path': file, | |
'number': 2, | |
'columns': columns, | |
'title': title | |
} | |
table_descriptions = [table_description0, table_description1] | |
return table_descriptions | |
else: | |
file_description = { | |
'path': file, | |
'number': 1, | |
'title': title | |
} | |
file_descriptions = [file_description] | |
return file_descriptions | |
def make_documents(pdf): | |
loader = PyPDFLoader(pdf) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0, separator='\n') | |
documents = text_splitter.split_documents(documents) | |
return documents | |
class Matcha_model: | |
def __init__(self) -> None: | |
# torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/20294671002019.png', 'chart_example.png') | |
# torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/multi_col_1081.png', 'chart_example_2.png') | |
# torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/18143564004789.png', 'chart_example_3.png') | |
# torch.hub.download_url_to_file('https://sharkcoder.com/files/article/matplotlib-bar-plot.png', 'chart_example_4.png') | |
self.model_name = "google/matcha-chartqa" | |
self.model = Pix2StructForConditionalGeneration.from_pretrained(self.model_name) | |
self.processor = Pix2StructProcessor.from_pretrained(self.model_name) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
def _filter_output(self, output): | |
return output.replace("<0x0A>", "") | |
def chart_qa(self, image, question: str) -> str: | |
inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device) | |
predictions = self.model.generate(**inputs, max_new_tokens=512) | |
return self._filter_output(self.processor.decode(predictions[0], skip_special_tokens=True)) | |