commercialcompany/App/Router/ReportRouter.py

315 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import io
import json
import os
import time
import requests
from PIL import Image
from docx import Document
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
from docxtpl import DocxTemplate
from fastapi import APIRouter, Depends, HTTPException
from docx.shared import Pt, Inches
from starlette.background import BackgroundTask
from starlette.responses import FileResponse
from App.Schemas import ReportSchemas
from Utils.OpenaiUtils.api import get_openai_response
router = APIRouter(
prefix="/api/report_generation"
)
def del_file(path):
os.remove(path)
# 文件路径
file_template = os.path.join(os.getcwd(), 'Utils', 'File', 'template', '信用报告模板v2.docx')
file_path = os.path.join(os.getcwd(), 'Utils', 'File', 'generate', '信用报告.docx')
@router.get("/report_template", summary="查看报告模板", tags=["报告生成"])
def func():
file = '信用报告模板.docx'
response = FileResponse(file_template, filename=file,
media_type='application/vnd.openxmlformats-officedocument.wordprocessingml.document')
return response
@router.post("/generation", summary="生成word报告", tags=["报告生成"])
def func(schemas: ReportSchemas.ReportData):
# 获取报告模板
doc = DocxTemplate(file_template)
# 获取填报数据
report_content = schemas.dict()
file_name = '{}信用报告.docx'.format(report_content.get('企业名称'))
# 处理联系电话过长的问题,先把字符串根据逗号分为列表,前五个拼接为一个字符串,中间加一个换行符,后面的拼接为一个字符串
contact_phone = report_content.get('联系电话')
contact_phone = contact_phone.split(',')
contact_phone_begin = contact_phone[:5]
contact_phone_end = contact_phone[5:]
contact_phone_begin = ','.join(contact_phone_begin)
contact_phone_end = ','.join(contact_phone_end)
report_content['联系电话'] = contact_phone_begin + '\n' + contact_phone_end
# 处理经营范围样式问题字符串前33个字符为一行后面每38个字符为一行每行最后添加换行符
business_scope = report_content.get('经营范围')
report_content['经营范围'] = business_scope[:33] + '\n' + '\n'.join(
[business_scope[i:i + 38] for i in range(33, len(business_scope), 38)])
# 读取json文件
with open(os.path.join(os.getcwd(), 'Utils', 'OpenaiUtils', 'prompt.json'), 'r', encoding='gbk') as f:
report_json = json.load(f)
def openai_api():
history = report_content.get('历史沿革')
history_prompt = report_json.get('历史沿革')
if history != '-':
history_prompt = history + history_prompt
history_by_chatgpt = get_openai_response(history_prompt)
report_content['历史沿革'] = history_by_chatgpt
shareholding = report_content.get('股权结构')
shareholding_prompt = report_json.get('股权结构')
if shareholding != '-':
shareholding_prompt = shareholding + shareholding_prompt
shareholding_by_chatgpt = get_openai_response(shareholding_prompt)
report_content['股权结构'] = shareholding_by_chatgpt
executives = report_content.get('高管信息')
executives_prompt = report_json.get('高管信息')
if executives != '--' and executives != '-' and executives is not None:
executives_prompt = executives + executives_prompt
executives_by_chatgpt = get_openai_response(executives_prompt)
report_content['高管信息'] = executives_by_chatgpt
last_prompt = report_json.get('管理制度与报告结论')
last_prompt = str(report_content) + last_prompt
last_by_chatgpt = get_openai_response(last_prompt)
if last_by_chatgpt:
last_by_chatgpt = json.loads(last_by_chatgpt)
report_content['管理制度'] = last_by_chatgpt.get('管理制度')
report_content['报告结论'] = last_by_chatgpt.get('报告结论')
# 调用openai api优化填报数据
openai_api()
# 替换除表格以外的数据
doc.render(report_content)
doc.save(file_path)
# 处理表格数据
file = Document(file_path)
tables = file.tables
tables_to_keep = list()
# 处理表格数据
for table in tables:
if table.rows[0].cells[0].text == '申请日期':
if report_content.get('商标信息'):
brand_data = []
for item in report_content.get('商标信息'):
row = []
for key in item:
row.append(item[key])
brand_data.append(row)
for brand_index in range(0, len(brand_data)):
table.add_row()
# 当前行
current_row = len(table.rows) - 1
for r_i in range(len(brand_data[brand_index])):
if r_i == 1:
# 获取图片路径
pic_url = brand_data[brand_index][r_i]
# 根据链接下载图片到本地
try:
pic_response = requests.get(pic_url)
# 保存图片
if pic_response.status_code == 200:
def get_pic_type():
content_type = pic_response.headers.get('Content-Type')
image_type = content_type.split('/')[-1].upper()
return '.' + image_type
pic_type = get_pic_type()
pic_name = str(time.time()) + pic_type
pic_path = os.path.join(os.getcwd(), 'Utils', 'File', 'picture', pic_name)
with open(pic_path, 'wb') as f:
f.write(pic_response.content)
try:
table.cell(current_row, r_i).paragraphs[0].add_run().add_picture(pic_path,
width=Inches(
0.9),
height=Inches(
0.9))
except Exception as e:
with Image.open(pic_path) as img:
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='JPEG')
img_byte_arr.seek(0) # 重置流的位置
table.cell(current_row, r_i).paragraphs[0].add_run().add_picture(
img_byte_arr, width=Inches(0.9), height=Inches(0.9))
# 删除本地图片
os.remove(pic_path)
except Exception as e:
continue
else:
table.cell(current_row, r_i).text = str(brand_data[brand_index][r_i])
for section in table.cell(current_row, r_i).paragraphs:
for block in section.runs:
block.font.size = Pt(9)
table.cell(current_row, r_i).vertical_alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
table.cell(current_row, r_i).paragraphs[0].alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
elif table.rows[0].cells[0].text == '申请日':
if report_content.get('专利信息'):
patent_data = []
for item in report_content.get('专利信息'):
row = []
for key in item:
row.append(item[key])
patent_data.append(row)
for patent_index in range(0, len(patent_data)):
table.add_row()
# 当前行
current_row = len(table.rows) - 1
for r_i in range(len(patent_data[patent_index])):
table.cell(current_row, r_i).text = str(patent_data[patent_index][r_i])
for section in table.cell(current_row, r_i).paragraphs:
for block in section.runs:
block.font.size = Pt(9)
table.cell(current_row, r_i).vertical_alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
table.cell(current_row, r_i).paragraphs[0].alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
elif table.rows[0].cells[0].text == '批准日期':
if report_content.get('软件著作权'):
software_data = []
for item in report_content.get('软件著作权'):
row = []
for key in item:
row.append(item[key])
software_data.append(row)
for software_index in range(0, len(software_data)):
table.add_row()
# 当前行
current_row = len(table.rows) - 1
for r_i in range(len(software_data[software_index])):
table.cell(current_row, r_i).text = str(software_data[software_index][r_i])
for section in table.cell(current_row, r_i).paragraphs:
for block in section.runs:
block.font.size = Pt(9)
table.cell(current_row, r_i).vertical_alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
table.cell(current_row, r_i).paragraphs[0].alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
elif table.rows[0].cells[0].text == '序号':
if report_content.get('主要供应商情况'):
supplier_data = []
for index, item in enumerate(report_content.get('主要供应商情况')):
row = [index + 1]
for key in item:
row.append(item[key])
supplier_data.append(row)
for supplier_index in range(0, len(supplier_data)):
table.add_row()
# 当前行
current_row = len(table.rows) - 1
for r_i in range(len(supplier_data[supplier_index])):
table.cell(current_row, r_i).text = str(supplier_data[supplier_index][r_i])
for section in table.cell(current_row, r_i).paragraphs:
for block in section.runs:
block.font.size = Pt(9)
table.cell(current_row, r_i).vertical_alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
table.cell(current_row, r_i).paragraphs[0].alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
elif table.rows[0].cells[0].text == '发布时间':
if report_content.get('招投标情况'):
tender_data = []
for item in report_content.get('招投标情况'):
row = []
for key in item:
row.append(item[key])
tender_data.append(row)
for tender_index in range(0, len(tender_data)):
table.add_row()
# 当前行
current_row = len(table.rows) - 1
for r_i in range(len(tender_data[tender_index])):
table.cell(current_row, r_i).text = str(tender_data[tender_index][r_i])
for section in table.cell(current_row, r_i).paragraphs:
for block in section.runs:
block.font.size = Pt(9)
table.cell(current_row, r_i).vertical_alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
table.cell(current_row, r_i).paragraphs[0].alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
elif table.rows[0].cells[0].text == '指标名称':
finance_data = report_content.get('主要财务数据')
years = list(finance_data.keys())
result = [years]
for k in finance_data[years[0]]:
row = [k]
for y in years:
row.append(finance_data[y][k])
result.append(row)
result[0].insert(0, '指标名称')
rows = len(table.rows)
cols = len(table.columns)
for row in range(rows):
for col in range(cols):
if col != 0:
table.cell(row, col).text = str(result[row][col])
table.cell(row, col).vertical_alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
table.cell(row, col).paragraphs[0].alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
for section in table.cell(row, col).paragraphs:
for block in section.runs:
block.font.size = Pt(9)
# 删除数据为空的表格
for table in tables:
txt = table.rows[0].cells[0].text
if txt == '申请日期' and not report_content.get('商标信息'):
tables_to_keep.append(table)
elif txt == '申请日' and not report_content.get('专利信息'):
tables_to_keep.append(table)
elif txt == '批准日期' and not report_content.get('软件著作权'):
tables_to_keep.append(table)
elif txt == '序号' and not report_content.get('主要供应商情况'):
tables_to_keep.append(table)
elif txt == '发布时间' and not report_content.get('招投标情况'):
tables_to_keep.append(table)
# 删除表格
for table in tables_to_keep:
tbl = table._element
tbl.getparent().remove(tbl)
# 新增文字
for para in file.paragraphs:
if para.text == "2.3 技术成果" and not report_content.get('商标信息'):
new_paragraph = file.add_paragraph("截止报告日,未查询到相关信息。")
index = para._element.getparent().index(para._element)
para._element.getparent().insert(index + 1, new_paragraph._element)
elif para.text == "2.4 软件著作权" and not report_content.get('软件著作权'):
new_paragraph = file.add_paragraph("截止报告日,未查询到相关信息。")
index = para._element.getparent().index(para._element)
para._element.getparent().insert(index + 1, new_paragraph._element)
elif para.text == "2.5 主要供应商情况" and not report_content.get('主要供应商情况'):
new_paragraph = file.add_paragraph("截止报告日,未查询到相关信息。")
index = para._element.getparent().index(para._element)
para._element.getparent().insert(index + 1, new_paragraph._element)
elif para.text == "2.6 招投标情况" and not report_content.get('招投标情况'):
new_paragraph = file.add_paragraph("截止报告日,未查询到相关信息。")
index = para._element.getparent().index(para._element)
para._element.getparent().insert(index + 1, new_paragraph._element)
if not report_content.get('软件著作权'):
for para in file.paragraphs:
if para.text == '说明技术成果包括专利、商标、科技进步奖、工法、QC 小组活动成果、参与制定标准等。':
p = para._element
p.getparent().remove(p)
file.save(file_path)
task = BackgroundTask(del_file, file_path)
return FileResponse(file_path, filename=file_name, media_type='application/octet-stream', background=task)