クラスの並列化

multiprocessingモジュールで関数を並列化する方法はネットにあるが、クラスを並列化するというのは少なかったので、やってみた。結論としては、処理速度上げるのしんどいしよくわからんエラー出るしで結局関数を並列化する事で落ち着いた。

何がしたいのか?

  1. ファイルからデータを読み込む
  2. SQLに加工するクラスにデータを渡す
  3. 上記クラスでSQLを生成し、実行

並列化するのは↑の2、3の部分

環境

Ubuntu12.04
Python2.7
PostgreSQL9.1.6

実装

ここの、プロセスへメッセージを渡す のところを参考に実装してみた。

class Sample:
    def __init__(self, json):
        self.json = json
        self.con = DBに繋ぐ関数
        self.cur = self.con.cursor()
    def __call__(self):
        #実際に並列化して実行したい関数
        self.add_db()
    def creat_sql(self):
        hogehoge
    def add_db(self):
        #クラス変数やインスタンス変数の扱いが面倒なので、関数をネストさせてローカルスコープ内で処理
        def main():
            sql = self.creat_sql()
            self.cur.execute(sql)

class Worker(multiprocessing.Process):
    def __init__(self, task_queue, result_queue):
        multiprocessing.Process.__init__(self)
        self.task_queue = task_queue
        self.result_queue = result_queue
    def run(self)
        while True:
            next_task = self.task_queue.get()
            if next_task == None:
                self.task_queue.task.done()
                break
            answer = next_task()
            self.task_queue.task_done()
            self.result_queue.put(answer)

f = open('読み込むファイル', 'r')
jsons = simplejson.load(f)
tasks = multiprocessing.JoinableQueue()
results = multiprocessing.Queue()
workers = [Worker(tasks, results) for i in range(multiprocessing.cpu_count() * 10)]
for w in workers:
    w.start()
#各ワーカーにデータを挿入
for w, data in itertools.izip(itertools.cycle(workers), json):
    tasks.put(tweet.Tweet(data))
#各Worker に poison pill を挿入
for w in iter(json):
    tasks.put(None)
tasks.join()
jobs = len(json)
while jobs:
    res = results.get()
    print 'res:', res
    jobs -= 1

Workerクラス以下はほとんど変えてないな。
取り敢えずはこれで動いたんだが、ここでも言ってるようにあまり処理速度が向上しなかった(8coreの自分の環境の場合、CPU使用率が20%程度にしかならなかった。map関数を使ったときは90%程度まで向上した)。ただ単にプログラムの書き方が悪いんだろう。
Workerクラスを使わずにmultiprocessing.Poolのmap関数を使ったほうがずっと早かった(Workerクラス使用時:2H → map関数使用時:11m)。これだけ違うってことは、プログラムの書き方が悪いんだろう。たぶん大事(だと思うから)二度書いてみた。
map関数使用時のコード

def add(data):
    s = Sample(data)
    s.add()
p = multiprocessing.Pool()
p.map(add, jsons)

詰まったところ備忘録

1.Sampleクラスのメソッドでcursorオブジェクトは渡せない(↓のような使い方)ので、curはインスタンス変数として持っておく

#ダメな例
cur = con.cursor()
def hoge(self, a, b, cur):
    hogehoge
    cur.execute(sql)
self.hoge(a=aaa, b=bbb, cur=cur)
#良い例
self.cur = self.con.curosr()
def hoge(self, a, b):
    hogehoge
    self.cur.execute(sql)
self.hoge(a=aaa, b=bbb)

2.インスタンス変数
クラスメソッドAの中で定義したインスタンス変数AとクラスメソッドBの中で定義したインスタンス変数Aは別だと思っていたが、同じだった。

class Sample:
    def hoge(self):
        self.a = 'aaa'
    def fuga(self):
        self.a = 'bbb'
    self.hoge() # self.a = 'aaa'になる
    self.fuga() # self.a = 'bbb'に上書き

3.よくわからんエラー
"psql: FATAL: remaining connection slots are reserved for non-replication superuser connections”
解決できなかったorz