database.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # -*- coding: utf-8 -*-
  2. import sqlalchemy
  3. from sqlalchemy import create_engine
  4. from sqlalchemy.ext.declarative import declarative_base
  5. from sqlalchemy.orm import sessionmaker
  6. from config import settings
  7. from contextlib import contextmanager
  8. mysql_dwd_config = {
  9. 'drivername': 'mysql+pymysql',
  10. 'username': settings.MYSQL_USER,
  11. 'password': settings.MYSQL_PASSWORD,
  12. 'host': settings.MYSQL_SERVER,
  13. 'port':settings.MYSQL_PORT,
  14. 'database': settings.MYSQL_DB_NAME
  15. }
  16. if sqlalchemy.__version__ >= '1.4':
  17. mysql_engine_url = sqlalchemy.engine.URL.create(**mysql_dwd_config)
  18. mysql_engine_url = mysql_engine_url.update_query_dict({'charset': 'utf8mb4'})
  19. else:
  20. mysql_engine_url = '{drivername}://{username}:{password}@{host}:{port}/{database}?charset=utf8mb4'.format(**mysql_dwd_config)
  21. engine = create_engine(mysql_engine_url, echo=False, pool_size=100, pool_recycle=3600, pool_pre_ping=True)
  22. SessionLocal = sessionmaker(bind=engine)
  23. Base = declarative_base()
  24. # Dependency
  25. def get_db():
  26. try:
  27. db = SessionLocal()
  28. yield db
  29. finally:
  30. db.close()
  31. # 适用scheduler
  32. @contextmanager
  33. def get_local_db():
  34. try:
  35. db = SessionLocal()
  36. yield db
  37. finally:
  38. db.close()
  39. def get_db_local():
  40. return SessionLocal()
  41. # from database import engine
  42. # from models.geojson_base import *
  43. # from shapely.geometry import shape
  44. # import json
  45. # import pymysql
  46. # # db = get_db_local()
  47. # conn = pymysql.connect(host=settings.MYSQL_SERVER,
  48. # user=settings.MYSQL_USER,
  49. # password=settings.MYSQL_PASSWORD,
  50. # database=settings.MYSQL_DB_NAME,
  51. # port=settings.MYSQL_PORT,
  52. # charset='utf8mb4')
  53. # cur = conn.cursor()
  54. # with open('/home/python3/zj_geojson.json', 'r', encoding='utf-8') as file:
  55. # geojson = json.load(file)
  56. # features = geojson.get('features', [])
  57. # for feature in features:
  58. # # print(feature)
  59. # name = feature['properties'].get('NAME', '')
  60. # geom = shape(feature['geometry']).__geo_interface__ # 将Shapely对象转换为GeoJSON
  61. # # print(geom)
  62. # properties = json.dumps(feature['properties'], ensure_ascii=False)
  63. # pac = feature['properties'].get('PAC', '')
  64. # sql = """
  65. # INSERT INTO tp_geojson_data_zj (name, geometry, properties,pac)
  66. # VALUES (%s, ST_GeomFromGeoJSON(%s), %s,%s)
  67. # """
  68. # # 执行插入操作
  69. # cur.execute(sql, (name, json.dumps(geom), properties,pac))
  70. # conn.commit()
  71. # 提交事务
  72. # break
  73. # 提交事务
  74. # db.commit()
  75. # # 关闭会话
  76. # db.close()
  77. #from models.oneshare_base import Base
  78. #from models.knowledge_base import Base
  79. #
  80. # #使用Base的metadata和engine来创建所有表
  81. #Base.metadata.create_all(bind=engine)