File size: 6,716 Bytes
d124aee
 
 
 
 
 
 
043ea5c
572201a
 
 
3402390
0af0099
3402390
d124aee
 
 
 
 
 
 
 
 
572201a
 
287b917
d124aee
 
 
 
 
 
 
 
 
 
 
 
 
 
eb444c6
d124aee
0af0099
 
 
 
 
eb444c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0af0099
 
 
 
 
 
 
 
 
 
d124aee
418bed4
287b917
 
d124aee
 
 
 
9a546d7
d124aee
 
 
 
 
 
 
 
 
7906865
d124aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572201a
287b917
572201a
287b917
 
eb444c6
a0384f7
d124aee
 
eb444c6
 
0af0099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d124aee
0af0099
d124aee
 
0af0099
d124aee
 
da5ada1
 
eb444c6
 
 
 
 
 
 
 
 
a0384f7
da5ada1
3402390
da5ada1
d124aee
 
 
 
 
0af0099
d124aee
eb444c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info
import logging
from typing import Tuple, Literal
import functools
import spaces
from cachetools import TTLCache
from cachetools.func import ttl_cache
import time
import os
import json
os.environ['HF_TRANSFER'] = "1"
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables
MODEL_NAME = "davanstrien/Smol-Hub-tldr"
model = None
tokenizer = None
device = None
CACHE_TTL = 6 * 60 * 60  # 6 hours in seconds
CACHE_MAXSIZE = 100

def load_model():
    global model, tokenizer, device
    logger.info("Loading model and tokenizer...")
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
        model = model.to(device)
        model.eval()
        return True
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        return False

def get_card_info(hub_id: str, repo_type: str = "auto") -> Tuple[str, str]:
    """Get card information from a Hugging Face hub_id."""
    model_exists = False
    dataset_exists = False
    model_text = None
    dataset_text = None

    # Handle based on repo type
    if repo_type == "auto":
        # Try getting model card
        try:
            info = model_info(hub_id)
            card = ModelCard.load(hub_id)
            model_exists = True
            model_text = card.text
        except Exception as e:
            logger.debug(f"No model card found for {hub_id}: {e}")

        # Try getting dataset card
        try:
            info = dataset_info(hub_id)
            card = DatasetCard.load(hub_id)
            dataset_exists = True
            dataset_text = card.text
        except Exception as e:
            logger.debug(f"No dataset card found for {hub_id}: {e}")
    elif repo_type == "model":
        try:
            info = model_info(hub_id)
            card = ModelCard.load(hub_id)
            model_exists = True
            model_text = card.text
        except Exception as e:
            logger.error(f"Failed to get model card for {hub_id}: {e}")
            raise ValueError(f"Could not find model with id {hub_id}")
    elif repo_type == "dataset":
        try:
            info = dataset_info(hub_id)
            card = DatasetCard.load(hub_id)
            dataset_exists = True
            dataset_text = card.text
        except Exception as e:
            logger.error(f"Failed to get dataset card for {hub_id}: {e}")
            raise ValueError(f"Could not find dataset with id {hub_id}")
    else:
        raise ValueError(f"Invalid repo_type: {repo_type}. Must be 'auto', 'model', or 'dataset'")

    # Handle different cases
    if model_exists and dataset_exists:
        return "both", (model_text, dataset_text)
    elif model_exists:
        return "model", model_text
    elif dataset_exists:
        return "dataset", dataset_text
    else:
        raise ValueError(f"Could not find model or dataset with id {hub_id}")

@spaces.GPU
def _generate_summary_gpu(card_text: str, card_type: str) -> str:
    """Internal function that runs on GPU."""
    # Determine prefix based on card type
    prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"

    # Format input according to the chat template
    messages = [{"role": "user", "content": f"{prefix}{card_text[:5000]}"}]
    inputs = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    )
    inputs = inputs.to(device)

    # Generate with optimized settings
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=60,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            temperature=0.4,
            do_sample=True,
            use_cache=True,
        )

    # Extract and clean up the summary
    input_length = inputs.shape[1]
    response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False)

    # Extract just the summary part
    try:
        summary = response.split("<CARD_SUMMARY>")[-1].split("</CARD_SUMMARY>")[0].strip()
    except IndexError:
        summary = response.strip()

    return summary

@ttl_cache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL)
def generate_summary(card_text: str, card_type: str) -> str:
    """Cached wrapper for generate_summary with TTL."""
    return _generate_summary_gpu(card_text, card_type)

def summarize(hub_id: str = "", repo_type: str = "auto") -> str:
    """Interface function for Gradio. Returns JSON format."""
    try:
        if hub_id:
            # Fetch card information with specified repo_type
            card_type, card_text = get_card_info(hub_id, repo_type)
            
            if card_type == "both":
                model_text, dataset_text = card_text
                model_summary = generate_summary(model_text, "model")
                dataset_summary = generate_summary(dataset_text, "dataset")
                return json.dumps({
                    "type": "both",
                    "hub_id": hub_id,
                    "model_summary": model_summary,
                    "dataset_summary": dataset_summary
                })
            else:
                summary = generate_summary(card_text, card_type)
                return json.dumps({
                    "summary": summary,
                    "type": card_type,
                    "hub_id": hub_id
                })
        else:
            return json.dumps({"error": "Hub ID must be provided"})

    except Exception as e:
        return json.dumps({"error": str(e)})

def create_interface():
    interface = gr.Interface(
        fn=summarize,
        inputs=[
            gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
            gr.Radio(
                choices=["auto", "model", "dataset"], 
                value="auto", 
                label="Repository Type",
                info="Choose 'auto' to detect automatically, or specify the repository type"
            )
        ],
        outputs=gr.JSON(label="Output"),
        title="Hugging Face Hub TLDR Generator",
        description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
    )
    return interface

if __name__ == "__main__":
    if load_model():
        interface = create_interface()
        interface.launch()
    else:
        print("Failed to load model. Please check the logs for details.")