This repository was archived by the owner on Nov 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess_dataset.py
More file actions
124 lines (91 loc) · 3.27 KB
/
process_dataset.py
File metadata and controls
124 lines (91 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import csv
from functions import *
from collections import defaultdict
from random import shuffle
import os
import json
params_file1 = 'params_train.txt'
params_file2 = 'params_validate.txt'
params_file3 = 'params_test.txt'
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True)
parser.add_argument('--model-folder', dest='model_folder', required=True)
def _dump_params(lines, model_folder):
shuffle(lines)
count_train = 95 * len(lines) // 100
count_validate = 5 * len(lines) // 100
print(count_train)
print(count_validate)
print(len(lines))
with open(os.path.join(model_folder, params_file1), 'w') as fp:
for i in range(0, count_train - 1):
fp.write(lines[i])
with open(os.path.join(model_folder, params_file2), 'w') as fp:
for i in range(count_train, count_train + count_validate - 1):
fp.write(lines[i])
with open(os.path.join(model_folder, params_file3), 'w') as fp:
for i in range((count_validate + count_train), len(lines) - 1):
fp.write(lines[i])
def main():
args = parser.parse_args()
if not os.path.exists(args.model_folder):
os.makedirs(args.model_folder)
print('Generating dict')
with open(args.dataset) as csvfile:
reader = csv.DictReader(csvfile)
word_counts = defaultdict(int)
i = 0
for row in reader:
if len(row['title'].split()) > 9:
stemmed = stemm(row['title'])
for word in stemmed:
word_counts[word] += 1
if i % 100 == 0:
print(i)
i += 1
popular_words = _get_popular_words(word_counts, 100)
_save_dict_to_file(os.path.join(args.model_folder, 'dict.txt'), popular_words)
word_counts = None
reader = None
print('Generating parameters')
with open(args.dataset) as csvfile:
lines = []
reader = csv.DictReader(csvfile)
i = 0
for row in reader:
if len(row['title'].split()) > 9:
pre_json = {}
stemmed = stemm(row['title'])
pre_json['res'] = row['img_id']
j = 0
add_line = False
for word in popular_words:
if stemmed.count(word) > 0 :
add_line = True
pre_json[str(j)] = stemmed.count(word)
j += 1
#line = ' '.join(str(stemmed.count(word)) for word in popular_words)
#line += ' '
#line += row['img_id'] + '\n'
if add_line:
lines.append(json.dumps(pre_json) + '\n')
if i % 100 == 0:
print(i)
i += 1
_dump_params(lines, args.model_folder)
def _get_popular_words(wordDict, threshold=4):
print(len(wordDict))
dictList = []
for word in wordDict:
if len(word) > 1 and wordDict[word] > threshold:
dictList.append(word)
return sorted(dictList)
def _save_dict_to_file(file_name, word_dict):
with open(file_name, 'w') as f:
i = 1
for word in word_dict:
f.write(str(i) + '\t' + word + '\n')
i += 1
if __name__ == '__main__':
main()