Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
slippersheepig authored Jun 13, 2024
1 parent 82f0c1b commit b57eb18
Showing 1 changed file with 19 additions and 46 deletions.
65 changes: 19 additions & 46 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,31 @@
import numpy as np
import cv2
import io
import logging
from decouple import config
from queue import Queue
from threading import Lock
from time import sleep

# 设置日志
logging.basicConfig(level=logging.INFO)

# 常量
BOT_TOKEN = config('BOT_TOKEN')
HUGGINGFACE_TOKEN = config('HUGGINGFACE_TOKEN')
API_URL = config('API_URL')

if not BOT_TOKEN or not HUGGINGFACE_TOKEN or not API_URL:
raise ValueError("请确保BOT_TOKEN, HUGGINGFACE_TOKEN和API_URL在环境变量中正确设置")

# 创建机器人
bot = telebot.TeleBot(BOT_TOKEN)
bot.set_webhook()

# 队列和锁
# 队列
queue = Queue()
queue_lock = Lock()

# 请求到stablediffusion
headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}

def stablediffusion(payload, retries=3):
for attempt in range(retries):
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status()
return response.content
except requests.exceptions.RequestException as e:
if attempt < retries - 1:
logging.warning(f"stablediffusion请求错误: {str(e)},重试中... ({attempt+1}/{retries})")
sleep(2) # 等待一段时间后重试
else:
logging.error(f"stablediffusion请求错误: {str(e)},重试次数用尽")
raise Exception(f"stablediffusion请求错误: {str(e)}")
def stablediffusion(payload):
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status()
return response.content
except requests.exceptions.RequestException as e:
raise Exception(f"stablediffusion请求错误: {str(e)}")

def generate_image(message, user, prompt):
try:
Expand All @@ -52,20 +36,14 @@ def generate_image(message, user, prompt):
image_bytes = stablediffusion({'inputs': prompt})
img_bytes = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
success, png_image = cv2.imencode('.png', img_bytes)
if not success:
raise ValueError("图像编码失败")
photo = io.BytesIO(png_image)
photo.seek(0)
bot.send_message(message.chat.id, text=f"请求: {prompt}\nstablediffusion:")
bot.send_photo(message.chat.id, photo)
except Exception as e:
bot.reply_to(message, f"生成图片错误: {str(e)}")
logging.error(f"生成图片错误: {str(e)}")
finally:
with queue_lock:
if user in queue.queue:
queue.get() # 从队列中移除用户
photo.close()
queue.get() # 从队列中移除用户

@bot.message_handler(commands=['start'])
def send_welcome(message):
Expand All @@ -78,24 +56,19 @@ def send_help(message):
@bot.message_handler(commands=['sd'])
def stablediffusion_command(message):
user = message.from_user.username
if not user:
bot.reply_to(message, "无法获取用户名,请确保您的帐户有用户名")
return

with queue_lock:
if user not in queue.queue:
prompt = message.text.replace("/sd", "").strip()
if prompt:
queue.put(user)
bot.reply_to(message, f'请稍等... \n您在队列中的位置: {queue.qsize()}')
generate_image(message, user, prompt)
else:
bot.reply_to(message, '请求不能为空.')
if user not in queue.queue:
prompt = message.text.replace("/sd", "").strip()
if prompt:
queue.put(user)
bot.reply_to(message, f'请稍等... \n您在队列中的位置: {queue.qsize()}')
generate_image(message, user, prompt)
else:
bot.reply_to(message, "您已经在生成此模型的查询,请等待完成.")
bot.reply_to(message, '请求不能为空.')
else:
bot.reply_to(message, "您已经在生成此模型的查询,请等待完成.")

def main():
bot.polling(none_stop=True)
bot.polling()

if __name__ == "__main__":
main()

0 comments on commit b57eb18

Please sign in to comment.