| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright (c) 2016-present, Facebook, Inc.
- # All rights reserved.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- #
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from __future__ import unicode_literals
- import numpy as np
- from scipy import stats
- import os
- import math
- import argparse
- def compat_splitting(line):
- return line.decode('utf8').split()
- def similarity(v1, v2):
- n1 = np.linalg.norm(v1)
- n2 = np.linalg.norm(v2)
- return np.dot(v1, v2) / n1 / n2
- parser = argparse.ArgumentParser(description='Process some integers.')
- parser.add_argument(
- '--model',
- '-m',
- dest='modelPath',
- action='store',
- required=True,
- help='path to model'
- )
- parser.add_argument(
- '--data',
- '-d',
- dest='dataPath',
- action='store',
- required=True,
- help='path to data'
- )
- args = parser.parse_args()
- vectors = {}
- fin = open(args.modelPath, 'rb')
- for _, line in enumerate(fin):
- try:
- tab = compat_splitting(line)
- vec = np.array(tab[1:], dtype=float)
- word = tab[0]
- if np.linalg.norm(vec) == 0:
- continue
- if not word in vectors:
- vectors[word] = vec
- except ValueError:
- continue
- except UnicodeDecodeError:
- continue
- fin.close()
- mysim = []
- gold = []
- drop = 0.0
- nwords = 0.0
- fin = open(args.dataPath, 'rb')
- for line in fin:
- tline = compat_splitting(line)
- word1 = tline[0].lower()
- word2 = tline[1].lower()
- nwords = nwords + 1.0
- if (word1 in vectors) and (word2 in vectors):
- v1 = vectors[word1]
- v2 = vectors[word2]
- d = similarity(v1, v2)
- mysim.append(d)
- gold.append(float(tline[2]))
- else:
- drop = drop + 1.0
- fin.close()
- corr = stats.spearmanr(mysim, gold)
- dataset = os.path.basename(args.dataPath)
- print(
- "{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)"
- .format(dataset, corr[0] * 100, math.ceil(drop / nwords * 100.0))
- )
|