From 7097e2a09e78d86beed8a92bd446c10ad4a4c11b Mon Sep 17 00:00:00 2001
From: Penny <2500338766@qq.com>
Date: Thu, 4 May 2023 23:33:09 +0800
Subject: [PATCH] =?UTF-8?q?feature-1.0-tx=E6=96=87=E7=94=9F=E5=9B=BE?=
=?UTF-8?q?=EF=BC=9A=E6=B7=BB=E5=8A=A0=E8=85=BE=E8=AE=AF=E6=96=87=E7=94=9F?=
=?UTF-8?q?=E5=9B=BE=E6=8E=A5=E5=8F=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
bnyer-common/bnyer-common-core/pom.xml | 6 +
.../core/constant/RedisKeyConstant.java | 5 +
.../img/controller/TiktokMiniController.java | 7 +
.../bnyer/img/enums/AiPaintButtonEnum.java | 20 ++
.../img/service/StableDiffusionService.java | 6 +
.../impl/StableDiffusionServiceImpl.java | 184 ++++++++++++++----
6 files changed, 186 insertions(+), 42 deletions(-)
create mode 100644 bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java
diff --git a/bnyer-common/bnyer-common-core/pom.xml b/bnyer-common/bnyer-common-core/pom.xml
index 0c67d44..dfa52ae 100644
--- a/bnyer-common/bnyer-common-core/pom.xml
+++ b/bnyer-common/bnyer-common-core/pom.xml
@@ -146,6 +146,12 @@
tencentcloud-sdk-java-tmt
3.1.715
+
+
+ com.tencentcloudapi
+ tencentcloud-sdk-java-aiart
+ 3.1.715
+
diff --git a/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java b/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java
index 4e1a47b..9d20096 100644
--- a/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java
+++ b/bnyer-common/bnyer-common-core/src/main/java/com/bnyer/common/core/constant/RedisKeyConstant.java
@@ -102,6 +102,11 @@ public class RedisKeyConstant {
* 平台用户ai绘画键
*/
public static final String PLATFORM_USER_AI_PAINT_KEY = "bnyer.img.user.aiPaint:";
+
+ /**
+ * ai绘画采用sd或tx文生图开关
+ */
+ public static final String AI_PAINT_BUTTON = "bnyer.img.paint.button";
/**
* 艺术家上传键
*/
diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java
index f98e478..d60772a 100644
--- a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java
+++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/controller/TiktokMiniController.java
@@ -459,4 +459,11 @@ public class TiktokMiniController extends BaseController {
paintCdkService.useCdk(dto.getCdk(),dto.getSource(),dto.getUserId(),dto.getAppType());
return AjaxResult.success();
}
+
+ @ApiOperation(value="设置Ai绘画开关")
+ @GetMapping(value = "/setAiPaint/{buttonValue}")
+ public AjaxResult setAiPaint(@PathVariable int buttonValue){
+ stableDiffusionService.setPaintButton(buttonValue);
+ return AjaxResult.success();
+ }
}
diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java
new file mode 100644
index 0000000..dbc0764
--- /dev/null
+++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/enums/AiPaintButtonEnum.java
@@ -0,0 +1,20 @@
+package com.bnyer.img.enums;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+/**
+ * ai绘画采用sd或tx文生图枚举类
+ * @author chengkun
+ * @date 2023/4/19 17:46
+ */
+@Getter
+@AllArgsConstructor
+public enum AiPaintButtonEnum {
+ SD(1,"stable diffusion"),
+ TX(0,"tx文生图");
+
+ private int code;
+
+ private String msg;
+}
diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java
index 5b9cec4..d8d9025 100644
--- a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java
+++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/StableDiffusionService.java
@@ -23,4 +23,10 @@ public interface StableDiffusionService {
* @param paintNUm 绘画次数
*/
void addPlatformUserAiPaintNum(String appType,String platform,Long userId,int paintNUm);
+
+ /**
+ * 设置AI绘画button值
+ * @param buttonValue
+ */
+ void setPaintButton(int buttonValue);
}
diff --git a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java
index fc2fbb3..6eca136 100644
--- a/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java
+++ b/bnyer-services/bnyer-img/src/main/java/com/bnyer/img/service/impl/StableDiffusionServiceImpl.java
@@ -6,6 +6,7 @@ import com.alibaba.fastjson.JSONObject;
import com.bnyer.common.core.constant.RedisKeyConstant;
import com.bnyer.common.core.domain.AiPaint;
import com.bnyer.common.core.dto.TextToImgDto;
+import com.bnyer.common.core.exception.ServiceException;
import com.bnyer.common.core.utils.TranslateUtils;
import com.bnyer.common.core.utils.file.Base64ToMultipartFileUtils;
import com.bnyer.common.core.vo.TextToImgVo;
@@ -13,8 +14,13 @@ import com.bnyer.common.redis.service.RedisService;
import com.bnyer.file.api.RemoteFileService;
import com.bnyer.img.config.StableDiffusionConfig;
import com.bnyer.img.config.TencentTranslateConfig;
+import com.bnyer.img.enums.AiPaintButtonEnum;
import com.bnyer.img.service.AiPaintService;
import com.bnyer.img.service.StableDiffusionService;
+import com.tencentcloudapi.aiart.v20221229.AiartClient;
+import com.tencentcloudapi.aiart.v20221229.models.ResultConfig;
+import com.tencentcloudapi.aiart.v20221229.models.TextToImageRequest;
+import com.tencentcloudapi.aiart.v20221229.models.TextToImageResponse;
import com.tencentcloudapi.common.Credential;
import com.tencentcloudapi.common.exception.TencentCloudSDKException;
import com.tencentcloudapi.common.profile.ClientProfile;
@@ -29,10 +35,7 @@ import org.springframework.web.client.RestTemplate;
import org.springframework.web.multipart.MultipartFile;
import java.text.SimpleDateFormat;
-import java.util.Date;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
@Service
@Slf4j
@@ -85,43 +88,63 @@ public class StableDiffusionServiceImpl implements StableDiffusionService {
@Override
public TextToImgVo textToImg(TextToImgDto param) {
- try{
- String prompt = "";
- //判断prompt是否包含中文,中文则翻译,否则跳过
- if(TranslateUtils.isContainChinese(param.getPrompt())){
- //调用翻译api
- prompt = translate(param.getPrompt());
- }else{
- prompt = param.getPrompt();
- }
- System.out.println(prompt);
-
- //TODO 根据选择的风格来选择模型
- Map map = new HashMap<>();
- map.put("width",param.getWidth() == null ? 512 : param.getWidth());
- map.put("height",param.getHeight() == null ? 512 : param.getHeight());
- map.put("prompt", prompt);
- //map.put("prompt", param.getPrompt());
- map.put("seed",-1);
- map.put("batch_size",1);
- map.put("cfg_scale",7);
- map.put("restore_faces",false);
- map.put("tiling",false);
- map.put("eta",0);
- map.put("sampler_index","DPM++ 2S a Karras");
- //map.put("sampler_index",param.getSamplerIndex());
- map.put("steps",25);
- map.put("negative_prompt","easynegative");
- //log.info("请求stable_diffusion请求体为:【{}】", JSON.toJSONString(map));
- JSONObject jsonObject = restTemplate.postForObject(stableDiffusionConfig.getTxt2ImgUrl(), map, JSONObject.class);
- //log.info("请求stable_diffusion响应体的为:【{}】", JSON.toJSONString(jsonObject));
- TextToImgVo img = new TextToImgVo();
- if(jsonObject != null && jsonObject.getJSONArray("images").size() > 0){
- List images = jsonObject.getJSONArray("images").toJavaList(String.class);
- img.setImages(images);
+
+ //配置控制采用sd还是tx文生图
+ int button = redisService.getCacheObject(RedisKeyConstant.AI_PAINT_BUTTON);
+ if(button == AiPaintButtonEnum.TX.getCode()){
+ //采用腾讯文生图
+ try{
+ Credential cred = new Credential(tencentTranslateConfig.getSecretId(), tencentTranslateConfig.getSecretKey());
+ // 实例化一个http选项,可选的,没有特殊需求可以跳过
+ HttpProfile httpProfile = new HttpProfile();
+ httpProfile.setEndpoint("aiart.tencentcloudapi.com");
+ // 实例化一个client选项,可选的,没有特殊需求可以跳过
+ ClientProfile clientProfile = new ClientProfile();
+ clientProfile.setHttpProfile(httpProfile);
+ // 实例化要请求产品的client对象,clientProfile是可选的
+ AiartClient client = new AiartClient(cred, "ap-guangzhou", clientProfile);
+ // 实例化一个请求对象,每个接口都会对应一个request对象
+ TextToImageRequest req = new TextToImageRequest();
+ req.setPrompt(param.getPrompt());
+ //请求风格
+ String[] styles1 = new String[0];
+ switch (param.getStyleName()) {
+ case "细腻":
+ styles1 = new String[]{"110"};
+ break;
+ case "卡通":
+ styles1 = new String[]{"201"};
+ break;
+ case "科幻":
+ styles1 = new String[]{"114"};
+ break;
+ case "中国风":
+ styles1 = new String[]{"101"};
+ break;
+ }
+ req.setStyles(styles1);
+
+ //画布大小
+ ResultConfig resultConfig1 = new ResultConfig();
+ if(param.getWidth() == 512 && param.getHeight() == 512){
+ resultConfig1.setResolution("768:768");
+ }else if(param.getWidth() == 512 && param.getHeight() == 1024){
+ resultConfig1.setResolution("768:1024");
+ }else{
+ resultConfig1.setResolution("1024:768");
+ }
+ req.setResultConfig(resultConfig1);
+
+ // 返回的resp是一个TextToImageResponse的实例,与请求对象对应
+ TextToImageResponse resp = client.TextToImage(req);
+ TextToImgVo img = new TextToImgVo();
+ String images = resp.getResultImage();
+ List list = new ArrayList<>();
+ list.add(images);
+ img.setImages(list);
String paintId = null;
Date paintTime = null;
- for (String image : images) {
+ for (String image : list) {
//base64转file
MultipartFile file = new Base64ToMultipartFileUtils(image, "data:image/png;base64", "file", "tempSDImg");
//上传图片到七牛云/minio
@@ -149,10 +172,82 @@ public class StableDiffusionServiceImpl implements StableDiffusionService {
}
img.setPaintId(paintId);
img.setPaintTime(paintTime);
+ return img;
+ } catch (TencentCloudSDKException e) {
+ log.error("腾讯文生图调用错误!"+e.getMessage());
+ throw new ServiceException(e.getMessage(),500);
+ }
+ }else{
+ //采用sd
+ try{
+ String prompt = "";
+ //判断prompt是否包含中文,中文则翻译,否则跳过
+ if(TranslateUtils.isContainChinese(param.getPrompt())){
+ //调用翻译api
+ prompt = translate(param.getPrompt());
+ }else{
+ prompt = param.getPrompt();
+ }
+ System.out.println(prompt);
+
+ //TODO 根据选择的风格来选择模型
+ Map map = new HashMap<>();
+ map.put("width",param.getWidth() == null ? 512 : param.getWidth());
+ map.put("height",param.getHeight() == null ? 512 : param.getHeight());
+ map.put("prompt", prompt);
+ //map.put("prompt", param.getPrompt());
+ map.put("seed",-1);
+ map.put("batch_size",1);
+ map.put("cfg_scale",7);
+ map.put("restore_faces",false);
+ map.put("tiling",false);
+ map.put("eta",0);
+ map.put("sampler_index","DPM++ 2S a Karras");
+ //map.put("sampler_index",param.getSamplerIndex());
+ map.put("steps",25);
+ map.put("negative_prompt","easynegative,nsfw,naked");
+ //log.info("请求stable_diffusion请求体为:【{}】", JSON.toJSONString(map));
+ JSONObject jsonObject = restTemplate.postForObject(stableDiffusionConfig.getTxt2ImgUrl(), map, JSONObject.class);
+ //log.info("请求stable_diffusion响应体的为:【{}】", JSON.toJSONString(jsonObject));
+ TextToImgVo img = new TextToImgVo();
+ if(jsonObject != null && jsonObject.getJSONArray("images").size() > 0){
+ List images = jsonObject.getJSONArray("images").toJavaList(String.class);
+ img.setImages(images);
+ String paintId = null;
+ Date paintTime = null;
+ for (String image : images) {
+ //base64转file
+ MultipartFile file = new Base64ToMultipartFileUtils(image, "data:image/png;base64", "file", "tempSDImg");
+ //上传图片到七牛云/minio
+ String imgStr = remoteFileService.uploadBanner(file).getData();
+ //保存生辰该图片到ai绘画表
+ AiPaint paint = new AiPaint();
+ //paint.setId(); 主键改成雪花算法后启用
+ paintId = IdUtil.getSnowflakeNextIdStr();
+ paintTime = new Date();
+ paint.setPaintId(paintId);
+ paint.setCreateTime(paintTime);
+ paint.setImgUrl(imgStr);
+ paint.setPrompt(param.getPrompt());
+ paint.setModel(param.getModelName());
+ paint.setStyleName(param.getStyleName());
+ paint.setHeight(param.getHeight() == null ? "512" : String.valueOf(param.getHeight()));
+ paint.setWidth(param.getWidth() == null ? "512" : String.valueOf(param.getWidth()));
+ paint.setIsShow("1");
+ paint.setSource(param.getPlatform());
+ paint.setPainterId(param.getPainterId());
+ paint.setPainterName(param.getPainterName());
+ aiPaintService.insert(paint);
+ //写入ai绘画次数
+ writePlatformUserAiPaintNum(param.getAppType(),param.getPlatform(),param.getPainterId());
+ }
+ img.setPaintId(paintId);
+ img.setPaintTime(paintTime);
+ }
+ return img;
+ }catch (Exception e){
+ log.error("文本翻译错误!"+e);
}
- return img;
- }catch (Exception e){
- log.error("文本翻译错误!"+e);
}
return null;
}
@@ -186,4 +281,9 @@ public class StableDiffusionServiceImpl implements StableDiffusionService {
String redisKey = RedisKeyConstant.PLATFORM_USER_AI_PAINT_KEY + date;
redisService.hashIncr(redisKey, hashKey, -patinNum);
}
+
+ @Override
+ public void setPaintButton(int buttonValue) {
+ redisService.setCacheObject(RedisKeyConstant.AI_PAINT_BUTTON,buttonValue);
+ }
}