-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstock_data.py
More file actions
1373 lines (1157 loc) · 60.1 KB
/
stock_data.py
File metadata and controls
1373 lines (1157 loc) · 60.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Tushare数据缓存系统 - 股票数据模块
包含股票相关的数据获取和处理方法
"""
import os
import sqlite3
import threading
import time
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
import pandas as pd
import numpy as np
import pickle
import gzip
import json
from concurrent.futures import ThreadPoolExecutor
from collections import OrderedDict
# 设置环境变量,确保Python使用UTF-8编码处理所有I/O
import os
os.environ['PYTHONIOENCODING'] = 'utf-8'
# 使用共享日志模块
try:
from .shared_logger import get_logger
except ImportError:
from shared_logger import get_logger
# 设置日志
logger = get_logger(__name__)
from collections import deque
import random
# API频率限制器类
class APIRateLimiter:
"""API调用频率限制器"""
def __init__(self, max_calls_per_minute=400, max_calls_per_hour=20000, max_calls_per_day=400000):
"""初始化频率限制器
Args:
max_calls_per_minute: 每分钟最大调用次数(留有一定余量)
max_calls_per_hour: 每小时最大调用次数
max_calls_per_day: 每天最大调用次数
"""
self.max_calls_per_minute = max_calls_per_minute
self.max_calls_per_hour = max_calls_per_hour
self.max_calls_per_day = max_calls_per_day
# 调用记录队列
self.minutes_calls = deque()
self.hours_calls = deque()
self.days_calls = deque()
# 锁
self.lock = threading.RLock()
def _clean_old_calls(self):
"""清理过期的调用记录"""
now = time.time()
# 清理超过1分钟的记录
while self.minutes_calls and now - self.minutes_calls[0] > 60:
self.minutes_calls.popleft()
# 清理超过1小时的记录
while self.hours_calls and now - self.hours_calls[0] > 3600:
self.hours_calls.popleft()
# 清理超过1天的记录
while self.days_calls and now - self.days_calls[0] > 86400:
self.days_calls.popleft()
def _check_limits(self):
"""检查是否超过调用限制"""
self._clean_old_calls()
# 检查每分钟限制
if self.max_calls_per_minute and len(self.minutes_calls) >= self.max_calls_per_minute:
return False, "每分钟调用次数超限"
# 检查每小时限制
if self.max_calls_per_hour and len(self.hours_calls) >= self.max_calls_per_hour:
return False, "每小时调用次数超限"
# 检查每天限制
if self.max_calls_per_day and len(self.days_calls) >= self.max_calls_per_day:
return False, "每天调用次数超限"
return True, ""
def wait_for_available(self):
"""等待直到API调用可用"""
while True:
with self.lock:
can_call, reason = self._check_limits()
if can_call:
return True
# 计算需要等待的时间
now = time.time()
wait_time = 0
# 检查下一个过期时间
if self.minutes_calls and self.max_calls_per_minute:
next_expiry = self.minutes_calls[0] + 60
if next_expiry > now:
wait_time = max(wait_time, next_expiry - now)
if self.hours_calls and self.max_calls_per_hour:
next_expiry = self.hours_calls[0] + 3600
if next_expiry > now:
wait_time = max(wait_time, next_expiry - now)
if self.days_calls and self.max_calls_per_day:
next_expiry = self.days_calls[0] + 86400
if next_expiry > now:
wait_time = max(wait_time, next_expiry - now)
# 添加一些随机延迟,避免同时发起请求
wait_time += random.uniform(0.1, 1.0)
wait_time = min(wait_time, 60) # 最多等待60秒
logger.warning(f"⏳ API调用限制触发,等待 {wait_time:.2f} 秒: {reason}")
# 在锁外等待,避免阻塞其他线程
time.sleep(wait_time)
def record_call(self):
"""记录API调用"""
with self.lock:
now = time.time()
self.minutes_calls.append(now)
self.hours_calls.append(now)
self.days_calls.append(now)
try:
from .core import TushareDataCacheCore
except ImportError:
from core import TushareDataCacheCore
class StockDataMixin:
"""股票数据获取和处理混入类"""
def daily(self, ts_code: str = '', start_date: str = '', end_date: str = '',
fields: str = '', trade_date: str = '', use_cache: bool = True, **kwargs) -> pd.DataFrame:
"""
获取股票日线数据(带智能缓存)
Args:
ts_code: 股票代码(格式:000001.SZ)
start_date: 开始日期(格式:20230101)
end_date: 结束日期(格式:20230131)
fields: 字段列表,逗号分隔
trade_date: 交易日期(格式:20230103)
use_cache: 是否使用缓存
**kwargs: 其他参数
Returns:
包含日线数据的DataFrame
"""
# 生成缓存键
cache_key = self._get_cache_key('daily', ts_code, start_date, end_date,
fields, trade_date=trade_date, **kwargs)
# 检查缓存
if use_cache:
cached_data = self._get_from_cache(cache_key)
if cached_data is not None:
logger.info(f"✅ 从缓存获取日线数据: {ts_code}")
return cached_data
# 从Tushare获取数据
if self.pro:
try:
logger.info(f"📊 从Tushare获取日线数据: {ts_code}, 开始日期: {start_date}, 结束日期: {end_date}")
# 构建参数
params = {
'ts_code': ts_code,
'start_date': start_date,
'end_date': end_date,
'fields': fields
}
# 如果提供了trade_date,则使用它而不是start_date和end_date
if trade_date:
params.pop('start_date', None)
params.pop('end_date', None)
params['trade_date'] = trade_date
# 添加额外参数
params.update(kwargs)
# 调用API
df = self.pro.daily(**params)
if df.empty:
logger.error(f"⚠️ Tushare API返回空数据: daily, {params}")
return pd.DataFrame()
# 保存到缓存
if use_cache:
self._save_to_cache(cache_key, df)
logger.debug(f"💾 日线数据已缓存: {cache_key}")
return df
except Exception as e:
logger.error(f"❌ 获取日线数据失败: {e}")
return pd.DataFrame()
else:
logger.warning("⚠️ 未初始化Tushare API客户端")
return pd.DataFrame()
def _get_daily_from_csv(self, symbol: str, start_date: str, end_date: str,
fields: List[str]) -> Optional[pd.DataFrame]:
"""从CSV文件获取日线数据"""
try:
filename = symbol.replace('.', '_') + '.csv'
filepath = self.data_root / 'daily' / filename
logger.debug(f"🔍 CSV检查: {filepath}")
logger.debug(f"📂 文件存在: {filepath.exists()}")
if not filepath.exists():
return None
# 使用pandas读取,支持分块读取大文件
df = pd.read_csv(filepath)
logger.info(f"📊 读取到 {len(df)} 条记录")
if 'trade_date' not in df.columns:
logger.error(f"❌ CSV文件中没有trade_date列")
return None
# 记录原始trade_date列的前几个值
logger.debug(f"🔍 原始trade_date列样本: {df['trade_date'].head().tolist()}")
# 确保trade_date列是字符串格式
df['trade_date'] = df['trade_date'].astype(str)
logger.debug(f"🔍 转换为字符串后的trade_date列样本: {df['trade_date'].head().tolist()}")
# 打印日期范围
min_date = df['trade_date'].min()
max_date = df['trade_date'].max()
logger.debug(f"📅 CSV数据日期范围: {min_date} - {max_date}")
logger.debug(f"🎯 请求日期范围: {start_date} - {end_date}")
# 检查日期比较
logger.debug(f"🔍 日期比较检查:")
logger.debug(f" min_date >= start_date: {min_date} >= {start_date} = {min_date >= start_date}")
logger.debug(f" max_date <= end_date: {max_date} <= {end_date} = {max_date <= end_date}")
# 过滤日期范围
mask = (df['trade_date'] >= start_date) & (df['trade_date'] <= end_date)
filtered_df = df[mask].copy()
logger.debug(f"🔍 过滤后记录数: {len(filtered_df)}")
if filtered_df.empty:
return None
# 选择指定字段
if fields:
available_fields = [f for f in fields if f in filtered_df.columns]
filtered_df = filtered_df[['trade_date'] + available_fields]
return filtered_df.sort_values('trade_date')
except Exception as e:
logger.error(f"❌ CSV读取失败 {symbol}: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def _get_daily_from_db(self, symbol: str, start_date: str, end_date: str,
fields: List[str]) -> Optional[pd.DataFrame]:
"""从SQLite数据库获取日线数据(带索引优化)"""
try:
db_path = self.data_root / 'quant_data.db'
if not db_path.exists():
return None
conn = sqlite3.connect(str(db_path))
# 创建索引(如果不存在)
index_sql = '''
CREATE INDEX IF NOT EXISTS idx_stock_daily_code_date
ON stock_daily (ts_code, trade_date)
'''
conn.execute(index_sql)
# 构建查询字段
if fields:
field_list = ', '.join([f'"{f}"' for f in fields if f != 'trade_date'])
else:
field_list = '*'
query = f'''
SELECT trade_date, {field_list}
FROM stock_daily
WHERE ts_code = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date
'''
df = pd.read_sql_query(query, conn, params=[symbol, start_date, end_date])
conn.close()
return df if not df.empty else None
except Exception as e:
logger.error(f"❌ 数据库读取失败 {symbol}: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def _get_daily_from_tushare(self, symbol: str, start_date: str, end_date: str,
fields: List[str]) -> Optional[pd.DataFrame]:
"""
从Tushare API获取单个股票的日线数据
"""
# 验证并修复日期参数顺序
start_date, end_date = self._validate_date_parameters(start_date, end_date)
return self._batch_get_daily_from_tushare([symbol], start_date, end_date, fields).get(symbol)
def _batch_get_daily_from_tushare(self, symbols: List[str], start_date: str, end_date: str,
fields: List[str]) -> Dict[str, Optional[pd.DataFrame]]:
"""
从Tushare API批量获取多个股票的日线数据(真正的批量接口调用)
Args:
symbols: 股票代码列表
start_date: 开始日期
end_date: 结束日期
fields: 字段列表
Returns:
Dict[str, pd.DataFrame]: 股票代码到数据的映射
"""
try:
# 验证并修复日期参数顺序
start_date, end_date = self._validate_date_parameters(start_date, end_date)
if not self.pro:
logger.error(f"❌ Tushare API未初始化")
return {symbol: None for symbol in symbols}
if not symbols:
return {}
# Tushare API支持的最大批量大小
# 根据Tushare限制和数据量限制调整批量大小
# Tushare单次返回数据量限制为6000条,假设每只股票平均250个交易日/年*10年=2500条
# 为安全起见,我们将批量大小设置为较小值
MAX_BATCH_SIZE = 2 # 减少批量大小以避免超出6000条数据限制
results = {}
for i in range(0, len(symbols), MAX_BATCH_SIZE):
batch = symbols[i:i+MAX_BATCH_SIZE]
batch_str = ','.join(batch)
# 构建字段参数 - 使用传入的字段或默认字段
if fields:
# 确保包含必要的字段
required_fields = ['ts_code', 'trade_date']
actual_fields = required_fields[:]
for field in fields:
if field not in actual_fields:
actual_fields.append(field)
fields_param = ','.join(actual_fields)
else:
# 默认字段
fields_param = 'ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount'
# 等待API可用
if hasattr(self, 'rate_limiter'):
self.rate_limiter.wait_for_available()
# 获取数据
logger.info(f"📊 调用Tushare API: daily(ts_code={batch_str}, start_date={start_date}, end_date={end_date})")
df = self.pro.daily(ts_code=batch_str, start_date=start_date, end_date=end_date, fields=fields_param)
# 记录API调用
if hasattr(self, 'rate_limiter'):
self.rate_limiter.record_call()
# 增强错误处理:检查返回数据的有效性
if df is None:
logger.error(f"❌ Tushare API返回None: {batch_str} {start_date}-{end_date}")
for symbol in batch:
results[symbol] = None
continue
if not isinstance(df, pd.DataFrame):
logger.error(f"❌ Tushare API返回非DataFrame类型 {type(df)}: {batch_str} {start_date}-{end_date}")
for symbol in batch:
results[symbol] = None
continue
if df.empty:
logger.warning(f"⚠️ Tushare API返回空数据: {batch_str} {start_date}-{end_date}")
logger.warning(f" 这可能是因为:")
logger.warning(f" 1. Token无效或权限不足")
logger.warning(f" 2. 股票代码不存在或已退市")
logger.warning(f" 3. 日期范围无交易数据")
logger.warning(f" 4. API调用频率超限")
for symbol in batch:
results[symbol] = None
continue
logger.info(f"✅ Tushare API批量获取到 {len(df)} 条数据,涉及 {len(batch)} 只股票")
# 调试信息改为debug级别
if not df.empty and 'ts_code' in df.columns:
unique_codes = df['ts_code'].unique()
logger.debug(f" 实际返回的股票代码: {unique_codes}")
for code in unique_codes:
code_data = df[df['ts_code'] == code]
logger.debug(f" {code}: {len(code_data)} 条数据")
# 按股票代码分组
for code, group in df.groupby('ts_code'):
# 选择指定字段
if fields:
# 确保trade_date在字段列表中且不重复
available_fields = [f for f in fields if f in group.columns and f != 'trade_date']
select_fields = ['trade_date'] + available_fields
group = group[select_fields]
results[code] = group.sort_values('trade_date')
# 检查是否有股票没有返回数据
for symbol in batch:
if symbol not in results:
logger.warning(f"⚠️ 股票 {symbol} 在API返回数据中未找到")
results[symbol] = None
return results
except Exception as e:
logger.error(f"❌ Tushare API批量获取失败: {e}")
import traceback
logger.debug(traceback.format_exc())
# 返回空数据而不是抛出异常,允许程序继续执行
return {symbol: None for symbol in symbols}
def get_stock_daily(self, symbol: str, start_date: str = None, end_date: str = None,
fields: List[str] = None, force_update: bool = False,
save_fields: List[str] = None, force_full_sync: bool = False,
trade_date: str = None) -> pd.DataFrame:
"""
获取股票日线数据(内部实现,支持增量更新和缓存部分命中处理)
"""
# 如果提供了trade_date参数,覆盖start_date和end_date
if trade_date:
start_date = trade_date
end_date = trade_date
# 验证并修复日期参数顺序
start_date, end_date = self._validate_date_parameters(start_date, end_date)
logger.info(f"🔍 开始获取 {symbol} 数据 ({start_date}-{end_date})")
cache_key = self._get_cache_key('daily', symbol, start_date, end_date, fields=fields)
if force_full_sync:
logger.info(f"🔄 启用全量同步: {symbol} {start_date}-{end_date}")
# 检查Tushare API是否初始化
if self.pro is None:
logger.warning(f"⚠️ Tushare API未初始化,无法执行全量同步,将从本地缓存获取数据: {symbol}")
# 尝试从本地缓存获取
cached_data = self._get_from_cache(cache_key)
if cached_data is not None and not cached_data.empty:
logger.info(f"✅ 从缓存获取到 {len(cached_data)} 条数据")
return cached_data
# 尝试从本地CSV获取
logger.warning(f"⚠️ 缓存未命中,尝试从本地CSV获取: {symbol}")
csv_data = self._get_daily_from_csv(symbol, start_date, end_date, fields)
if csv_data is not None and not csv_data.empty:
logger.info(f"✅ 从本地CSV获取到 {len(csv_data)} 条数据")
self._save_to_cache(cache_key, csv_data)
return csv_data
# 尝试从本地数据库获取
logger.warning(f"⚠️ CSV未命中,尝试从本地数据库获取: {symbol}")
db_data = self._get_daily_from_db(symbol, start_date, end_date, fields)
if db_data is not None and not db_data.empty:
logger.info(f"✅ 从本地数据库获取到 {len(db_data)} 条数据")
self._save_to_cache(cache_key, db_data)
return db_data
logger.error(f"❌ 所有本地数据源均未获取到数据: {symbol}")
else:
df_full = self._measure_performance(
'tushare', 'get_daily',
lambda: self._get_daily_from_tushare(symbol, start_date, end_date, fields)
)
if df_full is not None and not df_full.empty:
try:
# Ensure ts_code is included in save_fields
csv_save_fields = save_fields.copy() if save_fields else None
if csv_save_fields and 'ts_code' not in csv_save_fields:
csv_save_fields = ['ts_code'] + csv_save_fields
self._save_to_csv(symbol, df_full, csv_save_fields)
except Exception as e:
logger.warning(f"⚠️ 保存CSV失败 {symbol}: {e}")
try:
self._save_daily_to_db(symbol, df_full)
except Exception as e:
logger.warning(f"⚠️ 保存SQLite失败 {symbol}: {e}")
self._save_to_cache(cache_key, df_full)
return df_full
else:
logger.error(f"❌ 全量同步失败或无数据: {symbol}")
# 回退到原有获取策略
# 1. 优先检查持久化存储(CSV文件和数据库),因为这些数据不应过期
persistent_data = None
# 首先检查CSV文件
csv_data = self._get_daily_from_csv(symbol, start_date, end_date, fields)
if csv_data is not None and not csv_data.empty:
# 检查CSV数据是否完全覆盖请求范围
# 使用cache_mode=True,允许高覆盖率的缓存命中
fully_covered, missing_dates = self._check_data_coverage(csv_data, start_date, end_date, cache_mode=True)
if not missing_dates:
logger.info(f"✅ 从CSV文件获取 {symbol} 数据: {len(csv_data)} 条 (完全覆盖)")
# 保存到缓存供下次使用
self._save_to_cache(cache_key, csv_data)
return csv_data
else:
logger.debug(f"⚠️ CSV文件部分命中 {symbol}: {len(csv_data)} 条数据,缺失 {len(missing_dates)} 个交易日")
logger.debug(f" 缺失日期: {missing_dates[:5]}{'...' if len(missing_dates) > 5 else ''}")
persistent_data = csv_data
# 然后检查数据库
if persistent_data is None:
db_data = self._get_daily_from_db(symbol, start_date, end_date, fields)
if db_data is not None and not db_data.empty:
# 检查数据库数据是否完全覆盖请求范围
# 使用cache_mode=True,允许高覆盖率的缓存命中
fully_covered, missing_dates = self._check_data_coverage(db_data, start_date, end_date, cache_mode=True)
if not missing_dates:
logger.info(f"✅ 从数据库获取 {symbol} 数据: {len(db_data)} 条 (完全覆盖)")
# 保存到缓存和CSV文件供下次使用
self._save_to_cache(cache_key, db_data)
self._save_to_csv(symbol, db_data, save_fields)
return db_data
else:
logger.debug(f"⚠️ 数据库部分命中 {symbol}: {len(db_data)} 条数据,缺失 {len(missing_dates)} 个交易日")
logger.debug(f" 缺失日期: {missing_dates[:5]}{'...' if len(missing_dates) > 5 else ''}")
persistent_data = db_data
# 2. 检查内存缓存(除非强制更新)
cached_data = None
if not force_update:
cached_data = self._get_from_cache(cache_key)
if cached_data is not None:
# 检查缓存数据是否完全覆盖请求范围
# 使用cache_mode=True,允许高覆盖率的缓存命中
fully_covered, missing_dates = self._check_data_coverage(cached_data, start_date, end_date, cache_mode=True)
# 只有当数据未被标记为完全覆盖且有缺失日期时才继续获取
if not fully_covered and missing_dates:
logger.debug(f"⚠️ 缓存部分命中 {symbol}: {len(cached_data)} 条数据,缺失 {len(missing_dates)} 个交易日")
logger.debug(f" 缺失日期: {missing_dates[:5]}{'...' if len(missing_dates) > 5 else ''}")
# 继续获取缺失的数据
else:
logger.info(f"✅ 从缓存获取 {symbol} 数据: {len(cached_data)} 条 (完全覆盖)")
return cached_data
else:
logger.info(f"🔄 强制更新模式,跳过缓存检查")
# 3. 合并已有的持久化数据和缓存数据
result_data = pd.DataFrame()
if persistent_data is not None and not persistent_data.empty:
result_data = self._merge_dataframes(result_data, persistent_data)
if cached_data is not None and not cached_data.empty:
result_data = self._merge_dataframes(result_data, cached_data)
# 4. 智能数据获取策略
# 如果缓存数据为空或未完全覆盖,尝试从其他数据源获取
for source in self.data_sources:
logger.info(f"🔄 尝试从 {source} 获取 {symbol} 数据...")
if source == 'local_cache' and force_update:
continue # 强制更新时跳过缓存
try:
source_data = None
if source == 'local_csv':
source_data = self._measure_performance(
'local_csv', 'get_daily',
lambda: self._get_daily_from_csv(symbol, start_date, end_date, fields)
)
elif source == 'local_db':
source_data = self._measure_performance(
'local_db', 'get_daily',
lambda: self._get_daily_from_db(symbol, start_date, end_date, fields)
)
elif source == 'tushare_api' and self.pro:
# 增量更新:只获取缺失的数据
if not result_data.empty and not force_update:
# 检查当前结果数据覆盖的日期范围
_, missing_dates = self._check_data_coverage(result_data, start_date, end_date)
if missing_dates:
# 只获取缺失的日期范围
missing_start = missing_dates[0]
missing_end = missing_dates[-1]
logger.info(f"📊 只获取缺失日期范围: {missing_start} - {missing_end}")
source_data = self._measure_performance(
'tushare', 'get_daily',
lambda: self._get_daily_from_tushare(symbol, missing_start, missing_end, fields)
)
# 添加详细的调试信息
if source_data is None or source_data.empty:
# 移除过多的调试信息
# 尝试获取完整范围作为备选方案,但记录详细日志
logger.warning(f"⚠️ 增量获取失败,尝试获取完整日期范围: {start_date} - {end_date}")
source_data = self._measure_performance(
'tushare', 'get_daily',
lambda: self._get_daily_from_tushare(symbol, start_date, end_date, fields)
)
if source_data is None or source_data.empty:
logger.error(f"❌ 完整范围获取也失败,需要进一步调试")
else:
logger.info(f"📋 {symbol} 数据已完整,跳过网络请求")
continue
else:
# 获取全部数据
source_data = self._measure_performance(
'tushare', 'get_daily',
lambda: self._get_daily_from_tushare(symbol, start_date, end_date, fields)
)
# 保存到本地(同步保存CSV与SQLite)
if source_data is not None and not source_data.empty:
try:
# Ensure ts_code is included in save_fields
csv_save_fields = save_fields.copy() if save_fields else None
if csv_save_fields and 'ts_code' not in csv_save_fields:
csv_save_fields = ['ts_code'] + csv_save_fields
self._save_to_csv(symbol, source_data, csv_save_fields)
except Exception as e:
logger.warning(f"⚠️ 保存CSV失败 {symbol}: {e}")
try:
self._save_daily_to_db(symbol, source_data)
except Exception as e:
logger.warning(f"⚠️ 保存SQLite失败 {symbol}: {e}")
elif source_data is not None and source_data.empty:
logger.warning(f"⚠️ {symbol} 获取到空数据,跳过保存")
else:
logger.warning(f"⚠️ {symbol} 未获取到数据,跳过保存")
continue
elif source_data is not None and source_data.empty:
logger.warning(f"⚠️ {symbol} 获取到空数据,跳过保存")
else:
logger.warning(f"⚠️ {symbol} 未获取到数据,跳过保存")
continue
if source_data is not None and not source_data.empty:
logger.info(f"✅ 从 {source} 成功获取 {symbol} 数据: {len(source_data)} 条")
# 合并数据
result_data = self._merge_dataframes(result_data, source_data)
# 检查是否已完全覆盖请求范围
fully_covered, missing_dates = self._check_data_coverage(result_data, start_date, end_date, cache_mode=True)
# 修复:即使fully_covered为True(90%覆盖率),也检查missing_dates
if not missing_dates:
logger.info(f"✅ {symbol} 数据已完全覆盖请求范围: {len(result_data)} 条")
break
else:
if fully_covered:
logger.debug(f"⚠️ {symbol} 数据高覆盖率({len(result_data)}条),但仍有{len(missing_dates)}个交易日缺失")
logger.debug(f" 缺失日期: {missing_dates[:5]}{'...' if len(missing_dates) > 5 else ''}")
else:
logger.debug(f"⚠️ {symbol} 数据仍未完全覆盖,还缺失 {len(missing_dates)} 个交易日")
continue
else:
logger.error(f"❌ 从 {source} 未获取到 {symbol} 数据")
except Exception as e:
logger.warning(f"⚠️ 数据源 {source} 获取失败: {e}")
import traceback
logger.debug(traceback.format_exc())
continue
# 3. 保存结果到缓存(如果有数据)
if not result_data.empty:
# 保存到缓存
self._save_to_cache(cache_key, result_data)
logger.info(f"💾 已保存 {symbol} 数据到缓存: {len(result_data)} 条")
return result_data
else:
logger.error(f"❌ 无法获取股票 {symbol} 的数据")
return pd.DataFrame()
def batch_get_stock_daily(self, symbols: List[str], start_date: str = None,
end_date: str = None, fields: List[str] = None,
parallel: bool = True, save_fields: List[str] = None,
force_full_sync: bool = False, trade_date: str = None) -> Dict[str, pd.DataFrame]:
"""
批量获取多只股票的日线数据(支持批量API调用和并行处理)
Args:
symbols: 股票代码列表
start_date: 开始日期
end_date: 结束日期
fields: 字段列表
parallel: 是否使用并行处理
save_fields: 保存时使用的字段列表
force_full_sync: 是否强制全量同步
trade_date: 交易日期,格式YYYYMMDD,可选
Returns:
Dict[str, pd.DataFrame]: 股票代码到数据的映射
"""
# 如果提供了trade_date参数,覆盖start_date和end_date
if trade_date:
start_date = trade_date
end_date = trade_date
# 验证并修复日期参数顺序
start_date, end_date = self._validate_date_parameters(start_date, end_date)
# 日期范围日志改为debug级别
logger.debug(f"📅 请求日期范围: start_date={start_date}, end_date={end_date}")
results = {}
# 去重股票代码
unique_symbols = list(set(symbols))
# 获取需要从API获取的股票列表(缓存未命中或强制更新)
need_api_update = []
for symbol in unique_symbols:
cache_key = self._get_cache_key('daily', symbol, start_date, end_date, fields=fields)
cached_data = None
if not force_full_sync:
cached_data = self._get_from_cache(cache_key)
if cached_data is None or force_full_sync:
need_api_update.append(symbol)
else:
results[symbol] = cached_data
# 如果有需要从API获取的股票
if need_api_update:
logger.info(f"📊 需要从API获取 {len(need_api_update)} 只股票的数据")
# 1. 从Tushare API批量获取数据
api_results = self._batch_get_daily_from_tushare(need_api_update, start_date, end_date, fields)
# 2. 处理API返回的数据
for symbol, data in api_results.items():
if data is not None and not data.empty:
# 保存到本地
try:
# Ensure ts_code is included in save_fields
csv_save_fields = save_fields.copy() if save_fields else None
if csv_save_fields and 'ts_code' not in csv_save_fields:
csv_save_fields = ['ts_code'] + csv_save_fields
self._save_to_csv(symbol, data, csv_save_fields)
except Exception as e:
logger.warning(f"⚠️ 保存CSV失败 {symbol}: {e}")
try:
self._save_daily_to_db(symbol, data)
except Exception as e:
logger.warning(f"⚠️ 保存SQLite失败 {symbol}: {e}")
# 保存到缓存
cache_key = self._get_cache_key('daily', symbol, start_date, end_date, fields=fields)
self._save_to_cache(cache_key, data)
# 添加到结果
results[symbol] = data
# 日期范围日志改为debug级别
if 'trade_date' in data.columns:
logger.debug(f"📊 {symbol} 返回数据日期范围: {data['trade_date'].min()} 至 {data['trade_date'].max()}")
else:
logger.debug(f"⚠️ {symbol} 返回数据中没有trade_date列")
else:
logger.error(f"❌ {symbol} 从API获取到空数据或获取失败")
return results
def stock_basic(self, is_hs: str = None, exchange: str = '',
list_status: str = 'L', fields: str = None,
use_cache: bool = True) -> pd.DataFrame:
"""
获取股票基本信息(与tushare Pro API接口一致)
Args:
is_hs: 是否沪深港通标的,N否 H沪股通 S深股通,可选
exchange: 交易所代码,可选
list_status: 上市状态,L上市 D退市 P暂停上市,可选
fields: 字段列表,逗号分隔的字符串,可选
use_cache: 是否使用缓存
Returns:
DataFrame: 股票基本信息
"""
# 生成缓存键
cache_key = self._get_cache_key('stock_basic', '', is_hs=is_hs,
exchange=exchange, list_status=list_status,
fields=fields)
# 尝试从缓存获取
if use_cache:
cached_data = self._get_from_cache(cache_key)
if cached_data is not None:
return cached_data
# 从Tushare获取
if self.pro:
try:
# 调用tushare的stock_basic接口
basic_df = self.pro.stock_basic(
is_hs=is_hs,
exchange=exchange,
list_status=list_status,
fields=fields
)
# 保存到缓存
if use_cache:
self._save_to_cache(cache_key, basic_df)
return basic_df
except Exception as e:
print(f"❌ 获取股票基本信息失败: {e}")
return pd.DataFrame()
def daily_basic(self, ts_code: str = None, trade_date: str = None,
start_date: str = None, end_date: str = None,
fields: str = None) -> pd.DataFrame:
"""
获取每日基本面数据(与Tushare API接口一致,带智能缓存)
Args:
ts_code: 股票代码,格式XXXXXX.XSHE/XXXXXX.XSHG,可选
trade_date: 交易日期,格式YYYYMMDD,可选
start_date: 开始日期,格式YYYYMMDD,可选
end_date: 结束日期,格式YYYYMMDD,可选
fields: 字段列表,可选
Returns:
DataFrame: 每日基本面数据
"""
# 生成缓存键
cache_key = self._get_cache_key('daily_basic', 'daily_basic',
ts_code=ts_code,
trade_date=trade_date,
start_date=start_date,
end_date=end_date,
fields=fields)
# 尝试从缓存获取
cached_data = self._get_from_cache(cache_key)
if cached_data is not None:
return cached_data
# 默认字段
if not fields:
fields = ['ts_code', 'trade_date', 'close', 'turnover_rate', 'volume_ratio',
'pe', 'pe_ttm', 'pb', 'ps', 'ps_ttm', 'dv_ratio', 'dv_ttm',
'total_share', 'float_share', 'free_share', 'total_mv', 'circ_mv']
# 从Tushare获取
if self.pro:
try:
if trade_date:
# 获取特定日期的数据
df = self.pro.daily_basic(
trade_date=trade_date,
fields=','.join(fields)
)
elif start_date and end_date:
# 获取日期范围的数据
df = self.pro.daily_basic(
start_date=start_date,
end_date=end_date,
fields=','.join(fields)
)
else:
# 默认获取最近一个交易日的数据
df = self.pro.daily_basic(
fields=','.join(fields)
)
if df.empty:
print(f"❌ Tushare API返回空数据: daily_basic")
return pd.DataFrame()
# 按股票代码过滤
if ts_code:
df = df[df['ts_code'] == ts_code]
# 保存到缓存
self._save_to_cache(cache_key, df)
return df
except Exception as e:
print(f"❌ 获取每日基本面数据失败: {e}")
return pd.DataFrame()
def limit_list(self, trade_date: str = None, start_date: str = None,
end_date: str = None, limit_type: str = 'U',
fields: str = None, use_cache: bool = True) -> pd.DataFrame:
"""
获取涨跌停股票数据(带智能缓存)
Args:
trade_date: 交易日期,格式YYYYMMDD,可选
start_date: 开始日期,格式YYYYMMDD,可选
end_date: 结束日期,格式YYYYMMDD,可选
limit_type: 涨跌停类型,U表示涨停,D表示跌停,可选
fields: 字段列表,可选
use_cache: 是否使用缓存
Returns:
DataFrame: 涨跌停股票数据
"""
# 生成缓存键
cache_key = self._get_cache_key('limit_list', 'limit_list',
trade_date=trade_date,
start_date=start_date,
end_date=end_date,
limit_type=limit_type,
fields=fields)
# 尝试从缓存获取
if use_cache:
cached_data = self._get_from_cache(cache_key)
if cached_data is not None:
return cached_data
# 默认字段
if not fields:
fields = ['trade_date', 'ts_code', 'name', 'close', 'pct_chg',
'amount', 'turnover_rate', 'volume_ratio', 'limit_times']
# 从Tushare获取
if self.pro:
try:
if trade_date:
# 获取特定日期的数据
df = self.pro.limit_list(
trade_date=trade_date,
limit_type=limit_type,
fields=','.join(fields)
)
elif start_date and end_date:
# 获取日期范围的数据
df = self.pro.limit_list(
start_date=start_date,
end_date=end_date,
limit_type=limit_type,
fields=','.join(fields)
)
else:
# 默认获取最近一个交易日的数据
df = self.pro.limit_list(
limit_type=limit_type,
fields=','.join(fields)
)
if df.empty:
print(f"❌ Tushare API返回空数据: limit_list")
return pd.DataFrame()
# 保存到缓存
if use_cache:
self._save_to_cache(cache_key, df)
return df
except Exception as e:
print(f"❌ 获取涨跌停股票数据失败: {e}")
return pd.DataFrame()
def moneyflow(self, ts_code: str = None, trade_date: str = None,
start_date: str = None, end_date: str = None,