File size: 2,032 Bytes
76aeebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import fasttext

class SimpleMultilingualClassifier(nn.Module):
    def __init__(self, embedding_files, num_classes, embedding_dim=100):
        super().__init__()
        self.embedding_files = embedding_files
        self.embedding_dim = embedding_dim
        self.linear = nn.Linear(embedding_dim, num_classes)
        self.language_models = {}
        for lang, path in embedding_files.items():
            self.language_models[lang] = fasttext.load_model(path)

    def get_embedding(self, text, lang):
        if lang in self.language_models:
            return torch.tensor(self.language_models[lang].get_sentence_vector(text))
        else:
            raise ValueError(f"Language '{lang}' not supported.")

    def forward(self, text, lang):
        embedding = self.get_embedding(text, lang)
        return self.linear(embedding)

    def predict(self, text, lang, class_labels):
        self.eval()
        with torch.no_grad():
            output = self.forward(text, lang).unsqueeze(0) # Add batch dimension
            probabilities = torch.softmax(output, dim=-1)
            predicted_class_index = torch.argmax(probabilities, dim=-1).item()
            return class_labels[predicted_class_index]

# Example usage (you'd need to define your classes and supported languages)
if __name__ == '__main__':
    embedding_files = {
        'en': 'fasttext_embeddings/cc.en.100.bin',
        'fr': 'fasttext_embeddings/cc.fr.100.bin'
    }
    num_classes = 3  # Example number of classes
    class_labels = ["positive", "negative", "neutral"]
    model = SimpleMultilingualClassifier(embedding_files, num_classes)

    # Dummy prediction
    text_en = "This is a great movie."
    lang_en = 'en'
    prediction_en = model.predict(text_en, lang_en, class_labels)
    print(f"English Prediction: {prediction_en}")

    text_fr = "C'est un film incroyable."
    lang_fr = 'fr'
    prediction_fr = model.predict(text_fr, lang_fr, class_labels)
    print(f"French Prediction: {prediction_fr}")