Spaces:
Sleeping
Sleeping
File size: 7,418 Bytes
44cdc71 c108da3 caa64e7 c108da3 f84e083 9441c54 c108da3 7a31970 4849bdc 506bda4 04ac801 7bc74bc 8737454 4849bdc c108da3 4849bdc f84e083 ce8dee8 d0435f3 f12ecf0 1667997 f84e083 e40242b 6b74d17 f84e083 1667997 c108da3 9441c54 c108da3 1667997 9441c54 1667997 9441c54 1667997 9441c54 c108da3 9441c54 c108da3 d0c61b6 215f4a9 c108da3 215f4a9 d0c61b6 f84e083 1667997 f84e083 c108da3 9441c54 d0c61b6 1667997 8737454 44cdc71 c0b9a69 d33d65c 44cdc71 8737454 44cdc71 8737454 44cdc71 1d6eb67 22a4b4f 8737454 22a4b4f 669af95 22a4b4f 76264cd 8737454 fe81f5c 0f9cd45 fe81f5c b95f5d7 fe81f5c b95f5d7 0f9cd45 fe81f5c b95f5d7 fe81f5c b95f5d7 021d564 0f9cd45 cb746f1 0f9cd45 fe81f5c b95f5d7 fe81f5c b95f5d7 fe81f5c b95f5d7 fe81f5c b95f5d7 fe81f5c b95f5d7 0f9cd45 9530b69 c108da3 44cdc71 c108da3 0f9cd45 44cdc71 0f9cd45 021d564 0f9cd45 76264cd 0f9cd45 76264cd 8737454 c108da3 8737454 c108da3 27153aa 9441c54 ce8dee8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import re
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json # Asegúrate de que esta línea esté al principio del archivo
import nltk
import os
import google.protobuf # This line should execute without errors if protobuf is installed correctly
import sentencepiece
from transformers import pipeline, AutoTokenizer,AutoModelForSeq2SeqLM
import spacy
nltk.data.path.append(os.getenv('NLTK_DATA'))
app = FastAPI()
# Initialize the InferenceClient with your model
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.8
max_new_tokens: int = 4000
top_p: float = 0.15
repetition_penalty: float = 1.0
def format_prompt(current_prompt, history):
formatted_history = "<s>"
for entry in history:
if entry["role"] == "user":
formatted_history += f"[USER] {entry['content']} [/USER]"
elif entry["role"] == "assistant":
formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
formatted_history += f"[USER] {current_prompt} [/USER]</s>"
return formatted_history
def generate_stream(item: Item) -> Generator[bytes, None, None]:
formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
# Estimate token count for the formatted_prompt
input_token_count = len(nltk.word_tokenize(formatted_prompt)) # NLTK tokenization
# Ensure total token count doesn't exceed the maximum limit
max_tokens_allowed = 32768
max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))
generate_kwargs = {
"temperature": item.temperature,
"max_new_tokens": max_new_tokens_adjusted,
"top_p": item.top_p,
"repetition_penalty": item.repetition_penalty,
"do_sample": True,
"seed": 42,
}
# Stream the response from the InferenceClient
for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
# This assumes 'details=True' gives you a structure where you can access the text like this
chunk = {
"text": response.token.text,
"complete": response.generated_text is not None # Adjust based on how you detect completion
}
yield json.dumps(chunk).encode("utf-8") + b"\n"
class SummarizeRequest(BaseModel):
text: str
@app.post("/generate/")
async def generate_text(item: Item):
# Stream response back to the client
return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
# Load spaCy model
nlp = spacy.load("en_core_web_sm")
class TextRequest(BaseModel):
text: str
def preprocess_text(text: str) -> str:
# Normalize whitespace and strip punctuation
text = re.sub(r'\s+', ' ', text.strip())
text = re.sub(r'[^\w\s]', '', text)
return text
def reduce_tokens(text: str):
# Process the text with spaCy
doc = nlp(text)
# Select sentences that might be more important - this is a simple heuristic
important_sentences = []
for sent in doc.sents:
if any(tok.dep_ == 'ROOT' for tok in sent):
important_sentences.append(sent.text)
# Join selected sentences to form the reduced text
reduced_text = ' '.join(important_sentences)
# Tokenize the reduced text to count the tokens
reduced_doc = nlp(reduced_text) # Ensure this line is correctly aligned
token_count = len(reduced_doc)
return reduced_text, token_count
def segment_text(text: str, max_tokens=500): # Setting a conservative limit below 512
doc = nlp(text)
segments = []
current_segment = []
current_length = 0
for sent in doc.sents:
sentence = sent.text.strip()
sentence_length = len(sentence.split()) # Counting words for simplicity
if sentence_length > max_tokens:
# Split long sentences into smaller chunks if a single sentence exceeds max_tokens
words = sentence.split()
while words:
part = ' '.join(words[:max_tokens])
segments.append(part)
words = words[max_tokens:]
elif current_length + sentence_length > max_tokens:
segments.append(' '.join(current_segment))
current_segment = [sentence]
current_length = sentence_length
else:
current_segment.append(sentence)
current_length += sentence_length
if current_segment: # Add the last segment
segments.append(' '.join(current_segment))
return segments
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
def classify_segments(segments):
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
classified_segments = []
for segment in segments:
try:
if len(segment.split()) <= 512: # Double-check to avoid errors
result = classifier(segment)
classified_segments.append(result)
else:
classified_segments.append({"error": f"Segment too long: {len(segment.split())} tokens"})
except Exception as e:
classified_segments.append({"error": str(e)})
return classified_segments
@app.post("/process_document")
async def process_document(request: TextRequest):
try:
processed_text = preprocess_text(request.text)
segments = segment_text(processed_text)
classified_segments = classify_segments(segments)
return {
"classified_segments": classified_segments
}
except Exception as e:
print(f"Error during document processing: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/summarize")
async def summarize(request: TextRequest):
try:
# Preprocess and segment the text
processed_text = preprocess_text(request.text)
segments = segment_text(processed_text)
# Classify each segment safely
classified_segments = []
for segment in segments:
try:
result = classifier(segment)
classified_segments.append(result)
except Exception as e:
print(f"Error classifying segment: {e}")
classified_segments.append({"error": str(e)})
# Optional: Reduce tokens or summarize
reduced_texts = []
for segment in segments:
try:
reduced_text, token_count = reduce_tokens(segment)
reduced_texts.append((reduced_text, token_count))
except Exception as e:
print(f"Error during token reduction: {e}")
reduced_texts.append(("Error", 0))
return {
"classified_segments": classified_segments,
"reduced_texts": reduced_texts
}
except Exception as e:
print(f"Error during token reduction: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|