| | #include "GlobalLexicalModelUnlimited.h" |
| | #include <fstream> |
| | #include "moses/StaticData.h" |
| | #include "moses/InputFileStream.h" |
| | #include "moses/Hypothesis.h" |
| | #include "moses/TranslationTask.h" |
| | #include "util/string_piece_hash.hh" |
| | #include "util/string_stream.hh" |
| |
|
| | using namespace std; |
| |
|
| | namespace Moses |
| | { |
| | GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const std::string &line) |
| | :StatelessFeatureFunction(0, line) |
| | { |
| | UTIL_THROW(util::Exception, |
| | "GlobalLexicalModelUnlimited hasn't been refactored for new feature function framework yet"); |
| |
|
| | const vector<string> modelSpec = Tokenize(line); |
| |
|
| | for (size_t i = 0; i < modelSpec.size(); i++ ) { |
| | bool ignorePunctuation = true, biasFeature = false, restricted = false; |
| | size_t context = 0; |
| | string filenameSource, filenameTarget; |
| | vector< string > factors; |
| | vector< string > spec = Tokenize(modelSpec[i]," "); |
| |
|
| | |
| | if (spec.size() > 0) { |
| | if (spec.size() != 2 && spec.size() != 3 && spec.size() != 4 && spec.size() != 6) { |
| | std::cerr << "Format of glm feature is <factor-src>-<factor-tgt> [ignore-punct] [use-bias] " |
| | << "[context-type] [filename-src filename-tgt]"; |
| | |
| | } |
| |
|
| | factors = Tokenize(spec[0],"-"); |
| | if (spec.size() >= 2) |
| | ignorePunctuation = Scan<size_t>(spec[1]); |
| | if (spec.size() >= 3) |
| | biasFeature = Scan<size_t>(spec[2]); |
| | if (spec.size() >= 4) |
| | context = Scan<size_t>(spec[3]); |
| | if (spec.size() == 6) { |
| | filenameSource = spec[4]; |
| | filenameTarget = spec[5]; |
| | restricted = true; |
| | } |
| | } else |
| | factors = Tokenize(modelSpec[i],"-"); |
| |
|
| | if ( factors.size() != 2 ) { |
| | std::cerr << "Wrong factor definition for global lexical model unlimited: " << modelSpec[i]; |
| | |
| | } |
| |
|
| | const vector<FactorType> inputFactors = Tokenize<FactorType>(factors[0],","); |
| | const vector<FactorType> outputFactors = Tokenize<FactorType>(factors[1],","); |
| | throw runtime_error("GlobalLexicalModelUnlimited should be reimplemented as a stateful feature"); |
| | GlobalLexicalModelUnlimited* glmu = NULL; |
| |
|
| | if (restricted) { |
| | cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl; |
| | if (!glmu->Load(filenameSource, filenameTarget)) { |
| | std::cerr << "Unable to load word lists for word translation feature from files " |
| | << filenameSource |
| | << " and " |
| | << filenameTarget; |
| | |
| | } |
| | } |
| | } |
| | } |
| |
|
| | bool GlobalLexicalModelUnlimited::Load(const std::string &filePathSource, |
| | const std::string &filePathTarget) |
| | { |
| | |
| | ifstream inFileSource(filePathSource.c_str()); |
| | if (!inFileSource) { |
| | cerr << "could not open file " << filePathSource << endl; |
| | return false; |
| | } |
| |
|
| | std::string line; |
| | while (getline(inFileSource, line)) { |
| | m_vocabSource.insert(line); |
| | } |
| |
|
| | inFileSource.close(); |
| |
|
| | |
| | ifstream inFileTarget(filePathTarget.c_str()); |
| | if (!inFileTarget) { |
| | cerr << "could not open file " << filePathTarget << endl; |
| | return false; |
| | } |
| |
|
| | while (getline(inFileTarget, line)) { |
| | m_vocabTarget.insert(line); |
| | } |
| |
|
| | inFileTarget.close(); |
| |
|
| | m_unrestricted = false; |
| | return true; |
| | } |
| |
|
| | void GlobalLexicalModelUnlimited::InitializeForInput(ttasksptr const& ttask) |
| | { |
| | UTIL_THROW_IF2(ttask->GetSource()->GetType() != SentenceInput, |
| | "GlobalLexicalModel works only with sentence input."); |
| | Sentence const* s = reinterpret_cast<Sentence const*>(ttask->GetSource().get()); |
| | m_local.reset(new ThreadLocalStorage); |
| | m_local->input = s; |
| | } |
| |
|
| | void GlobalLexicalModelUnlimited::EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const |
| | { |
| | const Sentence& input = *(m_local->input); |
| | const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); |
| |
|
| | for(size_t targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) { |
| | StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0); |
| |
|
| | if (m_ignorePunctuation) { |
| | |
| | char firstChar = targetString[0]; |
| | CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); |
| | if(charIterator != m_punctuationHash.end()) |
| | continue; |
| | } |
| |
|
| | if (m_biasFeature) { |
| | util::StringStream feature; |
| | feature << "glm_"; |
| | feature << targetString; |
| | feature << "~"; |
| | feature << "**BIAS**"; |
| | accumulator->SparsePlusEquals(feature.str(), 1); |
| | } |
| |
|
| | boost::unordered_set<uint64_t> alreadyScored; |
| | for(size_t sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) { |
| | const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0); |
| | |
| |
|
| | if (m_ignorePunctuation) { |
| | |
| | char firstChar = sourceString[0]; |
| | CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); |
| | if(charIterator != m_punctuationHash.end()) |
| | continue; |
| | } |
| | const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size()); |
| |
|
| | if ( alreadyScored.find(sourceHash) == alreadyScored.end()) { |
| | bool sourceExists, targetExists; |
| | if (!m_unrestricted) { |
| | sourceExists = FindStringPiece(m_vocabSource, sourceString ) != m_vocabSource.end(); |
| | targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end(); |
| | } |
| |
|
| | |
| | if (m_unrestricted || (sourceExists && targetExists)) { |
| | if (m_sourceContext) { |
| | if (sourceIndex == 0) { |
| | |
| | util::StringStream feature; |
| | feature << "glm_"; |
| | feature << targetString; |
| | feature << "~"; |
| | feature << "<s>,"; |
| | feature << sourceString; |
| | accumulator->SparsePlusEquals(feature.str(), 1); |
| | alreadyScored.insert(sourceHash); |
| | } |
| |
|
| | |
| | for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) { |
| | StringPiece contextString = input.GetWord(contextIndex).GetString(0); |
| | bool contextExists; |
| | if (!m_unrestricted) |
| | contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end(); |
| |
|
| | if (m_unrestricted || contextExists) { |
| | util::StringStream feature; |
| | feature << "glm_"; |
| | feature << targetString; |
| | feature << "~"; |
| | feature << sourceString; |
| | feature << ","; |
| | feature << contextString; |
| | accumulator->SparsePlusEquals(feature.str(), 1); |
| | alreadyScored.insert(sourceHash); |
| | } |
| | } |
| | } else if (m_biphrase) { |
| | |
| | int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex; |
| |
|
| | |
| | StringPiece targetContext; |
| | if (globalTargetIndex > 0) |
| | targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0); |
| | else |
| | targetContext = "<s>"; |
| |
|
| | if (sourceIndex == 0) { |
| | StringPiece sourceTrigger = "<s>"; |
| | AddFeature(accumulator, sourceTrigger, sourceString, |
| | targetContext, targetString); |
| | } else |
| | for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { |
| | StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); |
| | bool sourceTriggerExists = false; |
| | if (!m_unrestricted) |
| | sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); |
| |
|
| | if (m_unrestricted || sourceTriggerExists) |
| | AddFeature(accumulator, sourceTrigger, sourceString, |
| | targetContext, targetString); |
| | } |
| |
|
| | |
| | StringPiece sourceContext; |
| | if (sourceIndex-1 >= 0) |
| | sourceContext = input.GetWord(sourceIndex-1).GetString(0); |
| | else |
| | sourceContext = "<s>"; |
| |
|
| | if (globalTargetIndex == 0) { |
| | string targetTrigger = "<s>"; |
| | AddFeature(accumulator, sourceContext, sourceString, |
| | targetTrigger, targetString); |
| | } else |
| | for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { |
| | StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); |
| | bool targetTriggerExists = false; |
| | if (!m_unrestricted) |
| | targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); |
| |
|
| | if (m_unrestricted || targetTriggerExists) |
| | AddFeature(accumulator, sourceContext, sourceString, |
| | targetTrigger, targetString); |
| | } |
| | } else if (m_bitrigger) { |
| | |
| | int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex; |
| |
|
| | if (sourceIndex == 0) { |
| | StringPiece sourceTrigger = "<s>"; |
| | bool sourceTriggerExists = true; |
| |
|
| | if (globalTargetIndex == 0) { |
| | string targetTrigger = "<s>"; |
| | bool targetTriggerExists = true; |
| |
|
| | if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
| | AddFeature(accumulator, sourceTrigger, sourceString, |
| | targetTrigger, targetString); |
| | } else { |
| | |
| | for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { |
| | StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); |
| | bool targetTriggerExists = false; |
| | if (!m_unrestricted) |
| | targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); |
| |
|
| | if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
| | AddFeature(accumulator, sourceTrigger, sourceString, |
| | targetTrigger, targetString); |
| | } |
| | } |
| | } |
| | |
| | else { |
| | |
| | for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { |
| | StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); |
| | bool sourceTriggerExists = false; |
| | if (!m_unrestricted) |
| | sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); |
| |
|
| | if (globalTargetIndex == 0) { |
| | string targetTrigger = "<s>"; |
| | bool targetTriggerExists = true; |
| |
|
| | if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
| | AddFeature(accumulator, sourceTrigger, sourceString, |
| | targetTrigger, targetString); |
| | } else { |
| | |
| | for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { |
| | StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); |
| | bool targetTriggerExists = false; |
| | if (!m_unrestricted) |
| | targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); |
| |
|
| | if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
| | AddFeature(accumulator, sourceTrigger, sourceString, |
| | targetTrigger, targetString); |
| | } |
| | } |
| | } |
| | } |
| | } else { |
| | util::StringStream feature; |
| | feature << "glm_"; |
| | feature << targetString; |
| | feature << "~"; |
| | feature << sourceString; |
| | accumulator->SparsePlusEquals(feature.str(), 1); |
| | alreadyScored.insert(sourceHash); |
| |
|
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| |
|
| | void GlobalLexicalModelUnlimited::AddFeature(ScoreComponentCollection* accumulator, |
| | StringPiece sourceTrigger, StringPiece sourceWord, |
| | StringPiece targetTrigger, StringPiece targetWord) const |
| | { |
| | util::StringStream feature; |
| | feature << "glm_"; |
| | feature << targetTrigger; |
| | feature << ","; |
| | feature << targetWord; |
| | feature << "~"; |
| | feature << sourceTrigger; |
| | feature << ","; |
| | feature << sourceWord; |
| | accumulator->SparsePlusEquals(feature.str(), 1); |
| |
|
| | } |
| |
|
| | } |
| |
|