# -*- coding: utf-8 -*- # ohlc_chart_maker.py # ========================== # * Version : 0.4 # * Last update : 2019-03-18 # import csv import mpl_finance import matplotlib.dates import numpy import sys from datetime import datetime as DateTime from matplotlib import pyplot from matplotlib.dates import WeekdayLocator from matplotlib.ticker import EngFormatter from matplotlib.ticker import MaxNLocator from matplotlib.ticker import Formatter class MyDateFormatter(Formatter): def __init__(self): pass def __call__(self, x, pos=None): fmt = "%b-%d" if pos == 0: fmt = "%Y-" + fmt return matplotlib.dates.num2date(x).strftime(fmt) def make_mpl_date_from_string(s, fmt="%Y-%m-%d"): return matplotlib.dates.date2num(DateTime.strptime(s, fmt)) def read_csv_data(fname, start_date=None, end_date=None): """ @return 2d array (Time, Open, High, Low, Close, Volume, SMA5, SMA25) """ csv_data = [] with open(fname, "r", newline="") as reader: csv_reader = csv.reader(reader) for row in csv_reader: list_temp = [] list_temp.append(make_mpl_date_from_string(row[0])) list_temp.append(int(row[1])) list_temp.append(int(row[2])) list_temp.append(int(row[3])) list_temp.append(int(row[4])) list_temp.append(int(row[5])) csv_data.append(list_temp) start_mdate = numpy.nan end_mdate = numpy.nan if start_date is not None: start_mdate = make_mpl_date_from_string(start_date) if end_date is not None: end_mdate = make_mpl_date_from_string(end_date) for i in range(len(csv_data)): if (i < 25 or csv_data[i][0] < start_mdate or csv_data[i][0] > end_mdate): csv_data[i].extend([numpy.nan, numpy.nan]) else: sma5 = sum([row[4] for row in csv_data[(i-4):(i+1)]]) / 5 sma25 = sum([row[4] for row in csv_data[(i-24):(i+1)]]) / 25 csv_data[i].extend([sma5, sma25]) start_index = None end_index = None for i in range(len(csv_data)): if csv_data[i][0] == start_mdate: start_index = i break for i in range(len(csv_data)): if csv_data[i][0] == end_mdate: end_index = i + 1 break return numpy.array(csv_data[start_index:end_index]) def _main(argc, argv): csv_file_name = "" start_date = None end_date = None if argc < 2: print("Usage: python3", argv[0], "[historical data file] ([start date]) ([end date])") print("e.g.1: python3", argv[0], "test_data.csv") print("e.g.2: python3", argv[0], "test_data.csv 2019-02-25 2019-06-07") sys.exit(0) else: csv_file_name = argv[1] if argc > 3: start_date = argv[2] end_date = argv[3] # Time, Open, High, Low, Close, Volume, SMA5, SMA25 data = read_csv_data(csv_file_name, start_date, end_date) # pyplot settings pyplot.rcParams["font.family"] = "monospace" pyplot.rcParams["axes.xmargin"] = 0.01 fig = pyplot.figure(figsize=(14, 4), dpi=92) ax1 = fig.add_subplot(1, 1, 1) ax2 = ax1.twinx() # 「OHLC チャート」と「出来高」の表示比 (2 なら 2:1 で表示する) pv_yratio = 2 # ax1 settings ax1.plot(data[:, 0], data[:, 6], color="mediumslateblue", linestyle="-", linewidth=0.75) ax1.plot(data[:, 0], data[:, 7], color="#a79bee", linestyle="-.", linewidth=0.75) mpl_finance.plot_day_summary_ohlc(ax1, data, ticksize=3.25, colorup="hotpink", colordown="cornflowerblue") ax1.tick_params("x", labelsize="small") ax1.xaxis.set_major_locator(WeekdayLocator(byweekday=matplotlib.dates.MO)) ax1.xaxis.set_major_formatter(MyDateFormatter()) (bottom, top) = ax1.get_ylim() new_bottom = bottom - (top - bottom) / pv_yratio ax1.set_ylim(bottom=new_bottom) ax1.yaxis.set_major_locator(MaxNLocator(13)) ticks = ax1.get_yticks() indices = numpy.where(ticks >= bottom) ax1.set_yticks(ticks[indices]) ax1.set_ylim(new_bottom - (ticks[1] - ticks[0]), top + (ticks[1] - ticks[0])) ax1.set_axisbelow(True) ax1.grid(color="silver", linestyle=":") # ax2 settings ax2.bar(data[:, 0], data[:, 5], width=0, edgecolor="forestgreen") (bottom, top) = ax2.get_ylim() new_top = top * (pv_yratio + 1) ax2.set_ylim(top=new_top) ax2.yaxis.set_major_locator(MaxNLocator(13)) ticks = ax2.get_yticks() indices = numpy.where(ticks <= top) ax2.set_yticks(ticks[indices]) ax2.yaxis.set_major_formatter(EngFormatter()) ax2.set_axisbelow(True) ax2.grid(axis="y", color="orange", linestyle="--") fig.tight_layout() index = csv_file_name.rindex(".csv") suffix = DateTime.now().strftime("%M%S") png_file_name = "fig_{0}_{1}.png".format(csv_file_name[0:index], suffix) pyplot.savefig(png_file_name) if __name__ == "__main__": _main(len(sys.argv), sys.argv)