Spaces:
Runtime error
Runtime error
"""Gmail source.""" | |
import base64 | |
import dataclasses | |
import os.path | |
import random | |
import re | |
from datetime import datetime | |
from time import sleep | |
from typing import TYPE_CHECKING, Any, Iterable, Optional | |
from pydantic import Field as PydanticField | |
from typing_extensions import override | |
from ..env import data_path | |
from ..schema import Item, field | |
from ..utils import log | |
from .source import Source, SourceSchema | |
if TYPE_CHECKING: | |
from google.oauth2.credentials import Credentials | |
# If modifying these scopes, delete the token json file. | |
_SCOPES = ['https://www.googleapis.com/auth/gmail.readonly'] | |
_GMAIL_CONFIG_DIR = os.path.join(data_path(), '.gmail') | |
_TOKEN_FILENAME = 'token.json' | |
_CREDS_FILENAME = 'credentials.json' | |
_NUM_RETRIES = 10 | |
_MAX_NUM_THREADS = 30_000 | |
_UNWRAP_PATTERN = re.compile(r'(\S)\n(\S)') | |
HTTP_PATTERN = re.compile(r'https?://[^\s]+') | |
class GmailSource(Source): | |
"""Connects to your Gmail and loads the text of your emails. | |
**One time setup** | |
Download the OAuth credentials file from the | |
[Google Cloud Console](https://console.cloud.google.com/apis/credentials) and save it to the | |
correct location. See | |
[guide](https://developers.google.com/gmail/api/quickstart/python#authorize_credentials_for_a_desktop_application) | |
for details. | |
""" | |
name = 'gmail' | |
credentials_file: str = PydanticField( | |
description='Path to the OAuth credentials file.', | |
default=os.path.join(_GMAIL_CONFIG_DIR, _CREDS_FILENAME)) | |
_creds: Optional['Credentials'] = None | |
class Config: | |
# Language is required even though it has a default value. | |
schema_extra = {'required': ['credentials_file']} | |
def setup(self) -> None: | |
try: | |
from google.auth.transport.requests import Request | |
from google.oauth2.credentials import Credentials | |
from google_auth_oauthlib.flow import InstalledAppFlow | |
except ImportError: | |
raise ImportError('Could not import dependencies for the "gmail" source. ' | |
'Please install with pip install lilac[gmail]') | |
# The token file stores the user's access and refresh tokens, and is created automatically when | |
# the authorization flow completes for the first time. | |
token_filepath = os.path.join(_GMAIL_CONFIG_DIR, _TOKEN_FILENAME) | |
if os.path.exists(token_filepath): | |
self._creds = Credentials.from_authorized_user_file(token_filepath, _SCOPES) | |
# If there are no (valid) credentials available, let the user log in. | |
if not self._creds or not self._creds.valid: | |
if self._creds and self._creds.expired and self._creds.refresh_token: | |
self._creds.refresh(Request()) | |
else: | |
if not os.path.exists(self.credentials_file): | |
raise ValueError( | |
f'Could not find the OAuth credentials file at "{self.credentials_file}". Make sure to ' | |
'download it from the Google Cloud Console and save it to the correct location.') | |
flow = InstalledAppFlow.from_client_secrets_file(self.credentials_file, _SCOPES) | |
self._creds = flow.run_local_server() | |
os.makedirs(os.path.dirname(token_filepath), exist_ok=True) | |
# Save the token for the next run. | |
with open(token_filepath, 'w') as token: | |
token.write(self._creds.to_json()) | |
def source_schema(self) -> SourceSchema: | |
return SourceSchema( | |
fields={ | |
'body': field('string'), | |
'snippet': field('string'), | |
'dates': field(fields=['string']), | |
'subject': field('string'), | |
}) | |
def process(self) -> Iterable[Item]: | |
try: | |
from email_reply_parser import EmailReplyParser | |
from googleapiclient.discovery import build | |
from googleapiclient.errors import HttpError | |
except ImportError: | |
raise ImportError('Could not import dependencies for the "gmail" source. ' | |
'Please install with pip install lilac[gmail]') | |
# Call the Gmail API | |
service = build('gmail', 'v1', credentials=self._creds) | |
# threads.list API | |
threads_resource = service.users().threads() | |
thread_batch: list[Item] = [] | |
retry_batch: set[str] = set() | |
num_retries = 0 | |
num_threads_fetched = 0 | |
def _thread_fetched(request_id: str, response: Any, exception: Optional[HttpError]) -> None: | |
if exception is not None: | |
retry_batch.add(request_id) | |
return | |
replies: list[str] = [] | |
dates: list[str] = [] | |
snippets: list[str] = [] | |
subject: Optional[str] = None | |
for msg in response['messages']: | |
epoch_sec = int(msg['internalDate']) / 1000. | |
date = datetime.fromtimestamp(epoch_sec).strftime('%Y-%m-%d %H:%M:%S') | |
dates.append(date) | |
if 'snippet' in msg: | |
snippets.append(msg['snippet']) | |
email_info = _parse_payload(msg['payload']) | |
subject = subject or email_info.subject | |
parsed_parts: list[str] = [] | |
for body in email_info.parts: | |
if not body: | |
continue | |
text = base64.urlsafe_b64decode(body).decode('utf-8') | |
text = EmailReplyParser.parse_reply(text) | |
# Unwrap text. | |
text = _UNWRAP_PATTERN.sub('\\1 \\2', text) | |
# Remove URLs. | |
text = HTTP_PATTERN.sub('', text) | |
if text: | |
parsed_parts.append(text) | |
if email_info.sender and parsed_parts: | |
parsed_parts = [ | |
f'--------------------{email_info.sender}--------------------', *parsed_parts | |
] | |
if parsed_parts: | |
replies.append('\n'.join(parsed_parts)) | |
if replies: | |
thread_batch.append({ | |
'body': '\n\n'.join(replies), | |
'snippet': '\n'.join(snippets) if snippets else None, | |
'dates': dates, | |
'subject': subject, | |
}) | |
if request_id in retry_batch: | |
retry_batch.remove(request_id) | |
# First request. | |
thread_list_req = threads_resource.list(userId='me', includeSpamTrash=False) or None | |
thread_list = thread_list_req.execute(num_retries=_NUM_RETRIES) if thread_list_req else None | |
while (num_threads_fetched < _MAX_NUM_THREADS and thread_list and thread_list_req): | |
batch = service.new_batch_http_request(callback=_thread_fetched) | |
threads = thread_list['threads'] if 'threads' in thread_list else [] | |
for gmail_thread in threads: | |
thread_id = gmail_thread['id'] if 'id' in gmail_thread else None | |
if not thread_id: | |
continue | |
if not retry_batch or (thread_id in retry_batch): | |
batch.add( | |
service.users().threads().get(userId='me', id=thread_id, format='full'), | |
request_id=thread_id) | |
batch.execute() | |
num_threads_fetched += len(thread_batch) | |
yield from thread_batch | |
thread_batch = [] | |
if retry_batch: | |
log(f'Failed to fetch {len(retry_batch)} threads. Retrying...') | |
timeout = 2**(num_retries - 1) + random.uniform(0, 1) | |
sleep(timeout) | |
num_retries += 1 | |
else: | |
retry_batch = set() | |
num_retries = 0 | |
# Fetch next page. | |
thread_list_req = threads_resource.list_next(thread_list_req, thread_list) | |
thread_list = thread_list_req.execute(num_retries=_NUM_RETRIES) if thread_list_req else None | |
class EmailInfo: | |
"""Stores parsed information about an email.""" | |
sender: Optional[str] = None | |
subject: Optional[str] = None | |
parts: list[bytes] = dataclasses.field(default_factory=list) | |
def _get_header(payload: Any, name: str) -> Optional[str]: | |
if 'headers' not in payload: | |
return None | |
values = [h['value'] for h in payload['headers'] if h['name'].lower().strip() == name] | |
return values[0] if values else None | |
def _parse_payload(payload: Any) -> EmailInfo: | |
sender = _get_header(payload, 'from') | |
subject = _get_header(payload, 'subject') | |
parts: list[bytes] = [] | |
# Process the message body. | |
if 'mimeType' in payload and 'text/plain' in payload['mimeType']: | |
if 'body' in payload and 'data' in payload['body']: | |
parts.append(payload['body']['data'].encode('ascii')) | |
# Process the message parts. | |
for part in payload.get('parts', []): | |
email_info = _parse_payload(part) | |
sender = sender or email_info.sender | |
subject = subject or email_info.subject | |
parts.extend(email_info.parts) | |
return EmailInfo(sender, subject, parts) | |