From 4dc5f129704337a83bbd1ecd3bf2ad573d1e1e5f Mon Sep 17 00:00:00 2001 From: Mykhailo Istomin Date: Tue, 18 Apr 2023 19:07:36 +0300 Subject: [PATCH 1/2] FluSurv model and flusurv_update refactor --- src/acquisition/flusurv/flusurv_update.py | 194 ++++++++++------------ src/acquisition/flusurv/models.py | 29 ++++ 2 files changed, 121 insertions(+), 102 deletions(-) create mode 100644 src/acquisition/flusurv/models.py diff --git a/src/acquisition/flusurv/flusurv_update.py b/src/acquisition/flusurv/flusurv_update.py index 35fadba05..72d13b35d 100644 --- a/src/acquisition/flusurv/flusurv_update.py +++ b/src/acquisition/flusurv/flusurv_update.py @@ -68,122 +68,112 @@ + initial version """ -# standard library import argparse -# third party -import mysql.connector - -# first party from delphi.epidata.acquisition.flusurv import flusurv -import delphi.operations.secrets as secrets from delphi.utils.epidate import EpiDate from delphi.utils.epiweek import delta_epiweeks +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from ...server._config import ( + SQLALCHEMY_DATABASE_URI, + SQLALCHEMY_ENGINE_OPTIONS +) +from .models import FluSurv +engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) -def get_rows(cur): - """Return the number of rows in the `flusurv` table.""" - # count all rows - cur.execute('SELECT count(1) `num` FROM `flusurv`') - for (num,) in cur: - return num +def get_rows(session): + """Return the number of rows in the `flusurv` table.""" + return session.query(FluSurv.id).count() def update(issue, location_name, test_mode=False): - """Fetch and store the currently avialble weekly FluSurv dataset.""" - - # fetch data - location_code = flusurv.location_codes[location_name] - print('fetching data for', location_name, location_code) - data = flusurv.get_data(location_code) - - # metadata - epiweeks = sorted(data.keys()) - location = location_name - release_date = str(EpiDate.today()) - - # connect to the database - u, p = secrets.db.epi - cnx = mysql.connector.connect( - host=secrets.db.host, user=u, password=p, database='epidata') - cur = cnx.cursor() - rows1 = get_rows(cur) - print('rows before: %d' % rows1) - - # SQL for insert/update - sql = ''' - INSERT INTO `flusurv` ( - `release_date`, `issue`, `epiweek`, `location`, `lag`, `rate_age_0`, - `rate_age_1`, `rate_age_2`, `rate_age_3`, `rate_age_4`, `rate_overall`, - `rate_age_5`, `rate_age_6`, `rate_age_7` - ) - VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) - ON DUPLICATE KEY UPDATE - `release_date` = least(`release_date`, %s), - `rate_age_0` = coalesce(%s, `rate_age_0`), - `rate_age_1` = coalesce(%s, `rate_age_1`), - `rate_age_2` = coalesce(%s, `rate_age_2`), - `rate_age_3` = coalesce(%s, `rate_age_3`), - `rate_age_4` = coalesce(%s, `rate_age_4`), - `rate_overall` = coalesce(%s, `rate_overall`), - `rate_age_5` = coalesce(%s, `rate_age_5`), - `rate_age_6` = coalesce(%s, `rate_age_6`), - `rate_age_7` = coalesce(%s, `rate_age_7`) - ''' - - # insert/update each row of data (one per epiweek) - for epiweek in epiweeks: - lag = delta_epiweeks(epiweek, issue) - if lag > 52: - # Ignore values older than one year, as (1) they are assumed not to - # change, and (2) it would adversely affect database performance if all - # values (including duplicates) were stored on each run. - continue - args_meta = [release_date, issue, epiweek, location, lag] - args_insert = data[epiweek] - args_update = [release_date] + data[epiweek] - cur.execute(sql, tuple(args_meta + args_insert + args_update)) - - # commit and disconnect - rows2 = get_rows(cur) - print('rows after: %d (+%d)' % (rows2, rows2 - rows1)) - cur.close() - if test_mode: - print('test mode: not committing database changes') - else: - cnx.commit() - cnx.close() + """Fetch and store the currently avialble weekly FluSurv dataset.""" + + # fetch data + location_code = flusurv.location_codes[location_name] + print('fetching data for', location_name, location_code) + data = flusurv.get_data(location_code) + + # metadata + epiweeks = sorted(data.keys()) + location = location_name + release_date = str(EpiDate.today()) + + with Session(engine) as session: + rows1 = get_rows(session) + print('rows before: %d' % rows1) + for epiweek in epiweeks: + lag = delta_epiweeks(epiweek, issue) + if lag > 52: + # Ignore values older than one year, as (1) they are assumed not to + # change, and (2) it would adversely affect database performance if all + # values (including duplicates) were stored on each run. + continue + args_meta = { + "issue": issue, + "epiweek": epiweek, + "location": location, + } + args_update = { + "release_date": release_date, + "lag": lag, + "rate_age_0": data[epiweek][0], + "rate_age_1": data[epiweek][1], + "rate_age_2": data[epiweek][2], + "rate_age_3": data[epiweek][3], + "rate_age_4": data[epiweek][4], + "rate_overall": data[epiweek][5], + "rate_age_5": data[epiweek][6], + "rate_age_6": data[epiweek][7], + "rate_age_7": data[epiweek][8], + } + existing_flusurv = session.query(FluSurv).filter_by(**args_meta) + if existing_flusurv.first() is not None: + existing_flusurv.update(args_update) + else: + args_create = args_meta + args_create.update(args_update) + session.add(FluSurv(**args_create)) + + rows2 = get_rows(session) + print('rows after: %d (+%d)' % (rows2, rows2 - rows1)) + if test_mode: + print('test mode: not committing database changes') + else: + session.commit() + session.close() def main(): - # args and usage - parser = argparse.ArgumentParser() - parser.add_argument( - 'location', - help='location for which data should be scraped (e.g. "CA" or "all")' - ) - parser.add_argument( - '--test', '-t', - default=False, action='store_true', help='do not commit database changes' - ) - args = parser.parse_args() - - # scrape current issue from the main page - issue = flusurv.get_current_issue() - print('current issue: %d' % issue) - - # fetch flusurv data - if args.location == 'all': - # all locations - for location in flusurv.location_codes.keys(): - update(issue, location, args.test) - else: - # single location - update(issue, args.location, args.test) + # args and usage + parser = argparse.ArgumentParser() + parser.add_argument( + 'location', + help='location for which data should be scraped (e.g. "CA" or "all")' + ) + parser.add_argument( + '--test', '-t', + default=False, action='store_true', help='do not commit database changes' + ) + args = parser.parse_args() + + # scrape current issue from the main page + issue = flusurv.get_current_issue() + print('current issue: %d' % issue) + + # fetch flusurv data + if args.location == 'all': + # all locations + for location in flusurv.location_codes.keys(): + update(issue, location, args.test) + else: + # single location + update(issue, args.location, args.test) if __name__ == '__main__': - main() + main() diff --git a/src/acquisition/flusurv/models.py b/src/acquisition/flusurv/models.py new file mode 100644 index 000000000..e4ef4d38e --- /dev/null +++ b/src/acquisition/flusurv/models.py @@ -0,0 +1,29 @@ +from sqlalchemy import Column +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.types import Date, Float, Integer, String + +Base = declarative_base() + + +class FluSurv(Base): + """ + SQLAlchemy model representing flusurve data. + """ + + __tablename__ = 'flusurv' + + id = Column(Integer, primary_key=True, autoincrement="auto", nullable=False) + release_date = Column(Date, nullable=False) + issue = Column(Integer, unique=True, nullable=False) + epiweek = Column(Integer, unique=True, nullable=False) + location = Column(String(length=32), unique=True, nullable=False) + lag = Column(Integer, nullable=False) + rate_age_0 = Column(Float, default=None) + rate_age_1 = Column(Float, default=None) + rate_age_2 = Column(Float, default=None) + rate_age_3 = Column(Float, default=None) + rate_age_4 = Column(Float, default=None) + rate_age_5 = Column(Float, default=None) + rate_age_6 = Column(Float, default=None) + rate_age_7 = Column(Float, default=None) + rate_overall = Column(Float, default=None) From 9563b64cf50fc072dccb6903da8d6864749a2436 Mon Sep 17 00:00:00 2001 From: Mykhailo Istomin Date: Wed, 26 Apr 2023 18:04:52 +0300 Subject: [PATCH 2/2] changes --- src/acquisition/flusurv/flusurv_update.py | 3 +-- src/acquisition/flusurv/models.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/acquisition/flusurv/flusurv_update.py b/src/acquisition/flusurv/flusurv_update.py index 72d13b35d..22ecb1355 100644 --- a/src/acquisition/flusurv/flusurv_update.py +++ b/src/acquisition/flusurv/flusurv_update.py @@ -135,8 +135,7 @@ def update(issue, location_name, test_mode=False): if existing_flusurv.first() is not None: existing_flusurv.update(args_update) else: - args_create = args_meta - args_create.update(args_update) + args_create = dict(**args_meta, **args_update) session.add(FluSurv(**args_create)) rows2 = get_rows(session) diff --git a/src/acquisition/flusurv/models.py b/src/acquisition/flusurv/models.py index e4ef4d38e..38915ceef 100644 --- a/src/acquisition/flusurv/models.py +++ b/src/acquisition/flusurv/models.py @@ -1,5 +1,5 @@ -from sqlalchemy import Column -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Index, UniqueConstraint +from sqlalchemy.orm import declarative_base from sqlalchemy.types import Date, Float, Integer, String Base = declarative_base() @@ -11,12 +11,20 @@ class FluSurv(Base): """ __tablename__ = 'flusurv' + __table_args__ = ( + UniqueConstraint("issue", "epiweek", "location", name="issue"), + Index("release_date", "release_date"), + Index("issue_2", "issue"), + Index("epiweek", "epiweek"), + Index("region", "location"), + Index("lag", "lag"), + ) id = Column(Integer, primary_key=True, autoincrement="auto", nullable=False) release_date = Column(Date, nullable=False) - issue = Column(Integer, unique=True, nullable=False) - epiweek = Column(Integer, unique=True, nullable=False) - location = Column(String(length=32), unique=True, nullable=False) + issue = Column(Integer, nullable=False) + epiweek = Column(Integer, nullable=False) + location = Column(String(length=32), nullable=False) lag = Column(Integer, nullable=False) rate_age_0 = Column(Float, default=None) rate_age_1 = Column(Float, default=None)