File size: 6,716 Bytes
d124aee 043ea5c 572201a 3402390 0af0099 3402390 d124aee 572201a 287b917 d124aee eb444c6 d124aee 0af0099 eb444c6 0af0099 d124aee 418bed4 287b917 d124aee 9a546d7 d124aee 7906865 d124aee 572201a 287b917 572201a 287b917 eb444c6 a0384f7 d124aee eb444c6 0af0099 d124aee 0af0099 d124aee 0af0099 d124aee da5ada1 eb444c6 a0384f7 da5ada1 3402390 da5ada1 d124aee 0af0099 d124aee eb444c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info
import logging
from typing import Tuple, Literal
import functools
import spaces
from cachetools import TTLCache
from cachetools.func import ttl_cache
import time
import os
import json
os.environ['HF_TRANSFER'] = "1"
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
MODEL_NAME = "davanstrien/Smol-Hub-tldr"
model = None
tokenizer = None
device = None
CACHE_TTL = 6 * 60 * 60 # 6 hours in seconds
CACHE_MAXSIZE = 100
def load_model():
global model, tokenizer, device
logger.info("Loading model and tokenizer...")
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = model.to(device)
model.eval()
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
def get_card_info(hub_id: str, repo_type: str = "auto") -> Tuple[str, str]:
"""Get card information from a Hugging Face hub_id."""
model_exists = False
dataset_exists = False
model_text = None
dataset_text = None
# Handle based on repo type
if repo_type == "auto":
# Try getting model card
try:
info = model_info(hub_id)
card = ModelCard.load(hub_id)
model_exists = True
model_text = card.text
except Exception as e:
logger.debug(f"No model card found for {hub_id}: {e}")
# Try getting dataset card
try:
info = dataset_info(hub_id)
card = DatasetCard.load(hub_id)
dataset_exists = True
dataset_text = card.text
except Exception as e:
logger.debug(f"No dataset card found for {hub_id}: {e}")
elif repo_type == "model":
try:
info = model_info(hub_id)
card = ModelCard.load(hub_id)
model_exists = True
model_text = card.text
except Exception as e:
logger.error(f"Failed to get model card for {hub_id}: {e}")
raise ValueError(f"Could not find model with id {hub_id}")
elif repo_type == "dataset":
try:
info = dataset_info(hub_id)
card = DatasetCard.load(hub_id)
dataset_exists = True
dataset_text = card.text
except Exception as e:
logger.error(f"Failed to get dataset card for {hub_id}: {e}")
raise ValueError(f"Could not find dataset with id {hub_id}")
else:
raise ValueError(f"Invalid repo_type: {repo_type}. Must be 'auto', 'model', or 'dataset'")
# Handle different cases
if model_exists and dataset_exists:
return "both", (model_text, dataset_text)
elif model_exists:
return "model", model_text
elif dataset_exists:
return "dataset", dataset_text
else:
raise ValueError(f"Could not find model or dataset with id {hub_id}")
@spaces.GPU
def _generate_summary_gpu(card_text: str, card_type: str) -> str:
"""Internal function that runs on GPU."""
# Determine prefix based on card type
prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
# Format input according to the chat template
messages = [{"role": "user", "content": f"{prefix}{card_text[:5000]}"}]
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
inputs = inputs.to(device)
# Generate with optimized settings
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=60,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.4,
do_sample=True,
use_cache=True,
)
# Extract and clean up the summary
input_length = inputs.shape[1]
response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False)
# Extract just the summary part
try:
summary = response.split("<CARD_SUMMARY>")[-1].split("</CARD_SUMMARY>")[0].strip()
except IndexError:
summary = response.strip()
return summary
@ttl_cache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL)
def generate_summary(card_text: str, card_type: str) -> str:
"""Cached wrapper for generate_summary with TTL."""
return _generate_summary_gpu(card_text, card_type)
def summarize(hub_id: str = "", repo_type: str = "auto") -> str:
"""Interface function for Gradio. Returns JSON format."""
try:
if hub_id:
# Fetch card information with specified repo_type
card_type, card_text = get_card_info(hub_id, repo_type)
if card_type == "both":
model_text, dataset_text = card_text
model_summary = generate_summary(model_text, "model")
dataset_summary = generate_summary(dataset_text, "dataset")
return json.dumps({
"type": "both",
"hub_id": hub_id,
"model_summary": model_summary,
"dataset_summary": dataset_summary
})
else:
summary = generate_summary(card_text, card_type)
return json.dumps({
"summary": summary,
"type": card_type,
"hub_id": hub_id
})
else:
return json.dumps({"error": "Hub ID must be provided"})
except Exception as e:
return json.dumps({"error": str(e)})
def create_interface():
interface = gr.Interface(
fn=summarize,
inputs=[
gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
gr.Radio(
choices=["auto", "model", "dataset"],
value="auto",
label="Repository Type",
info="Choose 'auto' to detect automatically, or specify the repository type"
)
],
outputs=gr.JSON(label="Output"),
title="Hugging Face Hub TLDR Generator",
description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
)
return interface
if __name__ == "__main__":
if load_model():
interface = create_interface()
interface.launch()
else:
print("Failed to load model. Please check the logs for details.") |