File size: 7,418 Bytes
44cdc71
c108da3
caa64e7
c108da3
f84e083
 
 
9441c54
c108da3
7a31970
4849bdc
506bda4
04ac801
7bc74bc
8737454
4849bdc
c108da3
4849bdc
f84e083
 
 
ce8dee8
d0435f3
f12ecf0
1667997
f84e083
 
 
 
e40242b
6b74d17
f84e083
 
 
1667997
 
 
 
 
 
 
 
 
 
c108da3
9441c54
c108da3
1667997
 
 
 
 
 
 
9441c54
 
1667997
9441c54
 
 
1667997
9441c54
 
c108da3
9441c54
c108da3
d0c61b6
215f4a9
c108da3
215f4a9
d0c61b6
f84e083
1667997
 
 
 
f84e083
 
c108da3
9441c54
d0c61b6
1667997
 
8737454
 
 
44cdc71
c0b9a69
d33d65c
44cdc71
8737454
44cdc71
8737454
44cdc71
1d6eb67
22a4b4f
8737454
 
 
 
 
 
 
22a4b4f
669af95
 
22a4b4f
76264cd
 
8737454
fe81f5c
0f9cd45
 
 
 
fe81f5c
 
 
 
 
b95f5d7
fe81f5c
 
 
 
 
 
 
b95f5d7
0f9cd45
 
 
 
 
 
fe81f5c
 
b95f5d7
fe81f5c
b95f5d7
021d564
0f9cd45
 
cb746f1
0f9cd45
fe81f5c
 
 
b95f5d7
 
fe81f5c
b95f5d7
fe81f5c
b95f5d7
fe81f5c
b95f5d7
fe81f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b95f5d7
0f9cd45
9530b69
c108da3
44cdc71
c108da3
0f9cd45
44cdc71
0f9cd45
 
021d564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f9cd45
76264cd
0f9cd45
 
76264cd
8737454
c108da3
8737454
c108da3
27153aa
9441c54
ce8dee8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import re
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json  # Asegúrate de que esta línea esté al principio del archivo
import nltk
import os
import google.protobuf  # This line should execute without errors if protobuf is installed correctly
import sentencepiece
from transformers import pipeline, AutoTokenizer,AutoModelForSeq2SeqLM
import spacy


nltk.data.path.append(os.getenv('NLTK_DATA'))

app = FastAPI()

# Initialize the InferenceClient with your model
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")


class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.8
    max_new_tokens: int = 4000
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def format_prompt(current_prompt, history):
    formatted_history = "<s>"
    for entry in history:
        if entry["role"] == "user":
            formatted_history += f"[USER] {entry['content']} [/USER]"
        elif entry["role"] == "assistant":
            formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
    formatted_history += f"[USER] {current_prompt} [/USER]</s>"
    return formatted_history


def generate_stream(item: Item) -> Generator[bytes, None, None]:
    formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    # Estimate token count for the formatted_prompt
    input_token_count = len(nltk.word_tokenize(formatted_prompt))  # NLTK tokenization

    # Ensure total token count doesn't exceed the maximum limit
    max_tokens_allowed = 32768
    max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))

    generate_kwargs = {
        "temperature": item.temperature,
        "max_new_tokens": max_new_tokens_adjusted,
        "top_p": item.top_p,
        "repetition_penalty": item.repetition_penalty,
        "do_sample": True,
        "seed": 42,
    }

    # Stream the response from the InferenceClient
    for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
        # This assumes 'details=True' gives you a structure where you can access the text like this
        chunk = {
            "text": response.token.text,
            "complete": response.generated_text is not None  # Adjust based on how you detect completion
        }
        yield json.dumps(chunk).encode("utf-8") + b"\n"


class SummarizeRequest(BaseModel):
    text: str

@app.post("/generate/")
async def generate_text(item: Item):
    # Stream response back to the client
    return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")



# Load spaCy model
nlp = spacy.load("en_core_web_sm")

class TextRequest(BaseModel):
    text: str

def preprocess_text(text: str) -> str:
    # Normalize whitespace and strip punctuation
    text = re.sub(r'\s+', ' ', text.strip())
    text = re.sub(r'[^\w\s]', '', text)
    return text

def reduce_tokens(text: str):
    # Process the text with spaCy
    doc = nlp(text)
    # Select sentences that might be more important - this is a simple heuristic
    important_sentences = []
    for sent in doc.sents:
        if any(tok.dep_ == 'ROOT' for tok in sent):
            important_sentences.append(sent.text)
    # Join selected sentences to form the reduced text
    reduced_text = ' '.join(important_sentences)
    # Tokenize the reduced text to count the tokens
    reduced_doc = nlp(reduced_text)  # Ensure this line is correctly aligned
    token_count = len(reduced_doc)
    return reduced_text, token_count

def segment_text(text: str, max_tokens=500):  # Setting a conservative limit below 512
    doc = nlp(text)
    segments = []
    current_segment = []
    current_length = 0

    for sent in doc.sents:
        sentence = sent.text.strip()
        sentence_length = len(sentence.split())  # Counting words for simplicity

        if sentence_length > max_tokens:
            # Split long sentences into smaller chunks if a single sentence exceeds max_tokens
            words = sentence.split()
            while words:
                part = ' '.join(words[:max_tokens])
                segments.append(part)
                words = words[max_tokens:]
        elif current_length + sentence_length > max_tokens:
            segments.append(' '.join(current_segment))
            current_segment = [sentence]
            current_length = sentence_length
        else:
            current_segment.append(sentence)
            current_length += sentence_length

    if current_segment:  # Add the last segment
        segments.append(' '.join(current_segment))

    return segments


classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")


def classify_segments(segments):
    classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
    classified_segments = []
    
    for segment in segments:
        try:
            if len(segment.split()) <= 512:  # Double-check to avoid errors
                result = classifier(segment)
                classified_segments.append(result)
            else:
                classified_segments.append({"error": f"Segment too long: {len(segment.split())} tokens"})
        except Exception as e:
            classified_segments.append({"error": str(e)})

    return classified_segments


@app.post("/process_document")
async def process_document(request: TextRequest):
    try:
        processed_text = preprocess_text(request.text)
        segments = segment_text(processed_text)
        classified_segments = classify_segments(segments)

        return {
            "classified_segments": classified_segments
        }
    except Exception as e:
        print(f"Error during document processing: {e}")
        raise HTTPException(status_code=500, detail=str(e))



@app.post("/summarize")
async def summarize(request: TextRequest):
    try:
        # Preprocess and segment the text
        processed_text = preprocess_text(request.text)
        segments = segment_text(processed_text)

        # Classify each segment safely
        classified_segments = []
        for segment in segments:
            try:
                result = classifier(segment)
                classified_segments.append(result)
            except Exception as e:
                print(f"Error classifying segment: {e}")
                classified_segments.append({"error": str(e)})

        # Optional: Reduce tokens or summarize
        reduced_texts = []
        for segment in segments:
            try:
                reduced_text, token_count = reduce_tokens(segment)
                reduced_texts.append((reduced_text, token_count))
            except Exception as e:
                print(f"Error during token reduction: {e}")
                reduced_texts.append(("Error", 0))

        return {
            "classified_segments": classified_segments,
            "reduced_texts": reduced_texts
        }
    
    except Exception as e:
        print(f"Error during token reduction: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)